diff --git a/.gitattributes b/.gitattributes index f2393ace7c10193d730e9e98485264cacbfa51eb..628226a30d9d140532e3c2070459358042cb3ccc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -785,3 +785,165 @@ lib/python3.10/site-packages/scipy/stats/tests/__pycache__/test_distributions.cp lib/python3.10/site-packages/scipy/stats/tests/__pycache__/test_morestats.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text lib/python3.10/site-packages/scipy/stats/tests/__pycache__/test_multivariate.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text lib/python3.10/site-packages/scipy/stats/tests/__pycache__/test_stats.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/special/__pycache__/_add_newdocs.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/special/__pycache__/_basic.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/csgraph/_matching.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/csgraph/_min_spanning_tree.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/csgraph/_reordering.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/csgraph/_shortest_path.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/csgraph/_tools.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/csgraph/_traversal.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/tests/__pycache__/test_base.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/linalg/_dsolve/_superlu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_cpropack.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_dpropack.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_spropack.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_zpropack.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/linalg/_eigen/arpack/_arpack.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/special/tests/__pycache__/test_basic.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/spatial/transform/_rotation.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/sparse/csgraph/_flow.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/optimize/_lsq/givens_elimination.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/signal/__pycache__/_filter_design.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/signal/__pycache__/_signaltools.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/signal/tests/__pycache__/test_filter_design.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/signal/tests/__pycache__/test_signaltools.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/optimize/_trlib/_trlib.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/optimize/__pycache__/_optimize.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/interpolate/tests/__pycache__/test_bsplines.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/io/matlab/_mio5_utils.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/io/matlab/_streams.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/io/_fast_matrix_market/_fmm_core.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/optimize/_highspy/_core.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/optimize/_highspy/_highs_options.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/fft/_pocketfft/pypocketfft.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/constants/__pycache__/_codata.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/cluster/__pycache__/hierarchy.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/scipy/integrate/__pycache__/_lebedev.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/fontTools/varLib/iup.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/fontTools/__pycache__/agl.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/fontTools/cu2qu/cu2qu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/fontTools/feaLib/lexer.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/fontTools/misc/bezierTools.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/fontTools/pens/momentsPen.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/fontTools/qu2cu/qu2cu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/fontTools/ttLib/tables/__pycache__/otData.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/fontTools/subset/__pycache__/__init__.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/fontTools/otlLib/__pycache__/builder.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/tensorboard/_vendor/html5lib/__pycache__/constants.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/tensorboard_data_server/bin/server filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/taichi/assets/Go-Regular.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/taichi/_lib/core/taichi_python.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/taichi/_lib/runtime/runtime_cuda.bc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/taichi/_lib/runtime/runtime_x64.bc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/taichi/_lib/runtime/slim_libdevice.10.bc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/taichi/assets/static/imgs/ti_gallery.png filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/core/__pycache__/function.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/core/__pycache__/numbers.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/core/__pycache__/expr.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/core/tests/__pycache__/test_args.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/stats/__pycache__/crv_types.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/utilities/tests/__pycache__/test_wester.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/tensor/__pycache__/tensor.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/polys/__pycache__/polyquinticconst.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/polys/__pycache__/polytools.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/combinatorics/__pycache__/perm_groups.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/solvers/__pycache__/solvers.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/solvers/__pycache__/solveset.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_solvers.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_solveset.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/ode.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/single.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_systems.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/solvers/diophantine/__pycache__/diophantine.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/printing/__pycache__/latex.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_latex.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/printing/pretty/tests/__pycache__/test_pretty.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/domainmatrix.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_spin.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/physics/control/__pycache__/lti.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/physics/continuum_mechanics/__pycache__/beam.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/polys/benchmarks/__pycache__/bench_solvers.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/polys/tests/__pycache__/test_polytools.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_cython.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/pydevd_frame_evaluator.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/logic/__pycache__/boolalg.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/matrices/__pycache__/matrixbase.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/matrices/tests/__pycache__/test_matrices.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/matrices/tests/__pycache__/test_matrixbase.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/winappdbg/__pycache__/breakpoint.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/winappdbg/__pycache__/process.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sympy/parsing/latex/_antlr/__pycache__/latexparser.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/skvideo/datasets/data/bigbuckbunny.mp4 filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/skvideo/datasets/data/bikes.mp4 filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/skvideo/datasets/data/carphone_pristine.mp4 filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/_debug_adapter/__pycache__/pydevd_schema.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/narwhals/__pycache__/dataframe.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/narwhals/__pycache__/expr.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-Bold.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-BoldOblique.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-ExtraLight.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-Oblique.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed-Bold.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed-BoldOblique.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed-Oblique.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/mpl_toolkits/mplot3d/__pycache__/axes3d.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/__pycache__/widgets.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/__pycache__/_cm_listed.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/__pycache__/backend_bases.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/__pycache__/colors.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/mpmath/__pycache__/function_docs.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/tests/__pycache__/test_axes.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/__pycache__/figure.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/__pycache__/patches.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/__pycache__/pyplot.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/backends/_backend_agg.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/backends/_tkagg.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/axes/__pycache__/_axes.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/axes/__pycache__/_base.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-Oblique.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-Bold.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-BoldItalic.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-Italic.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralBol.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneral.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralBolIta.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralItalic.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-BoldOblique.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-Bold.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-Oblique.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-Bold.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-BoldOblique.ttf filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/pygments/lexers/__pycache__/lisp.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/pycparser/__pycache__/yacctab.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/plotly/__pycache__/basedatatypes.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_figure.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_box.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_figurewidget.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_layout.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/plotly/graph_objs/layout/__pycache__/_xaxis.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/plotly/graph_objs/layout/__pycache__/_yaxis.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/neighbors/_quad_tree.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/preprocessing/_csr_polynomial_expansion.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/preprocessing/_target_encoder_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/svm/_liblinear.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/svm/_libsvm.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/svm/_libsvm_sparse.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/tree/_criterion.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/tree/_partitioner.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/tree/_splitter.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/tree/_tree.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/tree/_utils.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/pkg_resources/__pycache__/__init__.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/metrics/_dist_metrics.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/metrics/_pairwise_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/neighbors/_ball_tree.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/neighbors/_kd_tree.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/utils/_seq_dataset.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/utils/_typedefs.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +lib/python3.10/site-packages/sklearn/utils/_vector_sentinel.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text diff --git a/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-Bold.ttf b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-Bold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..d375dc1dee74b3795888e95e7f7ca2e6eede0602 --- /dev/null +++ b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-Bold.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e49221bbf17c0361274143169f7e6f16d8715f65d49f1d9f216eb3d661400308 +size 672300 diff --git a/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-BoldOblique.ttf b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-BoldOblique.ttf new file mode 100644 index 0000000000000000000000000000000000000000..0d67fa43858c06c7dc3bafa38a8929edd3b56815 --- /dev/null +++ b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-BoldOblique.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:735f3dd7381b6f4e0ca519b72620c75aa953ee8ee89b88cbe4f38ca3c23d6a79 +size 611212 diff --git a/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-ExtraLight.ttf b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-ExtraLight.ttf new file mode 100644 index 0000000000000000000000000000000000000000..a104dbf7fa6ac6b1bc1f088544502310c01c8a1c --- /dev/null +++ b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-ExtraLight.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:908d6ec802f28155c8de86192b5a77a9fb41792f072e03526f0536c234f3e9a0 +size 345204 diff --git a/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-Oblique.ttf b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-Oblique.ttf new file mode 100644 index 0000000000000000000000000000000000000000..9944cdd6a22da9b76dcb879a7b710c7f6d7114f5 --- /dev/null +++ b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans-Oblique.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bc9c02fefcadd517e5a158b2f34233dd354d67f4302486d88e84bca467d1d43 +size 611556 diff --git a/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans.ttf b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans.ttf new file mode 100644 index 0000000000000000000000000000000000000000..8c80a32611676de4bbf2975f0ba0c5882560f609 --- /dev/null +++ b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSans.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15da2d8f12e6950001b1cc8225c1ba72ddce1938837d37702ff3e9bf6d79bd5e +size 720012 diff --git a/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed-Bold.ttf b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed-Bold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..9f326bf225091bacbcef4529a2e5240340bf9c4a --- /dev/null +++ b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed-Bold.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f86c6d40a52ffe2b40f19d0b972ca4b9ce347fc04dcfc4d0b4e9277a8712c0dd +size 631992 diff --git a/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed-BoldOblique.ttf b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed-BoldOblique.ttf new file mode 100644 index 0000000000000000000000000000000000000000..4cc12324520dd52cb54f8b867f7d241d15a96865 --- /dev/null +++ b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed-BoldOblique.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6657693a18ecefee2667d9d0ecb1abb68524d627999bd8365039c19e04b42381 +size 580168 diff --git a/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed-Oblique.ttf b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed-Oblique.ttf new file mode 100644 index 0000000000000000000000000000000000000000..70b3dc1f48d97574e470784c98a15e7511c82750 --- /dev/null +++ b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed-Oblique.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48f994e81284666ab4bf89ef4d73085b07fae6c2c7e28820ab243e9941c4829e +size 576004 diff --git a/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed.ttf b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed.ttf new file mode 100644 index 0000000000000000000000000000000000000000..0a5db4d95da618635b71569c61ab434c447eeb46 --- /dev/null +++ b/lib/python3.10/site-packages/cv2/qt/fonts/DejaVuSansCondensed.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69f1355c9eef0a3d11a6c06f3cbf1d46eabfdadcc993589a3be93a44ed8678b4 +size 643852 diff --git a/lib/python3.10/site-packages/dash/labextension/dist/dash-jupyterlab.tgz b/lib/python3.10/site-packages/dash/labextension/dist/dash-jupyterlab.tgz new file mode 100644 index 0000000000000000000000000000000000000000..33e8e06748a28df84826be6d7a1f468cd4d4484e --- /dev/null +++ b/lib/python3.10/site-packages/dash/labextension/dist/dash-jupyterlab.tgz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc668e3cecc33fc68775489e7931ccaff8bfd7105565a6e143e856eb4f40af7e +size 2371 diff --git a/lib/python3.10/site-packages/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz b/lib/python3.10/site-packages/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..ff62a934a5009337271c60501278a7a34913a20b --- /dev/null +++ b/lib/python3.10/site-packages/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3ea52e7b6e968de0d884df1288193596fa95b803db4f92a18279a7398004475 +size 156400 diff --git a/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/_debug_adapter/__pycache__/pydevd_schema.cpython-310.pyc b/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/_debug_adapter/__pycache__/pydevd_schema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db9741fa453d21e7a7c15b2f60c6244b1b5af361 --- /dev/null +++ b/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/_debug_adapter/__pycache__/pydevd_schema.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfdeb33559f7b970ee2f252d9c4a4b517b80d17ba51ff28afc27880053e11440 +size 452096 diff --git a/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_cython.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_cython.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..4ba35d84da86a38d25934fe1d389c9e0517f7c3a --- /dev/null +++ b/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_cython.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c61c01ea66354c0856cabdad97267a859b347c5eaa61b17c0e07269f1a82ca8e +size 786048 diff --git a/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/pydevd_frame_evaluator.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/pydevd_frame_evaluator.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..c2c9c8a7ca276ca8771473bccaae05a2ec52108a --- /dev/null +++ b/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/pydevd_frame_evaluator.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc5bd9a9dacec855e09b85ca9c8238a8273ef1b4e7faa04b540817d690ff06c1 +size 284920 diff --git a/lib/python3.10/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/winappdbg/__pycache__/breakpoint.cpython-310.pyc b/lib/python3.10/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/winappdbg/__pycache__/breakpoint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7110b0b366a4f097c836ccc3e7aeed95fde2468 --- /dev/null +++ b/lib/python3.10/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/winappdbg/__pycache__/breakpoint.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6455e2e9eb97e314167500d4b63c4e4ae4e86b806f4a3f9658100fc6254cd76f +size 129973 diff --git a/lib/python3.10/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/winappdbg/__pycache__/process.cpython-310.pyc b/lib/python3.10/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/winappdbg/__pycache__/process.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed999d3c6fd89c306b156f9b388804d277f2b88e --- /dev/null +++ b/lib/python3.10/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/winappdbg/__pycache__/process.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a755a0a1ac61641b839eab38d70996ab3f706d5860552907a8b33895034b0c11 +size 133248 diff --git a/lib/python3.10/site-packages/fontTools/__pycache__/agl.cpython-310.pyc b/lib/python3.10/site-packages/fontTools/__pycache__/agl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72567c1fe2412bc9d3c311e4a04021e0b1e45ec7 --- /dev/null +++ b/lib/python3.10/site-packages/fontTools/__pycache__/agl.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f626ab0321e5e3d4665aa106160800a27baf2ecb883673bed5ffca69758989ea +size 111058 diff --git a/lib/python3.10/site-packages/fontTools/cu2qu/cu2qu.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/fontTools/cu2qu/cu2qu.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..5f79fef6fce65a4beef574eb07e64a6a40e1008d --- /dev/null +++ b/lib/python3.10/site-packages/fontTools/cu2qu/cu2qu.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87ffb69cda69e138e9379d211599ef21bd0f142fcf400e78bfb187ca510d0fb9 +size 1024424 diff --git a/lib/python3.10/site-packages/fontTools/feaLib/lexer.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/fontTools/feaLib/lexer.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..507896cc102fba8ae63de6e2eb22f9ee0b150caf --- /dev/null +++ b/lib/python3.10/site-packages/fontTools/feaLib/lexer.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1957b028ab8af856ac774507f0dc4755ae2e59cf9243f513ad90678583ce560c +size 1414432 diff --git a/lib/python3.10/site-packages/fontTools/misc/bezierTools.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/fontTools/misc/bezierTools.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..676c61039c62ab1dbe367179a5b044e113bcc934 --- /dev/null +++ b/lib/python3.10/site-packages/fontTools/misc/bezierTools.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bba19fdd711e02fbc29d772e313ef511939a56a5cee9f88d192e023c6d79624 +size 4714576 diff --git a/lib/python3.10/site-packages/fontTools/otlLib/__pycache__/builder.cpython-310.pyc b/lib/python3.10/site-packages/fontTools/otlLib/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..004160a5664f8b4283c7a29e82cf511d66e00d8e --- /dev/null +++ b/lib/python3.10/site-packages/fontTools/otlLib/__pycache__/builder.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:794277539c08f264633b446a0a352b8b20a2c99ce04fed20f441eb951241c3c1 +size 112587 diff --git a/lib/python3.10/site-packages/fontTools/pens/momentsPen.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/fontTools/pens/momentsPen.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..196f8ca1b517c65876989badf43dfaa7364d5bd5 --- /dev/null +++ b/lib/python3.10/site-packages/fontTools/pens/momentsPen.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0635c9682db43ff193dfb6686186f67a2549c7331f4d58303f0cea8b4b072199 +size 916288 diff --git a/lib/python3.10/site-packages/fontTools/qu2cu/qu2cu.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/fontTools/qu2cu/qu2cu.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..6698f390ab57ad087a0d0c52f3663c8b8bdc749a --- /dev/null +++ b/lib/python3.10/site-packages/fontTools/qu2cu/qu2cu.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e6e546258484a18ea729ee84f5591143e6b376e79c711454e2de9d23adb4b0f +size 1120912 diff --git a/lib/python3.10/site-packages/fontTools/subset/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/fontTools/subset/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03bb3de6f43f81d44d538fb8fcceed2d12d9d9a1 --- /dev/null +++ b/lib/python3.10/site-packages/fontTools/subset/__pycache__/__init__.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1af5b75709f0c2cc034b2380833401eb611a62de8e2ed5489c2f0aa46f616023 +size 106176 diff --git a/lib/python3.10/site-packages/fontTools/ttLib/tables/__pycache__/otData.cpython-310.pyc b/lib/python3.10/site-packages/fontTools/ttLib/tables/__pycache__/otData.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a26e4b87ced2275e6c852b320079e8f456e6aea7 --- /dev/null +++ b/lib/python3.10/site-packages/fontTools/ttLib/tables/__pycache__/otData.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:876469703f30576fecb1ce49de3651c4b00e546d08f9d19d0b20280a922c7b18 +size 107142 diff --git a/lib/python3.10/site-packages/fontTools/varLib/iup.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/fontTools/varLib/iup.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..dc6aff2966ea6d54c29ad9204c15b1c0ec10e720 --- /dev/null +++ b/lib/python3.10/site-packages/fontTools/varLib/iup.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5daf0eafee222c8254bdccc814a454952e20a29168e811e758b09c2a4732793 +size 1568136 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py27_np16.gz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py27_np16.gz new file mode 100644 index 0000000000000000000000000000000000000000..fedefdd304054a85fa995801885f997ca8e1a44f --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py27_np16.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:418447e90d83486568ae3092a960b18d358230e24ac9ec38365daa99f415bd0f +size 769 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py27_np17.gz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py27_np17.gz new file mode 100644 index 0000000000000000000000000000000000000000..3fd32f71887ddd0c94d06c8a77afccc322fed583 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py27_np17.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1f4e8cccfca94f25ae744d1f050b0734f663263ba38ed0642181404b348b17b +size 757 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py33_np18.gz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py33_np18.gz new file mode 100644 index 0000000000000000000000000000000000000000..7cd1fcc9dc7a04d7ac251d3b1bbf973609b947b8 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py33_np18.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9e215780f978ce693e48110ead23652e1c6de1c2189172232690198f7088788 +size 792 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py34_np19.gz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py34_np19.gz new file mode 100644 index 0000000000000000000000000000000000000000..7cc0ad77a184883b0268bffd1e2f3d195ae99f6b --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py34_np19.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1abdb3ff5b555831f51f7ff00951e66a49277fc2aa787293f18cf8775be65023 +size 794 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py35_np19.gz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py35_np19.gz new file mode 100644 index 0000000000000000000000000000000000000000..878decdcad534f6d2cdd14a487c207f8c6133261 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_compressed_pickle_py35_np19.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a56c3fc6e0db3a4102aaed4a19fd4e154eecd956f30b6bf9179897844ed3c01e +size 790 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py27_np17.pkl b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py27_np17.pkl new file mode 100644 index 0000000000000000000000000000000000000000..38202b72d880ff53255dbb86eccd73bdb6224c74 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py27_np17.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89c4508e3dfbe01f801e4e739f1aded13f685941e89281c8050f0ca8aa3c97e5 +size 986 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py27_np17.pkl.bz2 b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py27_np17.pkl.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..77ea56324322ed8657967221aad47bc0063a12fc --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py27_np17.pkl.bz2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a18415232322531c918164ae04148ebc258acd3a00fa4529728416755e14a15e +size 997 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py27_np17.pkl.xz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py27_np17.pkl.xz new file mode 100644 index 0000000000000000000000000000000000000000..7812497bc95e5894c8e880736bfb06aa22bb2fae --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py27_np17.pkl.xz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:efb146d450c6d061d06affb56f17384e7f64cbab9b516fcc6c4d3f8869b3e707 +size 712 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py33_np18.pkl b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py33_np18.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d6cf697b1c1c752d4d8a78d702a70042ad047ce9 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py33_np18.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e064c2eecfdc58d552844467da7bd56eca596098322bfd266a7e1312abdd5735 +size 1068 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py33_np18.pkl.bz2 b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py33_np18.pkl.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..317981559c4a9987aa099efeb68e4359c08d71ec --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py33_np18.pkl.bz2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e86d6f6ecfe2626cf691827ac38a81d64ec3ebb527c5432eb344b8496781b45a +size 1000 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py33_np18.pkl.xz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py33_np18.pkl.xz new file mode 100644 index 0000000000000000000000000000000000000000..826c9ba7b9579a988f8f1718219bbc41fd1ad756 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py33_np18.pkl.xz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e9a63dcc7df38ab0a1137a9b44b436b13cebfa300eb19dba4ae4bce50d0fa81 +size 752 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py34_np19.pkl b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py34_np19.pkl new file mode 100644 index 0000000000000000000000000000000000000000..f22c25bdb59d15a3771104dff6dfebe564e98add --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py34_np19.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1cbe456f5b91f5a3cb8e386838f276c30335432a351426686187761d5c34168b +size 1068 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py34_np19.pkl.bz2 b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py34_np19.pkl.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..80818a8baa1e2481b62bed06bb2b95f4a614cc3a --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py34_np19.pkl.bz2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f2af67ea667c1f5315ddcab06bfa447005863c1c0fd88bb7e04a0b8acb9a54b +size 1021 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py34_np19.pkl.xz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py34_np19.pkl.xz new file mode 100644 index 0000000000000000000000000000000000000000..1cd5660dd8d32121e8dd7b40534187add5981639 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py34_np19.pkl.xz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04d7e68907e978b56975f9309492b8849e42a60974beb795c9e93273977f3cd3 +size 752 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py35_np19.pkl b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py35_np19.pkl new file mode 100644 index 0000000000000000000000000000000000000000..360af38dc3a9bde47e3b18b144dc1c5257e7daca --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py35_np19.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97b9ef2e896104321d3c5ce73b3de504788c38f04f08c8b56d7a29d6d1520a96 +size 1068 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py35_np19.pkl.bz2 b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py35_np19.pkl.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..06e8395437874c25cfdf6a6783eab12a6c178f90 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py35_np19.pkl.bz2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6a1a9b884be654e2e3fc9a48251ecf0c6920e255c3f2ee5dd71d8252a694606 +size 1005 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py35_np19.pkl.xz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py35_np19.pkl.xz new file mode 100644 index 0000000000000000000000000000000000000000..cec2871b09ae347e07c81eb55e7979300748ccd1 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.10.0_pickle_py35_np19.pkl.xz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02cf30d8b196c303662b2dd035d2a58caeb762ae3a82345ffd1274961e7f5aa0 +size 752 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.11.0_compressed_pickle_py36_np111.gz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.11.0_compressed_pickle_py36_np111.gz new file mode 100644 index 0000000000000000000000000000000000000000..f2e65e202609648f0a5464ae5b78b9f9fba8dd6e --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.11.0_compressed_pickle_py36_np111.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d56ae75c3a83a0d10f60e657d50e56af6e3addbf2f555e9fc385a6e52e1b32de +size 800 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.11.0_pickle_py36_np111.pkl b/lib/python3.10/site-packages/joblib/test/data/joblib_0.11.0_pickle_py36_np111.pkl new file mode 100644 index 0000000000000000000000000000000000000000..4dda21d9ad4ce279b8474ecce9697e3290e96bfa --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.11.0_pickle_py36_np111.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e6b0e171782d5fd5a61d1844dc946eb27c5f6b2e8075d436b23808433142ebc +size 1068 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.11.0_pickle_py36_np111.pkl.bz2 b/lib/python3.10/site-packages/joblib/test/data/joblib_0.11.0_pickle_py36_np111.pkl.bz2 new file mode 100644 index 0000000000000000000000000000000000000000..895dd324d574d9b2298833317a76f3794209bbb3 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.11.0_pickle_py36_np111.pkl.bz2 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc8db259be742ca2ff36067277f5e4a03e6d78883ddee238da65a7c7d79ef804 +size 991 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.11.0_pickle_py36_np111.pkl.xz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.11.0_pickle_py36_np111.pkl.xz new file mode 100644 index 0000000000000000000000000000000000000000..c7607dcdb2b09e7a50acc3239cc585974e7a09e6 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.11.0_pickle_py36_np111.pkl.xz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd787f35b3197418d8c7bca77c9dc5ca47b6f22cd24524b3ccd074cf90f893d6 +size 752 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.8.4_compressed_pickle_py27_np17.gz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.8.4_compressed_pickle_py27_np17.gz new file mode 100644 index 0000000000000000000000000000000000000000..fc4e28719d5acc118ac1d8bd8cdd15227eef25ba --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.8.4_compressed_pickle_py27_np17.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a9f994fb8baa63e689f681ba6db33bbb45aaf32693a61c9ebb50a3a608f40c8 +size 659 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_compressed_pickle_py27_np16.gz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_compressed_pickle_py27_np16.gz new file mode 100644 index 0000000000000000000000000000000000000000..1238376dd6ac2e166bf56f263862afe56b866da3 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_compressed_pickle_py27_np16.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34bb43aefa365c81f42af51402f84ea8c7a85c48c65b422e4e4fe8b2ee57883c +size 658 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_compressed_pickle_py27_np17.gz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_compressed_pickle_py27_np17.gz new file mode 100644 index 0000000000000000000000000000000000000000..1238376dd6ac2e166bf56f263862afe56b866da3 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_compressed_pickle_py27_np17.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34bb43aefa365c81f42af51402f84ea8c7a85c48c65b422e4e4fe8b2ee57883c +size 658 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_compressed_pickle_py34_np19.gz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_compressed_pickle_py34_np19.gz new file mode 100644 index 0000000000000000000000000000000000000000..0720a70aee276c37f9457817922ae60b67600d47 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_compressed_pickle_py34_np19.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f33bd8a21a41b729b05dac5deeb0e868f218a092b0e3fe5988094cf167217f6 +size 673 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_compressed_pickle_py35_np19.gz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_compressed_pickle_py35_np19.gz new file mode 100644 index 0000000000000000000000000000000000000000..0720a70aee276c37f9457817922ae60b67600d47 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_compressed_pickle_py35_np19.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f33bd8a21a41b729b05dac5deeb0e868f218a092b0e3fe5988094cf167217f6 +size 673 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl new file mode 100644 index 0000000000000000000000000000000000000000..f7ca0addc6d032e93d0b530a2b42a583fb0d4b81 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9da8a3764db121e29d21ade67c9c3426598e76d88deae44cd7238983af8cef73 +size 670 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl_01.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl_01.npy new file mode 100644 index 0000000000000000000000000000000000000000..15574a4193ad4ad724b2b8053c701a82efa78fd5 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl_01.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0efbd7d9ce7eec3a6e0a0db41e795e0396cca3d6b037dad6c61b464843d28809 +size 120 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl_02.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl_02.npy new file mode 100644 index 0000000000000000000000000000000000000000..f00f08fbeeda280fa3ce00069c313c5412a33eca --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl_02.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c1cf36cb781fbcc21b953bb0a0b45df092da0eae0e765882e5963ccd70105b1 +size 120 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl_03.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl_03.npy new file mode 100644 index 0000000000000000000000000000000000000000..ccc84c361de2569ed5cb91967f9063efcd84dd14 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl_03.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0c45ae2a289841cbeba2443b7ebaa3b31c0a9e9dcc73294aca5729da0092405 +size 236 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl_04.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl_04.npy new file mode 100644 index 0000000000000000000000000000000000000000..e9b5e77c73268dfff541b576126f06fc6fed3d59 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np16.pkl_04.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ecbe244294ba93e08479b16c1b9a9411e3569ff660ed0459dca1d241381df05 +size 104 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl new file mode 100644 index 0000000000000000000000000000000000000000..976cba8c28be9a3dd0075efe5a6b3ce704319161 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f29d7f1d2ceca07f10df172c0e826ef08163a14b12c6ef3fa80ec53a5fcdc3c +size 670 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl_01.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl_01.npy new file mode 100644 index 0000000000000000000000000000000000000000..15574a4193ad4ad724b2b8053c701a82efa78fd5 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl_01.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0efbd7d9ce7eec3a6e0a0db41e795e0396cca3d6b037dad6c61b464843d28809 +size 120 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl_02.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl_02.npy new file mode 100644 index 0000000000000000000000000000000000000000..f00f08fbeeda280fa3ce00069c313c5412a33eca --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl_02.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c1cf36cb781fbcc21b953bb0a0b45df092da0eae0e765882e5963ccd70105b1 +size 120 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl_03.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl_03.npy new file mode 100644 index 0000000000000000000000000000000000000000..ccc84c361de2569ed5cb91967f9063efcd84dd14 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl_03.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0c45ae2a289841cbeba2443b7ebaa3b31c0a9e9dcc73294aca5729da0092405 +size 236 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl_04.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl_04.npy new file mode 100644 index 0000000000000000000000000000000000000000..e9b5e77c73268dfff541b576126f06fc6fed3d59 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py27_np17.pkl_04.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ecbe244294ba93e08479b16c1b9a9411e3569ff660ed0459dca1d241381df05 +size 104 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl new file mode 100644 index 0000000000000000000000000000000000000000..e739b6d035cdf110063dbb8b2cdceb116e187019 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3d4cbc690d3ce9e5323a714ea546f32c01ab1710285c420184f6cdf4b26fc25 +size 691 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl_01.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl_01.npy new file mode 100644 index 0000000000000000000000000000000000000000..15574a4193ad4ad724b2b8053c701a82efa78fd5 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl_01.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0efbd7d9ce7eec3a6e0a0db41e795e0396cca3d6b037dad6c61b464843d28809 +size 120 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl_02.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl_02.npy new file mode 100644 index 0000000000000000000000000000000000000000..f00f08fbeeda280fa3ce00069c313c5412a33eca --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl_02.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c1cf36cb781fbcc21b953bb0a0b45df092da0eae0e765882e5963ccd70105b1 +size 120 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl_03.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl_03.npy new file mode 100644 index 0000000000000000000000000000000000000000..73976395be90d4b2b2d955c79a90721e16cebc82 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl_03.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ede9a64a52b25d7db30950956c978ec0b3932b7d14acd5abc63216e64babde7 +size 307 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl_04.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl_04.npy new file mode 100644 index 0000000000000000000000000000000000000000..e9b5e77c73268dfff541b576126f06fc6fed3d59 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py33_np18.pkl_04.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ecbe244294ba93e08479b16c1b9a9411e3569ff660ed0459dca1d241381df05 +size 104 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl new file mode 100644 index 0000000000000000000000000000000000000000..19b2b0b8ee910063dcc1b24a87b1f387604c3706 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a538100e6ae94b16f2ab0f7d92d4d7e7a622be2dfcc0f6b0b73b623bc513ae2 +size 691 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl_01.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl_01.npy new file mode 100644 index 0000000000000000000000000000000000000000..15574a4193ad4ad724b2b8053c701a82efa78fd5 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl_01.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0efbd7d9ce7eec3a6e0a0db41e795e0396cca3d6b037dad6c61b464843d28809 +size 120 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl_02.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl_02.npy new file mode 100644 index 0000000000000000000000000000000000000000..f00f08fbeeda280fa3ce00069c313c5412a33eca --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl_02.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c1cf36cb781fbcc21b953bb0a0b45df092da0eae0e765882e5963ccd70105b1 +size 120 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl_03.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl_03.npy new file mode 100644 index 0000000000000000000000000000000000000000..73976395be90d4b2b2d955c79a90721e16cebc82 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl_03.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ede9a64a52b25d7db30950956c978ec0b3932b7d14acd5abc63216e64babde7 +size 307 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl_04.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl_04.npy new file mode 100644 index 0000000000000000000000000000000000000000..e9b5e77c73268dfff541b576126f06fc6fed3d59 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py34_np19.pkl_04.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ecbe244294ba93e08479b16c1b9a9411e3569ff660ed0459dca1d241381df05 +size 104 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl new file mode 100644 index 0000000000000000000000000000000000000000..93417ab8e94e4542a24211ad514948f9d1b80a3a --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59f0d522a29c333ce1d60480b2121fcc1a08a5d2dd650b86efdc987f991fa4ea +size 691 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl_01.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl_01.npy new file mode 100644 index 0000000000000000000000000000000000000000..15574a4193ad4ad724b2b8053c701a82efa78fd5 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl_01.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0efbd7d9ce7eec3a6e0a0db41e795e0396cca3d6b037dad6c61b464843d28809 +size 120 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl_02.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl_02.npy new file mode 100644 index 0000000000000000000000000000000000000000..f00f08fbeeda280fa3ce00069c313c5412a33eca --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl_02.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c1cf36cb781fbcc21b953bb0a0b45df092da0eae0e765882e5963ccd70105b1 +size 120 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl_03.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl_03.npy new file mode 100644 index 0000000000000000000000000000000000000000..73976395be90d4b2b2d955c79a90721e16cebc82 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl_03.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ede9a64a52b25d7db30950956c978ec0b3932b7d14acd5abc63216e64babde7 +size 307 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl_04.npy b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl_04.npy new file mode 100644 index 0000000000000000000000000000000000000000..e9b5e77c73268dfff541b576126f06fc6fed3d59 --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.2_pickle_py35_np19.pkl_04.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ecbe244294ba93e08479b16c1b9a9411e3569ff660ed0459dca1d241381df05 +size 104 diff --git a/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.4.dev0_compressed_cache_size_pickle_py35_np19.gz b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.4.dev0_compressed_cache_size_pickle_py35_np19.gz new file mode 100644 index 0000000000000000000000000000000000000000..e3125fe0fd4709dbd0067e67a06a3f24073934ad --- /dev/null +++ b/lib/python3.10/site-packages/joblib/test/data/joblib_0.9.4.dev0_compressed_cache_size_pickle_py35_np19.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2361f589b31d2863627edcb96612280ae5c0a59c9496d89dab7de493670f93b +size 802 diff --git a/lib/python3.10/site-packages/matplotlib/__pycache__/_cm_listed.cpython-310.pyc b/lib/python3.10/site-packages/matplotlib/__pycache__/_cm_listed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a06ab0055b3e1a75ef281fbdfea72770eb671cbc --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/__pycache__/_cm_listed.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b98494ac39ae02387f165dc2e152852da6a5ce58ca5b3674110014bdc5e74ea +size 128809 diff --git a/lib/python3.10/site-packages/matplotlib/__pycache__/backend_bases.cpython-310.pyc b/lib/python3.10/site-packages/matplotlib/__pycache__/backend_bases.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b11285a72bc8f0b51f60439463915ccdea137f8 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/__pycache__/backend_bases.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb4111defb1fcc8538d11eab1ccb8cabbcdc5ad729e8a53bafe3ef109234ba4e +size 117357 diff --git a/lib/python3.10/site-packages/matplotlib/__pycache__/colors.cpython-310.pyc b/lib/python3.10/site-packages/matplotlib/__pycache__/colors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a79e31012a88c8bd0f2595ae7d9015a2a0e0dba --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/__pycache__/colors.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3d7f16472a922be70153536682f03eebf4f5a43bf5549c2653703f5da8105f5 +size 118928 diff --git a/lib/python3.10/site-packages/matplotlib/__pycache__/figure.cpython-310.pyc b/lib/python3.10/site-packages/matplotlib/__pycache__/figure.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..beff81642781fd61d88f055346a6627d3aa0149b --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/__pycache__/figure.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba49d07d533322674c8a9ed024a3ad425a418a95accd110235ccac87888d4ba1 +size 119764 diff --git a/lib/python3.10/site-packages/matplotlib/__pycache__/patches.cpython-310.pyc b/lib/python3.10/site-packages/matplotlib/__pycache__/patches.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2dd9afca3bc797374aef6555b7ca39427626f085 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/__pycache__/patches.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:581f6c3fe6fec95b2c5fca591195426fc79529750411f49e6a45e71917c2e466 +size 141974 diff --git a/lib/python3.10/site-packages/matplotlib/__pycache__/pyplot.cpython-310.pyc b/lib/python3.10/site-packages/matplotlib/__pycache__/pyplot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5727959e295c2ea0ae1e812ea1e759dece8c1e61 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/__pycache__/pyplot.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a3e2be69200fa41f0e6813673c47800070205c86bbaef774237afee5f29f889 +size 121133 diff --git a/lib/python3.10/site-packages/matplotlib/__pycache__/widgets.cpython-310.pyc b/lib/python3.10/site-packages/matplotlib/__pycache__/widgets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b72fa1b01b947bc2551b1509547c3884566a8569 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/__pycache__/widgets.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f4612291b5cb2fed4ba7e084d86a426a1f9cc4fe1f7a4a4cae26ff2431a129b +size 120139 diff --git a/lib/python3.10/site-packages/matplotlib/axes/__pycache__/_axes.cpython-310.pyc b/lib/python3.10/site-packages/matplotlib/axes/__pycache__/_axes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f62a3d5b58df08fa4a1474552f8ced332b6faa92 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/axes/__pycache__/_axes.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8160441d38059ebf23febc2e6d39c7808269ed3d9d589cc2e5c8bd4de0f31a3f +size 275734 diff --git a/lib/python3.10/site-packages/matplotlib/axes/__pycache__/_base.cpython-310.pyc b/lib/python3.10/site-packages/matplotlib/axes/__pycache__/_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eada12e13a4b0956874d8c96560eb40441e2161 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/axes/__pycache__/_base.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1cd8d6b607a6c79f6e4650163b636dc12442b102758951b051a1213232345bd7 +size 149511 diff --git a/lib/python3.10/site-packages/matplotlib/backends/_backend_agg.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/matplotlib/backends/_backend_agg.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..3332716429f37bb14faf5f18bd62782c94a499d1 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/backends/_backend_agg.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e3a3611f9d739868a34b767fb85387c6b200628c147dae881a7ba02244c60b28 +size 741680 diff --git a/lib/python3.10/site-packages/matplotlib/backends/_tkagg.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/matplotlib/backends/_tkagg.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..585f678fd2955a455c631d1f4c14a475654b897a --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/backends/_tkagg.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99882d0e2dcf8ad31d19c4669a3f8d0acf7f54719da08d38d071b10653aa93a7 +size 279120 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-Bold.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-Bold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..39699bd84ae17507955ce571caa859ad14997f4e --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-Bold.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b184b89e3c1075f22f6b71575b6fc20d4972b3cfd3b23322ca6fd596dcaef167 +size 704128 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-BoldOblique.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-BoldOblique.ttf new file mode 100644 index 0000000000000000000000000000000000000000..cfd64ecefe8e5cdf2a0ea7d128019650a68ee682 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-BoldOblique.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6edf0283160186af451cbee71e7b845f2e4cabf264bb992ce668c83c25465e6f +size 641720 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-Oblique.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-Oblique.ttf new file mode 100644 index 0000000000000000000000000000000000000000..4ddf1fabb8648ac811ba05df9980121ec6a36452 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans-Oblique.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ccdf74b350f11fd3dd5774de50e5e6346a1a5da1f5b7d5fb83590665e97a5213 +size 633840 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf new file mode 100644 index 0000000000000000000000000000000000000000..f47aa8cffd396d3ac221ad7a2a518bc9a554c9ce --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3fdf69cabf06049ea70a00b5919340e2ce1e6d02b0cc3c4b44fb6801bd1e0d22 +size 756072 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-Bold.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-Bold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..fb4675de56d8c2b2b09734578191896c7bf6948e --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-Bold.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:baada9a5172fe20886251aff0433fc38461912d5daf07287e7bee56620a8da96 +size 331536 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-BoldOblique.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-BoldOblique.ttf new file mode 100644 index 0000000000000000000000000000000000000000..ee85d0dc6915489b236e932ab89838b200443aba --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-BoldOblique.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a69081c15c76c827e0a27a5a7f5c74b6135c843499955495ffa8c20d3a98288b +size 253116 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-Oblique.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-Oblique.ttf new file mode 100644 index 0000000000000000000000000000000000000000..ba2147393770bbb669b1a6c735190f3cf6e732a8 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono-Oblique.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28052813f7a709fc89f52d192dc995ef4f0fdc5c3d7b73a49d6849b1916d0cd0 +size 251472 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono.ttf new file mode 100644 index 0000000000000000000000000000000000000000..60f42148ea53bf551fc77faef9d6382ff0efcb16 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSansMono.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:602ec86b8948cfcd956482fe64f94c36c867770149ef2f791d4613f443bcecb3 +size 340240 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-Bold.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-Bold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..0137479cb3e30823976d9137d9e75ed8e225c57f --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-Bold.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3753f2ed6bc673f15846dc45addbeb3b9c872f32fb18fd53a21f1bef1ed7676 +size 355692 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-BoldItalic.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-BoldItalic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..b89793b3fe00a3510c0a22b54d9ee7735e5e0803 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-BoldItalic.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d93efec7a9d2e826768d1a2ee95b95870e15e29599a84f3484af1de1cec2e181 +size 347064 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-Italic.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-Italic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..c1a70e8374879a0c04e370eb5c57861c4fd5abce --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif-Italic.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e7994fbc54fa10ce3352a42d548fadd7d9cadb69cb1109bc9d960f6dac57f04 +size 345612 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif.ttf new file mode 100644 index 0000000000000000000000000000000000000000..73a4204e08a49a83ab7336cb55add8b32200ae24 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:107244956e9962b9e96faccdc551825e0ae0898ae13737133e1b921a2fd35ffa +size 379740 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneral.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneral.ttf new file mode 100644 index 0000000000000000000000000000000000000000..cbf159757c5c0a47f1d8f2b984f0c20f52b7e25b --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneral.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:167378031e2dddc6216d67819c9260e9a06ffc4c478e4e23cb98a6fd44b183c2 +size 448228 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralBol.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralBol.ttf new file mode 100644 index 0000000000000000000000000000000000000000..64ca51419ec522ee78c91e6f36f57fdb3bde290b --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralBol.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8533dc7083fa346bda1933d60ea4a83b67d0945bceaf1b3541f82b4a0e2c6a0 +size 237360 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralBolIta.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralBolIta.ttf new file mode 100644 index 0000000000000000000000000000000000000000..585e03c626a96775f018e690fc6e10f6cd3e4047 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralBolIta.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98788fd4ba48dfbb2bd026c0e20a247a8b06c7372879628b7a6bb0d5bb09736c +size 181152 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralItalic.ttf b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralItalic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..67ce0a807f6ae1dc2b4f7b863ce99620d29e6cb7 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/STIXGeneralItalic.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cfcb333d22b7c3c623bdfd40174f14c85c3d6731ca6166c1edc80140eae8527 +size 175040 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/axes_grid/bivariate_normal.npy b/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/axes_grid/bivariate_normal.npy new file mode 100644 index 0000000000000000000000000000000000000000..a76ff8b1f63c970d5e02b481a16b3231408aebe4 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/axes_grid/bivariate_normal.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e9599f6e74087aa2ca58aa77846b6ec3e8491180e445c07a2c69c65756ef7c5 +size 1880 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/goog.npz b/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/goog.npz new file mode 100644 index 0000000000000000000000000000000000000000..bd82e71bf7d72da7db030381c31f769a3d9736bb --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/goog.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:400917cf30e6b664f7b0da93d7c745860d3aa9008da8b7f160d2dd12e6a318b1 +size 22845 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/jacksboro_fault_dem.npz b/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/jacksboro_fault_dem.npz new file mode 100644 index 0000000000000000000000000000000000000000..065732313f852fe861ea6852a84f8c59666c90b1 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/jacksboro_fault_dem.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d493f50a33e82a4420494c54d1fca1539d177bdc27ab190bc5fe6e92f62fb637 +size 174061 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/s1045.ima.gz b/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/s1045.ima.gz new file mode 100644 index 0000000000000000000000000000000000000000..29d1c7f70afaa882d9673861bc3d17a4ffd7235d --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/s1045.ima.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32b424d64f62b7e71cb24d29fd53938ad5664d608055a67ab2b2af4369f8b89e +size 33229 diff --git a/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/topobathy.npz b/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/topobathy.npz new file mode 100644 index 0000000000000000000000000000000000000000..67fc6c403643c5b4e0624005b7bd99ac59e856fd --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/mpl-data/sample_data/topobathy.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0244e03291702df45024dcb5cacbc4f3d4cb30d72dfa7fd371c4ac61c42b4fbf +size 45224 diff --git a/lib/python3.10/site-packages/matplotlib/tests/__pycache__/test_axes.cpython-310.pyc b/lib/python3.10/site-packages/matplotlib/tests/__pycache__/test_axes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4af09786e5eac0037c85651ef29fd47e1aabf9a3 --- /dev/null +++ b/lib/python3.10/site-packages/matplotlib/tests/__pycache__/test_axes.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20cfbda4aa7bbb858d30ed0bbb31121e47ec16ee452ce20be47c9b0fc66a727d +size 292141 diff --git a/lib/python3.10/site-packages/mpl_toolkits/mplot3d/__pycache__/axes3d.cpython-310.pyc b/lib/python3.10/site-packages/mpl_toolkits/mplot3d/__pycache__/axes3d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a36ae282342b797249a3f9dc9ca0b3561d909e9a --- /dev/null +++ b/lib/python3.10/site-packages/mpl_toolkits/mplot3d/__pycache__/axes3d.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b639fca647b8e600a3fdc94e9e20c058b6870570d6aa5673c950de5a837d2c3 +size 122039 diff --git a/lib/python3.10/site-packages/mpmath/__pycache__/function_docs.cpython-310.pyc b/lib/python3.10/site-packages/mpmath/__pycache__/function_docs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eda6486ea0ee380f4b5a405507c4a8a1651494dc --- /dev/null +++ b/lib/python3.10/site-packages/mpmath/__pycache__/function_docs.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5fc48195e306938f14c6634be0078b3a10cacc1d7fc20e34b6a6c47126ec1ec7 +size 283793 diff --git a/lib/python3.10/site-packages/narwhals/__pycache__/dataframe.cpython-310.pyc b/lib/python3.10/site-packages/narwhals/__pycache__/dataframe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..337b1b5d0524de9cd895f5e3cb981d7b0ef1b738 --- /dev/null +++ b/lib/python3.10/site-packages/narwhals/__pycache__/dataframe.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8afd2699ff27b0123409263df0a2efa91f973cceef73d747bfbb62cd349f85f +size 128986 diff --git a/lib/python3.10/site-packages/narwhals/__pycache__/expr.cpython-310.pyc b/lib/python3.10/site-packages/narwhals/__pycache__/expr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d064e73a0804b9b6f97a74851eab5ae9b467f0eb --- /dev/null +++ b/lib/python3.10/site-packages/narwhals/__pycache__/expr.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:700d25751b3c9cb44b3ff1493a944b5e8bccb362e448957ac128162a754a8442 +size 102388 diff --git a/lib/python3.10/site-packages/pkg_resources/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/pkg_resources/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ccf303152e176f61d071b4f24901f1f08569450 --- /dev/null +++ b/lib/python3.10/site-packages/pkg_resources/__pycache__/__init__.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:88806c75aa02617f894fd74ae39b59650dc92198181513f41fcc3534a818c468 +size 115662 diff --git a/lib/python3.10/site-packages/pkg_resources/tests/data/my-test-package-zip/my-test-package.zip b/lib/python3.10/site-packages/pkg_resources/tests/data/my-test-package-zip/my-test-package.zip new file mode 100644 index 0000000000000000000000000000000000000000..400905eecf0385de4d3b8e50b9892e1302b2b894 --- /dev/null +++ b/lib/python3.10/site-packages/pkg_resources/tests/data/my-test-package-zip/my-test-package.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01845c437f4655e3cf9cc4fc4e49cfd607431f22675e1b611129a90239f34822 +size 1809 diff --git a/lib/python3.10/site-packages/plotly/__pycache__/basedatatypes.cpython-310.pyc b/lib/python3.10/site-packages/plotly/__pycache__/basedatatypes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3295d02d77281103e08382d8386cc736e8adf9ff --- /dev/null +++ b/lib/python3.10/site-packages/plotly/__pycache__/basedatatypes.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee0cdc421ee15e3d63f26d386fdaa01a325a677429a9e3e6ab1bdfd515d0a501 +size 163302 diff --git a/lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_box.cpython-310.pyc b/lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_box.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fba59f7b6a2af7235917fb420c8bdfc0abe141b7 --- /dev/null +++ b/lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_box.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58b6514121115461b5d5155afbb1063598b71eb9d924d8ae36241cc64311fdbe +size 106261 diff --git a/lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_figure.cpython-310.pyc b/lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_figure.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46b638f6031a08a4c893f94c5f48c367f871a7b0 --- /dev/null +++ b/lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_figure.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b690b2db311bed8987c0be11ab594edf23b6b37a1f8e3fc04afdcf14ecf7c1d6 +size 968352 diff --git a/lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_figurewidget.cpython-310.pyc b/lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_figurewidget.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e09f8880996059518f53780da0e62012f52d12a --- /dev/null +++ b/lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_figurewidget.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6762c2b496645a2222196eae92fb3a79616106a6e0cba13a53bb352052206a24 +size 969557 diff --git a/lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_layout.cpython-310.pyc b/lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_layout.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..268114177946203c65a826236a3337d6af6dff7d --- /dev/null +++ b/lib/python3.10/site-packages/plotly/graph_objs/__pycache__/_layout.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c6cb6f6f17bc24cd732e92dd6b0414f181efda885d160281c2af4640f53649a +size 122390 diff --git a/lib/python3.10/site-packages/plotly/graph_objs/layout/__pycache__/_xaxis.cpython-310.pyc b/lib/python3.10/site-packages/plotly/graph_objs/layout/__pycache__/_xaxis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e128d8fa9b43dcc12aff19cdb87ba39679675d4 --- /dev/null +++ b/lib/python3.10/site-packages/plotly/graph_objs/layout/__pycache__/_xaxis.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11218f87e47116b0dace50b126ebe09aadd253bdfbdf311969da3893a847efe2 +size 125605 diff --git a/lib/python3.10/site-packages/plotly/graph_objs/layout/__pycache__/_yaxis.cpython-310.pyc b/lib/python3.10/site-packages/plotly/graph_objs/layout/__pycache__/_yaxis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07f5b7e47b131358e4574ceb84e242f21fbf4a44 --- /dev/null +++ b/lib/python3.10/site-packages/plotly/graph_objs/layout/__pycache__/_yaxis.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87c2469a6aaa6a15536ea934003bf538edce1527d30cfbcd3ed39a27cdd521fd +size 127142 diff --git a/lib/python3.10/site-packages/plotly/package_data/datasets/carshare.csv.gz b/lib/python3.10/site-packages/plotly/package_data/datasets/carshare.csv.gz new file mode 100644 index 0000000000000000000000000000000000000000..2c41f79c63ae8f66f87992f59139fd6e5c0cc5a6 --- /dev/null +++ b/lib/python3.10/site-packages/plotly/package_data/datasets/carshare.csv.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0bbf2c15a87142158e5975719a8ea197ce27bd08430df7c903a3e225aba5cd0 +size 6215 diff --git a/lib/python3.10/site-packages/plotly/package_data/datasets/election.csv.gz b/lib/python3.10/site-packages/plotly/package_data/datasets/election.csv.gz new file mode 100644 index 0000000000000000000000000000000000000000..5f471b911945dc6ae4b10796727fd599b894a42d --- /dev/null +++ b/lib/python3.10/site-packages/plotly/package_data/datasets/election.csv.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54bf047516fde7930c3fbc6b11b1dabae29b838507b56772f0e41a349abb82be +size 1656 diff --git a/lib/python3.10/site-packages/plotly/package_data/datasets/election.geojson.gz b/lib/python3.10/site-packages/plotly/package_data/datasets/election.geojson.gz new file mode 100644 index 0000000000000000000000000000000000000000..24fd6ee308e7b0b6ec31a0c49665bdc966f25a91 --- /dev/null +++ b/lib/python3.10/site-packages/plotly/package_data/datasets/election.geojson.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c67ef7244a8d39d254a8659a569437f6ab214e78e25dc34b198b8b4a455f6fc +size 31857 diff --git a/lib/python3.10/site-packages/plotly/package_data/datasets/experiment.csv.gz b/lib/python3.10/site-packages/plotly/package_data/datasets/experiment.csv.gz new file mode 100644 index 0000000000000000000000000000000000000000..419eae9bcd97691b4d3692ff8088555edad6019f --- /dev/null +++ b/lib/python3.10/site-packages/plotly/package_data/datasets/experiment.csv.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff3ffab64fa8078e9162ab8fd10b70c2c1f17f5d5adba36467f0826768d00f91 +size 3154 diff --git a/lib/python3.10/site-packages/plotly/package_data/datasets/gapminder.csv.gz b/lib/python3.10/site-packages/plotly/package_data/datasets/gapminder.csv.gz new file mode 100644 index 0000000000000000000000000000000000000000..bd32c6c7edd79a2f4e6f3c7fe23c904551b85ef9 --- /dev/null +++ b/lib/python3.10/site-packages/plotly/package_data/datasets/gapminder.csv.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa7af7b739ac5a4cdf47e32a085ce74e351bbf5d2ae33890a298b1904afd0dfd +size 34016 diff --git a/lib/python3.10/site-packages/plotly/package_data/datasets/iris.csv.gz b/lib/python3.10/site-packages/plotly/package_data/datasets/iris.csv.gz new file mode 100644 index 0000000000000000000000000000000000000000..7fb68e454a0d7824f9451fb4cf6b22e95d49468d --- /dev/null +++ b/lib/python3.10/site-packages/plotly/package_data/datasets/iris.csv.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b92fa7bbdbf9720fba1c021a6dc0fb7eac0827094bb4045e70c9960582eaab5 +size 875 diff --git a/lib/python3.10/site-packages/plotly/package_data/datasets/medals.csv.gz b/lib/python3.10/site-packages/plotly/package_data/datasets/medals.csv.gz new file mode 100644 index 0000000000000000000000000000000000000000..61bd9630f4603c5142909b862cd22ff75ae57909 --- /dev/null +++ b/lib/python3.10/site-packages/plotly/package_data/datasets/medals.csv.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71b233248152a78b51612463628682e0d46855b6be6c1838ed8b091bbaf12dd9 +size 110 diff --git a/lib/python3.10/site-packages/plotly/package_data/datasets/stocks.csv.gz b/lib/python3.10/site-packages/plotly/package_data/datasets/stocks.csv.gz new file mode 100644 index 0000000000000000000000000000000000000000..bbacdf10e489f607f4ececddf6b862b64580a6d1 --- /dev/null +++ b/lib/python3.10/site-packages/plotly/package_data/datasets/stocks.csv.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8bb4cd39dc16c3408c311383344d359690930a10f480e522895d3983a30d789a +size 5895 diff --git a/lib/python3.10/site-packages/plotly/package_data/datasets/tips.csv.gz b/lib/python3.10/site-packages/plotly/package_data/datasets/tips.csv.gz new file mode 100644 index 0000000000000000000000000000000000000000..5a4cd2ef6f02faaea5ce0056eb150ff5d5cd6d3b --- /dev/null +++ b/lib/python3.10/site-packages/plotly/package_data/datasets/tips.csv.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cba42a1384173aafc7b8574f063d48189800f9e4fdc6d9ae0b8f0a2be106fa2a +size 1740 diff --git a/lib/python3.10/site-packages/plotly/package_data/datasets/wind.csv.gz b/lib/python3.10/site-packages/plotly/package_data/datasets/wind.csv.gz new file mode 100644 index 0000000000000000000000000000000000000000..7042744d93c60dcffae845f229cdf844fe4525c0 --- /dev/null +++ b/lib/python3.10/site-packages/plotly/package_data/datasets/wind.csv.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9fbad0081d48606f87ff940b23c8f0a31d5530c364f6c87acb631004a43fbabf +size 424 diff --git a/lib/python3.10/site-packages/pycparser/__pycache__/yacctab.cpython-310.pyc b/lib/python3.10/site-packages/pycparser/__pycache__/yacctab.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..145b2fb4da4143d9e11e4087f5d7cff5776f6c52 --- /dev/null +++ b/lib/python3.10/site-packages/pycparser/__pycache__/yacctab.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5016929da2b621a42951b4aa4fc1684c718d67f679f4c13f9464198ee84b15d +size 177070 diff --git a/lib/python3.10/site-packages/pygments/lexers/__pycache__/lisp.cpython-310.pyc b/lib/python3.10/site-packages/pygments/lexers/__pycache__/lisp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4aa19d49e18f5420f9d2b2da0d541ff0b47f130 --- /dev/null +++ b/lib/python3.10/site-packages/pygments/lexers/__pycache__/lisp.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89f43ca911aca28f4a0259334c25525d68e42fa2ca787061b7f892b25685a8ef +size 107477 diff --git a/lib/python3.10/site-packages/scipy/cluster/__pycache__/hierarchy.cpython-310.pyc b/lib/python3.10/site-packages/scipy/cluster/__pycache__/hierarchy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..638a345addf4201eb71e5f85bfd437d5bc22c8f1 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/cluster/__pycache__/hierarchy.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a7e83c47b888a1b3a922d6d6a1f51b5b73c8b0f8dff45d12abdac2d29b7e74e +size 131124 diff --git a/lib/python3.10/site-packages/scipy/constants/__pycache__/_codata.cpython-310.pyc b/lib/python3.10/site-packages/scipy/constants/__pycache__/_codata.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e65925667d12edd67ba0052cf18093886db70e8 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/constants/__pycache__/_codata.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1052d7d742f833dee8c304e3d8502bb461756cf4b12ad221f8a392670a14901 +size 198559 diff --git a/lib/python3.10/site-packages/scipy/fft/_pocketfft/pypocketfft.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/fft/_pocketfft/pypocketfft.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..e8caae18a064cbf2a94f99c71f14fe813f75432b --- /dev/null +++ b/lib/python3.10/site-packages/scipy/fft/_pocketfft/pypocketfft.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8d1cd66430404b76033dd8db08a2a2e81297dd185ce538beac8cb8c278f0286 +size 1051696 diff --git a/lib/python3.10/site-packages/scipy/fftpack/tests/fftw_double_ref.npz b/lib/python3.10/site-packages/scipy/fftpack/tests/fftw_double_ref.npz new file mode 100644 index 0000000000000000000000000000000000000000..e1e3d620400746177b560b9193efce03c2841e99 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/fftpack/tests/fftw_double_ref.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a60c649415b645223924d8342ccc5c097801c86901287a369e53fc9259f5ec4e +size 162120 diff --git a/lib/python3.10/site-packages/scipy/fftpack/tests/fftw_longdouble_ref.npz b/lib/python3.10/site-packages/scipy/fftpack/tests/fftw_longdouble_ref.npz new file mode 100644 index 0000000000000000000000000000000000000000..b1a646889c9889541e8d368c8c2d96520d183dc4 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/fftpack/tests/fftw_longdouble_ref.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a406cbd4dad04d0c59dd38f54416fb49424c82229c1a074b6a44ec0cde2000e3 +size 296072 diff --git a/lib/python3.10/site-packages/scipy/fftpack/tests/fftw_single_ref.npz b/lib/python3.10/site-packages/scipy/fftpack/tests/fftw_single_ref.npz new file mode 100644 index 0000000000000000000000000000000000000000..a42748dba14b7ff0d2f53ce4cd5a86a4f08e5d93 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/fftpack/tests/fftw_single_ref.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:276a9141318e6fc36e4ab6ff54a61b64054ef8849b660f17359e5f541b43c526 +size 95144 diff --git a/lib/python3.10/site-packages/scipy/fftpack/tests/test.npz b/lib/python3.10/site-packages/scipy/fftpack/tests/test.npz new file mode 100644 index 0000000000000000000000000000000000000000..1e5a4e06615c6bcc58f0feff20f73e83439a937d --- /dev/null +++ b/lib/python3.10/site-packages/scipy/fftpack/tests/test.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36de804a22d8fdea054590ce49ddf3c859838b7d89193c56b3bcb660cbf43797 +size 11968 diff --git a/lib/python3.10/site-packages/scipy/integrate/__pycache__/_lebedev.cpython-310.pyc b/lib/python3.10/site-packages/scipy/integrate/__pycache__/_lebedev.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9e35a16e1fa192b36f14f254ccab8868e2e2d74 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/integrate/__pycache__/_lebedev.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ae8c491d25f893805388038e56058da252d8b06645473db666572303cec2f4e +size 100121 diff --git a/lib/python3.10/site-packages/scipy/interpolate/tests/__pycache__/test_bsplines.cpython-310.pyc b/lib/python3.10/site-packages/scipy/interpolate/tests/__pycache__/test_bsplines.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e59050298ef0b04fffaa814fb5849c0fda830a21 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/interpolate/tests/__pycache__/test_bsplines.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:93026a4722e5d62c111b1a4960a15b5e40b424db1611d843cefab10e951aa415 +size 117534 diff --git a/lib/python3.10/site-packages/scipy/interpolate/tests/data/bug-1310.npz b/lib/python3.10/site-packages/scipy/interpolate/tests/data/bug-1310.npz new file mode 100644 index 0000000000000000000000000000000000000000..8bddf805c36b29dc449556c27a2b489691f841af --- /dev/null +++ b/lib/python3.10/site-packages/scipy/interpolate/tests/data/bug-1310.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d6803c0b398f2704c236f1d1b9e8e5ede06bd165a0abb0f228281abbd455ae9 +size 2648 diff --git a/lib/python3.10/site-packages/scipy/interpolate/tests/data/estimate_gradients_hang.npy b/lib/python3.10/site-packages/scipy/interpolate/tests/data/estimate_gradients_hang.npy new file mode 100644 index 0000000000000000000000000000000000000000..c5ef8f63f263a476823ddeacf2571551c2fe4690 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/interpolate/tests/data/estimate_gradients_hang.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:406c10857417ff5ea98d8cd28945c9d0e4f5c24f92a48ad0e8fab955bf2477f1 +size 35680 diff --git a/lib/python3.10/site-packages/scipy/interpolate/tests/data/gcvspl.npz b/lib/python3.10/site-packages/scipy/interpolate/tests/data/gcvspl.npz new file mode 100644 index 0000000000000000000000000000000000000000..50e9348dcca79eae861e67092add93cdb8ff1ca3 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/interpolate/tests/data/gcvspl.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03ce8155a6cba0c1bf0a2441a10c228191f916dec36cb820723429811296bba8 +size 3138 diff --git a/lib/python3.10/site-packages/scipy/io/_fast_matrix_market/_fmm_core.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/io/_fast_matrix_market/_fmm_core.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..7b6e991b70def287ea37ad1e10e9f32d5e7b12fe --- /dev/null +++ b/lib/python3.10/site-packages/scipy/io/_fast_matrix_market/_fmm_core.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e89d0067690706ec97bcde3eec9a0d0bafad769551e1be97d064ae9e73c30ad0 +size 4074120 diff --git a/lib/python3.10/site-packages/scipy/io/matlab/_mio5_utils.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/io/matlab/_mio5_utils.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..7f31415efe8528cce893322588856270ca141cc0 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/io/matlab/_mio5_utils.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84d2d5f3b2290b8fb22b5a01c2513c79d5a6104270da99c6f5fdc641da187ee9 +size 242560 diff --git a/lib/python3.10/site-packages/scipy/io/matlab/_streams.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/io/matlab/_streams.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..701f3306c7a8ecde34a3a83460d4119a3ac72c8f --- /dev/null +++ b/lib/python3.10/site-packages/scipy/io/matlab/_streams.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9764c76103a1d3363f72f1d9fa1e4039e2f13ab4963785aa1e598b3ed1d82a7 +size 145904 diff --git a/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_15_data.npz b/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_15_data.npz new file mode 100644 index 0000000000000000000000000000000000000000..660bbb41b7fad43ed945dc701693451ceb60166c --- /dev/null +++ b/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_15_data.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13f3e1491a876bbf59d7ea10ad29c1f9b5996a2ab99216f31d5bfcd659012c1e +size 34462 diff --git a/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_18_data.npz b/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_18_data.npz new file mode 100644 index 0000000000000000000000000000000000000000..0b3d569a1a65e9b5ff153ae4121a6a5a69409f7c --- /dev/null +++ b/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_18_data.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59f839467f2752b7df6fb6d4094396edd32a5929b764f7ffa1e6666431e6cac6 +size 161487 diff --git a/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_19_data.npz b/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_19_data.npz new file mode 100644 index 0000000000000000000000000000000000000000..90168ad4e888fba29a772ee13798ec126016140e --- /dev/null +++ b/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_19_data.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38e8fc7b041df0b23d7e5ca15ead1a065e6467611ef9a848cc7db93f80adfd87 +size 34050 diff --git a/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_20_data.npz b/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_20_data.npz new file mode 100644 index 0000000000000000000000000000000000000000..87266deb46238307347362b63a4878f2565baf56 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_20_data.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14e222d34a7118c7284a1675c6feceee77b84df951a5c6ba2a5ee9ff3054fa1d +size 31231 diff --git a/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_6_data.npz b/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_6_data.npz new file mode 100644 index 0000000000000000000000000000000000000000..35d1681786c95602c4f0d5260fc5ad0ff4236189 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/linalg/tests/data/carex_6_data.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b2a0736b541ebf5c4b9b4c00d6dab281e73c9fb9913c6e2581a781b37b602f9 +size 15878 diff --git a/lib/python3.10/site-packages/scipy/linalg/tests/data/gendare_20170120_data.npz b/lib/python3.10/site-packages/scipy/linalg/tests/data/gendare_20170120_data.npz new file mode 100644 index 0000000000000000000000000000000000000000..ff967f2ca0d0868aacf7d7e67402599e64bab817 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/linalg/tests/data/gendare_20170120_data.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3dfab451d9d5c20243e0ed85cd8b6c9657669fb9a0f83b5be165585783d55b5 +size 2164 diff --git a/lib/python3.10/site-packages/scipy/optimize/__pycache__/_optimize.cpython-310.pyc b/lib/python3.10/site-packages/scipy/optimize/__pycache__/_optimize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..deee2356609fcce6e118f2014f72674e9723c501 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/optimize/__pycache__/_optimize.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9606c52d56b4006e9949b0dbbcecc3eba97a5e7f2fb244a8080d4e29cdc710cc +size 115288 diff --git a/lib/python3.10/site-packages/scipy/optimize/_highspy/_core.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/optimize/_highspy/_core.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..f218abcc1daa77b9cd7fe701ce77c51e1e41b0aa --- /dev/null +++ b/lib/python3.10/site-packages/scipy/optimize/_highspy/_core.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:652d5532e6f86158c1c71fedfe54d8d64b99ef9a0990289cff69379f72a86bf6 +size 5293088 diff --git a/lib/python3.10/site-packages/scipy/optimize/_highspy/_highs_options.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/optimize/_highspy/_highs_options.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..f998120c2ebe7fc1942facfaa1f0ebaa00b37edf --- /dev/null +++ b/lib/python3.10/site-packages/scipy/optimize/_highspy/_highs_options.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c00f935c744fdec13ed4e4a54b103d1e9a5d0a9aaa322af4f2c0b699e0d27b9 +size 318456 diff --git a/lib/python3.10/site-packages/scipy/optimize/_lsq/givens_elimination.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/optimize/_lsq/givens_elimination.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..c2f2af62fb7dbc2a0ed1614df00f533fdf80e848 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/optimize/_lsq/givens_elimination.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e51c13cae6a472158617d36d7ae62eed96138693912ccb9798ce86e40c0d51b9 +size 193496 diff --git a/lib/python3.10/site-packages/scipy/optimize/_trlib/_trlib.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/optimize/_trlib/_trlib.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..4d92e9c5d7b2834eb5425bf2bed6742d33703a56 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/optimize/_trlib/_trlib.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37aa146106095bcd2a318942de8f4643de1f7511c5fe236d9d6c6139c5e408d3 +size 336376 diff --git a/lib/python3.10/site-packages/scipy/signal/__pycache__/_filter_design.cpython-310.pyc b/lib/python3.10/site-packages/scipy/signal/__pycache__/_filter_design.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a977fe6879e662e7716ca2e490263d49ffb7b25 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/signal/__pycache__/_filter_design.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14ae00b56ea3a8ffcd58008ce9bd897ed1d7637eab1e431e697450a1787d5667 +size 169581 diff --git a/lib/python3.10/site-packages/scipy/signal/__pycache__/_signaltools.cpython-310.pyc b/lib/python3.10/site-packages/scipy/signal/__pycache__/_signaltools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..852491fdff9be57d6d6153ed6ee70caf1ca20532 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/signal/__pycache__/_signaltools.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0317eed1b0719bf5e9707632ca1767b778de3064572c1f58c58cf5844a68d19 +size 151242 diff --git a/lib/python3.10/site-packages/scipy/signal/tests/__pycache__/test_filter_design.cpython-310.pyc b/lib/python3.10/site-packages/scipy/signal/tests/__pycache__/test_filter_design.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e654a22dd189bea910fabc392995138e02bad1e --- /dev/null +++ b/lib/python3.10/site-packages/scipy/signal/tests/__pycache__/test_filter_design.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d13b37be4e6be9efaf09f0dd1d24d18fb67149402be84028d1506c7200f809b4 +size 121026 diff --git a/lib/python3.10/site-packages/scipy/signal/tests/__pycache__/test_signaltools.cpython-310.pyc b/lib/python3.10/site-packages/scipy/signal/tests/__pycache__/test_signaltools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1b211e8c8ff3e5eec81edf31f252795c739edf4 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/signal/tests/__pycache__/test_signaltools.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e4afc9ca2da9f3626a67a2ea4fede617afab42cf766f1ecfb87e2a14b166ee8 +size 127629 diff --git a/lib/python3.10/site-packages/scipy/sparse/csgraph/_flow.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/csgraph/_flow.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..c23bcd820c28007e62730980c8cc1a8fa732d200 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/csgraph/_flow.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:beeb1d249e3b26f4cadb7c853840d0de65c293892eceac05a5b3b3c8134f01bc +size 308104 diff --git a/lib/python3.10/site-packages/scipy/sparse/csgraph/_matching.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/csgraph/_matching.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..12ae42c4be65b2aae7cc0aaeb766fcba1dd2341b --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/csgraph/_matching.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11d859f76f3109c84d9d35a340e7f1bf88de4b0e2d95ea5e6111aa724703ae97 +size 315064 diff --git a/lib/python3.10/site-packages/scipy/sparse/csgraph/_min_spanning_tree.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/csgraph/_min_spanning_tree.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..6bd5d3c0da39287ddda1d4ab363cbced3b1cacb6 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/csgraph/_min_spanning_tree.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4face0ee92ed88d1f41b9d65d7e88f54d24940eb4f7bbcbcd23fb3b520c017b8 +size 234080 diff --git a/lib/python3.10/site-packages/scipy/sparse/csgraph/_reordering.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/csgraph/_reordering.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..347b0ba7ffef6e54429262eb8401d465d13e0b0a --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/csgraph/_reordering.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:49256a4431f3c369739990af8fad7c81c01047288b6f9ba0c6d92713006c3fa2 +size 293680 diff --git a/lib/python3.10/site-packages/scipy/sparse/csgraph/_shortest_path.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/csgraph/_shortest_path.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..eb5e867a4e2102f8e800abf0b9321bb586843bae --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/csgraph/_shortest_path.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f192e32ecb4e9990fbb52648e8886c5f8fd23cde8edec58e3c595c084d7914d +size 518208 diff --git a/lib/python3.10/site-packages/scipy/sparse/csgraph/_tools.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/csgraph/_tools.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..6f2860d24a7db19415147906996c29a13b384fc2 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/csgraph/_tools.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79360f981b071bfa9afb20bcfcdcf1d7291d20911fb492f2f082a414fd28b8f1 +size 200880 diff --git a/lib/python3.10/site-packages/scipy/sparse/csgraph/_traversal.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/csgraph/_traversal.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..9dfee610cae94ef857077ba89ef72fc05bddb1db --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/csgraph/_traversal.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed054a57baee533acf3fe70f59d634ba811e9a82c162dfa5ee34cc03f1bc3650 +size 592200 diff --git a/lib/python3.10/site-packages/scipy/sparse/linalg/_dsolve/_superlu.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/linalg/_dsolve/_superlu.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..f70d25b3cee131305bb6ad04140b9b57e4364944 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/linalg/_dsolve/_superlu.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d02199521827f3819d20db4c1e45363b24345bf334a17f403abc5d131de531a6 +size 318352 diff --git a/lib/python3.10/site-packages/scipy/sparse/linalg/_eigen/arpack/_arpack.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/linalg/_eigen/arpack/_arpack.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..c5ec38b6dc06e513915102fe46d5e765ccb6412e --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/linalg/_eigen/arpack/_arpack.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f8e0e297b637a485fa4984bb6333b994de6f78fb8caa1dfd4ef1d00290d826d +size 470024 diff --git a/lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_cpropack.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_cpropack.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..5769917e181da132fbbecdc44f94fa9d6c734938 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_cpropack.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c592d78e3e04746a2338f98ccdb98c226c729d64636ade5a13ad4d6d915e6844 +size 121560 diff --git a/lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_dpropack.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_dpropack.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..41c5e4ef646d3778ece84bafcbe1862449392caf --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_dpropack.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e928176fe844d1b85bec191e75c74aa23b7648bd9abd382d320386740698c80c +size 112872 diff --git a/lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_spropack.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_spropack.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..3faaa0c91e2982f4d3e891996812bb01c8c53154 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_spropack.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe1f28b29b19c3187f7f5db4a4eff60ccbc8eaccb9c75068ed05b4ddff45b03a +size 112872 diff --git a/lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_zpropack.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_zpropack.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..817449aac07efcb13f3d0cdffda0ca2cda2835d2 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/linalg/_propack/_zpropack.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e724009f261c1601a1cc806fbaf22c9fbcdebd734e84ee78f6dc0df0281263ce +size 121752 diff --git a/lib/python3.10/site-packages/scipy/sparse/linalg/tests/propack_test_data.npz b/lib/python3.10/site-packages/scipy/sparse/linalg/tests/propack_test_data.npz new file mode 100644 index 0000000000000000000000000000000000000000..0bf01015610346655c749ead87a47d5575e2b67b --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/linalg/tests/propack_test_data.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bfe34d9a92353e08f400f3837136e553a8e91d441186913d39b59bf8a627bba3 +size 600350 diff --git a/lib/python3.10/site-packages/scipy/sparse/tests/__pycache__/test_base.cpython-310.pyc b/lib/python3.10/site-packages/scipy/sparse/tests/__pycache__/test_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6093cc8df65ccfedf2d9f0e4f55e98f0f9c34dbb --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/tests/__pycache__/test_base.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd32f3b09f2ef68b5395c20b5bbe645b0c0a4ca2e59e99eb8f205d50b1a4a04b +size 162340 diff --git a/lib/python3.10/site-packages/scipy/sparse/tests/data/csc_py2.npz b/lib/python3.10/site-packages/scipy/sparse/tests/data/csc_py2.npz new file mode 100644 index 0000000000000000000000000000000000000000..d4459ff2786fabe4bcf4653d880cbf0afd4bfdcf --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/tests/data/csc_py2.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bac27f1a3eb1fdd102dae39b7dd61ce83e82f096388e344e14285071984d01fa +size 846 diff --git a/lib/python3.10/site-packages/scipy/sparse/tests/data/csc_py3.npz b/lib/python3.10/site-packages/scipy/sparse/tests/data/csc_py3.npz new file mode 100644 index 0000000000000000000000000000000000000000..e40a38584bc4647621601075d946ce46a8e065dc --- /dev/null +++ b/lib/python3.10/site-packages/scipy/sparse/tests/data/csc_py3.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b1b84315c7077417e720512d086a5a6217c2875b818d27704ae9b7237c69dfe +size 851 diff --git a/lib/python3.10/site-packages/scipy/spatial/tests/data/degenerate_pointset.npz b/lib/python3.10/site-packages/scipy/spatial/tests/data/degenerate_pointset.npz new file mode 100644 index 0000000000000000000000000000000000000000..4f22bd3a3c941a683747944a0f12c7914f4b3f07 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/spatial/tests/data/degenerate_pointset.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:048abc1ddd924bf2d4d1f216015552ed9431f9e99546fbf382768eda58788175 +size 22548 diff --git a/lib/python3.10/site-packages/scipy/spatial/transform/_rotation.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/scipy/spatial/transform/_rotation.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..d51d6c177a1fa84039e5bc61241d2bca7a74b005 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/spatial/transform/_rotation.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d64264c512d24a1b03d821f1dd1e7bdd7960fdc49fd9ce2ad0733c5188fdbaf +size 935512 diff --git a/lib/python3.10/site-packages/scipy/special/__pycache__/_add_newdocs.cpython-310.pyc b/lib/python3.10/site-packages/scipy/special/__pycache__/_add_newdocs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3aabc6f5f855ee9163095fc63222906d8d3c7c9f --- /dev/null +++ b/lib/python3.10/site-packages/scipy/special/__pycache__/_add_newdocs.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f393310405fc01fb2929bfc5e8530c639cbe70d4d22c82301f18ca56c9f07bd +size 287955 diff --git a/lib/python3.10/site-packages/scipy/special/__pycache__/_basic.cpython-310.pyc b/lib/python3.10/site-packages/scipy/special/__pycache__/_basic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da4e075bea2ef1f96b1f1c0637c038faf99d9f16 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/special/__pycache__/_basic.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a283546a998c1acd3a5703a858ee9148db4c9bbda3aaae5b3e3a98fe1bea9e8 +size 103062 diff --git a/lib/python3.10/site-packages/scipy/special/tests/__pycache__/test_basic.cpython-310.pyc b/lib/python3.10/site-packages/scipy/special/tests/__pycache__/test_basic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4edef00e7d194be23aeee468f59019dc5739913c --- /dev/null +++ b/lib/python3.10/site-packages/scipy/special/tests/__pycache__/test_basic.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0308814b0fa62e3f6e7e3e2fa4550564133babe336f21cf1d5b9b6e364f5de8e +size 165966 diff --git a/lib/python3.10/site-packages/scipy/special/tests/data/boost.npz b/lib/python3.10/site-packages/scipy/special/tests/data/boost.npz new file mode 100644 index 0000000000000000000000000000000000000000..a3cba7656ee5445c3c94b8695f526de05973cadf --- /dev/null +++ b/lib/python3.10/site-packages/scipy/special/tests/data/boost.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d73ecbbb51654522342ba0470a6263a9684e617c2b8374565fe3a79593f4b231 +size 1270643 diff --git a/lib/python3.10/site-packages/scipy/special/tests/data/gsl.npz b/lib/python3.10/site-packages/scipy/special/tests/data/gsl.npz new file mode 100644 index 0000000000000000000000000000000000000000..b090dae17b5b0403f7c4919c46a464a09509aeab --- /dev/null +++ b/lib/python3.10/site-packages/scipy/special/tests/data/gsl.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:acab700208cbb301ee51eb1f512cb1c27e7b4e7533fc5a5f5cd5c5d6aa197dd8 +size 51433 diff --git a/lib/python3.10/site-packages/scipy/special/tests/data/local.npz b/lib/python3.10/site-packages/scipy/special/tests/data/local.npz new file mode 100644 index 0000000000000000000000000000000000000000..7a1d159f5fa6dc3c5521bda8cf3049ee24945857 --- /dev/null +++ b/lib/python3.10/site-packages/scipy/special/tests/data/local.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:102b876c27ec4d2f8041d5ab2fb6dfefc8147021335160f515455e53e06871ff +size 203438 diff --git a/lib/python3.10/site-packages/sklearn/metrics/_dist_metrics.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/metrics/_dist_metrics.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..0d3852aaa193c7784d76b5c664021c2868c7fc26 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/metrics/_dist_metrics.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f71f87cc46d7d489009c40af5ba2c5ef3e3eaad32948e9fc15fd42a6c9da2bf +size 644840 diff --git a/lib/python3.10/site-packages/sklearn/metrics/_pairwise_fast.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/metrics/_pairwise_fast.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..c493477f7d467ac705ae56f26da4060ae7fcd532 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/metrics/_pairwise_fast.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28010eb38160ceeff0e948c066a5406095a673271b5903f8e38a4d988c852697 +size 179401 diff --git a/lib/python3.10/site-packages/sklearn/neighbors/_ball_tree.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/neighbors/_ball_tree.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..ea31bc7281bd9c41120a0c575ef1958cc5c90603 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/neighbors/_ball_tree.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b84619b1f985248572aa988c4b00caa37bf536d00a4fec707964e4260b54e1dc +size 664400 diff --git a/lib/python3.10/site-packages/sklearn/neighbors/_kd_tree.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/neighbors/_kd_tree.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..022cb54da565de7eba0c7c419532d6addc060c3e --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/neighbors/_kd_tree.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c14795ff153794bbef5636ffe42a5d554a85192f35fcf21c7d442c9e5357886 +size 674424 diff --git a/lib/python3.10/site-packages/sklearn/neighbors/_quad_tree.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/neighbors/_quad_tree.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..9cbf22b61e16af18faeb5719f1e2f9a3d38e4968 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/neighbors/_quad_tree.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c5c95d14e82f10b938b84e535c1960389ef527b8bf817ce59453e3790360199 +size 191440 diff --git a/lib/python3.10/site-packages/sklearn/preprocessing/_csr_polynomial_expansion.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/preprocessing/_csr_polynomial_expansion.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..2f35b5b1110d3acabad3a39e268239c28c920bb1 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/preprocessing/_csr_polynomial_expansion.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d36a3dab9e6cd9989a455061380e40a2808d766c70c7de475012aab2e2f37c1 +size 359576 diff --git a/lib/python3.10/site-packages/sklearn/preprocessing/_target_encoder_fast.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/preprocessing/_target_encoder_fast.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..ee1c5ba5588be8ce810a1a9e313a45678ee7f47e --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/preprocessing/_target_encoder_fast.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e447b691c5fca6d200140ef914cef996d3de089d9c7a62c245adb9218b8dc702 +size 456536 diff --git a/lib/python3.10/site-packages/sklearn/svm/_liblinear.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/svm/_liblinear.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..fb001d759aca51379464615aa6e38347ba70d39a --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/svm/_liblinear.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57c845c6ec5e535e27a638c365e19ca88f9816d83349db6c722375fc13c46660 +size 186840 diff --git a/lib/python3.10/site-packages/sklearn/svm/_libsvm.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/svm/_libsvm.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..2794d3d4ec25c6b4d5117dd8ed54be2755b12ce9 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/svm/_libsvm.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de04b9defe666efd925fcc3174a1eeb5f20c4a0246da4f2ef2622e1fc03855ab +size 449200 diff --git a/lib/python3.10/site-packages/sklearn/svm/_libsvm_sparse.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/svm/_libsvm_sparse.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..2eeb21242e78a03f7ef66285ee9712a19ccea27e --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/svm/_libsvm_sparse.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7aa4dfe2f0a2c1e8ef8483e466f9f9c5acaf03531b05d2c3a29a3467906625c +size 380448 diff --git a/lib/python3.10/site-packages/sklearn/tree/_criterion.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/tree/_criterion.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..6eeceb87d652a41c23ec37d42fc4d391d2b678c0 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/tree/_criterion.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1cf34eaea2ef7d44a7d24976259adddd42c2677774ed84528290c30a7841873f +size 213496 diff --git a/lib/python3.10/site-packages/sklearn/tree/_partitioner.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/tree/_partitioner.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..ee523d2254cc6c51a6763fab064eff97b5f15536 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/tree/_partitioner.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33fda9f71ece78f468316eee34b6ee02910e37d39c40f945fe84e124165d9fb9 +size 182352 diff --git a/lib/python3.10/site-packages/sklearn/tree/_splitter.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/tree/_splitter.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..19a85f25168169b611c181c411de7e5474057977 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/tree/_splitter.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:483fb332bbf5b0fdecba665178de52ed2168d51bf81511c2e4ff839707def8dc +size 157232 diff --git a/lib/python3.10/site-packages/sklearn/tree/_tree.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/tree/_tree.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..56ed044102c5b1ff102ccbbf855cbe97deb52e3f --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/tree/_tree.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f7248f1084d29ce93cb63b144e1acf818f28fa9d731682873e9d87271c1b216 +size 512600 diff --git a/lib/python3.10/site-packages/sklearn/tree/_utils.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/tree/_utils.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..cb87e74058fae8bc150113b8cb2a1ef92acd90f8 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/tree/_utils.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b17766d55594c884a6cba41662b0dc8a68bdd045dde77a7a844f79b0273e36e +size 147576 diff --git a/lib/python3.10/site-packages/sklearn/utils/_seq_dataset.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/utils/_seq_dataset.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..e46e6fc5e07b020d3d769d21c8e6ea64ef75f6c9 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/utils/_seq_dataset.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:decfc380afc6b9bda7d0373bc02d956df6717962077b194de4b70e2da79e69be +size 226520 diff --git a/lib/python3.10/site-packages/sklearn/utils/_typedefs.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/utils/_typedefs.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..972173fd3707ff6526116d3b71603cfec83618e2 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/utils/_typedefs.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:982b442417b594c344b0bd60b9aa4572faa9b54b5850869898763781623cebe4 +size 153240 diff --git a/lib/python3.10/site-packages/sklearn/utils/_vector_sentinel.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/sklearn/utils/_vector_sentinel.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..a9f92cae223bf0caece6b3dbefebfe98e0679177 --- /dev/null +++ b/lib/python3.10/site-packages/sklearn/utils/_vector_sentinel.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:881fd44e66f16615aa20a79a9d22848711ac3257990f1cd322017042d7e74f68 +size 166544 diff --git a/lib/python3.10/site-packages/skvideo/datasets/data/bigbuckbunny.mp4 b/lib/python3.10/site-packages/skvideo/datasets/data/bigbuckbunny.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f2e5a8fb10a2c6d81ab774cf93e5011df3eefdbe --- /dev/null +++ b/lib/python3.10/site-packages/skvideo/datasets/data/bigbuckbunny.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f25b31f155970c46300934bda4a76cd2f581acab45c49762832ffdfddbcf9fdd +size 1055736 diff --git a/lib/python3.10/site-packages/skvideo/datasets/data/bikes.mp4 b/lib/python3.10/site-packages/skvideo/datasets/data/bikes.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ed6bd30a8ed47724a9f368049250baa354b4d85c --- /dev/null +++ b/lib/python3.10/site-packages/skvideo/datasets/data/bikes.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91028f9d6c72cc8137d8bd05678bdfcf5ab7c8fd9d7b77de70ce7a3ade257bb5 +size 509868 diff --git a/lib/python3.10/site-packages/skvideo/datasets/data/carphone_pristine.mp4 b/lib/python3.10/site-packages/skvideo/datasets/data/carphone_pristine.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..482e4907e56b31f0acae6b6edcd71509842d00c5 --- /dev/null +++ b/lib/python3.10/site-packages/skvideo/datasets/data/carphone_pristine.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c4add7838b07b4d65ad9d66e9491758c7dbb6c717490db4b79ecf9ff82bab28 +size 588804 diff --git a/lib/python3.10/site-packages/skvideo/measure/data/niqe_cov_96.pkl b/lib/python3.10/site-packages/skvideo/measure/data/niqe_cov_96.pkl new file mode 100644 index 0000000000000000000000000000000000000000..48a6bd0eccc45d78c869814e474c3d719be2d168 --- /dev/null +++ b/lib/python3.10/site-packages/skvideo/measure/data/niqe_cov_96.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9a4d73c23bda65cfb395b9bc35dac948687375cb5083d18a5b6d4a3120d3f1d +size 5507 diff --git a/lib/python3.10/site-packages/skvideo/measure/data/niqe_mu_96.pkl b/lib/python3.10/site-packages/skvideo/measure/data/niqe_mu_96.pkl new file mode 100644 index 0000000000000000000000000000000000000000..3cdef24809cf46fa70182d43814508e936756b84 --- /dev/null +++ b/lib/python3.10/site-packages/skvideo/measure/data/niqe_mu_96.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69f82cec3e39509e41c800ea19d1f9ab69447d793d1c3014dc8ffb46dd069ff0 +size 440 diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/perm_groups.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/perm_groups.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9be538b989ba2ae1c3f9596953ad45b1041d0bdd --- /dev/null +++ b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/perm_groups.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:259d06f4b9408a838686665014fbcf351463fdc7708bc71faa8321722047259c +size 153009 diff --git a/lib/python3.10/site-packages/sympy/core/__pycache__/expr.cpython-310.pyc b/lib/python3.10/site-packages/sympy/core/__pycache__/expr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c8c101fdde60137d5c9f01ee5e9dcc1f5c29cf3 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/core/__pycache__/expr.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4dbe6b0b31f08ec0dcb49615881f7e956a65ddaa87702947f453bd2675abf8ea +size 115097 diff --git a/lib/python3.10/site-packages/sympy/core/__pycache__/function.cpython-310.pyc b/lib/python3.10/site-packages/sympy/core/__pycache__/function.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d25d6542f3492bc91cb6f8ee890afdd9f80dc584 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/core/__pycache__/function.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2151a83ffcc961415cd255626d22516490d991b95fab26d6d8700e06b30ff4f5 +size 101122 diff --git a/lib/python3.10/site-packages/sympy/core/__pycache__/numbers.cpython-310.pyc b/lib/python3.10/site-packages/sympy/core/__pycache__/numbers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d961b136784d4d3957d10375cfa62788323e6c17 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/core/__pycache__/numbers.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b3d359c9e520212e571b77557bcf3602b7c44264aa1a81ddbac569c87a1d7c9 +size 118182 diff --git a/lib/python3.10/site-packages/sympy/core/tests/__pycache__/test_args.cpython-310.pyc b/lib/python3.10/site-packages/sympy/core/tests/__pycache__/test_args.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..739ab52a9ac18b0dd7eb0dcf317319d8d9b2e147 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/core/tests/__pycache__/test_args.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9b567d3797310322b0a3f82732f6abfbfe996e8d994e3aebe6627fb2697c31c +size 220839 diff --git a/lib/python3.10/site-packages/sympy/logic/__pycache__/boolalg.cpython-310.pyc b/lib/python3.10/site-packages/sympy/logic/__pycache__/boolalg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e8d0ee3cfc61860d448816d6a922b23f40685d4 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/logic/__pycache__/boolalg.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2eb30a51ecf00bbfa525d500514b88745f44ddb9db1c5d5e63871f7bd75442ed +size 100386 diff --git a/lib/python3.10/site-packages/sympy/matrices/__pycache__/matrixbase.cpython-310.pyc b/lib/python3.10/site-packages/sympy/matrices/__pycache__/matrixbase.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07f49e915ef574cfad1fc6b724b5323b891028f5 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/matrices/__pycache__/matrixbase.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1cd3aefd1a6e801104e23f6309547813a8d25b985bb97a880e1fa020f7c3945 +size 164922 diff --git a/lib/python3.10/site-packages/sympy/matrices/tests/__pycache__/test_matrices.cpython-310.pyc b/lib/python3.10/site-packages/sympy/matrices/tests/__pycache__/test_matrices.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3241547a68cac11e98685537d0cdf597b980fea --- /dev/null +++ b/lib/python3.10/site-packages/sympy/matrices/tests/__pycache__/test_matrices.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59772a46f54b8c8507e5dfd6ebb98fd1df3e09b4c39589ef6b924eb59d0ced42 +size 144515 diff --git a/lib/python3.10/site-packages/sympy/matrices/tests/__pycache__/test_matrixbase.cpython-310.pyc b/lib/python3.10/site-packages/sympy/matrices/tests/__pycache__/test_matrixbase.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5609ca97d4dc61d77ddc9ca11420a6287f05812f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/matrices/tests/__pycache__/test_matrixbase.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:149ada415d7dbd679b0b4d6270943271253b408f087e5ac53a8c783fc5eed873 +size 153973 diff --git a/lib/python3.10/site-packages/sympy/parsing/latex/_antlr/__pycache__/latexparser.cpython-310.pyc b/lib/python3.10/site-packages/sympy/parsing/latex/_antlr/__pycache__/latexparser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aa05a86eb128d7be5f678713fb222265532ce48 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/parsing/latex/_antlr/__pycache__/latexparser.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb96a384c472d027dd04009f3f40b9758f2c38e27308fdc4212198cfb776c171 +size 109854 diff --git a/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/__pycache__/beam.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/__pycache__/beam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02832e92ec4cf8a556d99869a47f4c4023c3d4e8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/__pycache__/beam.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:726306e6760c204fe7e91426e8a91668c497360fa7d5b6346f469f412ba6c087 +size 122437 diff --git a/lib/python3.10/site-packages/sympy/physics/control/__pycache__/lti.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/control/__pycache__/lti.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8f8e1415836e96b2a262c0da554633aee15e35c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/control/__pycache__/lti.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d41d2d12646f22654fb096ea1a9668e3ed2fec0fab5c8e38e9133c34f7d8f877 +size 154543 diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_spin.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_spin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3c254e54fa934bd7b1c2359600e5a212a21776b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_spin.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebf87c0af26392d4e216e5ef7080970f13f1c9762ced7227cc81d231f98ed027 +size 199617 diff --git a/lib/python3.10/site-packages/sympy/polys/__pycache__/polyquinticconst.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/__pycache__/polyquinticconst.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12b860361528ca036e88e56ff4d562af8b61315c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/__pycache__/polyquinticconst.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3dba9ec26671b46a9c440b8e0681300c41045875a100be3640e68bb3faa10af8 +size 132132 diff --git a/lib/python3.10/site-packages/sympy/polys/__pycache__/polytools.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/__pycache__/polytools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..002aa856df28603809e06b3034244d8a9d1b53d8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/__pycache__/polytools.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cca90bf7310fb7b8dcf935f0ec3697964fed3aba0125108ff73c12d0f6f0a384 +size 186997 diff --git a/lib/python3.10/site-packages/sympy/polys/benchmarks/__pycache__/bench_solvers.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/benchmarks/__pycache__/bench_solvers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9100765b55488926863c0dafc53c3be53823736a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/benchmarks/__pycache__/bench_solvers.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebfcec3a1c745c095d9fdcfe2bf888f44adf8383187dd7f078224c335c4dfceb +size 334887 diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/domainmatrix.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/domainmatrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..382ac45eb2b9a3e1512eff5a319d26faa77cd454 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/domainmatrix.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6cb4b7688ca8aa6ffc4cc1b122a0aa39cfdb973f96c0750d9e4f47b963d5274 +size 110810 diff --git a/lib/python3.10/site-packages/sympy/polys/tests/__pycache__/test_polytools.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/tests/__pycache__/test_polytools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65bf67f6158bf3d14f0ecb95f3396f83b2c2f929 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/tests/__pycache__/test_polytools.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2cd15a720f7153d61e23b8bd6714ba2f561d59805bd194cad6a4f464b32f8d52 +size 141303 diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/latex.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/latex.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28927ad9c06ac8aa86e500be642629889917d708 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/__pycache__/latex.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c35909cc0581aa62f4002c02132a9235ae33c07f9ab95da4f1a94fe4ecb0d20b +size 119189 diff --git a/lib/python3.10/site-packages/sympy/printing/pretty/tests/__pycache__/test_pretty.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/pretty/tests/__pycache__/test_pretty.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49886c64545210e6e799105e0fa993c02fd3ef2f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/pretty/tests/__pycache__/test_pretty.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d21a854352f4c58c651843bc50fc7b68ee55af01cb20143e52563bfab0b6130 +size 154804 diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_latex.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_latex.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bcbd0af61d711bf74fe4aa8b44d0aaa7c1a1214 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_latex.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec7b8bee47fbdd72848ee8aec4fe1f09b6c7cb6cc6f71a4abee0c5cbdb9a3060 +size 127484 diff --git a/lib/python3.10/site-packages/sympy/solvers/__pycache__/solvers.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/__pycache__/solvers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43e4362c3456379c204e929b9177d45e14343c6c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/__pycache__/solvers.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c31c1e13be5bdd71aa53093f6b03dfeae219de0524323a505252465e139b49a +size 100435 diff --git a/lib/python3.10/site-packages/sympy/solvers/__pycache__/solveset.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/__pycache__/solveset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eccaddcb607d727dd8c6845e9b872f84cfcfdaa3 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/__pycache__/solveset.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0a374718d412e1a5687656bf114e2adf5333f886dd9e5b4f1686727a7335605 +size 112020 diff --git a/lib/python3.10/site-packages/sympy/solvers/diophantine/__pycache__/diophantine.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/diophantine/__pycache__/diophantine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4228712c24ae34ec22784c08b0739969bbb89150 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/diophantine/__pycache__/diophantine.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f56119789c24cfc0bbb3ed7ca5e088448774dd7b8346d0699112610806a4d1a +size 106669 diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/ode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/ode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88211b3c1e721dee9029c441362a37295b1d4638 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/ode.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3df35ecf4e93743182620bbec8b83fc0526de60b5fa57e1fab8e8c928876d13c +size 121505 diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/single.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/single.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1d750848df1db6ae5539ce8d3ba4f3825b99c1e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/single.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52f70a82caefae3216fb74bb51f29760f4a6b485e902a321c6756e7bdd558564 +size 105045 diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_systems.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_systems.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1ecc7aaacf657efcdf8f64e6a1446daef447c41 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_systems.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:900ff23598cfc334239ca96a1983dec126a019b02da4d4242cffae5e40820616 +size 112073 diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_solvers.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_solvers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb8091798d036922aefc60f044090827882a0dd6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_solvers.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca9cf0f9b8b876be23d98e4fe0dee89673ace52f4fee9e9c67abb7fa34daaf90 +size 108425 diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_solveset.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_solveset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e14764b2daf8d7091b08a939de84170afc8b490 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_solveset.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53445532e9c7b52fd978383fecc7fe48072a71a724208e7030e6bf1b7b8e0777 +size 137749 diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/crv_types.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/crv_types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e05c6ee4f56c5000bfdef57a99988c4d5e0dac5d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/stats/__pycache__/crv_types.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:784591b50497ee7f4ac665bfa945e72c287dc31df3a503d7dc2d91ade6d39f0e +size 129114 diff --git a/lib/python3.10/site-packages/sympy/tensor/__pycache__/tensor.cpython-310.pyc b/lib/python3.10/site-packages/sympy/tensor/__pycache__/tensor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a79f271088693a5c98266daaf1b6beb9237a0914 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/tensor/__pycache__/tensor.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3108c15dd85d7247dfc8fe82f223918fd05338415a81f99b7f38ea3ec30cb329 +size 153033 diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/__pycache__/test_wester.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/tests/__pycache__/test_wester.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bee63a8e25fe08839c23701ffd79b15421561ca --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/__pycache__/test_wester.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ecb252ead2ba9c7b5a22bb73e0952affc2e2648e9327a4258908ecdca487ea45 +size 113455 diff --git a/lib/python3.10/site-packages/taichi/_lib/core/taichi_python.cpython-310-x86_64-linux-gnu.so b/lib/python3.10/site-packages/taichi/_lib/core/taichi_python.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..1eb48e996c2bcb8abba35a5e5deca73f664d5389 --- /dev/null +++ b/lib/python3.10/site-packages/taichi/_lib/core/taichi_python.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cab8023b4b997ea4402f6bb00934495ba1be16570a6b7bad8ed7c545912c6f4 +size 92546200 diff --git a/lib/python3.10/site-packages/taichi/_lib/runtime/runtime_cuda.bc b/lib/python3.10/site-packages/taichi/_lib/runtime/runtime_cuda.bc new file mode 100644 index 0000000000000000000000000000000000000000..64e060e4a2dc4a3772dfe4dc066f9202ce873a7a --- /dev/null +++ b/lib/python3.10/site-packages/taichi/_lib/runtime/runtime_cuda.bc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08ed89a836668974da04995890500ded1a66e1754d2bde10ae331f58e9d20657 +size 147512 diff --git a/lib/python3.10/site-packages/taichi/_lib/runtime/runtime_x64.bc b/lib/python3.10/site-packages/taichi/_lib/runtime/runtime_x64.bc new file mode 100644 index 0000000000000000000000000000000000000000..0be071a387bfbeb1b6d0736f06c1afd4fb3bb66b --- /dev/null +++ b/lib/python3.10/site-packages/taichi/_lib/runtime/runtime_x64.bc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39301c58a0f02d09460036439962fb8ef6b8097001fc8cbaec0e0f0553956c0e +size 134380 diff --git a/lib/python3.10/site-packages/taichi/_lib/runtime/slim_libdevice.10.bc b/lib/python3.10/site-packages/taichi/_lib/runtime/slim_libdevice.10.bc new file mode 100644 index 0000000000000000000000000000000000000000..3af755e6e5ea1be58584dee6bf57265e80811226 --- /dev/null +++ b/lib/python3.10/site-packages/taichi/_lib/runtime/slim_libdevice.10.bc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5d943157315e2e20020acbed7253ba69355fe831d2070d1f58a8e4ffeec7dac +size 171892 diff --git a/lib/python3.10/site-packages/taichi/assets/Go-Regular.ttf b/lib/python3.10/site-packages/taichi/assets/Go-Regular.ttf new file mode 100644 index 0000000000000000000000000000000000000000..9e6676d98af6024113628beec21089a277bdd83f --- /dev/null +++ b/lib/python3.10/site-packages/taichi/assets/Go-Regular.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4bb829593136416c6a39ecdc45482e052f75ac374e6459b9af68d4fba279396c +size 134988 diff --git a/lib/python3.10/site-packages/taichi/assets/static/imgs/ti_gallery.png b/lib/python3.10/site-packages/taichi/assets/static/imgs/ti_gallery.png new file mode 100644 index 0000000000000000000000000000000000000000..6384aeb73cef1f21f1071fffe85040776b4f3d2f --- /dev/null +++ b/lib/python3.10/site-packages/taichi/assets/static/imgs/ti_gallery.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ace7a1f099545dbf997644da7a2ddc57225764158ea9c2412dd9dd00e45476d5 +size 263871 diff --git a/lib/python3.10/site-packages/taichi/examples/autodiff/diff_sph/fc1_pretrained.pkl b/lib/python3.10/site-packages/taichi/examples/autodiff/diff_sph/fc1_pretrained.pkl new file mode 100644 index 0000000000000000000000000000000000000000..fb673489716b266647304cbebfa0ceaa3028e00e --- /dev/null +++ b/lib/python3.10/site-packages/taichi/examples/autodiff/diff_sph/fc1_pretrained.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38b579cfa26e0d6b84f86fb7b934a3fd25e4e2430d7072bdadd760068c3c99ea +size 711 diff --git a/lib/python3.10/site-packages/taichi/examples/autodiff/diff_sph/fc2_pretrained.pkl b/lib/python3.10/site-packages/taichi/examples/autodiff/diff_sph/fc2_pretrained.pkl new file mode 100644 index 0000000000000000000000000000000000000000..f7fab4758be61e2713acc4fa9f950b5ec7c9592b --- /dev/null +++ b/lib/python3.10/site-packages/taichi/examples/autodiff/diff_sph/fc2_pretrained.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa72b901b02e2b8a3e6ea968ef5b6f2fb2c7b004cfeb40e010832ce6d33633d0 +size 2375 diff --git a/lib/python3.10/site-packages/tensorboard/_vendor/html5lib/__pycache__/constants.cpython-310.pyc b/lib/python3.10/site-packages/tensorboard/_vendor/html5lib/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de36a7e7b1aaf7005cdaac8816457fb7f704bf0f --- /dev/null +++ b/lib/python3.10/site-packages/tensorboard/_vendor/html5lib/__pycache__/constants.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99930a36e537eb4165d7b280ea9103b64d9bf25b663c910011936f8aec946eb9 +size 161288 diff --git a/lib/python3.10/site-packages/tensorboard_data_server/bin/server b/lib/python3.10/site-packages/tensorboard_data_server/bin/server new file mode 100644 index 0000000000000000000000000000000000000000..038f40647d5cb5e9066266f5fbaf437ff9d53f81 --- /dev/null +++ b/lib/python3.10/site-packages/tensorboard_data_server/bin/server @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:756d02d60bc64c9d6218b500eb7fb02910020cc43c6baec48104da2dd957488e +size 20822272 diff --git a/lib/python3.10/site-packages/torch/_C/_aoti.pyi b/lib/python3.10/site-packages/torch/_C/_aoti.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a5e782fe62123e6e6c141c80a10e70fd991ee654 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_C/_aoti.pyi @@ -0,0 +1,20 @@ +from ctypes import c_void_p + +from torch import Tensor + +# Defined in torch/csrc/inductor/aoti_runner/pybind.cpp + +# Tensor to AtenTensorHandle +def unsafe_alloc_void_ptrs_from_tensors(tensors: list[Tensor]) -> list[c_void_p]: ... +def unsafe_alloc_void_ptr_from_tensor(tensor: Tensor) -> c_void_p: ... + +# AtenTensorHandle to Tensor +def alloc_tensors_by_stealing_from_void_ptrs( + handles: list[c_void_p], +) -> list[Tensor]: ... +def alloc_tensor_by_stealing_from_void_ptr( + handle: c_void_p, +) -> Tensor: ... + +class AOTIModelContainerRunnerCpu: ... +class AOTIModelContainerRunnerCuda: ... diff --git a/lib/python3.10/site-packages/torch/_C/_autograd.pyi b/lib/python3.10/site-packages/torch/_C/_autograd.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f756828ed6c9bf23a901e7416474fc86215276f4 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_C/_autograd.pyi @@ -0,0 +1,135 @@ +# mypy: allow-untyped-defs +from enum import Enum +from typing import Any, Callable + +import torch +from torch._C._profiler import ( + _ProfilerEvent, + ActiveProfilerType, + ProfilerActivity, + ProfilerConfig, +) + +# Defined in torch/csrc/autograd/init.cpp + +class DeviceType(Enum): + CPU = ... + CUDA = ... + XPU = ... + MKLDNN = ... + OPENGL = ... + OPENCL = ... + IDEEP = ... + HIP = ... + FPGA = ... + MAIA = ... + XLA = ... + MTIA = ... + MPS = ... + HPU = ... + Meta = ... + Vulkan = ... + Metal = ... + PrivateUse1 = ... + +class ProfilerEvent: + def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ... + def cpu_memory_usage(self) -> int: ... + def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ... + def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ... + def cuda_memory_usage(self) -> int: ... + def device(self) -> int: ... + def handle(self) -> int: ... + def has_cuda(self) -> bool: ... + def is_remote(self) -> bool: ... + def kind(self) -> int: ... + def name(self) -> str: ... + def node_id(self) -> int: ... + def sequence_nr(self) -> int: ... + def shapes(self) -> list[list[int]]: ... + def thread_id(self) -> int: ... + def flops(self) -> float: ... + def is_async(self) -> bool: ... + +class _KinetoEvent: + def name(self) -> str: ... + def device_index(self) -> int: ... + def device_resource_id(self) -> int: ... + def start_ns(self) -> int: ... + def end_ns(self) -> int: ... + def duration_ns(self) -> int: ... + def is_async(self) -> bool: ... + def linked_correlation_id(self) -> int: ... + def shapes(self) -> list[list[int]]: ... + def dtypes(self) -> list[str]: ... + def concrete_inputs(self) -> list[Any]: ... + def kwinputs(self) -> dict[str, Any]: ... + def device_type(self) -> DeviceType: ... + def start_thread_id(self) -> int: ... + def end_thread_id(self) -> int: ... + def correlation_id(self) -> int: ... + def fwd_thread_id(self) -> int: ... + def stack(self) -> list[str]: ... + def scope(self) -> int: ... + def sequence_nr(self) -> int: ... + def flops(self) -> int: ... + def cuda_elapsed_us(self) -> int: ... + def privateuse1_elapsed_us(self) -> int: ... + def is_user_annotation(self) -> bool: ... + +class _ProfilerResult: + def events(self) -> list[_KinetoEvent]: ... + def legacy_events(self) -> list[list[ProfilerEvent]]: ... + def save(self, path: str) -> None: ... + def experimental_event_tree(self) -> list[_ProfilerEvent]: ... + def trace_start_ns(self) -> int: ... + +class SavedTensor: ... + +def _enable_profiler( + config: ProfilerConfig, + activities: set[ProfilerActivity], +) -> None: ... +def _prepare_profiler( + config: ProfilerConfig, + activities: set[ProfilerActivity], +) -> None: ... +def _toggle_collection_dynamic( + enable: bool, + activities: set[ProfilerActivity], +) -> None: ... +def _disable_profiler() -> _ProfilerResult: ... +def _profiler_enabled() -> bool: ... +def _add_metadata_json(key: str, value: str) -> None: ... +def _kineto_step() -> None: ... +def _get_current_graph_task_keep_graph() -> bool: ... +def _get_sequence_nr() -> int: ... +def kineto_available() -> bool: ... +def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ... +def _record_function_with_args_exit(handle: torch.Tensor) -> None: ... +def _supported_activities() -> set[ProfilerActivity]: ... +def _enable_record_function(enable: bool) -> None: ... +def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ... +def _push_saved_tensors_default_hooks( + pack_hook: Callable[[torch.Tensor], Any], + unpack_hook: Callable[[Any], torch.Tensor], +) -> None: ... +def _pop_saved_tensors_default_hooks() -> None: ... +def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ... +def _enable_profiler_legacy(config: ProfilerConfig) -> None: ... +def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ... +def _profiler_type() -> ActiveProfilerType: ... +def _saved_tensors_hooks_enable() -> None: ... +def _saved_tensors_hooks_disable(message: str) -> None: ... +def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ... +def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ... + +class CreationMeta(Enum): + DEFAULT = ... + IN_CUSTOM_FUNCTION = ... + MULTI_OUTPUT_NODE = ... + NO_GRAD_MODE = ... + INFERENCE_MODE = ... + +def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ... +def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ... diff --git a/lib/python3.10/site-packages/torch/_C/_cpu.pyi b/lib/python3.10/site-packages/torch/_C/_cpu.pyi new file mode 100644 index 0000000000000000000000000000000000000000..6593222a119f4dfbd532859009ecb24ea4522323 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_C/_cpu.pyi @@ -0,0 +1,12 @@ +from torch.types import _bool, _int + +# Defined in torch/csrc/cpu/Module.cpp + +def _is_avx2_supported() -> _bool: ... +def _is_avx512_supported() -> _bool: ... +def _is_avx512_vnni_supported() -> _bool: ... +def _is_avx512_bf16_supported() -> _bool: ... +def _is_amx_tile_supported() -> _bool: ... +def _init_amx() -> _bool: ... +def _L1d_cache_size() -> _int: ... +def _L2_cache_size() -> _int: ... diff --git a/lib/python3.10/site-packages/torch/_C/_cudnn.pyi b/lib/python3.10/site-packages/torch/_C/_cudnn.pyi new file mode 100644 index 0000000000000000000000000000000000000000..689c984b9d7de1ca98329495223dcb0a13a54f4e --- /dev/null +++ b/lib/python3.10/site-packages/torch/_C/_cudnn.pyi @@ -0,0 +1,17 @@ +from enum import Enum + +from torch.types import _bool, Tuple + +# Defined in torch/csrc/cuda/shared/cudnn.cpp +is_cuda: _bool + +def getRuntimeVersion() -> Tuple[int, int, int]: ... +def getCompileVersion() -> Tuple[int, int, int]: ... +def getVersionInt() -> int: ... + +class RNNMode(int, Enum): + value: int + rnn_relu = ... + rnn_tanh = ... + lstm = ... + gru = ... diff --git a/lib/python3.10/site-packages/torch/_C/_cusparselt.pyi b/lib/python3.10/site-packages/torch/_C/_cusparselt.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a1c4bbb217777d50b9a3384da11d9992db5e18f0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_C/_cusparselt.pyi @@ -0,0 +1 @@ +def getVersionInt() -> int: ... diff --git a/lib/python3.10/site-packages/torch/__pycache__/_appdirs.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/_appdirs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9de3256a7c01fe32fd3c590c6d18252a2c044111 Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/_appdirs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/__pycache__/_classes.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/_classes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17227b5673b790dbddd1c96d299a31475077269c Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/_classes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/__pycache__/_compile.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/_compile.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1f42763502387c4280cf19d8f33240cf6ffafb6 Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/_compile.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/__pycache__/_custom_ops.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/_custom_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..352c6af19b930fb0cf0f1f3c3af6481af1a8ae00 Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/_custom_ops.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/__pycache__/_deploy.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/_deploy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd49afdddedc20c70a89f290b5b2a5f4e1daff00 Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/_deploy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/__pycache__/_guards.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/_guards.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0796589315774b96bd880a9a5010b3b5726d9c8 Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/_guards.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/__pycache__/_jit_internal.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/_jit_internal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c89e2bdae03ad9db2f07eff4aad9a006c12e2db Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/_jit_internal.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/__pycache__/_linalg_utils.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/_linalg_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7543c008c4e7484b450207637347648972c07e8 Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/_linalg_utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/__pycache__/_lobpcg.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/_lobpcg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6127ec47021e30d7c8d12a966a60dfebe8a996f8 Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/_lobpcg.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/__pycache__/storage.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/storage.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bb0fec9c9abcf4b94ada07f45fe4bcc33788a8e Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/storage.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/__pycache__/torch_version.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/torch_version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e713e17dd67c10a31dfeae0715d8bec55592dec3 Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/torch_version.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/__pycache__/types.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65d405ab9564b6d8af8ce0478ef14f0d7a94cd38 Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/types.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/__pycache__/version.cpython-310.pyc b/lib/python3.10/site-packages/torch/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff1b96490e0094f9003748245bf8bf75866a1649 Binary files /dev/null and b/lib/python3.10/site-packages/torch/__pycache__/version.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torch/_custom_op/__init__.py b/lib/python3.10/site-packages/torch/_custom_op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/torch/_custom_op/autograd.py b/lib/python3.10/site-packages/torch/_custom_op/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..35727197d03c1c4c1e00584d2c25e1830d6bcbd8 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_custom_op/autograd.py @@ -0,0 +1,275 @@ +# mypy: allow-untyped-defs +import torch +import torch.utils._pytree as pytree +from collections import namedtuple +import functools + + +# NOTE [CustomOp autograd kernel indirection] +# We register `inner` as the autograd kernel for this custom_op. +# `inner` either calls the autograd formula registered by the user, +# or goes into an `autograd_not_implemented` kernel. +# +# The reason why this indirection exists is +# so that we can swap out the autograd kernel (the PyTorch dispatcher +# doesn't actually allow us to do this). By default, we want +# the `autograd_not_implemented` behavior, but then the user may come +# and register something that is actually a backward formula +def autograd_kernel_indirection(custom_op): + autograd_fallback = autograd_not_implemented(custom_op) + + def inner(*args, **kwargs): + if custom_op._has_impl('autograd'): + kernel = custom_op._get_impl('autograd').func + return kernel(*args, **kwargs) + # As explained in NOTE ["backward", "save_for_backward", and "autograd"], + # after the user gives us "backward" and "save_for_backward", we generate + # the "autograd" impl. If the user only provided one, then we tell + # the user they've done something wrong. + if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'): + missing = ( + 'save_for_backward' if custom_op._has_impl('backward') + else 'backward' + ) + found = 'save_for_backward' if missing == 'backward' else 'backward' + loc = custom_op._get_impl(found).location + raise RuntimeError( + f"We found a '{found}' registration for {custom_op} at " + f"{loc} but were unable to find a '{missing}' registration. " + f"To use the CustomOp API to register a backward formula, " + f"please provide us both a backward function and a " + f"'save for backward' function via `impl_backward` and " + f"`impl_save_for_backward` respectively.") + return autograd_fallback(*args, **kwargs) + return inner + + +# TODO(#101191): Use the actual C++ autograd not implemented fallback, +# or change the default autograd fallback to the autograd not implemented fallback. +def autograd_not_implemented(custom_op): + def kernel(*args, **kwargs): + if torch.is_grad_enabled() and pytree.tree_any( + lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs) + ): + raise RuntimeError("Autograd has not been implemented for operator") + with torch._C._AutoDispatchBelowAutograd(): + return custom_op(*args, **kwargs) + return kernel + + +def mark_non_differentiable(ctx, output, output_differentiability): + # Output types are restricted to be: + # - Tensor + # - Tensor[] + # - int, bool, Scalar, float + # See _check_can_register_backward + if output_differentiability is not None: + if not isinstance(output, tuple): + tuple_output = (output,) + else: + tuple_output = output # type: ignore[assignment] + assert len(output_differentiability) == len(tuple_output) + non_differentiable_tensors = [] + for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)): + if isinstance(out, torch.Tensor): + if not differentiable: + non_differentiable_tensors.append(out) + continue + if isinstance(out, list): + if not differentiable: + non_differentiable_tensors.extend(out) + continue + if differentiable: + raise RuntimeError( + f"With output_differentiability={output_differentiability}. " + f"At idx {idx}, we received an object of type {type(out)} that " + f"is not a Tensor, so it cannot have be marked as differentiable in " + f"output_differentiability.") + if non_differentiable_tensors: + ctx.mark_non_differentiable(*non_differentiable_tensors) + + +def construct_autograd_kernel( + schema, + output_differentiability, + custom_op, + op_overload, + save_for_backward_fn, + backward_fn): + + def apply(*args): + flat_args, spec = pytree.tree_flatten(args) + out_spec = None + + def forward(ctx, *flat_args): + ctx.set_materialize_grads(True) + args = pytree.tree_unflatten(list(flat_args), spec) + with torch._C._AutoDispatchBelowAutograd(): + output = op_overload(*args) + + # We use the info about args to give better error messages in backward + args_info = namedtuple_args( + schema, pytree.tree_map(type, args)) + + save_for_backward_fn_inputs = namedtuple_args(schema, args) + to_save = save_for_backward_fn(save_for_backward_fn_inputs, output) + + save_pytree_for_backward(ctx, (to_save, args_info)) + mark_non_differentiable(ctx, output, output_differentiability) + + nonlocal out_spec + flat_output, out_spec = pytree.tree_flatten(output) + return tuple(flat_output) + + def backward(ctx, *flat_grad_output): + assert out_spec is not None + grads = pytree.tree_unflatten(list(flat_grad_output), out_spec) + saved, args_info = unpack_saved(ctx) + # There is nothing on the ctx object for now, it is just there so + # that we can add additional things in the future. + inner_ctx = object() + if not isinstance(grads, tuple): + grads = (grads,) + grad_inputs_dict = backward_fn(inner_ctx, saved, *grads) + + # Massage the grad_inputs_dict to a form acceptable by + # autograd.Function. + validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info) + return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info) + + generated_cls = gen_autograd_function( + custom_op._opname + '_customop', forward, backward) + + flat_output = generated_cls.apply(*flat_args) + assert out_spec is not None + return pytree.tree_unflatten(list(flat_output), out_spec) + return apply + + +def gen_autograd_function(name, forward, backward): + generated_cls = type( + name, + (torch.autograd.Function,), + { + 'forward': staticmethod(forward), + 'backward': staticmethod(backward), + } + ) + return generated_cls + + +@functools.lru_cache +def namedtuple_args_cls(schema): + attribs = [arg.name for arg in schema.arguments.flat_all] + name = str(schema.name) + "_args" + # mypy doesn't support dynamic namedtuple name + tuple_cls = namedtuple(name, attribs) # type: ignore[misc] + return tuple_cls + + +def namedtuple_args(schema, args): + assert isinstance(args, tuple) + tuple_cls = namedtuple_args_cls(schema) + return tuple_cls(*args) + + +def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info): + def error(what): + backward = forward_op._get_impl('backward') + raise RuntimeError( + f"In the backward function defined for {forward_op} at " + f"{backward.location} using the CustomOp API, {what}") + + if not isinstance(grad_inputs_dict, dict): + error(f"expected the output of the backward function to be a dict but " + f"got {type(grad_inputs_dict)}") + + expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all + if arg.type.is_tensor_like()} + actual_keys = grad_inputs_dict.keys() + if expected_keys != actual_keys: + error(f"expected the returned grad_input dict to have keys " + f"{expected_keys} but got {actual_keys}. The backward " + f"function must return a gradient (can be None) for each arg " + f"to the CustomOp that may be a Tensor or Sequence[Tensor]. " + f"Args declared to be non-Tensor-like types should not appear " + f"in the grad_input dict") + + for name, grad in grad_inputs_dict.items(): + arg_info = getattr(args_info, name) + + if isinstance(arg_info, list): + if not isinstance(grad, (tuple, list)): + error(f"for input '{name}' expected the grad_input dict to " + f"hold a list of gradients but got object of type " + f"{type(grad)}.") + if not len(grad) == len(arg_info): + error(f"for input '{name}' expected the grad_input dict to " + f"hold a list of {len(arg_info)} gradients but got " + f"{len(grad)}") + for idx, (g, info) in enumerate(zip(grad, arg_info)): + if g is None: + continue + if not isinstance(g, torch.Tensor): + error(f"for input '{name}' expected the grad_input dict to " + f"hold a list of None or Tensor gradients but got " + f"object of {type(g)} at index {idx}") + if not issubclass(info, torch.Tensor): + error(f"for input '{name}', got a Tensor as the gradient " + f"for the {idx}-th value but expected None because " + f"the {idx}-th value was not a Tensor (it was " + f"type {arg_info}") + continue + + if grad is None: + continue + if not isinstance(grad, torch.Tensor): + error(f"got object of type {type(grad)} as the gradient for input " + f"'{name}', " + f"but expected the gradient to be either None or a Tensor") + if not issubclass(arg_info, torch.Tensor): + error(f"got a Tensor as the gradient for input '{name}' but " + f"expected None as the gradient because input '{name}' " + f"was not a Tensor (it was type {arg_info}).") + + +def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info): + result = [] + for name, arg_info in args_info._asdict().items(): + if name not in grad_inputs_dict: + result.append(pytree.tree_map(lambda x: None, arg_info)) + continue + result.append(grad_inputs_dict[name]) + return tuple(pytree.tree_leaves(result)) + +# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it. +# autograd.Function prefers that users use ctx.save_for_backward to +# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the +# ctx object. +def save_pytree_for_backward(ctx, stuff): + flat_stuff, spec = pytree.tree_flatten(stuff) + num_elts = len(flat_stuff) + tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) + if isinstance(thing, torch.Tensor)] + non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) + if not isinstance(thing, torch.Tensor)] + tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)] + non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)] + + ctx.spec = spec + ctx.num_elts = num_elts + ctx.save_for_backward(*tensors) + ctx.tensor_idxs = tensor_idxs + ctx.saved_non_tensors = non_tensors + ctx.non_tensor_idxs = non_tensor_idxs + + +# Inverse operation to save_pytree_for_backward +def unpack_saved(ctx): + flat_stuff = [None] * ctx.num_elts + for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs): + flat_stuff[idx] = tensor + for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs): + flat_stuff[idx] = non_tensor + stuff = pytree.tree_unflatten(flat_stuff, ctx.spec) + return stuff diff --git a/lib/python3.10/site-packages/torch/_custom_op/functional.py b/lib/python3.10/site-packages/torch/_custom_op/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..57ff351e2e2d53a217008e793c57b1e3867ebe54 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_custom_op/functional.py @@ -0,0 +1,188 @@ +# mypy: allow-untyped-defs +import weakref + +import torch +import torch.utils._pytree as pytree +from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet +from torch._ops import OpOverload +from torch.library import Library +from torchgen.model import ( + BaseTy, + BaseType, + FunctionSchema, + OperatorName, + OptionalType, + SchemaKind, +) + +from .autograd import autograd_not_implemented + + +def register_functional_op( + lib: Library, + new_op_name: str, + mutable_op: OpOverload, +) -> None: + """Given a mutable operator, registers the functional variant. + + This API also correctly links the functional variant with the mutable + operator for the purposes of functionalization. + + All of the new registrations are performed on the ``lib`` passed in. + + Arguments: + lib (Library): Should be a torch.library.Library object that has + the same namespace as ``mutable_op``'s namespace. + lib will be used to register the new functional op as well + as a functionalization kernel for the ``mutable_op`` + If you don't have a library handy, use + ``torch.library.Library(ns, 'FRAGMENT')`` to construct one. + new_op_name (str): The name of the functional operator (without the + namespace). If no namespace, the new functional variant will be + accessible under ``torch.ops.{lib.ns}.new_op_name``. + mutable_op (OpOverload): The mutable custom operator. Note + that you may need to add a `.default` to it, like + `torch.ops.aten.abs_.default`. + + """ + validate(mutable_op) + schema = functional_schema(new_op_name, mutable_op) + lib.define(schema) + + functional_impl = construct_functional_impl(mutable_op) + lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd') + + functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default + + # There's no easy way for us to generate the autograd kernel, so we + # use autograd_not_implemented. Also, this makes it so that the user + # is unable to register an autograd formula themselves. This shouldn't + # be a problem if the user doesn't use the functional op direclty + # in their program, but we may need to revist this in the future. + lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd') + + f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op) + + lib.impl(mutable_op, f_kernel, 'Functionalize') + + +def construct_functional_impl(mutable_op): + def functional_impl(*args): + # Strategy: + # - clone args that would have been mutated + # - run mutable_op + # - return the cloned args as additional outputs + new_args = [] + extra_rets = [] + for is_write, arg in zip(mutable_args(mutable_op), args): + if is_write: + cloned = arg.clone() if arg is not None else None + new_args.append(cloned) + extra_rets.append(cloned) + else: + new_args.append(arg) + result = mutable_op(*new_args) + if result is None: + return tuple(extra_rets) + if isinstance(result, tuple): + return (*result, *extra_rets) + return (result, *extra_rets) + return functional_impl + + +def construct_functionalization_kernel(mutable_op, functional_op): + def kernel(*args): + # There's nothing to be functionalized! + # We can still end up here because DispatchKey::Functionalize is a mode key + if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args): + with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): + return mutable_op(*args) + + # NB: This differs from the codegen -- codegen handles cases where there + # are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper. + # This only really matters for XLA (mixed CPU-XLA tensors) and + # running functionalization without the PT2 stack (which guarantees to us that + # all tensors are FunctionalTensorWrapper). + if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args): + raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper") + + unwrapped_args = [] + for arg in args: + if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg): + torch._sync(arg) + unwrapped = torch._from_functional_tensor(arg) + unwrapped_args.append(unwrapped) + else: + unwrapped_args.append(arg) + + with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): + output = functional_op(*unwrapped_args) + + num_actual_output = len(mutable_op._schema.returns) + actual_output = pytree.tree_map( + torch._to_functional_tensor, output[:num_actual_output]) + + new_values_to_propagate = output[num_actual_output:] + inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args) + if is_write] + assert len(new_values_to_propagate) == len(inputs_to_replace) + for new_value, arg in zip(new_values_to_propagate, inputs_to_replace): + if (arg is None and new_value is None) or (arg is not None and new_value is not None): + continue + torch._C._propagate_xla_data(arg, new_value) + torch._C._replace_(arg, new_value) + torch._C._commit_update(arg) + torch._sync(arg) + + if len(actual_output) == 1: + return actual_output[0] + elif len(actual_output) == 0: + return None + return actual_output + + return kernel + + +def validate(mutable_op: OpOverload): + if not isinstance(mutable_op, OpOverload): + raise TypeError( + f"register_functional_op(mutable_op): expected mutable_op to be instance of " + f"OpOverload but got {type(mutable_op)}") + + # There are generally three types of "in-place" or "mutable" ops. + # Each of them have their own conventions: + # - inplace (first input modified in-place and returned as only output) + # - out= (some args modified in-place and returned as outputs) + # - mutable (some args modified in-place but none of those returned as outputs) + # In theory we can support all three, but we'll just support the last + # option right now for simplicity. + schema = FunctionSchema.parse(str(mutable_op._schema)) + if not schema.kind() == SchemaKind.mutable: + raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)") + for ret in schema.returns: + # construct_functionalization_kernel assumes this for simplicity + if ret.annotation is not None: + raise NotImplementedError( + "NYI: register_functional_op(op) where op returns a mutated or aliased value. " + "Please file an issue (and as a workaround, modify your operator to " + "not return the mutated value or aliases)") + for arg in schema.arguments.flat_all: + # construct_functionalization_kernel assumes this for simplicity + if arg.type.is_tensor_like() and ( + arg.type != BaseType(BaseTy.Tensor) + and arg.type != OptionalType(BaseType(BaseTy.Tensor)) + ): + raise NotImplementedError( + "NYI: register_functional_op(op) where op has a List[Tensor] input." + "Please file an issue.") + + +def functional_schema(new_op_name, op: OpOverload): + schema = FunctionSchema.parse(str(op._schema)) + schema = schema.signature().with_name(OperatorName.parse(new_op_name)) + return str(schema) + + +def mutable_args(op: OpOverload): + return tuple(False if arg.alias_info is None else arg.alias_info.is_write + for arg in op._schema.arguments) diff --git a/lib/python3.10/site-packages/torch/_custom_op/impl.py b/lib/python3.10/site-packages/torch/_custom_op/impl.py new file mode 100644 index 0000000000000000000000000000000000000000..c00e25ec7316b1dca66a8f77d9738c170b36d625 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_custom_op/impl.py @@ -0,0 +1,670 @@ +# mypy: allow-untyped-defs +import dataclasses +import functools +import inspect +import sys +import typing +import weakref +import warnings + +from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy + +import torch +import torch._C as _C +import torch.library as library +from torch.library import get_ctx + +from .autograd import autograd_kernel_indirection, construct_autograd_kernel +import torch._library.infer_schema +from torch._library.infer_schema import infer_schema + +""" +torch._custom_op is deprecated. We shipped a production-ready version of it into torch.library. +Please use those APIs instead. +""" + +__all__ = ["custom_op", "CustomOp", "get_ctx"] + + +SUPPORTED_DEVICE_TYPE_TO_KEY = { + "cpu": "CPU", + "cuda": "CUDA", +} + +# We will not let users register CustomOps with anything that could look like +# PyTorch internals to avoid confusion. +RESERVED_NS = { + "prim", + "prims", + "aten", + "at", + "torch", + "pytorch", +} + +def warn_deprecated(): + warnings.warn( + "torch._custom_op is deprecated and will be removed in PyTorch 2.6, please " + "use the equivalent torch.library API instead.", DeprecationWarning) + + +def custom_op( + qualname: str, manual_schema: typing.Optional[str] = None +) -> typing.Callable: + r""" + This API is deprecated, please use torch.library.custom_op instead + """ + warn_deprecated() + + def inner(func): + if not inspect.isfunction(func): + raise ValueError( + f"custom_op(...)(func): Expected `func` to be a Python " + f"function, got: {type(func)}" + ) + + ns, name = parse_qualname(qualname) + validate_namespace(ns) + if func.__name__ != name: + raise ValueError( + f"custom_op(qualname='{qualname}', ...)(func): expected `func` " + f"to have name '{name}' but got '{func.__name__}'. " + f"Please either change the name of `func` or the qualname that " + f"is passed to `custom_op`" + ) + + schema = infer_schema(func, mutates_args=()) if manual_schema is None else manual_schema + schema_str = f"{name}{schema}" + function_schema = FunctionSchema.parse(schema_str) + validate_schema(function_schema) + if manual_schema is not None: + validate_function_matches_schema(function_schema, func) + + lib = library.Library(ns, "FRAGMENT") + lib.define(schema_str) + ophandle = find_ophandle_or_throw(ns, function_schema.name) + result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) + + result.__name__ = func.__name__ + result.__module__ = func.__module__ + result.__doc__ = func.__doc__ + + library.impl(lib, result._opname, "Autograd")( + autograd_kernel_indirection(weakref.proxy(result)) + ) + + torch._C._dispatch_set_report_error_callback( + ophandle, functools.partial(report_error_callback, weakref.proxy(result)) + ) + + return result + + return inner + + +# Global dictionary holding references to all CustomOp objects +# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime]) +# Used to query the CustomOp associated with a specific C++ dispatcher operator. +# An example usage is FakeTensor: FakeTensor checks if a specific operator +# has an implementation registered via the CustomOp API. +# Indexed by qualname (e.g. aten::foo) +global_registry: typing.Dict[str, "CustomOp"] = {} + + +class CustomOp: + r""" + This API is deprecated, please use torch.library.custom_op instead + """ + + def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False): + super().__init__() + warn_deprecated() + if not _private_access: + raise RuntimeError( + "The CustomOp constructor is private and we do not guarantee " + "BC for it. Please use custom_op(...) to create a CustomOp object" + ) + name = f"{cpp_ns}::{operator_name}" + self._schema = schema + self._cpp_ns = cpp_ns + self._lib: library.Library = lib + self._ophandle: _C._DispatchOperatorHandle = ophandle + # Has the name of the op, e.g. "foo". We cache here for convenience. + self._opname: str = operator_name + # this is _opname but with namespace. e.g. "custom::foo" + self._qualname: str = name + self.__name__ = None # mypy requires this + # NB: Some of these impls are registered as kernels to DispatchKeys. + # Modifying the _impls dict directly won't do anything in that case. + self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {} + # See NOTE [CustomOp autograd kernel indirection] + self._registered_autograd_kernel_indirection = False + + global_registry[self._qualname] = self + + def _register_autograd_kernel_indirection(self): + assert not self._registered_autograd_kernel_indirection + self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd") + self._registered_autograd_kernel_indirection = True + + # Records the impl and the source location in self._impls + # Note that this doesn't cause torch.library to use the impl, that + # needs to be done in a separate self._lib.impl call. + def _register_impl(self, kind, func, stacklevel=2): + if self._has_impl(kind): + func_and_location = self._impls[kind] + assert func_and_location is not None # Pacify mypy + location = func_and_location.location + raise RuntimeError( + f"Attempting to register a {kind} impl for operator {self._qualname} " + f"that already has a {kind} impl registered from Python at " + f"{location}. This is not supported." + ) + frame = inspect.getframeinfo(sys._getframe(stacklevel)) + location = f"{frame.filename}:{frame.lineno}" + self._impls[kind] = FuncAndLocation(func, location) + + def _get_impl(self, kind): + return self._impls[kind] + + def _has_impl(self, kind): + return kind in self._impls + + def _destroy(self): + # NOTE: [CustomOp lifetime] + # A CustomOp, once created, lives forever. The mechanism is that the + # global registry holds a reference to it. However, to make testing + # easier, we want to be able to destroy CustomOp objects. + # CustomOp._destroy does the job, though it leaves the CustomOp + # in a garbage state. + del self._lib + + opnamespace = getattr(torch.ops, self._cpp_ns) + if hasattr(opnamespace, self._opname): + delattr(opnamespace, self._opname) + + del global_registry[self._qualname] + + def __repr__(self): + return f'' + + def __call__(self, *args, **kwargs): + # Bypass torch.ops.* and directly do OperatorHandle::callBoxed. + # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime + # issues from caching operators that make testing CustomOp difficult). + result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs) + return result + + def impl( + self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2, + ) -> typing.Callable: + r""" + This API is deprecated, please use torch.library.custom_op instead + """ + if isinstance(device_types, str): + device_types = [device_types] + for device_type in device_types: + validate_device_type(device_type) + + def inner(f): + for device_type in set(device_types): + self._check_doesnt_have_library_impl(device_type) + self._register_impl(device_type, f, stacklevel=_stacklevel) + dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] + library.impl(self._lib, self._opname, dispatch_key)(f) + return f + + return inner + + def _check_doesnt_have_library_impl(self, device_type): + if self._has_impl(device_type): + return + key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] + if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key): + raise RuntimeError( + f"impl(..., device_types={device_type}): the operator {self._qualname} " + f"already has an implementation for this device type via a " + f"pre-existing torch.library or TORCH_LIBRARY registration.") + + def impl_factory(self) -> typing.Callable: + r"""Register an implementation for a factory function.""" + + def inner(f): + self._register_impl("factory", f) + library.impl(self._lib, self._opname, "BackendSelect")(f) + return f + + return inner + + def impl_abstract(self, _stacklevel=2) -> typing.Callable: + r""" + This API is deprecated, please use torch.library.custom_op instead + """ + + def inner(f): + self._check_doesnt_have_library_meta_impl() + self._register_impl("abstract", f, stacklevel=_stacklevel) + location = self._get_impl("abstract").location + + qualname = self._qualname + + # Handle DispatchKey.Meta registration + @functools.wraps(f) + def f_with_ctx(*args, **kwargs): + def error_on_ctx(): + raise RuntimeError( + f"Attempted to call get_ctx() for the meta implementation " + f"for {qualname}." + f"You have presumably called get_ctx() because the operator " + f"has a data-dependent output shape; if so, there is no " + f"such meta implementation and this error is the correct " + f"behavior. Otherwise, please remove the call to get_ctx() " + f"in the implementation registered with impl_abstract " + f"at {location}" + ) + + with torch._library.fake_impl.set_ctx_getter(error_on_ctx): + return f(*args, **kwargs) + + self._lib.impl(self._opname, f_with_ctx, "Meta") + return f + + return inner + + def _check_can_register_backward(self): + def error(detail): + raise RuntimeError( + f"Cannot use torch._custom_ops APIs to register backward " + f"formula for {detail}. Got operator " + f"{self._qualname} with schema: {schema}" + ) + + schema = self._schema + if schema.kind() != SchemaKind.functional: + error("non-functional operator") + + rets = schema.returns + if not schema.returns: + error("operator with no returns") + + assert len(rets) > 0 + is_non_mutating_view = any( + r.annotation is not None and not r.annotation.is_write for r in rets + ) + if is_non_mutating_view: + error("operator that returns views") + + # We make assumptions about the schema's return types. + allowed_return_types = { + BaseType(BaseTy.int): "int", + BaseType(BaseTy.SymInt): "SymInt", + BaseType(BaseTy.bool): "bool", + BaseType(BaseTy.float): "float", + BaseType(BaseTy.Tensor): "Tensor", + ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]", + } + for ret in schema.returns: + if ret.type in allowed_return_types: + continue + error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})") + + def _check_doesnt_have_library_autograd_impl(self): + if self._registered_autograd_kernel_indirection: + return + + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"): + raise RuntimeError( + f"impl_backward/impl_save_for_backward: the operator {self._qualname} " + f"already has an implementation for this device type via a " + f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." + f"CompositeImplicitAutograd operators do not need an autograd formula; " + f"instead, the operator will decompose into its constituents and those " + f"can have autograd formulas defined on them.") + + # We can improve this by adding "all Autograd keys", but + # realistically people will just be using this API for CPU/CUDA for now. + for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]: + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key): + raise RuntimeError( + f"impl_backward/impl_save_for_backward: " + f"the operator {self._qualname} already has an Autograd kernel " + f"registered to DispatchKey::{key} vi a pre-existing " + f"torch.library or TORCH_LIBRARY registration. Please either " + f"remove those registrations or don't use the torch._custom_ops APIs") + + def _check_doesnt_have_library_meta_impl(self): + if self._has_impl("abstract"): + return + + # If the user's operator is CompositeExplicitAutograd, + # allow them to impl_abstract. This is being pragmatic + # (existing custom ops may have CompositeExplicitAutograd + # registration that don't work with Meta kernels, so this + # gives them an escape hatch). + if ( + _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd") + and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta") + ): + return + + # Otherwise, if the user's already has a Meta kernel or their + # op is CompositeImplicitAutograd or some other alias dispatch key, + # raise. + + # Special case for CompositeImplicitAutograd + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"): + raise RuntimeError( + f"impl_abstract(...): the operator {self._qualname} " + f"already has an implementation for this device type via a " + f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." + f"CompositeImplicitAutograd operators do not need an abstract impl; " + f"instead, the operator will decompose into its constituents and those " + f"can have abstract impls defined on them.") + + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"): + raise RuntimeError( + f"impl_abstract(...): the operator {self._qualname} " + f"already has an DispatchKey::Meta implementation via a " + f"pre-existing torch.library or TORCH_LIBRARY registration. " + f"Please either remove that registration or don't call impl_abstract.") + + # NOTE ["backward", "save_for_backward", and "autograd"] + # As a part of the explicit autograd API, a user must provide us + # a "save_for_backward" function and a "backward" function. + # When both of these have been provided, then we automatically + # construct the "autograd" kernel. + def _register_autograd_kernel(self): + assert self._has_impl("backward") + assert self._has_impl("save_for_backward") + kernel = construct_autograd_kernel( + self._schema, + self._output_differentiability, + self, + get_op(self._qualname), + self._get_impl("save_for_backward").func, + self._get_impl("backward").func) + self._register_impl("autograd", kernel) + + def impl_save_for_backward(self, _stacklevel=2): + r"""Register a function that tells us what to save for backward. + + Please see impl_backward for more details. + """ + def inner(f): + self._check_can_register_backward() + self._check_doesnt_have_library_autograd_impl() + if not self._registered_autograd_kernel_indirection: + self._register_autograd_kernel_indirection() + self._register_impl("save_for_backward", f, stacklevel=_stacklevel) + if self._has_impl("backward"): + self._register_autograd_kernel() + return inner + + def impl_backward(self, output_differentiability=None, _stacklevel=2): + r""" + This API is deprecated, please use torch.library.custom_op instead + """ + if output_differentiability is not None: + def yell(): + raise RuntimeError( + f"impl_backward(output_differentiability): expected " + f"output_differentiability to be a list of bools with " + f"length equal to the number of outputs of this CustomOp " + f"got: {output_differentiability}") + + if not isinstance(output_differentiability, list): + yell() + for diff in output_differentiability: + if not isinstance(diff, bool): + yell() + if len(self._schema.returns) != len(output_differentiability): + yell() + + def inner(f): + self._check_can_register_backward() + self._check_doesnt_have_library_autograd_impl() + if not self._registered_autograd_kernel_indirection: + self._register_autograd_kernel_indirection() + self._register_impl("backward", f, stacklevel=_stacklevel) + self._output_differentiability = output_differentiability + if self._has_impl("save_for_backward"): + self._register_autograd_kernel() + return inner + + +@dataclasses.dataclass +class FuncAndLocation: + func: typing.Callable + location: str + + +def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName): + overload_name = ( + "" if operator_name.overload_name is None else operator_name.overload_name + ) + return _C._dispatch_find_schema_or_throw( + f"{cpp_ns}::{str(operator_name.name)}", overload_name + ) + + +def validate_namespace(ns: str) -> None: + if "." in ns: + raise ValueError( + f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a ' + f"valid variable name)" + ) + if ns in RESERVED_NS: + raise ValueError( + f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, " + f"please choose something else. " + ) + +def validate_schema(schema: FunctionSchema) -> None: + if not torch._library.utils.is_functional_schema(schema): + raise ValueError( + f"custom_op only supports functional operators " + f"(ops that do not mutate any inputs, do not return " + f"views of the inputs, and has at least one return). " + f"Got the following non-functional schema: {schema}" + ) + + # For simplicity: don't allow self arguments + if schema.arguments.self_arg is not None: + raise ValueError( + f"custom_op does not support arguments named 'self'. Please " + f"rename your argument. Got: {schema}" + ) + + +def parse_qualname(qualname: str) -> typing.Tuple[str, str]: + names = qualname.split("::", 1) + if len(names) != 2: + raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The " + f"operator name should look something like ns::foo") + if '.' in names[1]: + raise ValueError(f"The torch.custom_ops APIs do not handle overloads, " + f"i.e. operator names with '.' in them. " + f"Please name your operator something like ns::foo. " + f"Got: {qualname}") + return names[0], names[1] + + +def validate_device_type(device_type: str) -> None: + if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY: + raise ValueError( + f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type " + f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}." + ) + + +def supported_param(param: inspect.Parameter) -> bool: + return param.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + + +def validate_function_matches_schema( + schema: FunctionSchema, func: typing.Callable +) -> None: + sig = inspect.signature(func) + + if not all(supported_param(p) for _, p in sig.parameters.items()): + raise ValueError( + f"custom_op(..., manual_schema)(func): positional-only args, " + f"varargs, and kwargs are not supported. Please rewrite `func` " + f"to not have them. Got `func` with signature: {sig}" + ) + + if ( + any( + p.annotation is not inspect.Parameter.empty + for _, p in sig.parameters.items() + ) + or sig.return_annotation is not inspect.Signature.empty + ): + raise ValueError( + f"custom_op(..., manual_schema)(func): When passing in a manual " + f"schema, we expect `func` to have no type annotations to avoid " + f"ambiguity. Got `func` with signature: {sig}" + ) + + positional = [ + (name, param) + for name, param in sig.parameters.items() + if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ] + kwargonly = [ + (name, param) + for name, param in sig.parameters.items() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ] + + def error(): + raise ValueError( + f"custom_op(..., manual_schema)(func): When passing in a manual " + f"schema, we expect `func`'s signature to match `manual_schema` " + f"(aside from type annotations). " + f"func's signature: {sig}, manual_schema: {schema}" + ) + + def error_default_args(): + raise ValueError( + f"custom_op(..., manual_schema)(func): " + f"neither func nor manual_schema should have default " + f"arguments. Got " + f"func's signature: {sig}, manual_schema: {schema}" + ) + + def compare(sig_args, schema_args): + if len(sig_args) != len(schema_args): + error() + for (name, param), arg in zip(sig_args, schema_args): + if name != arg.name: + error() + if param.default is not inspect.Parameter.empty or arg.default is not None: + error_default_args() + + compare(positional, schema.arguments.flat_positional) + compare(kwargonly, schema.arguments.flat_kwarg_only) + + +def report_error_callback(custom_op: typing.Any, key: str) -> None: + if key == "Undefined": + raise NotImplementedError( + f"{custom_op}: There were no Tensor inputs to this operator " + f"(e.g. you passed an empty list of Tensors). If your operator is a " + f"factory function (that is, it takes no Tensors and constructs " + f"a new one), then please use CustomOp.impl_factory to register " + f"an implementation for it" + ) + if key == "Meta": + raise NotImplementedError( + f"{custom_op}: when running with device='Meta' tensors: there is no " + f"abstract impl registered for this CustomOp. Please register one via " + f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors" + ) + if key in ("CPU", "CUDA"): + device = key.lower() + raise NotImplementedError( + f"{custom_op}: when running with device='{device}' tensors: there is no " + f"{device} impl registered for this CustomOp. Please register one via " + f"CustomOp.impl(device_type='{device}')" + ) + raise NotImplementedError( + f"{custom_op}: No implementation for dispatch key {key}. It is likely " + f"that we have not added this functionality yet, please either open an " + f"issue or if you're feeling adventurous, use the low-level " + f"torch.library API" + ) + + +def custom_op_from_existing(op): + ns = op.namespace + lib = torch.library.Library(ns, "FRAGMENT") + name = op.name().split("::")[-1] + schema_str = str(op._schema) + # CustomOp expects the schema string without the namespace + schema_str = schema_str.split("::")[-1] + schema = FunctionSchema.parse(schema_str) + return CustomOp(lib, ns, schema, name, op, _private_access=True) + + +def get_op(qualname): + def error_not_found(): + raise ValueError( + f"Could not find the operator {qualname}. Please make sure you have " + f"already registered the operator and (if registered from C++) " + f"loaded it via torch.ops.load_library.") + + ns, name = parse_qualname(qualname) + if not hasattr(torch.ops, ns): + error_not_found() + opnamespace = getattr(torch.ops, ns) + if not hasattr(opnamespace, name): + error_not_found() + packet = getattr(opnamespace, name) + if not hasattr(packet, 'default'): + error_not_found() + return packet.default + + +def _find_custom_op(qualname, also_check_torch_library=False): + if qualname in global_registry: + return global_registry[qualname] + if not also_check_torch_library: + raise RuntimeError( + f'Could not find custom op "{qualname}". Did you register it via ' + f"the torch._custom_ops API?") + overload = get_op(qualname) + result = custom_op_from_existing(overload) + return result + + +def get_abstract_impl(qualname): + if qualname not in torch._custom_op.impl.global_registry: + return None + custom_op = torch._custom_op.impl.global_registry[qualname] + if custom_op is None: + return None + if not custom_op._has_impl("abstract"): + return None + return custom_op._get_impl("abstract").func + + +def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True): + ns, name = qualname.split("::") + schema_str = f"{name}{schema}" + function_schema = FunctionSchema.parse(schema_str) + validate_schema(function_schema) + tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else [] + lib = library.Library(ns, "FRAGMENT") + lib.define(schema_str, tags=tags) + ophandle = find_ophandle_or_throw(ns, function_schema.name) + result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) + result._register_autograd_kernel_indirection() + + torch._C._dispatch_set_report_error_callback( + ophandle, functools.partial(report_error_callback, weakref.proxy(result)) + ) + return get_op(qualname) diff --git a/lib/python3.10/site-packages/torch/_decomp/__init__.py b/lib/python3.10/site-packages/torch/_decomp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93bbec04a425be731a4958ea075703c8386e121a --- /dev/null +++ b/lib/python3.10/site-packages/torch/_decomp/__init__.py @@ -0,0 +1,484 @@ +# mypy: allow-untyped-defs +import inspect +from collections import defaultdict +from functools import wraps +from itertools import chain +from typing import Callable, Dict, List, Sequence, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +import torch.library +from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket +from torch._prims_common import CustomOutParamAnnotation +from torch.utils import _pytree as pytree + + +__all__ = [ + "decomposition_table", + "pre_autograd_decomposition_table", + "meta_table", + "register_decomposition", + "get_decompositions", + "core_aten_decompositions", +] + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +# TODO: relax key type here; torch registrations should be possible to; but +# right now this type is accurate +global_decomposition_table: Dict[ + str, Dict[torch._ops.OperatorBase, Callable] +] = defaultdict(dict) + +decomposition_table = global_decomposition_table["post_autograd"] +pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"] +meta_table = global_decomposition_table["meta"] + + +def _add_op_to_registry(registry, op, fn): + """ + This is an internal API for adding an op to the decomposition table. + + If op is OpOverload, it will be added to the registry directly. + If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry. + """ + overloads: List[Union[torch._ops.OperatorBase]] = [] + if isinstance(op, HigherOrderOperator): + # There's no concept of overloads for HigherOrderOperator + registry[op] = fn + return + elif isinstance(op, OpOverload): + overloads.append(op) + else: + assert isinstance(op, OpOverloadPacket) + for ol in op.overloads(): + overloads.append(getattr(op, ol)) + + for op_overload in overloads: + if op_overload in registry: + raise RuntimeError(f"duplicate registrations for {op_overload}") + # TorchScript dumps a bunch of extra nonsense overloads + # which don't have corresponding dispatcher entries, we need + # to filter those out, e.g aten.add.float_int + if torch._C._dispatch_has_kernel(op_overload.name()): + registry[op_overload] = fn + + +def _convert_out_params(f): + out_annotation = f.__annotations__.get("out") + + # If there are no out params, do not wrap the function. + if not out_annotation: + return f + + # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this + if getattr(out_annotation, "__origin__", None) is tuple: + sig = inspect.signature(f) + out_names = sig.return_annotation._fields + # If out is a tuple, we need to register a function that unpacks all the out + # elements as this is what native_functions.yaml expects + + @wraps(f) + def _fn(*args, **kwargs): + out_kwargs = tuple(kwargs.pop(o, None) for o in out_names) + # Either all of the out kwargs are set or none of them + is_none = out_kwargs[0] is None + assert all((o is None) == is_none for o in out_kwargs) + return f(*args, **kwargs, out=None if is_none else out_kwargs) + + out_params = [ + inspect.Parameter( + o, + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=t, + ) + for o, t in zip(out_names, out_annotation.__args__) + ] + # Drop the out parameter and concatenate the new kwargs in the signature + params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params) + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type] + ) + # Drop the out parameter and concatenate the new kwargs in the annotations + _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"} + for o in out_params: + _fn.__annotations__[o.name] = o.annotation + + # Propagate that this function is wrapped by `out_wrapper` + _fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined] + + return _fn + + # Alternatively, there may be a single tensor out parameter with a name + # other than "out". This will need special treatment and is indicated by an + # annotation, which we will remove here so it is not exposed after wrapping. + custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None) + if custom_out_param_name: + + @wraps(f) + def _fn(*args, **kwargs): + out_kwarg = kwargs.pop(custom_out_param_name, None) + return f(*args, **kwargs, out=out_kwarg) + + out_param = inspect.Parameter( + custom_out_param_name, + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=out_annotation, + ) + + # Drop the out parameter and concatenate the new kwarg in the signature + sig = inspect.signature(f) + params = chain( + (v for k, v in sig.parameters.items() if k != "out"), (out_param,) + ) + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type] + ) + + # Drop the out parameter and concatenate the new kwargs in the annotations + _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"} + _fn.__annotations__[out_param.name] = out_param.annotation + + return _fn + + return f + + +def register_decomposition( + aten_op, registry=None, *, type="post_autograd", unsafe=False +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """ + A decorator to register a function as a decomposition to the Python + decomposition table. Use it like this:: + + @register_decomposition(torch.ops.aten.clamp_min) + def clamp_min(x): + return torch.clamp(self, min=min) + + If you are writing a new decomposition, consider contributing it + directly to PyTorch in torch._decomp.decompositions. + + This API is experimental; we are almost certainly going to extend + the API when we make decompositions eligible for use in transforms (e.g., + autograd) and not just backend tracing, where we then need to know if a + decomposition can be used to simulate a transform. + + By default, we also will register it to the Meta key of dispatcher, + and replace the c++ Meta implementation if there is already one. + + unsafe kwarg is for reuse of this function for registering non-function + things + """ + + assert type in {"post_autograd", "pre_autograd", "meta"} + + def decomposition_decorator(fn: Callable[_P, _T]) -> Callable[_P, _T]: + orig_fn = fn + if not unsafe: + fn = _convert_out_params(fn) + + nonlocal registry + if registry is None: + registry = global_decomposition_table[type] + + def register(op): + _add_op_to_registry(registry, op, fn) + + # To handle allowing multiple aten_ops at once + pytree.tree_map_(register, aten_op) + return orig_fn + + return decomposition_decorator + + +def get_decompositions( + aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]], + type: str = "post_autograd", +) -> Dict[torch._ops.OperatorBase, Callable]: + """ + Retrieve a dictionary of decompositions corresponding to the list of + operator overloads and overload packets passed as input. Overload + packets will include all decomposed overloads in the packet. If there is + no decomposition for a requested operator, it is silently ignored. + + This API is experimental; we are almost certainly going to give an alternate, + more recommended formulation, where a user provides the set of operators + they know how to implement, and we provide decompositions for everything + not in this set. + """ + assert type in {"post_autograd", "pre_autograd", "meta"} + + registry = global_decomposition_table[type] + packets_to_overloads = defaultdict(list) + for opo in registry: + if isinstance(opo, (OpOverload, OpOverloadPacket)): + packets_to_overloads[opo.overloadpacket].append(opo) + decompositions: Dict[torch._ops.OperatorBase, Callable] = {} + for op in aten_ops: + if isinstance(op, OpOverloadPacket) and op in packets_to_overloads: + for op_overload in packets_to_overloads[op]: + decompositions[op_overload] = registry[op_overload] + elif isinstance(op, (torch._ops.OperatorBase)) and op in registry: + decompositions[op] = registry[op] + return decompositions + + +def remove_decompositions( + decompositions: Dict[torch._ops.OperatorBase, Callable], + aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]], +) -> None: + """ + Given a dictionary of decompositions obtained from get_decompositions(), removes + operators associated with a list of operator overloads and overload packets passed + as input. If the decomposition dictionary does not contain a decomposition that is + specified to be removed, it is silently ignored. + """ + for op in aten_ops: + if isinstance(op, OpOverloadPacket): + for overload_name in op.overloads(): + opo = getattr(op, overload_name) + decompositions.pop(opo, None) + elif isinstance(op, OpOverload): + decompositions.pop(op, None) + + +# populate the table +import torch._decomp.decompositions +import torch._refs + + +# See NOTE [Core ATen Ops] +# +# list was copied from torch/_inductor/decomposition.py +# excluding decompositions that results in prim ops +# Resulting opset of decomposition is core aten ops +def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: + aten = torch.ops.aten + return get_decompositions( + [ + aten.addcdiv, + aten.addcdiv_, + aten.addcmul, + aten.addcmul_, + aten.addr, + aten.affine_grid_generator, + aten.alias_copy, + aten.all, + aten.aminmax, + aten.arange.default, + aten.arange.start, + aten.avg_pool2d_backward, + aten.baddbmm, + aten.binary_cross_entropy, + aten.binary_cross_entropy_backward, + aten.binary_cross_entropy_with_logits, + aten.block_diag, + aten.celu, + aten.celu_, + aten.channel_shuffle, + aten.clamp_max, + aten.clamp_min, + aten.col2im, + aten.count_nonzero, + aten.linalg_cross, + aten.cudnn_batch_norm, + aten.cudnn_batch_norm_backward, + aten.miopen_batch_norm_backward, + aten.deg2rad, + aten.deg2rad_, + aten.detach, + aten.diag_embed, + aten.diagonal_backward, + aten.dot, + aten.vdot, + aten.elu, + aten.elu_, + aten.elu_backward, + aten._embedding_bag, + aten.embedding_dense_backward, + aten.empty_like, + aten._euclidean_dist.default, + aten.expand_as, + aten.expand_copy, + aten.eye, + aten.fill, + aten.fill_, + aten.floor_divide, + aten.frac, + aten.frac_, + aten._fused_moving_avg_obs_fq_helper, + aten.gelu_, + aten.gelu_backward, + aten.glu, + aten.glu_backward, + aten.hardshrink, + aten.hardsigmoid, + aten.hardsigmoid_, + aten.hardsigmoid_backward, + aten.hardswish, + aten.hardswish_, + aten.hardswish_backward, + aten.hardtanh_, + aten.hardtanh_backward, + aten.heaviside, + aten.heaviside_, + aten.huber_loss, + aten.huber_loss_backward, + aten.im2col, + aten.index_add, + aten.index_add_, + aten.index_copy, + aten.index_copy_, + aten.index_fill, + aten.index_fill_, + aten.isin, + aten.isneginf, + aten.isposinf, + aten.l1_loss, + aten._lazy_clone, + aten._test_parallel_materialize, + aten.leaky_relu_, + aten.leaky_relu_backward, + aten.lerp, + aten.lerp_, + aten.linspace, + aten.logaddexp, + aten.logaddexp2, + aten.logit, + aten.logit_, + aten.logit_backward, + aten.log_sigmoid_backward, + aten.log_sigmoid_forward, + aten._log_softmax_backward_data, + aten.logspace, + aten.logsumexp.default, + aten.masked_fill, + aten.masked_fill_, + aten.mish, + aten.mish_, + aten.mse_loss, + aten.mse_loss_backward, + aten.multi_margin_loss, + aten.multilabel_margin_loss_forward, + aten.mv, + aten.mvlgamma, + aten.mvlgamma_, + aten.nansum, + aten.nan_to_num, + aten.nan_to_num_, + aten.narrow, + aten.native_batch_norm_backward, + aten.native_dropout_backward, + aten.native_group_norm_backward, + aten.native_layer_norm_backward, + aten.new_empty, + aten.new_full, + aten.new_ones, + aten.new_zeros, + aten.nll_loss2d_forward, + aten.nll_loss2d_backward, + aten.nll_loss_backward, + aten.nll_loss_forward, + aten.norm, + aten.ones, + aten.ones_like, + aten.pixel_shuffle, + aten.pixel_unshuffle, + aten._prelu_kernel, + aten._prelu_kernel_backward, + aten._reshape_alias, + aten.rad2deg, + aten.rad2deg_, + aten.reflection_pad1d, + aten.reflection_pad1d_backward, + aten.reflection_pad2d, + aten.reflection_pad2d_backward, + aten.reflection_pad3d, + aten.reflection_pad3d_backward, + aten.replication_pad1d, + aten.replication_pad2d, + aten.replication_pad3d, + aten.renorm, + aten.renorm_, + aten.replication_pad2d, + aten.resize_as, + aten.roll, + aten.rot90, + aten.rrelu_with_noise, + aten.rrelu_with_noise_, + aten.rsub, + aten._safe_softmax, + aten._scaled_dot_product_flash_attention_for_cpu.default, + aten.select_backward, + aten.select_scatter, + aten.sgn, + aten.sgn_, + aten.sigmoid_backward, + aten.silu, + aten.silu_, + aten.silu_backward, + aten.sinc, + aten.sinc_, + aten.slice_backward, + aten.smooth_l1_loss, + aten.smooth_l1_loss_backward, + aten.soft_margin_loss, + aten.soft_margin_loss_backward, + aten._softmax_backward_data, + aten.softplus, + aten.softplus_backward, + aten.softshrink, + aten.special_entr, + aten.special_log_ndtr, + aten.special_xlog1py, + aten.split.Tensor, + aten.split_with_sizes_copy, + aten.squeeze.default, + aten.squeeze.dim, + aten.std, + aten.std_mean, + aten.stack, + aten.sum.default, + aten.sum.out, + aten.t, + aten.t_copy, + aten.take, + aten.tanh_backward, + aten.threshold, + aten.threshold_, + aten.threshold_backward, + aten.trace, + aten.transpose.int, + aten.tril, + aten.tril_, + aten.triu, + aten.triu_, + aten.unbind, + aten.unfold_backward, + aten.unfold_copy, + aten._unsafe_index, + aten._unsafe_index_put, + aten._unsafe_masked_index, + aten._unsafe_masked_index_put_accumulate, + aten.unsafe_split.Tensor, + aten.unsafe_split_with_sizes, + aten.unsqueeze_copy, + aten._unsafe_view, + aten.upsample_linear1d, + aten.upsample_bilinear2d, + aten.upsample_trilinear3d, + aten.upsample_nearest2d_backward, + aten.view_as_complex, + aten.xlogy, + aten.xlogy_, + aten.zero, + aten.zero_, + aten.zeros, + aten.zeros_like, + aten._chunk_cat, + aten._weight_norm_interface, + ] + ) diff --git a/lib/python3.10/site-packages/torch/_decomp/decompositions.py b/lib/python3.10/site-packages/torch/_decomp/decompositions.py new file mode 100644 index 0000000000000000000000000000000000000000..c35d7a72774f1643b3688446983e7ee27442ca00 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_decomp/decompositions.py @@ -0,0 +1,5113 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import itertools +import numbers +import operator +import sys +from enum import Enum +from functools import partial, reduce +from itertools import chain, product +from typing import Any, Callable, cast, Iterable, List, Optional, Tuple, Union + +import torch +import torch._meta_registrations +import torch._prims as prims +import torch._prims_common as utils +import torch.nn.functional as F +from torch import sym_float, sym_int, Tensor +from torch._decomp import register_decomposition +from torch._higher_order_ops.out_dtype import out_dtype +from torch._prims_common import ( + IntLike, + NumberType, + suggest_memory_format, + TensorLike, + TensorSequenceType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + _maybe_resize_out, + _safe_copy_out, + out_wrapper, +) +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map + + +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + +# None of these functions are publicly accessible; get at them +# from torch._decomps +__all__: List[str] = [] + +aten = torch._ops.ops.aten + + +class Reduction(Enum): + NONE = 0 + MEAN = 1 + SUM = 2 + + +# This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided +# We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops +# Will need to validate the non-elementwise uses +def type_casts( + f: Callable, + type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND, + compute_dtype_only: bool = False, +): + @functools.wraps(f) + def inner(*args, **kwargs): + flat_args = [ + x for x in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(x, Tensor) + ] + computation_dtype, result_dtype = utils.elementwise_dtypes( + *flat_args, type_promotion_kind=type_promotion + ) + + # TODO: pretty sure this is not quite right + def increase_prec(x): + if isinstance(x, Tensor): + return x.to(computation_dtype) + else: + return x + + def decrease_prec(x): + if isinstance(x, Tensor): + return x.to(result_dtype) + else: + return x + + r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs)) + if compute_dtype_only: + return r + else: + return tree_map(decrease_prec, r) + + return inner + + +compute_only_pw_cast_for_opmath = partial( + type_casts, + type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + compute_dtype_only=True, +) +pw_cast_for_opmath = partial( + type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT +) +pw_cast_for_int_to_real = partial( + type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +) + + +# This expands x until x.dim() == dim. Might be useful as an operator +def _unsqueeze_to_dim(x: Tensor, dim: int) -> Tensor: + for _ in range(dim - x.dim()): + x = x.unsqueeze(-1) + return x + + +@register_decomposition(aten.tanh_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def tanh_backward(out_grad: Tensor, y: Tensor): + return out_grad * (1 - y * y).conj_physical() + + +@register_decomposition(aten.sigmoid_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def sigmoid_backward(out_grad: Tensor, y: Tensor): + return out_grad * (y * (1 - y)).conj_physical() + + +@register_decomposition(aten.softplus_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float): + z = (x * beta).exp() + return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0)) + + +@register_decomposition(aten.elu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def elu_backward( + grad_output: Tensor, + alpha: float, + scale: float, + input_scale: float, + is_result: bool, + self_or_result: Tensor, +): + negcoef = alpha * scale + poscoef = scale + negiptcoef = input_scale + if is_result: + return torch.where( + self_or_result <= 0, + grad_output * negiptcoef * (self_or_result + negcoef), + grad_output * poscoef, + ) + else: + return torch.where( + self_or_result <= 0, + grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef), + grad_output * poscoef, + ) + + +@register_decomposition([aten.fill.Scalar]) +def fill_scalar(self, value): + return torch.full_like(self, value) + + +@register_decomposition([aten.fill.Tensor]) +def fill_tensor(self, value: Tensor): + torch._check( + value.dim() == 0, + lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions", + ) + return aten.copy(self, value) + + +@register_decomposition(aten.hardsigmoid) +@out_wrapper() +@pw_cast_for_opmath +def hardsigmoid(self: Tensor) -> Tensor: + return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6 + + +@register_decomposition(aten.hardsigmoid_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def hardsigmoid_backward(grad_output: Tensor, self: Tensor): + return torch.where( + (self > -3.0) & (self < 3.0), + grad_output * (1.0 / 6.0), + 0.0, + ) + + +@register_decomposition(aten.hardtanh_backward) +@out_wrapper("grad_input") +def hardtanh_backward( + grad_output: Tensor, self: Tensor, min_val: float, max_val: float +): + return torch.where((self <= min_val) | (self >= max_val), 0.0, grad_output) + + +@register_decomposition(aten.hardswish) +@out_wrapper() +@pw_cast_for_opmath +def hardswish(self: Tensor) -> Tensor: + return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6 + + +@register_decomposition(aten.hardswish_backward) +@out_wrapper() +@pw_cast_for_opmath +def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor: + return torch.where( + self < -3, + 0.0, + torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output), + ) + + +@register_decomposition(aten.threshold_backward) +@out_wrapper("grad_input") +def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float): + return torch.where(self <= threshold, 0, grad_output) + + +@register_decomposition(aten.leaky_relu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def leaky_relu_backward( + grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool +): + return torch.where(self > 0, grad_output, grad_output * negative_slope) + + +@register_decomposition(aten.gelu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"): + M_SQRT2 = 1.41421356237309504880 + M_SQRT1_2 = 0.70710678118654752440 + M_2_SQRTPI = 1.12837916709551257390 + if approximate == "tanh": + kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 + kKappa = 0.044715 + x_sq = self * self + x_cube = x_sq * self + inner = kBeta * (self + kKappa * x_cube) + tanh_inner = torch.tanh(inner) + + left = 0.5 * self + right = 1 + tanh_inner + + left_derivative = 0.5 * right + + tanh_derivative = 1 - tanh_inner * tanh_inner + inner_derivative = kBeta * (1 + 3 * kKappa * x_sq) + right_derivative = left * tanh_derivative * inner_derivative + + return grad * (left_derivative + right_derivative) + else: + kAlpha = M_SQRT1_2 + kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5 + cdf = 0.5 * (1 + torch.erf(self * kAlpha)) + pdf = kBeta * torch.exp(self * self * -0.5) + return grad * (cdf + self * pdf) + + +@register_decomposition(aten.mish_backward) +@pw_cast_for_opmath +def mish_backward(grad_output: Tensor, input: Tensor): + input_tanh_softplus = torch.tanh(F.softplus(input)) + input_sigmoid = torch.sigmoid(input) + out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus) + return grad_output * (input_tanh_softplus + out) + + +@register_decomposition(aten.silu) +@out_wrapper() +@pw_cast_for_opmath +def silu(self: Tensor) -> Tensor: + return self * torch.sigmoid(self) + + +@register_decomposition(aten.silu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor: + sigmoid = 1 / (1 + torch.exp(-self)) + return grad_output * sigmoid * (1 + self * (1 - sigmoid)) + + +@register_decomposition(aten._prelu_kernel) +def _prelu_kernel(self: Tensor, weight: Tensor) -> Tensor: + return torch.where(self > 0, self, weight * self) + + +@register_decomposition(aten._prelu_kernel_backward) +def _prelu_kernel_backward( + grad_output: Tensor, + self: Tensor, + weight: Tensor, +) -> Tuple[Tensor, Tensor]: + input_grad = torch.where(self > 0, grad_output, weight * grad_output) + weight_grad = torch.where(self > 0, 0.0, self * grad_output) + return (input_grad, weight_grad) + + +@register_decomposition(aten.rrelu_with_noise) +@aten.rrelu_with_noise.default.py_impl(DispatchKey.AutogradCUDA) +@out_wrapper() +@pw_cast_for_opmath +def rrelu_with_noise( + self: Tensor, + noise: Tensor, + lower: float = 0.125, + upper: float = 0.3333333333333333, + training: bool = False, + generator: Optional[torch.Generator] = None, +) -> Tensor: + assert generator is None + if training: + not_positive = self <= 0 + r = aten.uniform(self, lower, upper) + output = torch.where(not_positive, self * r, self) + noise.copy_(torch.where(not_positive, r, 1)) + return output + else: + negative_slope = (lower + upper) / 2 + return aten.leaky_relu(self, negative_slope) + + +@register_decomposition(aten.rrelu_with_noise_) +@aten.rrelu_with_noise_.default.py_impl(DispatchKey.AutogradCUDA) +@pw_cast_for_opmath +def rrelu_with_noise_( + self: Tensor, + noise: Tensor, + lower: float = 0.125, + upper: float = 0.3333333333333333, + training: bool = False, + generator: Optional[torch.Generator] = None, +) -> Tensor: + return self.copy_(rrelu_with_noise(self, noise, lower, upper, training, generator)) + + +@register_decomposition(aten.rrelu_with_noise_backward) +@out_wrapper() +@pw_cast_for_opmath +def rrelu_with_noise_backward( + grad_output: Tensor, + self: Tensor, + noise: Tensor, + lower: float, + upper: float, + training: bool, + self_is_result: bool, +) -> Tensor: + if training and upper - lower > 1e-6: + return grad_output.mul(noise) + else: + negative_slope = (lower + upper) / 2 + return aten.leaky_relu_backward( + grad_output, self, negative_slope, self_is_result + ) + + +@register_decomposition(aten.log_sigmoid_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor: + in_negative = self < 0 + max_deriv = torch.where(in_negative, 1, 0) + sign = torch.where(in_negative, 1, -1) + z = torch.exp(-torch.abs(self)) + return grad_output * (max_deriv - sign * (z / (1 + z))) + # CPU has a special formula that uses buffer, but disabled for convenience sake + # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output + + +def apply_loss_reduction(loss: Tensor, reduction: int): + if reduction == Reduction.MEAN.value: + return torch.mean(loss) + elif reduction == Reduction.SUM.value: + return torch.sum(loss) + else: + return loss + + +def to_real_dtype(dtype: torch.dtype): + if dtype == torch.complex32: + return torch.float16 + elif dtype == torch.complex64: + return torch.float32 + elif dtype == torch.complex128: + return torch.float64 + + +# TODO: None of these loss castings are quite correct, see +# https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels +# perform the pointwise portion in opmath, but don't maintain it between the +# pointwise portion and the reduction + + +@register_decomposition(aten.mse_loss) +@out_wrapper() +@pw_cast_for_opmath +def mse_loss( + self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value +) -> Tensor: + loss = (self - target) ** 2 + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.mse_loss_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def mse_loss_backward( + grad_output: Tensor, input: Tensor, target: Tensor, reduction: int +): + norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0 + return norm * (input - target) * grad_output + + +@register_decomposition(aten._safe_softmax) +def safe_softmax(self, dim, dtype=None): + out = torch.softmax(self, dim=dim, dtype=dtype) + masked = self.eq(float("-inf")) + masked_rows = torch.all(masked, dim=dim, keepdim=True) + zeros = torch.zeros_like(out) + return torch.where(masked_rows, zeros, out) + + +@register_decomposition(aten.smooth_l1_loss) +@out_wrapper() +@pw_cast_for_opmath +def smooth_l1_loss( + self: Tensor, + target: Tensor, + reduction: int = Reduction.MEAN.value, + beta: float = 1.0, +): + loss = (self - target).abs() + loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.smooth_l1_loss_backward.default) +@pw_cast_for_opmath +def smooth_l1_loss_backward( + grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, beta: float +): + norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0 + x = self - target + abs_x = torch.abs(x) + norm_grad = norm * grad_output + return torch.where( + abs_x < beta, + norm_grad * x / beta, + norm_grad * torch.sign(x), + ) + + +@register_decomposition(aten.smooth_l1_loss_backward.grad_input) +@pw_cast_for_opmath +def smooth_l1_loss_backward_out( + grad_output: Tensor, + self: Tensor, + target: Tensor, + reduction: int, + beta: float, + grad_input: Tensor, +): + result = smooth_l1_loss_backward(grad_output, self, target, reduction, beta) + _maybe_resize_out(grad_input, result.shape) + return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True) + + +@register_decomposition(aten.huber_loss_backward.default) +@pw_cast_for_opmath +def huber_loss_backward( + grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float +): + norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0 + x = self - target + return torch.where( + x < -delta, + -norm * grad_output * delta, + torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output), + ) + + +# We cannot use @out_wrapper() here, because the output tensor is not named 'out', it's 'grad_input' +@register_decomposition(aten.huber_loss_backward.out) +@pw_cast_for_opmath +def huber_loss_backward_out( + grad_output: Tensor, + self: Tensor, + target: Tensor, + reduction: int, + delta: float, + grad_input: Tensor, +): + result = huber_loss_backward(grad_output, self, target, reduction, delta) + _maybe_resize_out(grad_input, result.shape) + return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True) + + +def _nll_loss_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, +) -> Tensor: + channel_dim = 0 if self.dim() < 2 else 1 + if reduction == Reduction.MEAN.value: + grad_output = grad_output / total_weight + + target = target.unsqueeze(channel_dim) + safe_target = torch.where(target != ignore_index, target, 0) + grad_input = torch.zeros_like(self) + grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0) + + if grad_input.dim() > grad_output.dim() > 0: + grad_output = grad_output.unsqueeze(channel_dim) + + if weight is not None: + new_shape = [1 for _ in range(self.dim())] + new_shape[channel_dim] = weight.shape[0] + weight = weight.reshape(new_shape) + grad_output = grad_output * weight + + grad_output = torch.where(target != ignore_index, grad_output, 0) + + return grad_input * grad_output + + +@register_decomposition(aten.glu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor: + assert self.dim() > 0, "glu does not support 0-dimensional tensors" + wrap_dim = utils.canonicalize_dim(self.dim(), dim) + nIn = self.size(wrap_dim) + assert ( + nIn % 2 == 0 + ), f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}" + inputSize = nIn // 2 + firstHalf = self.narrow(wrap_dim, 0, inputSize) + secondHalf = self.narrow(wrap_dim, inputSize, inputSize) + gradInputFirstHalf = torch.sigmoid(secondHalf) + gradInputSecondHalf = ( + (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output + ) + gradInputFirstHalf = gradInputFirstHalf * grad_output + return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim) + + +@register_decomposition(aten.nll_loss_backward) +@out_wrapper("grad_input") +def nll_loss_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, +) -> Tensor: + assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D" + assert ( + target.dim() <= 1 + ), "0D or 1D target tensor expected, multi-target not supported" + + no_batch_dim = self.dim() == 1 and target.dim() == 0 + assert no_batch_dim or ( + self.shape[0] == target.shape[0] + ), f"size mismatch (got input: {self.shape}, target: {target.shape})" + assert total_weight.numel() == 1, ( + "expected total_weight to be a single element tensor, got: ", + f"{total_weight.shape} ({total_weight.numel()} elements)", + ) + + assert ( + weight is None or weight.numel() == self.shape[-1] + ), "weight tensor should be defined either for all or no classes" + + if reduction == Reduction.NONE.value and self.dim() == 2: + assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], ( + f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but " + f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}" + ) + else: + assert ( + grad_output.dim() <= 1 and grad_output.numel() == 1 + ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}" + + return _nll_loss_backward( + grad_output, self, target, weight, reduction, ignore_index, total_weight + ) + + +@register_decomposition(aten.nll_loss2d_backward) +@out_wrapper("grad_input") +def nll_loss2d_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, +) -> Tensor: + assert ( + self.dim() == 4 + ), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}" + + assert ( + target.dim() == 3 + ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}" + + assert ( + self.shape[0] == target.shape[0] + and self.shape[2] == target.shape[1] + and self.shape[3] == target.shape[2] + ), f"size mismatch (got input: {self.shape}, target: {target.shape}" + + assert total_weight.numel() == 1, ( + "expected total_weight to be a single element tensor, " + f"got: {total_weight.shape} ( {total_weight.numel()}, elements)" + ) + + return _nll_loss_backward( + grad_output, self, target, weight, reduction, ignore_index, total_weight + ) + + +@register_decomposition(aten.binary_cross_entropy) +@out_wrapper() +@pw_cast_for_opmath +def binary_cross_entropy( + self: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + # We cannot currently model this without introducing data-dependent control flow + # TORCH_CHECK( + # (input_val >= 0) && (input_val <= 1), + # "all elements of input should be between 0 and 1" + # ) + loss = (target - 1) * torch.maximum( + torch.log1p(-self), self.new_full((), -100) + ) - target * torch.maximum(torch.log(self), self.new_full((), -100)) + if weight is not None: + loss = loss * weight + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.binary_cross_entropy_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def binary_cross_entropy_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + EPSILON = 1e-12 + result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON) + if weight is not None: + result = result * weight + if reduction == Reduction.MEAN.value: + result = result / self.numel() + return result + + +@register_decomposition(aten.soft_margin_loss) +@out_wrapper() +@pw_cast_for_opmath +def soft_margin_loss( + input: Tensor, + target: Tensor, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + loss = torch.log1p(torch.exp(-input * target)) + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.soft_margin_loss_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def soft_margin_loss_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + grad_input = target * grad_output * (torch.sigmoid(target * self) - 1) + if reduction == Reduction.MEAN.value: + grad_input = grad_input / self.numel() + return grad_input + + +@register_decomposition(aten.dist) +@out_wrapper() +def dist(input: Tensor, other: Tensor, p: float = 2): + return aten.norm(input - other, p=p) + + +@register_decomposition(aten._euclidean_dist) +@out_wrapper() +def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor: + x1_norm = x1.pow(2).sum(-1, True) + x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format) + x2_norm = x2.pow(2).sum(-1, True) + x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format) + x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1) + x2_ = torch.cat([x2, x2_pad, x2_norm], -1) + result = x1_.matmul(x2_.mT) + return result.clamp_min(0).sqrt() + + +@register_decomposition(aten.slice_backward) +@out_wrapper() +def slice_backward( + grad_output: Tensor, + input_sizes: List[int], + dim: int, + start: int, + end: int, + step: int, +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.slice_scatter(grad_input, grad_output, dim, start, end, step) + + +@register_decomposition(aten.slice.Tensor) +def slice_forward( + # Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1 + self: Tensor, + dim: int = 0, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, +): + from torch.fx.experimental.symbolic_shapes import ( + guard_size_oblivious, + statically_known_true, + ) + + ndim = self.dim() + if ndim == 0: + raise RuntimeError("slice() cannot be applied to a 0-dim tensor.") + dim = utils.canonicalize_dim(self.dim(), dim) + sizes = list(self.size()) + strides = list(self.stride()) + + if step <= 0: + raise RuntimeError("slice step must be positive") + + start_val = start if start is not None else 0 + end_val = end if end is not None else sys.maxsize # 2^63 - 1 + + if guard_size_oblivious(start_val < 0): + start_val += sizes[dim] + + if guard_size_oblivious(end_val < 0): + end_val += sizes[dim] + + if guard_size_oblivious(start_val < 0): + start_val = 0 + elif guard_size_oblivious(start_val > sizes[dim]): + start_val = sizes[dim] + + if guard_size_oblivious(end_val < start_val): + end_val = start_val + elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious( + end_val > sizes[dim] + ): + end_val = sizes[dim] + + storage_offset = self.storage_offset() + start_val * strides[dim] + len = end_val - start_val + sizes[dim] = (len + step - 1) // step + strides[dim] *= step + + if self.is_quantized: + raise NotImplementedError( + "Slice decomposition for quantized tensors aren't implemented" + ) + else: + return self.as_strided(sizes, strides, storage_offset) + + +def _normalize_start_end( + x: Tensor, dim: int, start: Optional[int], end: Optional[int] +) -> Tuple[int, int]: + """ + Normalize start and end such that both are in the range + [0, x.get_size()[dim]] and start <= end. + """ + dim_size = x.shape[dim] + + def clamp_wrap(val, lower, upper, default) -> int: + if val is None: + return default + if val < 0: + val = val + dim_size + return min(max(val, lower), upper) + + start = clamp_wrap(start, 0, dim_size, 0) + end = clamp_wrap(end, start, dim_size, dim_size) + return start, end + + +# This is not in torch._refs because aten.index used by +# aten._unsafe_masked_index does not have a decomposition. +@register_decomposition(aten.slice_scatter) +@out_wrapper() +def slice_scatter( + input: Tensor, + src: Tensor, + dim: int = 0, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, +): + dim = utils.canonicalize_dim(input.ndim, dim) + dim_size = input.shape[dim] + start, end = _normalize_start_end(input, dim, start, end) + + src_size = list(input.shape) + src_size[dim] = (end - start + (step - 1)) // step + src = src.expand(src_size) + + if start == 0 and end == dim_size and step == 1: + return src.clone() + + indices = [None] * input.dim() + idx = torch.arange(dim_size, device=input.device) + indices[dim] = (idx - start) // step + + mask = torch.ones(dim_size, device=input.device, dtype=torch.bool) + if start != 0: + mask = torch.logical_and(mask, idx >= start) + + if end != dim_size: + mask = torch.logical_and(mask, idx < end) + + if step != 1: + mask = torch.logical_and(mask, (idx - start) % step == 0) + + mask_shape = [1] * input.dim() + mask_shape[dim] = -1 + mask = mask.view(mask_shape) + return aten.where(mask, aten._unsafe_masked_index(src, mask, indices, 0), input) + + +@register_decomposition(aten.select_backward) +@out_wrapper() +def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int): + grad_input = grad_output.new_zeros(input_sizes) + return torch.select_scatter(grad_input, grad_output, dim, index) + + +@register_decomposition(aten.diagonal_backward) +@out_wrapper() +def diagonal_backward( + grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2) + + +def _cast_grad_to_input_dtype( + grad_output: Tensor, grad_input: Tensor, input_dtype: torch.dtype +): + if grad_output.dtype != input_dtype: + grad_input = grad_input.to(input_dtype) + return grad_input + + +@register_decomposition(aten._softmax_backward_data) +@out_wrapper("grad_input") +@compute_only_pw_cast_for_opmath +def _softmax_backward_data( + grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype +): + new_grad_output = grad_output * output + grad_input = new_grad_output - output * torch.sum( + new_grad_output, dim=dim, keepdim=True + ) + + # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor + # if grad_output.device == torch.device("cpu"): + # return grad_input.contiguous() + + return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous() + + +@register_decomposition(aten._log_softmax_backward_data) +@out_wrapper() +@compute_only_pw_cast_for_opmath +def _log_softmax_backward_data( + grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype +): + grad_input = grad_output - torch.exp(output) * torch.sum( + grad_output, dim=dim, keepdim=True + ) + return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype) + + +def _im2col_col2im_indices_along_dim( + input_d, kernel_d, dilation_d, padding_d, stride_d, device +): + """Utility function to implement im2col and col2im""" + blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1) + + arange_kw = partial(torch.arange, dtype=torch.int64, device=device) + + # Stride kernel over input and find starting indices along dim d + blocks_d_indices = arange_kw(0, blocks_d, stride_d).unsqueeze(0) + + # Apply dilation on kernel and find its indices along dim d + kernel_grid = arange_kw(0, kernel_d * dilation_d, dilation_d).unsqueeze(-1) + + # Broadcast and add kernel starting positions (indices) with + # kernel_grid along dim d, to get block indices along dim d + return blocks_d_indices + kernel_grid + + +@register_decomposition(aten.im2col) +@out_wrapper() +def im2col( + input: Tensor, + kernel_size: List[int], + dilation: List[int], + padding: List[int], + stride: List[int], +) -> Tensor: + torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported") + torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported") + torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported") + torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported") + + def check_positive(param, param_name, strict=True): + cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) + torch._check( + cond, lambda: "{param_name} should be greater {'than' zero, but got {param}" + ) + + check_positive(kernel_size, "kernel_size") + check_positive(dilation, "dilation") + check_positive(dilation, "padding", strict=False) + check_positive(stride, "stride") + + shape = input.shape + ndim = len(shape) + torch._check( + ndim in (3, 4) and all(d != 0 for d in shape[-3:]), + lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size " + f"and non-zero dimensions, but got: {tuple(shape)}", + ) + output_size = tuple( + 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st + for out, pad, dil, ker, st in zip( + shape[-2:], padding, dilation, kernel_size, stride + ) + ) + torch._check( + all(c > 0 for c in output_size), + lambda: f"Given an input with spacial size {tuple(shape[-2:])}, " + f"kernel_size={kernel_size}, dilation={dilation}, " + f"padding={padding}, stride={stride}, " + "the calculated shape of the array of sliding blocks " + f"is {output_size}, but its components must be at least one.", + ) + batched_input = ndim == 4 + if not batched_input: + input = input.unsqueeze(0) + + batch_dim, channel_dim, input_h, input_w = input.shape + + stride_h, stride_w = stride + padding_h, padding_w = padding + dilation_h, dilation_w = dilation + kernel_h, kernel_w = kernel_size + + blocks_row_indices = _im2col_col2im_indices_along_dim( + input_h, kernel_h, dilation_h, padding_h, stride_h, input.device + ) + blocks_col_indices = _im2col_col2im_indices_along_dim( + input_w, kernel_w, dilation_w, padding_w, stride_w, input.device + ) + + # Note that F.pad takes (padding_left, padding_right, padding_top, padding_bottom) + # ugh + padded_input = F.pad(input, (padding_w, padding_w, padding_h, padding_h)) + + blocks_row_indices = blocks_row_indices.unsqueeze(-1).unsqueeze(-1) + output = padded_input[:, :, blocks_row_indices, blocks_col_indices] + output = output.permute(0, 1, 2, 4, 3, 5) + num_blocks_row = blocks_row_indices.size(1) + num_blocks_col = blocks_col_indices.size(1) + output = output.reshape( + batch_dim, channel_dim * kernel_h * kernel_w, num_blocks_row * num_blocks_col + ) + + if not batched_input: + output = output.squeeze(0) + return output + + +@register_decomposition(aten.col2im) +@out_wrapper() +@pw_cast_for_opmath +def col2im( + input: Tensor, + output_size: List[int], + kernel_size: List[int], + dilation: List[int], + padding: List[int], + stride: List[int], +) -> Tensor: + torch._check(len(output_size) == 2, lambda: "only 2D output_size supported") + torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported") + torch._check(len(dilation) == 2, lambda: "only 2D dilation supported") + torch._check(len(padding) == 2, lambda: "only 2D padding supported") + torch._check(len(stride) == 2, lambda: "only 2D stride supported") + + def check_positive(param, param_name, strict=True): + cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) + torch._check( + cond, lambda: "{param_name} should be greater than zero, but got {param}" + ) + + check_positive(kernel_size, "kernel_size") + check_positive(dilation, "dilation") + check_positive(padding, "padding", strict=False) + check_positive(stride, "stride") + check_positive(output_size, "output_size") + + shape = input.shape + ndim = len(shape) + torch._check( + ndim in (2, 3) and all(d != 0 for d in shape[-2:]), + lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size " + f"and non-zero dimensions, but got: {tuple(shape)}", + ) + prod_kernel_size = kernel_size[0] * kernel_size[1] + torch._check( + shape[-2] % prod_kernel_size == 0, + lambda: "Expected size of input's first non-batch dimension to be divisible by the " + f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and " + f"kernel_size={kernel_size}", + ) + col = [ + 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st + for out, pad, dil, ker, st in zip( + output_size, padding, dilation, kernel_size, stride + ) + ] + L = col[0] * col[1] + torch._check( + shape[-1] == L, + lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " + f"dilation={dilation}, padding={padding}, stride={stride}, " + f"expected input.size(-1) to be {L} but got {shape[-1]}.", + ) + torch._check( + L > 0, + lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " + f"dilation={dilation}, padding={padding}, stride={stride}, " + f"expected input.size(-1) to be {L} but got {shape[-1]}.", + ) + batched_input = ndim == 3 + if not batched_input: + input = input.unsqueeze(0) + + shape = input.shape + + out_h, out_w = output_size + stride_h, stride_w = stride + padding_h, padding_w = padding + dilation_h, dilation_w = dilation + kernel_h, kernel_w = kernel_size + + # col2im is defined as the backwards of im2col, so we differentiate its decomposition by hand + input = input.reshape([shape[0], shape[1] // prod_kernel_size] + kernel_size + col) + input = input.permute(0, 1, 2, 4, 3, 5) + + indices_row = _im2col_col2im_indices_along_dim( + out_h, kernel_h, dilation_h, padding_h, stride_h, input.device + ) + indices_row = _unsqueeze_to_dim(indices_row, 4) + indices_col = _im2col_col2im_indices_along_dim( + out_w, kernel_w, dilation_w, padding_w, stride_w, input.device + ) + + output_padded_size = [o + 2 * p for o, p in zip(output_size, padding)] + output = input.new_zeros( + [shape[0], shape[1] // prod(kernel_size)] + output_padded_size + ) + idx = (None, None, indices_row, indices_col) + output = aten._unsafe_index_put(output, idx, input, accumulate=True) + output = F.pad(output, (-padding_w, -padding_w, -padding_h, -padding_h)) + + if not batched_input: + output = output.squeeze(0) + return output + + +@register_decomposition(aten.native_dropout_backward) +@out_wrapper() +def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): + # According to the CUDA kernel implementation we should have this test; + # but it seems to fail tests! + # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}") + + # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format + # This different from TensorIterator's behavior + r = (grad_output * (mask.type_as(grad_output) * scale)).clone( + memory_format=utils.suggest_memory_format(grad_output) + ) + return r + + +@register_decomposition(aten.unfold_backward) +@out_wrapper() +def unfold_backward( + grad: Tensor, input_size: List[int], dimension: int, size: int, step: int +) -> Tensor: + if len(input_size) == 0: + return torch.squeeze_copy(grad, 0) + dim = utils.canonicalize_dim(len(input_size), dimension) + idx = torch.arange(input_size[dim], device=grad.device, dtype=torch.int32) + idx = idx.unfold(0, size, step).flatten() + grad = grad.movedim(-1, dim + 1).flatten(dim, dim + 1) + # nb. At the moment this generates two kernels in triton + # It could potentially be fused into one call to scatter_reduce, + # in the case step <= size provided scatter_reduce generates 1 kernel + grad_input = grad.new_zeros(input_size) + index = (None,) * dim + (idx,) + return aten._unsafe_index_put(grad_input, index, grad, accumulate=True).contiguous() + + +@register_decomposition(aten.logit_backward.default) +@pw_cast_for_opmath +def logit_backward( + grad_output: Tensor, self: Tensor, eps: Optional[float] = None +) -> Tensor: + if eps is not None: + lo = eps + hi = 1.0 - lo + return torch.where( + torch.logical_and(self >= lo, self <= hi), + grad_output / (self * (1.0 - self)), + 0.0, + ) + else: + return torch.where( + torch.logical_and(self >= 0.0, self <= 1.0), + grad_output / (self * (1.0 - self)), + self.new_full((), float("nan")), + ) + + +@register_decomposition(aten.dropout) +@aten.dropout.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.dropout.default.py_impl(DispatchKey.Autograd) +def dropout(input: Tensor, p: float, train: Optional[bool]): + if train and p != 0: + return aten.native_dropout(input, p, train)[0] + else: + return input.clone() + + +@register_decomposition(aten.native_dropout) +@out_wrapper("out0", "out1") +def native_dropout(input: Tensor, p: float, train: Optional[bool]): + if train and p != 0: + if p == 1: + return (torch.zeros_like(input), torch.zeros_like(input, dtype=torch.bool)) + if not input.dtype.is_floating_point: + raise RuntimeError( + "result type Float can't be cast to the desired output type Long" + ) + bool_mask = torch.rand_like(input) > p + res = bool_mask * input * float(1.0 / (1.0 - p)) + return (res, bool_mask) + else: + return (input, torch.ones_like(input, dtype=torch.bool)) + + +@register_decomposition(aten._softmax) +@out_wrapper() +def _softmax(x: Tensor, dim: int, half_to_float: bool): + # eager softmax returns a contiguous tensor. Ensure that decomp also returns + # a contiguous tensor. + x = x.contiguous() + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(computation_dtype) + if x.numel() == 0: + unnormalized = torch.exp(x) + else: + x_max = torch.amax(x, dim, keepdim=True) + unnormalized = torch.exp(x - x_max) + result = unnormalized / torch.sum(unnormalized, dim, keepdim=True) + if not half_to_float: + result = result.to(result_dtype) + return result + + +@register_decomposition(aten._log_softmax) +@out_wrapper() +def _log_softmax(x: Tensor, dim: int, half_to_float: bool): + # eager log_softmax returns a contiguous tensor. Ensure that decomp also + # returns a contiguous tensor. + x = x.contiguous() + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(computation_dtype) + if x.numel() == 0: + shifted = x + else: + x_max = torch.amax(x, dim, keepdim=True) + shifted = x - x_max + shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True)) + result = shifted - shifted_logsumexp + if not half_to_float: + result = result.to(result_dtype) + return result + + +@register_decomposition(aten.embedding) +@out_wrapper() +def embedding( + weight: Tensor, + indices: Tensor, + padding_idx: int = -1, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: + assert weight.dim() == 2, "'weight' must be 2-D" + # Nb. scale_grad_by_freq is not used in the forward + if indices.ndim <= 1: + # We need this one as weight[indices] calls item() in these cases + out = weight.index_select(0, indices) + if indices.ndim == 0: + out = out.squeeze(0) + return out + else: + return weight[indices] + + +@register_decomposition(aten.embedding_dense_backward) +@out_wrapper() +def embedding_dense_backward( + grad_output: Tensor, + indices: Tensor, + num_weights: int, + padding_idx: int, + scale_grad_by_freq: bool, +): + computation_dtype, result_dtype = utils.elementwise_dtypes( + grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + grad_output = grad_output.to(computation_dtype) + indices = _maybe_convert_to_dtype(indices, torch.long) # type: ignore[assignment] + if scale_grad_by_freq: + counts = indices.new_zeros((num_weights,)) + ones = torch.ones_like(indices) + counts = aten._unsafe_index_put(counts, [indices], ones, accumulate=True) + grad_weights_scale = counts[indices] + grad_output = grad_output / grad_weights_scale.unsqueeze(-1) + + mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim) + grad = grad_output.masked_fill(mask, 0) + grad_weight = grad_output.new_zeros( + (num_weights,) + grad_output.shape[indices.ndim :] + ) + return aten._unsafe_index_put(grad_weight, [indices], grad, accumulate=True).to( + result_dtype + ) + + +def prod(x: List[int]): + r = 1 + for i in x: + r *= i + return r + + +def _pad_chunk( + tensors: List[Tensor], + dim: int, + num_chunks: int, +) -> List[Tensor]: + padded_tensors = [] + for tensor in tensors: + tensor_size = tensor.size() + pad_along_dim = (tensor_size[dim] + num_chunks - 1) // num_chunks * num_chunks + if pad_along_dim != tensor_size[dim]: + # Use aten.constant_pad_nd instead of copy_ for functionalization + pad = [0] * 2 * (tensor.ndim - dim - 1) + [ + 0, + pad_along_dim - tensor_size[dim], + ] + tensor = aten.constant_pad_nd(tensor, pad, 0) + view_size = tensor_size[:dim] + torch.Size([num_chunks, -1]) + padded_tensors.append(tensor.view(view_size)) + return padded_tensors + + +def have_same_ndims(tensors: List[Tensor]): + ndim = tensors[0].ndim + for tensor in tensors: + if tensor.ndim != ndim: + return False + return True + + +def leading_dimension_matches(tensors: List[Tensor], dim: int): + leading_dim_sizes = tensors[0].size()[:dim] + for tensor in tensors: + torch._check( + tensor.size()[:dim] == leading_dim_sizes, + lambda: "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors", + ) + + +def _preprocess_chunk_cat_inputs( + tensors: List[Tensor], + dim: int, + num_chunks: int, +): + torch._check(num_chunks >= 1, lambda: "_chunk_cat expects positive num_chunks") + torch._check( + len(tensors) > 0, lambda: "_chunk_cat expects a non-empty input tensor list" + ) + expected_dtype = tensors[0].dtype + expected_device = tensors[0].device + for tensor in tensors: + torch._check(tensor.numel() > 0, lambda: "_chunk_cat expects non-empty tensor") + torch._check( + tensor.dtype == expected_dtype, + lambda: "_chunk_cat expects all input tensors with the same dtype", + ) + torch._check( + tensor.device == expected_device, + lambda: "_chunk_cat expects all inputs tensors on the same device", + ) + if have_same_ndims(tensors): + dim = utils.canonicalize_dim(tensors[0].dim(), dim) + else: + torch._check( + dim >= 0, + lambda: "_chunk_cat expects non-negative dim when input tensors have different ndims", + ) + for tensor in tensors: + torch._check( + dim < tensor.ndim, + lambda: "_chunk_cat expects dim < ndim for all input tensors", + ) + leading_dimension_matches(tensors, dim) + return dim + + +@register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out]) +def _chunk_cat( + tensors: List[Tensor], + dim: int, + num_chunks: int, + out: Optional[Tensor] = None, +) -> Tensor: + dim = _preprocess_chunk_cat_inputs(tensors, dim, num_chunks) + padded_tensors = _pad_chunk(tensors, dim, num_chunks) + if out is None: + return torch.cat(padded_tensors, dim + 1) + else: + torch.cat(padded_tensors, dim + 1, out=out) + return out + + +@register_decomposition(aten.split_with_sizes) +def split_with_sizes( + self: Tensor, split_sizes: List[int], dim: int = 0 +) -> List[Tensor]: + # NB: Perform the check_is_size tests first so that the + # sum test does not try to do a replacement + for i in range(len(split_sizes)): + torch._check_is_size( + split_sizes[i], + lambda: "split_with_sizes expects split_sizes have only non-negative entries", + ) + torch._check_with( + ValueError, + sum(split_sizes) == self.shape[dim], + lambda: f"Split sizes add up to {sum(split_sizes)} but got the tensor's size of {self.shape[dim]}", + ) + num_splits = len(split_sizes) + splits = [] + start_idx = 0 + + for i in range(num_splits): + length = split_sizes[i] + splits.append(self.narrow(dim, start_idx, length)) + start_idx += length + return splits + + +# out_wrapper currently does not allow optional outputs +@register_decomposition( + [aten.split_with_sizes_copy.default, aten.split_with_sizes_copy.out] +) +def split_with_sizes_copy( + self: Tensor, + split_sizes: List[int], + dim: int = 0, + out: Optional[List[Tensor]] = None, +) -> Optional[List[Tensor]]: + splits = split_with_sizes(self, split_sizes, dim=dim) + if out is None: + return [s.clone(memory_format=torch.contiguous_format) for s in splits] + else: + for output, split in zip(out, splits): + _maybe_resize_out(output, split.shape) + _safe_copy_out(copy_from=split, copy_to=output, exact_dtype=True) + return None + + +@register_decomposition(aten.unsafe_split.Tensor) +def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]: + return aten.split.Tensor(input, split_size, dim) + + +@register_decomposition(aten.unsafe_split_with_sizes.default) +def unsafe_split_with_sizes( + input: Tensor, split_sizes: List[int], dim: int = 0 +) -> Tuple[Tensor, ...]: + return aten.split_with_sizes.default(input, split_sizes, dim) + + +@register_decomposition(aten.split.Tensor) +def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]: + input_sizes = self.shape + dim_size = input_sizes[dim] + if split_size == 0: + assert dim_size == 0 + return (self,) + chunks = (dim_size + split_size - 1) // split_size + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import guard_int + + chunks = guard_int(chunks) + split_sizes = [split_size for i in range(chunks)] + split_sizes[-1] = split_size - (split_size * chunks - dim_size) + return torch.split(self, split_sizes, dim) + + +@aten.tensor_split.tensor_indices_or_sections.py_impl( + DispatchKey.CompositeImplicitAutograd +) +def tensor_split_tensor_indices_or_sections_py_impl( + self: Tensor, + tensor_indices_or_sections: Tensor, + dim: int = 0, +) -> Tuple[Tensor, ...]: + assert tensor_indices_or_sections.device.type == "cpu" + assert tensor_indices_or_sections.dtype == torch.int64 + split_dim = tensor_indices_or_sections.dim() + torch._check( + split_dim == 1 or split_dim == 0, + lambda: "tensor_split expected tensor_indices_or_sections to be a zero-dimensional " + f"or one-dimensional tensor, but got a tensor with {split_dim} dims", + ) + if split_dim == 0: + sections = tensor_indices_or_sections.item() + assert isinstance(sections, IntLike) + return self.tensor_split(sections, dim) + else: + indices = [i.item() for i in tensor_indices_or_sections] + # WARNING: Tempted to torch._check_is_size on the indices here? You + # can't: tensor_split works with negative values in indices: + # + # >>> torch.tensor_split(torch.randn(10), torch.tensor([-5, 5])) + # (tensor([ 0.3540, 2.1074, -0.8507, 1.1639, 0.3055]), tensor([]), + # tensor([-0.4285, 1.0692, -0.1776, 0.9362, 1.6143])) + # + # Sorry, I don't make the rules. Explicitly do the item call in user + # code if you KNOW that they are non-negative. + return self.tensor_split(indices, dim) + + +# TODO: this doesn't appear to have enough precision in bfloat16 +@register_decomposition(aten.addmm) +@out_wrapper() +@pw_cast_for_opmath +def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1): + if not self.is_floating_point() and not self.is_complex(): + beta = int(beta) + alpha = int(alpha) + out = alpha * torch.mm(mat1, mat2) + if beta == 0: + return out + + # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition. + # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided. + # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition. + # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input. + # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases. + # This implementation is not ideal, and we should revisit this when we have a better solution. + return out + beta * self + + +@register_decomposition(aten._addmm_activation) +@out_wrapper() +@pw_cast_for_opmath +def _addmm_activation( + self: Tensor, + mat1: Tensor, + mat2: Tensor, + beta: int = 1, + alpha: int = 1, + use_gelu: bool = False, +): + out = addmm(self, mat1, mat2, beta, alpha) + if use_gelu: + if self.is_cuda: + return aten.gelu(out, approximate="tanh") + else: + return aten.gelu(out) + return aten.relu(out) + + +@register_decomposition(aten.addmv) +@out_wrapper() +@pw_cast_for_opmath +def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1): + if not self.is_floating_point() and not self.is_complex(): + beta = int(beta) + alpha = int(alpha) + out = alpha * torch.mv(mat1, vec) + if beta == 0: + return out + return out + beta * self + + +@register_decomposition(aten.native_group_norm_backward.default) +@pw_cast_for_opmath +def native_group_norm_backward( + grad_output: Tensor, + input: Tensor, + mean: Tensor, + rstd: Tensor, + gamma: Optional[Tensor], + N: int, + C: int, + HxW: int, + group: int, + output_mask: List[bool], +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + utils.check_same_device( + grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False + ) + utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False) + utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False) + torch._check( + input.numel() == N * C * HxW, + lambda: f"Expect input to have {N * C * HxW} elements", + ) + torch._check( + mean.shape == (N, group), + lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}", + ) + torch._check( + gamma is None or gamma.numel() == C, + lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}", + ) + + cpg, _rem = divmod(C, group) + torch._check( + _rem == 0, + lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}", + ) + + # Compute Internal gradients + ds = torch.mul(grad_output, input).view(N, C, HxW).sum(dim=[2]) + db = grad_output.view(N, C, HxW).sum(dim=[2]) + + d_input: Optional[Tensor] = None + d_gamma: Optional[Tensor] = None + d_bias: Optional[Tensor] = None + if output_mask[0]: + s = 1.0 / (HxW * cpg) + if gamma is not None: + ds_val = torch.mul(ds, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2) + db_val = torch.mul(db, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2) + c1 = torch.mul( + rstd.unsqueeze(-1), + gamma.reshape(1, group, cpg), + ) + else: + ds_val = ds.reshape(N, group, cpg).sum(2) + db_val = db.reshape(N, group, cpg).sum(2) + c1 = torch.mul( + rstd.unsqueeze(-1), + torch.ones((1, group, cpg), device=rstd.device), + ) + c2 = (db_val * mean - ds_val) * rstd * rstd * rstd * s + c3 = -c2 * mean - db_val * rstd * s + + c1 = c1.unsqueeze(-1) + c2 = _unsqueeze_to_dim(c2, 4) + c3 = _unsqueeze_to_dim(c3, 4) + d_input = ( + torch.mul(grad_output.reshape(N, group, cpg, HxW), c1) + + torch.mul(input.reshape(N, group, cpg, HxW), c2) + + c3 + ) + d_input = d_input.reshape(input.shape).to(input.dtype) + if output_mask[1]: + d_gamma = ( + ( + (ds.view(N, group, cpg) - db.view(N, group, cpg) * mean.unsqueeze(-1)) + * rstd.unsqueeze(-1) + ) + .sum(dim=[0]) + .reshape(C) + ) + if output_mask[2]: + d_bias = db.sum(dim=[0]) + + return (d_input, d_gamma, d_bias) + + +# out_wrapper currently does not allow optional outputs +@register_decomposition(aten.native_group_norm_backward.out) +def native_group_norm_backward_out( + grad_output: Tensor, + input: Tensor, + mean: Tensor, + rstd: Tensor, + gamma: Optional[Tensor], + N: int, + C: int, + HxW: int, + group: int, + output_mask: List[bool], + *, + out0: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + result = native_group_norm_backward( + grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask + ) + grad_input = (out0, out1, out2) + for i, r in enumerate(result): + if r is not None: + _maybe_resize_out(grad_input[i], r.shape) + _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) + + return grad_input + + +def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]: + if x is not None: + return x.to(dtype) + return x + + +# TODO: Take a closer look at the type promotion semantics +@register_decomposition(aten.native_layer_norm_backward.default) +def native_layer_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: List[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: List[bool], +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + computation_dtype = utils.get_computation_dtype(input.dtype) + grad_out_cast, input_cast, weight_cast, bias_cast = ( + x.to(computation_dtype).contiguous() if x is not None else x + for x in (grad_out, input, weight, bias) + ) + assert grad_out_cast is not None + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices: List[int] = [] + outer_dim_indices: List[int] = [] + for i in range(input_ndim): + if i >= axis: + inner_dim_indices.append(i) + else: + outer_dim_indices.append(i) + + N = prod(inner_dims) # type: ignore[arg-type] + M = prod(outer_dims) # type: ignore[arg-type] + if M <= 0 or N <= 0: + return ( + input.new_zeros(input_shape) if output_mask[0] else None, + input.new_zeros(input_shape[axis:]) if output_mask[1] else None, + input.new_zeros(input_shape[axis:]) if output_mask[2] else None, + ) + mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr] + rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] + x_hat = (input_cast - mean) * rstd + if weight_cast is not None: + grad_x_hat = grad_out_cast * weight_cast + else: + grad_x_hat = grad_out_cast + a = grad_x_hat * N + b = torch.sum(grad_x_hat, inner_dim_indices, True) + c1 = torch.mul(grad_x_hat, x_hat) + c2 = torch.sum(c1, inner_dim_indices, True) + c3 = torch.mul(x_hat, c2) + + inner = a - b - c3 + d_input: Optional[Tensor] = None + d_weight: Optional[Tensor] = None + d_bias: Optional[Tensor] = None + if output_mask[0]: + d_input = (rstd / N) * inner + + if output_mask[1] and weight_cast is not None: + if len(outer_dim_indices) > 0: + d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False) + else: + d_weight = grad_out_cast * x_hat + + if output_mask[2] and bias_cast is not None: + if len(outer_dim_indices) > 0: + d_bias = torch.sum(grad_out_cast, outer_dim_indices, False) + else: + d_bias = grad_out_cast.clone() + + return ( + _maybe_cast(d_input, input.dtype), + _maybe_cast(d_weight, input.dtype), + _maybe_cast(d_bias, input.dtype), + ) + + +# out_wrapper currently does not allow optional outputs +@register_decomposition(aten.native_layer_norm_backward.out) +def native_layer_norm_backward_out( + grad_out: Tensor, + input: Tensor, + normalized_shape: List[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: List[bool], + *, + out0: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + result = native_layer_norm_backward( + grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask + ) + grad_input = (out0, out1, out2) + for i, r in enumerate(result): + if r is not None: + _maybe_resize_out(grad_input[i], r.shape) + _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) + + return grad_input + + +def native_batch_norm_helper( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, + functional: bool, +) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + reduction_dims = [0] + list(range(2, input.dim())) + computation_dtype = utils.get_computation_dtype(input.dtype) + new_running_mean = running_mean + new_running_var = running_var + if training: + computation_dtype = utils.get_computation_dtype(input.dtype) + input_acc = input.to(dtype=computation_dtype) + biased_var, mean = torch.var_mean( + input_acc, dim=reduction_dims, correction=0, keepdim=True + ) + rstd = torch.rsqrt(biased_var + eps) + + output = (input - mean) * rstd + + save_mean = torch.squeeze(mean, reduction_dims) + save_rstd = torch.squeeze(rstd, reduction_dims) + if running_mean is not None: + new_running_mean = momentum * save_mean + (1 - momentum) * running_mean + if not functional: + running_mean.copy_(new_running_mean) + if running_var is not None: + n = input.numel() / input.shape[1] + # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction + # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose + # numerics probably don't matter. + squeezed_var = torch.squeeze(biased_var, reduction_dims) + unbiased_var = squeezed_var * (n / (n - 1)) + new_running_var = momentum * unbiased_var + (1 - momentum) * running_var + if not functional: + running_var.copy_(new_running_var) + else: + assert running_mean is not None and running_var is not None + running_mean = running_mean.to(dtype=computation_dtype, copy=True) + new_running_mean = running_mean + running_var = running_var.to(dtype=computation_dtype, copy=True) + new_running_var = running_var + mean = running_mean + invstd = 1 / (torch.sqrt(running_var + eps)) + # Very annoying inconsistency where CPU and CUDA give different shapes + if input.device.type != "cpu": + save_mean = running_mean + save_rstd = invstd + else: + save_mean = input.new_zeros((0,)) + save_rstd = input.new_zeros((0,)) + mean = _unsqueeze_to_dim(mean, input.dim() - 1) + invstd = _unsqueeze_to_dim(invstd, input.dim() - 1) + output = (input - mean) * invstd + + if weight is not None: + weight = weight.flatten() + weight = _unsqueeze_to_dim(weight, input.dim() - 1) + output = output * weight + + if bias is not None: + bias = bias.flatten() + bias = _unsqueeze_to_dim(bias, input.dim() - 1) + output = output + bias + + if input.device.type == "cpu": + save_mean = save_mean.to(dtype=input.dtype) + save_rstd = save_rstd.to(dtype=input.dtype) + return ( + output.to(dtype=input.dtype), + save_mean, + save_rstd, + new_running_mean, + new_running_var, + ) + + +@register_decomposition(aten.native_batch_norm) +@out_wrapper("out", "save_mean", "save_invstd") +def native_batch_norm( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +# TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm +# with our new correctly schema'd _native_batch_norm_legit and its variants, but +# we cannot do that immediately in the C++ because it would be forwards incompatible +# with some mobile use cases. +# +# Since this change is most impactful for aot autograd/functionalization, we simply +# register this decomposition on the Autograd key for the python dispatcher (which is +# currently only used by aot autograd/functionalization and no one else, really). +# In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm +# to be _native_batch_norm_legit and have the right schema (stating that there are input mutations). +@aten.native_batch_norm.default.py_impl(DispatchKey.Autograd) +@aten.native_batch_norm.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def native_batch_norm_decomposition( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + if running_mean is None and running_var is None: + return aten._native_batch_norm_legit( + input, weight, bias, training, momentum, eps + ) + if running_mean is None: + raise RuntimeError( + "running_mean is None, but running_var is provided. " + "They should both be None or both be provided." + ) + if running_var is None: + raise RuntimeError( + "running_var is None, but running_mean is provided. " + "They should both be None or both be provided." + ) + if training: + # HACK: batch norm consolidation should clean this up so this op doesn't take in a training arg. + return aten._native_batch_norm_legit( + input, weight, bias, running_mean, running_var, training, momentum, eps + ) + else: + return aten._native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps + ) + + +@aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> List[Tensor]: + dim_size = tensor.size(dim) + split_size = (dim_size + chunks - 1) // chunks + + if split_size == 0 and dim_size == 0: + split_sizes = [split_size for _ in chunks] + split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size) + return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim) + return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim) + + +@register_decomposition(aten._native_batch_norm_legit_no_training.default) +def _native_batch_norm_legit_no_training( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + return aten._native_batch_norm_legit.default( + input, + weight, + bias, + running_mean, + running_var, + False, # training + momentum, + eps, + ) + + +@register_decomposition(aten._native_batch_norm_legit.default) +def _native_batch_norm_legit( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +@register_decomposition(aten._native_batch_norm_legit.no_stats) +def _native_batch_norm_legit_no_stats( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, None, None, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +@register_decomposition(aten._native_batch_norm_legit_functional.default) +def _native_batch_norm_legit_functional( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + ( + output, + save_mean, + save_rstd, + new_running_mean, + new_running_var, + ) = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, True + ) + assert new_running_mean is not None, "new_running_mean should not be None" + assert new_running_var is not None, "new_running_var should not be None" + return output, save_mean, save_rstd, new_running_mean, new_running_var + + +def _get_batch_norm_reserve_tensor( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + eps: float, + training: bool, +) -> Tensor: + """ + Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the + backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`, + which support a variety of backends including cudnn. We create this tensor here to get + the correct shape in the traced graph if we detect that will call the cudnn kernel, + and rely on DCE to avoid materializing this tensor. + """ + backend = torch._C._select_batch_norm_backend( # type: ignore[attr-defined] + input, weight, bias, running_mean, running_var, True, eps + ) + reserve_size = 0 + if backend == torch._C._BatchNormBackend.Cudnn: # type: ignore[attr-defined] + reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size(input, training) # type: ignore[attr-defined] + return torch.empty( + reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device + ) + + +@register_decomposition(aten._batch_norm_with_update.default) +def _batch_norm_with_update( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, + weight, + bias, + running_mean, + running_var, + True, # training + momentum, + eps, + False, # functional + ) + reserve = _get_batch_norm_reserve_tensor( + input, weight, bias, running_mean, running_var, eps, training=True + ) + return output, save_mean, save_rstd, reserve + + +@register_decomposition(aten._batch_norm_with_update_functional.default) +def _batch_norm_with_update_functional( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + ( + output, + save_mean, + save_rstd, + new_rm, + new_rv, + ) = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, True, momentum, eps, True + ) + reserve = _get_batch_norm_reserve_tensor( + input, weight, bias, running_mean, running_var, eps, training=True + ) + assert new_rm is not None, "new_running_mean should not be None" + assert new_rv is not None, "new_running_var should not be None" + return (output, save_mean, save_rstd, reserve, new_rm, new_rv) + + +@register_decomposition(aten._batch_norm_no_update.default) +def _batch_norm_no_update( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, + weight, + bias, + running_mean, + running_var, + False, # training + momentum, + eps, + False, # functional + ) + reserve = _get_batch_norm_reserve_tensor( + input, weight, bias, running_mean, running_var, eps, training=False + ) + return output, save_mean, save_rstd, reserve + + +@register_decomposition(aten._fused_dropout) +@out_wrapper("out0", "out1") +@pw_cast_for_opmath +def _fused_dropout_decomposition(input, p, generator=None): + assert generator is None + mask = (torch.rand_like(input) < p).to(dtype=torch.uint8) + res = mask.type_as(input) * input * (1.0 / p) + return (res, mask) + + +@register_decomposition(aten._to_copy) +@out_wrapper() +def _to_copy( + x: Union[Tensor, NumberType], + *, + dtype: Optional[torch.dtype] = None, + layout=None, + device: Optional[torch.device] = None, + pin_memory: bool = False, + non_blocking: bool = False, + memory_format: Optional[torch.memory_format] = None, +): + assert not layout or layout == torch.strided, "TODO" + assert not pin_memory, "TODO" + assert isinstance(x, (torch.Tensor, int, float, bool, complex)) + if device is None and dtype is None and memory_format is None: + if isinstance(x, torch.Tensor): + return x.clone() + else: + return x + dtype_converted = False + + if isinstance(x, torch.Tensor): + x_tensor = x + else: + x_tensor = torch.scalar_tensor(x) + + if device is not None and device != x_tensor.device: + # avoid conversions on cpu + if dtype is not None and device.type == "cpu": + x_tensor = torch._prims.convert_element_type(x_tensor, dtype) + dtype_converted = True + x_tensor = torch._prims.device_put(x_tensor, device) + + if dtype is not None and not dtype_converted: + x_tensor = torch._prims.convert_element_type(x_tensor, dtype) + dtype_converted = True + + if memory_format is not None: # no ref/prim for memory format + return torch.clone(x_tensor, memory_format=memory_format) + return x_tensor + + +# Questionable decompositions +# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced. +# Note that this decomposition causes issues with in-place ops +@register_decomposition([aten.detach, aten.lift, aten.lift_fresh]) +@out_wrapper() +def nop_decomposition(x): + return aten.alias(x) + + +# Also register to the Autograd dispatch key, so this decomp can run above autograd. +# native_batch_norm needs to decompose into other ops before autograd. +@aten.cudnn_batch_norm.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.cudnn_batch_norm) +@out_wrapper("out0", "out1", "out2", "out3") +def cudnn_batch_norm( + input: Tensor, + weight: Tensor, + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + exponential_average_factor: float, + epsilon: float, +): + a, b, c = aten.native_batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + exponential_average_factor, + epsilon, + ) + # Cudnn return running mean and variance when training is True + if training: + return (a, b, c, input.new_zeros((0,), dtype=torch.uint8)) + return ( + a, + weight.new_zeros((0,)), + weight.new_zeros((0,)), + input.new_zeros((0,), dtype=torch.uint8), + ) + + +def _broadcast_batch_norm_backward(x, broadcast_mask): + for axis, mask in enumerate(broadcast_mask): + if mask == 1 and not (axis < x.ndim and x.shape[axis] == mask): + x = x.unsqueeze(axis) + return x + + +@register_decomposition(aten.batch_norm_backward.default) +def batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: List[bool], + reserve: Tensor, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + return native_batch_norm_backward( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, + ) + + +@register_decomposition(aten.native_batch_norm_backward.default) +def native_batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: List[bool], +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + input_dtype = input.dtype + if weight is not None: + weight_dtype = weight.dtype + else: + weight_dtype = input_dtype + computation_dtype = utils.get_computation_dtype(input.dtype) + ( + grad_out_cast, + input_cast, + weight_cast, + running_mean_cast, + running_var_cast, + save_mean_cast, + save_invstd_cast, + ) = ( + x.to(computation_dtype) if x is not None else x + for x in ( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + ) + ) + input_shape = input.shape + input_rank = input.dim() + assert input_rank >= 2, "rank of the input must be at least 2" + + axis = 1 + num_features = prod(list(input_shape)) / input_shape[axis] + mean = save_mean_cast + invstd = save_invstd_cast + if train: + assert save_mean_cast is not None and save_invstd_cast is not None + else: + assert running_mean_cast is not None and running_var_cast is not None + mean = running_mean_cast + invstd = torch.rsqrt(running_var_cast + eps) + + broadcast_mask: List[int] = [1] * input_rank + broadcast_mask[axis] = input_shape[axis] + + reduction_axes: List[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) + + mean = _broadcast_batch_norm_backward(mean, broadcast_mask) # type: ignore[arg-type] + norm = 1.0 / num_features + grad_output_sum = torch.sum(grad_out_cast, reduction_axes) # type: ignore[arg-type] + dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes) # type: ignore[operator] + + grad_mean = _broadcast_batch_norm_backward(grad_output_sum * norm, broadcast_mask) + proj_scale = _broadcast_batch_norm_backward(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) # type: ignore[operator] + + if weight_cast is None: + grad_scale = _broadcast_batch_norm_backward(invstd, broadcast_mask) * 1.0 # type: ignore[arg-type] + else: + grad_scale = _broadcast_batch_norm_backward( + invstd * weight_cast, broadcast_mask + ) + + if train: + proj = (input_cast - mean) * proj_scale # type: ignore[operator] + grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale + else: + grad_input = grad_out_cast * grad_scale + + if output_mask[1]: + grad_weight = dot_p * invstd + else: + grad_weight = None # "None" doesn't work with vjp, should use zeros for vjp + + if output_mask[2]: + grad_bias = grad_output_sum + else: + grad_bias = None # "None" doesn't work with vjp, should use zeros for vjp + + return ( + grad_input.to(input_dtype), + _maybe_cast(grad_weight, weight_dtype), + _maybe_cast(grad_bias, weight_dtype), + ) + + +# out_wrapper currently does not allow optional outputs +@register_decomposition(aten.native_batch_norm_backward.out) +def native_batch_norm_backward_out( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: List[bool], + *, + out0: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + result = native_batch_norm_backward( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, + ) + grad_input = (out0, out1, out2) + for i, r in enumerate(result): + if r is not None: + _maybe_resize_out(grad_input[i], r.shape) + _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) + + return grad_input + + +@register_decomposition(aten.miopen_batch_norm_backward) +@out_wrapper("out0", "out1", "out2") +def miopen_batch_norm_backward( + input: Tensor, + grad_output: Tensor, + weight: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_var: Optional[Tensor], + epsilon: float, +): + return aten.native_batch_norm_backward( + grad_output, + input, + weight, + running_mean, + running_var, + save_mean, + save_var, + True, + epsilon, + [True, True, True], + ) + + +@register_decomposition(aten.cudnn_batch_norm_backward) +@out_wrapper("out0", "out1", "out2") +def cudnn_batch_norm_backward( + input: Tensor, + grad_output: Tensor, + weight: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_var: Optional[Tensor], + epsilon: float, + reserveSpace: Tensor, +): + return aten.native_batch_norm_backward( + grad_output, + input, + weight, + running_mean, + running_var, + save_mean, + save_var, + True, + epsilon, + [True, True, True], + ) + + +@register_decomposition(aten._adaptive_avg_pool2d) +@out_wrapper() +@pw_cast_for_opmath +def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]): + # Preconditions + device = input.device + shape = input.shape + ndim = len(shape) + torch._check( + ndim in (3, 4), + lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}", + ) + for d in input.shape[-2:]: + torch._check( + d != 0, + lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for " + f"non-batch dimensions, but input has shape {tuple(shape)}.", + ) + + # Optimisation (we should also do this in the kernel implementation) + if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0: + stride = tuple(i // o for i, o in zip(shape[-2:], output_size)) + kernel = tuple( + i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride) + ) + return torch.nn.functional.avg_pool2d(input, kernel, stride) + + def start_index(a, b, c): + return torch.div(a * c, b, rounding_mode="trunc") + + def end_index(a, b, c): + return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc") + + def compute_idx(in_size, out_size): + orange = torch.arange(out_size, device=device, dtype=torch.int64) + i0 = start_index(orange, out_size, in_size) + # Let length = end_index - start_index, i.e. the length of the pooling kernels + # length.max() can be computed analytically as follows: + maxlength = in_size // out_size + 1 + in_size_mod = in_size % out_size + # adaptive = True iff there are kernels with different lengths + adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0) + if adaptive: + maxlength += 1 + elif in_size_mod == 0: + maxlength -= 1 + + range_max = torch.arange(maxlength, device=device, dtype=torch.int64) + idx = i0.unsqueeze(-1) + range_max + if adaptive: + # Need to clamp to avoid accessing out-of-bounds memory + # TODO make minimum accept scalars + maxval = torch.scalar_tensor( + in_size - 1, dtype=idx.dtype, device=idx.device + ) + idx = torch.minimum(idx, maxval) + + # Compute the length + i1 = end_index(orange, out_size, in_size) + length = i1 - i0 + else: + length = maxlength + return idx, length, range_max, adaptive + + # length is not None if it's constant, otherwise we'll need to compute it + idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2]) + idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1]) + + vals = input[..., _unsqueeze_to_dim(idxh, 4), idxw] + # Shortcut for the simpler case + if not adaptive_h and not adaptive_w: + return torch.mean(vals, dim=(-3, -1)) + + def maybe_mask(vals, length, range_max, adaptive, dim): + if isinstance(length, IntLike): + return vals, length + else: + # zero-out the things we didn't really want to select + assert dim < 0 + # hack + mask = range_max >= length.unsqueeze(-1) + if dim == -2: + mask = _unsqueeze_to_dim(mask, 4) + vals = torch.masked_fill(vals, mask, 0.0) + # Compute the length of each window + length = _unsqueeze_to_dim(length, -dim) + return vals, length + + vals, length_h = maybe_mask( + vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2 + ) + vals, length_w = maybe_mask( + vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1 + ) + + # We unroll the sum as we assume that the kernels are going to be small + ret = None + for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])): + if ret is None: + ret = vals[..., i, :, j] + else: + ret = ret + vals[..., i, :, j] + return ret / (length_h * length_w) + + +@register_decomposition(aten.index_add_) +def index_add_( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + alpha: NumberType = 1, +): + return _index_add(x, dim, index, tensor, inplace=True, alpha=alpha) + + +@register_decomposition(aten.index_add) +@out_wrapper() +def index_add( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + alpha: NumberType = 1, +): + return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha) + + +def _index_add( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + inplace: bool, + alpha: NumberType = 1, +): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + index_size = index.size(0) if index.ndim == 1 else 1 + tensor_size = tensor.size(dim) if tensor.ndim > 0 else 1 + torch._check( + tensor_size == index_size, + lambda: f"Number of indices ({index_size}) should be equal to tensor.size(dim) ({tensor_size}), for {dim=}", + ) + if alpha != 1: + python_type = utils.dtype_to_type(x.dtype) + torch._check( + python_type == bool + or utils.is_weakly_lesser_type(type(alpha), python_type), + lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", + ) + tensor = tensor * alpha + # Treat scalars as elements of \R^1 + zero_dim = x.ndim == 0 + x1 = x.unsqueeze(0) if zero_dim else x + idx = (None,) * dim + (index,) + index_put = aten.index_put_ if inplace else aten.index_put + out = index_put(x1, idx, tensor, accumulate=True) + if inplace: + return x + else: + return out.squeeze(0) if zero_dim else out.contiguous() + + +@register_decomposition(aten.pad_sequence.default) +@aten.pad_sequence.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def pad_sequence(sequences, batch_first=False, padding_value=0.0): + torch._check(len(sequences) > 0, lambda: "received an empty list of sequences") + sequences_size = len(sequences) + max_size = sequences[0].size() + trailing_dims = max_size[1:] + max_len = max(x.size(0) for x in sequences) + if batch_first: + out_dims = (sequences_size, max_len) + else: + out_dims = (max_len, sequences_size) + out_dims = out_dims + trailing_dims + out = sequences[0].new_full(out_dims, padding_value) + dim_paddings = (0, 0) * len(trailing_dims) + for i in range(sequences_size): + currseq = sequences[i] + row = aten.constant_pad_nd( + currseq, dim_paddings + (0, max_len - currseq.size(0)), padding_value + ) + if batch_first: + out = aten.select_scatter(out, row, dim=0, index=i) + else: + out = aten.select_scatter(out, row, dim=1, index=i) + return out + + +@register_decomposition(aten.index_copy_) +def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + return _index_copy(x, dim, index, tensor, inplace=True) + + +@register_decomposition(aten.index_copy) +@out_wrapper() +def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + return _index_copy(x, dim, index, tensor, inplace=False) + + +def _index_copy( + x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool +): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + # Treat scalars as elements of \R^1 + zero_dim = x.ndim == 0 + x1 = x.unsqueeze(0) if zero_dim else x + index = index.unsqueeze(0) if index.ndim == 0 else index + idx = (None,) * dim + (index,) + index_put = aten.index_put_ if inplace else aten.index_put + out = index_put(x1, idx, tensor) + if inplace: + return x + else: + return out.squeeze(0) if zero_dim else out.contiguous() + + +# nb: Should use acc_t, not op_math +@register_decomposition(aten.log_sigmoid_forward) +@out_wrapper("output", "buffer") +@pw_cast_for_opmath +def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: + min = torch.minimum(self.new_zeros(()), self) + z = torch.exp(-torch.abs(self)) + if self.is_cuda: + buffer = self.new_zeros((0,)) + else: + buffer = z + return min - torch.log1p(z), buffer + + +@register_decomposition(aten.uniform) +@out_wrapper() +def uniform( + x: Tensor, + low: Union[bool, int, float] = 0.0, + high: Union[bool, int, float] = 1.0, + generator: Optional[torch.Generator] = None, +): + return prims._uniform_helper( + x.shape, + low=sym_float(low), + high=sym_float(high), + dtype=x.dtype, + device=x.device, + generator=generator, + ) + + +@register_decomposition(aten.uniform_) +def uniform_(self, low=0, high=1, generator=None): + return self.copy_(uniform(self, low, high, generator)) + + +# aten/src/ATen/native/UpSample.cpp compute_output_size +def upsample_compute_output_size(input_size, output_size, scale_factors): + spatial_dimensions = len(input_size) - 2 + if output_size is not None: + torch._check( + scale_factors is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + torch._check(len(output_size) == spatial_dimensions, lambda: "") + return output_size + if scale_factors is not None: + # NB: this isn't necessary lol + torch._check( + output_size is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + torch._check(len(scale_factors) == spatial_dimensions, lambda: "") + output_size = [] + for i, s in enumerate(scale_factors): + if int(s) == s: + output_size.append(input_size[i + 2] * int(s)) + else: + output_size.append(sym_int(input_size[i + 2] * s)) + return output_size + torch._check( + False, lambda: "Must specify exactly one of output_size and scale_factors" + ) + + +def get_scale_value(scales, idx): + if scales is None: + return None + return scales[idx] + + +@register_decomposition(aten.upsample_nearest1d.vec) +@register_decomposition(aten.upsample_nearest2d.vec) +@register_decomposition(aten.upsample_nearest3d.vec) +@aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd) +def _upsample_nearest_vec( + input: Tensor, + output_size: Optional[List[int]], + scale_factors: Optional[List[float]], +) -> Tensor: + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scales = ( + scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item] + ) + return _upsample_nearest(input, osize, scales) + + +@register_decomposition(aten._upsample_nearest_exact1d.vec) +@register_decomposition(aten._upsample_nearest_exact2d.vec) +@register_decomposition(aten._upsample_nearest_exact3d.vec) +@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd) +@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd) +@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd) +def _upsample_nearest_exact_vec( + input: Tensor, + output_size: Optional[List[int]], + scale_factors: Optional[List[float]], +) -> Tensor: + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scales = ( + scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item] + ) + return _upsample_nearest(input, osize, scales, exact=True) + + +def _compute_upsample_nearest_indices(input, output_size, scales, exact=False): + # For each dim in output_size, compute the set of input indices used + # to produce the upsampled output. + indices = [] + num_spatial_dims = len(output_size) + offset = 0.5 if exact else 0.0 + + for d in range(num_spatial_dims): + # Math matches aten/src/ATen/native/cpu/UpSampleKernel.cpp + # + # Indices are computed as following: + # scale = isize / osize + # Case: exact=False + # input_index = floor(output_index * scale) + # Same as OpenCV INTER_NEAREST + # + # Case: exact=False + # index_f32 = (output_index + 0.5) * scale - 0.5 + # input_index = round(index_f32) + # Same as Pillow and Scikit-Image/Scipy ndi.zoom + osize = output_size[d] + isize = input.shape[-num_spatial_dims + d] + scale = isize / (isize * scales[d]) if scales[d] is not None else isize / osize + + output_indices = torch.arange(osize, dtype=torch.float32, device=input.device) + input_indices = ((output_indices + offset) * scale).to(torch.int64) + for _ in range(num_spatial_dims - 1 - d): + input_indices = input_indices.unsqueeze(-1) + indices.append(input_indices) + return indices + + +@register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out]) +@aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def upsample_nearest1d( + input: Tensor, + output_size: List[int], + scales: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales]) + + +@register_decomposition( + [aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out] +) +@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def upsample_nearest_exact1d( + input: Tensor, + output_size: List[int], + scales: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales], exact=True) + + +@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out]) +@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def upsample_nearest2d( + input: Tensor, + output_size: List[int], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales_h, scales_w]) + + +@register_decomposition( + [aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out] +) +@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def _upsample_nearest_exact2d( + input: Tensor, + output_size: List[int], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True) + + +@register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out]) +@aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def upsample_nearest3d( + input: Tensor, + output_size: List[int], + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w]) + + +@register_decomposition( + [aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out] +) +@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def _upsample_nearest_exact3d( + input: Tensor, + output_size: List[int], + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_nearest( + input, output_size, [scales_d, scales_h, scales_w], exact=True + ) + + +@pw_cast_for_opmath +def _upsample_nearest( + input: Tensor, + output_size: List[int], + scales: List[Optional[float]], + exact: bool = False, +) -> Tensor: + spatial_indices = _compute_upsample_nearest_indices( + input, output_size, scales, exact=exact + ) + + indices = [None, None] + spatial_indices + result = aten._unsafe_index(input, indices) + + if result.ndim == 4: + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + + # following "heuristic: only use channels_last path when it's faster than the contiguous path" + n_channels = input.shape[1] + if input.device.type == "cuda" and n_channels < 4: + memory_format = torch.contiguous_format + + result = result.contiguous(memory_format=memory_format) + return result + + +def gather_params(params, has_biases, has_projections): + if has_biases and has_projections: + group_size = 5 + elif has_biases: + group_size = 4 + elif has_projections: + group_size = 3 + else: + group_size = 2 + + assert len(params) % group_size == 0, len(params) + return [ + tuple(params[i : i + group_size]) for i in range(0, len(params), group_size) + ] + + +def params_hiddens(params, hiddens, i, bidirectional): + if bidirectional: + cur_params, cur_hidden = params[2 * i], hiddens[2 * i] + bidir_params, bidir_hidden = params[2 * i + 1], hiddens[2 * i + 1] + else: + cur_params, cur_hidden = params[i], hiddens[i] + bidir_params, bidir_hidden = None, None + + return cur_params, cur_hidden, bidir_params, bidir_hidden + + +def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens): + assert last_batch_size > batch_size + hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size)) + return cur_hidden.narrow(0, 0, batch_size) + + +def update_hidden_for_packed_reverse( + cur_hidden, last_batch_size, batch_size, inp_hidden +): + if last_batch_size == batch_size: + return cur_hidden + assert last_batch_size < batch_size + return torch.concat( + ( + cur_hidden, + inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size), + ) + ) + + +def one_layer_rnn_data( + inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False +): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + + step_output = [] + hiddens: List[torch.Tensor] = [] + + last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0] + cur_hidden = hidden.narrow(0, 0, last_batch_size) + split_inp = torch.split(inp, list(batch_sizes)) + if reverse: + split_inp = split_inp[::-1] + for inp in split_inp: + i = inp.shape[0] + + if last_batch_size == i: + pass # don't update cur_hidden + # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest + elif reverse: + cur_hidden = update_hidden_for_packed_reverse( + cur_hidden, last_batch_size, i, hidden + ) + else: + cur_hidden = update_hidden_for_packed( + cur_hidden, last_batch_size, i, hiddens + ) + + cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias) + last_batch_size = i + step_output.append(cur_hidden) + + if reverse: + step_output.reverse() + else: + hiddens.append(cur_hidden) + hiddens.reverse() + + out = torch.cat(step_output, 0) + hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden + return out, hidden_out + + +def rnn_cell(nonlinearity): + def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i) + + return inner + + +def rnn_cell_data(nonlinearity): + def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + i = F.linear(i, ih_weight, ih_bias) + return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i) + + return inner + + +def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + + precomputed_input = F.linear(inp, ih_weight, ih_bias) + precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input + cur_hidden = hidden.unsqueeze(0) + step_output = [] + for i in precomputed_input: + cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias) + step_output.append(cur_hidden) + + if reverse: + step_output.reverse() + + out = torch.cat(step_output, 0) + + return out, cur_hidden.squeeze(0) + + +def mkldnn_one_layer_lstm(inp, hidden, params, has_biases, reverse=False): + w0 = params[0] + w1 = params[1] + if has_biases: + w2 = params[2] + w3 = params[3] + else: + w2 = torch.zeros(w0.size()) + w3 = torch.zeros(w1.size()) + + hx = hidden[0].unsqueeze(0) + cx = hidden[1].unsqueeze(0) + + batch_sizes: List[int] = [] + mode = 2 # third_party/ideep/include/ideep/abstract_types.hpp: ideep::rnn_kind::LSTM = 2 + hidden_size = hx.size(2) + num_layers = 1 + + # _rnn_helper already handles bidirectional and batch_first so we hard-code them to False here + bidirectional = False + batch_first = False + + train = False + # If batch_first, inp has been permuted in _rnn_helper. Convert to contiguous here. + # Same as aten/src/ATen/native/mkldnn/RNN.cpp: mkldnn_rnn: input = input.contiguous(); + inp = inp.contiguous() + hx = hx.contiguous() + cx = cx.contiguous() + outputs = torch.ops.aten.mkldnn_rnn_layer.default( + inp, + w0, + w1, + w2, + w3, + hx, + cx, + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ) + y, hy, cy = outputs[0], outputs[1], outputs[2] + return y, (hy.squeeze(0), cy.squeeze(0)) + + +def _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + layer_fn, +): + input = input.transpose(0, 1) if batch_first else input + final_hiddens = [] + + for i in range(num_layers): + cur_params, cur_hidden, bidir_params, bidir_hidden = params_hiddens( + params, hidden, i, bidirectional + ) + dropout = dropout if (train and num_layers < i - 1) else 0.0 + fwd_inp, fwd_hidden = layer_fn(input, cur_hidden, cur_params, has_biases) + final_hiddens.append(fwd_hidden) + + if bidirectional: + bwd_inp, bwd_hidden = layer_fn( + input, bidir_hidden, bidir_params, has_biases, reverse=True + ) + final_hiddens.append(bwd_hidden) + + if bidirectional: + input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1) # type: ignore[possibly-undefined] + else: + input = fwd_inp + + if dropout != 0 and train and i < num_layers - 1: + input = torch.dropout(input, dropout, train=True) + + input = input.transpose(0, 1) if batch_first else input + return input, final_hiddens + + +@register_decomposition(aten.rnn_tanh.input) +@aten.rnn_tanh.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_tanh.input.py_impl(DispatchKey.Autograd) +def rnn_tanh_input( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_relu.input) +@aten.rnn_relu.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_relu.input.py_impl(DispatchKey.Autograd) +def rnn_relu_input( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_relu.data) +@aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_relu.data.py_impl(DispatchKey.Autograd) +def rnn_relu_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial( + one_layer_rnn_data, + batch_sizes=batch_sizes, + hidden_fn=rnn_cell_data(torch.relu), + ), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_tanh.data) +@aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_tanh.data.py_impl(DispatchKey.Autograd) +def rnn_tanh_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial( + one_layer_rnn_data, + batch_sizes=batch_sizes, + hidden_fn=rnn_cell_data(torch.tanh), + ), + ) + return out, torch.stack(final_hiddens, 0) + + +def lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim): + gates = F.linear(hx, hh_weight, hh_bias) + inp + chunked_gates = gates.chunk(4, chunk_dim) + in_gate = chunked_gates[0].sigmoid() + forget_gate = chunked_gates[1].sigmoid() + cell_gate = chunked_gates[2].tanh() + out_gate = chunked_gates[3].sigmoid() + cy = forget_gate * cx + (in_gate * cell_gate) + hy = out_gate * cy.tanh() + hy = hy if hr_weight is None else F.linear(hy, hr_weight, None) + + return hy, cy + + +def one_layer_lstm(inp, hidden, params, has_biases, reverse=False): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + hr_weight = ( + params[4] if len(params) == 5 else params[2] if len(params) == 3 else None + ) + + hx = hidden[0].unsqueeze(0) + cx = hidden[1].unsqueeze(0) + + precomputed_input = F.linear(inp, ih_weight, ih_bias) + precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input + step_output = [] + for inp in precomputed_input: + hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2) + step_output.append(hx) + + if reverse: + step_output.reverse() + + out = torch.cat(step_output, 0) + + return out, (hx.squeeze(1), cx.squeeze(1)) + + +def one_layer_lstm_data(inp, hidden, params, has_biases, batch_sizes, reverse=False): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + hr_weight = ( + params[4] if len(params) == 5 else params[2] if len(params) == 3 else None + ) + + step_output = [] + hiddens = [] + + last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0] + split_inp = torch.split(inp, list(batch_sizes)) + if reverse: + split_inp = split_inp[::-1] + + orig_hx = hidden[0] + orig_cx = hidden[1] + hx, cx = orig_hx.narrow(0, 0, last_batch_size), orig_cx.narrow( + 0, 0, last_batch_size + ) + + for inp in split_inp: + i = inp.shape[0] + inp = F.linear(inp, ih_weight, ih_bias) + + # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest + if i < last_batch_size: + hiddens.append( + ( + hx.narrow(0, i, last_batch_size - i), + cx.narrow(0, i, last_batch_size - i), + ) + ) + hx, cx = hx.narrow(0, 0, i), cx.narrow(0, 0, i) + + # this will only happen when reverse=True + if i > last_batch_size: + hx = torch.concat( + (hx, orig_hx.narrow(0, last_batch_size, i - last_batch_size)), 0 + ) + cx = torch.concat( + (cx, orig_cx.narrow(0, last_batch_size, i - last_batch_size)), 0 + ) + + hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=1) + last_batch_size = i + step_output.append(hx) + + if reverse: + step_output.reverse() + hidden_out = (hx, cx) + else: + hiddens.append((hx, cx)) + hiddens.reverse() + hidden0, hidden1 = zip(*hiddens) + hidden_out = torch.cat(hidden0, 0), torch.cat(hidden1, 0) + + out = torch.cat(step_output, 0) + return out, hidden_out + + +def select_one_layer_lstm_function(input, hx, params): + r"""Check whether we could use decompose lstm with mkldnn_rnn_layer. + All the below conditions need to be met: + * ``torch._C._get_mkldnn_enabled()`` returns ``True``. + * All the input args are on CPU. + * The dtypes of args are either torch.float or torch.bfloat16. + * Inference. + * ``has_projections`` returns ``False``. + + Args: + * input: the input sequence to LSTM + * hx: a tuple of the input hidden state and cell state ``(h_0, c_0)`` to LSTM + * params: the weight and bias tensors of LSTM + """ + + def use_mkldnn(input, hx, params): + if not torch._C._get_mkldnn_enabled(): + return False + + tensors = [input] + list(hx) + list(chain.from_iterable(params)) + devices = {t.device for t in tensors} + if len(devices) != 1: + return False + + device = devices.pop() + if device != torch.device("cpu"): + return False + # With autocast, possible to have mixed dtype here + dtypes = {t.dtype for t in tensors} + for dtype in dtypes: + if dtype not in [torch.float, torch.bfloat16]: + return False + + if input.requires_grad: + return False + + has_projections = hx[0].size(2) != hx[1].size(2) + if has_projections: + return False + + return True + + # mkldnn_one_layer_lstm does not depend on seq_len while one_layer_lstm + # will expand over the seq_len dim + if use_mkldnn(input, hx, params): + return mkldnn_one_layer_lstm + else: + return one_layer_lstm + + +@register_decomposition(aten.lstm.input) +@aten.lstm.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.lstm.input.py_impl(DispatchKey.Autograd) +def lstm_impl( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + assert len(hx) == 2, "lstm expects two hidden states" + params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2)) + hidden = list(zip(hx[0], hx[1])) + layer_fn = select_one_layer_lstm_function(input, hx, params) + out, final_hiddens = _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + layer_fn, + ) + final_hiddens = list(zip(*final_hiddens)) + return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0) + + +@register_decomposition(aten.lstm.data) +@aten.lstm.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.lstm.data.py_impl(DispatchKey.Autograd) +def lstm_data_impl( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + assert len(hx) == 2, "lstm expects two hidden states" + params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2)) + hidden = list(zip(hx[0], hx[1])) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial(one_layer_lstm_data, batch_sizes=batch_sizes), + ) + final_hiddens = list(zip(*final_hiddens)) + return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0) + + +def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + chunked_igates = inp.chunk(3, 1) + chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2) + reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid() + input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid() + new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh() + return (cur_hidden - new_gate) * input_gate + new_gate + + +def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1) + chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1) + reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid() + input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid() + new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh() + return (cur_hidden - new_gate) * input_gate + new_gate + + +@register_decomposition(aten.gru.data) +@aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.gru.data.py_impl(DispatchKey.Autograd) +def gru_impl_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hx.unbind(0), + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.gru.input) +@aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.gru.input.py_impl(DispatchKey.Autograd) +def gru_impl( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + input, + hx.unbind(0), + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + partial(one_layer_rnn, hidden_fn=gru_cell), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten._upsample_bilinear2d_aa.vec) +@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.Autograd) +def upsample_bilinear2d_aa_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + return torch.ops.aten._upsample_bilinear2d_aa( + input, osize, align_corners, scale_h, scale_w + ) + + +@register_decomposition(aten._upsample_bicubic2d_aa.vec) +@aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.Autograd) +def upsample_bicubic2d_aa_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + return torch.ops.aten._upsample_bicubic2d_aa( + input, osize, align_corners, scale_h, scale_w + ) + + +@register_decomposition(aten.upsample_bilinear2d.vec) +@register_decomposition(aten.upsample_trilinear3d.vec) +@aten.upsample_linear1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_linear1d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_trilinear3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_trilinear3d.vec.py_impl(DispatchKey.Autograd) +def _upsample_linear_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scales = scale_factors if scale_factors else [None] * len(osize) + return _upsample_linear(input, osize, align_corners, scales) + + +@register_decomposition([aten.upsample_linear1d.default, aten.upsample_linear1d.out]) +@out_wrapper() +def upsample_linear1d( + input: Tensor, + output_size: List[int], + align_corners: bool, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_linear(input, output_size, align_corners, [scales_w]) + + +@register_decomposition( + [aten.upsample_bilinear2d.default, aten.upsample_bilinear2d.out] +) +@aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper() +def upsample_bilinear2d( + input: Tensor, + output_size: List[int], + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_linear(input, output_size, align_corners, [scales_h, scales_w]) + + +@register_decomposition( + [aten.upsample_trilinear3d.default, aten.upsample_trilinear3d.out] +) +@out_wrapper() +def upsample_trilinear3d( + input: Tensor, + output_size: List[int], + align_corners: bool, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_linear( + input, output_size, align_corners, [scales_d, scales_h, scales_w] + ) + + +def _compute_scale(in_size, out_size, align_corners, scale=None): + if align_corners: + return (in_size - 1.0) / (out_size - 1.0) if out_size > 1 else 0 + else: + return 1.0 / scale if scale is not None and scale > 0 else in_size / out_size + + +def _compute_source_index(scale, dst_index, align_corners): + if align_corners: + return scale * dst_index + else: + return scale * (dst_index + 0.5) - 0.5 + + +def _sum_tensors_uint8( + src: Iterable[Tensor], weights: Iterable[Tensor], weights_precision: Tensor +) -> Tensor: + output = _sum_tensors( + s.to(torch.int32) * c.to(torch.int32) for s, c in zip(src, weights) + ) + (1 << (weights_precision - 1)) + output = output >> weights_precision + return torch.clamp(output, 0, 255).to(torch.uint8) + + +def _compute_weight_precision(weights: TensorSequenceType) -> Tensor: + max_weight = torch.stack(weights).max() + max_weight_precision = 22 + precisions = torch.arange(max_weight_precision, device=max_weight.device) + values = 0.5 + max_weight * (1 << (precisions + 1)) + mask = values >= (1 << 15) + return max_weight_precision - mask.sum() + + +@pw_cast_for_opmath +def _upsample_linear( + input: Tensor, + output_size: List[int], + align_corners: bool, + scales: List[Optional[float]], +) -> Tensor: + # get dimensions of original image + n_batch, n_channels = input.shape[:2] + inp_sizes = input.shape[2:] + n_dims = len(inp_sizes) + + _, dtype = utils.elementwise_dtypes( + input, + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + + def get_values(inp_size, out_size, scales, nsqueeze): + # First Calculate scaling factor + scale_factor = _compute_scale(inp_size, out_size, align_corners, scales) + # We have to create arange with int64 dtype and use .to in order to avoid + # additional kernels creation in inductor and get a perf slowdown + i = torch.arange(out_size, device=input.device).to(dtype=dtype) + + x_f32 = _compute_source_index(scale_factor, i, align_corners).clamp(min=0.0) + x_f32 = x_f32.reshape(x_f32.shape[0], *[1] * (nsqueeze)) + x = x_f32.to(torch.int64) + xp1 = (x + 1).clamp(max=inp_size - 1) + return x_f32, x, xp1 + + values = [ + get_values(inp_size, out_size, scales, n_dims - 1 - i) + for i, (inp_size, out_size, scales) in enumerate( + zip(inp_sizes, output_size, scales) + ) + ] + xs_f32, xs, xp1s = list(zip(*values)) + + vs = [] + for a in product(*[[0, 1]] * n_dims): + idx = [None, None] + [xs[k] if a[k] == 0 else xp1s[k] for k in range(n_dims)] + v = aten._unsafe_index(input, idx) + v = _maybe_convert_to_dtype(v, dtype) + vs.append(v) + + for i in reversed(range(n_dims)): + xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype) + vs = [ + # x1 * (1 - alpha) + x2 * alpha == x1 + (x2 - x1) * alpha + v1 + torch.mul(v2 - v1, xscale) + for v1, v2 in zip(vs[::2], vs[1::2]) + ] + + assert len(vs) == 1 + result = vs[0] + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + + # following "heuristic: only use channels_last path when it's faster than the contiguous path" + if input.device.type == "cuda" and n_channels < 16: + memory_format = torch.contiguous_format + + assert isinstance(result, torch.Tensor) + + result = result.contiguous(memory_format=memory_format) + + if not input.is_floating_point(): + result = result.round() + + return result + + +# We should be applying decompositions after all transformations +@register_decomposition(aten.is_same_size.default) +def is_same_size(a: Tensor, b: Tensor) -> bool: + return a.shape == b.shape + + +@register_decomposition([aten._reshape_alias, aten._unsafe_view]) +@out_wrapper() +def _reshape_alias(x, shape, *args): + return aten.view(x, shape) + + +@register_decomposition([aten._unsafe_index]) +def _unsafe_index(x, indices): + return aten.index(x, indices) + + +@register_decomposition([aten._unsafe_index_put]) +def _unsafe_index_put(x, indices, value, accumulate=False): + return aten.index_put(x, indices, value, accumulate) + + +@register_decomposition([aten._unsafe_masked_index]) +def _unsafe_masked_index(x, mask, indices, fill): + for index in indices: + if index is not None: + torch._check( + index.dtype in [torch.long, torch.int], + lambda: "tensors used as indices must be long or int tensors", + ) + + torch._check( + mask.dtype == torch.bool, + lambda: "tensors used as masks must be bool tensors", + ) + + if x.numel() == 0: + meta_result = torch._meta_registrations.meta_index_Tensor(x, indices) + return x.new_full(meta_result.shape, fill) + + for i in range(len(indices)): + index = indices[i] + if index is not None: + indices[i] = index.clamp(min=0, max=x.size(i) - 1) + + return aten._unsafe_index(x, indices).masked_fill(~mask, fill) + + +@register_decomposition([aten._unsafe_masked_index_put_accumulate]) +def _unsafe_masked_index_put_accumulate(x, mask, indices, values): + for index in indices: + if index is not None: + torch._check( + index.dtype in [torch.long, torch.int], + lambda: "tensors used as indices must be long or int tensors", + ) + + torch._check( + mask.dtype == torch.bool, + lambda: "tensors used as masks must be bool tensors", + ) + + if x.numel() == 0: + return x.clone() + + for i in range(len(indices)): + index = indices[i] + if index is not None: + indices[i] = index.clamp(min=-x.size(i), max=x.size(i) - 1) + + masked_value = values.masked_fill(~mask, 0) + return aten._unsafe_index_put(x, indices, masked_value, accumulate=True) + + +def _nll_loss_forward( + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, +) -> Tuple[Tensor, Tensor]: + # self can be [N, C] or [C] + # target can be [N] or [] + + n_dims = self.dim() + channel_dim = 1 + if n_dims < 2: + channel_dim = 0 + + if weight is not None: + if n_dims > 1: + shape = [ + 1, + ] * n_dims + shape[channel_dim] = weight.shape[0] + w = weight.view(shape) + else: + w = weight + self = self * w + safe_target = torch.where(target != ignore_index, target, 0) + safe_target_ = safe_target.unsqueeze(channel_dim) + # target can be [N, 1] or [1] + + result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim) + + result = torch.where(target != ignore_index, result, 0) + + if reduction == Reduction.NONE.value and n_dims > 1: + total_weight = self.new_full((), 0.0) + return result, total_weight + + if weight is not None: + w = w.expand(self.shape) + wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) + wsum = torch.where(target != ignore_index, wsum, 0) + total_weight = wsum.sum() + else: + total_weight = (target != ignore_index).sum().to(self) + + if reduction == Reduction.SUM.value: + result = result.sum() + elif reduction == Reduction.MEAN.value: + result = result.sum() / total_weight + + return result, total_weight + + +@register_decomposition(aten.nll_loss_forward) +@out_wrapper("output", "total_weight") +def nll_loss_forward( + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, +) -> Tuple[Tensor, Tensor]: + assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D" + assert ( + target.dim() <= 1 + ), "0D or 1D target tensor expected, multi-target not supported" + + no_batch_dim = self.dim() == 1 and target.dim() == 0 + assert no_batch_dim or ( + self.shape[0] == target.shape[0] + ), f"size mismatch (got input: {self.shape}, target: {target.shape})" + + n_classes = self.shape[-1] + + assert weight is None or ( + weight.dim() == 1 and weight.numel() == n_classes + ), f"weight tensor should be defined either for all {n_classes} classes or no classes but got weight tensor of shape: {weight.shape}" # noqa: B950 + + return _nll_loss_forward(self, target, weight, reduction, ignore_index) + + +@register_decomposition(aten.nll_loss2d_forward) +@out_wrapper("output", "total_weight") +def nll_loss2d_forward( + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, +) -> Tuple[Tensor, Tensor]: + return _nll_loss_forward(self, target, weight, reduction, ignore_index) + + +# These are adapted from aten/src/ATen/native/UpSample.h, wich is based on +# https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +def _upsample_cubic_convolution1(x: Tensor, A: float) -> Tensor: + return ((A + 2) * x - (A + 3)) * x * x + 1 + + +def _upsample_cubic_convolution2(x: Tensor, A: float) -> Tensor: + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A + + +def _upsample_get_cubic_coefficients(t: Tensor) -> TensorSequenceType: + A = -0.75 + + if t.device == torch.device("cpu"): + tt1 = torch.stack([t, 1.0 - t], dim=0) + tt2 = torch.stack([t + 1.0, 2.0 - t], dim=0) + w03 = _upsample_cubic_convolution2(tt2, A) + w12 = _upsample_cubic_convolution1(tt1, A) + w0, w3 = torch.unbind(w03, dim=0) + w1, w2 = torch.unbind(w12, dim=0) + return w0, w1, w2, w3 + else: + return ( + _upsample_cubic_convolution2(t + 1.0, A), + _upsample_cubic_convolution1(t, A), + _upsample_cubic_convolution1(1.0 - t, A), + _upsample_cubic_convolution2(2.0 - t, A), + ) + + +def _upsample_cubic_interp1d(coeffs: TensorSequenceType, ts: Tensor) -> Tensor: + coeffs2 = _upsample_get_cubic_coefficients(ts) + return _sum_tensors(c1 * c2 for (c1, c2) in zip(coeffs, coeffs2)) + + +# Need this instead of just sum() to keep mypy happy +def _sum_tensors(ts: Iterable[Tensor]) -> Tensor: + return reduce(torch.add, ts) + + +def _linspace_from_neg_one( + num_steps: int, align_corners: bool, dtype: torch.dtype, device: torch.device +): + if num_steps <= 1: + return torch.tensor(0, device=device, dtype=dtype) + + a = ((num_steps - 1) / num_steps) if not align_corners else 1 + return torch.linspace(-a, a, steps=num_steps, device=device, dtype=dtype) + + +def _make_base_grid_4d(theta: Tensor, h: int, w: int, align_corners: bool): + dtype = theta.dtype + device = theta.device + + # Using padding and summation generates a single kernel vs using torch.stack where 3 kernels generated + # corresponding to each individual tensor: grid_x, grid_y, grid_one + grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, w, 1) + grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(h, 1, 1) + grid_one = torch.ones((1, 1, 1), dtype=dtype, device=device) + + # this is just a temporary hack and we should use torch.stack here once #104480 is merged + grid_x = torch.nn.functional.pad(grid_x, pad=(0, 2), mode="constant", value=0) + grid_y = torch.nn.functional.pad(grid_y, pad=(1, 1), mode="constant", value=0) + grid_one = torch.nn.functional.pad(grid_one, pad=(2, 0), mode="constant", value=0) + return grid_x + grid_y + grid_one + + +def _make_base_grid_5d(theta: Tensor, d: int, h: int, w: int, align_corners: bool): + dtype = theta.dtype + device = theta.device + + grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, 1, w, 1) + grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(1, h, 1, 1) + grid_z = _linspace_from_neg_one(d, align_corners, dtype, device).view(d, 1, 1, 1) + grid_one = torch.ones((1, 1, 1, 1), dtype=dtype, device=device) + + # this is just a temporary hack and we should use torch.stack here once #104480 is merged + grid_x = torch.nn.functional.pad(grid_x, pad=(0, 3), mode="constant", value=0) + grid_y = torch.nn.functional.pad(grid_y, pad=(1, 2), mode="constant", value=0) + grid_z = torch.nn.functional.pad(grid_z, pad=(2, 1), mode="constant", value=0) + grid_one = torch.nn.functional.pad(grid_one, pad=(3, 0), mode="constant", value=0) + return grid_x + grid_y + grid_z + grid_one + + +def _affine_grid_generator_4d(theta: Tensor, size: List[int], align_corners: bool): + n, _, h, w = size + base_grid = _make_base_grid_4d(theta, h, w, align_corners=align_corners) + # base_grid shape is (h, w, 3) and theta shape is (n, 2, 3) + # We do manually a matrix multiplication which is faster than mm() + # (h * w, 3, 1) * (n, 1, 3, 2) -> (n, h * w, 2) + grid = (base_grid.view(-1, 3, 1) * theta.mT.unsqueeze(1)).sum(-2) + return grid.view(n, h, w, 2) + + +def _affine_grid_generator_5d(theta: Tensor, size: List[int], align_corners: bool): + n, _, d, h, w = size + base_grid = _make_base_grid_5d(theta, d, h, w, align_corners=align_corners) + # base_grid shape is (d, h, w, 4) and theta shape is (n, 3, 4) + # We do manually a matrix multiplication which is faster than mm() + # (d * h * w, 4, 1) * (n, 1, 4, 3) -> (n, h * w, 3) + grid = (base_grid.view(-1, 4, 1) * theta.mT.unsqueeze(1)).sum(-2) + return grid.view(n, d, h, w, 3) + + +@register_decomposition(aten.affine_grid_generator) +@out_wrapper() +@pw_cast_for_opmath +def affine_grid_generator(theta: Tensor, size: List[int], align_corners: bool): + torch._check( + len(size) in (4, 5), + lambda: "affine_grid_generator needs 4d (spatial) or 5d (volumetric) inputs.", + ) + if len(size) == 4: + return _affine_grid_generator_4d(theta, size, align_corners=align_corners) + else: + return _affine_grid_generator_5d(theta, size, align_corners=align_corners) + + +def _grid_sampler_2d( + a: Tensor, + grid: Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, + _expand_grid: bool = True, +) -> Tensor: + # This method is a copy of grid_sampler_2d implementation and introduced with additional arg _expand_grid to + # optionally expand the input grid for performance reasons. + # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x + # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) + # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. + # Thus we apply this hack to not expand the grid for this case. + + torch._check( + interpolation_mode in (0, 1, 2), + lambda: f"Invalid interpolation mode {interpolation_mode}", + ) + torch._check( + padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}" + ) + + def unnormalize(coords: Tensor, size: int) -> Tensor: + # Rescale coordinates from [-1, 1] to: + # [0, size - 1] if align_corners is True + # [-.5, size -.5] if align_corners is False + mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5) + ofs = size * 0.5 - 0.5 + return coords * mul + ofs + + # Reflects coordinates until they fall between low and high (inclusive). + # The bounds are passed as twice their value so that half-integer values + # can be represented as ints. + def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor: + if twice_low == twice_high: + return torch.zeros_like(coords) + coords_min = twice_low / 2 + coords_span = (twice_high - twice_low) / 2 + coords2 = (coords - coords_min).abs() + extra = torch.fmod(coords2, coords_span) + flips = (coords2 / coords_span).floor().to(dtype=torch.int8) + return torch.where( + flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra + ) + + def compute_coordinates(coords: Tensor, size: int) -> Tensor: + if padding_mode == 0: # Zero + return coords + elif padding_mode == 1: # Borders + return torch.clamp(coords, 0, size - 1) + else: # padding_mode == 2, Reflection + if align_corners: + coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1)) + else: + coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1) + return torch.clamp(coords_reflected, 0, size - 1) + + def compute_source_index(coords: Tensor, size: int) -> Tensor: + coords_un = unnormalize(coords, size) + return compute_coordinates(coords_un, size) + + N, C, iH, iW = a.shape + _, oH, oW, two = grid.shape + assert two == 2 + + if _expand_grid: + # Let's expand grid to [N, C, oH, oW, 2] + # This allows to generate a single triton cuda kernel instead of two kernels. + # Two kernels are due source indices, weights have shape (N, 1, oH, oW), xnumel=N*oH*oW + # and output has shape (N, C, oH, oW), xnumel=N*C*oH*oW + # Expanding grid to (N, C, oH, oW, two) unifies xnumel to N*C*oH*oW + grid = grid.view(N, 1, oH, oW, two).expand(N, C, oH, oW, 2) + + def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor: + return torch.logical_and( + 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys < iH)) + ) + + N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1) + C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1) + + def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType: + cond = in_bounds_cond(xs, ys) + # To clip to inside valid coordinates, we map the coordinates + # to (x, y) = (0, 0) and also set the weight to 0 + # We also change the shape of the tensor to the appropriate one for + # broadcasting with N_idx, C_idx for the purposes of advanced indexing + c = C if _expand_grid else 1 + return tuple( + torch.where(cond, t, 0).view(N, c, oH, oW) + for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws) + ) + + def get_summand(ix: Tensor, iy: Tensor, w) -> Tensor: + # Perform clipping, index into input tensor and multiply by weight + idx_x, idx_y, w_ = clip(ix, iy, w) + return a[N_idx, C_idx, idx_y, idx_x] * w_ + + x = grid[..., 0] + y = grid[..., 1] + + if interpolation_mode == 0: # Bilinear + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + + ix_nw, iy_nw = ix.floor(), iy.floor() + ix_ne, iy_ne = ix_nw + 1, iy_nw + ix_sw, iy_sw = ix_nw, iy_nw + 1 + ix_se, iy_se = ix_ne, iy_sw + + w_nw = (ix_se - ix) * (iy_se - iy) + w_ne = (ix - ix_sw) * (iy_sw - iy) + w_sw = (ix_ne - ix) * (iy - iy_ne) + w_se = (ix - ix_nw) * (iy - iy_nw) + + return _sum_tensors( + get_summand(ix, iy, w) + for (ix, iy, w) in ( + (ix_nw, iy_nw, w_nw), + (ix_ne, iy_ne, w_ne), + (ix_sw, iy_sw, w_sw), + (ix_se, iy_se, w_se), + ) + ) + elif interpolation_mode == 1: # Nearest + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + + ix_nearest = ix.round() + iy_nearest = iy.round() + + return get_summand(ix_nearest, iy_nearest, 1) + else: # interpolation_mode == 2, Bicubic + ix = unnormalize(x, iW) + iy = unnormalize(y, iH) + + ix_nw = ix.floor() + iy_nw = iy.floor() + + tx = ix - ix_nw + ty = iy - iy_nw + + if not _expand_grid: + tx = tx.unsqueeze(1) + ty = ty.unsqueeze(1) + + def get_value_bounded(ix: Tensor, iy: Tensor) -> Tensor: + x = compute_coordinates(ix, iW) + y = compute_coordinates(iy, iH) + return get_summand(x, y, 1) + + def get_coeff(ofs: int) -> Tensor: + iy_ofs = iy_nw + (ofs - 1) + cs = ( + get_value_bounded(ix_nw - 1, iy_ofs), + get_value_bounded(ix_nw, iy_ofs), + get_value_bounded(ix_nw + 1, iy_ofs), + get_value_bounded(ix_nw + 2, iy_ofs), + ) + return _upsample_cubic_interp1d(cs, tx) + + coeffs = tuple(get_coeff(ofs) for ofs in range(4)) + return _upsample_cubic_interp1d(coeffs, ty) + + +@register_decomposition(aten.grid_sampler_2d) +@out_wrapper() +@pw_cast_for_opmath +def grid_sampler_2d( + a: Tensor, + grid: Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, +) -> Tensor: + return _grid_sampler_2d( + a, + grid=grid, + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + + +@register_decomposition(aten.mv) +@out_wrapper() +@pw_cast_for_opmath +def mv(self, vec): + torch._check( + self.dim() == 2 and vec.dim() == 1, + lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}", + ) + torch._check( + self.size(1) == vec.size(0), + lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})", + ) + return (self * vec).sum(dim=1) + + +@register_decomposition(aten.binary_cross_entropy_with_logits) +@out_wrapper() +def binary_cross_entropy_with_logits( + self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value +): + if pos_weight is not None: + log_weight = (pos_weight - 1) * target + 1 + loss = (1 - target) * self - (log_weight * F.logsigmoid(self)) + else: + loss = (1 - target) * self - F.logsigmoid(self) + + if weight is not None: + loss = loss * weight + + return apply_loss_reduction(loss, reduction) + + +def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> bool: + # For comments of the logic of this function see eager in /native/LinearAlgebra.cpp + + t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1) + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if not (t1.ndim >= 3 and t2.ndim <= 2): + return False + if t2.requires_grad and not is_out: + return True + if tensor1.ndim == 2: + return False + if guard_size_oblivious(t1.numel() == 0): + return True + + t1_shape = t1.shape + t1_stride = t1.stride() + return all( + st1 == st2 * s2 + for (st1, st2, s2) in zip(t1_stride[:-2], t1_stride[1:-1], t1_shape[1:-1]) + ) + + +@aten.matmul.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.matmul.out.py_impl(DispatchKey.CompositeImplicitAutograd) +@out_wrapper(pass_is_out=True) +def matmul(tensor1, tensor2, *, is_out=False): + dim_tensor1 = tensor1.dim() + dim_tensor2 = tensor2.dim() + assert dim_tensor1 != 0 and dim_tensor2 != 0 + if dim_tensor1 == 1 and dim_tensor2 == 1: + return torch.dot(tensor1, tensor2) + elif dim_tensor1 == 2 and dim_tensor2 == 1: + return torch.mv(tensor1, tensor2) + elif dim_tensor1 == 1 and dim_tensor2 == 2: + return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0) + elif dim_tensor1 == 2 and dim_tensor2 == 2: + return torch.mm(tensor1, tensor2) + elif should_fold(tensor1, tensor2, is_out): + # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) || + # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2) + # and some condition on the strides is fulfilled + + # optimization: use mm instead of bmm by folding the batch of the larger tensor + # into its leading matrix dimension + transpose = dim_tensor2 > dim_tensor1 + t1 = tensor2.mT if transpose else tensor1 + t2 = ( + tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1) + ) + # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2) + # and t1 and t2 are matmul-compatible + + # Why not t1.view(-1, sizes_1[-1])? + # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous. + # This can happen in e.g. [3, 5, 0] @ [0, 0]. + sizes_1 = t1.shape + output_shape = list(sizes_1[:-1]) + folded_dim1 = reduce(operator.mul, output_shape) + + # Readjust output_shape if we are multiplying by a matrix + t2_is_matrix = t2.dim() == 2 + if t2_is_matrix: + output_shape.append(t2.shape[1]) + + # This will almost always be a view. + # It may not be a view if t2->requires_grad(). See should_fold in aten/ for an explanation + t1_folded = t1.reshape(folded_dim1, sizes_1[-1]) + if t2_is_matrix: + # This copies if we perform a 2D @ 3D and the first tensor requires_grad + # See should_fold native/LinearAlgebra.cpp for why. + output = t1_folded.mm(t2).view(output_shape) + return output.mT.contiguous() if transpose else output + else: + return t1_folded.mv(t2).view(output_shape) + + elif dim_tensor1 >= 1 and dim_tensor2 >= 1: + # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list); + # we track m1 vs m2 separately even though they must match for nicer error messages + n = tensor1.size(-2) if dim_tensor1 > 1 else 1 + m1 = tensor1.size(-1) + batch_tensor1 = tensor1.shape[:-2] + m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1) + p = tensor2.size(-1) if dim_tensor2 > 1 else 1 + + batch_tensor2: List[int] = [] + # TODO: handling of slice + for i in range(dim_tensor2 - 2): + batch_tensor2.append(tensor2.size(i)) + + # Same optimization for the gradients as that in should_fold + # If we're going to broadcast, we force it to go through the should_fold branch + if ( + dim_tensor1 == 3 + and dim_tensor2 == 3 + and batch_tensor1[0] != batch_tensor2[0] + ): + if batch_tensor1[0] == 1 and tensor1.requires_grad: + return matmul(tensor1.squeeze(0), tensor2) + if batch_tensor2[0] == 1 and tensor2.requires_grad: + return matmul(tensor1, tensor2.squeeze(0)) + + # expand the batch portion (i.e. cut off matrix dimensions and expand rest) + expand_batch_portion = list( + torch.broadcast_shapes(batch_tensor1, batch_tensor2) + ) + + tensor1_expand_size = expand_batch_portion + [n, m1] + + expand_batch_product = prod(expand_batch_portion) + + # HACK: We need reshape with symint support + tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape( + expand_batch_product, n, m1 + ) + + vector_rhs = dim_tensor2 == 1 + if vector_rhs: + tensor2_expand_size = expand_batch_portion + [m2] + tensor2_expanded = ( + tensor2.expand(tensor2_expand_size) + .reshape(expand_batch_product, m2) + .unsqueeze(2) + ) + else: + tensor2_expand_size = expand_batch_portion + [m2, p] + tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape( + expand_batch_product, m2, p + ) + + output_shape = expand_batch_portion + if dim_tensor1 > 1: + output_shape.append(n) + + if dim_tensor2 > 1: + output_shape.append(p) + + if vector_rhs: + return tensor1_expanded.bmm(tensor2_expanded).squeeze(-1).view(output_shape) + else: + return tensor1_expanded.bmm(tensor2_expanded).view(output_shape) + else: + torch._check(False, lambda: "both arguments to matmul need to be at least 1D") + + +@register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out]) +@aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper() +@pw_cast_for_opmath +def upsample_bicubic2d_default( + input: Tensor, + output_size: Tuple[int, int], + align_corners: bool, + scale_h: Optional[float] = None, + scale_w: Optional[float] = None, +) -> Tensor: + # get dimensions of original image + _, _, in_h, in_w = input.shape + + # Calculate horizontal and vertical scaling factor + h_scale_factor = _compute_scale(in_h, output_size[0], align_corners, scale_h) + w_scale_factor = _compute_scale(in_w, output_size[1], align_corners, scale_w) + + _, dtype = utils.elementwise_dtypes( + input, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + # We have to create arange with int64 dtype and use .to in order to avoid + # additional kernels creation in inductor and get a perf slowdown + i = torch.arange(output_size[0], device=input.device).to(dtype=dtype) + j = torch.arange(output_size[1], device=input.device).to(dtype=dtype) + + x_float = _compute_source_index(w_scale_factor, j, align_corners) + y_float = _compute_source_index(h_scale_factor, i, align_corners) + y_float = y_float.unsqueeze(-1) + + x = x_float.floor() + y = y_float.floor() + + # We should also clamp xscale/yscale + # See guard_index_and_lambda in UpSample.h + yscale = (y_float - y).clamp(0.0, 1.0) + xscale = (x_float - x).clamp(0.0, 1.0) + x = x.to(torch.int64) + y = y.to(torch.int64) + + iys_ofs = (y - 1, y, y + 1, y + 2) + ixs_ofs = (x - 1, x, x + 1, x + 2) + + weights_x = _upsample_get_cubic_coefficients(xscale) + weights_y = _upsample_get_cubic_coefficients(yscale) + + weights_precision_x, weights_precision_y = None, None + if input.dtype == torch.uint8: + weights_precision_x = _compute_weight_precision(weights_x) + weights_precision_y = _compute_weight_precision(weights_y) + + weights_x = [ + (w * (1 << weights_precision_x) + torch.sign(w) * 0.5).to(torch.int16) + for w in weights_x + ] + weights_y = [ + (w * (1 << weights_precision_y) + torch.sign(w) * 0.5).to(torch.int16) + for w in weights_y + ] + + def load_bounded(ys, xs): + y_idx = torch.clamp(ys, 0, in_h - 1) + x_idx = torch.clamp(xs, 0, in_w - 1) + v = aten._unsafe_index(input, [None, None, y_idx, x_idx]) + return v + + def get_x_interp(y): + src_x = tuple(load_bounded(y, x_ofs) for x_ofs in ixs_ofs) + if input.dtype == torch.uint8: + assert weights_precision_x is not None + return _sum_tensors_uint8(src_x, weights_x, weights_precision_x) + return _sum_tensors(c1 * c2 for (c1, c2) in zip(src_x, weights_x)) + + src_y = tuple(get_x_interp(y_ofs) for y_ofs in iys_ofs) + if input.dtype == torch.uint8: + assert weights_precision_y is not None + result = _sum_tensors_uint8(src_y, weights_y, weights_precision_y) + else: + result = _sum_tensors(c1 * c2 for (c1, c2) in zip(src_y, weights_y)) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + result = result.contiguous(memory_format=memory_format) + return result + + +@register_decomposition(aten.upsample_bicubic2d.vec) +@aten.upsample_bicubic2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_bicubic2d.vec.py_impl(DispatchKey.Autograd) +@out_wrapper() +@pw_cast_for_opmath +def upsample_bicubic2d_vec( + a: Tensor, + output_size: Optional[Tuple[int, int]], + align_corners: bool, + scale_factors: Optional[Tuple[float, float]] = None, +) -> Tensor: + torch._check( + bool(output_size) + bool(scale_factors) == 1, + lambda: "Must specify exactly one of output_size and scale_factors.", + ) + if output_size is None: + assert scale_factors is not None + output_size = cast( + Tuple[int, int], + tuple( + sym_int(sym_float(w) * scale) + for w, scale in zip(a.shape[2:], scale_factors) + ), + ) + scale_h, scale_w = scale_factors if scale_factors else (None, None) + return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w) + + +@register_decomposition(aten.reflection_pad1d) +@register_decomposition(aten.reflection_pad2d) +@register_decomposition(aten.reflection_pad3d) +@pw_cast_for_opmath +@out_wrapper() +def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return middle - 1 - (middle - 1 - dim_idx.abs()).abs() + + return _reflection_or_replication_pad( + a, + padding, + idx, + ) + + +@register_decomposition(aten.replication_pad1d) +@register_decomposition(aten.replication_pad2d) +@register_decomposition(aten.replication_pad3d) +@pw_cast_for_opmath +@out_wrapper() +def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return torch.clamp(dim_idx, 0, middle - 1) + + return _reflection_or_replication_pad( + a, + padding, + idx, + ) + + +def _reflection_or_replication_pad( + a: Tensor, + padding: Tuple[int, ...], + idx_fn: Callable[[int, int, int], Tensor], +) -> Tensor: + dim = len(padding) // 2 + torch._check( + a.dim() in (dim + 1, dim + 2), + lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", + ) + inp_shape = a.shape[-dim:] + nc_dim = a.dim() - dim + + padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] + padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] + + result = a + for i in range(dim): + idx: List[Any] = [None] * result.dim() + idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i]) + result = aten._unsafe_index(result, idx) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(result) + result = result.contiguous(memory_format=memory_format) + return result + + +@register_decomposition(aten.reflection_pad1d_backward) +@register_decomposition(aten.reflection_pad2d_backward) +@register_decomposition(aten.reflection_pad3d_backward) +@out_wrapper("grad_input") +def _reflection_pad_backward(grad_output, x, padding): + dim = len(padding) // 2 + + dhw = [h - 1 for h in x.shape[-dim:]] + + padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] + padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] + + indices = [] + for i in range(x.ndim): + view_shape = [1] * x.ndim + view_shape[i] = -1 + indices.append(torch.arange(x.shape[i], device=x.device).view(view_shape)) + + b = indices[:-dim] + xyz = indices[-dim:] + + def index_range_condition(index_range): + i, lb, ub = index_range + return torch.logical_and(i >= lb, i <= ub) + + # Areas after reflection: + # + # top-left | top | top-right + # ----------------------------------------- + # left | center | right + # ----------------------------------------- + # bottom-left | bottom | bottom-right + # + # The center area is the original matrix. Other areas are reflections. + + center = [xyz[i] + padding_left[i] for i in range(dim)] + left_reflect = [padding_left[i] - xyz[i] for i in range(dim)] + right_reflect = [2 * dhw[i] + padding_left[i] - xyz[i] for i in range(dim)] + + # Accumulate gradients from different areas + # If some of the padding is negative, center load is not always valid + range_c = [ + (center[i], 0, dhw[i] + padding_left[i] + padding_right[i]) for i in range(dim) + ] + cond = functools.reduce( + aten.logical_and, [index_range_condition(range_c[i]) for i in range(dim)] + ) + grad = aten._unsafe_masked_index(grad_output, cond, b + center, 0.0) + + def accumulate(grad, out, index_ranges): + # If the upper bound is less than the lower bound, we can get rid of one accumulation. + # This happens when the padding size is zero. + for i in range(dim): + upper_less_than_lower = index_ranges[i][2] < index_ranges[i][1] + if isinstance(upper_less_than_lower, bool) and upper_less_than_lower: + return grad + + cond = functools.reduce( + aten.logical_and, + [index_range_condition(index_range) for index_range in index_ranges], + ) + g = aten._unsafe_masked_index(grad_output, cond, b + out, 0.0) + return grad + g + + for area in itertools.product(*[[-1, 0, 1] for _ in range(dim)]): + if area == tuple([0] * dim): + # center, this is already done. + continue + + outs = [] + index_ranges = [] + + for i in range(dim): + if area[i] == 0: + out = center[i] + index_range = range_c[i] + elif area[i] == -1: + out = left_reflect[i] + index_range = (xyz[i], 1, padding_left[i]) + elif area[i] == 1: + out = right_reflect[i] + index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1) + + outs.append(out) # type: ignore[possibly-undefined] + index_ranges.append(index_range) # type: ignore[possibly-undefined] + + grad = accumulate(grad, outs, index_ranges) + + return grad + + +@register_decomposition(aten.aminmax) +@out_wrapper("min", "max") +def aminmax(self, *, dim=None, keepdim=False): + amin = torch.amin(self, dim=dim, keepdim=keepdim) + amax = torch.amax(self, dim=dim, keepdim=keepdim) + return amin, amax + + +@register_decomposition(aten.nansum) +@out_wrapper() +def nansum(self, dim=None, keepdim=False, *, dtype=None): + return aten.sum(torch.where(torch.isnan(self), 0, self), dim, keepdim, dtype=dtype) + + +@register_decomposition([aten.arange.default, aten.arange.out]) +@out_wrapper() +def arange_default( + end: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + pin_memory: bool = False, +): + return aten.arange.start_step( + 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_decomposition([aten.arange.start]) +def arange_start( + start: NumberType, + end: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + pin_memory: bool = False, +): + return aten.arange.start_step( + start, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_decomposition(out_dtype) +def out_dtype_decomp(*args, **kwargs): + from torch._higher_order_ops.out_dtype import out_dtype_dense + + return out_dtype_dense(*args, **kwargs) + + +@register_decomposition(aten.multi_margin_loss) +@aten.multi_margin_loss.default.py_impl(DispatchKey.Autograd) +@out_wrapper() +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: NumberType = 1, + margin: NumberType = 1, + weight: Optional[Tensor] = None, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + input = torch.atleast_2d(input) + target = torch.atleast_1d(target) + nframe = input.shape[0] + dim = input.shape[1] + torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported") + torch._check( + input.ndim == 2 and dim != 0, + lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {input.shape}", + ) + torch._check( + target.ndim == 1 and target.numel() == nframe, + lambda: f"inconsistent target size, expected {nframe} but got {target.shape}", + ) + if weight is not None: + weight = torch.atleast_1d(weight) + torch._check( + weight.ndim == 1 and weight.numel() == dim, # type: ignore[union-attr] + lambda: f"inconsistent weight size, expected {dim} but got {weight.shape}", # type: ignore[union-attr] + ) + target = target.unsqueeze(1) + u = torch.gather(input, dim=1, index=target) + z = margin - u + input + z = z.clamp_min(0) + z = z if p == 1 else z * z + if weight is not None: + z = z * weight[target] + idx = torch.arange(dim, device=input.device) + z = torch.where(idx != target, z, 0) + if reduction == Reduction.MEAN.value: + return z.mean() + elif reduction == Reduction.SUM.value: + return z.sum() / z.shape[1] + else: + return z.mean(dim=1) + + +@register_decomposition(aten.multilabel_margin_loss_forward) +@aten.multilabel_margin_loss_forward.default.py_impl(DispatchKey.Autograd) +@out_wrapper("output", "is_target") +def multilabel_margin_loss_forward( + input: Tensor, + target: Tensor, + reduction: int, +) -> Tuple[Tensor, Tensor]: + orig_input_shape = input.shape + orig_target_shape = target.shape + input = torch.atleast_2d(input) + target = torch.atleast_2d(target) + dim = input.shape[1] + torch._check( + len(orig_input_shape) <= 2 and dim != 0, + lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {orig_input_shape}", + ) + torch._check( + len(orig_target_shape) <= 2 and orig_target_shape == orig_input_shape, + lambda: f"inconsistent target size: {orig_target_shape} for input of size: {orig_input_shape}", + ) + # ignores labels after the first -1, detects when -1 is not present + idx = torch.arange(dim, device=target.device) + is_end = target == -1 + end_idx = torch.amin(torch.where(is_end, idx, dim), dim=-1, keepdim=True) + # target indices + target_mask = idx < end_idx + # masks target to be able to use gather, which doesn't allow -1 + tidx0 = torch.where(target_mask, target, 0) + u = torch.gather(input, dim=-1, index=tidx0) + # is_target + tidx1 = torch.where(target_mask, target, -1) + is_target = torch.any(idx == tidx1.unsqueeze(dim=-1), dim=1) + # loss + z = 1.0 - u.T.unsqueeze(dim=-1) + input + z = z.clamp_min(0) + z = z / dim + # masks loss + z = torch.where(is_target, 0, z) + # reduction + if reduction == Reduction.MEAN.value: + z = z.sum(dim=(0, -1)).mean() + elif reduction == Reduction.SUM.value: + z = z.sum() + else: + z = z.sum(dim=(0, -1)) + # result + is_target = is_target.to(input.dtype).reshape(orig_target_shape) + return z, is_target + + +# scaled_dot_product_attention used to be decomposed in pre-autograd, given that +# it calls _scaled_dot_product_attention_math and +# _scaled_dot_product_attention_math only has a CompositeImplicitAutograd +# kernel. As a result it's decomposed into ops with finer granularity. +# However recent PRs (#103826 #105131 #115913) added new logic in +# scaled_dot_product_attention and now it calls +# _scaled_dot_product_flash_attention_for_cpu in export path. This results +# in _scaled_dot_product_flash_attention_for_cpu showing up in export result. +# This decomposition ensures scaled_dot_product_attention is still decomposed +# the same way as before, i.e., going through +# _scaled_dot_product_attention_math. Notice that this decomp rule should be +# excluded by inductor. +@register_decomposition(aten._scaled_dot_product_flash_attention_for_cpu.default) +def scaled_dot_product_flash_attention_for_cpu( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + attn_mask: Optional[Tensor] = None, + scale: Optional[float] = None, +) -> Tuple[Tensor, Tensor]: + dtype = query.dtype + torch._check( + torch.is_floating_point(query), + lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}", + ) + torch._check( + query.dim() == 4 and key.dim() == 4 and value.dim() == 4, + lambda: f"q, k, v must be a 4 dimensional tensor, got {query.dim()}, {key.dim()}, {value.dim()}", + ) + torch._check( + dropout_p == 0.0, lambda: f"dropout probability must be zero, got {dropout_p}" + ) + torch._check( + query.shape[3] == value.shape[3] and key.shape[3] == value.shape[3], + lambda: "q, k, v should have the same head size", + ) + + output, attn = aten._scaled_dot_product_attention_math.default( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + dropout_mask=None, + scale=scale, + ) + # Why this change? + # In pre-dispatch export scaled_dot_product_attention is executed via + # * flash_attention. + # flash_attention allocates output tensor as (N, L, H, E) + # it then transposes that to get (N, H, L, E) which is supposed to be the return + # tensor dim for scaled_dot_product_attention + # assume x: [N, H, L, E] is the output sdpa + # In MHA code, this output is then permuted via (2, 0, 1, 3) to get + # (L, N, H, E) dim tensor + # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via + # x = x.view(L * N, H * E) + # During pre autograd dispatch call to contiguous is not traced because + # flash_attention output after the x.permute is already contiguous + # on which the view is valid + # However, during 2nd stage export, post-dispatch, we run _match variant + # instead of flash* to get the decomposition. _match variant returns + # x: [N, H, L, E] applying x.permute(2, 0, 1, 3) returns + # x: [L, N, H, E] and without converting this to contiguous tensor + # subsequent view is not valid and the export fails + # solution is to maintain the return tensor view from the decomp to be + # exactly same as *flash* variant. + # flash variants output is contiguous as [N, L, H, E] + # _match variant out is contiguous as [N, H, L, E] + # out = out.transpose(1, 2).contiguous gets output as contiguous + # in [N, L, H, E]. + # Subsrequent transpose(1, 2) then returns a view on which + # aforementioned code snippet, as showm below, is valid + # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via + # x = x.view(L * N, H * E) + + # Really the invariant you want to maintain is: + # pre-dispatch op-output and its decomposed representation must + # return tensor with same view and dims + output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format) + return (output.transpose(1, 2), attn) + + +def register_inplace(aten_op, outplace_op): + @register_decomposition(aten_op) + def inplace_op(*args, **kwargs): + out = outplace_op(*args, **kwargs) + return args[0].copy_(out) + + return inplace_op + + +@register_decomposition([aten.baddbmm]) +@out_wrapper() +@pw_cast_for_opmath +def baddbmm(self, batch1, batch2, beta=1, alpha=1): + if not self.is_floating_point() and not self.is_complex(): + beta = int(beta) + alpha = int(alpha) + result = torch.bmm(batch1, batch2) + if not isinstance(alpha, numbers.Number) or alpha != 1: + result = result * alpha + if beta == 0: + return result + if not isinstance(beta, numbers.Number) or beta != 1: + self = self * beta + return self + result + + +@register_decomposition(aten.floor_divide) +@out_wrapper() +def floor_divide(self, other): + return torch.div(self, other, rounding_mode="floor") + + +@register_decomposition(aten.sym_numel) +def sym_numel(t): + return functools.reduce(operator.mul, t.shape, 1) + + +@register_decomposition([aten.sum.default, aten.sum.out]) +def sum_default( + self: Tensor, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> Tensor: + if out is None: + return aten.sum.dim_IntList(self, [], dtype=dtype) + else: + return aten.sum.IntList_out(self, [], dtype=dtype, out=out) + + +@register_decomposition([aten.squeeze.default, aten.squeeze.dim]) +def squeeze_default(self: Tensor, dim: Optional[int] = None): + # handle a scalar directly + if not isinstance(self, torch.Tensor): + return self + # perform squeeze + if dim is None: + return aten.squeeze.dims(self, list(range(self.dim()))) + else: + return aten.squeeze.dims(self, [dim]) + + +@register_decomposition(torch.ops.aten._weight_norm_interface) +def _weight_norm_interface(v, g, dim=0): + # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58 + keep_dim = tuple(i for i in range(len(v.shape)) if i != dim) + # align with cuda behavior, keep norm in 'float' when g is 'bfloat16' + norm_dtype = torch.float if g.dtype == torch.bfloat16 else None + norm = v.norm(2, keep_dim, keepdim=True, dtype=norm_dtype) + return v * (g / norm.to(g.dtype)), norm + + +@register_decomposition(aten.isin) +@out_wrapper() +def isin(elements, test_elements, *, assume_unique=False, invert=False): + # handle when either elements or test_elements are Scalars (they can't both be) + if not isinstance(elements, torch.Tensor): + elements = torch.tensor(elements, device=test_elements.device) + if not isinstance(test_elements, torch.Tensor): + test_elements = torch.tensor(test_elements, device=elements.device) + + if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145): + return isin_default(elements, test_elements, invert=invert) + else: + return isin_sorting( + elements, test_elements, assume_unique=assume_unique, invert=invert + ) + + +def isin_default(elements, test_elements, *, invert=False): + if elements.numel() == 0: + return torch.empty_like(elements, dtype=torch.bool) + + x = elements.view(*elements.shape, *((1,) * test_elements.ndim)) + if not invert: + cmp = x == test_elements + else: + cmp = x != test_elements + dim = tuple(range(-1, -test_elements.ndim - 1, -1)) + return cmp.any(dim=dim) + + +def isin_sorting(elements, test_elements, *, assume_unique=False, invert=False): + elements_flat = elements.flatten() + test_elements_flat = test_elements.flatten() + if assume_unique: + # This is the same as the aten implementation. For + # assume_unique=False, we cannot use unique() here, so we use a + # version with searchsorted instead. + all_elements = torch.cat([elements_flat, test_elements_flat]) + sorted_elements, sorted_order = torch.sort(all_elements, stable=True) + + duplicate_mask = sorted_elements[1:] == sorted_elements[:-1] + duplicate_mask = torch.constant_pad_nd(duplicate_mask, [0, 1], False) + + if invert: + duplicate_mask = duplicate_mask.logical_not() + + mask = torch.empty_like(duplicate_mask) + mask = mask.index_copy(0, sorted_order, duplicate_mask) + + return mask[0 : elements.numel()] + else: + sorted_test_elements, _ = torch.sort(test_elements_flat) + idx = torch.searchsorted(sorted_test_elements, elements_flat) + test_idx = torch.where(idx < sorted_test_elements.numel(), idx, 0) + cmp = sorted_test_elements[test_idx] == elements_flat + cmp = cmp.logical_not() if invert else cmp + return cmp.reshape(elements.shape) + + +@register_decomposition(aten.take) +@out_wrapper() +def take(self, index): + flattened = self.reshape(-1) + return flattened[index] + + +@register_decomposition(aten.resize_as) +def resize_as(self, other, memory_format=None): + if memory_format is None: + memory_format = torch.contiguous_format + if memory_format == torch.preserve_format: + memory_format = suggest_memory_format(other) + return aten.resize(self, other.shape, memory_format=memory_format) + + +register_inplace(aten.addbmm_, aten.addbmm) +register_inplace(aten.addmm_, aten.addmm) +register_inplace(aten.addmv_, aten.addmv) +register_inplace(aten.baddbmm_, aten.baddbmm) +register_inplace(aten.fill_, aten.fill) +register_inplace(aten.gelu_, aten.gelu) +register_inplace(aten.hardswish_, aten.hardswish) +register_inplace(aten.hardtanh_, aten.hardtanh) +register_inplace(aten.hardsigmoid_, aten.hardsigmoid) +register_inplace(aten.__iand__, aten.__and__) +register_inplace(aten.__ilshift__, aten.__lshift__) +register_inplace(aten.index_put_, aten.index_put) +register_inplace(aten.index_reduce_, aten.index_reduce) +register_inplace(aten.__ior__, aten.__or__) +register_inplace(aten.__irshift__, aten.__rshift__) +register_inplace(aten.__ixor__, aten.__xor__) +register_inplace(aten.leaky_relu_, aten.leaky_relu) +register_inplace(aten.logit_, aten.logit) +register_inplace(aten.relu_, aten.relu) +register_inplace(aten.renorm_, aten.renorm) +register_inplace(aten.round_, aten.round) +register_inplace(aten.scatter_, aten.scatter) +register_inplace(aten.scatter_add_, aten.scatter_add) +register_inplace(aten.scatter_reduce_, aten.scatter_reduce) +register_inplace(aten.silu_, aten.silu) diff --git a/lib/python3.10/site-packages/torch/_decomp/decompositions_for_jvp.py b/lib/python3.10/site-packages/torch/_decomp/decompositions_for_jvp.py new file mode 100644 index 0000000000000000000000000000000000000000..b542b7c511c4ad10bdc3ab083a991145d0262de3 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_decomp/decompositions_for_jvp.py @@ -0,0 +1,335 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import inspect +from typing import Callable, Dict, List, Optional, Tuple + +import torch +import torch._decomp +from torch import Tensor +from torch._prims_common.wrappers import _maybe_remove_out_wrapper + + +decomposition_table = torch._decomp.decomposition_table +decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {} +register_decomposition = torch._decomp.register_decomposition +aten = torch.ops.aten + +# NOTE: [forward-mode AD decompositions mechanism] +# +# The mechanism is in VariableType, +# IF any inputs have forward grad +# AND there is no forward AD formula implemented +# AND the functions is actually differentiable +# run the decomposition +# See run_jit_decomposition_with_args_for_jvp +# We currently use python decompositions that we torchscript. +# +# Note that we would be building the backward graph at the decomposed level +# too, but that is OK, because we would've errored out otherwise anyway. +# +# TODO: The mechanism we are using to register decompositions doesn't +# seem to be exclusively used for jvp. So open question here is whether +# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things. +# If that is the case, we may go down the decomposition path unexpectedly +# (and possibly produce an unintelligible error) vs erroring out earlier and +# printing that the forward AD formula is not implemented. +# +# The solution to this may be to have a explicitly white list control when +# to enable the decomposition. + + +def maybe_register_decomposition(op): + def decorator(f): + try: + return register_decomposition(op)(f) + except Exception: + return f + + return decorator + + +# Functions where we need a special decomposition for jvp but there's another version that +# should be used more generally (ex. for jvp we need to recompute the mean and variance for +# the backwards of a normalization function. Without jvp, it should use the saved value) +decomposition_table_for_jvp = {} + + +def register_decomposition_for_jvp(fn): + return register_decomposition(fn, registry=decomposition_table_for_jvp) + + +def _register_jit_decomposition_for_jvp(decomp, use_python=False): + if decomp in decomposition_table_for_jvp: + decomposition_table_used = decomposition_table_for_jvp + elif decomp in decomposition_table: + decomposition_table_used = decomposition_table + else: + raise RuntimeError(f"could not find decomposition for {decomp}") + decomp_fn = decomposition_table_used[decomp] + + # `out_wrapper` extends a decompositions signature with + # an `out` parameter. However jit will use the unwrapped function's + # signature instead so we need to unwrap here to prevent an error + decomp_fn = _maybe_remove_out_wrapper(decomp_fn) + + if use_python: + decomp_fn = torch.jit.ignore(decomp_fn) + sig = inspect.signature(decomp_fn) + + # Create a string wrapping the function from the signature + # example output: + # def wrapped_decomp(x: torch.Tensor, y: int, z: int): + # return decomp_fn(x, y, z) + # Thanks copilot! + def get_function_def(sig): + param_def = [f"{param_str}" for param_str in sig.parameters.values()] + param_use = [f"{param_str}" for param_str in sig.parameters.keys()] + + return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n" + + f_str = get_function_def(sig) + graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph + else: + graph = torch.jit.script(decomp_fn).graph + torch.jit._register_decomposition(decomp, graph) + + +# The only decompositions here are temporary or hacks for the purposes of jvp + + +# TODO: do these also belong here? +@maybe_register_decomposition(aten.trace.default) +def trace(self: Tensor) -> Tensor: + return torch.sum(torch.diag(self)) + + +@maybe_register_decomposition(aten.log_sigmoid_forward.default) +def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: + min = torch.minimum(self.new_zeros(()), self) + z = torch.exp(-torch.abs(self)) + if self.is_cuda: + buffer = self.new_zeros((0,)) + else: + buffer = z + return min - torch.log1p(z), buffer + + +def recompute_mean_var( + input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool +): + # for most norm decompositions, it will be the same as the core version except for here. + # We recompute the mean and variance so that they track gradients through input + + mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim) + var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim) + eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside + eps = eps.detach() + rstd = 1 / torch.sqrt(var + eps) + return mean, rstd + + +@register_decomposition_for_jvp(aten.native_layer_norm_backward) +def native_layer_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: List[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: List[bool], +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices = list(range(axis, input_ndim)) + outer_dim_indices = list(range(0, axis)) + + N = 1 + for i in inner_dims: + N *= i + M = 1 + for i in outer_dims: + M *= i + if M <= 0 or N <= 0: + return ( + input.new_zeros(input_shape), + input.new_zeros(input_shape[axis:]), + input.new_zeros(input_shape[axis:]), + ) + + mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True) + + x_hat = (input - mean_) * rstd_ + if weight is not None: + grad_x_hat = grad_out * weight + else: + grad_x_hat = grad_out + a = grad_x_hat * N + b = torch.sum(grad_x_hat, inner_dim_indices, True) + c1 = torch.mul(grad_x_hat, x_hat) + c2 = torch.sum(c1, inner_dim_indices, True) + c3 = torch.mul(x_hat, c2) + inner = a - b - c3 + + if output_mask[0]: + d_input: Optional[Tensor] = (rstd_ / N) * inner + else: + d_input = torch.zeros_like(input) # should be None but doesn't work with vjp + + if output_mask[1] and weight is not None: + if len(outer_dim_indices) > 0: + d_weight: Optional[Tensor] = torch.sum( + grad_out * x_hat, outer_dim_indices, False + ) + else: + d_weight = grad_out * x_hat + elif weight is not None: + d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp + else: + d_weight = torch.zeros(()) # should be None but doesn't work with vjp + + if output_mask[2] and bias is not None: + if len(outer_dim_indices) > 0: + d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False) + else: + d_bias = grad_out.clone() + elif bias is not None: + d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp + else: + d_bias = torch.zeros(()) # should be None but doesn't work with vjp + + return (d_input, d_weight, d_bias) + + +def prod(x: List[int]): + r = 1 + for i in x: + r *= i + return r + + +@register_decomposition_for_jvp(aten.native_batch_norm_backward) +def native_batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: List[bool], +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_rank = input.dim() + assert input_rank >= 2, "rank of the input must be at least 2" + + axis = 1 + num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type] + mean = save_mean + invstd = save_invstd + if train: + assert ( + save_mean is not None and save_invstd is not None + ), "when train=True, save_mean and save_invstd are required" + + reduciton_dims = [0] + list(range(2, input.dim())) + assert invstd is not None # for typing + mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False) + else: + assert running_mean is not None and running_var is not None + mean = running_mean + invstd = torch.rsqrt(running_var + eps) + + assert invstd is not None and mean is not None + + broadcast_mask = [1] * input_rank + broadcast_mask[axis] = input_shape[axis] + + reduction_axes: List[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) + + mean = torch.reshape(mean, broadcast_mask) + norm = 1.0 / num_features + grad_output_sum = torch.sum(grad_out, reduction_axes) + dot_p = torch.sum(grad_out * (input - mean), reduction_axes) + + grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask) + proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) + + if weight is None: + grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0 + else: + grad_scale = torch.reshape(invstd * weight, broadcast_mask) + + if train: + proj = (input - mean) * proj_scale + grad_input = ((grad_out - proj) - grad_mean) * grad_scale + else: + grad_input = grad_out * grad_scale + + if output_mask[1]: + grad_weight = dot_p * invstd + elif weight is not None: + grad_weight = torch.zeros_like( + weight + ) # should be None but doesn't work with vjp + else: + grad_weight = torch.zeros(()) # should be None but doesn't work with vjp + + if output_mask[2]: + grad_bias = grad_output_sum + else: + grad_bias = torch.zeros_like( + grad_output_sum + ) # should be None but doesn't work with vjp + + return (grad_input, grad_weight, grad_bias) + + +@register_decomposition_for_jvp(aten.batch_norm_backward) +def batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_var: Optional[Tensor], + update: bool, + eps: float, + output_mask: List[bool], + reserve: Tensor, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + return native_batch_norm_backward( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_var, + update, + eps, + output_mask, + ) + + +_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True) +_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default) +_register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.miopen_batch_norm_backward.default) diff --git a/lib/python3.10/site-packages/torch/_decomp/decompositions_for_rng.py b/lib/python3.10/site-packages/torch/_decomp/decompositions_for_rng.py new file mode 100644 index 0000000000000000000000000000000000000000..a62a28f783b7131dbccdae2ac9198aca13c1bf53 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_decomp/decompositions_for_rng.py @@ -0,0 +1,266 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +from collections import defaultdict +from typing import Callable, Dict + +import torch +import torch._decomp as decomp +from torch._decomp import get_decompositions +from torch._ops import OpOverload + + +aten = torch.ops.aten + +rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict) + + +def register_rng_decomposition(aten_op): + return decomp.register_decomposition(aten_op, rng_decompositions) + + +def throw_on_non_cuda(device): + raise RuntimeError( + f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not " + f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is " + "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU." + ) + + +# TODO - We have to register many more distributions here, and also higher level +# ops like dropout which have fused implementation and can hide the rand inside. +@register_rng_decomposition(aten.rand) +def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False): + if device and device.type != "cuda": + throw_on_non_cuda(device) + seed, offset = PhiloxStateTracker.get_state_as_tuple() + dtype = dtype or torch.float32 + out, offset_jump = torch.ops.rngprims.philox_rand( + shape, seed, offset, None, device, dtype + ) + PhiloxStateTracker.advance_offset(offset_jump) + return out + + +@register_rng_decomposition(aten.rand_like) +def rand_like( + x: torch.Tensor, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=torch.preserve_format, +): + device = device or x.device + if device.type != "cuda": + throw_on_non_cuda(device) + dtype = dtype or x.dtype + seed, offset = PhiloxStateTracker.get_state_as_tuple() + out, offset_jump = torch.ops.rngprims.philox_rand( + x.shape, seed, offset, None, device, dtype + ) + PhiloxStateTracker.advance_offset(offset_jump) + return out + + +class PhiloxState: + """ + Represents a PhiloxRngState - (seed, offset) where offset = base_offset + + relative_offset. seed and base_offset basically point to the rng state just + before tracing starts. relative offset tracks the totally consumed offset at + trace time. + """ + + def __init__(self) -> None: + self.reset() + + def reset(self): + self.seed = torch.tensor(()) + self.base_offset = torch.tensor(()) + self.relative_offset = 0 + self.offset_advanced_alteast_once = False + + def validate_state(self): + assert self.seed.numel() != 0 and self.base_offset.numel() != 0 + + def advance_offset(self, consumed_offset): + self.offset_advanced_alteast_once = True + self.relative_offset = self.relative_offset + consumed_offset + + def set_state(self, seed, base_offset, relative_offset=0): + self.seed = seed + self.base_offset = base_offset + self.relative_offset = relative_offset + + def get_state_as_tuple(self): + self.validate_state() + return (self.seed, self.base_offset + self.relative_offset) + + def get_state_as_tensor(self): + # Only needed because we override get_rng_state. + self.validate_state() + return torch.stack([self.seed, self.base_offset + self.relative_offset]) + + def set_state_from_tensor(self, state): + # Only needed because we override set_rng_state. + self.seed, self.base_offset = torch.unbind(state) + self.relative_offset = 0 + + +class PhiloxStateTracker: + """ + Singleton class to track the philox rng state during AOT Autograd tracing. + For each aot tracing instance, AOT Autograd resets this tracker and keeps + track of both forward and backward offsets. At runtime, we only care about + the total consumed forward and backward offsets. For dynamic shapes, these + offsets are a function of input shapes. Therefore, the AOT generated graphs + have additional outputs that compute total consumed forward and backward + offsets. + """ + + running_state: PhiloxState + fwd_state: PhiloxState + bwd_state: PhiloxState + + def __enter__(self): + PhiloxStateTracker.reset() + return self + + def __exit__(self, exc_type, exc_cal, exc_tb): + PhiloxStateTracker.reset() + + @classmethod + def reset(cls): + cls.running_state = PhiloxState() + cls.fwd_state = PhiloxState() + cls.bwd_state = PhiloxState() + + @classmethod + def mark_beginning_of_forward(cls): + # Tells the tracker to use fwd_state as the running state + cls.running_state = cls.fwd_state + + @classmethod + def mark_beginning_of_backward(cls): + # Tells the tracker to use bwd_state as the running state + cls.running_state = cls.bwd_state + + @classmethod + def record_state(cls, seed, offset, mode): + # Records the seed and offset tensors. These tensors are used to invoke + # the philox_rand functional primitives. + if mode == "forward": + cls.fwd_state.set_state(seed, offset) + cls.mark_beginning_of_forward() + else: + assert mode == "backward" + cls.bwd_state.set_state(seed, offset) + + @classmethod + def get_state_as_tensor(cls): + # The only reason this exists is because we override get_rng_state and + # set_rng_state during tracing. get_rng_state expects a tensor output, + # so return (seed, offset) tuple upset other parts of the program like + # ctx.saved_tensors. + + # A bad consequence is that if user saves and restores rng state, we + # have little bit of ugliness in the generated code, where we first + # concat the (seed, offset) to create a tensor for get_rng_state, and + # then split it back to get (seed, offset) tuple in set_rng_state. + + # TODO: Investigate if there is be a better way to wrap the tuple in a + # false Tensor object, and then desugar it later on. + return cls.running_state.get_state_as_tensor() + + @classmethod + def get_state_as_tuple(cls): + return cls.running_state.get_state_as_tuple() + + @classmethod + def set_state_from_tensor(cls, x): + # This is only needed because we override set_rng_state. Look at the + # comment in get_state_from_tensor method. + cls.running_state.set_state_from_tensor(x) + + @classmethod + def advance_offset(cls, consumed_offset): + cls.running_state.advance_offset(consumed_offset) + + @classmethod + def get_current_relative_offset(cls): + return cls.running_state.relative_offset + + @staticmethod + def multiple_of_4(offset): + # torch cuda rng state offset must be a multiple of 4. For inductor, as + # we sum up all the numel, the result might not be a multiple of 4. This + # method achieves that. + return (offset + 3) // 4 * 4 + + @classmethod + def get_updated_fwd_offset(cls): + # Short circuit if no rand ops were observed + if not cls.fwd_state.offset_advanced_alteast_once: + return cls.fwd_state.base_offset + return cls.multiple_of_4( + cls.fwd_state.base_offset + cls.fwd_state.relative_offset + ) + + @classmethod + def get_updated_bwd_offset(cls): + # Short circuit if no rand ops were observed + if not cls.bwd_state.offset_advanced_alteast_once: + return cls.bwd_state.base_offset + return cls.multiple_of_4( + cls.bwd_state.base_offset + cls.bwd_state.relative_offset + ) + + +# Adding more decompositions which eventually use rand_like inside decomps. +# Adding these in rng_decompositions ensures the functionalization of rand_like +# ops used in these decomps. The list is copied from inductor codebase, which +# uses it for similar purpose. +# +# Caution - These decomps do not have same accuracy as that of eager. However, +# we can't just disable them with a config flag like fallback_random, because +# for functionalization of rng ops, we have to decompose these ops. +extra_random_decomps = get_decompositions( + [ + aten.cauchy, + aten.cauchy_, + aten.exponential, + aten.exponential_, + aten.geometric, + aten.geometric_, + aten.native_dropout, + aten.normal, + aten.normal_, + aten.normal_functional, + aten.log_normal, + aten.log_normal_, + aten.rrelu_with_noise, + aten.rrelu_with_noise_, + aten.uniform_, + ] +) +register_extra_random_decomp = functools.partial( + decomp.register_decomposition, registry=extra_random_decomps +) + + +@register_extra_random_decomp([aten.bernoulli_]) +def bernoulli_(self, p=0.5): + if self.device == torch.device("cpu"): + return NotImplemented + return self.copy_(torch.rand_like(self, dtype=torch.float32) < p) + + +@register_extra_random_decomp([aten.bernoulli.p]) +def bernoulli_p(self, p=0.5, *, generator=None): + if self.device == torch.device("cpu"): + return NotImplemented + assert generator is None + return torch.rand_like(self, dtype=torch.float32) < p + + +rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type] diff --git a/lib/python3.10/site-packages/torch/_dispatch/__init__.py b/lib/python3.10/site-packages/torch/_dispatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/torch/_dispatch/python.py b/lib/python3.10/site-packages/torch/_dispatch/python.py new file mode 100644 index 0000000000000000000000000000000000000000..8b0eb69e9c3845ec5086a1f79aae4c1a6befa1ef --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dispatch/python.py @@ -0,0 +1,180 @@ +# mypy: allow-untyped-defs +import itertools +import unittest.mock +from contextlib import contextmanager +from typing import Iterator + +import torch +import torch._C +import torch._ops +import torch.utils._python_dispatch +import torch.utils._pytree as pytree + + +__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"] + +no_python_dispatcher = torch._C._DisablePythonDispatcher +enable_python_dispatcher = torch._C._EnablePythonDispatcher +enable_pre_dispatch = torch._C._EnablePreDispatch + +CROSSREF_FUNCTIONALIZE = False + + +def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]: + """ + Warning: the set of overloads this will report is very subtle. It is precisely + the set of torch.ops functions that have actually been accessed from Python + (e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT + from the set of registered operators, which will in general be a larger set, + as this would include all operators which we ran C++ static initializers or + Python operator registration on. This does not eagerly populate the list on + torch.ops.aten; this list is lazy! + + In other words, this is good for traversing over everything that has an + OpOverload object allocated in Python. We use it for cache invalidation, but + don't rely on this list being complete. + + Note that even if we did report all C++ registered overloads, this isn't guaranteed + to be complete either, as a subsequent lazy load of a library which triggers more + registrations could add more things to the set. + """ + for ns in torch.ops: + packets = getattr(torch.ops, ns) + for op_name in packets: + packet = getattr(packets, op_name) + for overload in packet: + yield getattr(packet, overload) + + +@contextmanager +def suspend_functionalization(): + f_tls = torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.Functionalize + ) + f_rv = torch._C._functionalization_reapply_views_tls() + if f_tls: + torch._disable_functionalization() + try: + yield + finally: + if f_tls: + torch._enable_functionalization(reapply_views=f_rv) + + +def check_tensor_metadata_matches(nv, rv, desc): + assert callable(desc) + assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}" + assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}" + same_strides, idx = torch._prims_common.check_significant_strides( + nv, rv, only_cuda=False + ) + assert ( + same_strides + ), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})" + + +def check_metadata_matches(n, r, desc): + assert callable(desc) + n_vals, n_spec = pytree.tree_flatten(n) + r_vals, r_spec = pytree.tree_flatten(r) + # TODO: test the specs match; empirically sometimes we have a tuple + # on one side and a list on the other + assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" + for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): + if not isinstance(rv, torch.Tensor): + continue + check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}") + + +class Lit: + def __init__(self, s): + self.s = s + + def __repr__(self): + return self.s + + +def _fmt(a: object) -> object: + if isinstance(a, torch.Tensor): + return Lit( + f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})" + ) + else: + return a + + +def make_crossref_functionalize(op, final_key): + from torch._subclasses.fake_tensor import FakeTensorMode + + # This case is pretty weird, suppress it for now + if op == torch.ops.aten.lift_fresh.default: + return final_key + + def handler(*args, **kwargs): + fake_mode = FakeTensorMode() + + def fakeify_defun(t): + if isinstance(t, torch.Tensor): + if torch._is_functional_tensor(t): + r = torch._from_functional_tensor(t) + # NB: This assumes that the inner tensor sizes/strides match + # the outer tensor sizes/strides. This doesn't necessarily have to + # be the case, see discussion at + # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456 + assert t.size() == r.size() + assert t.stride() == r.stride() + else: + r = t + # TODO: suppress guards + return fake_mode.from_tensor(r) + return t + + def maybe_detach(t): + if isinstance(t, torch.Tensor): + return t.detach() + else: + return t + + # TODO: This probably does the wrong thing if you're running other + # substantive modes with the normal op outside here + with torch.utils._python_dispatch._disable_current_modes(), suspend_functionalization(): + f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs)) + orig_f_args, orig_f_kwargs = pytree.tree_map( + maybe_detach, (f_args, f_kwargs) + ) + with fake_mode: + f_r = op(*f_args, **f_kwargs) + r = op._op_dk(final_key, *args, **kwargs) + + def desc(): + fmt_args = ", ".join( + itertools.chain( + (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args), + ( + f"{k}={pytree.tree_map(_fmt, v)}" + for k, v in orig_f_kwargs.items() + ), + ) + ) + return f"{op}({fmt_args})" + + check_metadata_matches(f_r, r, desc) + return r + + return handler + + +# NB: enabling this is slow, don't do it in a hot loop. This is purely +# for debugging purposes. +@contextmanager +def enable_crossref_functionalize(): + for op in all_py_loaded_overloads(): + op._uncache_dispatch(torch._C.DispatchKey.Functionalize) + try: + with enable_python_dispatcher(), unittest.mock.patch( + "torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True + ): + yield + finally: + for op in all_py_loaded_overloads(): + op._uncache_dispatch(torch._C.DispatchKey.Functionalize) diff --git a/lib/python3.10/site-packages/torch/_dynamo/__init__.py b/lib/python3.10/site-packages/torch/_dynamo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7f58ba7f7bf7f1331b9e483e103ac57cbc0dee71 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/__init__.py @@ -0,0 +1,109 @@ +import torch + +from . import convert_frame, eval_frame, resume_execution +from .backends.registry import list_backends, lookup_backend, register_backend +from .callback import callback_handler, on_compile_end, on_compile_start +from .code_context import code_context +from .convert_frame import replay +from .decorators import ( + allow_in_graph, + assume_constant_result, + disable, + disallow_in_graph, + forbid_in_graph, + graph_break, + mark_dynamic, + mark_static, + mark_static_address, + maybe_mark_dynamic, + run, + substitute_in_graph, +) +from .eval_frame import ( + _reset_guarded_backend_cache, + explain, + export, + is_dynamo_supported, + is_inductor_supported, + optimize, + optimize_assert, + OptimizedModule, + reset_code, +) +from .external_utils import is_compiling +from .mutation_guard import GenerationTracker +from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count + + +# Register polyfill functions +from .polyfills import loader as _ # usort: skip # noqa: F401 + + +__all__ = [ + "allow_in_graph", + "assume_constant_result", + "disallow_in_graph", + "forbid_in_graph", + "substitute_in_graph", + "graph_break", + "mark_dynamic", + "maybe_mark_dynamic", + "mark_static", + "mark_static_address", + "optimize", + "optimize_assert", + "export", + "explain", + "run", + "replay", + "disable", + "reset", + "OptimizedModule", + "is_compiling", + "register_backend", + "list_backends", + "lookup_backend", +] + +if torch.manual_seed is torch.random.manual_seed: + import torch.jit._builtins + + # Wrap manual_seed with the disable decorator. + # Can't do it at its implementation due to dependency issues. + torch.manual_seed = torch._disable_dynamo(torch.manual_seed) + # Add the new manual_seed to the builtin registry. + torch.jit._builtins._register_builtin(torch.manual_seed, "aten::manual_seed") + + +def reset() -> None: + """Clear all compile caches and restore initial state""" + with convert_frame.compile_lock: + reset_code_caches() + convert_frame.input_codes.clear() + convert_frame.output_codes.clear() + orig_code_map.clear() + guard_failures.clear() + graph_break_reasons.clear() + resume_execution.ContinueExecutionCache.cache.clear() + _reset_guarded_backend_cache() + reset_frame_count() + torch._C._dynamo.compiled_autograd.clear_cache() + convert_frame.FRAME_COUNTER = 0 + convert_frame.FRAME_COMPILE_COUNTER.clear() + callback_handler.clear() + GenerationTracker.clear() + torch._dynamo.utils.warn_once_cache.clear() + torch._dynamo.utils.user_obj_id_to_weakref.clear() + torch._C._autograd._saved_tensors_hooks_set_tracing(False) + + +def reset_code_caches() -> None: + """Clear compile caches that are keyed by code objects""" + with convert_frame.compile_lock: + for weak_code in ( + convert_frame.input_codes.seen + convert_frame.output_codes.seen + ): + code = weak_code() + if code: + reset_code(code) + code_context.clear() diff --git a/lib/python3.10/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py b/lib/python3.10/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c698ded100943a6add2f80637e0873bfea5aa4b5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -0,0 +1,127 @@ +# mypy: allow-untyped-defs +import torch +from torch._C import DispatchKey +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses import FakeTensorMode +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.utils._python_dispatch import _get_current_dispatch_mode +from torch.utils._pytree import tree_map_only + + +__all__ = ["trace_wrapped"] + + +# trace_wrapped(*args, fn) is equivalent to fn(*args), but with a twist: +# if you make_fx trace through this call, we will not actually trace into fn; instead, +# we will directly insert it as a call_function to fn in the graph. +# (Unlike make_fx, Dynamo WILL inline into fn.) +# You can think of this as a one off allow_in_graph equivalent for proxy tensor tracing. +# +# Because proxy tensor tracing does not actually run the function, there are +# requirements on the behavior of fn. We are still figuring it out, but here is the current state: +# +# 1) fn SHOULD only take a single argument, which must be a tensor +# 2) fn MUST return a new tensor with the same metadata as the original tensor +# (e.g., zeros_like(input) is a permissible implementation of fn). +# This is verified via an extra assert that is inserted into the traced graph. +# 3) fn MAY have side effects, but it MAY NOT perform metadata mutation on other tensors +# participating in proxy tensor tracing (it MAY mutate other tensors, it MAY mutate Python state) +# These requirements stem from the requirement that we need to continue performing proxy tensor tracing, +# which assumes accurate fake tensor metadata, without actually running fn. +# In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns. +# +# Note that tensors / Python state are allowed to be mutated. +# This is relaxed constraint is not always sound, but it is sound for backward tracing with fake +# tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete +# tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python). +# +# The intended use case for this function is to allow AOTAutograd to defer complex +# backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves +# the function call as is in the graph, and only when we Dynamo through the backward graph in +# compiled autograd do we inline into the function. + + +def trace_wrapped(*args, **kwargs): + with torch.no_grad(): + return _trace_wrapped_op(*args, **kwargs) + + +class TraceWrapped(HigherOrderOperator): + def __init__(self): + super().__init__("trace_wrapped") + + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + + +# TODO(jansel): need to ensure this does not get DCEed +_trace_wrapped_op = TraceWrapped() + + +def _assert_meta(grad, size, stride, dtype): + assert grad.size() == size, "size mismatch" + assert grad.stride() == stride, "stride mismatch" + assert grad.dtype == dtype, "dtype mismatch" + return grad + + +@_trace_wrapped_op.py_impl(ProxyTorchDispatchMode) +def inner_trace(mode, *args, bw_state=None, **kwargs): + def self_invoke(*args, **dyn_kwargs): + with torch.no_grad(): + return _trace_wrapped_op(*args, **dyn_kwargs, **kwargs) + + def unwrap_proxies(x): + if isinstance(x, torch.Tensor): + return mode.tracer.unwrap_proxy(x) + if isinstance(x, (list, tuple)): + return type(x)(map(unwrap_proxies, x)) + if x is None: + return None + raise AssertionError(f"unhandled type: {type(x)}") + + proxy_kwargs = {} + if bw_state is not None: + assert isinstance(bw_state, BackwardState) and bw_state.proxy is not None + proxy_kwargs["bw_state"] = bw_state.proxy + out_proxy = mode.tracer.create_proxy( + "call_function", + self_invoke, + unwrap_proxies(args), + proxy_kwargs, + name="trace_wrapped", + ) + + if args[0] is None: + grad = args[1] # module backward hooks + else: + grad = args[0] # other backward hooks + grad = tree_map_only(torch.Tensor, torch.empty_like, grad) + track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer) + return grad + + +@_trace_wrapped_op.py_impl(FakeTensorMode) +def inner_fake(*args, **kwargs): + raise RuntimeError("This op should never be invoked here") + + +@_trace_wrapped_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def _trace_wrapped_op_dense(*args, fn, **kwargs): + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return fn(*args, **kwargs) + + +_trace_wrapped_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(_trace_wrapped_op, deferred_error=True) +) + + +@_trace_wrapped_op.py_functionalize_impl +def _trace_wrapped_functionalized(ctx, *args, **kwargs): + unwrapped_args = ctx.unwrap_tensors(args) + with ctx.redispatch_to_next(): + return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, **kwargs)) diff --git a/lib/python3.10/site-packages/torch/_dynamo/bytecode_analysis.py b/lib/python3.10/site-packages/torch/_dynamo/bytecode_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..fe2ea31b09e558a6c0f5bf80a3bb037475e10b1a --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/bytecode_analysis.py @@ -0,0 +1,257 @@ +# mypy: allow-untyped-defs +import bisect +import dataclasses +import dis +import sys +from typing import Any, Set, Union + + +TERMINAL_OPCODES = { + dis.opmap["RETURN_VALUE"], + dis.opmap["JUMP_FORWARD"], + dis.opmap["RAISE_VARARGS"], + # TODO(jansel): double check exception handling +} +if sys.version_info >= (3, 9): + TERMINAL_OPCODES.add(dis.opmap["RERAISE"]) +if sys.version_info >= (3, 11): + TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD"]) + TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"]) +else: + TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"]) +if sys.version_info >= (3, 12): + TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"]) +JUMP_OPCODES = set(dis.hasjrel + dis.hasjabs) +JUMP_OPNAMES = {dis.opname[opcode] for opcode in JUMP_OPCODES} +HASLOCAL = set(dis.haslocal) +HASFREE = set(dis.hasfree) + +stack_effect = dis.stack_effect + + +def get_indexof(insts): + """ + Get a mapping from instruction memory address to index in instruction list. + Additionally checks that each instruction only appears once in the list. + """ + indexof = {} + for i, inst in enumerate(insts): + assert inst not in indexof + indexof[inst] = i + return indexof + + +def remove_dead_code(instructions): + """Dead code elimination""" + indexof = get_indexof(instructions) + live_code = set() + + def find_live_code(start): + for i in range(start, len(instructions)): + if i in live_code: + return + live_code.add(i) + inst = instructions[i] + if inst.exn_tab_entry: + find_live_code(indexof[inst.exn_tab_entry.target]) + if inst.opcode in JUMP_OPCODES: + find_live_code(indexof[inst.target]) + if inst.opcode in TERMINAL_OPCODES: + return + + find_live_code(0) + + # change exception table entries if start/end instructions are dead + # assumes that exception table entries have been propagated, + # e.g. with bytecode_transformation.propagate_inst_exn_table_entries, + # and that instructions with an exn_tab_entry lies within its start/end. + if sys.version_info >= (3, 11): + live_idx = sorted(live_code) + for i, inst in enumerate(instructions): + if i in live_code and inst.exn_tab_entry: + # find leftmost live instruction >= start + start_idx = bisect.bisect_left( + live_idx, indexof[inst.exn_tab_entry.start] + ) + assert start_idx < len(live_idx) + # find rightmost live instruction <= end + end_idx = ( + bisect.bisect_right(live_idx, indexof[inst.exn_tab_entry.end]) - 1 + ) + assert end_idx >= 0 + assert live_idx[start_idx] <= i <= live_idx[end_idx] + inst.exn_tab_entry.start = instructions[live_idx[start_idx]] + inst.exn_tab_entry.end = instructions[live_idx[end_idx]] + + return [inst for i, inst in enumerate(instructions) if i in live_code] + + +def remove_pointless_jumps(instructions): + """Eliminate jumps to the next instruction""" + pointless_jumps = { + id(a) + for a, b in zip(instructions, instructions[1:]) + if a.opname == "JUMP_ABSOLUTE" and a.target is b + } + return [inst for inst in instructions if id(inst) not in pointless_jumps] + + +def propagate_line_nums(instructions): + """Ensure every instruction has line number set in case some are removed""" + cur_line_no = None + + def populate_line_num(inst): + nonlocal cur_line_no + if inst.starts_line: + cur_line_no = inst.starts_line + + inst.starts_line = cur_line_no + + for inst in instructions: + populate_line_num(inst) + + +def remove_extra_line_nums(instructions): + """Remove extra starts line properties before packing bytecode""" + + cur_line_no = None + + def remove_line_num(inst): + nonlocal cur_line_no + if inst.starts_line is None: + return + elif inst.starts_line == cur_line_no: + inst.starts_line = None + else: + cur_line_no = inst.starts_line + + for inst in instructions: + remove_line_num(inst) + + +@dataclasses.dataclass +class ReadsWrites: + reads: Set[Any] + writes: Set[Any] + visited: Set[Any] + + +def livevars_analysis(instructions, instruction): + indexof = get_indexof(instructions) + must = ReadsWrites(set(), set(), set()) + may = ReadsWrites(set(), set(), set()) + + def walk(state, start): + if start in state.visited: + return + state.visited.add(start) + + for i in range(start, len(instructions)): + inst = instructions[i] + if inst.opcode in HASLOCAL or inst.opcode in HASFREE: + if "LOAD" in inst.opname or "DELETE" in inst.opname: + if inst.argval not in must.writes: + state.reads.add(inst.argval) + elif "STORE" in inst.opname: + state.writes.add(inst.argval) + elif inst.opname == "MAKE_CELL": + pass + else: + raise NotImplementedError(f"unhandled {inst.opname}") + if inst.exn_tab_entry: + walk(may, indexof[inst.exn_tab_entry.target]) + if inst.opcode in JUMP_OPCODES: + walk(may, indexof[inst.target]) + state = may + if inst.opcode in TERMINAL_OPCODES: + return + + walk(must, indexof[instruction]) + return must.reads | may.reads + + +@dataclasses.dataclass +class FixedPointBox: + value: bool = True + + +@dataclasses.dataclass +class StackSize: + low: Union[int, float] + high: Union[int, float] + fixed_point: FixedPointBox + + def zero(self): + self.low = 0 + self.high = 0 + self.fixed_point.value = False + + def offset_of(self, other, n): + prior = (self.low, self.high) + self.low = min(self.low, other.low + n) + self.high = max(self.high, other.high + n) + if (self.low, self.high) != prior: + self.fixed_point.value = False + + def exn_tab_jump(self, depth): + prior = (self.low, self.high) + self.low = min(self.low, depth) + self.high = max(self.high, depth) + if (self.low, self.high) != prior: + self.fixed_point.value = False + + +def stacksize_analysis(instructions) -> Union[int, float]: + assert instructions + fixed_point = FixedPointBox() + stack_sizes = { + inst: StackSize(float("inf"), float("-inf"), fixed_point) + for inst in instructions + } + stack_sizes[instructions[0]].zero() + + for _ in range(100): + if fixed_point.value: + break + fixed_point.value = True + + for inst, next_inst in zip(instructions, instructions[1:] + [None]): + stack_size = stack_sizes[inst] + # CALL_FINALLY in Python 3.8 is handled differently when determining stack depth. + # See https://github.com/python/cpython/blob/3.8/Python/compile.c#L5450. + # Essentially, the stack effect of CALL_FINALLY is computed with jump=True, + # but the resulting stack depth is propagated to the next instruction, not the + # jump target. + is_call_finally = ( + sys.version_info < (3, 9) and inst.opcode == dis.opmap["CALL_FINALLY"] + ) + if inst.opcode not in TERMINAL_OPCODES: + assert next_inst is not None, f"missing next inst: {inst}" + # total stack effect of CALL_FINALLY and END_FINALLY in 3.8 is 0 + eff = ( + 0 + if is_call_finally + else stack_effect(inst.opcode, inst.arg, jump=False) + ) + stack_sizes[next_inst].offset_of(stack_size, eff) + if inst.opcode in JUMP_OPCODES and not is_call_finally: + stack_sizes[inst.target].offset_of( + stack_size, stack_effect(inst.opcode, inst.arg, jump=True) + ) + if inst.exn_tab_entry: + # see https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + # on why depth is computed this way. + depth = inst.exn_tab_entry.depth + int(inst.exn_tab_entry.lasti) + 1 + stack_sizes[inst.exn_tab_entry.target].exn_tab_jump(depth) + + if False: + for inst in instructions: + stack_size = stack_sizes[inst] + print(stack_size.low, stack_size.high, inst) + + low = min(x.low for x in stack_sizes.values()) + high = max(x.high for x in stack_sizes.values()) + + assert fixed_point.value, "failed to reach fixed point" + assert low >= 0 + return high diff --git a/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py b/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9a0ce5d4eb5aa2c32e6d0c433057b8c1afc9a6 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py @@ -0,0 +1,1503 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import dis +import itertools +import sys +import types +from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union + +from .bytecode_analysis import ( + get_indexof, + propagate_line_nums, + remove_extra_line_nums, + stacksize_analysis, +) + + +@dataclasses.dataclass +class InstructionExnTabEntry: + start: "Instruction" + end: "Instruction" + target: "Instruction" + depth: int + lasti: bool + + def __repr__(self) -> str: + return ( + f"InstructionExnTabEntry(start={self.start.short_inst_repr()}, " + f"end={self.end.short_inst_repr()}, " + f"target={self.target.short_inst_repr()}, " + f"depth={self.depth}, lasti={self.lasti})" + ) + + def __eq__(self, o) -> bool: + return ( + self.start is o.start + and self.end is o.end + and self.target is o.target + and self.depth == o.depth + and self.lasti == o.lasti + ) + + +@dataclasses.dataclass +class Instruction: + """A mutable version of dis.Instruction""" + + opcode: int + opname: str + arg: Optional[int] + argval: Any + offset: Optional[int] = None + starts_line: Optional[int] = None + is_jump_target: bool = False + positions: Optional["dis.Positions"] = None + # extra fields to make modification easier: + target: Optional["Instruction"] = None + exn_tab_entry: Optional[InstructionExnTabEntry] = None + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other) -> bool: + return id(self) == id(other) + + def short_inst_repr(self) -> str: + return f"Instruction(opname={self.opname}, offset={self.offset})" + + +def convert_instruction(i: dis.Instruction) -> Instruction: + if sys.version_info >= (3, 13): + starts_line = i.line_number + else: + starts_line = i.starts_line + return Instruction( + i.opcode, + i.opname, + i.arg, + i.argval, + i.offset, + starts_line, + i.is_jump_target, + getattr(i, "positions", None), + ) + + +class _NotProvided: + def __repr__(self) -> str: + return "_NotProvided" + + +def inst_has_op_bits(name): + return (sys.version_info >= (3, 11) and name == "LOAD_GLOBAL") or ( + sys.version_info >= (3, 12) and name in ("LOAD_ATTR", "LOAD_SUPER_ATTR") + ) + + +def create_instruction( + name, *, arg=None, argval=_NotProvided, target=None +) -> Instruction: + """ + At most one of `arg`, `argval`, and `target` can be not None/_NotProvided. + This is to prevent ambiguity, e.g. does + create_instruction("LOAD_CONST", 5) + mean load the constant at co_consts[5], or load the constant 5? + + If `arg` is not provided, it will be computed during assembly from + `argval` or `target`. + + Bits in the args of instructions LOAD_GLOBAL, LOAD_ATTR (3.12+), and LOAD_SUPER_ATTR + modify the behavior of the instruction. In this case, we allow both `arg` + and `argval` to be set. The value of `arg` here is expected to be the value of + the op bits and the true value of `arg` will be computed during assembly. + If `arg` is not set, the bits are assumed to be 0. + """ + + # allow for instructions with op bits to have both arg and argval specified + if inst_has_op_bits(name): + if target is not None: + raise RuntimeError("target cannot be specified for instruction") + if arg is None: + arg = 0 + else: + cnt = (arg is not None) + (argval is not _NotProvided) + (target is not None) + if cnt > 1: + raise RuntimeError( + "only one of arg, argval, and target can be not None/_NotProvided" + ) + if arg is not None and not isinstance(arg, int): + raise RuntimeError("instruction arg must be int or None") + return Instruction( + opcode=dis.opmap[name], opname=name, arg=arg, argval=argval, target=target + ) + + +# Python 3.11 remaps +def create_jump_absolute(target) -> Instruction: + inst = "JUMP_FORWARD" if sys.version_info >= (3, 11) else "JUMP_ABSOLUTE" + return create_instruction(inst, target=target) + + +def create_dup_top() -> Instruction: + if sys.version_info >= (3, 11): + return create_instruction("COPY", arg=1) + return create_instruction("DUP_TOP") + + +def create_rot_n(n) -> List[Instruction]: + """ + Returns a "simple" sequence of instructions that rotates TOS to the n-th + position in the stack. For Python < 3.11, returns a single ROT_* + instruction. If no such instruction exists, an error is raised and the + caller is expected to generate an equivalent sequence of instructions. + For Python >= 3.11, any rotation can be expressed as a simple sequence of + swaps. + """ + if n <= 1: + # don't rotate + return [] + + if sys.version_info >= (3, 11): + # rotate can be expressed as a sequence of swap operations + # e.g. rotate 3 is equivalent to swap 3, swap 2 + return [create_instruction("SWAP", arg=i) for i in range(n, 1, -1)] + + # ensure desired rotate function exists + if sys.version_info < (3, 8) and n >= 4: + raise AttributeError(f"rotate {n} not supported for Python < 3.8") + if sys.version_info < (3, 10) and n >= 5: + raise AttributeError(f"rotate {n} not supported for Python < 3.10") + + if n <= 4: + return [create_instruction("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])] + return [create_instruction("ROT_N", arg=n)] + + +def add_push_null( + inst_or_insts: Union[Instruction, List[Instruction]], +) -> List[Instruction]: + """ + Appends or prepends a PUSH_NULL instruction to `inst_or_insts`, + depending on Python version. Used when you know that + `inst_or_insts` generates a callable that will be called. + + NOTE: Assumes `inst_or_insts` is a single instruction or sequence of + instructions that pushes exactly 1 object to the stack that is to + be called. It is important that you include ALL instructions that + construct the callable - not just the first instruction/a prefix. + + Will attempt to use the NULL push bit for instructions + with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR). + In this case, instructions WILL be modified. + """ + if isinstance(inst_or_insts, Instruction): + insts = [inst_or_insts] + else: + insts = inst_or_insts + + def inst_has_bit_set(idx): + assert insts[idx].arg is not None + return insts[idx].arg & 1 == 1 + + def set_inst_bit(idx): + assert insts[idx].arg is not None + insts[idx].arg |= 1 + + if sys.version_info >= (3, 13): + # In 3.13, NULL follows the callable + if inst_has_op_bits(insts[-1].opname) and not inst_has_bit_set(-1): + # All insts with op bits have the push_null bit as the last one. + # Only set the bit if it hasn't been set - otherwise, we need + # to add another PUSH_NULL. + set_inst_bit(-1) + else: + insts = insts + [create_instruction("PUSH_NULL")] + elif sys.version_info >= (3, 12): + # LOAD_ATTR/LOAD_SUPER_ATTR at the end + # We assume that `insts` will only load 1 object, so + # LOAD_GLOBAL at the end doesn't need to be checked + if inst_has_op_bits(insts[-1].opname) and not inst_has_bit_set(-1): + set_inst_bit(-1) + elif insts[0].opname == "LOAD_GLOBAL" and not inst_has_bit_set(0): + set_inst_bit(0) + else: + insts = [create_instruction("PUSH_NULL")] + insts + elif sys.version_info >= (3, 11): + # 3.11 introduced NULL preceding callable + if inst_has_op_bits(insts[0].opname) and not inst_has_bit_set(0): + set_inst_bit(0) + else: + insts = [create_instruction("PUSH_NULL")] + insts + return insts + + +def add_push_null_call_function_ex( + inst_or_insts: Union[Instruction, List[Instruction]], +) -> List[Instruction]: + """Like add_push_null, but the low bit of LOAD_ATTR/LOAD_SUPER_ATTR + is not set, due to an expected CALL_FUNCTION_EX instruction. + """ + if isinstance(inst_or_insts, Instruction): + insts = [inst_or_insts] + else: + insts = inst_or_insts + + if sys.version_info < (3, 11): + return insts + + idx = -1 if sys.version_info >= (3, 13) else 0 + if insts[idx].opname == "LOAD_GLOBAL": + assert insts[idx].arg is not None + if insts[idx].arg & 1 == 0: # type: ignore[operator] + insts[idx].arg |= 1 # type: ignore[operator] + return insts + + if sys.version_info >= (3, 13): + insts = insts + [create_instruction("PUSH_NULL")] + else: + insts = [create_instruction("PUSH_NULL")] + insts + + return insts + + +def create_call_function(nargs, push_null) -> List[Instruction]: + """ + Creates a sequence of instructions that makes a function call. + + `push_null` is used in Python 3.11+ only. It is used in codegen when + a function call is intended to be made with the NULL + fn convention, + and we know that the NULL has not been pushed yet. We will push a + NULL and rotate it to the correct position immediately before making + the function call. + + `push_null` should be True if no NULL is pushed for the callable. + Conversely, `push_null` should be False if a NULL was pushed for the callable. + Prefer using `push_null=False` when possible since we will not need to rotate + NULL to the right place, which is less efficient. + + Generally, you should codegen a function by using `add_push_null` then + `create_call_function` with `push_null=False`. + + Example of when to set push_null False: + + insts = [ + create_instruction("LOAD_GLOBAL", argval="torch"), + create_instruction("LOAD_ATTR", argval="nn"), + create_instruction("LOAD_ATTR", argval="functional"), + create_instruction("LOAD_ATTR", argval="relu"), + ] + insts = add_push_null(insts) + insts.append(create_instruction("LOAD_FAST", argval="x")) + insts.extend(create_call_function(1, False)) + + Example of when to set push_null True: + + insts = [create_instruction("LOAD_FAST", x)] + for should_wrap, wrapper_name in wrappers: + if should_wrap: + insts.extend([ + create_instruction("LOAD_GLOBAL", argval="wrapper1"), + create_instruction("SWAP", arg=2), + *create_call_function(1, True), + ) + """ + if sys.version_info >= (3, 11): + output = [] + if push_null: + output.append(create_instruction("PUSH_NULL")) + # 3.13 swapped NULL and callable + rots = nargs + 1 if sys.version_info >= (3, 13) else nargs + 2 + output.extend(create_rot_n(rots)) + if sys.version_info < (3, 12): + output.append(create_instruction("PRECALL", arg=nargs)) + output.append(create_instruction("CALL", arg=nargs)) + return output + return [create_instruction("CALL_FUNCTION", arg=nargs)] + + +def create_call_method(nargs) -> List[Instruction]: + if sys.version_info >= (3, 12): + return [create_instruction("CALL", arg=nargs)] + if sys.version_info >= (3, 11): + return [ + create_instruction("PRECALL", arg=nargs), + create_instruction("CALL", arg=nargs), + ] + return [create_instruction("CALL_METHOD", arg=nargs)] + + +def create_load_method(name) -> Instruction: + if sys.version_info >= (3, 12): + # in 3.12, create a LOAD_ATTR instruction with the low bit set + return create_instruction("LOAD_ATTR", arg=1, argval=name) + return create_instruction("LOAD_METHOD", argval=name) + + +def create_setup_with(target) -> Instruction: + opname = "BEFORE_WITH" if sys.version_info >= (3, 11) else "SETUP_WITH" + return create_instruction(opname, target=target) + + +def create_swap(n) -> List[Instruction]: + if sys.version_info >= (3, 11): + return [create_instruction("SWAP", arg=n)] + # in Python < 3.11, SWAP is a macro that expands to multiple instructions + if n == 1: + return [] + """ + e.g. swap "a" and "b" in this stack: + 0 a 1 2 3 b + 0 a [1 2 3 b] + 0 a [1 2 3 b] [1 2 3 b] + 0 a [1 2 3 b] [1 2 3 b] -1 + 0 a [1 2 3 b] b + 0 b a [1 2 3 b] + 0 b a [1 2 3 b] [1 2 3 b] + 0 b [1 2 3 b] a [1 2 3 b] + 0 b [1 2 3 b] a [1 2 3 b] -1 + 0 b [1 2 3 a] + 0 b [1 2 3 a] [1 2 3 a] + 0 b [1 2 3 a] [1 2 3 a] reverse + 0 b [a 3 2 1] None + 0 b [a 3 2 1] + 0 b 1 2 3 a + """ + return [ + create_instruction("BUILD_LIST", arg=n - 1), + create_instruction("DUP_TOP"), + create_instruction("LOAD_CONST", argval=-1), + create_instruction("BINARY_SUBSCR"), + create_instruction("ROT_THREE"), + create_instruction("DUP_TOP"), + create_instruction("ROT_THREE"), + create_instruction("LOAD_CONST", argval=-1), + create_instruction("STORE_SUBSCR"), + create_instruction("DUP_TOP"), + create_load_method("reverse"), + *create_call_method(0), + create_instruction("POP_TOP"), + create_instruction("UNPACK_SEQUENCE", arg=n - 1), + ] + + +def lnotab_writer( + lineno: int, byteno: int = 0 +) -> Tuple[List[int], Callable[[int, int], None]]: + """ + Used to create typing.CodeType.co_lnotab + See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt + This is the internal format of the line number table if Python < 3.10 + """ + assert sys.version_info < (3, 10) + lnotab: List[int] = [] + + def update(lineno_new, byteno_new): + nonlocal byteno, lineno + while byteno_new != byteno or lineno_new != lineno: + byte_offset = max(0, min(byteno_new - byteno, 255)) + line_offset = max(-128, min(lineno_new - lineno, 127)) + assert byte_offset != 0 or line_offset != 0 + byteno += byte_offset + lineno += line_offset + lnotab.extend((byte_offset, line_offset & 0xFF)) + + return lnotab, update + + +def linetable_310_writer(first_lineno): + """ + Used to create typing.CodeType.co_linetable + See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt + This is the internal format of the line number table for Python 3.10 + """ + assert sys.version_info >= (3, 10) and sys.version_info < (3, 11) + linetable: List[int] = [] + lineno = first_lineno + lineno_delta = 0 + byteno = 0 + + def _update(byteno_delta, lineno_delta): + while byteno_delta != 0 or lineno_delta != 0: + byte_offset = max(0, min(byteno_delta, 254)) + line_offset = max(-127, min(lineno_delta, 127)) + assert byte_offset != 0 or line_offset != 0 + byteno_delta -= byte_offset + lineno_delta -= line_offset + linetable.extend((byte_offset, line_offset & 0xFF)) + + def update(lineno_new, byteno_new): + nonlocal lineno, lineno_delta, byteno + byteno_delta = byteno_new - byteno + byteno = byteno_new + _update(byteno_delta, lineno_delta) + lineno_delta = lineno_new - lineno + lineno = lineno_new + + def end(total_bytes): + _update(total_bytes - byteno, lineno_delta) + + return linetable, update, end + + +def encode_varint(n: int) -> List[int]: + """ + 6-bit chunk encoding of an unsigned integer + See https://github.com/python/cpython/blob/3.11/Objects/locations.md + """ + assert n >= 0 + b = [n & 63] + n >>= 6 + while n > 0: + b[-1] |= 64 + b.append(n & 63) + n >>= 6 + return b + + +def linetable_311_writer(first_lineno: int): + """ + Used to create typing.CodeType.co_linetable + See https://github.com/python/cpython/blob/3.11/Objects/locations.md + This is the internal format of the line number table for Python 3.11 + """ + assert sys.version_info >= (3, 11) + linetable = [] + lineno = first_lineno + + def update(positions: "dis.Positions", inst_size): + nonlocal lineno + lineno_new = positions.lineno if positions else None + + def _update(delta, size): + assert 0 < size <= 8 + # first byte - use 13 (no column info) is positions is + # malformed, otherwise use 14 (long form) + other_varints: Tuple[int, ...] = () + if ( + positions + and positions.lineno is not None + and positions.end_lineno is not None + and positions.col_offset is not None + and positions.end_col_offset is not None + ): + linetable.append(0b1_1110_000 + size - 1) + # for whatever reason, column offset needs `+ 1` + # https://github.com/python/cpython/blob/1931c2a438c50e6250725c84dff94fc760b9b951/Python/compile.c#L7603 + other_varints = ( + positions.end_lineno - positions.lineno, + positions.col_offset + 1, + positions.end_col_offset + 1, + ) + else: + linetable.append(0b1_1101_000 + size - 1) + # encode signed int + if delta < 0: + delta = ((-delta) << 1) | 1 + else: + delta <<= 1 + # encode unsigned int + linetable.extend(encode_varint(delta)) + for n in other_varints: + linetable.extend(encode_varint(n)) + + if lineno_new is None: + lineno_delta = 0 + else: + lineno_delta = lineno_new - lineno + lineno = lineno_new + while inst_size > 8: + _update(lineno_delta, 8) + inst_size -= 8 + _update(lineno_delta, inst_size) + + return linetable, update + + +@dataclasses.dataclass +class ExceptionTableEntry: + start: int + end: int + target: int + depth: int + lasti: bool + + +def encode_exception_table_varint(n: int) -> List[int]: + """ + Similar to `encode_varint`, but the 6-bit chunks are ordered in reverse. + """ + assert n >= 0 + b = [n & 63] + n >>= 6 + while n > 0: + b.append(n & 63) + n >>= 6 + b.reverse() + for i in range(len(b) - 1): + b[i] |= 64 + return b + + +def decode_exception_table_varint(bytes_iter: Iterator[int]) -> int: + """ + Inverse of `encode_exception_table_varint`. + """ + b = next(bytes_iter) + val = b & 63 + while b & 64: + val <<= 6 + b = next(bytes_iter) + val |= b & 63 + return val + + +def check_exception_table(tab: List[ExceptionTableEntry]) -> None: + """ + Verifies that a list of ExceptionTableEntries will make a well-formed + jump table: entries are non-empty, sorted, and do not overlap. + """ + for i in range(len(tab) - 1): + assert ( + tab[i].start <= tab[i].end + and tab[i].end < tab[i + 1].start + and tab[i + 1].start <= tab[i + 1].end + ) + + +def parse_exception_table(exntab: bytes) -> List[ExceptionTableEntry]: + """ + Parse the exception table according to + https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + """ + exntab_iter = iter(exntab) + tab = [] + try: + while True: + start = decode_exception_table_varint(exntab_iter) * 2 + length = decode_exception_table_varint(exntab_iter) * 2 + end = start + length - 2 + target = decode_exception_table_varint(exntab_iter) * 2 + dl = decode_exception_table_varint(exntab_iter) + depth = dl >> 1 + lasti = bool(dl & 1) + tab.append(ExceptionTableEntry(start, end, target, depth, lasti)) + except StopIteration: + check_exception_table(tab) + return tab + + +def assemble_exception_table(tab: List[ExceptionTableEntry]) -> bytes: + """ + Inverse of parse_exception_table - encodes list of exception + table entries into bytes. + """ + b = [] + for entry in tab: + first_entry = encode_exception_table_varint(entry.start // 2) + first_entry[0] |= 1 << 7 + b.extend(first_entry) + length = entry.end - entry.start + 2 + b.extend(encode_exception_table_varint(length // 2)) + b.extend(encode_exception_table_varint(entry.target // 2)) + dl = (entry.depth << 1) + entry.lasti + b.extend(encode_exception_table_varint(dl)) + return bytes(b) + + +def assemble(instructions: List[Instruction], firstlineno: int) -> Tuple[bytes, bytes]: + """Do the opposite of dis.get_instructions()""" + code: List[int] = [] + if sys.version_info >= (3, 11): + lnotab, update_lineno = linetable_311_writer(firstlineno) + num_ext = 0 + for i, inst in enumerate(instructions): + if inst.opname == "EXTENDED_ARG": + inst_size = 1 + num_ext += 1 + # copy positions from the actual instruction + for j in (1, 2, 3): + if instructions[i + j].opname != "EXTENDED_ARG": + inst.positions = instructions[i + j].positions + break + else: + inst_size = instruction_size(inst) // 2 + num_ext + num_ext = 0 + update_lineno(inst.positions, inst_size) + num_ext = 0 + arg = inst.arg or 0 + code.extend((inst.opcode, arg & 0xFF)) + for _ in range(instruction_size(inst) // 2 - 1): + code.extend((0, 0)) + else: + if sys.version_info < (3, 10): + lnotab, update_lineno = lnotab_writer(firstlineno) + else: + lnotab, update_lineno, end = linetable_310_writer(firstlineno) + + for inst in instructions: + if inst.starts_line is not None: + update_lineno(inst.starts_line, len(code)) + arg = inst.arg or 0 + code.extend((inst.opcode, arg & 0xFF)) + + if sys.version_info >= (3, 10): + end(len(code)) + + return bytes(code), bytes(lnotab) + + +def _get_instruction_by_offset(offset_to_inst: Dict[int, Instruction], offset: int): + """ + Get the instruction located at a given offset, accounting for EXTENDED_ARGs + """ + for n in (0, 2, 4, 6): + if offset_to_inst[offset + n].opcode != dis.EXTENDED_ARG: + return offset_to_inst[offset + n] + return None + + +def virtualize_jumps(instructions) -> None: + """Replace jump targets with pointers to make editing easier""" + jump_targets = {inst.offset: inst for inst in instructions} + + for inst in instructions: + if inst.opcode in dis.hasjabs or inst.opcode in dis.hasjrel: + inst.target = _get_instruction_by_offset(jump_targets, inst.argval) + + +_REL_JUMPS = set(dis.hasjrel) + + +def flip_jump_direction(instruction: Instruction) -> None: + if sys.version_info < (3, 11): + raise RuntimeError("Cannot flip jump direction in Python < 3.11") + if "FORWARD" in instruction.opname: + instruction.opname = instruction.opname.replace("FORWARD", "BACKWARD") + elif "BACKWARD" in instruction.opname: + instruction.opname = instruction.opname.replace("BACKWARD", "FORWARD") + else: + raise AttributeError("Instruction is not a forward or backward jump") + instruction.opcode = dis.opmap[instruction.opname] + assert instruction.opcode in _REL_JUMPS + + +def _get_instruction_front(instructions: List[Instruction], idx: int): + """ + i.e. get the first EXTENDED_ARG instruction (if any) when targeting + instructions[idx] with a jump. + """ + target = instructions[idx] + for offset in (1, 2, 3): + if idx >= offset and instructions[idx - offset].opcode == dis.EXTENDED_ARG: + target = instructions[idx - offset] + else: + break + return target + + +def devirtualize_jumps(instructions): + """Fill in args for virtualized jump target after instructions may have moved""" + jumps = set(dis.hasjabs).union(set(dis.hasjrel)) + + # check for negative jump args and fix them + for inst in instructions: + if inst.opcode in jumps: + if inst.opcode not in dis.hasjabs: + if inst.target.offset < inst.offset: + if sys.version_info < (3, 11): + raise RuntimeError("Got negative jump offset for Python < 3.11") + # forward jumps become backward + if "FORWARD" in inst.opname: + flip_jump_direction(inst) + else: + # backward jumps become forward + if sys.version_info >= (3, 11) and "BACKWARD" in inst.opname: + flip_jump_direction(inst) + + # jump instruction size may have changed due to flips + update_offsets(instructions) + indexof = get_indexof(instructions) + + # compute jump instruction arg + for inst in instructions: + if inst.opcode in jumps: + target = _get_instruction_front(instructions, indexof[inst.target]) + if inst.opcode in dis.hasjabs: + if sys.version_info < (3, 10): + inst.arg = target.offset + elif sys.version_info < (3, 11): + # `arg` is expected to be bytecode offset, whereas `offset` is byte offset. + # Divide since bytecode is 2 bytes large. + inst.arg = int(target.offset / 2) + else: + raise RuntimeError("Python 3.11+ should not have absolute jumps") + else: # relative jump + # byte offset between target and next instruction + inst.arg = abs( + int(target.offset - inst.offset - instruction_size(inst)) + ) + if sys.version_info >= (3, 10): + # see bytecode size comment in the absolute jump case above + inst.arg //= 2 + inst.argval = target.offset + inst.argrepr = f"to {target.offset}" + + +def virtualize_exception_table(exn_tab_bytes: bytes, instructions: List[Instruction]): + """Replace exception table entries with pointers to make editing easier""" + exn_tab = parse_exception_table(exn_tab_bytes) + offset_to_inst = {cast(int, inst.offset): inst for inst in instructions} + offsets = sorted(offset_to_inst.keys()) + end_offset_idx = 0 + exn_tab_iter = iter(exn_tab) + try: + + def step(): + nonlocal end_offset_idx + entry = next(exn_tab_iter) + # find rightmost offset <= entry.end, since entry.end may not be + # an actual instruction, e.g. if the end instruction is LOAD_GLOBAL, + # which takes more than 2 bytes, then entry.end points to the end + # of the LOAD_GLOBAL instruction, not the beginning. + while ( + end_offset_idx < len(offsets) and offsets[end_offset_idx] <= entry.end + ): + end_offset_idx += 1 + assert end_offset_idx > 0 + end_offset = offsets[end_offset_idx - 1] + inst_entry = InstructionExnTabEntry( + _get_instruction_by_offset(offset_to_inst, entry.start), + _get_instruction_by_offset(offset_to_inst, end_offset), + _get_instruction_by_offset(offset_to_inst, entry.target), + entry.depth, + entry.lasti, + ) + return entry, inst_entry + + entry, inst_entry = step() + for inst in instructions: + while inst.offset > entry.end: + entry, inst_entry = step() + if inst.offset >= entry.start: + inst.exn_tab_entry = copy.copy(inst_entry) + except StopIteration: + pass + + +def compute_exception_table( + instructions: List[Instruction], +) -> List[ExceptionTableEntry]: + """Compute exception table in list format from instructions with exn_tab_entries""" + exn_dict: Dict[Tuple[int, int], Tuple[int, int, bool]] = {} + indexof = get_indexof(instructions) + + for inst in instructions: + if inst.exn_tab_entry: + # account for prefixed EXTENDED_ARGS + start = _get_instruction_front( + instructions, indexof[inst.exn_tab_entry.start] + ).offset + # point to the last 2 bytes of the end instruction + end = ( + cast(int, inst.exn_tab_entry.end.offset) + + instruction_size(inst.exn_tab_entry.end) + - 2 + ) + target = _get_instruction_front( + instructions, indexof[inst.exn_tab_entry.target] + ).offset + key = (start, end) + val = (target, inst.exn_tab_entry.depth, inst.exn_tab_entry.lasti) + if key in exn_dict: + assert exn_dict[key] == val + exn_dict[key] = val + + # Dynamo may construct nested exception table entries for convenience, + # but Python expects exception table entries to not overlap. + # NOTE: below, "keys" refer to old instruction entries' starts and ends, + # and "entries" refer to the generated exception table entries. + + # Sort keys by increasing start, then decreasing end + keys_sorted = sorted(exn_dict.keys(), key=lambda t: (t[0], -t[1])) + # smallest byte that the next exception table entry can start at + nexti = 0 + # stack of current nested keys + key_stack: List[Tuple[int, int]] = [] + exn_tab: List[ExceptionTableEntry] = [] + + def pop(): + """ + Pop the key_stack and append an exception table entry if possible. + """ + nonlocal nexti + if key_stack: + key = key_stack.pop() + if nexti <= key[1]: + exn_tab.append( + ExceptionTableEntry(max(key[0], nexti), key[1], *exn_dict[key]) + ) + nexti = key[1] + 2 + + for key in keys_sorted: + # pop keys that are no longer nested over the current key + while key_stack and key_stack[-1][1] < key[0]: + pop() + if key_stack: + # create an entry covering to the current key, if possible + assert key_stack[-1][0] <= key[0] <= key[1] <= key_stack[-1][1] + left = max(nexti, key_stack[-1][0]) + if left < key[0]: + exn_tab.append( + ExceptionTableEntry(left, key[0] - 2, *exn_dict[key_stack[-1]]) + ) + nexti = key[0] + key_stack.append(key) + while key_stack: + pop() + check_exception_table(exn_tab) + return exn_tab + + +def check_inst_exn_tab_entries_nested( + tab: List[InstructionExnTabEntry], indexof +) -> None: + """ + Checks `tab` is a properly sorted list of nested InstructionExnTabEntry's, + i.e. no entries partially overlap. + "Properly sorted" means entries are sorted by increasing starts, then + decreasing ends. + """ + entry_stack: List[Tuple[int, int]] = [] + for entry in tab: + key = (indexof[entry.start], indexof[entry.end]) + while entry_stack and entry_stack[-1][1] < key[0]: + entry_stack.pop() + if entry_stack: + assert entry_stack[-1][0] <= key[0] <= key[1] <= entry_stack[-1][1] + entry_stack.append(key) + + +def propagate_inst_exn_table_entries(instructions: List[Instruction]) -> None: + """ + Copies exception table entries to all instructions in an entry's range. + Supports nested exception table entries. + """ + indexof = get_indexof(instructions) + entries: Dict[Tuple[int, int], InstructionExnTabEntry] = {} + for inst in instructions: + if inst.exn_tab_entry: + key = ( + indexof[inst.exn_tab_entry.start], + indexof[inst.exn_tab_entry.end], + ) + if key in entries: + assert inst.exn_tab_entry == entries[key] + entries[key] = inst.exn_tab_entry + sorted_entries = [ + entries[key] for key in sorted(entries.keys(), key=lambda t: (t[0], -t[1])) + ] + check_inst_exn_tab_entries_nested(sorted_entries, indexof) + # Propagation of nested entries works since nested entries come later + # in sorted order. + for entry in sorted_entries: + for i in range(indexof[entry.start], indexof[entry.end] + 1): + instructions[i].exn_tab_entry = copy.copy(entry) + + +def check_inst_exn_tab_entries_valid(instructions: List[Instruction]): + """ + Checks that exn_tab_entries of instructions are valid. + An entry's start, end, and target must be in instructions. + Instructions with an exn_tab_entry are located within + the entry's start and end instructions. + Instructions do not share exn_tab_entries. + + Implicitly checks for no duplicate instructions. + """ + indexof = get_indexof(instructions) + exn_tab_entry_set = set() + for i, inst in enumerate(instructions): + if inst.exn_tab_entry: + assert sys.version_info >= (3, 11) + assert id(inst.exn_tab_entry) not in exn_tab_entry_set + exn_tab_entry_set.add(id(inst.exn_tab_entry)) + entry = inst.exn_tab_entry + assert entry.start in indexof + assert entry.end in indexof + assert entry.target in indexof + assert indexof[entry.start] <= i <= indexof[entry.end] + + +def strip_extended_args(instructions: List[Instruction]) -> None: + instructions[:] = [i for i in instructions if i.opcode != dis.EXTENDED_ARG] + + +def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction]: + """LOAD_METHOD puts a NULL on the stack which causes issues, so remove it""" + assert sys.version_info < (3, 11) + rewrites = {"LOAD_METHOD": "LOAD_ATTR", "CALL_METHOD": "CALL_FUNCTION"} + for inst in instructions: + if inst.opname in rewrites: + inst.opname = rewrites[inst.opname] + inst.opcode = dis.opmap[inst.opname] + return instructions + + +def remove_jump_if_none(instructions: List[Instruction]) -> None: + new_insts = [] + for inst in instructions: + new_insts.append(inst) + if "_NONE" in inst.opname: + is_op = create_instruction("IS_OP", arg=int("NOT" in inst.opname)) + is_op.argval = is_op.arg + is_op.positions = inst.positions + if sys.version_info < (3, 12): + jump_op = create_instruction( + "POP_JUMP_FORWARD_IF_TRUE" + if "FORWARD" in inst.opname + else "POP_JUMP_BACKWARD_IF_TRUE", + target=inst.target, + ) + else: + jump_op = create_instruction("POP_JUMP_IF_TRUE", target=inst.target) + jump_op.positions = inst.positions + # update inst.exn_tab_entry.end if necessary + if inst.exn_tab_entry and inst.exn_tab_entry.end is inst: + inst.exn_tab_entry.end = jump_op + # preserve exception table entries + is_op.exn_tab_entry = copy.copy(inst.exn_tab_entry) + jump_op.exn_tab_entry = copy.copy(inst.exn_tab_entry) + # modify inst in-place to preserve jump target + inst.opcode = dis.opmap["LOAD_CONST"] + inst.opname = "LOAD_CONST" + inst.arg = None + inst.argval = None + new_insts.extend([is_op, jump_op]) + instructions[:] = new_insts + + +def remove_binary_store_slice(instructions: List[Instruction]) -> None: + new_insts = [] + for inst in instructions: + new_insts.append(inst) + if inst.opname in ("BINARY_SLICE", "STORE_SLICE"): + # new instruction + subscr_inst = create_instruction(inst.opname.replace("SLICE", "SUBSCR")) + if inst.exn_tab_entry and inst.exn_tab_entry.end is inst: + inst.exn_tab_entry.end = subscr_inst + subscr_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry) + subscr_inst.positions = inst.positions + # modify inst in-place to preserve jump target + inst.opcode = dis.opmap["BUILD_SLICE"] + inst.opname = "BUILD_SLICE" + inst.arg = 2 + inst.argval = 2 + new_insts.append(subscr_inst) + instructions[:] = new_insts + + +FUSED_INSTS = { + "LOAD_FAST_LOAD_FAST": ("LOAD_FAST", "LOAD_FAST"), + "STORE_FAST_STORE_FAST": ("STORE_FAST", "STORE_FAST"), + "STORE_FAST_LOAD_FAST": ("STORE_FAST", "LOAD_FAST"), +} + + +def remove_fused_load_store(instructions: List[Instruction]) -> None: + new_insts = [] + for inst in instructions: + new_insts.append(inst) + if inst.opname in FUSED_INSTS: + inst0, inst1 = FUSED_INSTS[inst.opname] + argval0, argval1 = inst.argval + + # modify inst in-place to preserve jump target + inst.opcode = dis.opmap[inst0] + inst.opname = inst0 + inst.argval = argval0 + + new_inst = create_instruction(inst1, argval=argval1) + # update inst.exn_tab_entry.end if necessary + if inst.exn_tab_entry and inst.exn_tab_entry.end is inst: + inst.exn_tab_entry.end = new_inst + # preserve exception table entries + new_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry) + + new_insts.append(new_inst) + instructions[:] = new_insts + + +def explicit_super(code: types.CodeType, instructions: List[Instruction]) -> None: + """convert super() with no args into explicit arg form""" + cell_and_free = (code.co_cellvars or ()) + (code.co_freevars or ()) + if not len(code.co_varnames): + # A function with no argument cannot contain a valid "super()" call + return + output = [] + for idx, inst in enumerate(instructions): + output.append(inst) + if inst.opname == "LOAD_GLOBAL" and inst.argval == "super": + nexti = instructions[idx + 1] + if nexti.arg == 0 and ( + (sys.version_info >= (3, 12) and nexti.opname == "CALL") + or ( + sys.version_info >= (3, 11) + and sys.version_info < (3, 12) + and nexti.opname == "PRECALL" + ) + or (sys.version_info < (3, 11) and nexti.opname == "CALL_FUNCTION") + ): + assert "__class__" in cell_and_free + output.append(create_instruction("LOAD_DEREF", argval="__class__")) + first_var = code.co_varnames[0] + if first_var in cell_and_free: + output.append(create_instruction("LOAD_DEREF", argval=first_var)) + else: + output.append(create_instruction("LOAD_FAST", argval=first_var)) + nexti.arg = 2 + nexti.argval = 2 + if nexti.opname == "PRECALL": + # also update the following CALL instruction + call_inst = instructions[idx + 2] + call_inst.arg = 2 + call_inst.argval = 2 + + instructions[:] = output + + +def fix_extended_args(instructions: List[Instruction]) -> int: + """Fill in correct argvals for EXTENDED_ARG ops""" + output: List[Instruction] = [] + + def maybe_pop_n(n): + for _ in range(n): + if output and output[-1].opcode == dis.EXTENDED_ARG: + output.pop() + + for inst in instructions: + if inst.opcode == dis.EXTENDED_ARG: + # Leave this instruction alone for now so we never shrink code + inst.arg = 0 + elif inst.arg and inst.arg > 0xFFFFFF: + maybe_pop_n(3) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 24)) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 16)) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8)) + elif inst.arg and inst.arg > 0xFFFF: + maybe_pop_n(2) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 16)) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8)) + elif inst.arg and inst.arg > 0xFF: + maybe_pop_n(1) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8)) + output.append(inst) + + added = len(output) - len(instructions) + assert added >= 0 + instructions[:] = output + return added + + +def instruction_size(inst) -> int: + import torch + + if sys.version_info >= (3, 11): + return 2 * (torch._C._dynamo.eval_frame.py_opcode_caches[inst.opcode] + 1) + return 2 + + +def check_offsets(instructions) -> None: + offset = 0 + for inst in instructions: + assert inst.offset == offset + offset += instruction_size(inst) + + +def update_offsets(instructions) -> None: + offset = 0 + for inst in instructions: + inst.offset = offset + offset += instruction_size(inst) + + +def debug_bytes(*args) -> str: + index = range(max(map(len, args))) + result = [] + for arg in ( + [index] + list(args) + [[int(a != b) for a, b in zip(args[-1], args[-2])]] + ): + result.append(" ".join(f"{x:03}" for x in arg)) + + return "bytes mismatch\n" + "\n".join(result) + + +def debug_checks(code): + """Make sure our assembler produces same bytes as we start with""" + dode = transform_code_object(code, lambda x, y: None, safe=True) + assert code.co_code == dode.co_code, debug_bytes(code.co_code, dode.co_code) + assert code.co_lnotab == dode.co_lnotab, debug_bytes(code.co_lnotab, dode.co_lnotab) + + +HAS_LOCAL = set(dis.haslocal) +HAS_NAME = set(dis.hasname) +HAS_FREE = set(dis.hasfree) +HAS_CONST = set(dis.hasconst) + + +def get_const_index(code_options, val) -> int: + for i, v in enumerate(code_options["co_consts"]): + # NOTE: stronger comparison is required, since we have + # examples where two values compare equal but have + # different semantic meaning in some cases, e.g. + # 0.0 == -0.0 but have different effects in torch.copysign. + if val is v: + return i + code_options["co_consts"] += (val,) + return len(code_options["co_consts"]) - 1 + + +def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=None): + # compute instruction arg from argval if arg is not provided + names = {name: idx for idx, name in enumerate(code_options["co_names"])} + + def get_name_index(name) -> int: + try: + idx = names[name] + except KeyError: + # Add a missing item to co_names + idx = names[name] = len(names) + code_options["co_names"] = (*code_options["co_names"], name) + assert len(code_options["co_names"]) == len(names) + return idx + + if sys.version_info < (3, 11): + assert varname_from_oparg is None + varnames = {name: idx for idx, name in enumerate(code_options["co_varnames"])} + freenames = { + name: idx + for idx, name in enumerate( + code_options["co_cellvars"] + code_options["co_freevars"] + ) + } + else: + assert callable(varname_from_oparg) + allnames = {} + for idx in itertools.count(): + try: + name = varname_from_oparg(idx) + allnames[name] = idx + except IndexError: + break + varnames = {name: allnames[name] for name in code_options["co_varnames"]} + freenames = { + name: allnames[name] + for name in code_options["co_cellvars"] + code_options["co_freevars"] + } + for i in range(len(instructions)): + + def should_compute_arg(): + # argval is prioritized over arg + return instructions[i].argval is not _NotProvided + + if instructions[i].opname == "LOAD_GLOBAL": + # 3.11 LOAD_GLOBAL requires both arg and argval - see create_instruction + assert instructions[i].argval is not _NotProvided + if sys.version_info >= (3, 11): + assert instructions[i].arg is not None + instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + ( + cast(int, instructions[i].arg) % 2 + ) + else: + instructions[i].arg = get_name_index(instructions[i].argval) + elif instructions[i].opname == "LOAD_ATTR": + # 3.12 LOAD_ATTR requires both arg and argval, like LOAD_GLOBAL + assert instructions[i].argval is not _NotProvided + if sys.version_info >= (3, 12): + assert instructions[i].arg is not None + instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + ( + cast(int, instructions[i].arg) % 2 + ) + else: + instructions[i].arg = get_name_index(instructions[i].argval) + elif instructions[i].opname == "LOAD_SUPER_ATTR": + assert instructions[i].arg is not None + assert instructions[i].argval is not _NotProvided + # Copy low bit, force second bit on for explicit super (the "+ 2") + instructions[i].arg = ( + (get_name_index(instructions[i].argval) << 2) + + (cast(int, instructions[i].arg) % 2) + + 2 + ) + elif instructions[i].opcode in HAS_LOCAL: + if should_compute_arg(): + if ( + sys.version_info >= (3, 13) + and instructions[i].argval not in varnames + ): + # instructions like LOAD_FAST used for both local and free vars + instructions[i].arg = freenames[instructions[i].argval] + else: + instructions[i].arg = varnames[instructions[i].argval] + elif instructions[i].opcode in HAS_NAME: + if should_compute_arg(): + instructions[i].arg = get_name_index(instructions[i].argval) + elif instructions[i].opcode in HAS_FREE: + if should_compute_arg(): + instructions[i].arg = freenames[instructions[i].argval] + elif instructions[i].opcode in HAS_CONST: + # NOTE: only update argval if arg is not provided. This assumes + # that any additions to co_consts are appended. + if instructions[i].arg is None: + # cannot use a dictionary since consts may not be hashable + idx = get_const_index(code_options, instructions[i].argval) + assert idx >= 0 + instructions[i].arg = idx + + +def clear_instruction_args(instructions): + # Clear the instruction arg for instructions that have argvals. + # Useful for using dis'd bytecode within generated bytecode. + for inst in instructions: + if ( + inst.argval is not _NotProvided + and ( + inst.opcode in HAS_LOCAL + or inst.opcode in HAS_NAME + or inst.opcode in HAS_FREE + or inst.opcode in HAS_CONST + ) + and inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR", "LOAD_SUPER_ATTR") + ): + inst.arg = None + + +def get_code_keys() -> List[str]: + # Python 3.11 changes to code keys are not fully documented. + # See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24 + # for new format. + keys = ["co_argcount"] + keys.append("co_posonlyargcount") + keys.extend( + [ + "co_kwonlyargcount", + "co_nlocals", + "co_stacksize", + "co_flags", + "co_code", + "co_consts", + "co_names", + "co_varnames", + "co_filename", + "co_name", + ] + ) + if sys.version_info >= (3, 11): + keys.append("co_qualname") + keys.append("co_firstlineno") + if sys.version_info >= (3, 10): + keys.append("co_linetable") + else: + keys.append("co_lnotab") + if sys.version_info >= (3, 11): + # not documented, but introduced in https://github.com/python/cpython/issues/84403 + keys.append("co_exceptiontable") + keys.extend( + [ + "co_freevars", + "co_cellvars", + ] + ) + return keys + + +def transform_code_object(code, transformations, safe=False) -> types.CodeType: + keys = get_code_keys() + code_options = {k: getattr(code, k) for k in keys} + assert len(code_options["co_varnames"]) == code_options["co_nlocals"] + + instructions = cleaned_instructions(code, safe) + propagate_line_nums(instructions) + + transformations(instructions, code_options) + return clean_and_assemble_instructions(instructions, keys, code_options)[1] + + +def clean_and_assemble_instructions( + instructions: List[Instruction], keys: List[str], code_options: Dict[str, Any] +) -> Tuple[List[Instruction], types.CodeType]: + # also implicitly checks for no duplicate instructions + check_inst_exn_tab_entries_valid(instructions) + + code_options["co_nlocals"] = len(code_options["co_varnames"]) + varname_from_oparg = None + if sys.version_info >= (3, 11): + # temporary code object with updated names + tmp_code = types.CodeType(*[code_options[k] for k in keys]) + varname_from_oparg = tmp_code._varname_from_oparg # type: ignore[attr-defined] + fix_vars(instructions, code_options, varname_from_oparg=varname_from_oparg) + + dirty = True + while dirty: + update_offsets(instructions) + devirtualize_jumps(instructions) + # this pass might change offsets, if so we need to try again + dirty = bool(fix_extended_args(instructions)) + + remove_extra_line_nums(instructions) + bytecode, lnotab = assemble(instructions, code_options["co_firstlineno"]) + if sys.version_info < (3, 10): + code_options["co_lnotab"] = lnotab + else: + code_options["co_linetable"] = lnotab + + code_options["co_code"] = bytecode + code_options["co_stacksize"] = stacksize_analysis(instructions) + assert set(keys) - {"co_posonlyargcount"} == set(code_options.keys()) - { + "co_posonlyargcount" + } + if sys.version_info >= (3, 11): + code_options["co_exceptiontable"] = assemble_exception_table( + compute_exception_table(instructions) + ) + + return instructions, types.CodeType(*[code_options[k] for k in keys]) + + +def populate_kw_names_argval(instructions, consts): + for inst in instructions: + if inst.opname == "KW_NAMES": + inst.argval = consts[inst.arg] + + +def cleaned_instructions(code, safe=False) -> List[Instruction]: + instructions = list(map(convert_instruction, dis.get_instructions(code))) + check_offsets(instructions) + if sys.version_info >= (3, 11): + populate_kw_names_argval(instructions, code.co_consts) + virtualize_exception_table(code.co_exceptiontable, instructions) + virtualize_jumps(instructions) + strip_extended_args(instructions) + if not safe: + if sys.version_info < (3, 11): + remove_load_call_method(instructions) + if sys.version_info < (3, 12): + explicit_super(code, instructions) + if sys.version_info >= (3, 11): + remove_jump_if_none(instructions) + if sys.version_info >= (3, 12): + remove_binary_store_slice(instructions) + if sys.version_info >= (3, 13): + remove_fused_load_store(instructions) + update_offsets(instructions) + devirtualize_jumps(instructions) + return instructions + + +_unique_id_counter = itertools.count() + + +def unique_id(name) -> str: + return f"{name}_{next(_unique_id_counter)}" + + +def is_generator(code: types.CodeType) -> bool: + co_generator = 0x20 + return (code.co_flags & co_generator) > 0 + + +def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True): + """Generates bytecode from a template function `fn` for use in + dynamo bytecode generation. + + For example, we can generate Python-version-independent bytecode + for looping through a dictionary and copying the values to a new dictionary. + + def template(d1, d2): + for k, v in d1.items(): + d2[k] = v + + + or a try block: + + def template(): + try: + dummy1 + except: + dummy2 + raise + dummy3 + + Args: + fn: a function template to generate bytecode from + varname_map: a mapping of `fn`'s varnames to new names. This + map will be applied to the generated bytecode's varnames. + For example, local variables in `fn` can be replaced with + new names that are generated by `OutputGraph.new_var`. + noreturn: remove all RETURN_* bytecodes and replace them with a jump + to the end of the bytecode. + noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive). + """ + insts = cleaned_instructions(fn.__code__) + clear_instruction_args(insts) + + if noprefix: + for i, inst in enumerate(insts): + if inst.opname == "RESUME": + insts = insts[i + 1 :] + break + + for inst in insts: + # If we don't reset starts_line, then the generated + # bytecode's line number will be based on fn's. + inst.starts_line = None + if varname_map and inst.argval in varname_map: + inst.argval = varname_map[inst.argval] + + if noreturn: + if sys.version_info >= (3, 12): + # replace RETURN_CONST with LOAD_CONST RETURN_VALUE + new_insts = [] + for inst in insts: + if inst.opname == "RETURN_CONST": + inst.opcode = dis.opmap["LOAD_CONST"] + inst.opname = "LOAD_CONST" + new_insts.append(inst) + # no need to propagate target/exn table + new_insts.append(create_instruction("RETURN_VALUE")) + else: + new_insts.append(inst) + insts = new_insts + + returns = [] + for inst in insts: + if inst.opname == "RETURN_VALUE": + returns.append(inst) + + if len(returns) == 1 and returns[0] is insts[-1]: + # only 1 return at the end - just pop it + insts.pop(-1) + elif len(returns) > 0: + # create jump target - if the last inst is a return, + # we can replace it with a NOP and make that the jump target. + if insts[-1] is returns[-1]: + insts[-1].opname = "NOP" + insts[-1].opcode = dis.opmap["NOP"] + insts[-1].arg = None + insts[-1].argval = _NotProvided + returns.pop(-1) + else: + insts.append(create_instruction("NOP")) + + # replace returns with jumps + for inst in returns: + # don't replace inst with new instruction + # due to targetting/exn table/etc. + jump_inst = create_jump_absolute(insts[-1]) + inst.opname = jump_inst.opname + inst.opcode = jump_inst.opcode + inst.arg = jump_inst.arg + inst.argval = jump_inst.argval + inst.target = jump_inst.target + + return insts diff --git a/lib/python3.10/site-packages/torch/_dynamo/cache_size.py b/lib/python3.10/site-packages/torch/_dynamo/cache_size.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a793aa06c47ebf800ae705d85cbcc484de9c46 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/cache_size.py @@ -0,0 +1,185 @@ +# mypy: allow-untyped-defs +import logging +import types +import weakref +from dataclasses import dataclass +from typing import Tuple + +from torch._guards import CompileId + +from . import config + + +log = logging.getLogger(__name__) +""" +[Note on cache size limit] + +Background - TorchDynamo cache is a linked list. Each cache entry is a +(check_fn, out_code, next pointer). These are stored on the f_code's co_extra +scratch space. When a frame is invoked, we walk this linked list and run +check_fn in each cache_entry to decide if the frame needs recompilation. If none +of the check_fn's returns True, we recompile and add a new entry. To ensure we +don't end up recompiling infinitely, we put limits on the cache size. + +There are two limits +1) cache_size_limit +2) accumulated_cache_size_limit + + +Earlier we used to have only limit - maximum number of entries in 1 cache line +(which is now represented by (2) above). So, why do we need two limits? Lets try +to understand that. + +In general, we want our cache limit value to be a small number (e.g. 8 or even +lower). This ensures that for frames that cause too many recompilation fall to +eager quickly. However, there is another problem that prevents us from lowering +the value of cache_size_limit. This is due to ID_MATCH'd guards. Today, we put +ID_MATCH guards on nn module if there is a graph break. This means we will have +many recompilations for the same code object because the ID_MATCH guard fails +for different instances of the nn module. This is a common pattern in how models +are authored. Therefore, this requires us to keep the cache_size_limit high. + +We resolve this by introducing these two limits. The first limit (1) limits the +number of cache entries that have an ID_MATCH'd guard for an nn module instance. +And, (2)nd limit becomes a safeguard mechanism to have a maximum compilations +for a code object. One important question is - what is the limit for the code +object that does not have any ID_MATCH guard? For such code objects, we choose +(1) as the cache size limit. + +Lets take an example to understand how these limits help. Suppose, we have 16 +instances of a nn module and we ID_MATCH on the self object. Further, suppose +the inputs to these functions have varying batch size, leading to one +recompilation. In total, there will be 32 recompilations, and therefore 32 cache +entries on the forward code object. In the older case when we had only 1 limit, +our cache size limit must be >= 32 to capture all these recompilations. Now, +suppose there is a separate function in the same program which is very dynamic +and unsuitable for compilation. Such a function will need to undergo 32 +compilations to burst the cache and fallback to eager. These 32 recompilations +are too many and we want to fallback for these compilation-unfriendly functions +sooner. + +In the new scenario, we can have (1) cache_size_limit = 2, (2) +accumulated_cache_size_limit = 32. This means that each ID_MATCH'd object can +have maximum of two cache entries, and the maximum number of cache entries +(irrespective of ID_MATCH obj) is 32. This covers the case of forward code +object which has 32 recompilations. For the other function, the one unsuitable +for recompilation, our limit is 2. So, we will burst the cache in just 2 +recompilations. In this manner, these 2 limits help us resolve the tension +mentioned earlier. +""" + + +@dataclass +class CacheSizeRelevantForFrame: + """ + We track the number of cache entries that have same id_match objects as the + given frame. + + TODO(janimesh) - Consider adding a map from tuple_of_match_ids to count - + https://github.com/pytorch/pytorch/pull/107496#discussion_r1304564682 - this + could be useful for debugging as well. + """ + + # Total number of CacheEntry objects in the Dynamo linked list + num_cache_entries: int = 0 + + # Number of CacheEntry objects having same ID_MATCH'd objects as given frame. + num_cache_entries_with_same_id_matched_objs: int = 0 + + def will_compilation_exceed(self, limit: int) -> bool: + # Checks if a compilation will exceed the given limit (thats why >=). + return ( + self.will_compilation_exceed_accumulated_limit() + or self.will_compilation_exceed_specific_limit(limit) + ) + + def will_compilation_exceed_accumulated_limit(self) -> bool: + return self.num_cache_entries >= config.accumulated_cache_size_limit + + def will_compilation_exceed_specific_limit(self, limit: int) -> bool: + return self.num_cache_entries_with_same_id_matched_objs >= limit + + +def _get_weakref_from_f_locals(frame: types.FrameType, local_name: str): + obj = frame.f_locals.get(local_name, None) + weak_id = None + try: + weak_id = weakref.ref(obj) + except TypeError: + pass # cannot weakref bool object + return weak_id + + +def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool: + """ + Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones + in frame.f_locals. + """ + if not cache_entry: + return False + + for ( + local_name, + weakref_from_cache_entry, + ) in cache_entry.check_fn.id_matched_objs.items(): + if weakref_from_cache_entry() is not None: + weakref_from_frame = _get_weakref_from_f_locals(frame, local_name) + if weakref_from_frame != weakref_from_cache_entry: + return False + + # Also covers the case where no ID_MATCH objects are saved in frame.f_locals + return True + + +def compute_cache_size( + frame: types.FrameType, cache_entry +) -> CacheSizeRelevantForFrame: + # Walk the linked list to calculate the cache size + num_cache_entries = 0 + num_cache_entries_with_same_id_matched_objs = 0 + + while cache_entry: + num_cache_entries += 1 + # Track the number of cache entries having same ID_MATCH'd objects as + # that of frame.f_locals. This will be used later to compare against the + # cache_size_limit. + if _has_same_id_matched_objs(frame, cache_entry): + num_cache_entries_with_same_id_matched_objs += 1 + cache_entry = cache_entry.next + + return CacheSizeRelevantForFrame( + num_cache_entries, num_cache_entries_with_same_id_matched_objs + ) + + +def is_recompilation(cache_size: CacheSizeRelevantForFrame) -> bool: + """ + If the frame (earlier parsed by compute_cache_size) has more than 1 cache + entry with same ID_MATCH'd objects, then its a recompilation. + """ + # Note that you can have multiple entries in the cache but still not a + # recompile, e.g., you can have 64 nn module instances, each one having an + # ID_MATCH guard, and each one having just 1 cache entry in the cache. In + # this case, we can have 64 entries in the cache, but no recompilation + # because there is only one entry for each id_matched_obj. + return cache_size.will_compilation_exceed(1) + + +def exceeds_cache_size_limit( + cache_size: CacheSizeRelevantForFrame, compile_id: CompileId +) -> Tuple[bool, str]: + """ + Checks if we are exceeding the cache size limit. + """ + if cache_size.will_compilation_exceed_accumulated_limit(): + return True, "accumulated_cache_size_limit" + if cache_size.will_compilation_exceed_specific_limit(config.cache_size_limit): + return True, "cache_size_limit" + # NOTE this check is needed in the case that the frame's cache doesn't grow + # and we keep recompiling. This can happen if the guard check_fn becomes invalidated, + # e.g. due to guarded objects being freed. This technically makes the + # will_compilation_exceed_accumulated_limit check unnecessary, but we will keep the + # check in case we have a better fix in the future. + if compile_id.frame_compile_id >= config.accumulated_cache_size_limit: + return True, "accumulated_cache_size_limit" + return False, "" diff --git a/lib/python3.10/site-packages/torch/_dynamo/callback.py b/lib/python3.10/site-packages/torch/_dynamo/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..35f447a8034903833d142ad4225995bdded9e3a1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/callback.py @@ -0,0 +1,83 @@ +# mypy: allow-untyped-defs +class CompilationCallbackHandler: + def __init__(self): + self.start_callbacks = [] + self.end_callbacks = [] + + def register_start_callback(self, callback): + """ + Register a callback function to be called when the compilation starts. + + Args: + - callback (callable): The callback function to register. + """ + self.start_callbacks.append(callback) + return callback + + def register_end_callback(self, callback): + """ + Register a callback function to be called when the compilation ends. + + Args: + - callback (callable): The callback function to register. + """ + self.end_callbacks.append(callback) + return callback + + def remove_start_callback(self, callback): + """ + Remove a registered start callback function. + + Args: + - callback (callable): The callback function to remove. + """ + self.start_callbacks.remove(callback) + + def remove_end_callback(self, callback): + """ + Remove a registered end callback function. + + Args: + - callback (callable): The callback function to remove. + """ + self.end_callbacks.remove(callback) + + def run_start_callbacks(self): + """ + Execute all registered start callbacks. + """ + for callback in self.start_callbacks: + callback() + + def run_end_callbacks(self): + """ + Execute all registered end callbacks. + """ + for callback in self.end_callbacks: + callback() + + def clear(self): + """ + Clear all registered callbacks. + """ + self.start_callbacks.clear() + self.end_callbacks.clear() + + +callback_handler = CompilationCallbackHandler() + + +def on_compile_start(callback): + """ + Decorator to register a callback function for the start of the compilation. + """ + callback_handler.register_start_callback(callback) + return callback + + +def on_compile_end(callback): + """ + Decorator to register a callback function for the end of the compilation. + """ + callback_handler.register_end_callback(callback) + return callback diff --git a/lib/python3.10/site-packages/torch/_dynamo/code_context.py b/lib/python3.10/site-packages/torch/_dynamo/code_context.py new file mode 100644 index 0000000000000000000000000000000000000000..727aad9349555f363a727c5200c22c044c0a5083 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/code_context.py @@ -0,0 +1,30 @@ +# mypy: allow-untyped-defs +import types + +from .utils import ExactWeakKeyDictionary + + +class CodeContextDict: + def __init__(self) -> None: + self.code_context = ExactWeakKeyDictionary() + + def has_context(self, code: types.CodeType): + return code in self.code_context + + def get_context(self, code: types.CodeType): + ctx = self.code_context.get(code) + if ctx is None: + ctx = {} + self.code_context[code] = ctx + return ctx + + def pop_context(self, code: types.CodeType): + ctx = self.get_context(code) + self.code_context._remove_id(id(code)) + return ctx + + def clear(self): + self.code_context.clear() + + +code_context = CodeContextDict() diff --git a/lib/python3.10/site-packages/torch/_dynamo/codegen.py b/lib/python3.10/site-packages/torch/_dynamo/codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc8361a974eb65de316009be7dcae05626056eb --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/codegen.py @@ -0,0 +1,511 @@ +# mypy: allow-untyped-defs +import collections +import dataclasses +import re +import sys +import types +from typing import Counter, Dict, List, Optional + +import torch.nn + +from . import utils +from .bytecode_transformation import ( + add_push_null, + add_push_null_call_function_ex, + create_call_function, + create_call_method, + create_dup_top, + create_instruction, + create_load_method, + create_rot_n, + Instruction, +) +from .exc import unimplemented +from .source import AttrSource, Source +from .utils import is_safe_constant, rot_n_helper +from .variables.base import VariableTracker +from .variables.nn_module import NNModuleVariable +from .variables.tensor import ( + NumpyNdarrayVariable, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .variables.torch_function import TensorWithTFOverrideVariable + + +@dataclasses.dataclass +class GraphOutputEntry: + index: int + variable: VariableTracker + + +class PyCodegen: + """ + Helper class uses for constructing Python bytecode + """ + + def __init__( + self, + tx=None, + root: Optional[torch.nn.Module] = None, + graph_output_var: Optional[str] = None, + tempvars=None, + ) -> None: + self.root = root + self.top_of_stack: Optional[VariableTracker] = None + self.uses: Counter[VariableTracker] = collections.Counter() + self.graph_outputs: Dict[int, GraphOutputEntry] = {} + self._output: List[Instruction] = [] + self.tempvars = tempvars or {} + self.tx = tx + self.graph_output_var = graph_output_var + self.code_options = self.tx.output.code_options + self.cell_and_freevars = self.tx.cell_and_freevars + self.new_var = self.tx.output.new_var + self.mutable_side_effects_from_source = False + self.value_from_source: bool = True + + def restore_stack(self, stack_values, *, value_from_source=True): + prior = self.mutable_side_effects_from_source + self.mutable_side_effects_from_source = True + prev = self.value_from_source + self.value_from_source &= value_from_source + try: + self.foreach(stack_values) + finally: + self.mutable_side_effects_from_source = prior + self.value_from_source = prev + + def graph_output_vars(self): + return [x.variable for x in self.graph_outputs.values()] + + def call_reconstruct(self, value): + res = value.reconstruct(self) + assert res is None, f"reconstruct!=None {value}" + + def add_push_null(self, gen_fn, call_function_ex=False): + """ + `gen_fn` generates instructions via PyCodegen methods + that push a single callable to the stack. + + `add_push_null` pushes a NULL to the stack before or after the + instructions generated by `gen_fn`, depending on Python version. + + Will attempt to use the NULL push bit for instructions + with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR). + """ + old_len = len(self._output) + if sys.version_info < (3, 13): + # gen_fn may DUP_TOP instead if TOS is not cleared. + # Will cause problems since NULL will be pushed right + # before the generated instructions in <= 3.12 + self.clear_tos() + gen_fn() + # inplace modify self._output + added_insts = self._output[old_len:] + del self._output[old_len:] + if call_function_ex: + self._output.extend(add_push_null_call_function_ex(added_insts)) + else: + self._output.extend(add_push_null(added_insts)) + if sys.version_info >= (3, 13): + # NULL will be at top of stack + self.clear_tos() + + def __call__(self, value, allow_cache=True): + """Generate code such that top-of-stack (TOS) is set to value""" + if isinstance(value, Source): + self.call_reconstruct(value) + self.clear_tos() + return + + assert isinstance(value, VariableTracker) + output = self._output + graph_outputs = self.graph_outputs + + if self.top_of_stack is value and allow_cache: + output.append(create_dup_top()) + return + + if self.mutable_side_effects_from_source: + # this is needed to get aliasing relationships right + # value.mutable_local.source will get mutated to hold `value` + # mutable_side_effects_from_source=False is used to codegen the mutation + # mutable_side_effects_from_source=True is used to codegen a reference + from .side_effects import MutableSideEffects + + if isinstance(value.mutable_local, MutableSideEffects): + self(value.mutable_local.source) + return + + if allow_cache: + if value.mutable_local and value.mutable_local in self.tempvars: + output.append(self.create_load(self.tempvars[value.mutable_local])) + self.top_of_stack = value + return + if self.tempvars.get(value) is not None: + output.append(self.create_load(self.tempvars[value])) + self.top_of_stack = value + return + + if value.source is not None and allow_cache and self.value_from_source: + self.call_reconstruct(value.source) + elif value.is_python_constant() and is_safe_constant( + value.as_python_constant() + ): + output.append(self.create_load_const(value.as_python_constant())) + elif isinstance(value, TensorWithTFOverrideVariable): + graph_outputs_key = self.add_graph_output(value) + + self.add_push_null( + lambda: self.load_import_from(utils.__name__, "to_subclass") + ) + self.load_graph_output(graph_outputs[graph_outputs_key].index) + output.append( + self.create_load_global( + value.global_mangled_class_name(self.tx), add=True + ) + ) + output.extend(create_call_function(2, False)) + elif ( + isinstance(value, SymNodeVariable) + and value.python_type() == float + and not self.tx.export + ): + # This is a little unusual; force the output convention to be a + # Tensor here. Don't do this for export because this is + # apparently load bearing for export tests (but I am a bit + # doubtful it actually works in the real world) + # NB: It works to add_graph_output on a computed expression + # as_tensor here, because we memoize as_tensor calls on + # SymNodeVariable! + graph_outputs_key = self.add_graph_output(value.as_tensor(self.tx)) + + def gen_fn(): + self.load_graph_output(graph_outputs[graph_outputs_key].index) + output.append(self.create_load_attr("item")) + + self.add_push_null(gen_fn) + output.extend(create_call_function(0, False)) + elif isinstance( + value, + ( + TensorVariable, + SymNodeVariable, + UnspecializedPythonVariable, + NumpyNdarrayVariable, + ), + ): + graph_outputs_key = self.add_graph_output(value) + + if isinstance(value, NumpyNdarrayVariable): + self.add_push_null( + lambda: self.load_import_from(utils.__name__, "to_numpy_helper") + ) + self.load_graph_output(graph_outputs[graph_outputs_key].index) + output.extend(create_call_function(1, False)) + elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap: + + def gen_fn(): + self.load_graph_output(graph_outputs[graph_outputs_key].index) + output.append(self.create_load_attr("item")) + + self.add_push_null(gen_fn) + output.extend(create_call_function(0, False)) + else: + self.load_graph_output(graph_outputs[graph_outputs_key].index) + elif isinstance(value, NNModuleVariable): + parts = value.module_key.split(".") + if parts[0] in self.code_options["co_varnames"]: + output.append(self.create_load(parts[0])) + parts = parts[1:] + else: + assert self.root is not None + output.append(self.create_load_output(self.root)) + for part in parts: + output.append(self.create_load_attr(part)) + else: + self.uses[value] += 1 + try: + self.call_reconstruct(value) + except NotImplementedError: + unimplemented(f"reconstruct: {value}") + if allow_cache and value in self.tempvars: + self._output.append(create_dup_top()) + self.add_cache(value) + + self.top_of_stack = value + + def add_graph_output(self, value): + graph_outputs_key = id(value.as_proxy()) + if graph_outputs_key not in self.graph_outputs: + self.graph_outputs[graph_outputs_key] = GraphOutputEntry( + len(self.graph_outputs), value + ) + return graph_outputs_key + + def load_graph_output(self, index): + output = self._output + output.append(self.create_load(self.graph_output_var)) + output.append(self._create_load_const(index)) + output.append(create_instruction("BINARY_SUBSCR")) + + def add_cache(self, value): + var = self.new_var() + self.tempvars[value] = var + if value.mutable_local: + self.tempvars[value.mutable_local] = var + self._output.append(self.create_store(var)) + + def foreach(self, items): + for i in items: + self(i) + + def setup_globally_cached(self, name, value): + """Store value in a new global""" + name = re.sub(r"[^a-zA-Z0-9_]+", "_", name) + f_globals = self.tx.f_globals + if name in f_globals: + assert id(f_globals[name]) == id(value) + else: + f_globals[name] = value + return [self.create_load_global(name, add=True)] + + def clear_tos(self): + self.top_of_stack = None + + def append_output(self, inst): + assert isinstance(inst, Instruction) + self._output.append(inst) + self.clear_tos() + + def extend_output(self, insts): + assert all(isinstance(x, Instruction) for x in insts) + self._output.extend(insts) + self.clear_tos() + + def get_instructions(self) -> List[Instruction]: + return self._output + + def create_load(self, name) -> Instruction: + if name in self.cell_and_freevars(): + return create_instruction("LOAD_DEREF", argval=name) + assert name in self.code_options["co_varnames"], f"{name} missing" + return create_instruction("LOAD_FAST", argval=name) + + def create_load_closure(self, name) -> Instruction: + assert name in self.cell_and_freevars() + inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE" + return create_instruction(inst_name, argval=name) + + def create_store(self, name) -> Instruction: + if name in self.cell_and_freevars(): + return create_instruction("STORE_DEREF", argval=name) + assert name in self.code_options["co_varnames"] + return create_instruction("STORE_FAST", argval=name) + + def create_load_global(self, name, add=False) -> Instruction: + if add: + self.tx.output.update_co_names(name) + assert name in self.code_options["co_names"], f"{name} not in co_names" + return create_instruction("LOAD_GLOBAL", argval=name) + + def create_load_const(self, value) -> Instruction: + assert is_safe_constant(value), f"unsafe constant {value}" + return self._create_load_const(value) + + def _create_load_const(self, value) -> Instruction: + return create_instruction("LOAD_CONST", argval=value) + + create_load_output = _create_load_const + + def load_method(self, name): + self.tx.output.update_co_names(name) + self.append_output(create_load_method(name)) + + def call_method(self, nargs): + self.extend_output(create_call_method(nargs)) + + def create_load_attr(self, name) -> Instruction: + if name not in self.code_options["co_names"]: + self.code_options["co_names"] += (name,) + return create_instruction("LOAD_ATTR", argval=name) + + def load_attr(self, name): + self.append_output(self.create_load_attr(name)) + + def create_load_attrs(self, names): + return [self.create_load_attr(name) for name in names.split(".")] + + def create_store_attr(self, name) -> Instruction: + if name not in self.code_options["co_names"]: + self.code_options["co_names"] += (name,) + return create_instruction("STORE_ATTR", argval=name) + + def store_attr(self, name): + self.append_output(self.create_store_attr(name)) + + def load_function_name(self, fn_name, push_null, num_on_stack=0): + """Load the global fn_name on the stack num_on_stack down""" + output = [] + if push_null and sys.version_info >= (3, 11): + output.extend(add_push_null(self.create_load_global(fn_name, add=True))) + if num_on_stack > 0: + output.extend( + [ + *self.rot_n(num_on_stack + 2), + *self.rot_n(num_on_stack + 2), + ] + ) + else: + output.extend( + [ + self.create_load_global(fn_name, add=True), + *self.rot_n(num_on_stack + 1), + ] + ) + return output + + def rot_n(self, n): + try: + return create_rot_n(n) + except AttributeError: + # desired rotate bytecode doesn't exist, generate equivalent bytecode + return [ + create_instruction("BUILD_TUPLE", arg=n), + self._create_load_const(rot_n_helper(n)), + *create_rot_n(2), + create_instruction("CALL_FUNCTION_EX", arg=0), + create_instruction("UNPACK_SEQUENCE", arg=n), + ] + + def pop_null(self): + # POP_TOP doesn't work for null, so we pop nulls by pushing in a + # nop function, calling it (which consumes the null), and popping the result. + assert sys.version_info >= (3, 11) + return [ + self._create_load_const(lambda: None), + # 3.13 swapped NULL and callable + *( + (create_instruction("SWAP", arg=2),) + if sys.version_info >= (3, 13) + else () + ), + *create_call_function(0, False), + create_instruction("POP_TOP"), + ] + + def pop_top(self): + self.append_output(create_instruction("POP_TOP")) + + def call_function(self, nargs: int, push_null: bool): + self.extend_output(create_call_function(nargs, push_null=push_null)) + + def dup_top(self): + self.append_output(create_dup_top()) + + def store(self, varname): + self.append_output(self.create_store(varname)) + + def make_function_with_closure( + self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0 + ): + freevars = code.co_freevars + assert freevars + output = self._output + + def gen_fn(): + for var in freevars: + assert var in self.cell_and_freevars() + inst_name = ( + "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE" + ) + output.append(create_instruction(inst_name, argval=var)) + output.append(create_instruction("BUILD_TUPLE", arg=len(freevars))) + output.append(self.create_load_const(code)) + if sys.version_info < (3, 11): + output.append(self.create_load_const(fn_name)) + if sys.version_info >= (3, 13): + output.extend( + [ + create_instruction("MAKE_FUNCTION"), + create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08), + ] + ) + else: + output.append(create_instruction("MAKE_FUNCTION", arg=0x08)) + + if push_null and sys.version_info >= (3, 11): + self.add_push_null(gen_fn) + output.extend(self.rot_n(num_on_stack + 2)) + output.extend(self.rot_n(num_on_stack + 2)) + else: + gen_fn() + output.extend(self.rot_n(num_on_stack + 1)) + self.clear_tos() + + def create_load_python_module(self, mod) -> Instruction: + """ + Generate a LOAD_GLOBAL instruction to fetch a given python module. + """ + output = self.tx.output + global_scope = output.global_scope + name = re.sub(r"^.*[.]", "", mod.__name__) + if global_scope.get(name, None) is mod: + return self.create_load_global(name, add=True) + prefix = f"___module_{name}" + global_name = self.tx.output.install_global_by_id(prefix, mod) + return self.create_load_global(global_name, add=True) + + def make_call_generated_code(self, fn_name: str) -> None: + """Call the generated code function stored in fn_name""" + self.extend_output(self.load_function_name(fn_name, True)) + + graphargs = self.tx.output.graphargs + for arg in graphargs: + if arg.pass_arg_as_tensor: + self.add_push_null( + lambda: self.extend_output( + [ + self.create_load_python_module(torch), + self.create_load_attr("as_tensor"), + ] + ) + ) + self.call_reconstruct(arg) + self.extend_output(create_call_function(1, False)) + else: + self.call_reconstruct(arg) + + self.extend_output(create_call_function(len(graphargs), False)) + + def load_import_from(self, module_name, object_name) -> None: + self(AttrSource(self.tx.import_source(module_name), object_name)) + + def create_call_function_kw(self, nargs, kw_names, push_null) -> List[Instruction]: + if sys.version_info >= (3, 13): + output = create_call_function(nargs, push_null) + assert output[-1].opname == "CALL" + output.insert(-1, self.create_load_const(kw_names)) + output[-1] = create_instruction("CALL_KW", arg=nargs) + return output + elif sys.version_info >= (3, 11): + output = create_call_function(nargs, push_null) + if sys.version_info >= (3, 12): + idx = -1 + expected_inst = "CALL" + else: + idx = -2 + expected_inst = "PRECALL" + assert output[idx].opname == expected_inst + kw_names_inst = create_instruction("KW_NAMES", argval=kw_names) + output.insert(idx, kw_names_inst) + return output + return [ + self.create_load_const(kw_names), + create_instruction("CALL_FUNCTION_KW", arg=nargs), + ] + + def create_delete(self, value) -> Instruction: + return create_instruction("DELETE_FAST", argval=value) diff --git a/lib/python3.10/site-packages/torch/_dynamo/compiled_autograd.py b/lib/python3.10/site-packages/torch/_dynamo/compiled_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..e7c5d2414f6e2b8b5a4610ee4f2e192d89f1cca1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/compiled_autograd.py @@ -0,0 +1,533 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union + +import torch +from torch._dynamo.external_utils import ( + call_backward, + call_hook, + FakeCompiledAutogradEngine, +) +from torch._dynamo.source import GetItemSource, LocalSource +from torch._dynamo.utils import counters, lazy_format_graph_code, set_locals_to_steal +from torch._logging import getArtifactLogger, trace_structured +from torch._prims_common import clone_preserve_strides +from torch._subclasses import FakeTensorMode +from torch.fx import GraphModule +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import ( + decompose, + disable_autocast_cache, + disable_proxy_modes_tracing, + fetch_object_proxy, + ProxyTorchDispatchMode, + PythonKeyTracer, + track_tensor_tree, +) +from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv +from torch.fx.traceback import preserve_node_meta, set_stack_trace +from torch.utils._traceback import CapturedTraceback + + +if TYPE_CHECKING: + from torch.fx.proxy import Proxy + + +compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd") +verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose") + + +def snapshot_verbose_logging_enabled(): + return torch._logging._internal.log_state.is_artifact_enabled( + "compiled_autograd_verbose" + ) + + +def cpp_verbose_log_fn(msg: str) -> None: + verbose_log.debug(msg) + + +def snapshot_cudagraph_enabled(): + return torch._inductor.config.triton.cudagraphs + + +def maybe_clone(x): + if x is not None: + return clone_preserve_strides(x) + return x + + +class AutogradCompilerInstance: + def __init__(self, compiler_fn) -> None: + self.compiler_fn = compiler_fn + self.stack = contextlib.ExitStack() + self.close = self.stack.close + self.shape_env = ShapeEnv() + self.fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=True, + allow_non_fake_inputs=True, + shape_env=self.shape_env, + ) + self.fx_tracer = PythonKeyTracer() + self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic") + self.hooks_proxy: Optional[Proxy] = None + self.graph_placeholders = ["inputs", "sizes", "scalars", "hooks"] + + def wrap_fake(self, x, source): + assert isinstance(x, torch.Tensor) + return self.fake_tensor_mode.from_tensor(x, source=source) + + @staticmethod + def source(name, idx) -> GetItemSource: + return GetItemSource(LocalSource(name), idx) + + def begin_capture( + self, + inputs: List[torch.Tensor], + sizes: List[int], + scalars: List[Union[int, float]], + ): + counters["compiled_autograd"]["captures"] += 1 + self.aot_graph_cls_name: Optional[str] = None + self.aot_graph_infos: Dict[int, Dict[str, Any]] = {} + self.fx_tracer.root = torch.nn.Module() + self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) + self.fx_tracer.tensor_attrs = {} + args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = ( + self.fx_tracer.create_proxy("placeholder", name, (), {}) + for name in self.graph_placeholders + ) + + # tensor inputs to fake tensors + inputs = [ + self.wrap_fake(x, self.source("inputs", idx)) + for idx, x in enumerate(inputs) + ] + self.bind_tensors_to_proxies(inputs, args_proxy) + + # size inputs to symints + sizes = [ + self.shape_env.create_unspecified_symint_and_symbol( + val, + self.source("sizes", idx), + DimDynamic.DYNAMIC, + ) + for idx, val in enumerate(sizes) + ] + self.bind_tensors_to_proxies(sizes, sizes_proxy) + + for idx, val in enumerate(scalars): + source = self.source("scalars", idx) + if isinstance(val, int): + scalars[idx] = self.shape_env.create_unspecified_symint_and_symbol( + val, + source, + DimDynamic.DYNAMIC, + ) + elif isinstance(val, float): + scalars[idx] = self.shape_env.create_symfloatnode( + self.shape_env.create_unspecified_symbol( + val, + source=source, + dynamic_dim=DimDynamic.DYNAMIC, + ), + hint=val, + source=source, + ) + else: + raise AssertionError("Unexpected scalar type: ", type(val)) + self.bind_tensors_to_proxies(scalars, scalars_proxy) + + # TODO(jansel): are all these modes needed? + self.stack.enter_context(decompose({})) + self.stack.enter_context(self.fake_tensor_mode) + self.stack.enter_context(self.proxy_mode) + self.stack.enter_context(disable_autocast_cache()) + self.stack.enter_context(preserve_node_meta()) + return inputs, sizes, scalars + + def proxy_call_backward( + self, + inputs, + output_metadatas, + saved_tensors, + backward_idx: int, + ): + assert self.hooks_proxy is not None + backward_c_function = self.hooks_proxy[backward_idx] # type: ignore[index] + proxies = self.fx_tracer.create_proxy( + kind="call_function", + target=call_backward, + args=( + backward_c_function, + self.to_proxy(saved_tensors), + *self.to_proxy(inputs), + ), + kwargs={}, + ) + + with disable_proxy_modes_tracing(): + # create fake Tensors + grad_ins: List[Optional[torch.Tensor]] = [] + for output_metadata in output_metadatas: + if output_metadata is None: + grad_ins.append(None) + continue + + layout, device, dtype, size = output_metadata + grad_ins.append( + torch.empty(size=size, dtype=dtype, layout=layout, device=device) + ) + self.bind_tensors_to_proxies(grad_ins, proxies) + return tuple(grad_ins) + + def proxy_call_hook(self, hook, *args, **kwargs): + return self.fx_tracer.create_proxy( + "call_function", + call_hook, + ( + hook, + *[self.to_proxy(x) for x in args], + ), + kwargs, + ) + + def tensor_pre_hook(self, inputs, hook_id, i: int): + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + proxy = self.proxy_call_hook( + hook, + inputs[i], + hook_type="tensor_pre_hook", + ) + with disable_proxy_modes_tracing(): + inputs[i] = maybe_clone(inputs[i]) + self.bind_tensors_to_proxies([inputs[i]], [proxy]) + return inputs + + def pre_hook(self, inputs, hook_id): + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + proxies = self.proxy_call_hook( + hook, + inputs, + hook_type="pre_hook", + ) + with disable_proxy_modes_tracing(): + inputs = [maybe_clone(x) for x in inputs] + self.bind_tensors_to_proxies(inputs, proxies) + return inputs + + def post_hook(self, outputs, inputs, hook_id): + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + proxies = self.proxy_call_hook( + hook, + outputs, + inputs, + hook_type="post_hook", + ) + with disable_proxy_modes_tracing(): + outputs = [maybe_clone(x) for x in outputs] + self.bind_tensors_to_proxies(outputs, proxies) + return outputs + + def post_acc_grad_hook(self, input, hook_id): + assert isinstance(input, torch.Tensor) + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + proxy = self.proxy_call_hook( + hook, + input, + hook_type="post_acc_grad_hook", + ) + with disable_proxy_modes_tracing(): + input = [maybe_clone(input)] + self.bind_tensors_to_proxies(input, [proxy]) + return input + + # Note: [Compiled autograd and cudagraphs] + # Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_. + # When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph + # with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the + # scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too. + def move_graph_nodes_to_cuda(self, graph) -> List[int]: + to_move: Dict[int, torch.fx.Node] = {} + has_cuda_inputs = False + nodes = list(graph.nodes) + assert nodes[0].target == "inputs" + inputs = nodes[0] + inputs_users = list(inputs.users.keys()) + # input access nodes should immediately follow placeholder nodes + first_getitem_idx = len(self.graph_placeholders) + assert nodes[first_getitem_idx] == inputs_users[0] + last_getitem_idx = first_getitem_idx + len(inputs_users) - 1 + assert nodes[last_getitem_idx] == inputs_users[-1] + for i, node in enumerate(inputs_users): + if not has_cuda_inputs and node.meta["val"].device.type == "cuda": + has_cuda_inputs = True + continue + + is_cpu = node.meta["val"].device.type == "cpu" + is_scalar = len(node.meta["val"].size()) == 0 + if is_cpu and is_scalar: + node_users = list(node.users.keys()) + if all( + isinstance(user.target, torch._ops.OpOverload) + and user.target.namespace in ("prims", "aten") + for user in node_users + ): + # all users are prims/aten, can move safely + to_move[i] = node + + # only move cpu scalars to cuda if there were cuda activations in this graph, + # this is to handle the case where cudagraphs is enabled on a cpu-only graph + if has_cuda_inputs: + for node in to_move.values(): + node.meta["val"] = node.meta["val"].cuda() + + # return runtime indices we need to move to cuda + return list(to_move.keys()) + + return [] + + def end_capture(self, outputs): + self.fx_tracer.create_proxy( + "call_function", + FakeCompiledAutogradEngine._exec_final_callbacks_stub, + (), + {}, + ) + self.stack.close() + self.fx_tracer.create_node( + "output", + "output", + (self.fx_tracer.create_arg(self.to_proxy(outputs)),), + {}, + ) + self.rename_aot_dispatcher_nodes() + self.reorder_accumulate_grad_nodes() + runtime_inputs_to_move: List[int] = [] + if snapshot_cudagraph_enabled(): + runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) + + graph = GraphModule( + self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd" + ) + set_locals_to_steal(graph, ["inputs"]) + lazy_graph_code = lazy_format_graph_code( + "Compiled autograd graph", + graph, + include_device=True, + include_stride=True, + colored=True, + ) + compiled_autograd_log.info("%s", lazy_graph_code) + verbose_log.debug("%s", lazy_graph_code) + trace_structured( + "compiled_autograd_graph", + payload_fn=lambda: graph.print_readable(print_output=False), + ) + + def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks): + global in_compiled_autograd_region + try: + in_compiled_autograd_region = True + for i in runtime_inputs_to_move: + inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True) + + return compiled_fn(inputs, sizes, scalars, hooks) + finally: + in_compiled_autograd_region = False + + return runtime_wrapper, self.compiler_fn(graph) + + def rename_aot_dispatcher_nodes(self): + """ + Renames nodes as they appear in the AOTDispatcher backward graphs, prefixed by AOT id + e.g. AOTDispatcher backward graph X's `sin_Y` -> `aotX_sin_Y` + """ + if self.aot_graph_cls_name is None: + return + + def is_similar(a: torch.fx.node.Node, b: torch.fx.node.Node): + target_match = a.target == b.target + if not target_match: + target_match = ( + hasattr(a.target, "__name__") + and hasattr(b.target, "__name__") + and a.target.__name__ == b.target.__name__ + ) + return ( + target_match + and a.op == b.op + and a.type == b.type + and len(a.all_input_nodes) == len(b.all_input_nodes) + ) + + for nodecall_index, info in self.aot_graph_infos.items(): + ca_node_start_idx = info["ca_node_start_idx"] + aot_id = info["aot_id"] + aot_graph = info["aot_gm"].graph + + # 1. Find the first op from user code in the AOT graph + aot_it = iter(aot_graph.nodes) + aot_node = next(aot_it) + assert aot_node is not None + try: + while aot_node.op != "call_function": + aot_node = next(aot_it) + except StopIteration: + continue + + try: + # 2. Find the first op in the compiled autograd graph segment + ca_it = iter(self.fx_tracer.graph.nodes) + for _ in range(ca_node_start_idx): + next(ca_it) + ca_node = next(ca_it) + + # Graphs should all end with output node + while ca_node.op != "output" and not is_similar(ca_node, aot_node): + # The compiled autograd graph may contain lazily inserted ops + # We skip those when aligning nodes + ca_node = next(ca_it) + + # 3. Keep alligned and rename nodes + while aot_node.op != "output" and ca_node.op != "output": + if not ca_node.users: + # TODO: DCE for compiled autograd graph + ca_node = next(ca_it) + continue + + if not is_similar(aot_node, ca_node): + # There should be no lazily inserted ops in the middle of a match + # So any deviation is an error + raise StopIteration + + ca_node.name = f"aot{aot_id}_{aot_node.name}" + for i, inp in enumerate(aot_node.all_input_nodes): + ca_node.all_input_nodes[i].name = f"aot{aot_id}_{inp.name}" + + aot_node = next(aot_it) + ca_node = next(ca_it) + except StopIteration: + verbose_log.debug( + "Failed to match %s%s (NodeCall %s) nodes with AOT backward graph %s nodes", + self.aot_graph_cls_name, + aot_id, + nodecall_index, + aot_id, + ) + + def reorder_accumulate_grad_nodes(self): + """ + Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of + the graph. This differs from eager mode, which schedules them as soon as possible. This + pass attempts to reorder the graph to mimic eager behavior. + """ + for node in self.fx_tracer.graph.find_nodes( + op="call_function", target=torch.ops.inductor.accumulate_grad_.default + ): + arg = max(node.args) # last arg + if arg is not node.prev and arg.op != "placeholder": + arg.append(node) + + def to_proxy(self, t): + if t is None: + return None + if isinstance(t, list): + return [self.to_proxy(x) for x in t] + if isinstance(t, tuple): + return tuple(self.to_proxy(x) for x in t) + # can it be torch.SymInt as the code used to imply? + assert isinstance(t, torch.Tensor) + proxy_tensor = fetch_object_proxy(self.fx_tracer, t) + assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor) + return proxy_tensor.proxy + + def bind_tensors_to_proxies(self, tensors, proxies): + if isinstance(proxies, torch.fx.Proxy): + proxies = [proxies[i] for i in range(len(tensors))] # type: ignore[index] + assert len(tensors) == len(proxies) + track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer) + + def bind_backward_state(self, index: int): + assert self.hooks_proxy is not None + proxy = self.hooks_proxy[index] # type: ignore[index] + bw_state = BackwardState() + track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer) + return bw_state + + def set_node_origin( + self, + node_name: str, + nodecall_index: int, + pyobj: Optional[torch.autograd.Function], + ): + maybe_aot_id = "" + if pyobj is not None: + forward_cls = pyobj._forward_cls # type: ignore[attr-defined] + if hasattr(forward_cls, "_aot_id"): + # backward was created by AOT Dispatcher + self.aot_graph_cls_name = node_name + maybe_aot_id = forward_cls._aot_id + self.aot_graph_infos[nodecall_index] = { + "ca_node_start_idx": len(self.fx_tracer.graph.nodes), + "aot_id": maybe_aot_id, + "aot_gm": forward_cls._lazy_backward_info.bw_module, + } + + new_code = f"{node_name}{maybe_aot_id} (NodeCall {nodecall_index})" + raw_stack_trace = CapturedTraceback.extract().format()[-1] + new_stack_trace = raw_stack_trace.replace( + "raw_stack_trace = CapturedTraceback.extract().format()[-1]", new_code + ) + set_stack_trace(new_stack_trace) + + +# state of the autograd engine dispatch, kept in sync by enable/disable context managers +compiled_autograd_enabled = False + +# global flag to check if we are processing graphs produced from a compiled autograd graph +in_compiled_autograd_region = False + + +@contextlib.contextmanager +def enable(compiler_fn): + prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler( + functools.partial(AutogradCompilerInstance, compiler_fn) + ) + if snapshot_verbose_logging_enabled(): + torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn) + global compiled_autograd_enabled + compiled_autograd_enabled = True + try: + with torch.autograd.set_multithreading_enabled(False): + yield + finally: + if not prior: + compiled_autograd_enabled = False + torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) + + +@contextlib.contextmanager +def disable(): + prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) + global compiled_autograd_enabled + compiled_autograd_enabled = False + try: + yield + finally: + if prior: + compiled_autograd_enabled = True + torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) + + +# return to starting state of a new process +def reset() -> None: + compiled_autograd_enable = False + assert not in_compiled_autograd_region + torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) + torch._C._dynamo.compiled_autograd.set_verbose_logger(None) diff --git a/lib/python3.10/site-packages/torch/_dynamo/comptime.py b/lib/python3.10/site-packages/torch/_dynamo/comptime.py new file mode 100644 index 0000000000000000000000000000000000000000..972d79d48fa8b2e841aa491e315798c7e56e0a51 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/comptime.py @@ -0,0 +1,401 @@ +# mypy: allow-untyped-defs +# This file establishes the public comptime interface to Dynamo. +# This allows Dynamo users to execute arbitrary Python code while +# Dynamo is symbolically evaluating their original programs. +# +# The goal of the public API is to give users rope, without actually +# leaking private implementation details of Dynamo. + +import builtins +import dis +import time +import traceback +from typing import Optional, Union + +import torch +from torch.fx.experimental.symbolic_shapes import free_symbols + +from .exc import unimplemented +from .variables import NewCellVariable +from .variables.constant import ConstantVariable +from .variables.misc import ClosureVariable +from .variables.tensor import SymNodeVariable + + +class ComptimeVar: + """ + A ComptimeVar represents a Python value, at some particular point + in time, in the Python code we are symbolically evaluating with + torchdynamo. This must be distinguished from a runtime value, as + at compile-time there are some properties of the variable we + do not know (for example, if the ComptimeVar represents a Tensor, + we only know metadata about the tensor; we do NOT know what the + actual data in the Tensor is.) + """ + + def __init__(self, v) -> None: + self.__variable = v + + def as_proxy(self): + """ + Returns an fx.Proxy (or tuple/list of fx.Proxy) representing + this variable in the FX graph we are assembling to pass + to the user compiler. + + This method only works for variables we actually track in + the FX graph, aka Tensors (and ints, if you are compiling + with dynamic shapes). In particular, if you have a list + or tuple of tensors, you will get a list/tuple of proxies + (not a single proxy representing the entire list/tuple). + """ + return self.__variable.as_proxy() + + def is_proxy(self): + """ + Returns True if as_proxy() would succeed. + """ + return self.__variable.is_proxy() + + def as_fake(self): + """ + Returns a "fake" value (either a FakeTensor or a SymInt) + representing the variable in question. This only works + for variables that denote Tensor or int. You can use + this to query metadata; e.g., v.as_fake().size(0) will + tell you the compile-time known size of the tensor. + + WARNING: Do NOT mutate the returned tensor. + """ + return self.__variable.as_proxy().node.meta["example_value"] + + def size(self, dim: Optional[int] = None) -> Union[int, torch.SymInt]: + """ + Returns the size of the tensor (if dim is None) or the size + at the dimension dim. The returned size may be a SymInt. + """ + return self.as_fake().size(dim) + + def python_type(self): + """ + Returns what type(v) would have returned for the variable + at compile time. + """ + return self.__variable.python_type() + + def as_python_constant(self): + """ + Returns the Python value this variable would have, but only if it is + completely known at compile-time (e.g., it is constant). + + WARNING: Do NOT mutate the returned constant. The returned constant + may or may not correspond to the actual value this variable may take + on at runtime; for example, if the variable in question is a constant + list, we may return a copy of that list. + """ + return self.__variable.as_python_constant() + + def is_python_constant(self): + """ + Returns True if as_python_constant would succeed. + """ + return self.__variable.is_python_constant() + + def is_dynamic(self): + if isinstance(self.__variable, SymNodeVariable): + fs = free_symbols(self.__variable.sym_num) + return bool(fs) + return False + + def force_static(self): + """ + Forces that a value is static, inducing a guard on its specific value + """ + if isinstance(self.__variable, SymNodeVariable): + self.__variable.evaluate_expr() + elif isinstance(self.__variable, ConstantVariable): + # TODO: Maybe complain if this isn't a int/bool/float variable + pass + else: + raise AssertionError( + f"cannot force {self.__variable} ({type(self.__variable)}) static" + ) + + def _i_will_not_complain_if_bc_breaks_VariableTracker(self): + """ + Returns the internal data structure VariableTracker that Dynamo uses + to represent variables at compile time. There are no BC guarantees on + this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if you rely on + it. + """ + return self.__variable + + def __repr__(self) -> str: + return self.__variable.debug_repr() + + # TODO: API for adding a custom guard + + +class ComptimeContext: + """ + This context class provides access to a public API for Dynamo's internals. + If there is something here you would find useful that is missing, please + file a feature request at https://github.com/pytorch/pytorch/ + """ + + def __init__(self, tx) -> None: + self.__tx = tx + + def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar: + """ + Retrieve the compile-time known information about a local. + """ + tx = self.__get_tx(stacklevel) + + # This is analogous to LOAD_DEREF + if hasattr(tx, "closure_cells") and name in tx.closure_cells: + cell = tx.closure_cells[name] + if isinstance(cell, ClosureVariable): + return ComptimeVar(tx.output.root_tx.symbolic_locals[cell.name]) + else: + return ComptimeVar(tx.output.side_effects.load_cell(cell)) + else: + r = tx.symbolic_locals[name] + if isinstance(r, NewCellVariable): + return ComptimeVar(tx.output.side_effects.load_cell(r)) + else: + return ComptimeVar(r) + + def graph_break(self, msg="ComptimeContext.graph_break"): + """ + Manually trigger a graph break + """ + unimplemented(msg) + + def graph(self): + """ + Retrieve the partially constructed FX graph that would be + passed to the user compiler after compilation. + """ + return self.__tx.output.graph + + def assert_static(self, val): + """ + Asserts that the int is static (and not dynamic, per dynamic shapes) + """ + assert ( + not val.is_dynamic() + ), "expected static but got dynamic (run with TORCH_LOGS=dynamic for more info)" + + def print_graph(self, *, verbose=True, file=None): + """ + Print the partially constructed FX graph that would be passed + to the user compiler after compilation. + """ + print( + self.__tx.output.graph.python_code("self", verbose=verbose).src, file=file + ) + + def parent(self): + return ComptimeContext(self.__tx.parent) + + def __get_tx(self, stacklevel): + tx = self.__tx + for _ in range(stacklevel): + tx = tx.parent + return tx + + def print(self, val, *, file=None): + print(repr(val), file=file) + + def print_disas(self, *, file=None, stacklevel=0): + """ + Print the current series of opcodes being executed (not including + parent frames), including where you are in the particular opcode + stream. + """ + tx = self.__get_tx(stacklevel) + print( + dis.Bytecode( + tx.f_code, + current_offset=tx.instructions[tx.instruction_pointer].offset, + ).dis(), + file=file, + ) + + def print_value_stack(self, *, file=None, stacklevel=0): + """ + Print the current Python value stack. Note that this is NOT the same + as the traceback; use print_bt() to print that. Note that at + stacklevel=0, this will typically be empty, as comptime cannot + currently be used in an expression context where there would be + intermediates on the stack. If you would find this useful, please + file a bug at https://github.com/pytorch/pytorch/ + + NB: Stack grows downwards in our print + """ + tx = self.__get_tx(stacklevel) + for s in tx.stack: + print(f"- {s.debug_repr()}", file=file) + + def print_locals(self, *, file=None, stacklevel=0): + """ + Print all of the locals available in the current context. + By default this view is very limited; you can get more information + about any individual local using get_local(). + """ + tx = self.__get_tx(stacklevel) + for k, v in tx.symbolic_locals.items(): + print(f"{k} = {v.debug_repr()}", file=file) + + def print_bt(self, *, file=None, stacklevel=0): + """ + Print the user code backtrace, starting at the beginning of the + frame Dynamo started evaluating. Note that this MAY NOT go all + the way to the torch.compile invocation, as we may have done + a graph break and are compiling an intermediate frame as the + starting point. If you think the other behavior would be better, + file a bug at https://github.com/pytorch/pytorch/ + """ + stack = [] + tx = self.__get_tx(stacklevel) + while tx is not None: + stack.append(tx.frame_summary()) + tx = getattr(tx, "parent", None) + print( + "".join(traceback.StackSummary.from_list(reversed(stack)).format()), + file=file, + ) + + def print_guards(self, *, file=None): + """ + Print the currently installed guards for the Dynamo context. + This does NOT include guards associated with variables that + may or may not be installed in the future if those variables + are used. + """ + # TODO: improve print format, current guard format is extremely + # verbose + print( + "\n".join(f"{repr(guard)}" for guard in sorted(self.__tx.output.guards)), + file=file, + ) + + def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self): + """ + Returns the internal data structure InstructionTranslator that Dynamo + uses to track state of symbolic evaluation. There are no BC + guarantees on this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if + you rely on it. + """ + return self.__tx + + def sleep(self, sec): + time.sleep(sec) + + +class _Comptime: + @staticmethod + def __call__(fn, fallback_fn=lambda: None): + """fn gets called at compile time in TorchDynamo, calls fallback_fn otherwise""" + fallback_fn() + + # Convenience wrappers that are more compact to use + + @staticmethod + def graph_break(): + comptime(lambda ctx: ctx.graph_break()) + + @staticmethod + def print(e): + comptime(lambda ctx: ctx.print(ctx.get_local("e")), lambda: print(e)) + + @staticmethod + def print_graph(): + comptime(lambda ctx: ctx.print_graph()) + + @staticmethod + def print_disas(*, stacklevel=0): + comptime( + lambda ctx: ctx.print_disas( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + + @staticmethod + def print_value_stack(*, stacklevel=0): + comptime( + lambda ctx: ctx.print_value_stack( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + + # This is a more useful variant of print_value_stack that can be used + # in an expression context; e.g., x + print_value_stack_and_return(y + z), + # you will see x on the stack prior to the addition operation + @staticmethod + def print_value_stack_and_return(e, *, stacklevel=0): + comptime( + lambda ctx: ctx.print_value_stack( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + return e + + @staticmethod + def print_locals(*, stacklevel=0): + comptime( + lambda ctx: ctx.print_locals( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + + @staticmethod + def print_bt(*, stacklevel=0): + comptime( + lambda ctx: ctx.print_bt( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + + @staticmethod + def print_guards(): + comptime(lambda ctx: ctx.print_guards()) + + @staticmethod + def assert_static(val): + comptime(lambda ctx: ctx.assert_static(ctx.get_local("val"))) + + @staticmethod + def force_static(val): + comptime(lambda ctx: ctx.get_local("val").force_static()) + + @staticmethod + def breakpoint(): + """ + Like pdb breakpoint(), but drop into pdb whenever this line + of code is compiled by dynamo. Use it by putting + this in your model code:: + + from torch._dynamo.comptime import comptime + comptime.breakpoint() + + And then, inside pdb, you can access 'ctx' to query things + about the compilation context:: + + (Pdb) !ctx.print_bt() + (Pdb) !ctx.print_locals() + (Pdb) p ctx.get_local("attention").as_fake() + """ + + def inner(inner_ctx): + ctx = inner_ctx.parent() + builtins.breakpoint() + + comptime(inner) + + @staticmethod + def sleep(sec): + comptime(lambda ctx: ctx.sleep(ctx.get_local("sec").as_python_constant())) + + +comptime = _Comptime() diff --git a/lib/python3.10/site-packages/torch/_dynamo/config.py b/lib/python3.10/site-packages/torch/_dynamo/config.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba29961af36e93b5c16d3ce02181fe989e0458c --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/config.py @@ -0,0 +1,490 @@ +# mypy: allow-untyped-defs +import getpass +import inspect +import os +import re +import sys +import tempfile +from os.path import abspath, dirname +from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union + +import torch + + +def is_fbcode(): + return not hasattr(torch.version, "git_version") + + +# to configure logging for dynamo, aot, and inductor +# use the following API in the torch._logging module +# torch._logging.set_logs(dynamo=, aot=, inductor) +# or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity) +# see this design doc for more detailed info +# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit# +# the name of a file to write the logs to +# [@compile_ignored: debug] +log_file_name: Optional[str] = None + +# [@compile_ignored: debug] Verbose will print full stack traces on warnings and errors +verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1" + +# [@compile_ignored: runtime_behaviour] verify the correctness of optimized backend +verify_correctness = False + +# need this many ops to create an FX graph +minimum_call_count = 1 + +# turn on/off DCE pass +dead_code_elimination = True + +# disable (for a function) when cache reaches this size + +# controls the maximum number of cache entries with a guard on same ID_MATCH'd +# object. It also controls the maximum size of cache entries if they don't have +# any ID_MATCH'd guards. +# [@compile_ignored: runtime_behaviour] +cache_size_limit = 8 + +# [@compile_ignored: runtime_behaviour] safeguarding to prevent horrible recomps +accumulated_cache_size_limit = 256 + +# [@compile_ignored: runtime_behaviour] skip tracing recursively if cache limit is hit +skip_code_recursive_on_cache_limit_hit = True + +# whether or not to specialize on int inputs. This only has an effect with +# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int +# inputs. Note that assume_static_by_default will also cause ints to get +# specialized, so this is mostly useful for export, where we want inputs +# to be dynamic, but accesses to ints should NOT get promoted into inputs. +specialize_int = False + +# Whether or not to specialize on float inputs. Dynamo will always promote +# float inputs into Tensor inputs, but at the moment, backends inconsistently +# support codegen on float (this is to be fixed). +specialize_float = True + +# legacy config, does nothing now! +dynamic_shapes = True + +use_lazy_graph_module = ( + os.environ.get("TORCH_COMPILE_USE_LAZY_GRAPH_MODULE", "1") == "1" +) + +# This is a temporarily flag, which changes the behavior of dynamic_shapes=True. +# When assume_static_by_default is True, we only allocate symbols for shapes marked dynamic via mark_dynamic. +# NOTE - this flag can be removed once we can run dynamic_shapes=False w/ the mark_dynamic API +# see [Note - on the state of mark_dynamic] +assume_static_by_default = True + +# This flag changes how dynamic_shapes=True works, and is meant to be used in conjunction +# with assume_static_by_default=True. +# With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail +# any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic. +automatic_dynamic_shapes = True + +# This flag changes how the shapes of parameters are treated. +# If this flag is set to True, then the shapes of torch.nn.Parameter as well as of torch.Tensor are attempted to be dynamic +# If this flag is set to False, then the shapes of torch.nn.Parameter are assumed to be static, +# while the shapes of torch.Tensor are assumed to be dynamic. +force_parameter_static_shapes = True + +# This flag ensures that the shapes of a nn module are always assumed to be static +# If the flag is set to True, then the shapes of a nn.module are assumed to be static +# If the flag is set to False, then the shapes of a nn.module can be dynamic +force_nn_module_property_static_shapes = True + +# Typically, if you mark_dynamic a dimension, we will error if the dimension +# actually ended up getting specialized. This knob changes the behavior so +# that we don't error at all. This is helpful for our CI where I'm using a +# heuristic to mark batch dimensions as dynamic and the heuristic may get it +# wrong. +allow_ignore_mark_dynamic = False + +# Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing) +guard_nn_modules = True + +# Uses CPython internal dictionary tags to detect mutation. There is some +# overlap between guard_nn_modules_using_dict_tags and guard_nn_modules flag. +# guard_nn_modules unspecializes the nn module instance and adds guard for each +# relevant member of the nn modules. On the other hand, +# guard_nn_modules_using_dict_tags specializes on each nn module instance but +# uses low overhead dict version matching to detect mutations, obviating the +# need to guard on members of the nn modules. With +# guard_nn_modules_using_dict_tags, the guard_nn_modules is not really required +# but kept around for debugging and discussing unspecializing nn module +# variables. +# TODO(janimesh, voz): Remove both of these flags (or atleast guard_nn_modules) +# once we have reached stability for the guard_nn_modules_using_dict_tags. +guard_nn_modules_using_dict_tags = True + +# This feature doesn't really work. We offer this flag for experimental +# purposes / if you want to help us build out support. +# +# torchdynamo has limited support for tensor subclasses that implement +# __torch_function__ see [Note: __torch_function__] in torch_function.py. +# Our current support is limited to tensor subclasses +# that DO NOT store metadata on the tensor (in general, dynamo does not +# support Python code that stores extra attributes on tensors at present). +# If your tensor subclass purely changes function call behavior via +# __torch_function__, you can allow torchdynamo to trace into it by +# adding it to traceable_tensor_subclasses. We don't do any safety checks, +# so it is up to you to ensure that your subclass is well behaved. See also +# https://github.com/pytorch/torchdynamo/issues/1948 +# +# We do NOT currently support __torch_dispatch__. The implementation is +# currently buggy, the main show stopper for nontrivial use is +# https://github.com/pytorch/torchdynamo/issues/1952 +traceable_tensor_subclasses: Set[Type[Any]] = set() + +# Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager. +# This is a good way to get your model to work one way or another, but you may +# lose optimization opportunities this way. Devs, if your benchmark model is failing +# this way, you should figure out why instead of suppressing it. +suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False)) + +# Record and write an execution record of the current frame to a file +# if an exception is encountered +# @compile_ignored[debug] +replay_record_enabled = os.environ.get("TORCH_COMPILE_REPLAY_RECORD", "0") == "1" + +# Rewrite assert statement in python with torch._assert +rewrite_assert_with_torch_assert = True + +# Disable dynamo +disable = os.environ.get("TORCH_COMPILE_DISABLE", False) + +# [@compile_ignored: runtime_behaviour] Get a cprofile trace of Dynamo +cprofile = os.environ.get("TORCH_COMPILE_CPROFILE", False) + +# legacy config, does nothing now! +skipfiles_inline_module_allowlist: Dict[Any, Any] = {} + +# If a string representing a PyTorch module is in this ignorelist, +# the `allowed_functions.is_allowed` function will not consider it +# when creating a list of PyTorch functions that will appear in +# FX IR. +allowed_functions_module_string_ignorelist = { + "torch.distributions", + "torch.testing", + "torch._refs", + "torch._prims", + "torch._decomp", +} + +# Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"} +# None - Minifier is switched off +# dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails +# aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails +# [@compile_ignored: debug] +repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None) + +# Compiler compilation debug info +# 1: Dumps the original graph out to repro.py if compilation fails +# 2: Dumps a minifier_launcher.py if compilation fails. +# 3: Always dumps a minifier_launcher.py. Good for segfaults. +# 4: Dumps a minifier_launcher.py if the accuracy fails. +# [@compile_ignored: debug] +repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2)) + +# By default, we try to detect accuracy failure by running both forward +# and backward of a torchdynamo produced graph (if you are using repro_after +# 'dynamo'). This setting forces us to only test the forward graph and +# not the backward graph. This can be helpful if you're trying to debug +# an inference only problem, but the minifier seems to be choking on the +# backwards step +# TODO: Detect this situation automatically so the user doesn't need +# to manually configure this +# [@compile_ignored: debug] +repro_forward_only = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1" + +# The tolerance we should use when testing if a compiled graph +# has diverged so that we should treat it as an accuracy failure +# [@compile_ignored: debug] +repro_tolerance = 1e-3 + + +# Whether to ignore non-floating point values when checking accuracy. +# Checking accuracy of non-floating point values such as boolean tensors +# can lead to false positives. +# [@compile_ignored: debug] +repro_ignore_non_fp = os.environ.get("TORCHDYNAMO_REPRO_IGNORE_NON_FP") == "1" + +# If True, when testing if two models are the same, we will test them against +# a third fp64 reference and only report a problem if the RMSE relative to the +# fp64 is greater. However, this will use more memory; you may disable this +# if memory usage is too high. +# [@compile_ignored: runtime_behaviour] +same_two_models_use_fp64 = True + +# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type. +# When this flag is set to False, we introduce a graph break instead of capturing. +# This requires dynamic_shapes to be True. +capture_scalar_outputs = os.environ.get("TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS") == "1" + +# Not all backends support operators that have dynamic output shape (e.g., +# nonzero, unique). When this flag is set to False, we introduce a graph +# break instead of capturing. This requires dynamic_shapes to be True. +# If you set this to True, you probably also want capture_scalar_outputs +# (these are separated for historical reasons). +capture_dynamic_output_shape_ops = ( + os.environ.get("TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS", "0") == "1" +) + +# hybrid backed unbacked symints +prefer_deferred_runtime_asserts_over_guards = False + +# For complex dynamic shapes guards that we're unable to specify with dynamo/export's +# range constraints + dims + derived dims language, we raise constraint violation +# errors or specialize by default. If set to True, this flag avoids crashing/specialization, +# and allows complex guards as runtime assertions in the graph. +allow_complex_guards_as_runtime_asserts = False + +# By default, dynamo will treat all ints as backed SymInts, which means (1) it +# will wait to see the int change over multiple runs before generalizing and +# (2) it will still always 0/1 specialize an int. When true, this knob +# forces dynamo to treat _length_per_key and _offset_per_key on +# KeyedJaggedTensor from torchrec as size-like unbacked SymInts, so that +# they (1) generalize immediately and (2) unsoundly never compare equal to +# 0/1. This is not on by default as AOTAutograd/Inductor cannot currently +# compile this code; however, this can be useful for export. +force_unspec_int_unbacked_size_like_on_torchrec_kjt = False + +# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and +# false_fn produces code with identical guards. +enforce_cond_guards_match = True + +# Specify how to optimize a compiled DDP module. The flag accepts a boolean +# value or a string. There are 4 modes. +# 1. "ddp_optimizer" (or True): with "ddp_ptimizer", Dynamo will automatically +# split model graph into pieces to match DDP bucket sizes to allow DDP +# comm/compute overlap. +# 2. "python_reducer" (experimental): this optimization requires the usage +# of compiled_autograd. With "python_reducer", DDP will disable the C++ reducer +# and use the Python reducer to allow compiled_autograd to trace the +# communication and allow comm/compute overlap without graph-breaks. +# 3. "python_reducer_without_compiled_forward" (experimental): this mode is +# similar to "python_reducer". One should only use this optimization mode +# when compiled_autograd is used but the DDP module is not compiled. +# 4. "no_optimization" (or False): Dynamo won't split the model graph, nor +# will Python reducer be used. With this mode, there will be no graph-breaks +# and the original DDP C++ reducer will be used. There will no comm/compute +# overlap. This mode CANNOT be used with compiled_autograd. +# Note that to avoid breaking the existing usage, mode 1 and mode 4 can be +# specified with a boolean value. True is using ddp_optimizer and False is +# no optimization. +optimize_ddp: Union[bool, str] = True + +# By default, Dynamo emits runtime asserts (e.g. torch._check, torch._check_is_size) in the graph. +# In some cases those asserts could be performance costly +# E.g. torch._check(tensor[0].item() > 2) for tensor on cuda will require cuda sync. +# Setting this to True keeps them hinting to symbolic shapes engine, +# but not be emitted in the graph. +do_not_emit_runtime_asserts: bool = ( + os.environ.get("TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS", "0") == "1" +) + +_ddp_optimization_mode = [ + "ddp_optimizer", + "python_reducer", # experimental mode + "python_reducer_without_compiled_forward", # experimental mode + "no_optimization", +] + + +def _get_optimize_ddp_mode(): + m = sys.modules[__name__] + if isinstance(m.optimize_ddp, bool): + if m.optimize_ddp: + mode = "ddp_optimizer" + else: + mode = "no_optimization" + elif isinstance(m.optimize_ddp, str): + mode = m.optimize_ddp + else: + raise ValueError(f"Invalid type, {type(optimize_ddp)=}") + + assert mode in m._ddp_optimization_mode, f"Invalid mode {mode=}" + return mode + + +# Skip tracing the torchrec files added to trace_rules.FBCODE_SKIP_DIRS +skip_torchrec = True + + +# No longer used +optimize_ddp_lazy_compile = False + +# Whether to skip guarding on FSDP-managed modules +skip_fsdp_guards = True +# Whether to apply torch._dynamo.disable() to FSDP2 hooks. +# Defaults to True. If Traceable FSDP2 is used, set this to False. +skip_fsdp_hooks = True + +# Make dynamo skip guarding on hooks on nn modules +# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them, +# dynamo will not notice and will execute whichever version you first compiled. +skip_nnmodule_hook_guards = True + +# If True, raises exception if TorchDynamo is called with a context manager +raise_on_ctx_manager_usage = True + +# If True, raise when aot autograd is unsafe to use +raise_on_unsafe_aot_autograd = False + +# If true, error if you torch.jit.trace over a dynamo-optimized function. +# If false, silently suppress dynamo +error_on_nested_jit_trace = True + +# If true, error with a better message if we symbolically trace over a +# dynamo-optimized function. If false, silently suppress dynamo. +error_on_nested_fx_trace = True + +# Disables graph breaking on rnn. YMMV with backends. +allow_rnn = False + +# If true, enables feature that captures PyTorch sparsity in the +# exported FX graph. This flag should become the default eventually +# and be removed, but currently provides a way to fall back to old +# graph breaking behavior. +capture_sparse_compute = False if is_fbcode() else True + +# If true, error if we try to compile a function that has +# been seen before. +# [@compile_ignored: runtime_behaviour] +error_on_recompile = False + +# [@compile_ignored: debug] Whether to report any guard failures (deprecated: does not do anything) +report_guard_failures = True + +# [@compile_ignored: debug] root folder of the project +base_dir = dirname(dirname(dirname(abspath(__file__)))) + +# Trace through NumPy or graphbreak +trace_numpy = True + +# Default NumPy dtypes when tracing with torch.compile +# We default to 64bits. For efficiency, one may want to change these to float32 +numpy_default_float = "float64" +numpy_default_complex = "complex128" +numpy_default_int = "int64" + +# use numpy's PRNG if True, pytorch otherwise +use_numpy_random_stream = False + +# Use C++ guard manager +enable_cpp_guard_manager = os.environ.get("TORCHDYNAMO_CPP_GUARD_MANAGER", "1") == "1" + +# Inline inbuilt nn modules +inline_inbuilt_nn_modules = not is_fbcode() + +# When set, total compile time instruction count is recorded using +# torch._dynamo.utilsCompileTimeInstructionCounter. +record_compile_time_instruction_count = False + + +def default_debug_dir_root(): + # [@compile_ignored: debug] + DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR" + if DEBUG_DIR_VAR_NAME in os.environ: + return os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug") + elif is_fbcode(): + return os.path.join( + tempfile.gettempdir(), getpass.getuser(), "torch_compile_debug" + ) + else: + return os.path.join(os.getcwd(), "torch_compile_debug") + + +# [@compile_ignored: debug] +debug_dir_root = default_debug_dir_root() + +# [@compile_ignored: debug] +_save_config_ignore = { + "repro_after", + "repro_level", + # workaround: "cannot pickle PyCapsule" + "constant_functions", + # workaround: "cannot pickle module" + "skipfiles_inline_module_allowlist", +} + +# for backend="cudagraphs", mutations on input be sent to the cudagraph backend +# or replayed in aot_autograd epilogue. default is False because mutation on inputs +# can prevent cudagraphing. +cudagraph_backend_keep_input_mutation = False + +# enable cudagraph support for mutated inputs from prior cudagraph pool +cudagraph_backend_support_input_mutation = False + +# When True, only ops that have the torch.Tag.pt2_compliant tag +# will be allowed into the graph; all other ops will be disallowed +# and will fall back to eager-mode PyTorch. Useful to ensure +# correctness of custom ops. +only_allow_pt2_compliant_ops = False + +capture_autograd_function = True + +# enable/disable dynamo tracing for `torch.func` transforms +capture_func_transforms = True + +# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode). +log_compilation_metrics = True + +# A set of logging functions which will be reordered to the end of graph breaks, +# allowing dynamo to construct larget graph. Note that there are some +# limitations to this, such as how it does not correctly print objects that were +# mutated after the print statement. +reorderable_logging_functions: Set[Callable[[Any], None]] = set() + +# simulates what would happen if we didn't have support for BUILD_SET opcode, +# used for testing +inject_BUILD_SET_unimplemented_TESTING_ONLY = False + +_autograd_backward_strict_mode_banned_ops = [ + "stride", + "requires_grad", + "storage_offset", + "layout", + "data", +] + +_autograd_backward_strict_mode_banned_ops.extend( + [name for name, _ in inspect.getmembers(torch.Tensor) if re.match(r"^is_.*", name)] +) + +# Enables caching of dispatches to fake tensors. +fake_tensor_cache_enabled = ( + os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE", "1") == "1" +) + +# Enables cross checking between the fake tensor cache and dispatch. +fake_tensor_cache_crosscheck_enabled = ( + os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE_CROSSCHECK", "0") == "1" +) + +# Enables the Compiled Autograd engine to trace .backward() calls made under torch.compile(). +# Note: AOT Autograd will still trace joint graphs. +compiled_autograd = False + +# Enables use of collectives *during* compilation to synchronize behavior +# across ranks. Today, this is used solely to modify automatic_dynamic_shapes +# behavior, making it so that we infer that if an input is dynamic by +# inspecting whether or not its input size varies across ranks. Because +# this synchronization uses collectives, all ranks must run compilation at +# the same time; ranks must not diverge with graph breaks. This can be most +# reliably achieved by ensuring PT2 only is run on SPMD programs. If this +# invariant is inviolated, you will likely deadlock NCCL and encounter a +# NCCL timeout. +enable_compiler_collectives = os.environ.get("TORCH_COMPILER_COLLECTIVES", "0") == "1" + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + + def _make_closure_patcher(**changes): + ... + + +from torch.utils._config_module import install_config_module + + +install_config_module(sys.modules[__name__]) diff --git a/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py b/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..d15f5c4aa43dcffe4dcd343a895e546147dab2d0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py @@ -0,0 +1,1277 @@ +# mypy: allow-untyped-decorators +from __future__ import annotations + +import collections +import contextlib +import cProfile +import dis +import functools +import itertools +import logging +import os +import pstats +import random +import subprocess +import sys +import threading +import time +import traceback +import typing +import weakref +from pathlib import Path +from types import CodeType, FrameType, FunctionType, ModuleType +from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union +from typing_extensions import ParamSpec +from weakref import ReferenceType + +import torch +import torch._logging +from torch._C._dynamo.guards import GlobalStateGuard +from torch._dynamo.distributed import get_compile_pg +from torch._dynamo.utils import CompileTimeInstructionCounter +from torch._guards import compile_context, CompileContext, CompileId, tracing +from torch._logging import structured +from torch._utils_internal import ( + compile_time_strobelight_meta, + justknobs_check, + maybe_upload_prof_stats_to_manifold, + signpost_event, +) +from torch.fx._lazy_graph_module import _use_lazy_graph_module +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + GuardOnDataDependentSymNode, +) +from torch.fx.graph_module import _forward_from_src as original_forward_from_src +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils._python_dispatch import ( + _disable_current_modes, + is_in_torch_dispatch_mode, +) +from torch.utils._traceback import CapturedTraceback, format_traceback_short + +from . import config, exc, trace_rules +from .bytecode_analysis import remove_dead_code, remove_pointless_jumps +from .bytecode_transformation import ( + check_inst_exn_tab_entries_valid, + Instruction, + is_generator, + propagate_inst_exn_table_entries, + transform_code_object, +) +from .cache_size import ( + CacheSizeRelevantForFrame, + compute_cache_size, + exceeds_cache_size_limit, + is_recompilation, +) +from .eval_frame import always_optimize_code_objects, skip_code, TorchPatcher +from .exc import ( + augment_exc_message, + BackendCompilerFailed, + CacheLimitExceeded, + format_error_msg, + InternalTorchDynamoError, + SkipCodeRecursiveException, + TorchRuntimeError, + UncapturedHigherOrderOpError, + unimplemented, + Unsupported, +) +from .guards import ( + CheckFunctionManager, + get_and_maybe_log_recompilation_reason, + GuardedCode, +) +from .hooks import Hooks +from .replay_record import ExecutionRecord +from .symbolic_convert import ( + DistributedState, + InstructionTranslator, + LocalState, + SpeculationLog, +) +from .trace_rules import is_numpy +from .utils import ( + CleanupManager, + CompilationMetrics, + counters, + dynamo_timed, + format_bytecode, + frame_phase_timing, + gen_record_file_name, + get_chromium_event_logger, + increment_frame, + is_namedtuple, + istype, + LazyString, + orig_code_map, + record_compilation_metrics, + reset_graph_break_dup_checker, + setup_compile_debug, + troubleshooting_url, + write_record_to_file, +) + + +np: Optional[ModuleType] +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +if typing.TYPE_CHECKING: + from .backends.registry import CompilerFn + from .repro.after_dynamo import WrapBackendDebug + from .types import BytecodeHook, CacheEntry + from .variables.builder import FrameStateSizeEntry + + +log = logging.getLogger(__name__) +bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode") +graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") + + +compile_lock = threading.RLock() + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +class TODO_UNKNOWN: + pass + + +class Tracker: + def __init__(self) -> None: + self.seen: List[ReferenceType[CodeType]] = [] + self.seen_ids: Set[int] = set() + + def add(self, strong_obj: CodeType) -> None: + idx = id(strong_obj) + if idx not in self.seen_ids: + obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx)) + self.seen.append(obj) + self.seen_ids.add(idx) + + def __contains__(self, item: CodeType) -> bool: + return id(item) in self.seen_ids + + def clear(self) -> None: + self.seen.clear() + self.seen_ids.clear() + + +input_codes = Tracker() +output_codes = Tracker() + +initial_global_state: Optional[GlobalStateGuard] = None + + +@functools.wraps(original_forward_from_src) +def fx_forward_from_src_skip_result( + src: str, globals: Dict[str, Any], co_fields: Optional[Dict[str, str]] = None +) -> FunctionType: + # we monkey patch FX to prevent infinite loop of trying to convert + # our generated code + result = original_forward_from_src(src, globals, co_fields) + skip_code(result.__code__) + return result + + +def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: + """ + Context manager to: + 1) Save/restore torch.is_grad_enabled() state + 2) Save/restore python random state + 3) Save/restore torch random state + 4) Monkey patch torch.fx.graph_module._forward_from_src + """ + + @functools.wraps(fn) + def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: + guards = GlobalStateGuard() + prior_grad_mode = torch.is_grad_enabled() + # Just in case we get left in a bad dispatch state we want to restore + # it. This can happen because the dispatch bits aren't a true + # stack/counter - so we can't just increment/decrement them as we enter + # and leave. + with torch._C._PreserveDispatchKeyGuard(): + prior_inference_mode = torch.is_inference_mode_enabled() + prior_deterministic = torch.are_deterministic_algorithms_enabled() + prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled() + py_rng_state = random.getstate() + torch_rng_state = torch.random.get_rng_state() + cuda_rng_state = None + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + allow_tf32 = torch._C._get_cublas_allow_tf32() + prior_fwd_from_src = torch.fx.graph_module._forward_from_src + torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result + cleanup = setup_compile_debug() + + exit_stack = contextlib.ExitStack() + exit_stack.enter_context( + torch.fx._symbolic_trace._maybe_revert_all_patches() + ) + try: + return fn(*args, **kwargs) + finally: + cleanup.close() + exit_stack.close() + torch._C._set_grad_enabled(prior_grad_mode) + torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) + torch.use_deterministic_algorithms( + prior_deterministic, warn_only=prior_warn_only + ) + random.setstate(py_rng_state) + torch.random.set_rng_state(torch_rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + torch._C._set_cublas_allow_tf32(allow_tf32) + torch.fx.graph_module._forward_from_src = prior_fwd_from_src + assert ( + guards.check() + ), f"Global {guards.reason()}state changed while dynamo tracing, please report a bug" + + _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + return _fn + + +@TorchPatcher.suppress_torch_distributed_warnings +def has_tensor_in_frame(frame: FrameType) -> bool: + """Check if the frame has torch.* related bits""" + # Check if the function was decorated using torch._dynamo.optimize + if frame.f_code in always_optimize_code_objects: + return True + + # Check if there is global import of torch.* + for co_name in frame.f_code.co_names: + if co_name in frame.f_globals: + obj = frame.f_globals[co_name] + if isinstance(obj, ModuleType) and ( + obj.__name__.startswith("torch.") or obj is torch + ): + return True + # ... or a global import of numpy.* + if np and config.trace_numpy and (obj is np or is_numpy(obj)): + return True + + seen_ids: Dict[int, bool] = {} + + def has_tensor(obj: object) -> bool: + """Recursively check if the obj has a tensor""" + obj_id = id(obj) + if obj_id in seen_ids: + return seen_ids[obj_id] + seen_ids[obj_id] = False + + if isinstance(obj, (torch.Tensor, torch.nn.Module)) or ( + istype(obj, type) and issubclass(obj, torch.nn.Module) + ): + seen_ids[obj_id] = True + return seen_ids[obj_id] + elif ( + config.trace_numpy + and np + and (istype(obj, np.ndarray) or isinstance(obj, np.generic)) + ): + seen_ids[obj_id] = True + return seen_ids[obj_id] + elif istype(obj, (list, tuple)): + seen_ids[obj_id] = any(has_tensor(v) for v in obj) + return seen_ids[obj_id] + elif istype(obj, dict): + # Some packages like pytest can be updated during runtime. So, make a + # copy of values to avoid issues like "RuntimeError: dictionary + # changed size during iteration" + values = list(obj.values()) + seen_ids[obj_id] = any(has_tensor(v) for v in values) + return seen_ids[obj_id] + elif istype(obj, (str, int, float, type(None), bool)): + seen_ids[obj_id] = False + return seen_ids[obj_id] + elif is_namedtuple(obj) and hasattr(obj, "_fields"): + seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields) + return seen_ids[obj_id] + else: + # if config.debug: + # print( + # f"Assuming that object of type {type(obj)} does not have a tensor" + # ) + return False + + # Check if the passed arguments are of type Tensor + for value in frame.f_locals.values(): + if has_tensor(value): + return True + + log.debug( + "skipping because no torch.* %s \ + %s %s", + frame.f_code.co_name, + frame.f_code.co_filename, + frame.f_code.co_firstlineno, + ) + + return False + + +def exception_handler( + e: Exception, + code: CodeType, + frame: Optional[FrameType] = None, + export: bool = False, +) -> None: + record_filename = None + if hasattr(e, "exec_record"): + record_filename = gen_record_file_name(e, code) + write_record_to_file(record_filename, e.exec_record) + e.record_filename = record_filename # type: ignore[attr-defined] + + augment_exc_message(e, export=export) + + +FRAME_COUNTER = 0 +FRAME_COMPILE_COUNTER: typing.Counter[ + Union[int, FrameStateSizeEntry] +] = collections.Counter() + + +def maybe_cprofile(func: Callable[_P, _T]) -> Callable[_P, _T]: + if config.cprofile: + return cprofile_wrapper(func) + return func + + +def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(func) + def profile_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + trace_id = CompileContext.current_trace_id() + assert trace_id, "Trace id is None" + profile_path = Path( + f"/tmp/{func.__name__}_{str(trace_id).replace('/', '_')}.profile" + ) + prof = cProfile.Profile() + prof.enable() + start_ts = time.time() + retval = prof.runcall(func, *args, **kwargs) + profile_latency = time.time() - start_ts + prof.disable() + log.warning( + "### Cprofile for %s trace id [%s] took %.3f seconds ###", + func.__name__, + trace_id, + profile_latency, + ) + ps = pstats.Stats(prof) + try: + prof.dump_stats(profile_path) + except PermissionError: + log.exception("Cannot write to %s", profile_path) + log.warning("Raw profile at %s", profile_path) + svg_path = profile_path.with_suffix(".svg") + try: + gprof2dot_process = subprocess.Popen( + [ + "gprof2dot", + "-f", + "pstats", + "--node-label=total-time-percentage", + "--node-label=self-time-percentage", + "--node-label=total-time", + str(profile_path), + ], + stdout=subprocess.PIPE, + ) + subprocess.check_call( + ["dot", "-Tsvg", "-o", str(svg_path)], + stdin=gprof2dot_process.stdout, + ) + log.warning("Generated SVG from profile at %s", svg_path) + except FileNotFoundError: + log.warning( + "Failed to generate SVG from profile -- dumping stats instead." + "Try installing gprof2dot and dot for a better visualization" + ) + ps.sort_stats(pstats.SortKey.TIME).print_stats(20) + ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20) + + if manifold_link := maybe_upload_prof_stats_to_manifold( + str(profile_path) + ): # fb-only + torch._logging.trace_structured( + "link", + lambda: {"name": "cprofile_manifold_url", "url": manifold_link}, + ) + return retval + + return profile_wrapper + + +class ConvertFrameAssert: + def __init__( + self, + compiler_fn: CompilerFn, + one_graph: bool = True, + export: bool = False, + export_constraints: Optional[typing.Never] = None, + ) -> None: + # assert export_constraints is None + reset_graph_break_dup_checker() + self._torchdynamo_orig_callable = compiler_fn + self._one_graph = one_graph + self._export = export + self._export_constraints = export_constraints + + @property + def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]: + return lambda backend: convert_frame_assert( + backend, self._one_graph, self._export, self._export_constraints + ) + + def __call__( + self, + frame: FrameType, + cache_entry: Optional[CacheEntry], + hooks: Hooks, + frame_state: Dict[str, Union[int, FrameStateSizeEntry]], + *, + skip: int = 0, + ) -> Optional[GuardedCode]: + increment_frame() + + code = frame.f_code + + cache_size = compute_cache_size(frame, cache_entry) + input_codes.add(code) + if code in output_codes: + return None + if ( + os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") + and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name + ): + return None + if code.co_name == "" and code.co_filename.endswith( + ( + "transformers/file_utils.py", + "transformers/utils/generic.py", + "diffusers/utils/outputs.py", + ) + ): + # not needed, but cleans up torchbench error stats + return None + if code.co_name == "__setattr__": + # setattr could be tricky to handle generally, + # but also not likely useful to compile- skip the whole frame + return None + if code.co_name == "__init__" and code.co_filename.startswith( + os.path.dirname(torch.optim.__file__) + ): + # optimizer support is still incomplete see + # test_state_dict in test/dynamo/test_optimizers.py + return None + + # Check if the frame is generated by an exec builtin call + # TODO - Running exec generated frame seems propagates f_globals to the + # next frames. + if code.co_name == "" and code.co_filename == "": + return None + + if ( + code.co_name == "" + and code.co_filename == "" + and not bool(frame.f_builtins) + ): + # namedtuple subclass constructor. Empty builtins cause issue with + # len keyword in LIST_LEN guard. + return None + + if is_generator(code): + unimplemented("generator") + + if not has_tensor_in_frame(frame): + return None + + global initial_global_state + initial_global_state = GlobalStateGuard() + + global FRAME_COUNTER + if "_id" not in frame_state: + frame_state["_id"] = FRAME_COUNTER + FRAME_COUNTER += 1 + frame_id = frame_state["_id"] + assert isinstance(frame_id, int) + + frame_compile_id = FRAME_COMPILE_COUNTER[frame_id] + FRAME_COMPILE_COUNTER[frame_id] += 1 + + compile_id = CompileId(frame_id, frame_compile_id) + + signpost_event( + "dynamo", + "_convert_frame_assert._compile", + { + "co_name": code.co_name, + "frame_id": frame_id, + "compile_id": str(compile_id), + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs, + "accumulated_cache_size": cache_size.num_cache_entries, + }, + ) + + return _compile( + frame.f_code, + frame.f_globals, + frame.f_locals, + frame.f_builtins, + self._torchdynamo_orig_callable, + self._one_graph, + self._export, + self._export_constraints, + hooks, + cache_entry, + cache_size, + frame, + frame_state=frame_state, + compile_id=compile_id, + skip=skip + 1, + ) + + +def convert_frame_assert( + compiler_fn: CompilerFn, + one_graph: bool = True, + export: bool = False, + export_constraints: Optional[typing.Never] = None, +) -> ConvertFrameAssert: + """Fully convert a frame into an FX graph""" + return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints) + + +from collections import OrderedDict + +from torch.utils.hooks import RemovableHandle + + +if typing.TYPE_CHECKING: + from .output_graph import OutputGraph + +# we have to use `OrderedDict` to make `RemovableHandle` work. +_bytecode_hooks: Dict[int, BytecodeHook] = OrderedDict() + + +def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle: + """Register hooks for bytecode generated by Dynamo. The hook can do some + logging, as well as return a new code object to be used. Please refer + to `BytecodeHook` for the hook signature. + """ + handle = RemovableHandle(_bytecode_hooks) + _bytecode_hooks[handle.id] = hook + return handle + + +def _compile( + code: CodeType, + globals: Dict[str, object], + locals: Dict[str, object], + builtins: Dict[str, object], + compiler_fn: CompilerFn, + one_graph: bool, + export: bool, + export_constraints: Optional[typing.Never], + hooks: Hooks, + cache_entry: Optional[CacheEntry], + cache_size: CacheSizeRelevantForFrame, + frame: Optional[FrameType] = None, + frame_state: Optional[Dict[str, Union[int, FrameStateSizeEntry]]] = None, + *, + compile_id: CompileId, + skip: int = 0, +) -> Optional[GuardedCode]: + from torch.fx.experimental.validator import ( + bisect, + BisectValidationException, + translation_validation_enabled, + ValidationException, + ) + + # Only nonlocal defs here please! + # Time spent compiling this frame before restarting or failing analysis + dynamo_time_before_restart: float = 0.0 + output: Optional[OutputGraph] = None + tracer: Optional[InstructionTranslator] = None + + @preserve_global_state + def transform( + instructions: List[Instruction], code_options: Dict[str, object] + ) -> None: + nonlocal output + nonlocal tracer + speculation_log.restart() + tracer = InstructionTranslator( + instructions, + code, + locals, + globals, + builtins, + code_options, + compiler_fn, + one_graph, + export, + export_constraints, + mutated_closure_cell_contents, + frame_state=frame_state, + speculation_log=speculation_log, + distributed_state=distributed_state, + ) + + try: + with tracing(tracer.output.tracing_context), tracer.set_current_tx(): + tracer.run() + except exc.UnspecializeRestartAnalysis: + speculation_log.clear() + raise + except (exc.SpeculationRestartAnalysis, exc.SkipFrame): + raise + except Exception: + if translation_validation_enabled(): + bisect(tracer.output.shape_env) + raise + finally: + tracer.output.call_cleanup_hooks() + + output = tracer.output + assert output is not None + assert output.output_instructions + instructions[:] = output.output_instructions + code_options.update(output.code_options) + + if config.dead_code_elimination: + propagate_inst_exn_table_entries(instructions) + check_inst_exn_tab_entries_valid(instructions) + instructions[:] = remove_pointless_jumps(remove_dead_code(instructions)) + + def compile_inner( + code: CodeType, + one_graph: bool, + hooks: Hooks, + transform: Callable[[List[Instruction], Dict[str, Any]], Any], + ) -> Optional[GuardedCode]: + with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"): + with CompileTimeInstructionCounter.record(): + return _compile_inner(code, one_graph, hooks, transform) + + @compile_time_strobelight_meta(phase_name="compile_inner") + @maybe_cprofile + def _compile_inner( + code: CodeType, + one_graph: bool, + hooks: Hooks, + transform: Callable[[List[Instruction], Dict[str, Any]], Any], + ) -> Optional[GuardedCode]: + nonlocal dynamo_time_before_restart + last_attempt_start_time = start_time = time.time() + + def log_bytecode( + prefix: str, name: str, filename: str, line_no: int, code: CodeType + ) -> None: + if bytecode_log.isEnabledFor(logging.DEBUG): + bytecode_log.debug( + format_bytecode(prefix, name, filename, line_no, code) + ) + + log_bytecode( + "ORIGINAL BYTECODE", + code.co_name, + code.co_filename, + code.co_firstlineno, + code, + ) + + out_code = None + for attempt in itertools.count(): + CompileContext.get().attempt = attempt + try: + out_code = transform_code_object(code, transform) + break + except exc.RestartAnalysis as e: + log.info( + "Restarting analysis due to %s", + LazyString(format_traceback_short, e.__traceback__), + ) + # If restart reason is None just log the type of the exception + restart_reasons.add(e.restart_reason or str(type(e))) + # We now have a new "last attempt", reset the clock + last_attempt_start_time = time.time() + if attempt > 100: + unimplemented("100+ RestartAnalysis() calls") + except exc.SkipFrame as e: + log.debug( + "Skipping frame %s %s \ + %s %s", + e, + code.co_name, + code.co_filename, + code.co_firstlineno, + ) + if one_graph: + log.debug("No graph captured with one_graph=True") + return None + + assert ( + distributed_state is None or distributed_state.all_states is not None + ), "compiler collective wasn't run before compilation completed" + + assert out_code is not None + log_bytecode( + "MODIFIED BYTECODE", + code.co_name, + code.co_filename, + code.co_firstlineno, + out_code, + ) + + for hook in _bytecode_hooks.values(): + hook_output = hook(code, out_code) + if hook_output is not None: + out_code = hook_output + + orig_code_map[out_code] = code + output_codes.add(out_code) + dynamo_time_before_restart = last_attempt_start_time - start_time + assert output is not None + + # Tests for new code objects. + # The rationale for these tests can be found in torch/csrc/dynamo/eval_frame.c + # Only test once the code object is created. + # They are not tested during runtime. + + def count_args(code: CodeType) -> int: + import inspect + + return ( + code.co_argcount + + code.co_kwonlyargcount + + bool(code.co_flags & inspect.CO_VARARGS) + + bool(code.co_flags & inspect.CO_VARKEYWORDS) + ) + + assert out_code is not None + + total_argcount_old = count_args(code) + total_argcount_new = count_args(out_code) + msg = "arg mismatch: " + msg += f"old code object has args {code.co_varnames[:total_argcount_old]}, " + msg += f"new code object has args {out_code.co_varnames[:total_argcount_new]}" + assert ( + code.co_varnames[:total_argcount_old] + == out_code.co_varnames[:total_argcount_new] + ), msg + + msg = "free var mismatch: " + msg += f"old code object has free var {code.co_freevars}, " + msg += f"new code object has free var {out_code.co_freevars}" + assert code.co_freevars == out_code.co_freevars, msg + + msg = "cell var mismatch: " + msg += f"old code object has cell var {code.co_cellvars}, " + msg += f"new code object has cell var {out_code.co_cellvars}" + assert code.co_cellvars == out_code.co_cellvars, msg + + # Skipping Dynamo on a frame without any extracted graph. + # This does not affect eager functionality. But this is necessary + # for export for cases where Dynamo-reconstructed bytecode can create + # new function frames, confusing export in thinking that there + # are extra graphs now. + + if output.export and output.is_empty_graph(): + return None + + assert output.guards is not None + CleanupManager.instance[out_code] = output.cleanups + check_fn = CheckFunctionManager( + output, + hooks.guard_fail_fn if hooks else None, + ) + + guarded_code = GuardedCode(out_code, check_fn.check_fn, compile_id) + + if not output.is_empty_graph() and hooks.guard_export_fn is not None: + # We should not run the guard_export_fn when Dynamo does not + # generate any graph. This can happen in export when TorchDynamo + # generated bytecode has some reconstruction logic for mutated + # variables which can trigger TorchDynamo on the children frames but + # they are benign and do not generate any new graphs. + hooks.guard_export_fn(output.guards) + + return guarded_code + + with _use_lazy_graph_module(config.use_lazy_graph_module), compile_context( + CompileContext(compile_id) + ): + restart_reasons: set[str] = set() + # This is shared across restarts + mutated_closure_cell_contents: Set[str] = set() + speculation_log = SpeculationLog() + if compile_pg := get_compile_pg(): + distributed_state = DistributedState(compile_pg, LocalState()) + else: + distributed_state = None + torch._dynamo.callback_handler.run_start_callbacks() + + # Check recompilations + recompile_reasons = None + if is_recompilation(cache_size) and frame: + recompile_reasons = get_and_maybe_log_recompilation_reason( + cache_entry, frame + ) + + exceeded, limit_type = exceeds_cache_size_limit(cache_size, compile_id) + if exceeded: + + def format_func_info(code: CodeType) -> str: + return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})" + + def format_guard_failures() -> str: + if not recompile_reasons: + return "Unable to find recompilation reasons" + return recompile_reasons[-1] + + log.warning( + "torch._dynamo hit config.%s (%s)\n" + " function: %s\n" + " last reason: %s\n" + 'To log all recompilation reasons, use TORCH_LOGS="recompiles".\n' + "To diagnose recompilation issues, see %s.", + limit_type, + getattr(config, limit_type), + format_func_info(code), + format_guard_failures(), + troubleshooting_url, + ) + if config.skip_code_recursive_on_cache_limit_hit and justknobs_check( + "pytorch/compiler:skip_code_recursive_on_cache_limit_hit" + ): + raise CacheLimitExceeded(f"{limit_type} reached") + else: + # do not recursively skip frames + unimplemented(f"{limit_type} reached") + + log.debug( + "torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s", + code.co_name, + code.co_filename, + code.co_firstlineno, + skip + 2, + # -2: omit current frame, omit contextlib decorator + "".join(CapturedTraceback.extract(skip=2 + skip).format()), + ) + # -4: -2 as above, plus trace_structured frames + # + # NB: the frame looks like this: + # + # # handled by skip argument + # torch/_dynamo/convert_frame.py:1069 in catch_errors + # torch/_dynamo/convert_frame.py:910 in _convert_frame + # torch/_dynamo/convert_frame.py:464 in _convert_frame_assert + # torch/_utils_internal.py:70 in wrapper_function + # + # # 2 current frame and context lib + # env/lib/python3.10/contextlib.py:79 in inner + # torch/_dynamo/convert_frame.py:776 in _compile + # + # # 2 extra here + # torch/_logging/_internal.py:1064 in trace_structured + # torch/_dynamo/convert_frame.py:780 in + convert_frame_intern = structured.intern_string(__file__) + # Initialize the ChromiumEventLogger on start + chromium_event_log = get_chromium_event_logger() + chromium_event_log.reset() + torch._logging.trace_structured( + "dynamo_start", + lambda: { + "stack": list( + itertools.takewhile( + lambda f: f["filename"] != convert_frame_intern, + structured.from_traceback( + CapturedTraceback.extract(skip=4 + skip).summary() + ), + ) + ) + + [ + { + "line": code.co_firstlineno, + "name": code.co_name, + "filename": structured.intern_string(code.co_filename), + } + ] + }, + ) + start_time = time.time() + fail_type: Optional[str] = None + fail_reason: Optional[str] = None + fail_user_frame_filename: Optional[str] = None + fail_user_frame_lineno: Optional[int] = None + start_possibly_missed_reinplacing_opportunities = torch._dynamo.utils.counters[ + "inductor" + ]["possibly_missed_reinplacing_opportunities"] + guarded_code = None + try: + guarded_code = compile_inner(code, one_graph, hooks, transform) + return guarded_code + except Exception as e: + fail_type = type(e).__qualname__ + fail_reason = str(e) + # NB: e's msg is mutated here to add user stack, but we DON'T want + # that stack in the Scuba logged fail_reason + exception_handler(e, code, frame, export=export) + fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message( + e, compile_id + ) + if isinstance( + e, + ( + Unsupported, + TorchRuntimeError, + BackendCompilerFailed, + AssertionError, + ConstraintViolationError, + GuardOnDataDependentSymNode, + ValidationException, + UncapturedHigherOrderOpError, + BisectValidationException, + ), + ): + raise + else: + # Rewrap for clarity + raise InternalTorchDynamoError( + f"{type(e).__qualname__}: {str(e)}" + ).with_traceback(e.__traceback__) from None + finally: + if tracer: + tracer.output.local_scope = {} + + from .utils import curr_frame + + frame_key = str(curr_frame) + if ( + fail_reason is None + and output is not None + and frame_key in frame_phase_timing + ): + guard_count = len(output.guards) + shape_env_guard_count = len(output.shape_env.guards) + graph_op_count = output.count_calls() + graph_node_count = len(output.graph.nodes) + graph_input_count = len(output.placeholders) + entire_frame_compile_time = frame_phase_timing[frame_key].get( + "entire_frame_compile", None + ) + backend_compile_time = frame_phase_timing[frame_key].get( + "backend_compile", None + ) + inductor_compile_time = frame_phase_timing[frame_key].get( + "inductor_compile", None + ) + code_gen_time = frame_phase_timing[frame_key].get("code_gen", None) + non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops} + compliant_custom_ops = { + op.__qualname__ for op in output.compliant_custom_ops + } + possibly_missed_reinplacing_opportunities = ( + torch._dynamo.utils.counters["inductor"][ + "possibly_missed_reinplacing_opportunities" + ] + - start_possibly_missed_reinplacing_opportunities + ) + else: + guard_count = None + shape_env_guard_count = None + graph_op_count = None + graph_node_count = None + graph_input_count = None + entire_frame_compile_time = None + backend_compile_time = None + inductor_compile_time = None + code_gen_time = None + non_compliant_ops = set({}) + compliant_custom_ops = set({}) + restart_reasons = set() + # If compilation failed, the entire time is wasted + dynamo_time_before_restart = time.time() - start_time + possibly_missed_reinplacing_opportunities = None + + metrics = CompilationMetrics( + str(compile_id), + frame_key, + code.co_name, + code.co_filename, + code.co_firstlineno, + cache_size.num_cache_entries_with_same_id_matched_objs, + cache_size.num_cache_entries, + guard_count, + shape_env_guard_count, + graph_op_count, + graph_node_count, + graph_input_count, + start_time, + entire_frame_compile_time, + backend_compile_time, + inductor_compile_time, + code_gen_time, + fail_type, + fail_reason, + fail_user_frame_filename, + fail_user_frame_lineno, + non_compliant_ops, + compliant_custom_ops, + restart_reasons, + dynamo_time_before_restart, + guarded_code is not None, + possibly_missed_reinplacing_opportunities, + ) + record_compilation_metrics(metrics) + torch._dynamo.callback_handler.run_end_callbacks() + + +class ConvertFrame: + def __init__(self, compiler_fn: CompilerFn, hooks: Hooks) -> None: + self._torchdynamo_orig_callable = compiler_fn + self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False) + self._hooks = hooks + + @property + def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]: + return lambda backend: convert_frame(backend, self._hooks) + + def __call__( + self, + frame: FrameType, + cache_entry: Optional[CacheEntry], + hooks: Hooks, + frame_state: Dict[str, Union[int, FrameStateSizeEntry]], + skip: int = 0, + ) -> Optional[ + Union[GuardedCode, torch._C._dynamo.eval_frame.SkipCodeRecursiveFlag] + ]: + counters["frames"]["total"] += 1 + try: + result = self._inner_convert( + frame, cache_entry, hooks, frame_state, skip=skip + 1 + ) + counters["frames"]["ok"] += 1 + return result + except Exception as e: + # These two exception types are "soft" failure, in the sense that + # we know this is due to something we didn't implement all the + # way, scare the user less about it. That being said, if you + # are trying to understand why a graph break happened, it's still + # important to have this information, so offer it. + # + # NB: NotImplementedError used to be on this list, but actually + # it is impossible for it to reach here, as it is converted into + # InternalTorchDynamoError. This behavior seemed reasonable + # to me (ezyang, Aug 2023) so I kept it, but maybe at some point + # someone wanted these to also get suppressed. If so, you'll + # need to make these exceptions not get wrapped + + # We intentionally don't want to suppress error here. + if isinstance(e, UncapturedHigherOrderOpError): + raise + + soft_fail = isinstance(e, Unsupported) + + # This is a soft failure. In the sense, the code path reaches here + # when we do not support graph breaks on bytecodes like LOAD_ATTR, + # BUILD_SET etc. In such case, we can fallback to eager without + # scaring users. + if isinstance(e, Unsupported) and graph_break_log.isEnabledFor( + logging.DEBUG + ): + # Log this message in the graph break. Also use the string + # "skip: " to tell that the whole frame is falling back to + # eager. + if hasattr(e, "compile_id"): + with compile_context(CompileContext(e.compile_id)): # type: ignore[attr-defined] + user_stack = e.real_stack + user_stack_formatted = "".join( + traceback.format_list(user_stack) + ) + graph_break_log.debug( + "Graph break: skip: from user code at:\n%s", + user_stack_formatted, + exc_info=True, + ) + + if not config.suppress_errors and not soft_fail: + raise + + # Suppress the error. NB: It's very important to do the + # suppression logging HERE, where the actual suppression + # happens. Previously it was somewhere else and so it was + # possible to accidentally not log at all. + record_filename = getattr(e, "record_filename", None) + code = frame.f_code + error_msg = format_error_msg(e, code, record_filename, frame) + + if soft_fail: + log.info(error_msg, exc_info=True) + else: + log.warning(error_msg, exc_info=True) + + # If we encounter SkipCodeRecursiveException, return skip_code_recursive_flag + # to signal to Dynamo eval frame to skip the current frame and any recursive calls. + if isinstance(e, SkipCodeRecursiveException): + return torch._C._dynamo.eval_frame.skip_code_recursive_flag + + return None + + +def convert_frame(compiler_fn: CompilerFn, hooks: Hooks) -> ConvertFrame: + """Try to convert a frame into an FX graph, if error leave frame unmodified""" + return ConvertFrame(compiler_fn, hooks) + + +# TODO mlazos: add support for same args, or record them +def replay(filename: str) -> None: + from .backends.debugging import eager + + original_replay_val = config.replay_record_enabled + config.replay_record_enabled = False + with open(filename, "rb") as in_file: + record = ExecutionRecord.load(in_file) + record.globals = dict(itertools.chain(record.globals.items(), globals().items())) + + try: + _compile( + record.code, + record.globals, + record.locals, + record.builtins, + compiler_fn=eager, + one_graph=False, + export=False, + export_constraints=None, + hooks=Hooks(), + cache_size=CacheSizeRelevantForFrame(0, 0), + cache_entry=None, + frame=None, + frame_state={}, + compile_id=CompileId(42, 999), + ) + finally: + config.replay_record_enabled = original_replay_val + + +def first_real_inst_idx(code: CodeType) -> int: + if sys.version_info < (3, 11): + return 0 + for inst in dis.get_instructions(code): + if inst.opname == "RESUME": + return inst.offset // 2 + raise RuntimeError("RESUME instruction not found in code") + + +class ConvertFrameProtocol(typing.Protocol): + def __call__( + self, + frame: FrameType, + cache_entry: Optional[CacheEntry], + hooks: Hooks, + frame_state: Dict[str, Union[int, FrameStateSizeEntry]], + *, + skip: int = 0, + ) -> Optional[GuardedCode]: + ... + + +class CatchErrorsWrapper: + def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None: + functools.wraps(callback)(self) + self._torchdynamo_orig_callable = callback + self.hooks = hooks + + def __call__( + self, + frame: FrameType, + cache_entry: Optional[CacheEntry], + frame_state: Dict[str, Union[int, FrameStateSizeEntry]], + ) -> Optional[GuardedCode]: + assert frame_state is not None + + is_skipfile = trace_rules.check(frame.f_code) + if sys.version_info >= (3, 13): + has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code) + else: + has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code) + if ( + # TODO: the first condition is not covered by any test + has_started_execution + or is_skipfile + or config.disable + or ( + is_in_torch_dispatch_mode(include_infra_modes=False) + and not getattr(self._torchdynamo_orig_callable, "_export", False) + ) + ): + if log.isEnabledFor(logging.DEBUG): + print(frame.f_lasti, first_real_inst_idx(frame.f_code)) + + if has_started_execution: + skip_reason = "traced frame already" + elif trace_rules.check(frame.f_code): + skip_reason = "in skipfiles" + elif is_in_torch_dispatch_mode(include_infra_modes=False): + skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile" + else: + skip_reason = "dynamo tracing is disabled" + + log.debug( + "skipping: %s (reason: %s, file: %s)", + frame.f_code.co_name, + skip_reason, + frame.f_code.co_filename, + ) + return None + + if frame.f_code.co_filename == "" and frame.f_code.co_name == "__new__": + # nametuple constructor + return None + if config._get_optimize_ddp_mode() == "ddp_optimizer": + ddp_module = DistributedDataParallel._get_active_ddp_module() + if ddp_module: + with compile_lock: + from torch._dynamo.backends.distributed import DDPOptimizer + + ddp_optimizer = DDPOptimizer( + bucket_bytes_cap=ddp_module.bucket_bytes_cap, + backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable, # type: ignore[attr-defined] + ) + assert hasattr( + self._torchdynamo_orig_callable, "_clone_with_backend" + ), "DDPOptimizer only supports callback fns that know how to clone themselves." + hijacked_callback = ( + self._torchdynamo_orig_callable._clone_with_backend( + ddp_optimizer.compile_fn, + ) + ) + return hijacked_callback( + frame, cache_entry, self.hooks, frame_state + ) + + with compile_lock, _disable_current_modes(): + # skip=1: skip this frame + return self._torchdynamo_orig_callable( + frame, cache_entry, self.hooks, frame_state, skip=1 + ) + + +def catch_errors_wrapper( + callback: ConvertFrameProtocol, hooks: Hooks +) -> CatchErrorsWrapper: + return CatchErrorsWrapper(callback, hooks) diff --git a/lib/python3.10/site-packages/torch/_dynamo/create_parameter_op.py b/lib/python3.10/site-packages/torch/_dynamo/create_parameter_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6661078859211c969287e646a3bc8f078e364df2 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/create_parameter_op.py @@ -0,0 +1,60 @@ +# mypy: allow-untyped-defs +import threading +from contextlib import contextmanager + +import torch + + +doc = """ +This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly +with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which +becomes a graph arg and has no storage backing it. At the point in the graph where the parameter +actually should be created we mutate this sacrificial placeholder into it. This allows gradients +to flow into the parameter as if it were an input to the graph (which is the only thing we are +allowed to compute gradients on). +""".strip() + + +class TracableCreateParameter(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, placeholder): + assert not tensor.requires_grad + return placeholder.set_(tensor) + + @staticmethod + def backward(ctx, grad): + return None, grad # grad flows to placeholder + + +def tracable_create_parameter(tensor, placeholder): + with torch.set_grad_enabled(placeholder.requires_grad): + out = TracableCreateParameter.apply(tensor, placeholder) + return out + + +def new_parameter_placeholder(size, dtype, device, requires_grad): + """Create a placeholder to be passed to the above functions""" + result = torch.nn.Parameter( + torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad + ) + # TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor. + # Allocating a zero tensor would causes assert failures in autograd. + result.untyped_storage().resize_(0) + return result + + +_TLS = threading.local() + + +@contextmanager +def do_not_convert_to_tracable_parameter(): + old_flag = getattr(_TLS, "convert_tracable_parameter", True) + _TLS.convert_tracable_parameter = False + try: + yield False + finally: + _TLS.convert_tracable_parameter = old_flag + + +def can_convert_to_tracable_parameter(): + return getattr(_TLS, "convert_tracable_parameter", True) diff --git a/lib/python3.10/site-packages/torch/_dynamo/current_scope_id.py b/lib/python3.10/site-packages/torch/_dynamo/current_scope_id.py new file mode 100644 index 0000000000000000000000000000000000000000..c0337b78462fa094447437ccd07004b9f4c525a8 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/current_scope_id.py @@ -0,0 +1,25 @@ +# mypy: allow-untyped-defs +import contextlib +import threading + + +# Global variable to identify which SubgraphTracer we are in. +# It is sometimes difficult to find an InstructionTranslator to use. +_current_scope_id = threading.local() + + +def current_scope_id(): + global _current_scope_id + if not hasattr(_current_scope_id, "value"): + _current_scope_id.value = 1 + return _current_scope_id.value + + +@contextlib.contextmanager +def enter_new_scope(): + global _current_scope_id + try: + _current_scope_id.value = current_scope_id() + 1 + yield + finally: + _current_scope_id.value = current_scope_id() - 1 diff --git a/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py b/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..94687ff2747bf375899c5428836e0790e06bb2d3 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py @@ -0,0 +1,824 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code="method-assign" +import atexit +import copy +import cProfile +import functools +import getpass +import inspect +import itertools +import logging +import os +import re +import subprocess +import sys +import tempfile +import textwrap +from collections import Counter +from importlib import import_module +from typing import Any, Callable, Dict, List, Optional, TypeVar + +import torch +import torch._prims_common as utils +import torch._subclasses.meta_utils +from torch import Tensor +from torch._dynamo.testing import rand_strided +from torch._prims_common import is_float_dtype +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._content_store import ContentStoreReader, ContentStoreWriter + +from . import config +from .utils import clone_inputs, get_debug_dir + + +log = logging.getLogger(__name__) + +T = TypeVar("T") + + +inductor_config = import_module("torch._inductor.config") +use_buck = inductor_config.is_fbcode() + +if use_buck: + import libfb.py.build_info + + +extra_deps = [] +extra_imports = "" +if use_buck: + extra_deps = [ + "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu", + "//caffe2/torch/fb/sparsenn:sparsenn_operators", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops", + ] + cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined] + extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps]) + + +BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"] + + +class BuckTargetWriter: + def __init__(self, filename): + self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) + self.target = self.py_file.replace(".py", "") + + # Get main_module path from fbcode + self.path = f'{self.subdir.replace("/", ".")}.{self.target}' + self.path = self.path[self.path.find("fbcode.") :] + self.path = self.path[7:] + + # Get cmd line path + tmp = self.subdir + tmp = tmp[tmp.find("fbcode/") :][7:] + self.cmd_line_path = f"//{tmp}:{self.target}" + + def build(self): + extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps]) + return textwrap.dedent( + f""" +load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") + +python_binary( + name="{self.target}", + srcs = ["{self.py_file}"], + compile = False, + deps = [ + "//caffe2:torch", + "//caffe2/functorch:functorch", + "//triton:triton", + "{cur_target}", + ], + cpp_deps = [ +{extra_cpp_deps} + ], + main_module = "{self.path}", + par_style = "xar", +) +""" + ) + + def write(self, print_msg=True): + target_file = os.path.join(self.subdir, "TARGETS") + with open(target_file, "w") as fd: + fd.write(self.build()) + # log.warning("Wrote isolation TARGETS file at %s", target_file) + cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path] + if print_msg: + log.warning( + "Found an example that reproduces the error. Run this cmd to repro - %s", + " ".join(cmd_split), + ) + return cmd_split + + +def minifier_dir(): + path = os.path.join(get_debug_dir(), "minifier") + if path is None: + path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}" + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + return path + + +MAX_CONSTANT_NUMEL_INLINE = 4 + + +class NNModuleToString: + safe_reprs = [ + torch.nn.Linear, + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.LayerNorm, + torch.nn.Dropout, + torch.nn.Softmax, + torch.nn.ReLU, + torch.nn.GELU, + torch.nn.Identity, + torch.nn.MaxPool2d, + torch.nn.Embedding, + torch.nn.Tanh, + torch.nn.ConvTranspose1d, + torch.nn.GLU, + torch.nn.LSTM, + torch.nn.Flatten, + torch.nn.AdaptiveAvgPool2d, + ] + + @staticmethod + def can_convert_to_string(gm): + cant_convert = set() + for _, module in gm.named_children(): + if type(module) not in NNModuleToString.safe_reprs: + cant_convert.add(module) + + if len(cant_convert) > 0: + log.warning("We have not tested reprs of some modules - %s", cant_convert) + # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct. + return True + + @staticmethod + def convert(gm): + from torch.nn.modules.module import _addindent + + tab = " " * 4 + + model_str = textwrap.dedent( + """ + from torch.nn import * + class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + """ + ) + + for module_name, module in gm.named_children(): + module_str = f"{module.__repr__()}" + # module should be a core torch.nn.Module, so all parameters + # should be on the same device. + example_param = next(module.parameters(), None) + if example_param is not None and example_param.is_cuda: + module_str = f"{module_str}.cuda()" + model_str += f"{tab*2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in gm._buffers.items(): + if buffer is None: + continue + # Serialize full data for small buffers + if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE: + from torch._tensor_str import PRINT_OPTS + + assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE + tensor_str = repr(buffer) + elif torch.is_floating_point(buffer): + tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})" + else: + tensor_str = ( + f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" + ) + if buffer.is_cuda: + tensor_str = f"{tensor_str}.cuda()" + model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" + + for param_name, param in gm._parameters.items(): + if param is None: + continue + maybe_device = "" + if param.is_cuda: + maybe_device = ', device="cuda"' + tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))" + model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" + + # TODO - Keep this code for now. But, I don't think we will need this. + # attrs = dir(gm) + # for attr in attrs: + # if "_tensor_constant" in attr: + # val = getattr(gm, attr) + # model_str += f" {attr} = {val!r}\n" + + model_str += f"{_addindent(gm.code, 4)}\n" + return model_str + + +@functools.lru_cache(None) # subprocess is expensive +def _cuda_system_info_comment(): + if not torch.cuda.is_available(): + return "# torch.cuda.is_available()==False, no GPU info collected\n" + + model_str = "# CUDA Info: \n" + try: + cuda_version_out = subprocess.check_output(["nvcc", "--version"]) + cuda_version_lines = cuda_version_out.decode().split("\n") + comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]]) + model_str += f"{comment}\n" + except (FileNotFoundError, subprocess.CalledProcessError): + model_str += "# nvcc not found\n" + + gpu_names = Counter( + torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count()) + ) + + model_str += "# GPU Hardware Info: \n" + for name, count in gpu_names.items(): + model_str += f"# {name} : {count} \n" + model_str += "\n" + return model_str + + +def generate_config_string(*, stable_output=False): + import torch._functorch.config + import torch._inductor.config + + if stable_output: + return "# config omitted due to stable_output=True" + + experimental_config = torch.fx.experimental._config.codegen_config() # type: ignore[attr-defined] + return f"""\ +import torch._dynamo.config +import torch._inductor.config +import torch._functorch.config +import torch.fx.experimental._config +{torch._dynamo.config.codegen_config()} +{torch._inductor.config.codegen_config()} +{torch._functorch.config.codegen_config()} +{experimental_config} +""" + + +def get_minifier_repro_path(): + return os.path.join(minifier_dir(), "minifier_launcher.py") + + +def helper_for_dump_minify(contents): + minified_repro_path = get_minifier_repro_path() + log.warning("Writing minified repro to:\n%s", minified_repro_path) + + if use_buck: + BuckTargetWriter(minified_repro_path).write() + try: + with open(minified_repro_path, "w") as fd: + fd.write(contents) + + except OSError as e: + log.exception("") + raise NotImplementedError("Could not write to {minified_repro_path}") from e + + +class AccuracyError(Exception): + pass + + +def clone_inputs_retaining_gradness(example_inputs): + """ + This clone inputs is different from utils clone_input. In case of minifier, + all the tensors are leaf tensors while creating a new graph. So, we set the + requires_grad field w/o checking the leafness of the tensor. + """ + cloned_inputs = clone_inputs(example_inputs) + for idx in range(len(example_inputs)): + if isinstance(cloned_inputs[idx], torch.Tensor): + cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad) + return cloned_inputs + + +def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False): + """ + Runs a forward and possibly backward iteration for a given mod and args. + + When disable_clone is True, we will use args as-is without cloning. + This is higher fidelity but we may destroy the args in the process. + """ + from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass + + gm = copy.deepcopy(gm) + if not disable_clone: + args = clone_inputs_retaining_gradness(args) + + if hasattr(gm, "zero_grad"): + gm.zero_grad(True) + + # TorchInductor returned callable expects lists. So, may need a boxed calling convention. + out = gm(args) if hasattr(gm, "_boxed_call") else gm(*args) + + if only_fwd: + return out + if requires_bwd_pass(out): + loss = reduce_to_scalar_loss(out) + loss.backward() + return collect_results(gm, out, None, args) + + +def same_two_models( + gm, + opt_gm, + example_inputs, + only_fwd=False, + *, + require_fp64=False, + ignore_non_fp=False, +): + """ + Check two models have same accuracy. + + require_fp64: if True, raise an error if we unable to calculate the fp64 reference + ignore_non_fp: if True, do not compare outputs which are not floating point. This + is mostly useful for the minifier (which wants to avoid quantizing floating point + error into integer/boolean error) + """ + from .utils import same + + ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) + + fp64_ref = None + if config.same_two_models_use_fp64: + try: + fp64_model, fp64_examples = cast_to_fp64( + copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) + ) + fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) + except Exception: + if require_fp64: + raise RuntimeError( # noqa: B904 + "Could not generate fp64 outputs, workaround with torch._dynamo.config.same_two_models_use_fp64 = False" + ) + log.warning("Could not generate fp64 outputs") + + try: + res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd) + except Exception as e: + # This means that the minified graph is bad/exposes a different problem. + # As we are checking accuracy here, lets log the exception and return True. + log.exception( + "While minifying the program in accuracy minification mode, " + "ran into a runtime exception which is likely an unrelated issue." + " Skipping this graph." + ) + return True + + passing = same( + ref, + res, + fp64_ref, + tol=config.repro_tolerance, + equal_nan=True, + ignore_non_fp=ignore_non_fp, + ) + return passing + + +def cast_dtype_args_to_fp64(model): + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.prims.convert_element_type.default + ): + assert len(node.args) == 2 + if is_float_dtype(node.args[1]) and node.args[1] != torch.float64: + node.args = (node.args[0], torch.float64) + if node.op == "call_function": + dtype = node.kwargs.get("dtype") + if dtype is not None and is_float_dtype(dtype): + new_kwargs = dict(node.kwargs) + new_kwargs["dtype"] = torch.float64 + node.kwargs = new_kwargs + + model.graph.lint() + model.recompile() + return model + + +def cast_to(dtype, model, inputs): + from torch.utils._pytree import tree_map + + model = model.to(dtype) + if dtype == torch.float64: + # If casting to fp64 for accuracy comparison, we need to + # replace dtype arguments embedded in the graph with fp64 + model = cast_dtype_args_to_fp64(model) + + inputs = tree_map( + lambda x: x.to(dtype) + if isinstance(x, torch.Tensor) and x.is_floating_point() + else x, + inputs, + ) + return model, inputs + + +def cast_to_fp64(model, inputs): + return cast_to(torch.float64, model, inputs) + + +def backend_accuracy_fails( + gm, + example_inputs, + compiler_fn, + only_fwd=False, + *, + require_fp64=False, + ignore_non_fp=False, +): + try: + compiled_gm = compiler_fn( + copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) + ) + return not same_two_models( + gm, + compiled_gm, + example_inputs, + only_fwd, + require_fp64=require_fp64, + ignore_non_fp=ignore_non_fp, + ) + except Exception as e: + # This means that the minified graph is bad/exposes a different problem. + # As we are checking accuracy here, lets log the exception and return False. + log.exception( + "While minifying the program in accuracy minification mode, " + "ran into a runtime exception which is likely an unrelated issue." + " Skipping this graph" + ) + return False + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO SUPPORT CODE +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +# Helper functions for computing what the default values of tensor +# values should be. These all coincide with factory functions, e.g., torch.empty + + +def _stride_or_default( + stride: Optional["torch._prims_common.StrideType"], + *, + shape: "torch._prims_common.ShapeType", +) -> "torch._prims_common.StrideType": + return stride if stride is not None else utils.make_contiguous_strides_for(shape) + + +def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]: + return lambda x: x if x is not None else d + + +_dtype_or_default = _mk_defaulter(torch.float32) +_device_or_default = _mk_defaulter(torch.device("cpu")) +_storage_offset_or_default = _mk_defaulter(0) +_requires_grad_or_default = _mk_defaulter(False) +_is_leaf_or_default = _mk_defaulter(False) + + +class NopInputReader: + def __init__(self) -> None: + self.total = 0 + + def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): + self.total += 1 + + def tensor(self, *args, **kwargs): + pass + + def symint(self, *args, **kwargs): + pass + + +# TODO: Support bundling the entire repro into a zip file for ease of +# transferring around +class InputReader: + def __init__(self, save_dir=None, *, pbar=None): + # If None, we will generate random data instead. It's important + # to natively support this use case as it will allow people to + # share repros without including the real data, if the problem + # reproduces even on random data. + if save_dir is None: + log.warning("no save_dir specified, will generate random data") + self.store = ContentStoreReader(save_dir) if save_dir is not None else None + self.args = [] + self.pbar = pbar + + def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): + if self.pbar is not None: + self.pbar.update(1) + device = _device_or_default(device) + dtype_hint = _dtype_or_default(dtype_hint) + if self.store is not None and storage_hash is not None: + try: + storage = self.store.read_storage(storage_hash) + except FileNotFoundError: + pass + else: + if device != storage.device: + log.warning("device mismatch: %s != %s", device, storage.device) + # TODO: transfer it to the right device? But failing this + # way would be very mysterious! Would have been better + # not to store device in the serialized format... + return storage + log.warning("could not load %s, generating random data instead", storage_hash) + shape = (nbytes // dtype_hint.itemsize,) + stride = _stride_or_default(None, shape=shape) + return rand_strided(shape, stride, dtype_hint, device).untyped_storage() + + def tensor( + self, + storage, + shape, + stride=None, + *, + storage_offset=None, + dtype=None, + requires_grad=None, + is_leaf=None, + **metadata, + ): + stride = _stride_or_default(stride, shape=shape) + storage_offset = _storage_offset_or_default(storage_offset) + dtype = _dtype_or_default(dtype) + is_leaf = _is_leaf_or_default(is_leaf) + requires_grad = _requires_grad_or_default(requires_grad) + t = torch.tensor( + [], dtype=dtype, device=storage.device, requires_grad=requires_grad + ) + with torch.no_grad(): + t.set_(storage, storage_offset, shape, stride) + if not is_leaf: + # Fake up some autograd history in a very naughty way + with torch.enable_grad(): + t = t.clone(memory_format=torch.preserve_format) + with torch.no_grad(): + t.set_(storage, storage_offset, shape, stride) + assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf + torch._utils.set_tensor_metadata(t, metadata) + self.args.append(t) + return t # for BC + + def symint(self, val): + self.args.append(val) + return val # for BC + + +# Here is our writer strategy: +# 1. We will stream all of the inputs to disk +# 2. You can now deterministically randomize the inputs, or reload +# the inputs from disk +# 3. You can YOLO run the script without the inputs, in which case +# we'll fill the inputs with random data and pray. This is the +# legacy behavior, but it's also useful if you want to find out +# if we're so broken even random inputs trigger it +# 4. We could offer an in process "check if the randomized thing +# works too" but this is delicate so we don't do it + + +class InputWriter: + def __init__(self, save_dir, *, stable_hash=False): + self._lines = [] + # TODO: consider ensuring tensor and storage counters line up? + self.storage_counter = itertools.count() + self.save_dir = save_dir + self.store = ( + ContentStoreWriter(save_dir, stable_hash=stable_hash) + if save_dir is not None + else None + ) + self.seen_storages = {} + + def lines(self): + r = [ + "def load_args(reader):", + ] + r.extend(f" {l}" for l in self._lines) + # In case we need to change the internal format of load_args + # in an FC-breaking way + r.append("load_args._version = 0") + return r + + # Storages are untyped, but we need to initialize them with data if + # we don't have the real data, so we give a hint saying what kind + # of initialization may be appropriate + # + # If we had a FakeTensor, device_hint tells us what device should be + def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: + ws = StorageWeakRef(untyped_storage) + v = self.seen_storages.get(ws) + if v is not None: + return v + v = f"buf{next(self.storage_counter)}" + maybe_dtype_hint = "" + if _dtype_or_default(None) != _dtype_or_default(dtype_hint): + maybe_dtype_hint = f", dtype_hint={dtype_hint!r}" + # TODO: being optional on device is kind of pointless as the default + # is CPU but most repros we care about are CUDA + maybe_device = "" + device = untyped_storage.device + if device.type == "meta": + assert device_hint is not None + device = device_hint + if _device_or_default(None) != device: + maybe_device = f", device={device!r}" + nbytes = untyped_storage.nbytes() + storage_hash = None + if self.store is not None and untyped_storage.device.type != "meta": + storage_hash = self.store.write_storage(untyped_storage) + self._lines.append( + f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})" + ) + self.seen_storages[ws] = v + return v + + def tensor(self, name, t) -> None: + from torch.fx.experimental.symbolic_shapes import statically_known_true + + storage = self.storage( + t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device + ) + args = [] + # NB: this is positional, must come first + if _stride_or_default(None, shape=t.shape) != t.stride(): + args.append(str(tuple(t.stride()))) + if _dtype_or_default(None) != t.dtype: + args.append(f"dtype={t.dtype!r}") + if not statically_known_true( + _storage_offset_or_default(None) == t.storage_offset() + ): + args.append(f"storage_offset={t.storage_offset()!r}") + tensor_metadata = torch._utils.get_tensor_metadata(t) + if tensor_metadata: + args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items()) + if _requires_grad_or_default(None) != t.requires_grad: + args.append(f"requires_grad={t.requires_grad!r}") + is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t) + if _is_leaf_or_default(None) != is_leaf: + args.append(f"is_leaf={is_leaf!r}") + self._lines.append( + "reader.tensor(" + + ", ".join([storage, str(tuple(t.shape)), *args]) + + f") # {name}" + ) + + # TODO: this doesn't actually symint atm + def symint(self, name, val) -> None: + if isinstance(val, torch.SymInt): + val = val.node.hint + self._lines.append(f"reader.symint({val!r}) # {name}") + + +def aot_graph_input_parser( + func: Callable[[List[Tensor]], List[Tensor]], + device: str = "cuda", + sym_shapes: Optional[Dict[str, int]] = None, + default_sym_shape: Optional[int] = None, +) -> Dict[str, Any]: + """ + Takes in a function which has been printed with print_readable() and constructs kwargs to run it. + + Handles Tensor inputs, Symints, and a graph module which might have tensor constants. + + Consider a function `forward` defined as follows: + + def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "Sym(s0)",): + _tensor_constant0: "i64[4190]" = self._tensor_constant0 + # Further implementation + + kwargs = aot_graph_input_parser(forward) + forward(**kwargs) + """ + + from torch.fx.graph import dtype_abbrs + + dtype_map = {value: key for key, value in dtype_abbrs.items()} + dtype_pattern = "|".join(dtype_abbrs.values()) + + # Extracting the source code from the function + source = inspect.getsource(func) + + # Regular expressions + tensor_assignment_regex = rf"(_tensor_constant\d+): \"({dtype_pattern})\[\s*(.*?)\s*\]\" = self\.(_tensor_constant\d+)" + tensor_regex = rf"({dtype_pattern})\[\s*(.*?)\s*\]" + sym_shape_regex = r"Sym\((s\d+)\)" + + class TensorContainer: + "Container for tensors as attributes" + + # Dictionary for tensors from annotations + kwargs: Dict[str, Any] = {} + + sym_shapes = sym_shapes or {} + + def get_sym_int(symint): + torch._check( + symint in sym_shapes or default_sym_shape is not None, + lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in", + ) + return sym_shapes.get(symint, default_sym_shape) + + def gen_tensor(shape, dtype) -> Tensor: + # Resolve symbolic shapes to concrete values + resolved_shape = [] + dynamic_dims = [] + for i, dim in enumerate(shape): + dim = dim.strip() + if "s" in dim: + s = get_sym_int(dim) + resolved_shape.append(s) + dynamic_dims.append(i) + else: + if dim: + resolved_shape.append(int(dim)) + + constructor = torch.randn if dtype.is_floating_point else torch.zeros + out = constructor(resolved_shape, dtype=dtype, device=device) # type: ignore[call-arg] + for d in dynamic_dims: + torch._dynamo.mark_dynamic(out, d) + return out + + # Parse function annotations for tensor generation + annotations = func.__annotations__ + for param, annotation in annotations.items(): + # Skip 'return' annotation + if param == "return": + continue + + match = re.search(tensor_regex, annotation) + if match: + data_type, shape_str = match.groups() + shape = tuple(shape_str.split(",")) + dtype = dtype_map[data_type] + kwargs[param] = gen_tensor(shape, dtype) + + match = re.search(sym_shape_regex, annotation) + if match: + kwargs[param] = get_sym_int(match.group(1)) + + if "self" in inspect.signature(func).parameters: + container = TensorContainer() + kwargs["self"] = container + for match in re.finditer(tensor_assignment_regex, source): + attr_name, data_type, shape_str, _ = match.groups() + shape = tuple(shape_str.split(",")) + dtype = dtype_map[data_type] + setattr(container, attr_name, gen_tensor(shape, dtype)) + + return kwargs + + +def profile_to_file(filename: str) -> Callable[[T], T]: + """ + Decorator to cProfile a given function and save the result to disk on process exit. + + Args: + filename: filename to save profile to + """ + prof = cProfile.Profile() + filename = os.path.abspath(os.path.expanduser(filename)) + + def decorator(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + prof.enable() + try: + return fn(*args, **kwargs) + finally: + prof.disable() + + return wrapper + + def save_it(): + prof.dump_stats(filename) + sys.stderr.write( + textwrap.dedent( + f"""\ + Wrote profile to {filename}, view with: + + snakeviz {filename} + + """ + ) + ) + + atexit.register(save_it) + return decorator diff --git a/lib/python3.10/site-packages/torch/_dynamo/decorators.py b/lib/python3.10/site-packages/torch/_dynamo/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..235db547dc7b1c251a23a18e9911250ed23abfd5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/decorators.py @@ -0,0 +1,580 @@ +# mypy: allow-untyped-defs +# ruff: noqa: TCH004 +import functools +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar + +import torch +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from . import trace_rules, variables +from .comptime import comptime +from .eval_frame import DisableContext, innermost_fn, RunOnlyContext +from .exc import IncorrectUsage +from .external_utils import is_compiling +from .utils import is_function + + +if TYPE_CHECKING: + from types import FunctionType + + from torch._C._dynamo.eval_frame import ( # noqa: F401 + reset_code, + set_eval_frame, + set_guard_error_hook, + skip_code, + unsupported, + ) + + from .variables import VariableTracker +else: + for name in dir(torch._C._dynamo.eval_frame): + if name.startswith("__"): + continue + globals()[name] = getattr(torch._C._dynamo.eval_frame, name) + + +_F = TypeVar("_F", bound=Callable[..., Any]) + + +def run(fn=None): + """Don't do any dynamic compiles, just use prior optimizations""" + if fn is not None: + fn = innermost_fn(fn) + assert callable(fn) + return RunOnlyContext()(fn) + return RunOnlyContext() + + +def disable(fn=None, recursive=True): + """ + Decorator and context manager to disable TorchDynamo + + If recursive=True, Dynamo is completely skipped on the decorated function + frame as well as the recursively invoked functions. + + If recursive=False, Dynamo skips frames associated with the function code, + but still process recursively invoked frames. + """ + if recursive: + if fn is not None: + fn = innermost_fn(fn) + assert callable(fn) + return DisableContext()(fn) + return DisableContext() + else: + return skip(fn) + + +def skip(fn=None): + """ + Skip frames associated with the function code, but still process recursively + invoked frames + """ + if fn is None: + return skip + fn = innermost_fn(fn) + assert callable(fn) + skip_code(fn.__code__) + fn._torchdynamo_disable = True + return fn + + +def assume_constant_result(fn): + fn._dynamo_marked_constant = True + return fn + + +def allow_in_graph(fn): + """ + Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function + and instead directly write it to the graph when encountered. + + See :func:`torch.compiler.allow_in_graph`'s docstring for the full documentation + + WARNING: this API can be a footgun, please read the documentation carefully. + """ + if isinstance(fn, (list, tuple)): + return [allow_in_graph(x) for x in fn] + assert callable(fn), "allow_in_graph expects a callable" + if trace_rules.lookup_callable(fn) != variables.TorchInGraphFunctionVariable: + trace_rules._disallowed_callable_ids.remove(id(fn)) + trace_rules._allowed_callable_ids.add(id(fn)) + return fn + + +def _disallow_in_graph_helper(throw_if_not_allowed): + def inner(fn): + if isinstance(fn, (list, tuple)): + return [disallow_in_graph(x) for x in fn] + assert callable(fn), "disallow_in_graph expects a callable" + if ( + throw_if_not_allowed + and trace_rules.lookup_callable(fn) + != variables.TorchInGraphFunctionVariable + and trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable + ): + raise IncorrectUsage( + "disallow_in_graph is expected to be used on an already allowed callable (like torch.* ops). " + "Allowed callables means callables that TorchDynamo puts as-is in the extracted graph." + ) + trace_rules._allowed_callable_ids.remove(id(fn)) + trace_rules._disallowed_callable_ids.add(id(fn)) + return fn + + return inner + + +def disallow_in_graph(fn): + """ + Customize which functions TorchDynamo will exclude in the generated + graph and force a graph break on. + :: + + torch._dynamo.disallow_in_graph(torch.sub) + + @torch._dynamo.optimize(...) + def fn(a): + x = torch.add(x, 1) + x = torch.sub(x, 1) + x = torch.add(x, 1) + return x + + fn(...) + + Will break the graph on `torch.sub`, and give two graphs each with a + single `torch.add()` op. + """ + return _disallow_in_graph_helper(throw_if_not_allowed=True)(fn) + + +@_disallow_in_graph_helper(throw_if_not_allowed=False) +def graph_break(): + """Force a graph break""" + + +def forbid_in_graph(fn): + """ + Customize which functions TorchDynamo will assert are not present while tracing. + + If you want a graph break on this function instead, use disallow_in_graph. + TODO(voz): We now have allow_in_graph, disallow_in_graph, forbid_in_graph - some more robust + documentation would not be amiss. + """ + if isinstance(fn, (list, tuple)): + return [forbid_in_graph(x) for x in fn] + assert callable(fn), "forbid_in_graph applies only to callables" + fn._dynamo_forbidden = True + return fn + + +def substitute_in_graph( + original_fn: _F, + *, + can_constant_fold_through: bool = False, + skip_signature_check: bool = False, + # type that is embedded in the Python interpreter + is_embedded_type: bool = False, # internal use only +) -> Callable[[_F], _F]: + """ + Register a polyfill handler for a function, usually a C function from the C extension, to be + used in place of the original function when inlining the original function in the graph. + + .. note:: + + The polyfill handler is only used when inlining the original function. It is not used when + the original function is called directly. In the eager mode, the decorated function calls + the performant C function rather than the polyfill handler. + + The polyfill handler is a function that will be called in place of the original function when + inlining the original function. The polyfill handler should have the same signature and the same + behavior as the original function. + + Args: + original_fn (callable): The original function, usually a C function, to register a polyfill + handler for. + can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant + folded through. That is, if the polyfill handler is a pure function and its arguments + are constant, the result of the polyfill handler can be constant folded during the + compilation. Defaults to ``False``. + skip_signature_check (bool, optional): Whether to skip the signature check between the + original function and the polyfill handler. Defaults to ``False``. + + Returns: + A decorator that registers the polyfill handler for the original function. + + Example:: + + >>> # xdoctest: +SKIP("conflict with the tests: duplicate polyfill handlers") + >>> import operator + >>> operator.indexOf([1, 2, 3, 4, 5], 3) + 2 + >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) + Traceback (most recent call last): + ... + torch._dynamo.exc.Unsupported: ... + + >>> @torch.compiler.substitute_in_graph(operator.indexOf) + ... def indexOf(a, b, /): + ... for i, item in enumerate(a): + ... if item is b or item == b: + ... return i + ... raise ValueError("sequence.index(x): x not in sequence") + >>> + >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) + 2 + """ + if not is_function(original_fn) and not ( + is_embedded_type and inspect.isclass(original_fn) + ): + raise TypeError( + f"substitute_in_graph expects a function but got {type(original_fn)!r}" + ) + if is_embedded_type: + if not inspect.isclass(original_fn): + raise TypeError( + f"substitute_in_graph expects a class but got {type(original_fn)!r}" + ) + + from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS + + if id(original_fn) in ITERTOOLS_TYPE_IDS: + ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn)) + + def wrapper(traceable_fn: _F) -> _F: + if not is_function(traceable_fn): + raise TypeError( + f"@substitute_in_graph(...) expects a function but got {type(traceable_fn)!r}" + ) + + if not skip_signature_check: + try: + original_sig = inspect.signature(original_fn) + except ValueError: + pass + else: + traceable_sig = inspect.signature(traceable_fn) + + def sig_ident(sig): + # Ignore annotations for parameters and return type + return ( + tuple( + p.name + for p in sig.parameters.values() + if ( + p.kind + not in { + p.KEYWORD_ONLY, + # the name of *args and **kwargs is not important + p.VAR_POSITIONAL, + p.VAR_KEYWORD, + } + ) + ), + { + p.name + for p in sig.parameters.values() + if p.kind == p.KEYWORD_ONLY + }, + { + p.name: p.default + for p in sig.parameters.values() + # the name of *args and **kwargs is not important + if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD} + }, + ) + + wildcard_sig = inspect.signature(lambda *args, **kwargs: None) + + if ( + sig_ident(original_sig) != sig_ident(traceable_sig) + and sig_ident(original_sig) != sig_ident(wildcard_sig) + and sig_ident(traceable_sig) != sig_ident(wildcard_sig) + ): + raise TypeError( + f"Signature mismatch between {original_fn} and {traceable_fn}: " + f"{original_sig} != {traceable_sig}" + ) + + from torch._dynamo.guards import GuardBuilder + from torch._dynamo.trace_rules import get_torch_obj_rule_map + from torch._dynamo.variables import PolyfilledFunctionVariable + from torch._dynamo.variables.builder import VariableBuilder + + id_dispatch_map = VariableBuilder._id_dispatch() + if id(original_fn) in id_dispatch_map: + raise ValueError( + f"Duplicate dispatch rule for {original_fn}: " + "already registered in VariableBuilder's id dispatch map" + ) + + rule_map: Dict[Any, Type[VariableTracker]] = get_torch_obj_rule_map() + if original_fn in rule_map: + raise ValueError( + f"Duplicate object {original_fn} with different rules: " + f"{PolyfilledFunctionVariable}, {rule_map[original_fn]}" + ) + + polyfill_handlers: Dict[Callable[..., Any], FunctionType] + polyfill_handlers = PolyfilledFunctionVariable._get_polyfill_handlers() + if original_fn in polyfill_handlers: + raise ValueError( + f"Duplicate polyfill handlers for {original_fn}: " + f"already handled by {polyfill_handlers[original_fn]}" + ) + + # Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a + # C++ function. + @functools.wraps(traceable_fn) + def wrapped(*args, **kwargs): + return original_fn(*args, **kwargs) + + def dispatch_fn(self, value: _F) -> PolyfilledFunctionVariable: + return PolyfilledFunctionVariable( + value, + source=self.source, + **self.install_guards(GuardBuilder.FUNCTION_MATCH), + ) + + id_dispatch_map[id(original_fn)] = id_dispatch_map[id(wrapped)] = dispatch_fn + rule_map[original_fn] = rule_map[wrapped] = PolyfilledFunctionVariable + polyfill_handlers[original_fn] = polyfill_handlers[wrapped] = wrapped # type: ignore[assignment] + + wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined] + wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined] + wrapped.__torch_dynamo_can_constant_fold_through__ = can_constant_fold_through # type: ignore[attr-defined] + + return wrapped # type: ignore[return-value] + + return wrapper + + +# Helper function to flatten a tensor subclass and apply a function to +# all inner tensors that match the outer dim. Used to reduce duplication +# across the various marking APIs. +def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs): + assert is_traceable_wrapper_subclass(t) + + attrs, ctx = t.__tensor_flatten__() + assert isinstance(t, torch.Tensor) + for attr in attrs: + inner = getattr(t, attr) + if inner.dim() == t.dim(): + func(inner, *args, **kwargs) + + +@dataclass(frozen=True) +class _DimRange: + """ + This represents an dimension of a tensor and the corresponding + min and max values it can take. Don't create this + class directly; instead, use :func:`mark_dynamic`. + """ + + dim: int + min: int + max: int + + +@forbid_in_graph +def mark_unbacked(t, index): + """ + Mark a tensor as having an unbacked dim. This changes the semantics of operations, + we will always report the size does not equal zero/one, we will turn asserts + on this index into runtime asserts, and if you try to get the real value we will + raise an exception. In other words, we will treat this dimension as if it was + data dependent (we do not know anything about its value.) + """ + # You could have copied the mark_dynamic behavior but I'm not convinced + # it's what you want + assert not is_traceable_wrapper_subclass(t), "not implemented yet" + + if isinstance(index, int): + if not hasattr(t, "_dynamo_unbacked_indices"): + t._dynamo_unbacked_indices = set() + t._dynamo_unbacked_indices.add(index) + return + + assert isinstance(index, (list, tuple)) + for i in index: + mark_unbacked(t, i) + + +@forbid_in_graph +def mark_dynamic(t, index, *, min=None, max=None): + """ + Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim. + + [Note - on the state of mark_dynamic] + + The behavior of having a dynamic dimension on a tensor is governed by a few factors: + + 1) torch._dynamo.config dynamic_shapes True or False. + a) dynamic_shapes=True - dynamic_shapes must be True for mark_dynamic to work. + a) dynamic_shapes=False - This config will raise an exception when used in conjunction with + mark_dynamic. We will eventually support this. + + 2) If the dimension is fully constrained - as in, it does not allow more than a single value + in both eager (torch.compile, torch._dynamo.optimize) mode and export mode (torch._dynamo.export), + we will raise an error + + 3) If the dimension is partially constrained - allowing at least 2 values but not the full unbounded + range of shapes, in eager we will pass it through, but export will raise an error. + + 4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made + before torch.compile. + + """ + if is_traceable_wrapper_subclass(t): + # default behavior: mirror mark_dynamic() on all inner tensors with same dim as t + # TODO: Make this configurable via a supported public API + _apply_func_to_inner_tensors_of_same_dim( + mark_dynamic, t, index, min=min, max=max + ) + + if isinstance(index, int): + if not hasattr(t, "_dynamo_dynamic_indices"): + t._dynamo_dynamic_indices = set() + t._dynamo_dynamic_range = set() + # TODO(voz): Should we bounds check? + t._dynamo_dynamic_indices.add(index) + t._dynamo_dynamic_range.add(_DimRange(index, min, max)) + return + + assert isinstance(index, (list, tuple)) + for i in index: + mark_dynamic(t, i, min=min, max=max) + + +@forbid_in_graph +def maybe_mark_dynamic(t, index): + """ + Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this + dimension ends up getting specialized, don't error). + """ + if is_traceable_wrapper_subclass(t): + # default behavior: mirror maybe_mark_dynamic() on all inner tensors with same dim as t + # TODO: Make this configurable via a supported public API + _apply_func_to_inner_tensors_of_same_dim(maybe_mark_dynamic, t, index) + + if isinstance(index, int): + if not hasattr(t, "_dynamo_weak_dynamic_indices"): + t._dynamo_weak_dynamic_indices = set() + # TODO(voz): Should we bounds check? + t._dynamo_weak_dynamic_indices.add(index) + return + + assert isinstance(index, (list, tuple)) + for i in index: + maybe_mark_dynamic(t, i) + + +def mark_static(t, index=None): + """ + Mark a tensor as having a static dim or mark a nn module class as static. + + For tensors + =========== + This will prevent us from attempting to compile it dynamically + when dynamic=True; this can improve trace-time performance. + + This has lower precedence than mark_dynamic. + + Unlike mark_dynamic, this can be done inside a graph, in which case it + induces specialization on the tensor. + + For nn.Module classes + ===================== + For static nn.Module classes, TorchDynamo assumes that the module instance + attributes will not be modified after compilation. This will ensure that + TorchDynamo keeps integer attributes CONSTANT and not symints. + + From TorchDynamo implementation side, the instances of static-marked + nn.Module class will be converted to UnspecializedBuiltinNNModuleVariable, + which have the same properties. + + Note that we still have to guard on the attributes, because different + instances of the nn.Module can have different values of the attributes. The + key point here is that the attributes are static. + """ + if is_compiling(): + if index is None: + for s in t.size(): + comptime.force_static(s) + else: + comptime.force_static(t.size(index)) + return + + if is_traceable_wrapper_subclass(t): + # default behavior: mirror mark_static() on all inner tensors with same dim as t + # TODO: Make this configurable via a supported public API + _apply_func_to_inner_tensors_of_same_dim(mark_static, t, index) + + if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module): + t._dynamo_marked_static = True + return t + + if not isinstance(t, torch.Tensor): + raise TypeError( + f"mark_static expects a tensor/nn.Module class but recieved {type(t)}" + ) + + if isinstance(index, int): + if not hasattr(t, "_dynamo_static_indices"): + t._dynamo_static_indices = set() # type: ignore[attr-defined] + # TODO(voz): Should we bounds check? + t._dynamo_static_indices.add(index) # type: ignore[attr-defined] + elif index is None: + for i in range(t.dim()): + mark_static(t, i) + else: + assert isinstance(index, (list, tuple)) + for i in index: + mark_static(t, i) + + +@forbid_in_graph +def mark_static_address(t, guard=True): + """ + Marks an input tensor whose data_ptr will not change across multiple calls + to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation + is not needed for this input. The data_ptr will be guarded if guard=True. Note: + Tensors marked in this way will be kept alive until `torch._dynamo.reset()` is called. + """ + if not isinstance(t, torch.Tensor): + raise TypeError(f"mark_static_address expects a tensor but recieved {type(t)}") + + if guard: + t._dynamo_static_input_type = "guarded" # type: ignore[attr-defined] + else: + t._dynamo_static_input_type = "unguarded" # type: ignore[attr-defined] + + +# Note: this carefully avoids eagerly import einops. +# TODO: we should delete this whole _allow_in_graph_einops logic by approximately 2024 Q2 +def _allow_in_graph_einops(): + import einops + + try: + # requires einops > 0.6.1, torch >= 2.0 + from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401 + _ops_were_registered_in_torchdynamo, + ) + + # einops > 0.6.1 will call the op registration logic as it is imported. + except ImportError: + # einops <= 0.6.1 + allow_in_graph(einops.rearrange) + allow_in_graph(einops.reduce) + if hasattr(einops, "repeat"): + allow_in_graph(einops.repeat) # available since einops 0.2.0 + if hasattr(einops, "einsum"): + allow_in_graph(einops.einsum) # available since einops 0.5.0 + if hasattr(einops, "pack"): + allow_in_graph(einops.pack) # available since einops 0.6.0 + if hasattr(einops, "unpack"): + allow_in_graph(einops.unpack) # available since einops 0.6.0 + + +trace_rules.add_module_init_func("einops", _allow_in_graph_einops) diff --git a/lib/python3.10/site-packages/torch/_dynamo/device_interface.py b/lib/python3.10/site-packages/torch/_dynamo/device_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..5670172c49c52c5f5db2e213d0428512fec83946 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/device_interface.py @@ -0,0 +1,330 @@ +# mypy: allow-untyped-defs +import inspect +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union + +import torch +from torch._streambase import _EventBase, _StreamBase + + +get_cuda_stream: Optional[Callable[[int], int]] +if torch.cuda._is_compiled(): + from torch._C import _cuda_getCurrentRawStream as get_cuda_stream +else: + get_cuda_stream = None + +_device_t = Union[torch.device, str, int, None] + +# Recording the device properties in the main process but used in worker process. +caching_worker_device_properties: Dict[str, Any] = {} +caching_worker_current_devices: Dict[str, int] = {} + + +class DeviceInterfaceMeta(type): + def __new__(metacls, *args, **kwargs): + class_member = args[2] + if "Event" in class_member: + assert inspect.isclass(class_member["Event"]) and issubclass( + class_member["Event"], _EventBase + ), "DeviceInterface member Event should be inherit from _EventBase" + if "Stream" in class_member: + assert inspect.isclass(class_member["Stream"]) and issubclass( + class_member["Stream"], _StreamBase + ), "DeviceInterface member Stream should be inherit from _StreamBase" + return super().__new__(metacls, *args, **kwargs) + + +class DeviceInterface(metaclass=DeviceInterfaceMeta): + """ + This is a simple device runtime interface for Inductor. It enables custom + backends to be integrated with Inductor in a device-agnostic semantic. + """ + + class device: + def __new__(cls, device: _device_t): + raise NotImplementedError + + class Worker: + """ + Worker API to query device properties that will work in multi processing + workers that cannot use the GPU APIs (due to processing fork() and + initialization time issues). Properties are recorded in the main process + before we fork the workers. + """ + + @staticmethod + def set_device(device: int): + raise NotImplementedError + + @staticmethod + def current_device() -> int: + raise NotImplementedError + + @staticmethod + def get_device_properties(device: _device_t = None): + raise NotImplementedError + + @staticmethod + def current_device(): + raise NotImplementedError + + @staticmethod + def set_device(device: _device_t): + raise NotImplementedError + + @staticmethod + def maybe_exchange_device(device: int) -> int: + raise NotImplementedError + + @staticmethod + def exchange_device(device: int) -> int: + raise NotImplementedError + + @staticmethod + def device_count(): + raise NotImplementedError + + @staticmethod + def is_available() -> bool: + raise NotImplementedError + + @staticmethod + def stream(stream: torch.Stream): + raise NotImplementedError + + @staticmethod + def current_stream(): + raise NotImplementedError + + @staticmethod + def set_stream(stream: torch.Stream): + raise NotImplementedError + + @staticmethod + def _set_stream_by_id(stream_id: int, device_index: int, device_type: int): + raise NotImplementedError + + @staticmethod + def get_raw_stream(device_idx: int) -> int: + raise NotImplementedError + + @staticmethod + def synchronize(device: _device_t = None): + raise NotImplementedError + + @staticmethod + def get_device_properties(device: _device_t = None): + raise NotImplementedError + + @staticmethod + def get_compute_capability(device: _device_t = None): + raise NotImplementedError + + @staticmethod + def is_bf16_supported(including_emulation: bool = False): + raise NotImplementedError + + +class DeviceGuard: + """ + This class provides a context manager for device switching. This is a stripped + down version of torch.{device_name}.device. + + The context manager changes the current device to the given device index + on entering the context and restores the original device on exiting. + The device is switched using the provided device interface. + """ + + def __init__( + self, device_interface: Type[DeviceInterface], index: Optional[int] + ) -> None: + self.device_interface = device_interface + self.idx = index + self.prev_idx = -1 + + def __enter__(self): + if self.idx is not None: + self.prev_idx = self.device_interface.exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + if self.idx is not None: + self.idx = self.device_interface.maybe_exchange_device(self.prev_idx) + return False + + +class CudaInterface(DeviceInterface): + device = torch.cuda.device + + # register Event and Stream class into the backend interface + # make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase + Event = torch.cuda.Event + Stream = torch.cuda.Stream + + class Worker: + @staticmethod + def set_device(device: int): + caching_worker_current_devices["cuda"] = device + + @staticmethod + def current_device() -> int: + if "cuda" in caching_worker_current_devices: + return caching_worker_current_devices["cuda"] + return torch.cuda.current_device() + + @staticmethod + def get_device_properties(device: _device_t = None): + if device is not None: + if isinstance(device, str): + device = torch.device(device) + assert device.type == "cuda" + if isinstance(device, torch.device): + device = device.index + if device is None: + device = CudaInterface.Worker.current_device() + + if "cuda" not in caching_worker_device_properties: + device_prop = [ + torch.cuda.get_device_properties(i) + for i in range(torch.cuda.device_count()) + ] + caching_worker_device_properties["cuda"] = device_prop + + return caching_worker_device_properties["cuda"][device] + + current_device = staticmethod(torch.cuda.current_device) + set_device = staticmethod(torch.cuda.set_device) + device_count = staticmethod(torch.cuda.device_count) + stream = staticmethod(torch.cuda.stream) # type: ignore[assignment] + current_stream = staticmethod(torch.cuda.current_stream) + set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment] + _set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment] + synchronize = staticmethod(torch.cuda.synchronize) + get_device_properties = staticmethod(torch.cuda.get_device_properties) # type: ignore[assignment] + get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[assignment, arg-type] + exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type] + maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type] + is_bf16_supported = staticmethod(torch.cuda.is_bf16_supported) # type: ignore[arg-type] + + # Can be mock patched by @patch decorator. + @staticmethod + def is_available() -> bool: + return torch.cuda.is_available() + + @staticmethod + def get_compute_capability(device: _device_t = None): + if torch.version.hip is None: + major, min = torch.cuda.get_device_capability(device) + return major * 10 + min + else: + return torch.cuda.get_device_properties(device).gcnArchName.split(":", 1)[0] + + +get_xpu_stream: Optional[Callable[[int], int]] +if torch.xpu._is_compiled(): + from torch._C import _xpu_getCurrentRawStream as get_xpu_stream +else: + get_xpu_stream = None + + +class XpuInterface(DeviceInterface): + device = torch.xpu.device + Event = torch.xpu.Event + Stream = torch.xpu.Stream + + class Worker: + @staticmethod + def set_device(device: int): + caching_worker_current_devices["xpu"] = device + + @staticmethod + def current_device() -> int: + if "xpu" in caching_worker_current_devices: + return caching_worker_current_devices["xpu"] + return torch.xpu.current_device() + + @staticmethod + def get_device_properties(device: _device_t = None): + if device is not None: + if isinstance(device, str): + device = torch.device(device) + assert device.type == "xpu" + if isinstance(device, torch.device): + device = device.index + if device is None: + device = XpuInterface.Worker.current_device() + + if "xpu" not in caching_worker_device_properties: + device_prop = [ + torch.xpu.get_device_properties(i) + for i in range(torch.xpu.device_count()) + ] + caching_worker_device_properties["xpu"] = device_prop + + return caching_worker_device_properties["xpu"][device] + + current_device = staticmethod(torch.xpu.current_device) + set_device = staticmethod(torch.xpu.set_device) + device_count = staticmethod(torch.xpu.device_count) + stream = staticmethod(torch.xpu.stream) # type: ignore[assignment] + current_stream = staticmethod(torch.xpu.current_stream) + set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment] + _set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment] + synchronize = staticmethod(torch.xpu.synchronize) + get_device_properties = staticmethod(torch.xpu.get_device_properties) # type: ignore[assignment] + get_raw_stream = staticmethod(get_xpu_stream) # type: ignore[assignment, arg-type] + exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type] + maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type] + + # Can be mock patched by @patch decorator. + @staticmethod + def is_available() -> bool: + return torch.xpu.is_available() + + @staticmethod + def get_compute_capability(device: _device_t = None): + cc = torch.xpu.get_device_capability(device) + return cc + + @staticmethod + def is_bf16_supported(including_emulation: bool = False) -> bool: + return torch.xpu.is_bf16_supported() + + +device_interfaces: Dict[str, Type[DeviceInterface]] = {} +_device_initialized = False + + +def register_interface_for_device( + device: Union[str, torch.device], device_interface: Type[DeviceInterface] +): + if isinstance(device, torch.device): + device = str(device) + device_interfaces[device] = device_interface + + +def get_interface_for_device(device: Union[str, torch.device]) -> Type[DeviceInterface]: + if isinstance(device, torch.device): + device = str(device) + if not _device_initialized: + init_device_reg() + if device in device_interfaces: + return device_interfaces[device] + raise NotImplementedError(f"No interface for device {device}") + + +def get_registered_device_interfaces() -> Iterable[Tuple[str, Type[DeviceInterface]]]: + if not _device_initialized: + init_device_reg() + return device_interfaces.items() + + +def init_device_reg(): + global _device_initialized + register_interface_for_device("cuda", CudaInterface) + for i in range(torch.cuda.device_count()): + register_interface_for_device(f"cuda:{i}", CudaInterface) + + register_interface_for_device("xpu", XpuInterface) + for i in range(torch.xpu.device_count()): + register_interface_for_device(f"xpu:{i}", XpuInterface) + + _device_initialized = True diff --git a/lib/python3.10/site-packages/torch/_dynamo/distributed.py b/lib/python3.10/site-packages/torch/_dynamo/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..90a1376c7a13b83891cf3b80fc51b14fca13c3f7 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/distributed.py @@ -0,0 +1,25 @@ +from typing import Optional + +import torch.distributed as dist + +from . import config + + +_COMPILE_PG: Optional[dist.ProcessGroup] = None + + +def get_compile_pg() -> Optional[dist.ProcessGroup]: + if ( + config.enable_compiler_collectives + and dist.is_available() + and dist.is_initialized() + ): + global _COMPILE_PG + if _COMPILE_PG is None: + # , timeout=datetime.timedelta(seconds=2) + _COMPILE_PG = dist.distributed_c10d._new_group_with_tag( + pg_tag="pt2_compile_pg" + ) + return _COMPILE_PG + + return None diff --git a/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py b/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..c04e4ccb00a97a7508320d1be7e446516eae0421 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py @@ -0,0 +1,1717 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code="method-assign" + +""" +Functions in this file are responsible for modifying the eval frame +handler at RUNTIME. Therefore, all functions in this file are hot. +Functions that only execute at compile time should be placed +in torch._dynamo.convert_frame. +""" + +from __future__ import annotations + +import contextlib +import functools +import inspect +import logging +import os +import sys +import textwrap +import traceback +import types +import warnings +import weakref +from enum import Enum +from os.path import dirname, join +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) +from unittest.mock import patch + +import sympy + +import torch +import torch.fx +import torch.utils._pytree as pytree +import torch.utils.checkpoint +from torch import _guards + +# see discussion at https://github.com/pytorch/pytorch/issues/120699 +from torch._C._dynamo.eval_frame import ( # noqa: F401 + reset_code, + set_guard_error_hook, + skip_code, + unsupported, +) +from torch._dispatch.python import enable_python_dispatcher +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch._utils_internal import justknobs_check, log_export_usage +from torch.export.dynamic_shapes import _combine_args, _process_dynamic_shapes +from torch.fx import GraphModule +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + DimDynamic, + ShapeEnv, + StatelessSymbolicContext, +) +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo + +from . import config, convert_frame, external_utils, trace_rules, utils +from .backends.registry import CompilerFn, lookup_backend +from .code_context import code_context +from .exc import CondOpArgsMismatchError, UserError, UserErrorType +from .hooks import Hooks +from .mutation_guard import install_generation_tagging_init +from .utils import common_constant_types, compile_times + + +if TYPE_CHECKING: + from torch._subclasses import fake_tensor + + from .types import CacheEntry, DynamoCallback + + +log = logging.getLogger(__name__) + + +always_optimize_code_objects = utils.ExactWeakKeyDictionary() +null_context = contextlib.nullcontext + + +# See https://github.com/python/typing/pull/240 +class Unset(Enum): + token = 0 + + +cached_backends: Dict[int, CompilerFn] = {} + +unset = Unset.token + + +def _maybe_set_eval_frame(callback: DynamoCallback): + # A wrapper on set_eval_frame that is guarded by a Justknob. + # Users can disable torchDynamo by setting the JK to False. + from torch._C._dynamo.eval_frame import set_eval_frame + + if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"): + torch._dynamo.utils.warn_once( + "Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame" + ) + return callback + else: + return set_eval_frame(callback) + + +def _reset_guarded_backend_cache(): + global cached_backends + for backend in cached_backends.values(): + if hasattr(backend, "reset"): + backend.reset() + cached_backends.clear() + + +DONT_WRAP_FILES = { + # For tracing into fx modules + inspect.getsourcefile(GraphModule), + join(dirname(dirname(__file__)), "onnx/_internal/fx/dynamo_graph_extractor.py"), +} + + +def _debug_get_cache_entry_list( + code: Union[types.CodeType, Callable[..., Any]] +) -> List[CacheEntry]: + """ + Given a code object or a callable object, retrieve the cache entries + stored in this code. + """ + if callable(code): + code = code.__code__ + return torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code) + + +class OptimizedModule(torch.nn.Module): + """ + Wraps the original nn.Module object and later patches its + forward method to optimized self.forward method. + """ + + _torchdynamo_orig_callable: Callable[..., Any] + get_compiler_config: Callable[[], Any] + + _opt_mod_attributes = { + "_orig_mod", + "dynamo_ctx", + "_torchdynamo_orig_callable", + "get_compiler_config", + "forward", + "_forward", + "__dict__", + "named_children_walk", + } + + def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None: + super().__init__() + # Installs the params/buffer + self._orig_mod = mod + self.dynamo_ctx = dynamo_ctx + self._initialize() + self.training = self._orig_mod.training + + def _initialize(self): + # Do this stuff in constructor to lower overhead slightly + if isinstance(self.dynamo_ctx, DisableContext): + # No need to check trace rules + self.forward = self.dynamo_ctx(self._orig_mod.__call__) + elif isinstance(self._orig_mod.forward, types.MethodType) and ( + trace_rules.check(self._orig_mod.forward) + or getattr(self._orig_mod, "_is_fsdp_managed_module", False) + ): + # This may be a torch.nn.* instance in trace_rules.py which + # won't trigger a frame evaluation workaround to add an extra + # frame we can capture + self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod)) + else: + # Invoke hooks outside of dynamo then pickup the inner frame + self.forward = self.dynamo_ctx(self._orig_mod.__call__) + + if hasattr(self._orig_mod, "_initialize_hook"): + self._forward = self.forward + self.forward = self._call_lazy_check + + def __reduce__(self): + return (self.__class__, (self._orig_mod, self.dynamo_ctx)) + + def __getstate__(self): + state = dict(self.__dict__) + state.pop("forward", None) + state.pop("__call__", None) + return state + + def __setstate__(self, state): + self.__dict__ = state + self._initialize() + + @property + def training(self): + return self._orig_mod.training + + @training.setter + def training(self, value): + try: + super().__getattr__("_orig_mod") + self._orig_mod.training = value + except AttributeError: + # still initializing + pass + + def __getattr__(self, name): + if name == "_orig_mod": + return self._modules["_orig_mod"] + return getattr(self._orig_mod, name) + + def __setattr__(self, name, val) -> None: + # Allow patching over class attributes + if hasattr(type(self), name): + return super().__setattr__(name, val) + + if name in OptimizedModule._opt_mod_attributes: + return super().__setattr__(name, val) + return setattr(self._orig_mod, name, val) + + def _call_lazy_check(self, *args, **kwargs): + if hasattr(self._orig_mod, "_initialize_hook"): + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it. + # Afterwards, lazy module deletes its pre-hooks + # to avoid treating it as lazy on subsequent recompile. + self._orig_mod._infer_parameters(self._orig_mod, args, kwargs) + return self._forward(*args, **kwargs) + + def __dir__(self): + orig_mod_attrs = self._orig_mod.__dir__() + return orig_mod_attrs + [ + attr for attr in super().__dir__() if attr not in orig_mod_attrs + ] + + +def remove_from_cache(f): + """ + Make sure f.__code__ is not cached to force a recompile + """ + if isinstance(f, types.CodeType): + reset_code(f) + elif hasattr(f, "__code__"): + reset_code(f.__code__) + elif hasattr(getattr(f, "forward", None), "__code__"): + reset_code(f.forward.__code__) + else: + from . import reset # type: ignore[attr-defined] + + reset() + log.warning("could not determine __code__ for %s", f) + + +def nothing(): + pass + + +def always_false(): + return False + + +def innermost_fn(fn): + """ + In case of nesting of _TorchDynamoContext calls, find the innermost + function. TorchDynamo caches on fn.__code__ object, so its necessary to find + the innermost function to pass on the optimize, run, disable etc. + """ + unaltered_fn = fn + while hasattr(unaltered_fn, "_torchdynamo_orig_callable"): + unaltered_fn = unaltered_fn._torchdynamo_orig_callable + assert callable(unaltered_fn) + return unaltered_fn + + +def make_set_enable_dynamic(enable: bool): + assert isinstance(enable, bool) + if enable: + # Assume everything is dynamic by default + return config._make_closure_patcher(assume_static_by_default=False) + else: + return config._make_closure_patcher( + automatic_dynamic_shapes=False, assume_static_by_default=True + ) + + +class _TorchDynamoContext: + def __init__( + self, + callback: DynamoCallback, + on_enter=nothing, + backend_ctx_ctor=null_context, + patch_fn=nothing, + first_ctx=False, + *, + export=False, + dynamic=None, + compiler_config=None, + ) -> None: + super().__init__() + assert callable(callback) or callback is False or callback is None + self.callback: DynamoCallback = callback + self._backend_ctx_ctor = backend_ctx_ctor + self.prior: Union[Unset, DynamoCallback] = unset + self.first_ctx = first_ctx + self.export = export + self._dynamic = dynamic + self.compiler_config = compiler_config + self.cleanup_fns: List[Callable[[], Any]] = [] + self.enter_exit_hooks = [] + patch_fn() + + # Save the backends so that we can reset them during torch._dynamo.reset + backend = innermost_fn(callback) + cached_backends.setdefault(id(backend), backend) + + if dynamic is not None: + self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic)) + + if on_enter is not nothing: + # this case is not common + def call_on_enter(): + on_enter() + return nothing + + self.enter_exit_hooks.append(call_on_enter) + + if backend_ctx_ctor is not contextlib.nullcontext: + # this case is not common + def call_backend_ctx(): + ctx = backend_ctx_ctor() + ctx.__enter__() + return functools.partial(ctx.__exit__, None, None, None) + + self.enter_exit_hooks.append(call_backend_ctx) + + def __enter__(self): + if config.raise_on_ctx_manager_usage: + raise RuntimeError( + "torch._dynamo.optimize(...) is used with a context manager. " + "Please refer to https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html " + "to use torch._dynamo.optimize(...) as an annotation/decorator. " + ) + self.cleanup_fns = [enter() for enter in self.enter_exit_hooks] + self.prior = _maybe_set_eval_frame(self.callback) + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.prior is not unset + _maybe_set_eval_frame(self.prior) + self.prior = unset + for cleanup in self.cleanup_fns: + cleanup() + self.cleanup_fns.clear() + + def __call__(self, fn): + # public api for compiler config/options + def get_compiler_config(): + return self.compiler_config + + fn = innermost_fn(fn) + + # add context containing GraphModule to any GraphModule forward functions + if isinstance(fn, GraphModule): + # add context containing GraphModule to any GraphModule forward functions + code_context.get_context(fn.forward.__code__)[ + "orig_graphmodule" + ] = weakref.ref(fn) + + # Optimize the forward method of torch.nn.Module object + if isinstance(fn, torch.nn.Module): + mod = fn + new_mod = OptimizedModule(mod, self) + # Save the function pointer to find the original callable while nesting + # of decorators. + new_mod._torchdynamo_orig_callable = mod.forward + + # when compiling torch.nn.Module, + # provide public api OptimizedModule.get_compiler_config() + assert not hasattr(new_mod, "get_compiler_config") + new_mod.get_compiler_config = get_compiler_config + + return new_mod + + if inspect.isclass(fn): + # User has wrapped the class with compile/disable decorator. Apply + # disable to init/call method. + cls_obj = fn + cls_obj.__call__ = self(cls_obj.__call__) + if issubclass(cls_obj, torch.nn.Module): + # NN module variable tracker directly inlines the _call_impl. + cls_obj._call_impl = self(cls_obj._call_impl) + return cls_obj + + assert callable(fn) + + try: + filename = inspect.getsourcefile(fn) + except TypeError: + filename = None + if ( + (filename is None or trace_rules.check(fn)) + and ( + getattr(fn, "__name__", "") + not in ["_call_impl", "_wrapped_call_impl", "_lazy_forward"] + ) + and filename not in DONT_WRAP_FILES + ): + # call to a builtin without a frame for us to capture + fn = external_utils.wrap_inline(fn) + + def do_nothing(*arg, **kwargs): + pass + + if hasattr(self, "callback"): + callback = self.callback + else: + callback = do_nothing + + is_jit_tracing = torch._C._is_tracing + is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing + + @functools.wraps(fn) + def _fn(*args, **kwargs): + if is_fx_tracing(): + if config.error_on_nested_fx_trace: + raise RuntimeError( + "Detected that you are using FX to symbolically trace " + "a dynamo-optimized function. This is not supported at the moment." + ) + else: + return fn(*args, **kwargs) + + if is_jit_tracing(): + if config.error_on_nested_jit_trace: + raise RuntimeError( + "Detected that you are using FX to torch.jit.trace " + "a dynamo-optimized function. This is not supported at the moment." + ) + else: + return fn(*args, **kwargs) + + cleanups = [enter() for enter in self.enter_exit_hooks] + prior = _maybe_set_eval_frame(callback) + + # Ensure that if an assertion occurs after graph pushes + # something onto the DynamicLayerStack then we pop it off (the + # constructed graph code isn't guarded with try/finally). + # + # This used to be a context but putting a `with` here is a noticible + # perf regression (#126293) + saved_dynamic_layer_stack_depth = ( + torch._C._functorch.get_dynamic_layer_stack_depth() + ) + + try: + return fn(*args, **kwargs) + finally: + # Restore the dynamic layer stack depth if necessary. + torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( + saved_dynamic_layer_stack_depth + ) + + _maybe_set_eval_frame(prior) + for cleanup in cleanups: + cleanup() + + # hooks to properly handle inlining + _fn._torchdynamo_inline = fn # type: ignore[attr-defined] + + # Save the function pointer to find the original callable while nesting + # of decorators. + _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + + # when compiling user function instead of nn.Module + # provide public api _fn.get_compiler_config() + assert not hasattr(_fn, "get_compiler_config") + _fn.get_compiler_config = get_compiler_config # type: ignore[attr-defined] + + # If the function is called using torch._dynamo.optimize decorator, we + # should prevent any type of skipping. + if callback not in (None, False): + if not hasattr(fn, "__code__"): + raise RuntimeError( + textwrap.dedent( + """ + + torch._dynamo.optimize is called on a non function object. + If this is a callable class, please wrap the relevant code into a function and optimize the + wrapper function. + + >> class CallableClass: + >> def __init__(self) -> None: + >> super().__init__() + >> self.relu = torch.nn.ReLU() + >> + >> def __call__(self, x): + >> return self.relu(torch.sin(x)) + >> + >> def print_hello(self): + >> print("Hello world") + >> + >> mod = CallableClass() + + If you want to optimize the __call__ function and other code, wrap that up in a function + + >> def wrapper_fn(x): + >> y = mod(x) + >> return y.sum() + + and then optimize the wrapper_fn + + >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn) + """ + ) + ) + always_optimize_code_objects[fn.__code__] = True + + return _fn + + +class OptimizeContext(_TorchDynamoContext): + def __init__( + self, + callback, + backend_ctx_ctor, + first_ctx=False, + *, + export=False, + dynamic=None, + compiler_config=None, + rebuild_ctx: Optional[ + Callable[[], Union[OptimizeContext, _NullDecorator]] + ] = None, + ) -> None: + def on_enter(): + install_generation_tagging_init() + + super().__init__( + callback=callback, + on_enter=on_enter, + backend_ctx_ctor=backend_ctx_ctor, + patch_fn=TorchPatcher.patch, + first_ctx=first_ctx, + export=export, + dynamic=dynamic, + compiler_config=compiler_config, + ) + + if config.compiled_autograd: + + def call_compiled_autograd(): + assert rebuild_ctx is not None + compiler_fn = rebuild_ctx() + ctx = torch._dynamo.compiled_autograd.enable(compiler_fn) + ctx.__enter__() + return functools.partial(ctx.__exit__, None, None, None) + + self.enter_exit_hooks.append(call_compiled_autograd) + + def __reduce__(self): + return ( + self.__class__, + (self.callback, self._backend_ctx_ctor, self.first_ctx), + { + "export": self.export, + "dynamic": self._dynamic, + "compiler_config": self.compiler_config, + }, + ) + + +class RunOnlyContext(_TorchDynamoContext): + def __init__(self) -> None: + # cudagraph trees relies on generation increment + def on_enter(): + torch._dynamo.mutation_guard.GenerationTracker.generation += 1 + + super().__init__(callback=False, on_enter=on_enter) + + def __reduce__(self): + return (self.__class__, ()) + + +class DisableContext(_TorchDynamoContext): + def __init__(self) -> None: + super().__init__(callback=None) + + def __call__(self, fn): + # Earlier this code was in the base class _TorchDynamoContext. But we + # moved it here to have better code organization. For disable, we just + # want the callback to be None. We don't have to check trace_rules or + # create any wrapper. + fn = innermost_fn(fn) + + if isinstance(fn, torch.nn.Module): + mod = fn + new_mod = OptimizedModule(mod, self) + new_mod._torchdynamo_orig_callable = mod.forward + return new_mod + + if inspect.isclass(fn): + # User has wrapped the class with compile/disable decorator. Apply + # disable to init/call method. + cls_obj = fn + # Disable on init is useful for reconstruction of bytecodes where we + # want to prevent Dynamo from tracing into the init function. Check + # test_reconstruction in test_model_output.py. + cls_obj.__init__ = self(cls_obj.__init__) + cls_obj.__call__ = self(cls_obj.__call__) + if issubclass(cls_obj, torch.nn.Module): + # NN module variable tracker directly inlines the _call_impl. Disable it. + cls_obj._call_impl = self(cls_obj._call_impl) + return cls_obj + + assert callable(fn) + + callback = self.callback + + @functools.wraps(fn) + def _fn(*args, **kwargs): + prior = _maybe_set_eval_frame(callback) + try: + return fn(*args, **kwargs) + finally: + _maybe_set_eval_frame(prior) + + _fn._torchdynamo_disable = True # type: ignore[attr-defined] + + # Save the function pointer to find the original callable while nesting + # of decorators. + _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + + return _fn + + def __reduce__(self): + return (self.__class__, ()) + + +def _optimize_catch_errors( + compile_fn, + hooks: Hooks, + backend_ctx_ctor=null_context, + export=False, + dynamic=None, + compiler_config=None, + rebuild_ctx=None, +): + return OptimizeContext( + convert_frame.catch_errors_wrapper(compile_fn, hooks), + backend_ctx_ctor=backend_ctx_ctor, + first_ctx=True, + export=export, + dynamic=dynamic, + compiler_config=compiler_config, + rebuild_ctx=rebuild_ctx, + ) + + +def get_compiler_fn(compiler_fn): + from .repro.after_dynamo import wrap_backend_debug + + if hasattr(compiler_fn, "compiler_name"): + compiler_str = compiler_fn.compiler_name + elif isinstance(compiler_fn, str): + compiler_str = compiler_fn + else: + compiler_str = None + compiler_fn = lookup_backend(compiler_fn) + return wrap_backend_debug(compiler_fn, compiler_str) + + +class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg] + def __call__(self, fn): + assert callable(fn) + return fn + + +def check_if_dynamo_supported(): + if sys.version_info >= (3, 13): + raise RuntimeError("Python 3.13+ not yet supported for torch.compile") + + +def is_dynamo_supported(): + try: + check_if_dynamo_supported() + return True + except Exception: + return False + + +def check_if_inductor_supported(): + check_if_dynamo_supported() + + +def is_inductor_supported(): + try: + check_if_inductor_supported() + return True + except Exception: + return False + + +def optimize(*args, **kwargs): + def rebuild_ctx(): + return optimize(*args, **kwargs) + + return _optimize(rebuild_ctx, *args, **kwargs) + + +def _optimize( + rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]], + backend="inductor", + *, + nopython=False, + guard_export_fn=None, + guard_fail_fn=None, + disable=False, + dynamic=None, +) -> Union[OptimizeContext, _NullDecorator]: + """ + The main entrypoint of TorchDynamo. Do graph capture and call + backend() to optimize extracted graphs. + + Args: + backend: One of the two things: + - Either, a function/callable taking a torch.fx.GraphModule and + example_inputs and returning a python callable that runs the + graph faster. + One can also provide additional context for the backend, like + torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute. + See AOTAutogradMemoryEfficientFusionWithContext for the usage. + - Or, a string backend name in `torch._dynamo.list_backends()` + nopython: If True, graph breaks will be errors and there will + be a single whole-program graph. + disable: If True, turn this decorator into a no-op + dynamic: If True, upfront compile as dynamic a kernel as possible. If False, + disable all dynamic shapes support (always specialize). If None, automatically + detect when sizes vary and generate dynamic kernels upon recompile. + + Example Usage:: + + @torch._dynamo.optimize() + def toy_example(a, b): + ... + """ + check_if_dynamo_supported() + # Note: The hooks object could be global instead of passed around, *however* that would make + # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls. + # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same + # compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an + # easier to understand UX at the cost of a little more plumbing on our end. + hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn) + torch._C._log_api_usage_once("torch._dynamo.optimize") + if ( + disable + or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1" + or (not justknobs_check("pytorch/compiler:enable_dynamo")) + ): + return _NullDecorator() + + backend = get_compiler_fn(backend) + + # Find if backend has any extra context manager + backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) + + if nopython: + return optimize_assert( + backend, + dynamic=dynamic, + hooks=hooks, + rebuild_ctx=rebuild_ctx, + ) + # The backend function is stashed in the callable returned by + # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can + # be used by eval_frame.c to insert a guard on the backend. + return _optimize_catch_errors( + convert_frame.convert_frame(backend, hooks=hooks), + hooks, + backend_ctx_ctor, + dynamic=dynamic, + compiler_config=backend.get_compiler_config() + if hasattr(backend, "get_compiler_config") + else None, + rebuild_ctx=rebuild_ctx, + ) + + +# TODO(voz): Consider making "explain" output alongside a run / part of a run +@patch("torch._dynamo.symbolic_convert.explain", True) +def explain(f, *extra_args, **extra_kwargs): + def inner(*args, **kwargs): + # TODO(voz): Do we want a decorator for this? + from . import reset # type: ignore[attr-defined] + + reset() + + graphs: List[torch.fx.GraphModule] = [] + break_reasons: List[Any] = [] + op_count: int = 0 + ops_per_graph: List[torch.fx.Node] = [] + out_guards: List[_guards.Guard] = [] + + def dynamo_graph_accumulating_compiler( + gm: torch.fx.GraphModule, example_inputs + ): + from .backends.debugging import _explain_graph_detail + + nonlocal graphs + nonlocal op_count + nonlocal ops_per_graph + nonlocal break_reasons + + gm, graphs, op_count, ops_per_graph, break_reasons = _explain_graph_detail( + gm, graphs, op_count, ops_per_graph, break_reasons + ) + + return gm.forward + + def guard_export_print(guards): + nonlocal out_guards + out_guards.extend(guards) + + opt_f = optimize( + dynamo_graph_accumulating_compiler, + nopython=False, + guard_export_fn=guard_export_print, + )(f) + # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject. + opt_f(*args, **kwargs) + + graph_count = len(graphs) + graph_break_count = graph_count - 1 + compile_time = compile_times(repr="str") + + # TODO(voz): Do we want a decorator for this? + reset() + from .backends.debugging import ExplainOutput + + return ExplainOutput( + graphs, + graph_count, + graph_break_count, + break_reasons, + op_count, + ops_per_graph, + out_guards, + compile_time, + ) + + if extra_args or extra_kwargs: + warnings.warn( + "explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead. " + "If you don't migrate, we may break your explain call in the future if your user defined kwargs " + "conflict with future kwargs added to explain(f).", + FutureWarning, + stacklevel=2, + ) + return inner(*extra_args, **extra_kwargs) + else: + return inner + + +class FlattenInputOutputSignature(torch.fx.interpreter.Transformer): + def __init__( + self, + m: torch.fx.GraphModule, + flat_args: Tuple[Any], + matched_input_elements_positions: List[int], + flat_results: List[Any], + matched_output_elements_positions: List[int], + example_fake_inputs: List[torch.Tensor], + flat_args_dynamic_dims: List[Set[int]], + fake_mode: Optional[fake_tensor.FakeTensorMode] = None, + ) -> None: + super().__init__(m) + + assert len(flat_args_dynamic_dims) == len(flat_args) + matched_input_elements_to_fake = { + val: example_fake_inputs[ix] + for ix, val in enumerate(matched_input_elements_positions) + } + + self.new_args = [] + for i in range(0, len(flat_args)): + arg = super().placeholder(f"arg{i}", (), {}) + if i in matched_input_elements_to_fake: + arg.node.meta["val"] = matched_input_elements_to_fake[i] + else: + # Fill node.mata["val"] with faketensor from the input, + # if it's not found in matched_input_elements_positions + if fake_mode is not None and isinstance(flat_args[i], torch.Tensor): + # TODO(zhxchen17) Also preserve all the user constraints here. + arg.node.meta["val"] = fake_mode.from_tensor( + flat_args[i], + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=[ + DimDynamic.DYNAMIC + if d in flat_args_dynamic_dims[i] + else DimDynamic.STATIC + for d in range(len(flat_args[i].shape)) + ], + constraint_sizes=[None] * len(flat_args[i].shape), + ), + ) + self.new_args.append(arg) + self.old_args_gen = (self.new_args[i] for i in matched_input_elements_positions) + self.matched_output_elements_positions = matched_output_elements_positions + self.flat_results = flat_results + + def placeholder(self, target, args, kwargs): + arg = next(self.old_args_gen) + if "val" in self.current_node.meta: + arg.node.meta["val"] = self.current_node.meta["val"] + if "tensor_dict" in self.current_node.meta: + arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"] + if "example_value" in self.current_node.meta: + # NB: intentionally do not use set_example_value + arg.node.meta["example_value"] = self.current_node.meta["example_value"] + if "unbacked_bindings" in self.current_node.meta: + arg.node.meta["unbacked_bindings"] = self.current_node.meta[ + "unbacked_bindings" + ] + return arg + + def output(self, target, args, kwargs): + dynamo_result_flat = args[0] + lookup = [*dynamo_result_flat, *self.new_args] + new_results_flat = [] + for i in range(len(self.flat_results)): + if self.matched_output_elements_positions[i] is not None: + new_results_flat.append( + lookup[self.matched_output_elements_positions[i]] + ) + else: + const_val = self.flat_results[i] + assert isinstance(const_val, tuple(common_constant_types)) + new_results_flat.append(const_val) + return super().output(target, (new_results_flat,), {}) + + def run_node(self, n): + self.current_node = n + result_proxy = super().run_node(n) + if "val" in self.current_node.meta: + result_proxy.node.meta["val"] = self.current_node.meta["val"] + if "example_value" in self.current_node.meta: + # NB: intentionally do not use set_example_value + result_proxy.node.meta["example_value"] = self.current_node.meta[ + "example_value" + ] + if "unbacked_bindings" in self.current_node.meta: + result_proxy.node.meta["unbacked_bindings"] = self.current_node.meta[ + "unbacked_bindings" + ] + if self.current_node.op != "output": + result_proxy.node._rename( + getattr(self.current_node, "name", result_proxy.node.name) + ) + return result_proxy + + def transform(self): + result_gm = super().transform() + if "dynamo_flat_name_to_original_fqn" in self.module.meta: + result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ + "dynamo_flat_name_to_original_fqn" + ] + return result_gm + + +class ExportResult(NamedTuple): + graph_module: torch.fx.GraphModule + guards: _guards.GuardsSet + # NB: Do not add new fields without overriding __iter__; people are + # destructuring so it is BC-breaking + + +def check_signature_rewritable(graph): + input_errors = [] + for node in graph.graph.find_nodes(op="placeholder"): + assert hasattr(node, "_dynamo_source") + source = node._dynamo_source + user_stacks = graph._source_to_user_stacks.get(source) + if user_stacks is None: + continue + assert len(user_stacks) > 0 + # In some cases we may not have a useful stack. Look for a + # useful stack + stack = None + for s in user_stacks: + if len(s) == 0: + continue + stack = s + break + if stack is None: + msg = f"{source.name()}, a closed over free variable" + else: + tb = "".join(traceback.format_list(stack)) + extra = "" + if len(user_stacks) > 1: + extra = f"(elided {len(user_stacks) - 1} more accesses)" + msg = f"{source.name()}, accessed at:\n{tb}{extra}" + # TODO: option to print ALL of the stack traces at once + input_errors.append(msg) + + if input_errors: + raise UserError( + UserErrorType.INVALID_INPUT, + "Cannot export model which references tensors that are neither " + "buffers/parameters/constants nor are direct inputs. For each tensor, if you'd " + "like this tensor to be an explicit input, add it as a dummy argument " + "to the top-level model definition you are exporting; if you would " + "like its value to be embedded as an exported constant, wrap its access " + "in a function marked with @assume_constant_result.\n\n" + + "\n\n".join(input_errors), + ) + + +def rewrite_signature( + f_sig, + graph, + fake_mode, + flat_args, + in_spec, + example_fake_inputs, + graph_captured_input, + graph_captured_output, + dynamo_traced_result, + flat_args_dynamic_dims, +): + orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec) + + def check_user_input_output(flat_values, error_type): + supported_types = [ + torch.Tensor, + torch.SymInt, + torch.SymFloat, + torch.SymBool, + torch._C.ScriptObject, + ] + list(common_constant_types) + + def is_supported_type(val): + return isinstance(val, tuple(supported_types)) + + value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output" + # We only check that the outputs are not None. Inputs can be None. + for v in flat_values: + if not is_supported_type(v): + if error_type == UserErrorType.INVALID_INPUT and v is None: + continue + + raise UserError( + error_type, + f"It looks like one of the {value_type}s with type `{type(v)}` " + "is not supported or pytree-flattenable. \n" + f"Exported graphs {value_type}s can only contain the " + f"following supported types: {supported_types}. \n" + "If you are using a custom class object, " + "please register a pytree_flatten/unflatten function " + "using `torch.utils._pytree.register_pytree_node` or " + "`torch.export.register_dataclass`.", + ) + + check_user_input_output(flat_args, UserErrorType.INVALID_INPUT) + flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result) + check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT) + + def check_optional_input_and_error(f_sig: inspect.Signature): + # Check if function has optional input. + for name, param in f_sig.parameters.items(): + if param.default is not inspect.Parameter.empty: + from torch._dynamo.exc import Unsupported + + log.error( + "Parameter %s is optional with a default value of %s", + name, + param.default, + ) + raise Unsupported( + "Tracing through optional input is not supported yet", + case_name="optional_input", + ) + + def produce_matching(debug_type, sources, candidates): + matched_elements_positions: List[Optional[int]] = [] + dict_of_source_vals = {} + for i, val in enumerate(sources): + dict_of_source_vals[id(val)] = i + + for i, val in enumerate(candidates): + if isinstance(val, tuple(common_constant_types)): + matched_elements_positions.append(None) + elif id(val) not in dict_of_source_vals: + if debug_type == "inputs": + check_optional_input_and_error(f_sig) + raise AssertionError( + f"Unexpectedly found a {type(val)} in the {debug_type}.\n" + 'Please file an issue along with a paste of the logs from TORCH_LOGS="+export"', + ) + else: + matched_elements_positions.append(dict_of_source_vals[id(val)]) + + return matched_elements_positions + + matched_input_elements_positions = produce_matching( + "inputs", flat_args, graph_captured_input + ) + + assert graph_captured_output is not None + matched_output_elements_positions = produce_matching( + "outputs", list(graph_captured_output) + flat_args, flat_results_traced + ) + + new_graph = FlattenInputOutputSignature( + graph, + flat_args, + matched_input_elements_positions, + flat_results_traced, + matched_output_elements_positions, + example_fake_inputs, + flat_args_dynamic_dims, + fake_mode, + ).transform() + + # Make dynamo graph to have same input/output spec as user code + def argument_names(f_sig, args, kwargs) -> List[str]: + def signature_to_fullargspec(sig: inspect.Signature): + # Get a list of Parameter objects from the Signature object + params = list(sig.parameters.values()) + # Separate positional arguments, keyword-only arguments and varargs/varkw + args = [ + p.name + for p in params + if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ] + kwonlyargs = [ + p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY + ] + varargs = next( + (p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL), + None, + ) + varkw = next( + (p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD), + None, + ) + # Get default values for positional arguments and keyword-only arguments + defaults = tuple( + p.default + for p in params + if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + and p.default is not inspect.Parameter.empty + ) + kwonlydefaults = { + p.name: p.default + for p in params + if p.kind == inspect.Parameter.KEYWORD_ONLY + and p.default is not inspect.Parameter.empty + } + # Get annotations for parameters and return value + annotations = {} + if sig.return_annotation: + annotations = {"return": sig.return_annotation} + for parameter in params: + annotations[parameter.name] = parameter.annotation + # Return a FullArgSpec object with the extracted attributes + return inspect.FullArgSpec( + args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations + ) + + fullargspec = signature_to_fullargspec(f_sig) + + # 1. Map `args` 1-to-1 to positional arguments in original signature. + input_strs = fullargspec.args[: len(args)] + + if len(args) > len(fullargspec.args): + # 2. If there are more arguments left in `args`, they map to varargs in original + # signature. Assign names as {varargs}_0, {varargs}_1, ... + assert fullargspec.varargs is not None, "More arguments than expected" + input_strs += [ + f"{fullargspec.varargs}_{i}" + for i in range(0, len(args) - len(input_strs)) + ] + elif len(args) < len(fullargspec.args): + # 3. If there are fewer arguments in `args` than `fullargspec.args`, + # it implies these are arguments either with default values, or provided in + # `kwargs`. The former can be safely ignored. Because Dynamo.export does not + # export them as part of the function signature. The latter will be handled + # in the next step. + for unprovided_arg in fullargspec.args[ + len(args) : -len(fullargspec.defaults or []) + ]: + assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}" + + # 4. Keyword arguments provided in `kwargs`. + input_strs += list(kwargs.keys()) + + # 5. Keyword-only arguments with default values if not provided are not exported + # as part of the function signature. + for kwonly_arg in fullargspec.kwonlyargs: + kwonlydefaults = fullargspec.kwonlydefaults or {} + assert ( + kwonly_arg in kwargs or kwonly_arg in kwonlydefaults + ), f"Missing keyword only argument {kwonly_arg}" + + return input_strs + + new_graph.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo( + argument_names(f_sig, orig_args, orig_kwargs), + in_spec, + out_spec_traced, + ) + ) + new_graph.recompile() + return new_graph + + +def export( + f: Callable[..., Any], + *extra_args, + aten_graph: bool = False, + pre_dispatch: bool = False, + decomposition_table: Optional[ + Dict[torch._ops.OpOverload, Callable[..., Any]] + ] = None, + tracing_mode: str = "symbolic", + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + assume_static_by_default: bool = False, + same_signature: bool = True, + disable_constraint_solver: bool = False, + prefer_deferred_runtime_asserts_over_guards: bool = False, + allow_complex_guards_as_runtime_asserts: bool = False, + _log_export_usage: bool = True, + **extra_kwargs, +) -> Callable[..., ExportResult]: + """ + Export an input function f to a format that can be executed outside of PyTorch using the FX graph. + + Args: + f (callable): A PyTorch function to be exported. + + aten_graph (bool): If True, exports a graph with ATen operators. + If False, exports a graph with Python operators. Default is False. + + pre_dispatch (bool): If True, exports a graph with ATen operators, + but before any logic in the PyTorch dispatcher has run. + This can be useful if you want to apply further transformations on a graph before running it + through autograd, autocast, or any other functionalities that are integrated into the dispatcher. + This flag is only valid if aten_graph=True is set. + Default is False. + + decomposition_table (dict): A dictionary that maps operators to their decomposition functions. + Required if aten_graph or tracing_mode is specified. Default is None. + + tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic". + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + same_signature (bool): If True, rewrite the returned graph's signature to be the same as f. + + disable_constraint_solver (bool): Whether the dim constraint solver must be disabled. + + Returns: + A function that given args and kwargs, returns a tuple of (graph, guards) + Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options. + Guards: The guards we accumulated during tracing f above + + Raises: + AssertionError: If decomposition_table is specified without setting aten_graph=True, + or if graph breaks during tracing in export. + + AssertionError: If Dynamo input and output is not consistent with traced input/output. + + Note - this headerdoc was authored by ChatGPT, with slight modifications by the author. + """ + if _log_export_usage: + log_export_usage(event="export.private_api", flags={"_dynamo"}) + + # Deal with "local variable referenced before assignment" + _f = f + _assume_static_by_default = assume_static_by_default + + def inner(*args, **kwargs): + combined_args = _combine_args(_f, args, kwargs) + constraints = _process_dynamic_shapes(combined_args, dynamic_shapes) + f = _f + assume_static_by_default = _assume_static_by_default + check_if_dynamo_supported() + torch._C._log_api_usage_once("torch._dynamo.export") + if decomposition_table is not None: + assert ( + aten_graph + ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True" + if pre_dispatch: + assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True" + f = innermost_fn(f) + call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f + original_signature = inspect.signature(call_to_inspect) + graph = None + out_guards = None + graph_captured_input = None + graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None + fake_mode = None + result_traced = None + + def guard_export_print(guards: _guards.GuardsSet): + nonlocal out_guards + assert ( + out_guards is None + ), "whole graph export entails exactly one guard export" + out_guards = guards + + example_inputs = [] + + def dynamo_normalization_capturing_compiler( + gm: torch.fx.GraphModule, inner_example_inputs + ): + nonlocal graph + assert ( + graph is None + ), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph." + graph = gm + + nonlocal fake_mode, example_inputs + # NB: do NOT pass inner_example_inputs here, we are detecting the + # Dynamo allocated fake mode, which should be DISTINCT from a + # potential outer ambient fake mode which the user provided. + # example_inputs is always the user specified inputs, so they + # would have the wrong fake mode attached to them + fake_mode = _guards.detect_fake_mode() + example_inputs = inner_example_inputs + + def result_capturing_wrapper(*graph_inputs): + nonlocal graph_captured_result + nonlocal graph_captured_input + + graph_captured_input = graph_inputs + assert graph is not None + + named_parameters = dict(graph.named_parameters(remove_duplicate=False)) + named_buffers = dict(graph.named_buffers(remove_duplicate=False)) + + ambient_fake_mode = ( + _guards.detect_fake_mode(graph_inputs) + if _guards.detect_fake_mode(graph_inputs) is not None + else fake_mode + ) + + # We reran fake tensor propagation, but we didn't do + # anything with the resulting unbacked SymInts. Drop them + # from the pending list. + # NB: this is wrong if graph_captured_result has + # data-dependent output size! + ignore_fresh_unbacked = null_context() + if shape_env := ambient_fake_mode.shape_env: + ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() + + with ( + ambient_fake_mode + ), enable_python_dispatcher(), ignore_fresh_unbacked: + params_and_buffers = { + **named_parameters, + **named_buffers, + } + fake_params_buffers = {} + + for name, value in params_and_buffers.items(): + fake_params_buffers[name] = ambient_fake_mode.from_tensor( + value, static_shapes=True + ) + + fake_graph_inputs = pytree.tree_map( + ambient_fake_mode.from_tensor, graph_inputs + ) + graph_captured_result = torch.func.functional_call( + graph, fake_params_buffers, fake_graph_inputs + ) + + return graph_captured_result + + return result_capturing_wrapper + + # Note: This is needed by rewrite_signature. We need to put it before + # optimize_assert since user program may mutate the inputs. + flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + + remove_from_cache(f) + constraint_violation_error = None + if tracing_mode != "symbolic": + assume_static_by_default = True + with config.patch( + specialize_int=True, + assume_static_by_default=assume_static_by_default, + automatic_dynamic_shapes=False, + capture_dynamic_output_shape_ops=True, + capture_scalar_outputs=True, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + ): + opt_f = optimize_assert( + dynamo_normalization_capturing_compiler, + hooks=Hooks( + guard_export_fn=guard_export_print, + guard_fail_fn=None, + ), + export=True, + export_constraints=constraints, + )(f) + # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject. + try: + result_traced = opt_f(*args, **kwargs) + except ConstraintViolationError as e: + constraint_violation_error = e + remove_from_cache(f) + + if ( + not disable_constraint_solver + and (shape_env := getattr(fake_mode, "shape_env", None)) is not None + and (dim_constraints := shape_env.dim_constraints) is not None + and not isinstance( + call_to_inspect, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ) + and not trace_rules.check(call_to_inspect) + ): + dim_constraints.solve() + forced_specializations = dim_constraints.forced_specializations() + msg = dim_constraints.prettify_results( + original_signature, + dynamic_shapes, + constraint_violation_error, + forced_specializations, + ) + if constraint_violation_error: + constraint_violation_error.args = ( + constraint_violation_error.args[0] + msg, + ) + else: + if forced_specializations: + constraint_violation_error = ConstraintViolationError(msg) + else: + log.info( + "Summary of dimension constraints:%s", + msg, + ) + + # Error if we have any constraints on static values + for k in shape_env.var_to_range.keys(): + if isinstance(k, sympy.Integer): + constraint_violation_error = ConstraintViolationError( + f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" + "It appears that you're trying to set a constraint on a " + f"value which we evaluated to have a static value of {k}. " + 'Set TORCH_LOGS="+export" for more information.' + ) + if constraint_violation_error: + raise constraint_violation_error + + if graph is None: + assert ( + same_signature + ), "Failed to produce a graph during tracing as no tensor operations were found and same_signature is False." + # If the module does not contain any tensor computation, we would create a graph with inputs and outputs. + # To be consitant with the graph traced by dynano, `graph` will have only tensor inputs as placeholders + # and tensor outputs as output nodes. non-tensor inputs and outputs will be added when rewriting signature. + # We will also construct the `example_inputs`, `graph_captured_input`, and `graph_captured_result` corresponding + # to `graph`. + example_inputs = [] + graph_captured_input = () + graph_captured_result = () + fake_mode = torch._subclasses.FakeTensorMode( + shape_env=ShapeEnv(), export=True + ) + if out_guards is None: + out_guards = _guards.GuardsSet() + assert out_guards is not None # suppress mypy error + parameter_names = list(original_signature.parameters.keys()) + fx_graph = torch.fx.Graph() + for i, name in enumerate(parameter_names): + if torch.is_tensor(flat_args[i]): + node = fx_graph.placeholder(name) + node.meta["val"] = fake_mode.from_tensor( + flat_args[i], static_shapes=True + ) + graph_captured_input = graph_captured_input + (flat_args[i],) + example_inputs.append(flat_args[i]) + fx_graph.output(graph_captured_result) + module = torch.nn.Module() + graph = torch.fx.GraphModule(module, fx_graph) + log.info( + "Failed to capture a graph during tracing as no tensor operations were found.:\n\n%s", + graph.print_readable(print_output=False, colored=True), + ) + else: + assert hasattr(graph, "_source_to_user_stacks") + assert out_guards is not None, "Failed to produce guards during tracing" + assert fake_mode is not None + + log.info( + "Dynamo captured graph:\n\n%s", + graph.print_readable(print_output=False, colored=True), + ) + + # This check need to happened before aten_graph + # because placeholder's _source_node attribute is not preserved by make_fx + if same_signature: + check_signature_rewritable(graph) + + # NB: This is mostly hitting the cache; Dynamo already converted these + example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs] + + if aten_graph: + # Running graph with interpreter is needed for propagating the stack_trace + def graph_with_interpreter(*args): + with torch.fx.traceback.preserve_node_meta(): + return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type] + + with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode: + try: + graph = make_fx( + graph_with_interpreter, + decomposition_table=decomposition_table, + tracing_mode="real", + _allow_non_fake_inputs=True, + pre_dispatch=pre_dispatch, + _allow_fake_constant=False, + )(*example_fake_inputs) + except CondOpArgsMismatchError as e: + # Wrap the internal error to the user-facing error + raise UserError( # noqa: B904 + UserErrorType.DYNAMIC_CONTROL_FLOW, + str(e), + case_name="cond_operands", + ) + + assert graph is not None + for node in graph.graph.find_nodes(op="get_attr"): + if isinstance(getattr(graph, node.target), torch.Tensor): # type: ignore[arg-type] + node.meta["val"] = fake_mode.from_tensor( + getattr(graph, node.target), static_shapes=True # type: ignore[arg-type] + ) + + if same_signature: + flat_args_dynamic_dims = [ + { + c.dim + for c in (constraints or ()) + if ( + c.t_id == id(x) + and c.constraint_range.vr.lower != c.constraint_range.vr.upper + ) + } + for x in flat_args + ] + graph = rewrite_signature( + original_signature, + graph, + fake_mode, + flat_args, + in_spec, + example_fake_inputs, + graph_captured_input, + graph_captured_result, + result_traced, # type: ignore[possibly-undefined] + flat_args_dynamic_dims, + ) + return ExportResult(graph, out_guards) # type: ignore[arg-type] + + if extra_args or extra_kwargs: + warnings.warn( + "export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. " + "If you don't migrate, we may break your export call in the future if your user defined kwargs " + "conflict with future kwargs added to export(f).", + FutureWarning, + stacklevel=2, + ) + return inner(*extra_args, **extra_kwargs) + else: + return inner + + +def optimize_assert( + backend, + *, + hooks=Hooks(None, None), + export=False, + export_constraints=None, + dynamic=None, + rebuild_ctx=None, +): + """ + The same as `torch._dynamo.optimize(backend, nopython=True)` + """ + backend = get_compiler_fn(backend) + + # Find if backend has any extra context manager + backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) + + return _optimize_catch_errors( + convert_frame.convert_frame_assert( + backend, export=export, export_constraints=export_constraints + ), + hooks, + backend_ctx_ctor, + export=export, + dynamic=dynamic, + rebuild_ctx=rebuild_ctx, + ) + + +class TorchPatcher: + @staticmethod + @functools.lru_cache(None) + def patch(): + # A better way to disable the following would be decorate the source + # functions with @torch._disable_dynamo. However, this causes issues + # with torch.deploy internally. + from .decorators import disable + + torch.jit.trace = disable(torch.jit.trace) + torch.jit.trace_module = disable(torch.jit.trace_module) + torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph) + torch.fx._symbolic_trace.Tracer.trace = disable( + torch.fx._symbolic_trace.Tracer.trace + ) + torch.distributions.Distribution.set_default_validate_args(False) + + from torch.optim import ( + adadelta, + adagrad, + adam, + adamax, + adamw, + asgd, + lbfgs, + nadam, + radam, + rmsprop, + rprop, + sgd, + sparse_adam, + ) + + optimizer_modules = { + adadelta, + adagrad, + adam, + adamax, + adamw, + asgd, + lbfgs, + nadam, + radam, + rmsprop, + rprop, + sgd, + sparse_adam, + } + + for opt_mod in optimizer_modules: + opt_name = opt_mod.__name__.split(".")[-1] + fused_fn_name = f"_fused_{opt_name}" + single_tensor_fn_name = f"_single_tensor_{opt_name}" + + if hasattr(opt_mod, fused_fn_name): + setattr( + opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name)) + ) + + optimizer_classes = [ + opt + for opt in torch.optim.__dict__.values() + if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer) + ] + + # Note: we don't support sparsity or tracing through backwards + excluded_optimizer_classes = { + torch.optim.SparseAdam, + torch.optim.LBFGS, + } + + for opt in optimizer_classes: + if opt in excluded_optimizer_classes: + opt.step = disable(opt.step) + + if hasattr(opt, "_init_group"): + opt._init_group = disable(opt._init_group) + + @staticmethod + def suppress_torch_distributed_warnings(fn): + def inner_fn(*args, **kwargs): + warnings.filterwarnings( + "ignore", category=UserWarning, module="torch.distributed" + ) + return fn(*args, **kwargs) + + return inner_fn diff --git a/lib/python3.10/site-packages/torch/_dynamo/exc.py b/lib/python3.10/site-packages/torch/_dynamo/exc.py new file mode 100644 index 0000000000000000000000000000000000000000..0d2108ada9e10e9dc37990e1041027133968b1ff --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/exc.py @@ -0,0 +1,454 @@ +# mypy: allow-untyped-defs +import os +import textwrap +from enum import auto, Enum +from traceback import extract_stack, format_exc, format_list, StackSummary +from typing import Any, cast, NoReturn, Optional, Tuple, TYPE_CHECKING + +import torch._guards + +from . import config +from .utils import counters + + +if TYPE_CHECKING: + from torch._guards import CompileId + + +def exportdb_error_message(case_name): + return ( + "For more information about this error, see: " + + "https://pytorch.org/docs/main/generated/exportdb/index.html#" + + case_name.replace("_", "-") + ) + + +import logging + + +log = logging.getLogger(__name__) +graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") + + +class TorchDynamoException(RuntimeError): + pass + + +class InternalTorchDynamoError(TorchDynamoException): + pass + + +class RestartAnalysis(TorchDynamoException): + restart_reason: str + + def __init__(self, *args, restart_reason=None) -> None: + self.restart_reason = restart_reason + super().__init__(*args) + + +class SpeculationRestartAnalysis(RestartAnalysis): + pass + + +class UnspecializeRestartAnalysis(RestartAnalysis): + pass + + +class CompileCollectiveRestartAnalysis(RestartAnalysis): + pass + + +class SkipFrame(TorchDynamoException): + pass + + +class TorchRuntimeError(TorchDynamoException): + pass + + +class InvalidBackend(TorchDynamoException): + def __init__(self, name) -> None: + super().__init__( + f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends." + ) + + +class ResetRequired(TorchDynamoException): + def __init__(self) -> None: + super().__init__( + textwrap.dedent( + """ + Must call `torch._dynamo.reset()` before changing backends. Detected two calls to + `torch.compile()` with a different backend compiler arguments. + """ + ) + ) + + +class BackendCompilerFailed(TorchDynamoException): + def __init__(self, backend_fn, inner_exception) -> None: + self.backend_name = getattr(backend_fn, "__name__", "?") + self.inner_exception = inner_exception + msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}" + super().__init__(msg) + + +class Unsupported(TorchDynamoException): + def __init__(self, msg, *, case_name=None) -> None: + super().__init__(msg) + self.real_stack = torch._guards.TracingContext.extract_stack() + self.msg = msg + self.category: Optional[str] = None + self.add_to_stats() + self.case_name: Optional[str] = case_name + + def remove_from_stats(self): + assert self.category is not None + counters[self.category][self.msg] -= 1 + if counters[self.category][self.msg] <= 0: + del counters[self.category][self.msg] + + def add_to_stats(self, category="unimplemented"): + self.category = category + counters[category][self.msg] += 1 + + +class RecompileError(TorchDynamoException): + pass + + +class ArgsMismatchError(Unsupported): + def __init__(self, msg) -> None: + super().__init__(msg) + + +class AttributeMutationError(Unsupported): + def __init__(self, msg) -> None: + super().__init__(msg) + + +class CondOpArgsMismatchError(ArgsMismatchError): + """ + Internal error from cond() due to arguments mismatch. + """ + + def __init__(self, msg) -> None: + super().__init__(msg) + + +class UserErrorType(Enum): + DYNAMIC_CONTROL_FLOW = auto() + ANTI_PATTERN = auto() + STANDARD_LIBRARY = auto() + CONSTRAINT_VIOLATION = auto() + DYNAMIC_DIM = auto() + INVALID_INPUT = auto() + INVALID_OUTPUT = auto() + + +class UserError(Unsupported): + def __init__(self, error_type: UserErrorType, msg, case_name=None) -> None: + """ + Type of errors that would be valid in Eager, but not supported in TorchDynamo. + The error message should tell user about next actions. + + error_type: Type of user error + msg: Actionable error message + case_name: (Optional) Unique name (snake case) for the usage example in exportdb. + """ + if case_name is not None: + assert isinstance(case_name, str) + if msg.endswith("."): + msg += " " + else: + msg += "\n" + msg += exportdb_error_message(case_name) + super().__init__(msg) + self.error_type = error_type + self.message = msg + + +class SkipCodeRecursiveException(TorchDynamoException): + pass + + +class CacheLimitExceeded(SkipCodeRecursiveException, Unsupported): + pass + + +class UnsafeScriptObjectError(TorchDynamoException): + pass + + +class UncapturedHigherOrderOpError(TorchDynamoException): + pass + + +class IncorrectUsage(Exception): + pass + + +class ObservedException(TorchDynamoException): + # An exception observed during the tracing. This exception is used by Dynamo to handle exceptions. + pass + + +class ObservedUserStopIteration(ObservedException): + # An UserStopIteraion exception observed during the Dynamo tracing (e.g Dynamo tracing __next__) + value: Optional[Any] + + # Reference `StopIteration_init` in CPython + # https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L568-L584 + def __init__(self, *args, **kwargs) -> None: + super().__init__("unhandled `raise StopIteration`") + if len(args) > 0: + self.value = args[0] + else: + self.value = None + + +class ObservedKeyError(ObservedException): + # A KeyError exception to be raised from inside Dynamo tracing. This can happen on dict __getitem__ + pass + + +class ObservedAttributeError(ObservedException): + # An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__ + pass + + +observed_exception_map = { + StopIteration: ObservedUserStopIteration, + KeyError: ObservedKeyError, + AttributeError: ObservedAttributeError, +} + + +def raise_observed_exception(e, tx, vt): + from .variables import BuiltinVariable + + # CPython here raises an exception. Since there is no python code, we have to manually setup the exception + # stack and raise the exception. + exception_vt = BuiltinVariable(e).call_function(vt, [], {}) + tx.exn_vt_stack.append(exception_vt) + raise observed_exception_map[e] + + +def handle_observed_exception(tx): + # This is essentially exception handling code, equivalent of this pseudo code + # + # try: + # ... somebody raising StopIteration + # except StopIteration + # pass + # + # If this was going through the python code, we would have called exception_handler method, but FOR_ITER + # handles the exception completely in CPython. For example for 3.11, the resulting bytecode is + # + # + # 6 46 LOAD_GLOBAL 2 (StopIteration) + # 58 RAISE_VARARGS 1 + # >> 60 PUSH_EXC_INFO + + # 7 62 LOAD_GLOBAL 2 (StopIteration) + # 74 CHECK_EXC_MATCH + # 76 POP_JUMP_FORWARD_IF_FALSE 3 (to 84) + # 78 POP_TOP + + # 8 80 POP_EXCEPT + # + + # Fortunately this translates to a simple pop from the exn_vt_stack + tx.exn_vt_stack.pop() + + +# These exceptions are ok to fallback to eager/graph_break. +exceptions_allowed_to_be_fallback = ( + torch._subclasses.fake_tensor.DataDependentOutputException, + torch._subclasses.fake_tensor.DynamicOutputShapeException, + torch._subclasses.fake_tensor.UnsupportedOperatorException, + torch._subclasses.fake_tensor.UnsupportedFakeTensorException, +) + + +def unimplemented_with_warning(e: Exception, code, msg: str) -> NoReturn: + # This function calls unimplemented internally and eventually graph breaks + # or falls to eager. unimplemented itself does not print any user warnings, + # i.e., its very silent. This helper function is intended when an error is + # encountered in the torch.compile stack which is worth showing as warning + # to the user. For example, if AOT Autograd backend fails with a fake tensor + # exception, its ok to fallback to eager but not silently. Here, we can use + # this function to log the message and the stack trace. + graph_break_msg = format_error_msg_verbose(e, code) + graph_breaks_log.debug("%s", graph_break_msg) + log.warning(msg) + unimplemented(msg, from_exc=e) + + +_NOTHING = object() + + +def unimplemented( + msg: str, *, from_exc: Any = _NOTHING, case_name: Optional[str] = None +) -> NoReturn: + assert msg != os.environ.get("BREAK", False) + if from_exc is not _NOTHING: + raise Unsupported(msg, case_name=case_name) from from_exc + raise Unsupported(msg, case_name=case_name) + + +def warning(msg: str) -> None: + counters["warnings"][msg] += 1 + assert msg != os.environ.get("BREAK", False) + + +# KeyError has special handling for its args +# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details +class KeyErrorMsg: + def __init__(self, value) -> None: + self.value = value + + def __str__(self) -> str: + return str(self.value) + + def __repr__(self) -> str: + return self.__str__() + + +def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None: + import traceback + + exc.innermost_user_frame_summary = None # type: ignore[attr-defined] + + real_stack = get_real_stack(exc) + if real_stack is not None and len(real_stack) > 0: + exc.innermost_user_frame_summary = real_stack[-1] # type: ignore[attr-defined] + msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}" + + if config.replay_record_enabled and hasattr(exc, "record_filename"): + msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\ + torch._dynamo.replay('{exc.record_filename}').\n" + + if not config.verbose and hasattr(exc, "real_stack"): + msg += '\nSet TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information\n' + + if hasattr(exc, "inner_exception") and hasattr( + exc.inner_exception, "minifier_path" + ): + if hasattr(exc.inner_exception, "buck_command"): + msg += ( + f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " + f"this buck command to find the smallest traced graph " + f"which reproduces this error: {exc.inner_exception.buck_command}\n" + ) + else: + msg += ( + f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " + "this script to find the smallest traced graph which reproduces this error.\n" + ) + + if not config.suppress_errors and not export: + msg += ( + "\n\n" + "You can suppress this exception and fall back to eager by setting:\n" + " import torch._dynamo\n" + " torch._dynamo.config.suppress_errors = True\n" + ) + + old_msg = "" if len(exc.args) == 0 else str(exc.args[0]) + + if isinstance(exc, KeyError): + exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:] + else: + new_msg = old_msg + msg + exc.args = (new_msg,) + exc.args[1:] + + +def get_exc_message( + e: Exception, compile_id: "CompileId" +) -> Tuple[Optional[str], Optional[int]]: + filename = None + lineno = None + if e.innermost_user_frame_summary is not None: # type: ignore[attr-defined] + filename = e.innermost_user_frame_summary.filename # type: ignore[attr-defined] + lineno = e.innermost_user_frame_summary.lineno # type: ignore[attr-defined] + e.compile_id = compile_id # type: ignore[attr-defined] + return filename, lineno + + +def get_real_stack(exc: Exception, frame=None) -> Optional[StackSummary]: + real_stack = getattr(exc, "real_stack", None) + if real_stack is None: + return None + + # NB: it's possible for real_stack to be []; we still attempt to + # report a stack anyway because the stack_above_dynamo may still + # be useful for debugging + + stack_above_dynamo = [] + if frame is not None: + # NB: frame is PyInterpreterFrame on Python 3.11 and later, + # not a TRUE frame object. You can't actually feed it + # to traceback because it doesn't have enough information. + # To solve this problem, we technically should just materialize + # the frame, the same way _PyFrame_GetFrameObject would do + # (but we cannot actually do this, because this populates + # frame_obj field, which default eval frame doesn't like). + # + # Fortunately, in this case, we can hack it: there's no need + # to actually use the truly top frame, we can just extract + # from where we are right now and rely on filter_stack to + # get rid of all the dynamo frames. For ease of testing + # we apply this behavior to ALL Python versions + stack_above_dynamo = filter_stack(extract_stack()) + + return cast(StackSummary, stack_above_dynamo + real_stack) + + +# filter out all frames after entering dynamo +def filter_stack(stack): + user_stack = [] + for frame in stack: + if "convert_frame" in frame.filename: + break + if "eval_frame" in frame.filename or "torch._dynamo.optimize(" in frame.line: + continue + user_stack.append(frame) + + return user_stack + + +def format_error_msg_verbose( + exc: Exception, code, record_filename=None, frame=None +) -> str: + msg = ( + f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n" + ) + msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n" + msg += format_exc() + real_stack = get_real_stack(exc, frame) + if real_stack is not None: + msg += ( + "\n" + + "=" * 10 + + " The above exception occurred while processing the following code " + + "=" * 10 + + "\n\n" + ) + msg += "".join(format_list(real_stack)) + msg += "\n" + msg += "=" * 10 + + return msg + + +def format_error_msg(exc: Exception, code, record_filename=None, frame=None) -> str: + msg = os.linesep * 2 + + if config.verbose: + msg = format_error_msg_verbose(exc, code, record_filename, frame) + else: + msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\ + line {code.co_firstlineno} \ndue to: \n{format_exc()}" + + return msg diff --git a/lib/python3.10/site-packages/torch/_dynamo/external_utils.py b/lib/python3.10/site-packages/torch/_dynamo/external_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..91663b4a99d175c14f8fd02df96d26d6a1d2e013 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/external_utils.py @@ -0,0 +1,144 @@ +# mypy: allow-untyped-defs +# This module contains functions that *will be allowed* by dynamo + +import functools +import warnings +from typing import List + +import torch +import torch.utils._pytree as pytree + + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + + +def is_compiling() -> bool: + """ + Indicates whether we are tracing/compiling with torch.compile() or torch.export(). + + If need to check specifically that TorchDynamo is used, then use + torch.compiler.is_dynamo_compiling(). + + TODO(khabinov): we should deprecate this function and use one of these two: + * torch.compiler.is_compiling(), + * torch.compiler.is_dynamo_compiling(). + It will depend on the context where to use what. + """ + return torch.compiler.is_compiling() + + +def wrap_inline(fn): + """ + Create an extra frame around fn that is not in skipfiles + """ + + @functools.wraps(fn) + def inner(*args, **kwargs): + return fn(*args, **kwargs) + + return inner + + +def call_hook(hook, *args, **kwargs): + """ + Used by compiled autograd to handle hook returning None + """ + result = hook(*args) + if result is None: + return args[0] + elif kwargs["hook_type"] == "post_acc_grad_hook": + raise RuntimeError("Tensor post accumulate grad hooks should return None.") + return result + + +def wrap_numpy(f): + r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function + from ``torch.Tensor``s to ``torch.Tensor``s. + """ + if not np: + return f + + @functools.wraps(f) + def wrap(*args, **kwargs): + args, kwargs = pytree.tree_map_only( + torch.Tensor, lambda x: x.numpy(), (args, kwargs) + ) + out = f(*args, **kwargs) + return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out) + + return wrap + + +class FakeBackwardCFunction: + def __init__( + self, + real: torch.autograd.function.BackwardCFunction, + saved_tensors: List[torch.Tensor], + ) -> None: + self.real = real + self.saved_tensors = saved_tensors + + def __getattr__(self, name): + if name == "saved_variables": + warnings.warn( + "'saved_variables' is deprecated; use 'saved_tensors'", + DeprecationWarning, + ) + return self.saved_tensors + + # route any attribute that isn't defined on this obj + return getattr(self.real, name) + + +# This function corresponds to the "eager" implementation of a lifted autograd.Function.backward +def call_backward(backward_c_function, saved_tensors, *args): + fake = FakeBackwardCFunction(backward_c_function, saved_tensors) + grads = fake._forward_cls.backward(fake, *args) # type: ignore[attr-defined] + + # in eager, we wrap in a tuple when there's only one grad output + if type(grads) is not tuple: + grads = (grads,) + + return grads + + +def untyped_storage_size(x: torch.Tensor): + return x.untyped_storage().size() + + +class FakeCompiledAutogradEngine: + @staticmethod + def queue_callback(final_callbacks, cb): + final_callbacks.append(cb) + + @staticmethod + def exec_final_callbacks(final_callbacks): + i = 0 + while i < len(final_callbacks): + cb = final_callbacks[i] + cb() + i += 1 + final_callbacks.clear() + + @staticmethod + def _exec_final_callbacks_stub(): + pass + + +def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs): + return getattr(bw_state, hook_name)(*args, **kwargs) + + +def call_module_hooks_from_backward_state( + _, result, *args, bw_state, hooks_name: str, module_name: str +): + module = getattr(bw_state, module_name) + hooks = getattr(bw_state, hooks_name) + for hook in hooks: + new_result = hook(module, result, *args) + if new_result is not None: + result = new_result + return result diff --git a/lib/python3.10/site-packages/torch/_dynamo/funcname_cache.py b/lib/python3.10/site-packages/torch/_dynamo/funcname_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9e278c871e2247a102dd917e8ba8a588f17ba8 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/funcname_cache.py @@ -0,0 +1,57 @@ +import tokenize +from typing import Dict, List, Optional + + +cache: Dict[str, Dict[int, str]] = {} + + +def clearcache() -> None: + cache.clear() + + +def _add_file(filename: str) -> None: + try: + with tokenize.open(filename) as f: + tokens = list(tokenize.generate_tokens(f.readline)) + except OSError: + cache[filename] = {} + return + + # NOTE: undefined behavior if file is not valid Python source, + # since tokenize will have undefined behavior. + result: Dict[int, str] = {} + # current full funcname, e.g. xxx.yyy.zzz + cur_name = "" + cur_indent = 0 + significant_indents: List[int] = [] + + for i, token in enumerate(tokens): + if token.type == tokenize.INDENT: + cur_indent += 1 + elif token.type == tokenize.DEDENT: + cur_indent -= 1 + # possible end of function or class + if significant_indents and cur_indent == significant_indents[-1]: + significant_indents.pop() + # pop the last name + cur_name = cur_name.rpartition(".")[0] + elif ( + token.type == tokenize.NAME + and i + 1 < len(tokens) + and tokens[i + 1].type == tokenize.NAME + and (token.string == "class" or token.string == "def") + ): + # name of class/function always follows class/def token + significant_indents.append(cur_indent) + if cur_name: + cur_name += "." + cur_name += tokens[i + 1].string + result[token.start[0]] = cur_name + + cache[filename] = result + + +def get_funcname(filename: str, lineno: int) -> Optional[str]: + if filename not in cache: + _add_file(filename) + return cache[filename].get(lineno, None) diff --git a/lib/python3.10/site-packages/torch/_dynamo/guards.py b/lib/python3.10/site-packages/torch/_dynamo/guards.py new file mode 100644 index 0000000000000000000000000000000000000000..1923426bf060f8630d3b3b1325f7c8e45f349f93 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/guards.py @@ -0,0 +1,2914 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import ast +import builtins +import collections +import dataclasses +import enum +import functools +import importlib +import inspect +import itertools +import logging +import math +import os +import re +import sys +import textwrap +import types +import weakref +from contextlib import contextmanager +from copy import deepcopy +from inspect import currentframe, getframeinfo +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TYPE_CHECKING, + Union, +) +from weakref import ReferenceType + +import torch +import torch.utils._device +from torch._C._dynamo.guards import ( + check_obj_id, + check_type_id, + dict_version, + DictGuardManager, + install_no_tensor_aliasing_guard, + install_object_aliasing_guard, + RootGuardManager, + TensorGuards, +) +from torch._dynamo.source import ( + is_from_flatten_script_object_source, + is_from_local_source, + is_from_optimizer_source, + TensorProperty, + TensorPropertySource, +) +from torch._guards import ( + CompileContext, + CompileId, + DuplicateInputs, + Guard, + GuardBuilderBase, + GuardEnvExpr, + GuardSource, + Source, +) +from torch._logging import structured +from torch._utils_internal import justknobs_check +from torch.fx.experimental.symbolic_shapes import ( + EqualityConstraint, + is_symbolic, + SYMPY_INTERP, +) +from torch.utils._traceback import format_frame, report_compile_source_on_error +from torch.utils.weak import TensorWeakRef + +from . import config, convert_frame, exc, mutation_guard +from .eval_frame import set_guard_error_hook +from .source import ( + AttrProxySource, + AttrSource, + ChainedSource, + ConstDictKeySource, + DefaultsSource, + FlattenScriptObjectSource, + FSDPNNModuleSource, + GetItemSource, + GlobalSource, + GlobalStateSource, + GlobalWeakRefSource, + GradSource, + LocalSource, + NNModuleSource, + NumpyTensorSource, + ODictGetItemSource, + OptimizerSource, + ScriptObjectQualifiedNameSource, + ShapeEnvSource, + SubclassAttrListSource, + TupleIteratorGetItemSource, + TypeSource, + UnspecializedBuiltinNNModuleSource, + UnspecializedNNModuleSource, + UnspecializedParamBufferSource, + WeakRefCallSource, +) +from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401 +from .utils import ( + common_constant_types, + dict_keys_repr, + get_custom_getattr, + get_torch_function_mode_stack, + guard_failures, + istype, + key_is_id, + key_to_id, + orig_code_map, + tensor_always_has_static_shape, + tuple_iterator_getitem, + tuple_iterator_len, + unpatched_nn_module_getattr, + verify_guard_fn_signature, +) + + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + + +if TYPE_CHECKING: + from sympy import Symbol + + +log = logging.getLogger(__name__) +guards_log = torch._logging.getArtifactLogger(__name__, "guards") +recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles") +recompiles_verbose_log = torch._logging.getArtifactLogger( + __name__, "recompiles_verbose" +) +verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") + + +class GuardManager: + """ + A helper class that contains the root guard manager. An instance of this + class is stored in the Dynamo cache entry, so that the cache entry can + access the RootGuardManager stored in the "root" attribute and directly call + the check_nopybind from C++. + """ + + def __init__(self): + self.root = RootGuardManager() + + self.closure_vars = None + self.args = None + self.code_parts = [] + self.verbose_code_parts = None + self.global_scope = None + self.guard_fail_fn = None + self.cache_entry = None + self.extra_state = None + self.id_matched_objs = None + self.no_tensor_aliasing_sources = [] + + self.print_no_tensor_aliasing_guard = True + + @contextmanager + def _preserve_print_no_tensor_aliasing_flag(self): + self.print_no_tensor_aliasing_guard = True + try: + yield + finally: + self.print_no_tensor_aliasing_guard = True + + def get_guard_lines(self, guard): + guard_name = guard.__class__.__name__ + parts = guard.verbose_code_parts() + parts = [guard_name + ": " + part for part in parts] + return parts + + def get_manager_line(self, guard_manager, accessor_str=None): + source = guard_manager.get_source() + t = guard_manager.__class__.__name__ + s = t + ": source=" + source + if accessor_str: + s += ", " + accessor_str + return s + + def construct_dict_manager_string(self, mgr, body): + for idx, (key_mgr, val_mgr) in sorted(mgr.get_key_value_managers().items()): + body.writeline(f"KeyValueManager pair at index={idx}") + with body.indent(): + if key_mgr: + body.writeline(f"KeyManager: {self.get_manager_line(key_mgr)}") + self.construct_manager_string(key_mgr, body) + + if val_mgr: + body.writeline(f"ValueManager: {self.get_manager_line(val_mgr)}") + self.construct_manager_string(val_mgr, body) + + def construct_manager_string(self, mgr, body): + with body.indent(): + for guard in mgr.get_leaf_guards(): + if isinstance(guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING): # type: ignore[attr-defined] + if self.print_no_tensor_aliasing_guard: + self.print_no_tensor_aliasing_guard = False + body.writelines(self.get_guard_lines(guard)) + else: + body.writelines( + [ + guard.__class__.__name__, + ] + ) + else: + body.writelines(self.get_guard_lines(guard)) + + # This works for both DictGuardManager and SubclassedDictGuardManager + if isinstance(mgr, DictGuardManager): + self.construct_dict_manager_string(mgr, body) + + # General case of GuardManager/RootGuardManager + for accessor, child_mgr in zip( + mgr.get_accessors(), mgr.get_child_managers() + ): + body.writeline( + self.get_manager_line(child_mgr, f"accessed_by={accessor.repr()}") + ) + self.construct_manager_string(child_mgr, body) + + def __str__(self): + from torch._inductor.utils import IndentedBuffer + + class IndentedBufferWithPrefix(IndentedBuffer): + def prefix(self): + return "| " * (self._indent * self.tabwidth) + + def writeline(self, line, skip_prefix=False): + if skip_prefix: + super().writeline(line) + else: + super().writeline("+- " + line) + + with self._preserve_print_no_tensor_aliasing_flag(): + body = IndentedBufferWithPrefix() + body.tabwidth = 1 + body.writeline("", skip_prefix=True) + body.writeline("TREE_GUARD_MANAGER:", skip_prefix=True) + body.writeline("RootGuardManager") + self.construct_manager_string(self.root, body) + for guard in self.root.get_epilogue_lambda_guards(): + body.writelines(self.get_guard_lines(guard)) + return body.getvalue() + + def check(self, x): + # Only needed for debugging purposes. + return self.root.check(x) + + def check_verbose(self, x): + # Only needed for debugging purposes. + return self.root.check_verbose(x) + + def populate_code_parts_for_debugging(self): + # This should be called when the guard manager is fully populated + tensor_aliasing_guard_seen = False + + def get_code_parts(leaf_guard): + code_parts = [] + for verbose_code_part in leaf_guard.verbose_code_parts(): + code_part = verbose_code_part.split("#")[0].rstrip() + code_parts.append(code_part) + return code_parts + + def visit(mgr): + nonlocal tensor_aliasing_guard_seen + for guard in mgr.get_leaf_guards(): + if isinstance(guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING): # type: ignore[attr-defined] + if not tensor_aliasing_guard_seen: + self.code_parts.extend(get_code_parts(guard)) + tensor_aliasing_guard_seen = True + else: + self.code_parts.extend(get_code_parts(guard)) + + for child_mgr in mgr.get_child_managers(): + visit(child_mgr) + + visit(self.root) + + +def from_numpy(a): + # If not numpy array, piggy back on e.g. tensor guards to check type + return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a + + +# For user stack printing +@functools.lru_cache(None) +def uninteresting_files(): + import torch._dynamo.external_utils + + mods = [ + torch._dynamo.external_utils, + ] + return {inspect.getfile(m) for m in mods} + + +CLOSURE_VARS = { + "___check_type_id": check_type_id, + "___check_obj_id": check_obj_id, + "___odict_getitem": collections.OrderedDict.__getitem__, + "___key_to_id": key_to_id, + "___dict_version": dict_version, + "___dict_contains": lambda a, b: a in b, + "___tuple_iterator_len": tuple_iterator_len, + "___tuple_iterator_getitem": tuple_iterator_getitem, + "__math_isnan": math.isnan, + "__numpy_isnan": None if np is None else np.isnan, + "inf": float("inf"), + "__load_module": importlib.import_module, + "utils_device": torch.utils._device, + "device": torch.device, + "___from_numpy": from_numpy, + "___as_tensor": torch.as_tensor, + "torch": torch, + "inspect": inspect, +} + +if sys.version_info[:2] <= (3, 8): + # [Note: Python Version <= 3.8] + # This branch should be dropped when we drop support for Python 3.8. + # Reason: 'ast.unparse' function was introduced in Python 3.9. + + try: + import astunparse # type: ignore[import] + + def _ast_unparse(node: ast.AST) -> str: + return astunparse.unparse(node).replace("\n", "") + + HAS_UNPARSE_FUNCTIONS = True + except ImportError: + HAS_UNPARSE_FUNCTIONS = False +else: + HAS_UNPARSE_FUNCTIONS = True + + def _ast_unparse(node: ast.AST) -> str: + return ast.unparse(node).replace("\n", "") + + +def strip_function_call(name): + """ + "___odict_getitem(a, 1)" => "a" + "a.layers[slice(2)][0]._xyz" ==> "a" + "getattr(a.layers[slice(2)][0]._abc, '0')" ==> "a" + "getattr(getattr(a.x[3], '0'), '3')" ==> "a" + "a.layers[slice(None, -1, None)][0]._xyz" ==> "a" + """ + # recursively find valid object name in function + valid_name = re.compile("[A-Za-z_].*") + curr = "" + for char in name: + if char in " (": + curr = "" + elif char in "),[]": + if curr and curr != "None" and valid_name.match(curr): + return strip_function_call(curr) + else: + curr += char + + return strip_getattr_getitem(name) + + +def strip_getattr_getitem(name): + """ + "a[1]" => "a" + "a.foo" => "a" + """ + return re.split(r"[.\[]", name)[0] + + +def get_verbose_code_part(code_part: str, guard: Guard) -> str: + extra = "" + if guard.user_stack: + for fs in reversed(guard.user_stack): + if fs.filename not in uninteresting_files(): + extra = f" # {format_frame(fs, line=True)}" + break + elif guard.stack: + extra = f" # {format_frame(guard.stack.summary()[-1])}" + + return f"{code_part:<60}{extra}" + + +def get_verbose_code_parts( + code_parts: Union[str | List[str]], guard: Guard +) -> List[str]: + if not isinstance(code_parts, list): + code_parts = [code_parts] + return [get_verbose_code_part(code_part, guard) for code_part in code_parts] + + +def convert_to_concrete_values(size_or_stride): + converted: List[Optional[int]] = [] + for dim in size_or_stride: + if not is_symbolic(dim): + converted.append(dim) + else: + assert isinstance(dim, torch.SymInt) + converted.append(dim.node.maybe_as_int()) + return converted + + +def get_tensor_guard_code_part(value, name, sizes, strides): + pytype = type(value) + dispatch_key = ( + torch._C._dispatch_keys(value) | torch._C._dispatch_tls_local_include_set() + ) - torch._C._dispatch_tls_local_exclude_set() + dtype = value.dtype + device_index = value.device.index + requires_grad = value.requires_grad + guard_str = ( + f"check_tensor({name}, {pytype.__qualname__}, {dispatch_key}, {dtype}, " + f"device={device_index}, requires_grad={requires_grad}, size={sizes}, stride={strides})" + ) + return guard_str + + +def get_key_index(dct, key): + return list(dct.keys()).index(key) + + +def get_key_index_source(source, index): + return f"list({source}.keys())[{index}]" + + +@dataclasses.dataclass(frozen=True) +class NNModuleAttrAccessorInfo: + # Represents where is the attr name is present in the nn module attribute + # access + + # Tells that the attribute can be accessed via __dict__ + present_in_generic_dict: bool = False + + # Either the actual name or _parameters/_buffers/_modules + l1_key: Optional[str] = None + + # Actual paramter/buffer/submodule name + l2_key: Optional[str] = None + + +def getitem_on_dict_manager( + source, base_guard_manager, base_example_value, example_value, guard_manager_enum +): + base_source_name = source.base.name() + source_name = source.name() + if isinstance(source.index, ConstDictKeySource): + index = source.index.index + else: + assert isinstance(base_example_value, dict) + index = get_key_index(base_example_value, source.index) + + key_source = get_key_index_source(base_source_name, index) + key_example_value = list(base_example_value.keys())[index] + if isinstance(key_example_value, (int, str)): + value_source = f"{base_source_name}[{key_example_value!r}]" + else: + value_source = f"{base_source_name}[{key_source}]" + if not isinstance(source.index, ConstDictKeySource): + # We have to insert a key manager guard here + # TODO - source debug string is probably wrong here. + base_guard_manager.get_key_manager( + index=index, + source=key_source, + example_value=source.index, + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ).add_equals_match_guard( + source.index, [f"{key_source} == {key_example_value!r}"] + ) + + return base_guard_manager.get_value_manager( + index=index, + source=value_source, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + + +def match_on_id_for_tensor(guard): + source = guard.originating_source + return source.is_dict_key() and not isinstance(source, GradSource) + + +# The ready to eval generated code (possibly multiple parts) for a guard, plus +# the original guard object that created it for provenance +@dataclasses.dataclass +class GuardCodeList: + code_list: List[str] + guard: Guard + + +class GuardManagerType(enum.Enum): + GUARD_MANAGER = 1 + DICT_GUARD_MANAGER = 2 + DICT_SUBCLASS_GUARD_MANAGER = 3 + + +class GuardBuilder(GuardBuilderBase): + def __init__( + self, + id_ref: Callable[[Any], str], + source_ref: Callable[[Source], str], + lookup_weakrefs: Callable[[object], ReferenceType[object]], + local_scope: Dict[str, object], + global_scope: Dict[str, object], + guard_manager: Optional[GuardManager], + check_fn_manager: CheckFunctionManager, + ): + self.id_ref = id_ref + self.source_ref = source_ref + self.lookup_weakrefs = lookup_weakrefs + self.scope: Dict[str, Dict[str, object]] = {"L": local_scope, "G": global_scope} + self.scope["__builtins__"] = builtins.__dict__.copy() + for ( + name, + package_module, + ) in torch.package.package_importer._package_imported_modules.items(): + name = name.replace(">", "_").replace("<", "_").replace(".", "_dot_") + # Write the package module into the scope so that we can import it + self.scope["__builtins__"][name] = package_module + # Write the demangled name to the scope so that we can use it + self.scope[name] = package_module + self.guard_manager = guard_manager + + self.argnames: List[str] = [] + # Code is python expression strings generated for each guard + self.code: List[GuardCodeList] = [] + # shape_env_code is only used by builder and is used for + # shape env code. This exists only because we need to make sure + # shape env guards get run after tensor match guards (since the + # tensor match guards make sure we actually have tensors) + self.shape_env_code: List[GuardCodeList] = [] + + # [Note - On Eager Tensor Guards] + # Most of the time, we generate Python code in a guard to directly + # check various properties. However, tensors are a bit special; + # it is too slow to check their properties one-by-one in Python. + # Instead, there is a C++ function TensorGuards.check which takes + # all of the tensor arguments and checks them all against compile-time + # examples entirely in C++. Thus, every time we process a + # TENSOR_MATCH guard, we just add another entry to + # tensor_check_names/tensor_check_examples, saying "for this local, + # check it against this example", and it all ends up getting + # swept up into a single call to ___check_tensors. Invariant: + # len(tensor_check_names) == len(tensor_check_examples). + # TODO: something here + self.tensor_check_names: List[str] = [] + self.tensor_check_examples: List[torch.Tensor] = [] + self.tensor_check_guards: List[Guard] = [] + self.tensor_check_guard_managers: List[GuardManager] = [] + + self.check_fn_manager: CheckFunctionManager = check_fn_manager + + # Collect the ids of dicts which need key order guarding. source_name is + # not sufficient because for nn modules, we can have different sources + # to access the same object - self._module["param"] is same as + # self.param. + self.key_order_guarded_dict_ids = set() + for source_name in self.check_fn_manager.output_graph.guard_on_key_order: + self.key_order_guarded_dict_ids.add(id(self.get(source_name))) + + # Keep track of weak references of objects with ID_MATCH guard. This + # info is stored alongside optimized_code and check_fn and is used to + # limit the number of cache entries with same ID_MATCH'd object. + self.id_matched_objs: Dict[str, ReferenceType[object]] = {} + + # Save the guard managers to avoid repeatedly traversing sources. + self._cached_guard_managers: Dict[ + str, torch._C._dynamo.guards.GuardManager + ] = {} + + self._cached_duplicate_input_guards: Set[Tuple[str, str]] = set() + + def guard_on_dict_keys_and_ignore_order(self, example_value, guard): + dict_mgr = self.get_guard_manager(guard) + if isinstance(dict_mgr, DictGuardManager): + raise NotImplementedError( + "Not expecting a DictGuardManager. Seems like Dynamo incorrectly " + f"added the dict to tx.output.guard_on_key_order for {guard.name}" + ) + + # Iterate over the dicts and install a dict_getitem_manager. + dict_source = guard.originating_source.name() + for key in example_value.keys(): + value = example_value[key] + value_source = GetItemSource(guard.originating_source, index=key) + guard_manager_enum = self.get_guard_manager_type( + value_source, example_value + ) + dict_mgr.dict_getitem_manager( + key=key, + source=f"{dict_source}[{key!r}]", + example_value=value, + guard_manager_enum=guard_manager_enum, + ) + + def guard_on_dict_keys_and_order(self, value, guard): + # Add key managers for the DictGuardManager. Then add either an + # ID_MATCH or EQUALS_MATCH guard on the key. + dict_mgr = self.get_guard_manager(guard) + if not isinstance(dict_mgr, DictGuardManager): + raise NotImplementedError( + "Expecting a DictGuardManager. Seems like Dynamo forgot " + f"to set the right guard manager enum for {guard.name}" + ) + assert isinstance(dict_mgr, DictGuardManager) + + for idx, key in enumerate(value.keys()): + key_source = get_key_index_source(guard.name, idx) + key_manager = dict_mgr.get_key_manager( + index=idx, + source=key_source, + example_value=key, + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ) + if key_is_id(key): + # Install ID_MATCH guard + id_val = self.id_ref(key) + key_manager.add_id_match_guard( + id_val, + get_verbose_code_parts( + f"__check_obj_id({key_source}, {id_val})", guard + ), + ) + else: + # Install EQUALS_MATCH guard + key_manager.add_equals_match_guard( + key, get_verbose_code_parts(f"{key_source} == {key!r}", guard) + ) + + def getattr_on_nn_module( + self, + source, + base_guard_manager, + base_example_value, + example_value, + base_source_name, + source_name, + guard_manager_enum, + ): + """ + This tries to avoid calling the expensive nn module custom getattr method by + checking if the attribute is accessible via __dict__. For attributes that + are not accessible via __dict__ (like descriptors), we fallback to + PyObject_GetAttr. + + There are two cases that we optimize for + 1) attributes present directly in __dict__, e.g training. + 2) parameters/buffers/modules - they can be accessed via _parameters, + _buffers, _modules keys in __dict__. For example, mod.linear can be + accessed as mod.__dict__["_parameters"]["linear"] + + The most common and expensive case for nn module guards is of type + mod.submod1.submod2.submod3.training. We avoid the python getattr of nn + modules by going through the __dict__. + """ + + def getitem_on_dict_mgr( + mgr, key, source_name, base_example_value, example_value, guard_manager_enum + ): + if isinstance(mgr, DictGuardManager): + # Case where the user code relies on key order, e.g., + # named_parameters + index = get_key_index(base_example_value, key) + + # Install the key manager and add equals match guard + key_source = f"list({source_name}.keys())[{index!r}]" + mgr.get_key_manager( + index=index, + source=key_source, + example_value=key, + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ).add_equals_match_guard(key, [f"{key_source} == {key!r}"]) + + # Install the value manager + return mgr.get_value_manager( + index=index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + else: + return mgr.dict_getitem_manager( + key=key, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + + attr_name = source.member + mod_dict = base_example_value.__dict__ + + all_class_attribute_names: Set[str] = set() + for x in inspect.getmro(base_example_value.__class__): + all_class_attribute_names.update(x.__dict__.keys()) + + accessor_info = NNModuleAttrAccessorInfo(False, None, None) + + if attr_name in mod_dict: + accessor_info = NNModuleAttrAccessorInfo(True, attr_name, None) + elif "_parameters" in mod_dict and attr_name in mod_dict["_parameters"]: + accessor_info = NNModuleAttrAccessorInfo(True, "_parameters", attr_name) + elif "_buffers" in mod_dict and attr_name in mod_dict["_buffers"]: + accessor_info = NNModuleAttrAccessorInfo(True, "_buffers", attr_name) + elif ( + attr_name not in all_class_attribute_names + and "_modules" in mod_dict + and attr_name in mod_dict["_modules"] + ): + # Check test_attr_precedence test - instance attributes always take precedence unless its an nn.Module. + accessor_info = NNModuleAttrAccessorInfo(True, "_modules", attr_name) + + if not accessor_info.present_in_generic_dict: + # The attribute can be accessed by __getattribute__ call, so rely on + # PyObject_GetAttr + return base_guard_manager.getattr_manager( + attr=source.member, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + else: + assert accessor_info.l1_key + l1_key = accessor_info.l1_key + l2_key = accessor_info.l2_key + + # Set source strings for debug info + mod_dict_source = f"{base_source_name}.__dict__" + l1_source_name = l2_source_name = None + l1_value = l2_value = None + l1_guard_manager_enum = l2_guard_manager_enum = None + if l2_key: + l1_source = AttrSource(source.base, l1_key) + l1_source_name = l1_source.name() + l1_value = mod_dict[l1_key] + # do not guard on key order for _parameters etc unless the user code + # actually needs the key order (e.g. calling named_parameters) + l1_guard_manager_enum = self.get_guard_manager_type(l1_source, l1_value) + + l2_source_name = source_name + l2_value = example_value + l2_guard_manager_enum = self.get_guard_manager_type( + source, example_value + ) + else: + l1_source_name = source_name + l1_value = example_value + l1_guard_manager_enum = self.get_guard_manager_type( + source, example_value + ) + + # Get __dict__ accessor. No need to guard on dict key order, so use base + # Guard Manager + mod_generic_dict_manager = base_guard_manager.get_generic_dict_manager( + source=mod_dict_source, + example_value=mod_dict, + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ) + + l1_mgr = getitem_on_dict_mgr( + mgr=mod_generic_dict_manager, + key=l1_key, + source_name=l1_source_name, + base_example_value=mod_dict, + example_value=l1_value, + guard_manager_enum=l1_guard_manager_enum, + ) + + if l2_key: + return getitem_on_dict_mgr( + mgr=l1_mgr, + key=l2_key, + source_name=l2_source_name, + base_example_value=l1_value, + example_value=l2_value, + guard_manager_enum=l2_guard_manager_enum, + ) + return l1_mgr + + def requires_key_order_guarding(self, source): + source_name = source.name() + if source_name == "": + return False + obj_id = id(self.get(source_name)) + return obj_id in self.key_order_guarded_dict_ids + + def get_guard_manager_type(self, source, example_value): + guard_manager_enum = GuardManagerType.GUARD_MANAGER + if self.requires_key_order_guarding(source): + assert isinstance(example_value, dict) + # If keys method is not overriden, we can use PyDict_Next to get key + # orderings. Read more in guards.cpp + if type(example_value).keys is type({}).keys: + guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER + else: + guard_manager_enum = GuardManagerType.DICT_SUBCLASS_GUARD_MANAGER + return guard_manager_enum + + def manager_guards_on_keys(self, mgr_enum): + return ( + mgr_enum == GuardManagerType.DICT_GUARD_MANAGER + or mgr_enum == GuardManagerType.DICT_SUBCLASS_GUARD_MANAGER + ) + + def get_global_guard_manager(self): + assert self.guard_manager # to make mypy happy + return self.guard_manager.root.globals_dict_manager( + f_globals=self.scope["G"], + source="G", + example_value=self.scope["G"], + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ) + + def get_guard_manager_from_source(self, source): + assert self.guard_manager # to make mypy happy + root_guard_manager = self.guard_manager.root + + example_value = None + source_name = source.name() + + if source_name != "" and source_name in self._cached_guard_managers: + return self._cached_guard_managers[source_name] + + if source_name != "": + example_value = self.get(source_name) + + guard_manager_enum = self.get_guard_manager_type(source, example_value) + + # Get base manager related information + base_source_name = None + base_example_value = None + base_guard_manager = None + base_guard_manager_enum = GuardManagerType.GUARD_MANAGER + if isinstance(source, ChainedSource): + base_source_name = source.base.name() + base_example_value = self.get(base_source_name) + base_guard_manager = self.get_guard_manager_from_source(source.base) + base_guard_manager_enum = self.get_guard_manager_type( + source.base, base_example_value + ) + + # Use istype instead of isinstance to check for exact type of source. + if istype(source, LocalSource): + # RootGuardManager accepts a dict but still its not a + # DictGuardManager because we will eventually move to + # fastlocals. + out = root_guard_manager.dict_getitem_manager( + key=source.local_name, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, GlobalSource): + # Global manager accepts a dict but it is not a DictGuardManager + # because globals dict is big and we typically guard on a very + # selected items on globals. + out = self.get_global_guard_manager().dict_getitem_manager( + key=source.global_name, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, GlobalWeakRefSource): + out = self.get_global_guard_manager().global_weakref_manager( + global_name=source.global_name, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, GlobalStateSource): + # Don't do anything here. We guard on global state completely in + # C++. So just return the root mgr. + return root_guard_manager + elif istype(source, ShapeEnvSource): + return root_guard_manager + elif istype(source, TypeSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.type_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype( + source, + ( + OptimizerSource, + NNModuleSource, + UnspecializedNNModuleSource, + UnspecializedBuiltinNNModuleSource, + FSDPNNModuleSource, + ), + ): + assert base_guard_manager # to make mypy happy + out = base_guard_manager + elif istype(source, GradSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.grad_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, (AttrSource, UnspecializedParamBufferSource)): + assert base_guard_manager # to make mypy happy + + if ( + isinstance(base_example_value, torch.nn.Module) + and get_custom_getattr(base_example_value) + is unpatched_nn_module_getattr + ): + out = self.getattr_on_nn_module( + source, + base_guard_manager, + base_example_value, + example_value, + base_source_name, + source_name, + guard_manager_enum, + ) + else: + out = base_guard_manager.getattr_manager( + attr=source.member, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, GetItemSource): + assert base_guard_manager # to make mypy happy + if isinstance(base_example_value, (dict, collections.OrderedDict)): + # TODO(anijain2305) - Consider isolating GetItemSource and + # DictGetItemSource (or maybe use ODictGetItemSource for + # dicts) so that GetItemSource is only for non dict objects. + if isinstance(base_guard_manager, DictGuardManager): + assert self.manager_guards_on_keys(base_guard_manager_enum) + out = getitem_on_dict_manager( + source, + base_guard_manager, + base_example_value, + example_value, + guard_manager_enum, + ) + else: + if isinstance(source.index, ConstDictKeySource): + raise RuntimeError( + "Expecting clean index here. Likely Dynamo forgot to mark" + " a dict as guard_on_key_order" + ) + out = base_guard_manager.dict_getitem_manager( + key=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif isinstance(base_example_value, list) and not source.index_is_slice: + out = base_guard_manager.list_getitem_manager( + key=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif isinstance(base_example_value, tuple) and not source.index_is_slice: + out = base_guard_manager.tuple_getitem_manager( + key=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + else: + index = source.index + if source.index_is_slice: + index = source.unpack_slice() + out = base_guard_manager.getitem_manager( + key=index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, ODictGetItemSource): + if isinstance(base_guard_manager, DictGuardManager): + assert self.manager_guards_on_keys(base_guard_manager_enum) + out = getitem_on_dict_manager( + source, + base_guard_manager, + base_example_value, + example_value, + guard_manager_enum, + ) + else: + assert base_guard_manager # to make mypy happy + out = base_guard_manager.dict_getitem_manager( + key=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, DefaultsSource): + assert base_guard_manager # to make mypy happy + assert callable(base_example_value) + if not source.is_kw: + out = base_guard_manager.func_defaults_manager( + source=base_source_name, + example_value=base_example_value.__defaults__, + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ).getitem_manager( + key=source.idx_key, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + else: + # kwdefauts is a dict, so use a DictGuardManager + kwdefaults = base_example_value.__kwdefaults__ + assert base_source_name is not None + kw_source = base_source_name + ".__kwdefaults__" + + # kwdefaults is a dict. No need to guard on dict order. + dict_mgr = base_guard_manager.func_kwdefaults_manager( + source=kw_source, + example_value=kwdefaults, + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ) + assert not isinstance(dict_mgr, DictGuardManager) + + out = dict_mgr.dict_getitem_manager( + key=source.idx_key, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, NumpyTensorSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=from_numpy, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, SubclassAttrListSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: x.__tensor_flatten__()[0], + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, FlattenScriptObjectSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: x.__obj_flatten__(), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, ScriptObjectQualifiedNameSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: x._type().qualified_name(), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, AttrProxySource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: x.get_base(), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, TupleIteratorGetItemSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.tuple_iterator_getitem_manager( + index=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif isinstance(source, ConstDictKeySource): + if not isinstance(base_guard_manager, DictGuardManager): + raise AssertionError( + "ConstDictKeySource can only work on DictGuardManager" + ) + out = base_guard_manager.get_key_manager( + index=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif isinstance(source, WeakRefCallSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.weakref_call_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + else: + raise AssertionError( + f"missing guard manager builder {source} - {source.name()}" + ) + + self._cached_guard_managers[source.name()] = out + return out + + def get_guard_manager(self, guard: Guard): + return self.get_guard_manager_from_source(guard.originating_source) + + def add_python_lambda_leaf_guard_to_root( + self, + code_parts, + verbose_code_parts, + closure_vars=CLOSURE_VARS, + is_epilogue=True, + ): + # Adds a lambda leaf guard to the root guard manager. It wraps the + # code_parts in a function object which is then passed on to the leaf + # guard. + make_guard_fn_args = ", ".join(closure_vars.keys()) + guard_body, pycode = build_guard_function(code_parts, make_guard_fn_args) + out: Dict[str, Any] = {} + globals_for_guard_fn = {"G": self.scope["G"]} + exec(pycode, globals_for_guard_fn, out) + guard_fn = out["___make_guard_fn"](*closure_vars.values()) + assert self.guard_manager # to make mypy happy + if is_epilogue: + # Epilogue guards are run after all the other guards have finished. + # If epilogue guards contain a getattr or getitem access, one of the + # other guards would fail preventing the epilogue guards to run. + self.guard_manager.root.add_epilogue_lambda_guard( + guard_fn, verbose_code_parts + ) + else: + self.guard_manager.root.add_lambda_guard(guard_fn, verbose_code_parts) + + # Warning: use this with care! This lets you access what the current + # value of the value you are guarding on is. You probably don't want + # to actually durably save this value though (because it's specific + # to this frame!) Instead, you should be reading out some property + # (like its type) which is what you permanently install into the + # guard code. + def get(self, name: str) -> Any: + return eval(name, self.scope, CLOSURE_VARS) + + # Registers the usage of the source name referenced by the + # string (or stored in the Guard) as being guarded upon. It's important + # to call this before generating some code that makes use of 'guard', + # because without this call, we won't actually bind the variable + # you reference in the actual guard closure (oops!) + def arg_ref(self, guard: Union[str, Guard]) -> str: + name: str + if isinstance(guard, str): + name = guard + else: + name = guard.name + base = strip_getattr_getitem(strip_function_call(name)) + if base not in self.argnames: + if re.match(r"[a-zA-Z0-9_]+", base): + if re.match(r"^\d+$", base): + log.warning("invalid var name: %s", guard) + self.argnames.append(base) + + return name + + def _guard_on_attribute(self, guard: Guard, attr_name: str, guard_fn): + attr_source = AttrSource(guard.originating_source, attr_name) + # Copy the stack info + new_guard = Guard( + attr_source, guard_fn, stack=guard.stack, user_stack=guard.user_stack + ) + new_guard.create(self) + + # Note: the order of the guards in this file matters since we sort guards on the same object by lineno + def HASATTR(self, guard: Guard): + source = guard.originating_source + if isinstance(source, NNModuleSource): + source = source.base + assert isinstance(source, AttrSource), f"invalid source {guard.name}" + base_source = source.base + base = base_source.name() + attr = source.member + + ref = self.arg_ref(base) + val = hasattr(self.get(base), attr) + code = None + if val: + code = f"hasattr({ref}, {attr!r})" + else: + code = f"not hasattr({ref}, {attr!r})" + self._set_guard_export_info( + guard, [code], provided_guarded_object=self.get(base) + ) + + if config.enable_cpp_guard_manager: + base_manager = self.get_guard_manager_from_source(base_source) + if val: + # Just install a getattr manager. GetAttrGuardAccessor itself + # acts as hasattr guard. + example_value = self.get(source.name()) + base_example_value = self.get(base) + guard_manager_enum = self.get_guard_manager_type(source, example_value) + + # if the base value is nn.Module, check if we can speedup the + # guard by going through __dict__ attrs. + if ( + isinstance(base_example_value, torch.nn.Module) + and get_custom_getattr(base_example_value) + is unpatched_nn_module_getattr + ): + return self.getattr_on_nn_module( + source, + base_manager, + base_example_value, + example_value, + base, + source.name(), + guard_manager_enum, + ) + else: + base_manager.getattr_manager( + attr=attr, + source=guard.name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + else: + base_manager.add_no_hasattr_guard( + attr, get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, [code]) + + def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None: + assert attr is not None + ref = self.arg_ref(guard) + val = self.get(guard.name) + assert isinstance(val, torch.nn.Module) + + base_manager = self.get_guard_manager(guard) + + mod_dict_source = f"{guard.name}.__dict__" + mod_generic_dict_manager = base_manager.get_generic_dict_manager( + source=mod_dict_source, + example_value=val.__dict__, + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ) + + code = f"not ___dict_contains({attr!r}, {ref}.__dict__)" + mod_generic_dict_manager.add_dict_contains_guard( + False, attr, get_verbose_code_parts(code, guard) + ) + + def TYPE_MATCH(self, guard: Guard) -> None: + # ___check_type_id is same as `id(type(x)) == y` + t = type(self.get(guard.name)) + obj_id = self.id_ref(t) + code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})" + self._set_guard_export_info(guard, [code]) + + if config.enable_cpp_guard_manager: + self.get_guard_manager(guard).add_type_match_guard( + obj_id, get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, [code]) + + def DICT_VERSION(self, guard: Guard): + # ___check_dict_version is same as `dict_version(x) == y` + ref = self.arg_ref(guard) + val = self.get(guard.name) + version = dict_version(self.get(guard.name)) + code = f"___dict_version({ref}) == {version}" + self._set_guard_export_info(guard, [code]) + + if config.enable_cpp_guard_manager: + # TODO(anijain2305) - Delete this when DictGuardManager uses tags + # for dicts. + self.get_guard_manager(guard).add_dict_version_guard( + val, get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, [code]) + + def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): + dict_ref = self.arg_ref(guard) + + maybe_not = "not " if invert else "" + code = f"{maybe_not}___dict_contains({key!r}, {dict_ref})" + self._set_guard_export_info(guard, [code]) + + if config.enable_cpp_guard_manager: + self.get_guard_manager(guard).add_dict_contains_guard( + not invert, key, get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, [code]) + + def ID_MATCH(self, guard: Guard): + # ___check_obj_id is same as `id(x) == y` + if isinstance(guard.originating_source, TypeSource): + # optional optimization to produce cleaner/faster guard code + return self.TYPE_MATCH( + Guard(guard.originating_source.base, GuardBuilder.TYPE_MATCH) # type: ignore[arg-type] + ) + + ref = self.arg_ref(guard) + val = self.get(guard.name) + id_val = self.id_ref(val) + code = f"___check_obj_id({ref}, {id_val})" + self._set_guard_export_info(guard, [code]) + + if config.enable_cpp_guard_manager: + self.get_guard_manager(guard).add_id_match_guard( + id_val, get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, [code]) + + # Keep track of ID_MATCH'd objects. This will be used to modify the + # cache size logic + if isinstance(guard.originating_source, LocalSource): + # TODO(anijain2305) - This is currently restricted to nn.Module objects + # because many other ID_MATCH'd objects fail - like DeviceMesh. + # Increase the scope of ID_MATCH'd objects. + if isinstance(val, torch.nn.Module): + local_name = guard.originating_source.local_name + weak_id = self.lookup_weakrefs(val) + if weak_id is not None: + self.id_matched_objs[local_name] = weak_id + + def NOT_NONE_MATCH(self, guard: Guard, value=None): + ref = self.arg_ref(guard) + val = self.get(guard.name) + assert isinstance(val, torch.Tensor) + code = f"{ref} is not None" + self._set_guard_export_info(guard, [code]) + + if config.enable_cpp_guard_manager: + self.get_guard_manager(guard).add_not_none_guard( + get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, [code]) + + def NAME_MATCH(self, guard: Guard): + self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) + + def DATA_PTR_MATCH(self, guard: Guard): + # Add a type check. C++ guard has the type check internally, so only + # enable it for Python guards. + if not config.enable_cpp_guard_manager: + self.TYPE_MATCH(guard) + + obj = self.get(guard.name) + code = f"{self.arg_ref(guard)}.data_ptr() == {obj.data_ptr()}" + self._set_guard_export_info(guard, [code]) + + if config.enable_cpp_guard_manager: + self.get_guard_manager(guard).add_data_ptr_guard( + obj, get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, [code]) + + def DUAL_LEVEL(self, guard: Guard): + # Invalidate dual level if current dual level is different than the one + # in the fx graph + dual_level = torch.autograd.forward_ad._current_level + code = [f"torch.autograd.forward_ad._current_level == {dual_level}"] + self._set_guard_export_info(guard, [code]) + if config.enable_cpp_guard_manager: + # TODO(anijain2305) - Consider this moving this guard to C++ + forward_ad = torch.autograd.forward_ad + + def fn(x): + return forward_ad._current_level == dual_level + + assert self.guard_manager # to make mypy happy + self.guard_manager.root.add_lambda_guard( + fn, get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, code) + + def FUNCTORCH_STACK_MATCH(self, guard: Guard): + # Invalidate functorch code if current level is different than + # the one when FX graph was generated + cis = torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters() + states = [ci.get_state() for ci in cis] + code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"] + self._set_guard_export_info(guard, code) + + if config.enable_cpp_guard_manager: + # TODO(anijain2305) - Consider this moving this guard to C++ + compare_fn = torch._functorch.pyfunctorch.compare_functorch_state + + def fn(x): + return compare_fn(states) + + assert self.guard_manager # to make mypy happy + self.guard_manager.root.add_lambda_guard( + fn, get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, code) + + def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard): + value = self.get(guard.name) + original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1]) + if hasattr(value, "__metadata_guard__"): + verify_guard_fn_signature(value) + + def metadata_checker(x): + return value.__metadata_guard__( + original_metadata, x.__tensor_flatten__()[1] + ) + + else: + + def metadata_checker(x): + return x.__tensor_flatten__()[1] == original_metadata + + global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}" + if config.enable_cpp_guard_manager: + self.get_guard_manager(guard).add_lambda_guard( + metadata_checker, get_verbose_code_parts(global_name, guard) + ) + else: + global_scope = self.get("G") + global_scope[global_name] = metadata_checker + code = [f"{global_name}({self.get(guard.name)})"] + self._produce_guard_code(guard, code) + + def EQUALS_MATCH(self, guard: Guard): + ref = self.arg_ref(guard) + val = self.get(guard.name) + t = type(val) + if np: + np_types: Tuple[Type[Any], ...] = ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.float16, + np.float32, + np.float64, + ) + else: + np_types = () + + ok_mutable_types = (list, set) + + ok_types = tuple( + common_constant_types + | { + type, + tuple, + frozenset, + slice, + range, + torch.Size, + *np_types, + *ok_mutable_types, + } + ) + + if torch.distributed.is_available(): + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed.tensor.placement_types import ( + Partial, + Replicate, + Shard, + ) + + ok_types = ok_types + ( + Shard, + Replicate, + Partial, + DeviceMesh, + ) + + if istype(val, dict): + assert all( + istype(x, ok_types) for x in itertools.chain(val.keys(), val.values()) + ) + else: + assert istype( + val, + ok_types, + ), f"Unexpected type {type(val)}, not in {ok_types}" + + # Special case for nan because float("nan") == float("nan") evaluates to False + if istype(val, float) and math.isnan(val): + self.TYPE_MATCH(guard) + code = [] + code.append(f"__math_isnan({ref})") + self._set_guard_export_info(guard, code) + + if config.enable_cpp_guard_manager: + self.get_guard_manager(guard).add_lambda_guard( + CLOSURE_VARS["__math_isnan"], get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, code) + return + + # Python math library doesn't support complex nan, so we need to use numpy + if istype(val, complex) and np.isnan(val): + self.TYPE_MATCH(guard) + code = [] + code.append(f"__numpy_isnan({ref})") + self._set_guard_export_info(guard, code) + + if config.enable_cpp_guard_manager: + self.get_guard_manager(guard).add_lambda_guard( + CLOSURE_VARS["__numpy_isnan"], get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, code) + return + + if config.enable_cpp_guard_manager: + # Construct a debug string to put into the c++ equals match guard. + code = [f"{ref} == {val!r}"] + if istype(val, ok_mutable_types): + # C++ guards perform a pointer equality check to speedup guards, but the assumption is that the object + # is mutable. For a few corner cases like sets and lists, we make a deepcopy to purposefully fail the + # pointer equality check. + val = deepcopy(val) + self.get_guard_manager(guard).add_equals_match_guard( + val, get_verbose_code_parts(code, guard) + ) + self._set_guard_export_info(guard, code) + return + + code = [] + + # If matching equality against list/tuple, we must also check that + # the internal types match. (TODO: what about nested lists?) + if istype(val, (list, tuple)): + # NB: SEQUENCE_LENGTH takes care of the outer __check_type_id test + self.SEQUENCE_LENGTH(guard) + + for idx, elem in enumerate(val): + code.append( + f"___check_type_id({ref}[{idx}], {self.id_ref(type(elem))})" + ) + else: + # Add type check to prevent equality check between tensor and non-tensor. + self.TYPE_MATCH(guard) + + if istype(val, torch.Size): + val = tuple(val) + + # Code object can not be compared against their string representation + # I.e `eval(f"{compile('2+2','','exec')!r}")` raises SyntaxError + assert not istype(val, types.CodeType) + + # TODO: It feels like it would be better to just implement our own + # equality test in C that handles all of the necessary type checking + # and NaN tests + code.append(f"{ref} == {val!r}") + self._produce_guard_code(guard, code) + self._set_guard_export_info(guard, code) + + def CONSTANT_MATCH(self, guard: Guard): + val = self.get(guard.name) + if istype(val, (bool, type(None), types.CodeType)): + self.ID_MATCH(guard) + else: + self.EQUALS_MATCH(guard) + + def NN_MODULE(self, guard: Guard): + self.ID_MATCH(guard) + val = self.get(guard.name) + if hasattr(val, "training"): + assert istype(val.training, bool) + self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) + else: + exc.unimplemented(f"Guard setup for uninitialized class {type(val)}") + + def FUNCTION_MATCH(self, guard: Guard): + """things like torch.add and user defined functions""" + return self.ID_MATCH(guard) + + def CLOSURE_MATCH(self, guard: Guard): + """matches a closure by __code__ id.""" + val = self.get(guard.name) + # Strictly only want user-defined functions + if type(val) == types.FunctionType and hasattr(val, "__code__"): + self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) + self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) + else: + self.FUNCTION_MATCH(guard) + + def BUILTIN_MATCH(self, guard: Guard): + return self.FUNCTION_MATCH(guard) + + def PYMODULE_MATCH(self, guard: Guard): + return self.FUNCTION_MATCH(guard) + + def SEQUENCE_LENGTH(self, guard): + # This guard is used to check lenght of PySequence objects like list, + # tuple, collections.deque etc + ref = self.arg_ref(guard) + value = self.get(guard.name) + t = type(value) + + if not (config.enable_cpp_guard_manager and isinstance(value, dict)): + # C++ DICT_LENGTH checks for type + self.TYPE_MATCH(guard) + + code = [] + if len(value) == 0: + code.append(f"not {ref}") + else: + code.append(f"len({ref}) == {len(value)}") + + self._set_guard_export_info(guard, code) + if config.enable_cpp_guard_manager: + if isinstance(value, dict): + self.get_guard_manager(guard).add_dict_length_check_guard( + len(value), get_verbose_code_parts(code, guard) + ) + else: + self.get_guard_manager(guard).add_length_check_guard( + len(value), get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, code) + + def TUPLE_ITERATOR_LEN(self, guard): + ref = self.arg_ref(guard) + value = self.get(guard.name) + t = type(value) + + if not config.enable_cpp_guard_manager: + # C++ guard already checks the type + self.TYPE_MATCH(guard) + + code = [] + code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}") + self._set_guard_export_info(guard, code) + + if config.enable_cpp_guard_manager: + t = type(value) + obj_id = self.id_ref(t) + + self.get_guard_manager(guard).add_tuple_iterator_length_guard( + tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, code) + + # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards + def DUPLICATE_INPUT(self, guard, source_b): + ref_a = self.arg_ref(guard) + ref_b = self.arg_ref(source_b.name()) + + if is_from_optimizer_source( + guard.originating_source + ) or is_from_optimizer_source(source_b): + return + + code = [f"{ref_b} is {ref_a}"] + self._set_guard_export_info(guard, code) + + if config.enable_cpp_guard_manager: + # Check that the guard has not been inserted already + key = (ref_a, ref_b) + if key in self._cached_duplicate_input_guards: + return + self._cached_duplicate_input_guards.add((ref_a, ref_b)) + self._cached_duplicate_input_guards.add((ref_b, ref_a)) + + install_object_aliasing_guard( + self.get_guard_manager(guard), + self.get_guard_manager_from_source(source_b), + get_verbose_code_parts(code, guard), + ) + else: + self._produce_guard_code(guard, code) + + def DICT_KEYS(self, guard): + # Guard on the keys and their order + ref = self.arg_ref(guard) + value = self.get(guard.name) + t = type(value) + + self.TYPE_MATCH(guard) + code = [] + any_key_is_id = any(key_is_id(k) for k in value.keys()) + const_keys_repr = dict_keys_repr( + key_to_id(value), + local=is_from_local_source(guard.originating_source), + ) + if any_key_is_id: + code.append(f"___key_to_id({ref}) == {const_keys_repr}") + else: + code.append(f"list({ref}.keys()) == {const_keys_repr}") + + self._set_guard_export_info(guard, code) + if config.enable_cpp_guard_manager: + if self.requires_key_order_guarding(guard.originating_source): + self.guard_on_dict_keys_and_order(value, guard) + else: + self.guard_on_dict_keys_and_ignore_order(value, guard) + else: + self._produce_guard_code(guard, code) + + def WEAKREF_ALIVE(self, guard): + code = [f"{self.arg_ref(guard)} is not None"] + + self._set_guard_export_info(guard, code) + if config.enable_cpp_guard_manager: + self.get_guard_manager(guard).add_not_none_guard( + get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, code) + + def DICT_CONST_KEYS(self, guard): + """Constant keys match""" + ref = self.arg_ref(guard) + value = self.get(guard.name) + t = type(value) + + if not config.enable_cpp_guard_manager: + # DictGuardManager supports TYPE_MATCH internally + self.TYPE_MATCH(guard) + + code = [] + code.append(f"list({ref}.keys()) == {list(value.keys())!r}") + self._set_guard_export_info(guard, code) + + if config.enable_cpp_guard_manager: + if self.requires_key_order_guarding(guard.originating_source): + self.guard_on_dict_keys_and_order(value, guard) + else: + self.guard_on_dict_keys_and_ignore_order(value, guard) + else: + self._produce_guard_code(guard, code) + + def EMPTY_NN_MODULE_HOOKS_DICT(self, guard): + """Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards""" + if config.skip_nnmodule_hook_guards: + # This is unsafe if you add/remove a hook on nn module variable + return + self.SEQUENCE_LENGTH(guard) + + def OBJECT_MUTATION(self, guard: Guard): + mutation_guard.watch(self.get(guard.name), self.check_fn_manager) + + def GRAD_MODE(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def DETERMINISTIC_ALGORITHMS(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def TORCH_FUNCTION_STATE(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def FSDP_TRAINING_STATE(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def DEFAULT_DEVICE(self, guard: Guard): + """Guard on CURRENT_DEVICE per torch.utils._device""" + assert guard.source is GuardSource.GLOBAL + import torch.utils._device as m + + code = [f"utils_device.CURRENT_DEVICE == {m.CURRENT_DEVICE!r}"] + self._set_guard_export_info(guard, code) + + if config.enable_cpp_guard_manager: + self.get_guard_manager(guard).add_default_device_guard( + get_verbose_code_parts(code, guard) + ) + else: + self._produce_guard_code(guard, code) + + def SHAPE_ENV(self, guard: Guard): + # Let's handle ShapeEnv guards. To do this, we will resolve + # shape variables to sources from tracked_fakes. This must happen after + # tensor checks. + assert guard.name == "" + output_graph = self.check_fn_manager.output_graph + # NB: self.output_graph can be None in the debug_nops tests + fs = output_graph.tracked_fakes + input_contexts = [a.symbolic_context for a in fs] + + def get_sources(t_id, dim): + # Looks up base sources mapped to a tensor id and uses them to create + # sources for the corresponding tensor dimension. + return [ + TensorPropertySource(source, TensorProperty.SIZE, dim) + for source in output_graph.tracked_fakes_id_to_source[t_id] + ] + + if output_graph.export_constraints: + names: Dict[str, Tuple[int, int]] = {} + source_pairs: List[Tuple[Source, Source]] = [] + derived_equalities: List[ # type: ignore[type-arg] + Tuple[Source, Union[Source, Symbol], Callable] + ] = [] + phantom_symbols: Dict[str, Symbol] = {} + for constraint in output_graph.export_constraints: + if constraint.t_id in output_graph.tracked_fakes_id_to_source: + torch.export.dynamic_shapes._process_equalities( + constraint, + get_sources, + output_graph.shape_env, + names, + source_pairs, + derived_equalities, + phantom_symbols, + ) + else: + log.warning("Untracked tensor used in export constraints") + equalities_inputs = EqualityConstraint( + source_pairs=source_pairs, + derived_equalities=derived_equalities, + phantom_symbols=list(phantom_symbols.values()), + warn_only=False, + ) + else: + equalities_inputs = None + guards = output_graph.shape_env.produce_guards( + [a.fake for a in fs], + [a.source for a in fs], + input_contexts=input_contexts, + equalities_inputs=equalities_inputs, + source_ref=self.source_ref, + # Export keeps static. + ignore_static=(not self.check_fn_manager.output_graph.export), + ) + # When exporting, we may work with the shape constraints some more in + # postprocessing, so don't freeze yet + if not self.check_fn_manager.output_graph.export: + output_graph.shape_env.freeze() + + for shape_guard in guards: + self._set_guard_export_info(guard, [shape_guard]) + + if config.enable_cpp_guard_manager: + # Install all the symbolic guards in one lambda guard. These are run + # at the very end of the RootGuardManager via epilogue guards. + # TODO(anijain2305,williamwen42) - Consider moving this to C++. + code_parts = guards + self.add_python_lambda_leaf_guard_to_root( + code_parts, + get_verbose_code_parts(code_parts, guard), + closure_vars={**SYMPY_INTERP, **CLOSURE_VARS}, + ) + else: + for shape_guard in guards: + self._produce_guard_code(guard, [shape_guard], shape_env=True) + + def TENSOR_MATCH(self, guard: Guard, value=None): + # For FSDP modules, we can skip guards on nn module tensors because FSDP + # eager assumes that the params are unchanged once the model is wrapped. + if guard.is_fsdp_module(): + return + + # For tensors that are part of the Dynamo extracted Fx graph module, an + # ID_MATCH suffices. Once we turn on inline_inbuilt_nn_modules, these + # will be lifted as inputs and have a TENSOR_MATCH guard. + # For numpy tensors, always use TENSOR_MATCH because __from_numpy leads + # to a new tensor everytime and therefore id differs. + if ( + guard.is_specialized_nn_module() + and not isinstance(guard.originating_source, NumpyTensorSource) + ) or match_on_id_for_tensor(guard): + self.ID_MATCH(guard) + else: + if isinstance(value, TensorWeakRef): + value = value() + + value = value if value is not None else self.get(guard.name) + assert isinstance(value, torch.Tensor) + + tensor_name = self.arg_ref(guard) + # [Note - On Export Tensor Guards] + # + # In eager mode, tensor guards are evaluated through C++, in guards.cpp + # see [Note - On Eager Tensor Guards] for more info. + # + # In export mode, we instead maintain parallel logic between C++ and python + # here, with an exception of checking the dispatch key - with the idea that a dispatch key + # is an entirely runtime notion that would make no sense to keep in an exported graph. + # + # Now, this idea is okay, but to paraphrase @ezyang, this mental model is sufficient for now, although + # not entirely true. + # For example, suppose one of the input tensors had the negative dispatch key. + # You should end up with a graph that is specialized for tensors that have a negative dispatch key. + # If you allow a Tensor that does NOT have this bit set, you will accidentally run it "as if" it were negated. + # Now, negative key only shows up for complex numbers, and most likely, the exported to target doesn't + # support this feature at all, but the point stands that :some: tensor state only shows up on dispatch key. + # TODO(voz): Either populate a dispatch_key check into the guards, or error on users passing in an unsupported + # subset of keys during export. + # + # The list of tensor fields and calls we care about can be found in `terms` below. + # TODO(voz): We are missing storage offset in all our tensor guards? + code: List[str] = [] + if self.check_fn_manager.output_graph.export: + self.TYPE_MATCH(guard) + terms = [ + "dtype", + "device", + "requires_grad", + "ndimension()", + ] + + for term in terms: + real_value = self.get(tensor_name + "." + term) + if istype(real_value, (torch.device, torch.dtype)): + # copy pasted from EQUALS_MATCH + code.append(f"str({tensor_name}.{term}) == {str(real_value)!r}") + else: + code.append(f"{tensor_name}.{term} == {real_value}") + else: + self.tensor_check_examples.append(value) + self.tensor_check_names.append(tensor_name) + self.tensor_check_guards.append(guard) + + if config.enable_cpp_guard_manager: + guard_manager = self.get_guard_manager(guard) + # Keep track of all the tensor guard managers to insert + # NoAliasing check at the end. + self.tensor_check_guard_managers.append(guard_manager) + + output_graph = self.check_fn_manager.output_graph + metadata = output_graph.input_source_to_sizes_strides[ + guard.originating_source + ] + size = convert_to_concrete_values(metadata["size"]) + stride = convert_to_concrete_values(metadata["stride"]) + + verbose_code_parts = get_verbose_code_parts( + get_tensor_guard_code_part(value, tensor_name, size, stride), + guard, + ) + guard_manager.add_tensor_match_guard( + value, + size, + stride, + tensor_name, + verbose_code_parts, + ) + + # A frame is valid for reuse with dynamic dimensions if the new + # (user-requested) dynamic dimensions are a subset of the old + # (already compiled) dynamic dimensions. + # + # It's a little non-obvious why you'd want this: in particular, + # if an already compiled frame matches all of the guards, why + # not just use it, why force a recompile? + # + # We force it for two reasons: + # + # - The user *required* us to compile with a new dynamic dimension, + # we should not ignore that and serve up the old, specialized + # frame. Listen to the user! + # + # - In fact, we are obligated to *raise an error* if we fail to + # make the requested dimension dynamic. If we don't + # recompile, we can't tell if that dimension can actually be + # made dynamic. + # + # If the new dynamic dims are a subset of the old, we already know + # we can make them dynamic (since we made them dynamic in old). + # This is slightly unsound, because maybe your input size is + # [s0, s0, s1] and so you can do it dynamic if you say dynamic + # dims {0, 1, 2} but you can't if you only do {0, 2} (because now + # the second s0 is specialized). But we're not entirely sure if + # this is a good idea anyway lol... (if you want to try removing + # this logic, be my guest! -- ezyang 2024) + # + assert guard.source is not None + static, reason = tensor_always_has_static_shape( + value, is_tensor=True, tensor_source=guard.originating_source + ) + + if not static: + if hasattr(value, "_dynamo_dynamic_indices"): + dynamic_indices = value._dynamo_dynamic_indices + code_part = f"(({tensor_name}._dynamo_dynamic_indices.issubset({dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" # noqa: B950 + code.append(code_part) + if config.enable_cpp_guard_manager: + self.get_guard_manager(guard).add_dynamic_indices_guard( + dynamic_indices, get_verbose_code_parts(code_part, guard) + ) + # In the case of us not having any dynamic dimension indices, we compiled the frame with no chance of + # raising for this specific tensor - and any inputs with more dynamic user directives specified must be recompiled. + else: + code_part = ( + f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False" + ) + code.append(code_part) + if config.enable_cpp_guard_manager: + self.get_guard_manager(guard).add_no_hasattr_guard( + "_dynamo_dynamic_indices", + get_verbose_code_parts(code_part, guard), + ) + if len(code) > 0: + self._set_guard_export_info(guard, code) + if not config.enable_cpp_guard_manager: + self._produce_guard_code(guard, code) + + # A util that appends guarded code + def _produce_guard_code(self, guard, code_list, shape_env=False): + assert not config.enable_cpp_guard_manager + if shape_env: + self.shape_env_code.append(GuardCodeList(code_list, guard)) + else: + self.code.append(GuardCodeList(code_list, guard)) + + # A util that in the case of export, adds data onto guards + def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None): + # WARNING: It is important that cur_frame/caller do NOT stay in + # the current frame, because they will keep things live longer + # than they should. See TestMisc.test_release_module_memory + cur_frame = currentframe() + assert cur_frame is not None + caller = cur_frame.f_back + del cur_frame + assert caller is not None + func_name = getframeinfo(caller)[2] + del caller + # We use func_name for export, so might as well get a nice defensive check out of it + assert func_name in dir( + self.__class__ + ), f"_produce_guard_code must be called from inside GuardedCode. Called from {func_name}" + + # Not all guards have names, some can be installed globally (see asserts on HAS_GRAD) + if provided_guarded_object is None: + name_valid = guard.name is not None and guard.name != "" + + guarded_object = self.get(guard.name) if name_valid else None + else: + guarded_object = provided_guarded_object + + guarded_object_type = ( + weakref.ref(type(guarded_object)) if guarded_object is not None else None + ) + obj_ref = None + # Not necessary to have weakref for Enum type, but there is a bug that + # makes hasattr(guarded_object.__class__, "__weakref__") return True. + if hasattr(guarded_object.__class__, "__weakref__") and not isinstance( + guarded_object, enum.Enum + ): + obj_ref = weakref.ref(guarded_object) + + guard.set_export_info( + func_name, + guarded_object_type, + code_list, + obj_ref, + ) + + +# Common Sub-Expression Elimination for Python expressions. +# +# There are 2 steps to this pass: +# 1. Count the frequency of each sub-expression (i.e. inner +# node in the AST tree) +# +# 2. Replace those that occur more than once by a fresh variable 'v'. +# 'v' will be defined in the 'preface' list (output argument to +# 'NodeTransformer') +# +# NB: the use of 'ast.unparse' while visiting the nodes makes this pass +# quadratic on the depth of the tree. +# +# NB: this pass creates a new variable for each AST node that is repeated +# more than 'USE_THRESHOLD'. e.g. if 'a.b.c.d' is used 10 times, 'a.b.c' +# and 'a.b' are also used 10 times. So, there will be a new variable for +# each of them. +class PyExprCSEPass: + # Maximum number of times a given expression can be used without being + # replaced by a fresh variable. + USE_THRESHOLD = 1 + + # Ad-Hoc: AST nodes this pass focuses on. + ALLOWED_NODE_TYPES = (ast.Attribute, ast.Call, ast.Subscript) + + @dataclasses.dataclass + class Config: + expr_count: Dict[str, int] + expr_to_name: Dict[str, str] + + class ExprCounter(ast.NodeVisitor): + def __init__(self, config: PyExprCSEPass.Config) -> None: + self._config = config + + def visit(self, node: ast.AST) -> Any: + if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): + self._config.expr_count[_ast_unparse(node)] += 1 + super().visit(node) + + class Replacer(ast.NodeTransformer): + def __init__( + self, + config: PyExprCSEPass.Config, + gen_name: Callable[[], str], + ) -> None: + super().__init__() + self._config = config + self._gen_name = gen_name + self.preface: List[str] = [] + + def visit(self, node: ast.AST) -> Any: + if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): + expr = _ast_unparse(node) + + # Replacement only occurs if a given expression is used more + # than once. + if self._config.expr_count[expr] > PyExprCSEPass.USE_THRESHOLD: + if expr not in self._config.expr_to_name: + # Parent 'visit' is called so that we CSE the inner expressions first. + # + # The resulting expression is used as right-hand-side of the variable + # assignment. i.e. we are CSE-ing the children before the parents. + # + # Indexing still uses the old 'node', since that's what was counted + # by the 'NodeVisitor'. + node_ = super().visit(node) + expr_ = _ast_unparse(node_) + var_name = self._gen_name() + self.preface.append(f"{var_name} = {expr_}") + self._config.expr_to_name[expr] = var_name + else: + var_name = self._config.expr_to_name[expr] + return ast.Name(var_name, ast.Load()) + + return super().visit(node) + + def __init__(self) -> None: + self._counter = 0 + self._config = self.Config( + expr_count=collections.defaultdict(lambda: 0), expr_to_name={} + ) + + def _new_var(self, prefix: str = "_var") -> str: + name = f"{prefix}{self._counter}" + self._counter += 1 + return name + + def count(self, exprs: List[str]) -> None: + counter = self.ExprCounter(self._config) + for e in exprs: + try: + counter.visit(ast.parse(e)) + except SyntaxError as ex: + log.exception("Failed to visit expr at line %s.\n%s", ex.lineno, e) + raise + + def replace(self, expr: str) -> Tuple[List[str], str]: + replacer = self.Replacer(self._config, self._new_var) + new_node = replacer.visit(ast.parse(expr)) + return replacer.preface, _ast_unparse(new_node) + + +def must_add_nn_module_guards(guard): + # For config.guard_nn_modules=False, we can skip all the guards that + # originate from inside of nn module except for a few categories. + return ( + # Guard for defaults + isinstance(guard.originating_source, DefaultsSource) + # Guard using dict tags if the config flag is set + or ( + config.guard_nn_modules_using_dict_tags + and guard.create_fn is GuardBuilder.NN_MODULE + ) + ) + + +class DeletedGuardFn: + pass + + +# NB: Naively, you'd expect this to only be a function that produces +# the callable that constitutes the guard. However, there is some +# delicate handling for invalidating this check function when the +# locals/globals get invalidated, so there's some extra state +# we have to hold in this manager class. +class CheckFunctionManager: + def __init__( + self, + output_graph=None, + guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, + ): + guards = output_graph.guards if output_graph else None + self._weakrefs: Dict[int, ReferenceType[object]] = {} + self.guard_manager = None + if config.enable_cpp_guard_manager: + self.guard_manager = GuardManager() + self.output_graph = output_graph + w_builder = None + + self.torch_function_mode_stack = ( + output_graph.torch_function_mode_stack if output_graph else None + ) + + def source_ref(source): + guard_source = source.guard_source() + if guard_source is GuardSource.CONSTANT: + # No need to track constants + return source.name() + assert w_builder + r_builder = w_builder() + assert r_builder is not None + return r_builder.arg_ref(source.name()) + + builder = GuardBuilder( + self.id_ref, + source_ref, + self.lookup_weakrefs, + output_graph.local_scope, + output_graph.global_scope, + self.guard_manager, + self, + ) + + # Break retain cycle. See test_release_scope_memory + def cleanup_builder(weak_b): + b = weak_b() + if b: + b.scope = None + + # Break retain cycle. See test_release_input_memory + w_builder = weakref.ref(builder, cleanup_builder) + + guard_on_nn_modules = config.guard_nn_modules and justknobs_check( + "pytorch/compiler:guard_nn_modules" + ) + + if not justknobs_check("pytorch/compiler:guard_nn_modules"): + log.warning("guard_nn_modules is turned off using justknobs killswitch") + + for guard in sorted(guards or [], key=Guard.sort_key): + if ( + not guard_on_nn_modules + and guard.is_specialized_nn_module() + # Default func args must be guarded on. + # TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API + and "__defaults__" not in guard.name + and "__kwdefaults__" not in guard.name + and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name) + ): + continue + + guard.create(builder) + + self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn) + + # Keep track of weak references of objects with ID_MATCH guard. This + # info is stored alongside optimized_code and check_fn and is used to + # limit the number of cache entries with same ID_MATCH'd object. + # TODO(anijain2305) - Currently this information is stored as an attr on + # the check_fn itself to avoid changing CacehEntry datastructure in + # eval_frame.c. In future, we should probably replace check_fn with a + # queryable data structure such that this information is already present + # in some form. + self.check_fn.id_matched_objs = builder.id_matched_objs + + if config.enable_cpp_guard_manager: + # TODO: don't do the string rep, do something more structured here + torch._logging.trace_structured( + "dynamo_cpp_guards_str", payload_fn=lambda: str(self.guard_manager) + ) + guards_log.debug("%s", self.guard_manager) + assert self.guard_manager # to make mypy happy + self.guard_manager.id_matched_objs = builder.id_matched_objs + self.check_fn = self.guard_manager + + # Check that the guard returns True. False means that we will always + # recompile. + # TODO(anijain2305, ydwu4) - Skipping export because of following test + # python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs + if not output_graph.export: + if not self.guard_manager.check(output_graph.local_scope): + reasons = get_guard_fail_reason_helper( + self.guard_manager, # type: ignore[arg-type] + output_graph.local_scope, + CompileContext.current_compile_id(), + ) + raise AssertionError(f"Guard check failed: {reasons}") + + # NB - We have to very careful of cleaning up here. Because of the + # invalidate function, we can create a weakref finalizer that keeps + # `self` alive for very long. Sometimes by mistake, we can run + # invalidate for a type/object (check id_ref method) that Python can + # leak by design, preventing us from calling the finalizer. In that + # case, the `self` will be alive even though the cache entry will be + # deleted (check invalidate method), which can cause a memory leak, + # e.g., not setting output_graph = None can keep hold of nn_modules. + self._weakrefs.clear() + self.output_graph = None + + def compile_check_fn(self, builder, guards_out, guard_fail_fn): + # see parallel handling of ".0" / "___implicit0" in _eval_frame.c + largs = builder.argnames + largs += ["**___kwargs_ignored"] + + guards_log.debug("GUARDS:") + + code_parts = [] + verbose_code_parts = [] + structured_guard_fns: list[Callable[[], dict[str, Any]]] = [] + + torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard( + self.torch_function_mode_stack + ) + + if config.enable_cpp_guard_manager: + from .variables.torch_function import IGNORED_MODES + + # Insert the global_state guard + assert self.guard_manager # to make mypy happy + self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) + + self.guard_manager.root.add_torch_function_mode_stack_guard( + self.torch_function_mode_stack, + list(IGNORED_MODES), + ["___check_torch_function_mode_stack()"], + ) + # Clear references to torch_function modes held in the list + self.torch_function_mode_stack = None + else: + # Don't report this guard, it's always the same, useless! + global_guard = "___check_global_state()" + code_parts.append(global_guard) + verbose_code_parts.append(global_guard) + + tf_mode_stack_guard = "___check_torch_function_mode_stack()" + code_parts.append(tf_mode_stack_guard) + verbose_code_parts.append(tf_mode_stack_guard) + + def add_code_part(code_part, guard, log_only=False): + verbose_code_part = get_verbose_code_part(code_part, guard) + guards_log.debug("%s", verbose_code_part) + + structured_guard_fns.append( + lambda: { + "code": code_part, + "stack": structured.from_traceback(guard.stack.summary()) + if guard.stack + else None, + "user_stack": structured.from_traceback(guard.user_stack) + if guard.user_stack + else None, + } + ) + + if verbose_guards_log.isEnabledFor(logging.DEBUG): + maybe_stack = "" + maybe_user_stack = "" + if guard is not None: + if guard.stack: + maybe_stack = f"\nStack:\n{''.join(guard.stack.format())}" + if guard.user_stack: + maybe_user_stack = ( + f"\nUser stack:\n{''.join(guard.user_stack.format())}" + ) + verbose_guards_log.debug( + "Guard: %s%s%s", + code_part, + maybe_stack, + maybe_user_stack, + ) + + if not log_only: + code_parts.append(code_part) + verbose_code_parts.append(verbose_code_part) + + seen = set() + for gcl in builder.code: + for code in gcl.code_list: + if code not in seen: + # If Cpp guard manager is enabled, we don't need to add to + # code_parts. + add_code_part(code, gcl.guard, config.enable_cpp_guard_manager) + seen.add(code) + + tensor_check_names = builder.tensor_check_names + check_tensors_fn = None + check_tensors_verbose_fn = None + if tensor_check_names and not config.enable_cpp_guard_manager: + tensor_check_guards = builder.tensor_check_guards + assert ( + not self.output_graph.export + ), "Illegal to set tensor_check_names in export." + tensor_check_examples = builder.tensor_check_examples + + dynamic_dims_sizes = [] + dynamic_dims_strides = [] + for t, g in zip(tensor_check_examples, tensor_check_guards): + metadata = self.output_graph.input_source_to_sizes_strides[ + g.originating_source + ] + dynamic_dims_sizes.append(convert_to_concrete_values(metadata["size"])) + dynamic_dims_strides.append( + convert_to_concrete_values(metadata["stride"]) + ) + + tensor_guards = TensorGuards( + *tensor_check_examples, + dynamic_dims_sizes=dynamic_dims_sizes, + dynamic_dims_strides=dynamic_dims_strides, + ) + check_tensors_fn = tensor_guards.check + check_tensors_verbose_fn = tensor_guards.check_verbose + tensor_check_args = ", ".join( + tensor_check_names + ["tensor_check_names=tensor_check_names"] + ) + # Do this manually, to un-stagger the guards in log message + code_parts.append(f"___check_tensors({tensor_check_args})") + verbose_code_parts.append(f"___check_tensors({tensor_check_args})") + + for i, name in enumerate(tensor_check_names): + # This is a copy of what guards.cpp checks against + # Keep this in sync with TensorCheck constructor + t = tensor_check_examples[i] + sizes = dynamic_dims_sizes[i] + strides = dynamic_dims_strides[i] + code_part = get_tensor_guard_code_part(t, name, sizes, strides) + add_code_part(code_part, tensor_check_guards[i], log_only=True) + + if len(tensor_check_names) > 1 and config.enable_cpp_guard_manager: + # Install tensor aliasing guard. TENSOR_MATCH guards are already + # installed for cpp guard manager. + install_no_tensor_aliasing_guard( + builder.tensor_check_guard_managers, + tensor_check_names, + ["check_no_aliasing(" + ", ".join(tensor_check_names) + ")"], + ) + + aotautograd_guards: List[GuardEnvExpr] = ( + self.output_graph.tracing_context.guards_context.aotautograd_guards + if self.output_graph + else [] + ) + + # TODO(anijain2305) - There is a duplicate logic in Dynamo to find + # aliased input tensors. So most probably we don't need this here. + # Revisit. + for guard in aotautograd_guards: + if isinstance(guard, DuplicateInputs): + source_a = guard.input_source_a + source_b = guard.input_source_b + code_part = f"{source_a.name()} is {source_b.name()}" + if config.enable_cpp_guard_manager: + install_object_aliasing_guard( + builder.get_guard_manager_from_source(source_a), + builder.get_guard_manager_from_source(source_b), + [code_part], + ) + add_code_part(code_part, None, config.enable_cpp_guard_manager) + else: + raise RuntimeError(f"Unknown GuardEnvExpr: {guard}") + + # TODO: the "guard" here is actually just the top level SHAPE_ENV + # which is useless. Get ShapeEnv to pass in more provenance. + for gcl in builder.shape_env_code: + for code in gcl.code_list: + # Shape env guards are already added for CPP guard manager in + # SHAPE_ENV implementation. + add_code_part(code, gcl.guard, config.enable_cpp_guard_manager) + + # OK, all done generating guards + if structured_guard_fns: + torch._logging.trace_structured( + "dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns] + ) + + global_state = convert_frame.initial_global_state + if global_state is None: + # we should only hit this case in NopTests() + global_state = convert_frame.GlobalStateGuard() + closure_vars = { + "___check_tensors": check_tensors_fn, + "___check_tensors_verbose": check_tensors_verbose_fn, + "___check_global_state": global_state.check, + "___check_torch_function_mode_stack": torch_function_mode_stack_check_fn, + "tensor_check_names": tensor_check_names, + **SYMPY_INTERP, + **CLOSURE_VARS, + } + + globals_for_guard_fn = {"G": builder.scope["G"]} + if config.enable_cpp_guard_manager: + # Guard manager construction is complete + assert self.guard_manager # to make mypy happy + # TODO (anijain2305) - When enable_cpp_guard_manager is ON by + # default, change the guard_fn name to be guard_manager everywhere + # to avoid confusion. + guard_fn = self.guard_manager + # Ensure we did not miss to insert a guard in cpp guard manager. + assert len(code_parts) == 0 + else: + unique_code_parts = list(unique(code_parts)) + make_guard_fn_args = ", ".join(closure_vars.keys()) + guard_body, pycode = build_guard_function( + unique_code_parts, make_guard_fn_args + ) + + if os.environ.get("TORCHDYNAMO_PRINT_GUARDS", None) == "1": + print("GUARDS\n", guard_body) + + out: Dict[str, Any] = {} + + # We don't put builder.scope as the globals in exec call because + # guard_fn.__globals__ becomes equal to builder.scope. This causes + # guard_fn to hold a referece to f_locals sitting in builder.scope["L"] + try: + exec(pycode, globals_for_guard_fn, out) + except SyntaxError as ex: + log.exception("Failed to exec guard at line %s.\n%s", ex.lineno, pycode) + raise + guard_fn = out["___make_guard_fn"](*closure_vars.values()) + + guard_fn.closure_vars = closure_vars + # TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both + guard_fn.args = largs + if config.enable_cpp_guard_manager: + guard_fn.populate_code_parts_for_debugging() + else: + guard_fn.code_parts = code_parts + guard_fn.verbose_code_parts = verbose_code_parts + # Grab only G, but preserve "G" because guards access it as "G" + guard_fn.global_scope = globals_for_guard_fn + guard_fn.guard_fail_fn = guard_fail_fn + # will be populated by a non-owning reference to CacheEntry/ExtraState + # when the CacheEntry is constructed + guard_fn.cache_entry = None + guard_fn.extra_state = None + guard_fn.no_tensor_aliasing_sources = tensor_check_names + return guard_fn + + def invalidate(self): + # Some tests reveal that CheckFunctionManager has no attribute + # check_fn, but this case should not be of any concern. + # This case doesn't seem easy to repro. + if ( + hasattr(self, "check_fn") + and self.check_fn is not DeletedGuardFn + and (cache_entry := self.check_fn.cache_entry) is not None + and (extra_state := self.check_fn.extra_state) is not None + ): + assert isinstance(cache_entry, CacheEntry) + assert isinstance(extra_state, ExtraState) + extra_state.invalidate(cache_entry) + self.check_fn.cache_entry = None + self.check_fn.extra_state = None + self.check_fn = DeletedGuardFn + + def id_ref(self, obj): + """add a weakref, return the id""" + try: + if id(obj) not in self._weakrefs: + # We will clear the _weakrefs dict at the end of __init__ + # function, which will delete the callbacks as well. Therefore, + # we are using a finalizer which is kept alive. + self._weakrefs[id(obj)] = weakref.ref(obj) + weakref.finalize(obj, self.invalidate) + except TypeError: + pass # cannot weakref bool object + return id(obj) + + def lookup_weakrefs(self, obj): + """Lookup the _weakrefs created in id_ref function for ID_MATCH'd objects""" + if id(obj) in self._weakrefs: + return self._weakrefs[id(obj)] + return None + + +def build_guard_function(code_parts, closure_args) -> Tuple[str, str]: + from torch._inductor.utils import IndentedBuffer + + if HAS_UNPARSE_FUNCTIONS: + csepass = PyExprCSEPass() + csepass.count(code_parts) + + def replace(expr: str) -> Tuple[List[str], str]: + return csepass.replace(expr) + + else: + + def replace(expr: str) -> Tuple[List[str], str]: + return [], expr + + # Generate the inner body of the guard function. + # i.e. if-chain of the guard expressions. + guard_body = IndentedBuffer() + for expr in code_parts: + preface, expr = replace(expr) + guard_body.writelines(preface) + guard_body.writeline(f"if not ({expr}):") + with guard_body.indent(): + guard_body.writeline("return False") + + # Wrap the inner body into the actual guard function. + guard = IndentedBuffer() + guard.writeline("def guard(L):") + with guard.indent(): + guard.splice(guard_body) + guard.writeline("return True") + + # Wrap the whole guard function into another function + # with the closure variables. + make_guard_fn = IndentedBuffer() + make_guard_fn.writeline(f"def ___make_guard_fn({closure_args}):") + with make_guard_fn.indent(): + make_guard_fn.splice(guard) + make_guard_fn.writeline("return guard") + + return guard_body.getvalue(), make_guard_fn.getvalue() + + +def is_recompiles_enabled(): + return torch._logging._internal.log_state.is_artifact_enabled("recompiles") + + +def is_recompiles_verbose_enabled(): + return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose") + + +# this will only be used if cpp guards are disabled +def make_torch_function_mode_stack_guard(intial_stack): + types = [type(x) for x in intial_stack] + from .variables.torch_function import IGNORED_MODES + + def check_torch_function_mode_stack(): + cur_stack = get_torch_function_mode_stack() + if len(cur_stack) != len(types): + return False + + for ty, mode in zip(types, cur_stack): + if ty in IGNORED_MODES: + continue + if ty != type(mode): + return False + + return True + + return check_torch_function_mode_stack + + +def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope): + duplicate_tensors = [] + global_scope = dict(guard_manager.global_scope) + ids_to_source = collections.defaultdict(list) + for tensor_source in guard_manager.no_tensor_aliasing_sources: # type: ignore[attr-defined] + global_scope["__compile_source__"] = tensor_source + tensor_id = id(eval(tensor_source, global_scope, scope)) + ids_to_source[tensor_id].append(tensor_source) + + for key in ids_to_source: + if len(ids_to_source[key]) > 1: + duplicate_tensors.append(f"{ids_to_source[key]}") + + reason = ", ".join(duplicate_tensors) + return [f"Duplicate tensors found: {reason}"] + + +def get_guard_fail_reason_helper( + guard_fn: GuardFn, + f_locals: Dict[str, object], + compile_id: CompileId, +) -> str: + """ + Return the reason why `guard_fn` failed. + Updates `guard_failures` with the generated reason. + Only the first failed check of guard_fn is reported. + """ + scope = {"L": f_locals, "G": guard_fn.global_scope["G"]} + scope.update(guard_fn.closure_vars) + reasons: List[str] = [] + + no_tensor_aliasing_check_failed = False + + verbose_code_parts: List[str] = [] + if config.enable_cpp_guard_manager: + guard_manager = guard_fn + guard_debug_info = guard_manager.check_verbose(f_locals) # type: ignore[attr-defined] + # For test_export_with_map_cond, the check_verbose fail even without the + # C++ guard manager. We need to fix the issue to remove the comment. + # assert not guard_debug_info.result + if not guard_debug_info.result: + verbose_code_parts = guard_debug_info.verbose_code_parts + # verbose_code_parts is either the actual reason (e.g. in case of + # TENSOR_MATCH) or it could be a list of verbose_code_part that we + # passed to the leaf guard at construction time. If its a list, we + # walk through this list and find the guard that failed. This is + # very important for symbolic shape guards which are currently + # installed as a lambda guard and can encompass a long list of code_parts. + + if len(verbose_code_parts) == 1: + if "Duplicate tensor found" in verbose_code_parts[0]: + no_tensor_aliasing_check_failed = True + else: + reasons = verbose_code_parts + verbose_code_parts = [] + else: + verbose_code_parts = guard_fn.verbose_code_parts + # This is not needed for CPP guard because the verbose check is already + # run in C++. + scope["___check_tensors"] = scope["___check_tensors_verbose"] + + if no_tensor_aliasing_check_failed: + reasons = recompilation_reason_for_no_tensor_aliasing_guard(guard_fn, scope) + else: + for part in verbose_code_parts: + global_scope = dict(guard_fn.global_scope) + global_scope["__compile_source__"] = part + with report_compile_source_on_error(): + try: + fail_reason = eval(part, global_scope, scope) + except Exception as e: + if is_recompiles_verbose_enabled(): + continue + else: + raise + # Only ___check_tensors knows how to return a fancy fail reason; + # for everything else we just report the code that failed + + if isinstance(fail_reason, bool) and not fail_reason: + fail_reason = part + if isinstance(fail_reason, str): + reasons.append(fail_reason) + if not is_recompiles_verbose_enabled(): + break + + reason_str = f"{compile_id}: " + "; ".join(reasons) + return reason_str + + +def get_guard_fail_reason( + guard_fn: GuardFn, + code: types.CodeType, + f_locals: Dict[str, object], + compile_id: CompileId, +) -> str: + reason_str = get_guard_fail_reason_helper(guard_fn, f_locals, compile_id) + guard_failures[orig_code_map[code]].append(reason_str) + + try: + if guard_fn.guard_fail_fn is not None: + guard_fn.guard_fail_fn( + GuardFail(reason_str or "unknown reason", orig_code_map[code]) + ) + except Exception as e: + log.exception( + "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval", + ) + + return reason_str + + +def get_and_maybe_log_recompilation_reason( + cache_entry, frame: types.FrameType +) -> List[str]: + """ + Return the list of guard failure reasons using cache_entry. + Logs the recompilation reason if `recompiles` logging is enabled. + Raises a RecompileError if `config.error_on_recompile` is enabled. + """ + reasons = [] + while cache_entry is not None: + reason = get_guard_fail_reason( + cache_entry.check_fn, + cache_entry.code, + frame.f_locals, + cache_entry.compile_id, + ) + if reason: + reasons.append(reason) + cache_entry = cache_entry.next + + code = frame.f_code + + # at least one of "recompiles" or "recompiles_verbose" is enabled + do_recompiles_log = is_recompiles_enabled() or is_recompiles_verbose_enabled() + + if do_recompiles_log or config.error_on_recompile: + if is_recompiles_verbose_enabled(): + failures = "\n\n".join( + f"guard {i} failures:\n" + textwrap.indent(reason, "- ") + for i, reason in enumerate(reasons) + ) + else: + failures = textwrap.indent("\n".join(reasons), "- ") + guard_failure_details = ( + f"triggered by the following guard failure(s):\n{failures}" + ) + message = ( + f"Recompiling function {code.co_name} in {code.co_filename}:{code.co_firstlineno}\n" + f"{textwrap.indent(guard_failure_details, ' ')}" + ) + if do_recompiles_log: + if is_recompiles_verbose_enabled(): + recompiles_verbose_log.debug(message) + else: + recompiles_log.debug(message) + if config.error_on_recompile: + raise exc.RecompileError(message) + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "recompile_reasons", + "encoding": "json", + }, + payload_fn=lambda: reasons, + ) + + return reasons + + +def guard_error_hook( + guard_fn: GuardFn, + code: types.CodeType, + f_locals: Dict[str, object], + index: int, + last: bool, +): + print( + f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" + ) + print("lambda " + ", ".join(guard_fn.args) + ":") + print(" ", " and\n ".join(guard_fn.code_parts)) + + if config.enable_cpp_guard_manager: + print(guard_fn) + + local_scope = {"L": f_locals, **guard_fn.closure_vars} + for guard in guard_fn.code_parts: + try: + eval(guard, guard_fn.global_scope, local_scope) + except: # noqa: B001,E722 + print(f"Malformed guard:\n{guard}") + + +set_guard_error_hook(guard_error_hook) + + +def unique(seq): + seen = set() + for x in seq: + if x not in seen: + yield x + seen.add(x) + + +def make_dupe_guard(obj_source, dupe_source): + # Note - we may end up in a situation where we invoke something like + # def fn(x, y) + # with fn(x, x) + # Prior to the addition of tracking to all relevant objects, we would handle this just fine by + # eagerly re-entering VB and rewrapping inputs, correctly creating graphargs and placeholders. However, + # with tracking on inputs, duplicate inputs or aliased relationships may end up getting erased here - + # In the fn(x, x) example call above look like a graph with a single input. + # In order to ensure that we do not reuse fn(x, x) for fn(x, y), we create a duplicate input guard. + + # Note - we may not have a source, that is fine, it just means we had an object that is safe to have + # leave unsourced - like a local list created and discharged entirely within a local scope. + if dupe_source and dupe_source != obj_source: + ser_source_is_local = is_from_local_source(dupe_source) + source_is_local = is_from_local_source(obj_source) + if is_from_flatten_script_object_source( + dupe_source + ) or is_from_flatten_script_object_source(obj_source): + raise exc.UnsafeScriptObjectError( + f"{obj_source.name()} is alising {dupe_source.name()}. This is not supported." + f" Please do a clone for corresponding input." + ) + + # Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently + # reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here, + # so maybe we should do this refactor before we land this... + # TODO(voz): Combine local and global guard builders. + if ser_source_is_local == source_is_local: + # Note - this is a little aggressive - these being duplicate input does not always matter. + # However, this should always be a sound guard to add here. + return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source) + return None + + +def install_guard(*guards, skip=0): + """ + Add dynamo guards to the current tracing context. + + Args: + guards: guard(s) to add + skip: number of stack frames to ignore for debug stack trace + """ + from torch._guards import TracingContext + + collect_debug_stack = guards_log.isEnabledFor( + logging.DEBUG + ) or verbose_guards_log.isEnabledFor(logging.DEBUG) + add = TracingContext.get().guards_context.dynamo_guards.add + for guard in guards: + assert isinstance(guard, Guard) + add(guard, collect_debug_stack=collect_debug_stack, skip=skip + 1) diff --git a/lib/python3.10/site-packages/torch/_dynamo/hooks.py b/lib/python3.10/site-packages/torch/_dynamo/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..9371dee9d8184c85eb6378a23a8d7a6ae1b47604 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/hooks.py @@ -0,0 +1,12 @@ +import dataclasses +from typing import Callable, Optional + +from torch._guards import GuardsSet + +from .types import GuardFail + + +@dataclasses.dataclass +class Hooks: + guard_export_fn: Optional[Callable[[GuardsSet], None]] = None + guard_fail_fn: Optional[Callable[[GuardFail], None]] = None diff --git a/lib/python3.10/site-packages/torch/_dynamo/logging.py b/lib/python3.10/site-packages/torch/_dynamo/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..55bf1b1d199a518a4d225d65f889bd655caf69c9 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/logging.py @@ -0,0 +1,59 @@ +# mypy: allow-untyped-defs +import itertools +import logging + +from torch.hub import _Faketqdm, tqdm + + +# Disable progress bar by default, not in dynamo config because otherwise get a circular import +disable_progress = True + + +# Return all loggers that torchdynamo/torchinductor is responsible for +def get_loggers(): + return [ + logging.getLogger("torch.fx.experimental.symbolic_shapes"), + logging.getLogger("torch._dynamo"), + logging.getLogger("torch._inductor"), + ] + + +# Creates a logging function that logs a message with a step # prepended. +# get_step_logger should be lazily called (i.e. at runtime, not at module-load time) +# so that step numbers are initialized properly. e.g.: + +# @functools.lru_cache(None) +# def _step_logger(): +# return get_step_logger(logging.getLogger(...)) + +# def fn(): +# _step_logger()(logging.INFO, "msg") + +_step_counter = itertools.count(1) + +# Update num_steps if more phases are added: Dynamo, AOT, Backend +# This is very inductor centric +# _inductor.utils.has_triton() gives a circular import error here + +if not disable_progress: + try: + import triton # noqa: F401 + + num_steps = 3 + except ImportError: + num_steps = 2 + pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0) + + +def get_step_logger(logger): + if not disable_progress: + pbar.update(1) + if not isinstance(pbar, _Faketqdm): + pbar.set_postfix_str(f"{logger.name}") + + step = next(_step_counter) + + def log(level, msg, **kwargs): + logger.log(level, "Step %s: %s", step, msg, **kwargs) + + return log diff --git a/lib/python3.10/site-packages/torch/_dynamo/mutation_guard.py b/lib/python3.10/site-packages/torch/_dynamo/mutation_guard.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e21a1ebd295447c39e4e0d00bad3cad66ff160 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/mutation_guard.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code="method-assign" + +import functools +import weakref + +import torch.nn +from torch.nn import Module + +from . import config +from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks + + +unpatched_nn_module_init = torch.nn.Module.__init__ + + +class MutationTracker: + db = ExactWeakKeyDictionary() + + def __init__(self): + self.mutation_count = 0 + self.watchers = [] + + def on_mutation(self, name): + self.mutation_count += 1 + tmp = self.watchers + self.watchers = [] + for ref in tmp: + guarded = ref() + if guarded is not None: + guarded.invalidate(ref) + + def track(self, guarded_code): + self.watchers.append(weakref.ref(guarded_code)) + + +def watch(obj, guarded_code): + """invalidate guarded_code when obj is mutated""" + ensure_patched(type(obj)) + + if obj not in MutationTracker.db: + MutationTracker.db[obj] = MutationTracker() + tracker = MutationTracker.db[obj] + tracker.track(guarded_code) + + +def ensure_patched(cls): + if getattr(cls, "___needs_mutation_patch", True): + cls.___needs_mutation_patch = False + original_setattr = cls.__setattr__ + + @functools.wraps(original_setattr) + def custom_setattr(self, key, value): + try: + MutationTracker.db[self].on_mutation(key) + except KeyError: + pass + return original_setattr(self, key, value) + + cls.__setattr__ = custom_setattr + + +class GenerationTracker: + generation = 0 + dynamic_classes = ExactWeakKeyDictionary() + generation_values = ExactWeakKeyDictionary() + + @classmethod + def tag(cls, obj): + cls.generation_values[obj] = cls.generation + + @staticmethod + def mark_class_dynamic(cls): + assert issubclass(cls, torch.nn.Module) + GenerationTracker.dynamic_classes[cls] = True + + @classmethod + def get_generation_value(cls, obj): + if obj not in cls.generation_values: + return -1 + return cls.generation_values[obj] + + @classmethod + def check(cls, obj): + return ( + obj in cls.generation_values + and cls.generation_values[obj] == cls.generation + ) + + @classmethod + def clear(cls): + cls.generation = 0 + cls.dynamic_classes = ExactWeakKeyDictionary() + cls.generation_values = ExactWeakKeyDictionary() + + +def is_dynamic_nn_module(obj, is_export): + """Check for nn.Modules() created dynamically or mutated""" + if isinstance(obj, torch.nn.Module) and "forward" in obj.__dict__: + # A monkey patched `.forward` indicates something wacky is going on + return True + if hasattr(obj, "torchdynamo_force_dynamic"): + return obj.torchdynamo_force_dynamic + if is_lazy_module(obj): + return False + # For export, we will have to fix + # 1) Input signature problem because params are lifted as inputs + # 2) nn module stack info changes + # 3) adjust failing tests + if ( + isinstance(obj, torch.nn.Module) + and config.inline_inbuilt_nn_modules + and not is_export + ): + return True + + if isinstance(obj, torch.nn.Module) and nn_module_has_global_hooks(): + return True + dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check( + obj + ) + return dyn + + +def install_generation_tagging_init(): + """ + Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__ + so we can detect nn.Module instances created dynamically inside forward methods. + """ + + if getattr(Module, "___needs_generation_tag_patch", True): + init = Module.__init__ + + def patched_init(self, *args, **kwargs): + init(self, *args, **kwargs) + GenerationTracker.tag(self) + + Module.__init__ = patched_init + + setstate = Module.__setstate__ + + def patched_setstate(self, state): + setstate(self, state) + GenerationTracker.tag(self) + + Module.__setstate__ = patched_setstate + + Module.___needs_generation_tag_patch = False # type: ignore[attr-defined] + + GenerationTracker.generation += 1 diff --git a/lib/python3.10/site-packages/torch/_dynamo/output_graph.py b/lib/python3.10/site-packages/torch/_dynamo/output_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..5684de69601565f783c82f3be4905711a4882ae0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/output_graph.py @@ -0,0 +1,2190 @@ +# mypy: allow-untyped-defs +import collections +import contextlib +import copy +import dataclasses +import functools +import itertools +import json +import logging +import operator +import re +import sys +import traceback +import weakref +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union + +import sympy + +import torch._guards +import torch._logging +import torch.distributed as dist +import torch.nn +import torch.utils._pytree as pytree +from torch import fx +from torch._guards import GlobalContextCheckpointState, Source, TracingContext +from torch._utils_internal import signpost_event +from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined] +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from . import config, exc, logging as torchdynamo_logging, variables +from .backends.registry import CompiledFn, CompilerFn +from .bytecode_transformation import ( + create_call_function, + create_instruction, + Instruction, + unique_id, +) +from .code_context import code_context +from .codegen import PyCodegen +from .current_scope_id import enter_new_scope +from .exc import ( + BackendCompilerFailed, + exceptions_allowed_to_be_fallback, + SkipFrame, + unimplemented, + unimplemented_with_warning, +) +from .guards import GuardBuilder, install_guard +from .mutation_guard import is_dynamic_nn_module +from .side_effects import AttributeMutationExisting, SideEffects +from .source import ( + AttrSource, + BackwardStateSource, + ConstantSource, + GetItemSource, + GlobalStateSource, + is_constant_source, + is_from_local_source, + LocalSource, + ParamBufferSource, + ShapeEnvSource, + SyntheticLocalSource, + TensorProperty, + TensorPropertySource, +) +from .utils import ( + _extract_tensor_dict, + checkpoint_params, + CleanupHook, + clone_inputs, + count_calls, + counters, + dynamo_timed, + get_instruction_source_311, + get_locals_to_steal, + get_static_address_type, + get_torch_function_mode_stack, + graph_break_reasons, + increment_op_count, + lazy_format_graph_code, + LazyString, + nn_module_proxy, + same, + set_example_value, +) +from .variables.base import VariableTracker +from .variables.builder import ( + BackwardStateGraphArg, + GraphArg, + TrackedFake, + VariableBuilder, + wrap_fx_proxy, +) +from .variables.lists import BaseListVariable +from .variables.misc import NullVariable +from .variables.nn_module import NNModuleVariable +from .variables.tensor import ( + NumpyNdarrayVariable, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .variables.torch_function import TensorWithTFOverrideVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslatorBase + + +log = logging.getLogger(__name__) +graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph") +graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") +graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes") +trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") + + +@dataclass(frozen=True) +class VariableTrackerCacheKey: + vt_id: int + # Two different source can point to the same object. However, Dynamo handles + # globals and local source differently when it comes to guards and possibly + # some other parts as well. So, cache also relies on the source. + source: Source + + +class VariableTrackerCache: + def __init__(self): + self.cache = {} + + def lookup(self, value, source): + key = VariableTrackerCacheKey(id(value), source) + if key not in self.cache: + return None + return self.cache[key] + + def add(self, value, source, vt): + key = VariableTrackerCacheKey(id(value), source) + self.cache[key] = vt + + def clone(self): + # Needed for copy and restore graph state + new_cache = VariableTrackerCache() + new_cache.cache.update(self.cache) + return new_cache + + def clear(self): + self.cache.clear() + + +@functools.lru_cache(None) +def _step_logger(): + return torchdynamo_logging.get_step_logger(log) + + +@dataclass +class GraphCompileReason: + """Stores why a given output graph was compiled; i.e. what caused the graph break.""" + + reason: str + user_stack: List[traceback.FrameSummary] + + # Indicates if this was a graph compile reason due to graph break. + graph_break: bool = True + + def __post_init__(self): + if self.graph_break: + graph_break_reasons.append(self) + + +def _get_gen_rand_values_fn(random_calls): + def _gen_rand_values(): + return [fn(*args, **kwargs) for fn, args, kwargs in random_calls] + + return _gen_rand_values + + +class FakeRootModule(torch.nn.Module): + """Trick the constructor of fx.GraphModule""" + + def __init__(self, nn_modules: Dict[str, torch.nn.Module]): + super().__init__() + for k, v in nn_modules.items(): + setattr(self, k, v) + + def __repr__(self): + return "FakeRootModule(...)" + + +class WrapperBackend: + def __init__(self, backend: CompilerFn): + self.backend: CompilerFn = backend + + def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + self.restore = checkpoint_params(gm) + self.gm = gm + copy_gm = copy.deepcopy(self.gm) + self.candidate = self.backend(copy_gm, example_inputs) + + if self.candidate is None or self.candidate is self.gm.forward: + return self.gm.forward + + if not config.verify_correctness: + return self.candidate + + # if verify_correctness=True + try: + correct = self.gm.forward(*clone_inputs(example_inputs)) + result = self.candidate(*clone_inputs(example_inputs)) + + # TODO: replace `same` function with the one in testing + if same(correct, result): + return self.candidate + + raise RuntimeError(f"incorrect results of backend {self}") + return self.gm.forward + + except Exception: + log.exception("error in verify_correctness") + raise + finally: + self.restore() + + +Scope = Dict[str, object] + + +class OutputGraph: + """ + Wrapper class to hold outputs of InstructionTranslator. Mainly the + generated fx.Graph. + + OutputGraph is 1:1 with a frame being processed. Each frame is associated + with some root InstructionTranslator. When user code calls a function, + we construct a InliningInstructionTranslator that continues to write into + the root InstructionTranslator's OutputGraph. + """ + + def __init__( + self, + code_options: Dict[str, Any], + compiler_fn: Optional[CompilerFn], + root_tx, + export: bool, + export_constraints, + frame_state, + local_scope: Scope, + global_scope: Scope, + f_code, + ): + super().__init__() + self.tracers = [SubgraphTracer(self, export_root=export)] + # Map from graph input's `Source` to its `VariableTracker` to + # de-duplicate graph inputs by source and reuse the tracker + self.input_source_to_var: Dict[Source, VariableTracker] = {} + self.export = export + self.export_constraints = export_constraints + self.frame_state = frame_state + # Map from graph input's `Source` to sizes / strides metadata + self.input_source_to_sizes_strides: Dict[Source, Dict[str, Any]] = {} + self.cleanup_hooks: List[Callable[[], Any]] = [] + # compile_id is an id number for the current torch.compile + self.compile_id: int = next(_compile_id_counter) + # Set of globals installed via install_global* APIs + self.installed_globals: Set[str] = set() + + # TODO: maybe should just pass the entire f_code in here? Not + # sure... + self.co_fields = { + "co_name": f_code.co_name, + "co_filename": f_code.co_filename, + "co_firstlineno": f_code.co_firstlineno, + } + + # tracked_fakes says where any tensor that was wrapped to fake came + # from. It is similar to GraphArg, in that all GraphArgs will get + # will get added to TrackedFakes, but TrackedFakes also contains + # GraphArgs that got pruned, and things like Tensor attributes which + # aren't explicit graph inputs. Used by shape guard + self.tracked_fakes: List[TrackedFake] = [] + + # List of symbols for which we have exact bindings in the arguments + # already + self.bound_symbols: Set[sympy.Symbol] = set() + + shape_env = ShapeEnv( + # Reference Cycle! + # Share a reference to the list of TrackedFake. + # + # ShapeEnv needs this in order to be able to reproduce the call + # to produce_guards at an arbitrary time point. That is because + # TrackedFake instances may have its metadata changed throughout + # the program execution. + tracked_fakes=self.tracked_fakes, + allow_scalar_outputs=config.capture_scalar_outputs, + allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, + prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards, + allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts, + co_fields=self.co_fields, + ) + + # In export mode, we force the shape_env to strictly disallow any constraining + # of the user marked dynamic dims + import torch._functorch.config as _config + + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_mode = torch._subclasses.FakeTensorMode( + shape_env=shape_env, + # TODO (tmanlaibaatar) Remove this once we always lift params and buffers + allow_non_fake_inputs=True if self.export else False, + export=self.export, + ) + self.tracing_context: TracingContext = TracingContext(fake_mode) + self.init_ambient_guards() + + # Map each tensor id to a list of sources. This is necessary because + # tensor ids cannot be recovered from tracked fakes (in general). + # We use this map to interpret (i.e., check for violations of) constraints, + # specifically equality constraints, which have shared tensor ids in them. + # This map should also be generally useful, e.g., for (de)serialization. + self.tracked_fakes_id_to_source: Dict[ + int, List[Source] + ] = collections.defaultdict(list) + # Stores the full fqn of a param or buffer to the relevant source. + self.param_name_to_source: Optional[Dict[str, Source]] = {} + self.side_effects = SideEffects() + # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL + # and LOAD_ATTR for same python objects free. + self.variable_tracker_cache = VariableTrackerCache() + self.unique_var_id = itertools.count() + self.code_options = dict(code_options) + self.output_instructions: List[Instruction] = [] + # used to track nodes that are added between calls of copy_graphstate + # and restore_graphstate + self.timestamp = 0 + + # A list of register_finalizer_fns to apply to the output graph module + self.register_finalizer_fns: List[Callable[[fx.GraphModule], None]] = [] + + # Not checkpointed + self.compiler_fn: Optional[CompilerFn] = compiler_fn + self.global_scope = global_scope + self.local_scope = local_scope + self.root_tx = root_tx + + # Given a source, what are the user stacks of all locations that + # accessed it? + # + # For efficiency, we only populate this: + # - During export, and + # - If the source could potentially lead to a spurious export input + # + # Feel free to populate this more frequently if other use-cases arise, + # but be aware that we have to generate full stacks for each + # recording! + self.source_to_user_stacks: Dict[Source, List[traceback.StackSummary]] = {} + + self._current_tx: List[InstructionTranslatorBase] = [] + self.cleanups: List[CleanupHook] = [] + self.should_exit = False + self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {} + + # Note this returns true iff TF Mode and TF Subclasses are enabled + self.torch_function_enabled = torch._C._is_torch_function_enabled() + # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty + self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() + # This records the initial torch function mode stack for guarding + self.torch_function_mode_stack = get_torch_function_mode_stack() + + # Tracks if the output graph has a user defined allowed function in the + # graph. This is used later to determine if we should fallback to eager + # for certain exceptions. THe idea is that if the user has applied + # allow_in_graph, they would like to see the error instead of falling + # back for backend errors. + self.has_user_defined_allowed_in_graph = False + + # Tracks a list of called ops that were not tagged with "pt2_compliant_tag". + # This information is useful for logging. + self.non_compliant_ops: Set[torch._ops.OpOverload] = set({}) + + # Tracks a list of called custom ops that were tagged with "pt2_compliant_tag". + # This information is useful for logging. + self.compliant_custom_ops: Set[torch._ops.OpOverload] = set({}) + + # We save the global torch state here to be restored in case of graph + # breaks. The relevant issue is seen here + # https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086 + # where inlining of a function changes the global state (because of the + # presence of torch.no_grad) and there is a graph break. + self.save_global_state() + + # Tracks the original FQNs of the constant tensors from the original graph, + # i.e. buffers and parameters. + self.dynamo_flat_name_to_original_fqn: Dict[str, str] = {} + + # All calls to random() are replaced with a single call to __gen_rand_values + # functions that returns a tuple of random values for each original call. + # random_calls tracks calls to random() and random_values_var stores the name of + # the variable that stores __gen_rand_values results. + self.random_calls: List[ + Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]] + ] = [] + self.random_values_var = None + + # Bytecode to insert right before we call the graph + self.pregraph_bytecode: List[Instruction] = [] + + # Use to pass values to backward hooks when using compiled autograd + self.backward_state: Dict[str, VariableTracker] = {} + self.backward_state_proxy: Optional[torch.fx.Proxy] = None + self.backward_state_var: Optional[str] = None + + self.name_of_builtins_dict_key_in_fglobals: str = ( + self.install_builtins_dict_in_fglobals() + ) + + self.guard_on_key_order: Set[str] = set() + + def install_builtins_dict_in_fglobals(self): + # f_globals["__builtins__"] can be a dict or a module. This is an + # implemenation detail - + # https://docs.python.org/3/library/builtins.html. + + # This makes guarding on any builtin messy because the guard check_fn + # has to check if the __builtins__ is a module or dict, and then access + # by either using getattr or getitem respectively. + + # To solve this problem, we insert a new entry in f_globals which points + # to the builtins __dict__ and then we guard any builtin on this dict. + # To avoid any collision with the pre-existing keys, we use the + # install_global to give us a unique dict key. + + f_builtins = self.global_scope["__builtins__"] + if not isinstance(f_builtins, dict): + f_builtins = f_builtins.__dict__ + return self.install_global("__builtins_dict__", f_builtins) + + def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"): + name = f"{prefix}{len(self.backward_state)}" + assert name not in self.backward_state + self.backward_state[name] = hook + return name, self.get_backward_state_proxy() + + def get_backward_state_proxy(self): + if self.backward_state_proxy is None: + if self.export: + unimplemented("backward_state does not support export") + self.backward_state_proxy = self.root_tracer.create_graph_input( + "dynamo_backward_state", BackwardState, source=BackwardStateSource() + ) + self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg() + set_example_value(self.backward_state_proxy.node, BackwardState()) + self.backward_state_var = self.new_var() + return self.backward_state_proxy + + # This gets its own helper function so guards DEBUG logs are more informative + def init_ambient_guards(self): + # Register a SHAPE_ENV guard to make sure we setup shape guards + # that show up in ShapeEnv + self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV)) + + self.guards.add( + GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS) + ) + + self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE)) + + self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE)) + + self.guards.add( + GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE) + ) + + ci = torch._C._functorch.peek_interpreter_stack() + if ci is not None: + self.guards.add( + GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH) + ) + + def synthetic_graph_input(self, fn, args): + """ + call fn(*args) before the graph runs and turn the result into a fake input. + """ + example_value = fn(*args) + varname = self.new_var() + cg = PyCodegen(self.root_tx) + cg.add_push_null( + lambda: cg.load_import_from( + fn.__module__, + fn.__name__, + ) + ) + cg.foreach(map(variables.ConstantVariable.create, args)) + cg.call_function(len(args), False) + cg.store(varname) + self.pregraph_bytecode.extend(cg.get_instructions()) + source = SyntheticLocalSource(varname) + result = VariableBuilder(self.root_tx, source)(example_value) + TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( + source + ) + return result + + def add_cleanup_hook(self, fn: Callable[[], Any]): + self.cleanup_hooks.append(fn) + + def call_cleanup_hooks(self): + for hook in reversed(self.cleanup_hooks): + hook() + self.cleanup_hooks.clear() + + @property + def root_tracer(self): + return self.tracers[0] + + @property + def current_tracer(self): + return self.tracers[-1] + + def is_root_tracer(self): + # Helper to tell if we are inside the higher order operator tracing. + return len(self.tracers) == 1 + + @property + def graph(self): + return self.current_tracer.graph + + # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer. + @graph.setter + def graph(self, value): + self.current_tracer.graph = value + + @property + def input_name_to_proxy(self): + return self.current_tracer.input_name_to_proxy + + @property + def real_value_cache(self): + return self.current_tracer.real_value_cache + + # If you are here, and you're looking for create_graph_input, + # to avoid ambiguity, please call one of the following: + # - self.current_tracer.create_graph_input + # - self.root_tracer.create_graph_input + # See NOTE [HigherOrderOperator tracing design] for more context. + + def create_proxy(self, *args, **kwargs): + return self.current_tracer.create_proxy(*args, **kwargs) + + def create_node(self, *args, **kwargs): + return self.current_tracer.create_node(*args, **kwargs) + + def remove_node(self, *args, **kwargs): + return self.current_tracer.remove_node(*args, **kwargs) + + @contextlib.contextmanager + def subtracer(self, source_target, prior_tracer): + new_scope_ctx = enter_new_scope() + try: + if prior_tracer: + # Lineage MUST stay preserved + assert prior_tracer.parent is self.current_tracer + new_scope_ctx.__enter__() + tracer = ( + prior_tracer + if prior_tracer + else SubgraphTracer( + self, parent=self.current_tracer, source_target=source_target + ) + ) + self.tracers.append(tracer) + yield tracer + finally: + new_scope_ctx.__exit__(None, None, None) + self.tracers.pop() + + @property + def output(self): + return self + + @property + def fake_mode(self): + return self.tracing_context.fake_mode + + @property + def shape_env(self): + return self.tracing_context.fake_mode.shape_env + + @property + def guards(self) -> torch._guards.GuardsSet: + return self.tracing_context.guards_context.dynamo_guards + + @property + def nn_modules(self) -> Dict[str, Any]: + return self.tracing_context.module_context.nn_modules + + def save_global_state(self, out=None): + """ + Saves to out if it is provided. Else saves to the tracing context's global_state. + """ + global_state = ( + out if out is not None else self.tracing_context.global_context.global_state + ) + + # TODO - Consider having a torch level API for torch_function_state. As + # of now, we create a ref cycle by passing the + # output.set_torch_function_state to + # output.tracing_context.global_context.global_state. In the interim, + # the problem can be solved by manually set + # output.tracing_context.global_context.global_state to None at cleanup. + global_state["torch_function_enabled"] = ( + self.set_torch_function_state, + self.torch_function_enabled, + ) + global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled()) + + global_state["autocast_enabled"] = ( + functools.partial(torch.set_autocast_enabled, "cuda"), + torch.is_autocast_enabled("cuda"), + ) + global_state["autocast_cpu_enabled"] = ( + functools.partial(torch.set_autocast_enabled, "cpu"), + torch.is_autocast_enabled("cpu"), + ) + global_state["autocast_gpu_dtype"] = ( + functools.partial(torch.set_autocast_dtype, "cuda"), + torch.get_autocast_dtype("cuda"), + ) + global_state["autocast_cpu_dtype"] = ( + functools.partial(torch.set_autocast_dtype, "cpu"), + torch.get_autocast_dtype("cpu"), + ) + global_state["autocast_cache_enabled"] = ( + torch.set_autocast_cache_enabled, + torch.is_autocast_cache_enabled(), + ) + + def push_tx(self, tx): + self._current_tx.append(tx) + + def pop_tx(self): + return self._current_tx.pop() + + @property + def current_tx(self): + return self.root_tx if not self._current_tx else self._current_tx[-1] + + def add_symbol_bindings(self, arg: GraphArg): + # Insert implicit size vars as necessary. With dynamic shapes, we + # maintain the invariant that every sizevar gets a direct SymInt input + # into the graph. This means downstream graph transforms can assume + # every size variable is explicitly bound and accessible, instead of + # having to pull it out implicitly from tensors. + + if self.export: + return + + assert arg.fake_tensor is not None + + def bind_symint(s, prop): + if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)): + return + s0 = s.node.expr + if s0 in self.bound_symbols: + return + self.bound_symbols.add(s0) + log.debug("bind_symint %s %s", s, prop.name()) + # TODO: don't readd symint if we already have it in graph + # (this is harmless because we do remove the unused ones later) + proxy = self.root_tracer.create_graph_input( + str(s0), + torch.SymInt, + before=True, + source=prop, + ) + set_example_value(proxy.node, s) + proxy.node.meta["grapharg"] = GraphArg( + prop, + s, + pass_arg_as_tensor=False, + fake_tensor=None, + is_tensor=False, + ) + + def handle_tensor(t, src): + for i, s in enumerate(t.size()): + bind_symint(s, TensorPropertySource(src, TensorProperty.SIZE, i)) + if t.layout is torch.strided: + for i, s in enumerate(t.stride()): + bind_symint(s, TensorPropertySource(src, TensorProperty.STRIDE, i)) + bind_symint( + t.storage_offset(), + TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), + ) + elif t.layout is torch.sparse_coo: + handle_tensor(t._indices(), src) + handle_tensor(t._values(), src) + elif t.layout in {torch.sparse_csr, torch.sparse_bsr}: + handle_tensor(t.crow_indices(), src) + handle_tensor(t.col_indices(), src) + elif t.layout in {torch.sparse_csc, torch.sparse_bsc}: + handle_tensor(t.ccol_indices(), src) + handle_tensor(t.row_indices(), src) + if is_traceable_wrapper_subclass(t): + attrs, ctx = t.__tensor_flatten__() + for attr in attrs: + inner_t = getattr(t, attr) + handle_tensor(inner_t, AttrSource(src, attr)) + + handle_tensor(arg.fake_tensor, arg.source) + + def count_calls(self): + return count_calls(self.graph) + + def is_empty_graph(self): + return len(list(self.graph.nodes)) == 0 + + def get_submodule(self, keys): + assert keys + obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules + for k in keys.split("."): + if isinstance(obj, dict): + obj = obj[k] + else: + obj = getattr(obj, k) + return obj + + def new_var(self, name="tmp"): + existing = set(self.code_options["co_varnames"]) + # In common case, this will be O(1) + while True: + var = f"{name}_{next(self.unique_var_id)}" + if var not in existing: + self.code_options["co_varnames"] += (var,) + return var + + def update_co_names(self, name): + """Ensure self.code_options.co_names contains name""" + if name not in self.code_options["co_names"]: + self.code_options["co_names"] += (name,) + + @staticmethod + def module_key_name(*names): + # create a new unique name + name = "_".join(map(str, names)) + # Strip the guard lookup L/G access + name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name) + # e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv + name = re.sub(r"\[(\d+)\]", r"_\g<1>", name) + # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv + name = re.sub(r"[^a-zA-Z0-9]", "_", name) + + if not name or not name[0].isalpha(): + name = "sub" + name + + return name + + def register_attr_or_module( + self, + target: Union[torch.nn.Module, torch.Tensor, Any], + *names, + **options, + ): + if is_dynamic_nn_module(target, self.root_tx.export): + # Instead of returning UnspecializedNNModuleVariable, call + # VariableBuilder so that it is tracked for mutation. + return VariableBuilder(self.current_tx, **options)(target) + + options = dict(options) + assert "source" in options + source = options["source"] + assert not isinstance(source, ParamBufferSource) + + if isinstance(target, torch.Tensor): + tracer = self.current_tracer + if not self.is_root_tracer(): + # For higher order ops, we don't want to insert the get_attr in + # innermost graph. Instead, we want to raise the params/buffers + # as inputs to the higher-order graph, and register them as + # get_attrs in the root tracer. + + # Note that Dynamo will still call lift_tracked_freevar_to_input + # when these inputs are encountered for the inner graph. The + # only difference is what happens at the root tracer for + # nn.Parameters vs free inputs. The free inputs are registered + # as placeholders in the root graph, whereas the nn.Parameters + # are registered as get_attr nodes in the root graph. + tracer = self.root_tracer + + def wrap_name(module_key): + assert self.param_name_to_source is not None + self.param_name_to_source[module_key] = source + + # Check if the attr has already been registered. This can happen + # when two different sources point to the same tensor. + if target in self.root_tx.output.side_effects: + return self.root_tx.output.side_effects[target] + + if get_static_address_type(target) == "guarded": + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + elif not is_constant_source(source): + install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH)) + + vt = wrap_fx_proxy( + self.root_tx, + tracer.create_proxy("get_attr", module_key, (), {}), + example_value=target, + **options, + ) + + # Track the object so to avoid duplicate registration in case of + # different sources pointing to the same tensor object. + vt = self.root_tx.output.side_effects.track_object_existing(target, vt) + + assert "tensor_dict" not in vt.proxy.node.meta + vt.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(target) + + return vt + + elif isinstance(target, torch.nn.Module): + assert isinstance(target, torch.nn.Module) + + if source: + install_guard(source.make_guard(GuardBuilder.NN_MODULE)) + + def wrap_name(module_key): + return NNModuleVariable(type(target), module_key, target, **options) + + else: + # This is Dynamo created graph module, e.g., graph module coming + # from higher order ops. NNModuleVariable tracker can't be + # sourceless, so let's return a unspecializedNNModule variable + # tracker. + def wrap_name(module_key): + return variables.UnspecializedNNModuleVariable(target, **options) + + elif isinstance(target, (torch.SymInt, torch.SymFloat)): + # HACKY CODE REGION BEGIN + # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS + # This ultimately gets written to self.nn_modules, which is unfortunate + # Attrs that are tenors and symints and such need to be migrated to have their + # own storage + # alas, this is like this for now + + def wrap_name(module_key): + return SymNodeVariable.create( + self, + self.create_proxy("get_attr", module_key, (), {}), + sym_num=target, + **options, + ) + + # HACKY CODE REGION END + else: + + def wrap_name(module_key): + self.output.update_co_names(module_key) + self.global_scope[module_key] = target + return VariableBuilder(self, ConstantSource(source_name=module_key))( + target + ) + + for k, v in self.nn_modules.items(): + if v is target: + # it already exists + return wrap_name(k) + + name = OutputGraph.module_key_name(*names) + + base = name + for i in itertools.count(): + if name not in self.nn_modules: + self.nn_modules[name] = target + if isinstance(target, torch.nn.Module): + + def register_leaf_name(leaf_name): + assert self.param_name_to_source is not None + new_source = ParamBufferSource(source, leaf_name) + new_name = f"{name}.{leaf_name}" + self.param_name_to_source[new_name] = new_source + if isinstance(source, LocalSource): + self.dynamo_flat_name_to_original_fqn[ + OutputGraph.module_key_name(new_source.name()) + ] = leaf_name + + # annoying, but there are cases when we do not have parameters + # see test_nn_moduledict_contains + if hasattr(target, "_parameters"): + for leaf_name, _ in target.named_parameters(): + register_leaf_name(leaf_name) + if hasattr(target, "_buffers"): + for leaf_name, _ in target.named_buffers(): + register_leaf_name(leaf_name) + + return wrap_name(name) + name = f"{base}_{i}" + + raise AssertionError("unreachable") + + def handle_aliases_for_stolen_lists(self, tx): + # If list inputs are stolen, but still needed after the function call, create aliases to keep them alive + maybe_gm = self.local_scope.get("self") + stolen_list_names = get_locals_to_steal(maybe_gm) + if not stolen_list_names: + return [] + + alias_insts = [] + needs_alias: Dict[ + str, List[Union[VariableTracker, AttributeMutationExisting]] + ] = {} + + queue = [ + *tx.stack, + *tx.symbolic_locals.values(), + *self.side_effects.store_attr_mutations.keys(), + ] + + while queue: + x = queue.pop() + if isinstance(x, BaseListVariable): + assert isinstance(x.items, List) + queue += x.items + continue + + if not ( + isinstance(x, (VariableTracker, AttributeMutationExisting)) + and isinstance(x.source, GetItemSource) + and isinstance(x.source.base, LocalSource) + and x.source.base.local_name in stolen_list_names + ): + continue + + stolen_name = x.source.base.local_name + if stolen_name not in needs_alias: + needs_alias[stolen_name] = [] + needs_alias[stolen_name].append(x) + + visited = {} + for arg in self.graphargs: + if not ( + isinstance(arg._example, list) + and isinstance(arg.source, LocalSource) + and arg.source.local_name in needs_alias + ): + continue + + # arg is a list that will be cleared by the compiled function + list_name = arg.source.local_name + assert list_name in self.code_options["co_varnames"] + for x in needs_alias[list_name]: + list_idx = x.source.index + if list_idx not in visited: + alias_name = self.new_var( + f"{list_name}_ref" + ) # self.new_var already adds unique id suffix + + visited[list_idx] = alias_name + # bytecode of `alias_name = list_name[list_idx]` + alias_insts.extend( + [ + create_instruction("LOAD_FAST", argval=list_name), + create_instruction("LOAD_CONST", argval=list_idx), + create_instruction("BINARY_SUBSCR"), + create_instruction("STORE_FAST", argval=alias_name), + ] + ) + + # operate on alias, handled by suffix codegen + x.source = LocalSource(visited[list_idx]) + + return alias_insts + + def compile_subgraph( + self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None + ): + """ + Generate a subgraph to continue execution on user code. + Automatically restore live variables. + """ + assert reason is not None + + from .decorators import disable + + self.partial_convert = partial_convert + self.compile_subgraph_reason = reason + self.should_exit = True + + log.debug("COMPILING GRAPH due to %s", reason) + + if not all(block.can_restore() for block in tx.block_stack): + unimplemented("compile_subgraph with block_depth != 0") + + prefix_insts: List[Instruction] = [] + if sys.version_info >= (3, 11): + # prefix instructions (Python 3.11+) + for inst in tx.prefix_insts: + if inst.opname == "MAKE_CELL": + prefix_insts.append( + create_instruction("MAKE_CELL", argval=inst.argval) + ) + elif inst.opname == "COPY_FREE_VARS": + prefix_insts.append( + create_instruction( + "COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"]) + ) + ) + else: + prefix_insts.append(copy.copy(inst)) + assert not ( + self.pregraph_bytecode and self.export + ), "export does not support pregraph_bytecode" + prefix_insts.extend(self.pregraph_bytecode) + prefix_insts.extend(self.handle_aliases_for_stolen_lists(tx)) + + def append_prefix_insts(): + self.add_output_instructions(prefix_insts) + prefix_insts.clear() + + for block in reversed(tx.block_stack): + block.exit(tx) + + self.cleanup_graph() + tx.prune_dead_locals() + stack_values = list(tx.stack) + + # realize any unrealized tensor VTs in case they + # need to be added to self.nn_modules as attributes + for value in stack_values: + value.realize() + + # Use nn.Module "proxies" in the constructed GraphModule so that + # the resulting GM does not hold additional strong references to the original modules. + # This prevents a strong ref cycle where Dynamo created code holds on to references + # to modules that also have Dynamo code cache invalidation checks. + # When cache invalidation runs, the generated GM will be invalidated, which also deletes + # the proxies. + nn_modules_proxies = { + name: nn_module_proxy(mod) for name, mod in self.nn_modules.items() + } + root = FakeRootModule(nn_modules_proxies) + # Add all the local vars to the "stack" so restore at the end + restore_vars = [] + val_to_names: Dict[VariableTracker, List[str]] = {} + if stack_values: + val_to_names[stack_values[-1]] = [] + # NB: Typically (i.e., for graph compile from RETURN_VALUE), + # symbolic_locals will be empty at this point, as prune_dead_locals + # will clear out all of symbolic_locals because RETURN_VALUE is the + # last instruction and no more locals are used. The fanciness here + # is only needed for partial graphs. + for k, v in tx.symbolic_locals.items(): + # Note! this explicitly uses .local_name for matching + # Failure to do so will cause spurious registrations in val_to_names. + # This will in turn result in spurious variables showing up in the graph. + # This was very tricky to debug. For an example, dump the graph at call_user_compiler + # while running test_subgraphs.py + if isinstance(v.source, LocalSource) and v.source.local_name == k: + continue # no need to restore initial state + # Do not load variable if it is NULL. + if sys.version_info >= (3, 12): + # Continuation function will load the NULL for v. + if type.__instancecheck__(NullVariable, v): + continue + else: + # A variable should never be NULL in < 3.12 + assert not type.__instancecheck__(NullVariable, v) + if v not in val_to_names: + val_to_names[v] = [] + val_to_names[v].append(k) + for v in val_to_names.keys(): + restore_vars.extend(val_to_names[v]) + stack_values.extend([v] * len(val_to_names[v])) + + # to handle random calls + if len(self.random_calls) > 0: + append_prefix_insts() + random_calls_instructions = [] + self.random_values_var = self.new_var("random_values") + rand_fn = disable(_get_gen_rand_values_fn(self.random_calls)) + rand_fn_name = self.install_global("__gen_rand_values", rand_fn) + codegen = PyCodegen(tx, root) + random_calls_instructions.extend( + codegen.load_function_name(rand_fn_name, True) + ) + random_calls_instructions.extend(create_call_function(0, False)) + random_calls_instructions.append( + codegen.create_store(tx.output.random_values_var), + ) + self.add_output_instructions(random_calls_instructions) + + if ( + stack_values + and all( + not isinstance( + v, + ( + UnspecializedPythonVariable, + NumpyNdarrayVariable, + TensorWithTFOverrideVariable, + ), + ) + and not (isinstance(v, SymNodeVariable) and v.python_type() is float) + for v in stack_values + ) + and all(isinstance(x, TensorVariable) for x in stack_values) + and len(set(stack_values)) == len(stack_values) + and self.side_effects.is_empty() + and not len(tx.debug_locals) != 0 + and not self.backward_state + ): + append_prefix_insts() + # optimization to generate better code in a common case + self.add_output_instructions( + self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root) + + [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))] + ) + # restore all the live local vars + self.add_output_instructions( + [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)] + ) + else: + graph_output_var = self.new_var("graph_out") + pass1 = PyCodegen(tx, root, graph_output_var) + self.codegen_suffix(tx, stack_values, pass1) + + # one more time now that we have established tempvars + pass2 = PyCodegen( + tx, + root, + graph_output_var, + tempvars={val: None for val, count in pass1.uses.items() if count > 1}, + ) + self.codegen_suffix(tx, stack_values, pass2) + + stored_graph_output_var = False + output = [] + if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0: + output.extend( + self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root) + ) + + if len(pass2.graph_outputs) != 0: + output.append(pass2.create_store(graph_output_var)) + stored_graph_output_var = True + else: + output.append(create_instruction("POP_TOP")) + else: + # NB: Important to run compiler collective even when there is + # a graph break + self.run_compiler_collective(tx) + append_prefix_insts() + self.add_output_instructions(output + pass2.get_instructions()) + + # restore all the live local vars + self.add_output_instructions( + [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)] + ) + + if stored_graph_output_var: + self.add_output_instructions( + [PyCodegen(tx).create_delete(graph_output_var)] + ) + + def codegen_suffix(self, tx, stack_values, cg): + if self.backward_state: + assert not self.export + for name, val in self.backward_state.items(): + cg(val) + cg.append_output(cg.create_load(self.backward_state_var)) + cg.store_attr(name) + self.side_effects.codegen_hooks(cg) + self.side_effects.codegen_save_tempvars(cg) + + # Return variables used for logging at the end + for debug_var, args in tx.debug_locals: + cg.add_push_null(lambda: cg(debug_var)) + for arg in args: + cg(arg) + cg.extend_output(create_call_function(len(args), False)) + cg.extend_output([create_instruction("POP_TOP")]) + + cg.restore_stack(stack_values, value_from_source=not tx.export) + self.side_effects.codegen_update_mutated(cg) + + def cleanup_graph(self): + """ + Remove "creation_timestamp" from node meta + + Remove this pattern from the graph: + torch._C._set_grad_enabled(False) + torch._C._set_grad_enabled(True) + """ + assert self.should_exit + nodes = list(self.graph.nodes) + for node in nodes: + node.meta.pop("creation_timestamp", None) + + grad_enabled = torch.is_grad_enabled() + for node1, node2 in zip(nodes, nodes[1:]): + if ( + node1.target is torch._C._set_grad_enabled + and tuple(node1.args) == (not grad_enabled,) + and not node1._erased + ): + grad_enabled = node1.args[0] + if ( + node2.target is torch._C._set_grad_enabled + and tuple(node2.args) == (not grad_enabled,) + and not node2._erased + ): + grad_enabled = node2.args[0] + self.graph.erase_node(node1) + self.graph.erase_node(node2) + + def get_graph_sizes_structured(self): + ret = {} + for node in self.graph.nodes: + example_value = node.meta.get("example_value", None) + if isinstance(example_value, torch._subclasses.FakeTensor): + size = example_value.size() + ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size] + return ret + + def get_graph_sizes(self, name: str): + graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n" + graph_sizes_str += f"===== {name} =====\n" + for node in self.graph.nodes: + example_value = node.meta.get("example_value", None) + if isinstance(example_value, torch._subclasses.FakeTensor): + size = example_value.size() + graph_sizes_str += f"{node.name}: {tuple(size)}\n" + concrete_size = [] + has_symint = False + for sz in size: + if isinstance(sz, int): + concrete_size.append(sz) + elif isinstance(sz, torch.SymInt): + has_symint = True + concrete_size.append(sz.node.hint) + else: + break + else: + if has_symint: + graph_sizes_str += ( + f"{node.name} (concrete): {tuple(concrete_size)}\n" + ) + return graph_sizes_str + + @contextlib.contextmanager + def restore_global_state(self): + """ + Momentarily restores the global state to what it was prior to tracing the current output + """ + prior_global_state = self.tracing_context.global_context.copy_graphstate() + current_global_state: Dict[str, Tuple[Any, bool]] = {} + self.save_global_state(out=current_global_state) + try: + # Set to state prior to tracing the graph + self.tracing_context.global_context.restore_graphstate(prior_global_state) + yield + finally: + # Reset to state at the current time (e.g. before calling the user compiler) + self.tracing_context.global_context.restore_graphstate( + GlobalContextCheckpointState(current_global_state) + ) + + def run_compiler_collective(self, tx): + if (ds := tx.distributed_state) is not None and ds.all_states is None: + compile_pg = ds.compile_pg + log.info("compiler_collective %s", ds.local_state) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "compiler_collective", + "encoding": "json", + }, + payload_fn=lambda: json.dumps( + dataclasses.asdict(ds.local_state), + ), + ) + with torch.cuda.device(compile_pg.rank() % torch.cuda.device_count()): + all_states = [None] * compile_pg.size() + dist.all_gather_object(all_states, ds.local_state, group=compile_pg) + ds.all_states = all_states + # Clear speculation log, because are tracing may diverge due to + # this information from the compiler collective + tx.speculation_log.clear() + raise exc.CompileCollectiveRestartAnalysis + + def compile_and_call_fx_graph(self, tx, rv, root): + """ + Generate code from self.graph and return the Instruction()s to + call that generated code. + """ + with torch._guards.TracingContext.clear_frame(): + from .decorators import disable + + assert self.should_exit + + self.run_compiler_collective(tx) + + name = unique_id("__compiled_fn") + + assert isinstance(rv, list) + assert isinstance(root, FakeRootModule) + output_node = self.create_node( + "output", + "output", + (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),), + {}, + ) + tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node) + if not config.do_not_emit_runtime_asserts: + insert_deferred_runtime_asserts( + fx.GraphModule(root, self.graph), + self.shape_env, + name, + ) + # NB: deferred runtime asserts can keep graphargs live, so make sure + # those are inserted before pruning + self.remove_unused_graphargs() + ncalls = count_calls(self.graph) + counters["stats"]["calls_captured"] += ncalls + + # free a bit of memory + self.real_value_cache.clear() + + gm = _make_graph_module(root, self.graph) + for register_finalizer in self.register_finalizer_fns: + register_finalizer(gm) + + gm.compile_subgraph_reason = self.compile_subgraph_reason + gm.meta[ + "dynamo_flat_name_to_original_fqn" + ] = self.dynamo_flat_name_to_original_fqn.copy() + + graph_code_log.debug( + "%s", + lazy_format_graph_code( + name, gm, include_stride=True, include_device=True, colored=True + ), + ) + torch._logging.trace_structured( + "dynamo_output_graph", + lambda: {"sizes": self.get_graph_sizes_structured()}, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + self.call_cleanup_hooks() + old_fake_mode = self.tracing_context.fake_mode + if not self.export: + import torch._functorch.config as _config + + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting + backend_fake_mode = torch._subclasses.FakeTensorMode( + shape_env=old_fake_mode.shape_env, + ) + # TODO(voz): Ostensibily, this should be scoped and + # restore back to old_fake_mode, but doing so currently violates + # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode + self.tracing_context.fake_mode = backend_fake_mode + + with self.restore_global_state(): + compiled_fn = self.call_user_compiler(gm) + + from torch.fx._lazy_graph_module import _LazyGraphModule + + if isinstance(compiled_fn, _LazyGraphModule) or ( + isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule) + and compiled_fn.__name__ == "_lazy_forward" # type: ignore[attr-defined] + ): + # Since dynamo will run the forward method for the GraphModule shortly + # anyways, it does not hurt to do the real recompilation here if + # this is a _LazyGraphModule. This makes it easier for dynamo to + # optimize a _LazyGraphModule. + + lazy_gm = ( + compiled_fn + if isinstance(compiled_fn, _LazyGraphModule) + else compiled_fn.__self__ # type: ignore[attr-defined] + ) + + _LazyGraphModule.force_recompile(lazy_gm) + + if not isinstance(compiled_fn, _LazyGraphModule): + # replace compiled_fn with the real forward method + compiled_fn = lazy_gm.forward + + compiled_fn = disable(compiled_fn) + + counters["stats"]["unique_graphs"] += 1 + # This is safe because we pre-process name to be unique + self.install_global_unsafe(name, compiled_fn) + + cg = PyCodegen(tx) + cg.make_call_generated_code(name) + return cg.get_instructions() + + @property + def placeholders(self) -> List[fx.Node]: + return self.graph.find_nodes(op="placeholder") + + @property + def graphargs(self) -> List[GraphArg]: + return [node.meta["grapharg"] for node in self.placeholders] + + def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: + with dynamo_timed( + "OutputGraph.call_user_compiler", phase_name="backend_compile" + ): + return self._call_user_compiler(gm) + + def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: + assert self.compiler_fn is not None + tot = 0 + placeholders = [] + for node in gm.graph.nodes: + if node.op in ("call_function", "call_method", "call_module"): + tot += 1 + if node.op == "placeholder": + placeholders.append(node) + increment_op_count(tot) + for pl in placeholders: + arg = pl.meta["grapharg"] + # TODO: Why isn't this stored in meta :think: + pl._dynamo_source = arg.source + + gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment] + gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment] + + try: + name = ( + self.compiler_fn.__name__ + if hasattr(self.compiler_fn, "__name__") + else "" + ) + _step_logger()(logging.INFO, f"calling compiler function {name}") + compiler_fn = self.compiler_fn + if config.verify_correctness: + compiler_fn = WrapperBackend(compiler_fn) + compiled_fn = compiler_fn(gm, self.example_inputs()) + _step_logger()(logging.INFO, f"done compiler function {name}") + assert callable(compiled_fn), "compiler_fn did not return callable" + except exceptions_allowed_to_be_fallback as e: + if self.has_user_defined_allowed_in_graph: + raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( + e.__traceback__ + ) from None + msg = ( + "Backend compiler failed with a fake tensor exception at \n" + f"{self.root_tx.format_frame_summary()}" + "Adding a graph break." + ) + unimplemented_with_warning(e, self.root_tx.f_code, msg) + except SkipFrame as e: + # The backend compiler has requested that we skip the frame, instead of + # aborting execution. + raise e + except Exception as e: + raise BackendCompilerFailed(self.compiler_fn, e) from e + + signpost_event( + "dynamo", + "OutputGraph.call_user_compiler", + { + **self.co_fields, + "op_count": tot, + "node_count": len(gm.graph.nodes), + "input_count": len(placeholders), + }, + ) + + return compiled_fn + + def example_inputs(self) -> List[torch.Tensor]: + result = [] + for arg in self.graphargs: + result.append(arg.example) + return result + + def remove_unused_graphargs(self) -> None: + # NB: It's always OK to drop GraphArg for symbols that ended up being + # specialized. You don't even have to make a guard for it, because + # ShapeEnv produce_guards operates on tracked_fakes, which never gets + # pruned. That being said, you'll get marginally better generated + # guard code if you promote the guard into a Dynamo guard (since that + # allows for the guard to be done using C++ guards.) If we get + # ShapeEnv guards to go into C++ guards, this will stop being a thing + # though! + + assert self.should_exit + + # Miniature DCE pass, but only for obviously trivial operations + def is_static_true(b_node: fx.node.Argument): + if b_node is True: + return True + if not isinstance(b_node, fx.Node): + return False + b = b_node.meta.get("example_value") + if b is None: + return False + if b is True: + return True + if ( + isinstance(b, torch.SymBool) + and (r := b.node.maybe_as_bool()) is not None + ): + return r + # TODO: We can also technically remove all cases when the input + # doesn't have unbacked inputs, since it's all in the ShapeEnv + return False + + def is_symnode_arg(a: fx.node.Argument): + from torch.fx.experimental.sym_node import SymTypes + + if isinstance(a, (int, float, bool)): + return True + if isinstance(a, fx.Node): + return isinstance(a.meta.get("example_value"), SymTypes) + return False + + # NB: We assume that you cannot do mutations on int/float/bool, + # because they are immutable types, and therefore is always safe to + # DCE. + def is_symnode_compute_node(node): + from torch.fx.experimental.sym_node import SymTypes + + if node.op != "call_function": + return False + # TODO: I don't think it's possible to have a bare int/float here? + if not isinstance(node.meta.get("example_value"), SymTypes): + return False + # TODO: This will bail here if you ever end up with a more complicated + # computation function, like sum(list_of_ints), even though it + # should be DCE'able + if not all(is_symnode_arg(a) for a in node.args): + return False + if not all(is_symnode_arg(a) for a in node.kwargs.values()): + return False + return True + + from torch.fx.experimental.symbolic_shapes import is_accessor_node + + for node in reversed(list(self.graph.nodes)): + if len(list(node.users)) == 0: + if ( + node.op == "get_attr" + or (node.op == "call_function" and node.target is operator.getitem) + or ( + node.op == "call_function" + and node.target is torch._check + and is_static_true(node.args[0]) + ) + or is_symnode_compute_node(node) + or is_accessor_node(node) + ): + self.remove_node(node) + + def placeholder_binds_symbol(node): + arg = node.meta["grapharg"] + example = arg.example + if isinstance(example, torch.SymInt) and isinstance( + example.node.expr, sympy.Symbol + ): + return example.node.expr + return None + + def remove_unused(node): + log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name()) + # I'm not really sure why you need to delete these from the + # node since the node is going to get removed + del node.meta["grapharg"] + self.remove_node(node) + self.real_value_cache.pop(node, None) + + used_symbols: Set[sympy.Symbol] = set() + + def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]): + used_symbols |= free_symbols(fake) + + recheck_placeholders = [] + for node in self.placeholders: + binds_symbol = placeholder_binds_symbol(node) is not None + # Don't delete symbol bindings yet + if binds_symbol: + if not node.users: + recheck_placeholders.append(node) + else: + if not node.users and not isinstance( + node.meta["grapharg"], BackwardStateGraphArg + ): + remove_unused(node) + else: + # Register the free symbols as uses + arg = node.meta["grapharg"] + if isinstance(arg, BackwardStateGraphArg): + continue + if isinstance(node.meta["grapharg"].example, torch.ScriptObject): + real_script_obj = node.meta["grapharg"].example + fake_script_obj = node.meta["grapharg"].example_strong_ref + if not torch._library.fake_class_registry.tracing_with_real( + real_script_obj + ): + flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined] + for attr in flat_dict.keys(): + fake_attr_val = getattr( + fake_script_obj.wrapped_obj, attr + ) + pytree.tree_map_only( + (torch.SymInt, torch.Tensor), + lambda t: update_used_symbols(used_symbols, t), + fake_attr_val, + ) + continue + fake = ( + arg.fake_tensor if arg.fake_tensor is not None else arg.example + ) + update_used_symbols(used_symbols, fake) + + # After removing unused graphargs, prune unused binds_symbol + for node in recheck_placeholders: + symbol = placeholder_binds_symbol(node) + if symbol is not None: + if symbol not in used_symbols: + remove_unused(node) + else: + # Make sure we delete later occurrences of the same symbol + used_symbols.remove(symbol) + + def add_output_instructions(self, prefix: List[Instruction]) -> None: + """ + We call this on the creation of a new compiled subgraph that is inserted + before user code. + """ + self.output_instructions.extend(prefix) + self.should_exit = True + + def install_global_unsafe(self, name, value) -> None: + """ + WARNING: prefer the safer `install_global_by_id/install_global`. + torch.compile instances should be independent of each other; + one footgun is to have one instance depend on the existence of + a global installed by another instance. This can happen if we mangle + a global the same way across both instances. + """ + assert name not in self.installed_globals + self.installed_globals.add(name) + self.cleanups.append(CleanupHook.create(self.global_scope, name, value)) + + def install_global_by_id(self, prefix, value) -> str: + """ + Installs a global if it hasn't been installed already. + This is determined by (prefix, id(value)) pair. + + Returns the name of the newly installed global. + """ + # NB: need self.compile_id to distinguish this global + # from another global created in a different torch.compile instance + name = f"{prefix}_{id(value)}_c{self.compile_id}" + if name in self.installed_globals: + return name + self.install_global_unsafe(name, value) + return name + + def install_global(self, prefix, value) -> str: + """ + Installs a global, generating a unique name for it. + + Returns the name of the newly installed global. + """ + # NB: unique_id is unique, even across torch.compile instances + name = unique_id(prefix) + self.install_global_unsafe(name, value) + return name + + def cleanup(self) -> None: + # There is a reference cycle between tracer and OutputGraph, causing + # some of the tensor objects to be held alive for longer than necessary. + self.root_tx = None + self.nn_modules.clear() + self.param_name_to_source = None + + for node in self.graph.nodes: + if "grapharg" in node.meta: + del node.meta["grapharg"] + self.real_value_cache.clear() + self.input_name_to_proxy.clear() + self.side_effects.clear() + self.variable_tracker_cache.clear() + self.register_finalizer_fns.clear() + self.dynamo_flat_name_to_original_fqn.clear() + self.tracing_context.clear() + + def set_torch_function_state(self, enabled: bool) -> None: + self.torch_function_enabled = enabled + + def add_graph_finalizer( + self, register_finalizer: Callable[[fx.GraphModule], None] + ) -> None: + self.register_finalizer_fns.append(register_finalizer) + + def example_value_from_input_node(self, node: torch.fx.Node): + """Extract the non-fake example tensor""" + if node.op == "placeholder": + return node.meta["grapharg"].example + assert node.op == "get_attr" + return self.nn_modules[node.target] # type: ignore[index] + + +err_epilogue = ( + "With the current config, we will graph break " + "(and fall back to eager-mode PyTorch) on all ops " + "that have do not have the 'pt2_compliant_tag'. " + "Please see the following doc for how to mark this op as PT2 compliant " + "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html" +) + + +def check_pt2_compliant_op(output_graph, kind, target, args, kwargs): + if kind != "call_function": + return + + def encountered_compliant_op(target): + if target.namespace in {"prim", "prims", "aten"}: + return + output_graph.compliant_custom_ops.add(target) + + def encountered_non_compliant_op(target, msg): + output_graph.non_compliant_ops.add(target) + if config.only_allow_pt2_compliant_ops: + unimplemented(msg + " " + err_epilogue) + + if isinstance(target, torch._ops.OpOverload): + if torch.Tag.pt2_compliant_tag in target.tags: + encountered_compliant_op(target) + return + encountered_non_compliant_op( + target, + f"Encountered the torch.ops.OpOverload {target} " + f"that is not PT2 compliant.", + ) + return + + if isinstance(target, torch._ops.OpOverloadPacket): + overloads = tuple(target.overloads()) + # Optimization: Overload resolution is expensive. + # If there's only one overload, we know what it will resolve to. + if len(overloads) == 1: + op = getattr(target, overloads[0]) + if torch.Tag.pt2_compliant_tag in op.tags: + encountered_compliant_op(op) + return + encountered_non_compliant_op( + op, + f"Encountered the non-overloaded " + f"torch.ops.OpOverloadPacket {target} " + f"that is not PT2 compliant. ", + ) + return + + args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes( + output_graph.current_tx, (args, kwargs), False + ) + try: + overload = torch._C._jit_resolve_packet( + target._qualified_op_name, *args, **kwargs + ) + except RuntimeError as e: + unimplemented(str(e)) + + op = getattr(target, overload) + if torch.Tag.pt2_compliant_tag in op.tags: + encountered_compliant_op(op) + else: + encountered_non_compliant_op( + op, + f"Encountered the torch.ops.OpOverloadPacket {target} " + f"which resolves to the overload ({overload}) that is " + f"not PT2 compliant.", + ) + + +_compile_id_counter = itertools.count() + + +class SubgraphTracer(fx.Tracer): + """ + Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer + and the separation of responsibilities is that SubgraphTracer is + responsible for building the graph while OutputGraph is responsible for + compiling and executing the graph. + """ + + def __init__( + self, output_graph, parent=None, export_root=False, source_target=None + ): + super().__init__() + self.output_graph = weakref.proxy(output_graph) + self.graph = torch.fx.Graph() + + # The export is only ever set for the ROOT tracer. It controls + # whether or not certain inputs are allowed to be added or not. + # Look at call sites of create_graph_input to see how it is used. + if export_root: + assert parent is None + self.export_root = export_root + # Map from graph input name to its placeholder proxy object, where the + # map's keys give all current placeholder node names and can be used to + # create unique node names + self.input_name_to_proxy: Dict[str, fx.Proxy] = {} + # Node => computed real value (see utils.get_real_value) + self.real_value_cache: Dict[fx.Node, torch.Tensor] = {} + + # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design] + self.parent = parent + # A dict mapping previously free variables (Proxy objects) + # to new Proxy objects that wrap inputs to this subgraph. + # + # This dict serves two purposes: + # - Proxies are associated with VariableTrackers. If we see + # the same VariableTracker twice (and it is a free variable), + # then we want to use the same Proxy in the current subgraph to + # record the tracing. + # - If we are tracing a HigherOrderOperator's body_fn, then we + # need to keep track of what free variables were lifted so we can + # rewrite the HigherOrderOperator call using the traced body_fn. + # Dicts maintain the order of args for the HigherOrderOperator call. + self.lifted_freevars = {} + self.prev_inst = None + + self._cur_code = None + self._orig_gm_meta = None + self._orig_gm_lineno_map = None + self._orig_gm_firstlineno = None + # Each SubgraphTracer is associated with a source target, which indicates + # which operator this subgraph is attached to. We compute a source_fn_stack + # based on the source target. For the root tracer, it's set to []. + # This is useful for debugging and transforming the exported graph. + if self.parent is None: + self.source_fn_stack = [] + else: + self.source_fn_stack = self.parent.source_fn_stack + [ + (self.graph._target_to_str(source_target), source_target) + ] + + # preserve original meta if it is available + def _maybe_preserve_original_meta(self, tx, node): + if ( + self._orig_gm_meta + and self._orig_gm_lineno_map + and self._orig_gm_firstlineno + ): + lineno = tx.current_instruction.starts_line + node_idx = None + if lineno is not None: + node_idx = self._orig_gm_lineno_map.get( + lineno - self._orig_gm_firstlineno, None + ) + if node_idx is not None: + meta = self._orig_gm_meta[node_idx] + for field in fx.proxy._COPY_META_FIELDS: + if field in meta: + node.meta[field] = meta[field] + if "stack_trace" in meta: + node.meta["stack_trace"] = meta["stack_trace"] + + def create_proxy( + self, + kind, + target, + args, + kwargs, + name=None, + type_expr=None, + proxy_factory_fn=None, + ): + # NOTE: [Nested SubgraphTracer and free_variable handling] + # -------------------------------------------------------- + # Read NOTE [HigherOrderOperator tracing design] first. + # + # Let's say we're in the middle of introspecting the body of a possibly + # nested HigherOrderOperator, and we see a free variable. + # + # There are two cases: + # 1. We see a free variable that is already tracked by Dynamo. + # 2. We see a free variable that has not been tracked by Dynamo + # + # In case 1, we call `maybe_lift_tracked_freevar_to_input` (below) + # which will lift the freevar to be an input of this subgraph + # and also recursively lift it to be an input on the parent(s). + # + # In case 2, before the call to `create_proxy`, the InstructionTranslator + # will see the freevar when it gets loaded by Python bytecode. + # E.g. for Python 3.11 the bytecodes that may do this are LOAD_DEREF or + # LOAD_GLOBAL. + # There, the InstructionTranslator asks Dynamo to begin tracking the + # freevar by building a new Variable. + # Building a new Variable automatically lifts the freevar to be an + # input of the root SubgraphTracer. + # + # The implications for the code below are: + # - We will always be in Case 1 when we get to this code. + # - Any "free variable" we encounter here is guaranteed to already be + # bound, that is, it is either a graph input of the root graph, or + # some local variable of the root graph or a subgraph. + # - The additional work we need to do here is *only* that we need to + # lift this free variable into inputs (recursively) of each nested + # higher-order-op subgraph until we hit the subgraph where the free + # variable is bound + if self.parent is not None: + flat_args, tree_spec = pytree.tree_flatten((args, kwargs)) + new_flat_args = [] + for arg in flat_args: + maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg) + new_flat_args.append(maybe_new_arg) + + args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec) + + rv = super().create_proxy( + kind, target, args, kwargs, name, type_expr, proxy_factory_fn + ) + + # append stack trace to fx node + tx = self.output_graph.current_tx + + # log detailed location of line of code in 3.11 + if sys.version_info >= (3, 11) and kind in ( + "call_function", + "call_method", + "call_module", + ): + cur_inst = tx.current_instruction + if ( + cur_inst is not self.prev_inst + and cur_inst.positions is not None + and cur_inst.positions.lineno is not None + ): + tx_code = tx.f_code + header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno) + + def get_trace_call_log_str(): + line = get_instruction_source_311(tx_code, cur_inst).rstrip() + return f"TRACE FX call {rv.node.name} from {header}\n{line}" + + trace_call_log.debug("%s", LazyString(get_trace_call_log_str)) + self.prev_inst = cur_inst + + # update reference to original meta if we're tracing a new code object + is_retracing = False + if tx.f_code is not self._cur_code: + orig_graphmodule_maybe = code_context.get_context(tx.f_code).get( + "orig_graphmodule", lambda: None + )() + if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule): + is_retracing = True + self._orig_gm_meta = [ + nd.meta for nd in orig_graphmodule_maybe.graph.nodes + ] + self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map + self._orig_gm_firstlineno = ( + orig_graphmodule_maybe.forward.__code__.co_firstlineno + ) + else: + self._orig_gm_meta = None + self._orig_gm_lineno_map = None + self._orig_gm_firstlineno = None + nn_module_stack = tx.nn_module_stack + if nn_module_stack: + rv.node.meta["nn_module_stack"] = nn_module_stack.copy() + + if kind in {"call_function", "call_method"}: + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + (rv.node.name, target) + ] + elif kind == "call_module": + if self.parent is not None: + unimplemented("Invoking an nn.Module inside HigherOrderOperator") + # For modules we store the class + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + ( + rv.node.name, + rv.node.meta["nn_module_stack"][target][1], + ) + ] + + self._maybe_preserve_original_meta(tx, rv.node) + + if not is_retracing: + if "nn_module_stack" not in rv.node.meta: + nn_module_stack = tx.nn_module_stack + if nn_module_stack: + rv.node.meta["nn_module_stack"] = nn_module_stack.copy() + + if "source_fn_stack" not in rv.node.meta: + if kind in {"call_function", "call_method"}: + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + (rv.node.name, target) + ] + elif kind == "call_module": + if self.parent is not None: + unimplemented( + "Invoking an nn.Module inside HigherOrderOperator" + ) + # For modules we store the class + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + ( + rv.node.name, + rv.node.meta["nn_module_stack"][target][1], + ) + ] + + if "stack_trace" not in rv.node.meta: + frame_summaries: List[traceback.FrameSummary] = [] + while tx: + # Avoid frame summaries from inside the torch/nn/modules. This ensures that we keep the stack trace of + # the user code. + if not tx.is_co_filename_from_nn_modules(): + frame_summaries.append(tx.frame_summary()) + tx = getattr(tx, "parent", None) + # Reverse the frame_summaries, such that the innermost frame is at the last + frame_summaries.reverse() + + # official from_list stub doesn't have new-style type + msgs = traceback.StackSummary.from_list(frame_summaries).format() + rv.node.stack_trace = "".join(msgs) + + return rv + + def create_node( + self, op, target, args=None, kwargs=None, name=None, type_expr=None + ): + check_pt2_compliant_op(self.output_graph, op, target, args, kwargs) + if self.parent is not None: + flat_args = pytree.arg_tree_leaves(*args, **kwargs) + for arg in flat_args: + if not isinstance(arg, torch.fx.Node): + continue + assert ( + arg.graph == self.graph + ), "create_node using arg not from this SubgraphTracer" + + node = super().create_node(op, target, args, kwargs, name, type_expr) + node.meta["creation_timestamp"] = self.output_graph.timestamp + return node + + # Note: we did not override erase_node since + # we call self.graph.erase_node elsewhere + def remove_node(self, node): + if len(node.users) > 0: + user_graph_nodes: List[torch.fx.Node] = [] + for user in node.users.keys(): + # For the case where user.graph == self.graph, that is a real bug and will raise + # properly. + if user.graph != self.graph: + # This is a nested graph, which needs to be deleted. + # If we do not do this, we will raise on attempting to remove this. + # As we only get here during restoration cleanup, this is sound. + user_graph_nodes.extend(reversed(list(user.graph.nodes))) + for other_graph_node in user_graph_nodes: + other_graph_node.graph.erase_node(other_graph_node) + self.graph.erase_node(node) + self.input_name_to_proxy.pop(node.name, None) + + # when before=True, we will insert this input before the most recent + # inserted proxy. This is a hack to get around an ordering problem, + # where we first insert a tensor argument, and then insert bindings + # for SymInts that may occur in the tensor argument. + # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets + # fixed. + def create_graph_input(self, name, type_expr=None, before=False, source=None): + log.debug( + "create_graph_input %s %s", + name, + source.name() if source is not None else "(none)", + ) + if source is None: + assert ( + self.parent is not None + ), "you are required to provide a source for inputs on the root tracer" + + # In eager, we are generally OK with adding graph inputs whenever we + # want, because we take care of writing the bytecode that knows how + # to source all the inputs. + # + # In export, this is bad, because you want a self-contained export + # object which only depends on the inputs you explicitly passed to it. + # So we are a bit more strict about what sources can become inputs + # in export + if self.export_root: + if not is_from_local_source(source, allow_cell_or_freevar=False): + self.output_graph.source_to_user_stacks.setdefault(source, []).append( + TracingContext.extract_stack() + ) + + # unique + if name in self.input_name_to_proxy: + for i in itertools.count(): + candidate_name = f"{name}_{i}" + if candidate_name not in self.input_name_to_proxy: + name = candidate_name + break + + if self.input_name_to_proxy: + prev_name = next(reversed(self.input_name_to_proxy)) + node = self.input_name_to_proxy[prev_name].node + if before: + ctx = self.graph.inserting_before(node) + else: + ctx = self.graph.inserting_after(node) + else: + ctx = self.graph.inserting_before(None) + with ctx: + proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr) + if self.input_name_to_proxy and before: + k, v = self.input_name_to_proxy.popitem() + self.input_name_to_proxy[name] = proxy + self.input_name_to_proxy[k] = v + else: + self.input_name_to_proxy[name] = proxy + return proxy + + # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details + def lift_tracked_freevar_to_input(self, proxy): + # You're doing something wrong if we are the root SubgraphTracer because + # Dynamo adds tensors to graph inputs before creating a proxy for them. + assert ( + self.parent is not None + ), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer" + # Proxys are associated with VariableTracker. + # It is possible that we've already lifted the Proxy to be an input. + # If that is the case, just return the already lifted Proxy. + if proxy in self.lifted_freevars: + return self.lifted_freevars[proxy] + new_proxy = self.create_graph_input(proxy.node.name) + set_example_value(new_proxy.node, proxy.node.meta["example_value"]) + self.lifted_freevars[proxy] = new_proxy + if self.parent is not None and proxy.tracer != self.parent: + self.parent.lift_tracked_freevar_to_input(proxy) + return new_proxy + + def maybe_lift_tracked_freevar_to_input(self, arg): + """ + If arg is a free variable, then lift it to be an input. + Returns the new lifted arg (if arg was a freevar), else the + original arg. + """ + if not isinstance(arg, torch.fx.Proxy): + return arg + elif arg.tracer == self: + return arg + return self.lift_tracked_freevar_to_input(arg) + + +# NOTE: [HigherOrderOperator tracing design] +# Ignoring HigherOrderOperators for a moment, +# OutputGraph represents the graph being built by Dynamo that may be compiled +# and executed. It holds a root SubgraphTracer where the FX graph is built. +# +# HigherOrderOperators are operators that take functions as their arguments. +# When Dynamo encounters a HigherOrderOperator, then it attempts to introspect +# the function passed to it (call this the "body function"), capture it into a +# GraphModule, and rewrite the call to the HigherOrderOperator to use the +# GraphModule. +# +# The way we handle the capture of body functions is through having +# (possibly nested) SubgraphTracers, one per body function. +# +# Mechanically, we do the introspection by: +# - Creating a new SubgraphTracer via OutputGraph.subtracer +# - Executing the body function. +# This constructs the graph of the body function in the new SubgraphTracer +# while modifying the state of the OutputGraph. For example: +# - the OutputGraph can receive new GraphArgs (if we discover any new +# untracked Tensors) +# - side effects from the body function get accumulated into +# OutputGraph.side_effects +# - guards produced by the body function get accumulated into OutputGraph.guards +# +# The traced function has some special properties that make it easier for us +# to transform later down the line: +# - we lift all free variables to being inputs. +# +# If the introspection fails (due to the existence of graph breaks), then +# we roll back the current OutputGraph state and graph break on the +# HigherOrderOperator. diff --git a/lib/python3.10/site-packages/torch/_dynamo/profiler.py b/lib/python3.10/site-packages/torch/_dynamo/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..b06fead4c845e790a667d655294b836b87c1160b --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/profiler.py @@ -0,0 +1,156 @@ +# mypy: allow-untyped-defs +import dataclasses +import os +from typing import Any, List + +import torch + +from .utils import print_once + + +@dataclasses.dataclass +class ProfileMetrics: + microseconds: float = 0.0 + operators: int = 0 + fusions: int = 0 + graphs: int = 0 + + def __iadd__(self, other: "ProfileMetrics"): + self.microseconds += other.microseconds + self.operators += other.operators + self.fusions += other.fusions + return self + + def __add__(self, other: "ProfileMetrics"): + assert isinstance(other, ProfileMetrics) + return ProfileMetrics( + self.microseconds + other.microseconds, + self.operators + other.operators, + self.fusions + other.fusions, + ) + + def __truediv__(self, other): + if isinstance(other, int): + other = ProfileMetrics(other, other, other) + return ProfileMetrics( + self.microseconds / max(1, other.microseconds), + self.operators / max(1, other.operators), + self.fusions / max(1, other.fusions), + ) + + def __str__(self) -> str: + return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time" + + def tocsv(self): + return [self.operators, self.microseconds] + + +class ProfileResult: + def __init__(self, captured, total, unique_graphs) -> None: + self.captured: ProfileMetrics = captured or ProfileMetrics() + self.total: ProfileMetrics = total or ProfileMetrics() + self.unique_graphs: int = unique_graphs + + def __iadd__(self, other: "ProfileResult"): + self.captured += other.captured + self.total += other.total + self.unique_graphs += other.unique_graphs + return self + + def percent(self): + return self.captured / self.total + + def __str__(self) -> str: + return ( + f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls " + f"{self.captured.operators:4}/{self.total.operators:4} = " + + str(self.percent()) + ) + + def tocsv(self): + return [ + self.unique_graphs, + self.captured.graphs, + self.captured.operators, + self.total.operators, + ] + self.percent().tocsv() + + +def should_print_missing(): + return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1" + + +def print_missing(stack): + if any("/torch/autograd/profiler.py" in x for x in stack): + return + stack = [ + x for x in stack if ("> ".join(stack[-3:])) + + +class Profiler: + unique_graphs = 0 + + def __init__(self) -> None: + self.prof = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + with_stack=should_print_missing(), + ) + + def results(self): + captured_regions = 0 + captured_ops = 0 + captured_microseconds = 0 + total_ops = 0 + total_microseconds = 0 + + last_op_end_time = -1 + captured_region_end_time = -1 + events = sorted(self.prof.events(), key=lambda x: x.time_range.start) + for e in events: + if e.name == "TORCHDYNAMO": + captured_region_end_time = e.time_range.end + captured_regions += 1 + # ignore `handle = torch.zeros(1)` in record_function.__init__() + total_ops -= 1 + elif e.time_range.start >= last_op_end_time: + last_op_end_time = e.time_range.end + if e.time_range.end <= captured_region_end_time: + captured_ops += 1 + captured_microseconds += e.time_range.elapsed_us() + elif should_print_missing(): + print_missing(e.stack) + total_ops += 1 + total_microseconds += e.time_range.elapsed_us() + else: + pass # ops recursively called from other ops (ignored) + + unique_graphs = Profiler.unique_graphs + Profiler.unique_graphs = 0 + # we counted one extra op that is part of the profiler setup code + total_ops -= 1 + + return ProfileResult( + captured=ProfileMetrics( + microseconds=captured_microseconds, + operators=captured_ops, + fusions=captured_ops - captured_regions, + graphs=captured_regions, + ), + total=ProfileMetrics( + microseconds=total_microseconds, + operators=total_ops, + fusions=total_ops - 1, + ), + unique_graphs=unique_graphs, + ) + + +def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: List[Any]): + def _wrapped(*args): + with torch.profiler.record_function("TORCHDYNAMO"): + return gm.forward(*args) + + Profiler.unique_graphs += 1 + return _wrapped diff --git a/lib/python3.10/site-packages/torch/_dynamo/replay_record.py b/lib/python3.10/site-packages/torch/_dynamo/replay_record.py new file mode 100644 index 0000000000000000000000000000000000000000..8a259b6156aa18d3043cf189c9320bce4ebe48fe --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/replay_record.py @@ -0,0 +1,112 @@ +# mypy: allow-untyped-defs +import dataclasses +from dataclasses import field +from types import CodeType, ModuleType +from typing import Any, Dict + +from torch.utils._import_utils import import_dill + + +dill = import_dill() + + +@dataclasses.dataclass +class ModuleRecord: + module: ModuleType + accessed_attrs: Dict[str, Any] = field(default_factory=dict) + + +@dataclasses.dataclass +class DummyModule: + name: str + is_torch: bool = False + + @property + def __name__(self): + return self.name + + +@dataclasses.dataclass +class ExecutionRecord: + code: CodeType + globals: Dict[str, Any] = field(default_factory=dict) + locals: Dict[str, Any] = field(default_factory=dict) + builtins: Dict[str, Any] = field(default_factory=dict) + code_options: Dict[str, Any] = field(default_factory=dict) + + def dump(self, f): + assert dill is not None, "replay_record requires `pip install dill`" + dill.dump(self, f) + + @classmethod + def load(cls, f): + assert dill is not None, "replay_record requires `pip install dill`" + return dill.load(f) + + +@dataclasses.dataclass +class ExecutionRecorder: + LOCAL_MOD_PREFIX = "___local_mod_" + + code: CodeType + globals: Dict[str, Any] = field(default_factory=dict) + locals: Dict[str, Any] = field(default_factory=dict) + builtins: Dict[str, Any] = field(default_factory=dict) + code_options: Dict[str, Any] = field(default_factory=dict) + name_to_modrec: Dict[str, Any] = field(default_factory=dict) + + def add_local_var(self, name, var): + if isinstance(var, ModuleType): + self.locals[name] = self._add_mod(var) + else: + self.locals[name] = var + + def add_global_var(self, name, var): + if isinstance(var, ModuleType): + self.globals[name] = self._add_mod(var) + else: + self.globals[name] = var + + def add_local_mod(self, name, mod): + assert isinstance(mod, ModuleType) + + self.add_global_var(name, mod) + + def record_module_access(self, mod, name, val): + if isinstance(val, ModuleType): + self.name_to_modrec[mod.__name__].accessed_attrs[name] = self._add_mod(val) + return + + if mod.__name__ in self.name_to_modrec: + self.name_to_modrec[mod.__name__].accessed_attrs[name] = val + + def get_record(self): + return ExecutionRecord( + self.code, + ExecutionRecorder._resolve_modules(self.globals), + ExecutionRecorder._resolve_modules(self.locals), + self.builtins.copy(), + self.code_options.copy(), + ) + + def _add_mod(self, mod): + if mod.__name__ not in self.name_to_modrec: + self.name_to_modrec[mod.__name__] = ModuleRecord(mod) + + return self.name_to_modrec[mod.__name__] + + # Convert ModuleRecords -> DummyModule tree + @classmethod + def _resolve_modules(cls, vars): + def resolve_module(var): + if not isinstance(var, ModuleRecord): + return var + + dummy_mod = DummyModule(var.module.__name__) + for attr_name, attr_value in var.accessed_attrs.items(): + attr_value = resolve_module(attr_value) + dummy_mod.__setattr__(attr_name, attr_value) + + return dummy_mod + + return {k: resolve_module(v) for k, v in vars.items()} diff --git a/lib/python3.10/site-packages/torch/_dynamo/resume_execution.py b/lib/python3.10/site-packages/torch/_dynamo/resume_execution.py new file mode 100644 index 0000000000000000000000000000000000000000..132e9e4081bceb4ac53d2ec5627839f8b9b807c9 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/resume_execution.py @@ -0,0 +1,720 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import sys +import types +from typing import Any, cast, Dict, List, Optional, Tuple + +from .bytecode_transformation import ( + create_call_function, + create_call_method, + create_dup_top, + create_instruction, + create_jump_absolute, + create_load_method, + Instruction, + InstructionExnTabEntry, + transform_code_object, + unique_id, +) +from .utils import ExactWeakKeyDictionary + + +# taken from code.h in cpython +CO_OPTIMIZED = 0x0001 +CO_NEWLOCALS = 0x0002 +CO_VARARGS = 0x0004 +CO_VARKEYWORDS = 0x0008 +CO_NESTED = 0x0010 +CO_GENERATOR = 0x0020 +CO_NOFREE = 0x0040 +CO_COROUTINE = 0x0080 +CO_ITERABLE_COROUTINE = 0x0100 +CO_ASYNC_GENERATOR = 0x0200 + +# trace_rules.py import this constant for consistency +TORCH_DYNAMO_RESUME_IN_PREFIX = "torch_dynamo_resume_in" + + +def _initial_push_null(insts): + if sys.version_info >= (3, 11): + insts.append(create_instruction("PUSH_NULL")) + if sys.version_info < (3, 13): + insts.append(create_instruction("SWAP", arg=2)) + + +@dataclasses.dataclass(frozen=True) +class ReenterWith: + stack_index: int + target_values: Optional[Tuple[Any, ...]] = None + + # If we do not want to destroy the stack, we can do the same thing as a + # `SETUP_WITH` block, only that we store the context manager in a local_symbol + def try_except(self, code_options, cleanup: List[Instruction]): + """ + Codegen based off of: + load args + enter context + try: + (rest) + finally: + exit context + """ + # NOTE: we assume that TOS is a context manager CLASS! + load_args = [] + if self.target_values: + load_args = [ + create_instruction("LOAD_CONST", argval=val) + for val in self.target_values + ] + ctx_name = unique_id(f"___context_manager_{self.stack_index}") + if ctx_name not in code_options["co_varnames"]: + code_options["co_varnames"] += (ctx_name,) + for name in ["__enter__", "__exit__"]: + if name not in code_options["co_names"]: + code_options["co_names"] += (name,) + + except_jump_target = create_instruction( + "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO" + ) + cleanup_complete_jump_target = create_instruction("NOP") + + setup_finally: List[Instruction] = [] + _initial_push_null(setup_finally) + + # TODO(williamwen42) call method order is wrong for 3.13+ - will fix later + setup_finally.extend( + [ + *load_args, + *create_call_function(len(load_args), False), + create_instruction("STORE_FAST", argval=ctx_name), + create_instruction("LOAD_FAST", argval=ctx_name), + create_load_method("__enter__"), + *create_call_method(0), + create_instruction("POP_TOP"), + ] + ) + + if sys.version_info < (3, 11): + setup_finally.append( + create_instruction("SETUP_FINALLY", target=except_jump_target) + ) + else: + exn_tab_begin = create_instruction("NOP") + exn_tab_end = create_instruction("NOP") + exn_tab_begin.exn_tab_entry = InstructionExnTabEntry( + exn_tab_begin, + exn_tab_end, + except_jump_target, + self.stack_index + 1, + False, + ) + setup_finally.append(exn_tab_begin) + + def create_reset(): + return [ + create_instruction("LOAD_FAST", argval=ctx_name), + create_load_method("__exit__"), + create_instruction("LOAD_CONST", argval=None), + create_dup_top(), + create_dup_top(), + *create_call_method(3), + create_instruction("POP_TOP"), + ] + + if sys.version_info < (3, 9): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("BEGIN_FINALLY"), + except_jump_target, + *create_reset(), + create_instruction("END_FINALLY"), + ] + elif sys.version_info < (3, 11): + epilogue = [ + create_instruction("POP_BLOCK"), + *create_reset(), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + *create_reset(), + create_instruction("RERAISE"), + cleanup_complete_jump_target, + ] + else: + finally_exn_tab_end = create_instruction("RERAISE", arg=0) + finally_exn_tab_target = create_instruction("COPY", arg=3) + except_jump_target.exn_tab_entry = InstructionExnTabEntry( + except_jump_target, + finally_exn_tab_end, + finally_exn_tab_target, + self.stack_index + 2, + True, + ) + epilogue = [ + exn_tab_end, + *create_reset(), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, # PUSH_EXC_INFO + *create_reset(), + finally_exn_tab_end, # RERAISE 0 + finally_exn_tab_target, # COPY 3 + create_instruction("POP_EXCEPT"), + create_instruction("RERAISE", arg=1), + cleanup_complete_jump_target, + ] + + cleanup[:] = epilogue + cleanup + return setup_finally + + def __call__(self, code_options, cleanup): + """ + Codegen based off of: + with ctx(args): + (rest) + """ + # NOTE: we assume that TOS is a context manager CLASS! + load_args = [] + if self.target_values: + load_args = [ + create_instruction("LOAD_CONST", argval=val) + for val in self.target_values + ] + if sys.version_info < (3, 9): + with_cleanup_start = create_instruction("WITH_CLEANUP_START") + begin_finally = create_instruction("BEGIN_FINALLY") + cleanup[:] = [ + create_instruction("POP_BLOCK"), + begin_finally, + with_cleanup_start, + create_instruction("WITH_CLEANUP_FINISH"), + create_instruction("END_FINALLY"), + ] + cleanup + + return [ + *load_args, + create_instruction("CALL_FUNCTION", arg=len(load_args)), + create_instruction("SETUP_WITH", target=with_cleanup_start), + create_instruction("POP_TOP"), + ], None + elif sys.version_info < (3, 11): + with_except_start = create_instruction("WITH_EXCEPT_START") + pop_top_after_with_except_start = create_instruction("POP_TOP") + + cleanup_complete_jump_target = create_instruction("NOP") + + cleanup[:] = [ + create_instruction("POP_BLOCK"), + create_instruction("LOAD_CONST", argval=None), + create_instruction("DUP_TOP"), + create_instruction("DUP_TOP"), + create_instruction("CALL_FUNCTION", arg=3), + create_instruction("POP_TOP"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + with_except_start, + create_instruction( + "POP_JUMP_IF_TRUE", target=pop_top_after_with_except_start + ), + create_instruction("RERAISE"), + pop_top_after_with_except_start, + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_EXCEPT"), + create_instruction("POP_TOP"), + cleanup_complete_jump_target, + ] + cleanup + + return [ + *load_args, + create_instruction("CALL_FUNCTION", arg=len(load_args)), + create_instruction("SETUP_WITH", target=with_except_start), + create_instruction("POP_TOP"), + ], None + else: + pop_top_after_with_except_start = create_instruction("POP_TOP") + cleanup_complete_jump_target = create_instruction("NOP") + + def create_load_none(): + return create_instruction("LOAD_CONST", argval=None) + + exn_tab_1_begin = create_instruction("POP_TOP") + exn_tab_1_end = create_instruction("NOP") + exn_tab_1_target = create_instruction("PUSH_EXC_INFO") + exn_tab_2_end = create_instruction("RERAISE", arg=2) + exn_tab_2_target = create_instruction("COPY", arg=3) + + exn_tab_1_begin.exn_tab_entry = InstructionExnTabEntry( + exn_tab_1_begin, + exn_tab_1_end, + exn_tab_1_target, + self.stack_index + 1, + True, + ) + exn_tab_1_target.exn_tab_entry = InstructionExnTabEntry( + exn_tab_1_target, + exn_tab_2_end, + exn_tab_2_target, + self.stack_index + 3, + True, + ) + pop_top_after_with_except_start.exn_tab_entry = InstructionExnTabEntry( + pop_top_after_with_except_start, + pop_top_after_with_except_start, + exn_tab_2_target, + self.stack_index + 3, + True, + ) + + cleanup[:] = [ + exn_tab_1_end, + create_load_none(), + create_load_none(), + create_load_none(), + *create_call_function(2, False), + create_instruction("POP_TOP"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + exn_tab_1_target, # PUSH_EXC_INFO + create_instruction("WITH_EXCEPT_START"), + create_instruction( + "POP_JUMP_FORWARD_IF_TRUE" + if sys.version_info < (3, 12) + else "POP_JUMP_IF_TRUE", + target=pop_top_after_with_except_start, + ), + exn_tab_2_end, # RERAISE 2 + exn_tab_2_target, # COPY 3 + create_instruction("POP_EXCEPT"), + create_instruction("RERAISE", arg=1), + pop_top_after_with_except_start, + create_instruction("POP_EXCEPT"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + cleanup_complete_jump_target, + ] + cleanup + + ret: List[Instruction] = [] + _initial_push_null(ret) + ret.extend( + [ + *load_args, + *create_call_function(len(load_args), False), + create_instruction("BEFORE_WITH"), + exn_tab_1_begin, # POP_TOP + ] + ) + return ret, exn_tab_1_target + + +@dataclasses.dataclass +class ResumeFunctionMetadata: + code: types.CodeType + instructions: List[Instruction] = dataclasses.field(default_factory=list) + # Python 3.11+ fields + # NOTE: Python 3.11 removed blocks, but for our purposes, a "block" consists + # of instructions of all exception table entries that have the same target. + + # map from PUSH_EXC_INFO's in the prefix to original block target offset + prefix_block_target_offset_remap: List[int] = dataclasses.field( + default_factory=list + ) + # map from new block target offsets to original block target offsets + block_target_offset_remap: Optional[Dict[int, int]] = None + + +def _filter_iter(l1, l2, cond): + """ + Two-pointer conditional filter. + e.g. _filter_iter(insts, sorted_offsets, lambda i, o: i.offset == o) + returns the instructions with offsets in sorted_offsets + """ + it = iter(l2) + res: List[Instruction] = [] + try: + cur = next(it) + for val in l1: + if cond(val, cur): + res.append(val) + cur = next(it) + except StopIteration: + pass + return res + + +def _load_tuple_and_call(tup): + insts: List[Instruction] = [] + _initial_push_null(insts) + for val in tup: + insts.append(create_instruction("LOAD_CONST", argval=val)) + insts.extend(create_call_function(len(tup), False)) + return insts + + +class ContinueExecutionCache: + cache = ExactWeakKeyDictionary() + generated_code_metadata = ExactWeakKeyDictionary() + + @classmethod + def lookup(cls, code, lineno, *key): + if code not in cls.cache: + cls.cache[code] = {} + key = tuple(key) + if key not in cls.cache[code]: + cls.cache[code][key] = cls.generate(code, lineno, *key) + return cls.cache[code][key] + + @classmethod + def generate( + cls, + code, + lineno, + offset: int, + setup_fn_target_offsets: Tuple[int], # only used in Python 3.11+ + nstack: int, + argnames: Tuple[str], + argnames_null: Tuple[str], + setup_fns: Tuple[ReenterWith], + stack_ctx_vars: Tuple[int, Tuple[Any]], + argnames_ctx_vars: Tuple[str, Tuple[Any]], + null_idxes: Tuple[int], + ) -> types.CodeType: + assert offset is not None + assert not ( + code.co_flags + & (CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR) + ) + assert code.co_flags & CO_OPTIMIZED + if code in ContinueExecutionCache.generated_code_metadata: + return cls.generate_based_on_original_code_object( + code, + lineno, + offset, + setup_fn_target_offsets, + nstack, + argnames, + argnames_null, + setup_fns, + stack_ctx_vars, + argnames_ctx_vars, + null_idxes, + ) + + is_py311_plus = sys.version_info >= (3, 11) + meta = ResumeFunctionMetadata(code) + + def update(instructions: List[Instruction], code_options: Dict[str, Any]): + meta.instructions = copy.deepcopy(instructions) + + args = [f"___stack{i}" for i in range(nstack)] + args.extend(v for v in argnames if v not in args) + freevars = tuple(code_options["co_cellvars"] or []) + tuple( + code_options["co_freevars"] or [] + ) + freevars = tuple(sorted(freevars)) + code_options[ + "co_name" + ] = f"{TORCH_DYNAMO_RESUME_IN_PREFIX}_{code_options['co_name']}_at_{lineno}" + if is_py311_plus: + qualified_path = code_options["co_qualname"].rsplit(".", maxsplit=1) + if len(qualified_path) == 1: + code_options["co_qualname"] = code_options["co_name"] + else: + assert len(qualified_path) == 2 + module_name, co_name = qualified_path + code_options[ + "co_qualname" + ] = f"{module_name}.{TORCH_DYNAMO_RESUME_IN_PREFIX}_{co_name}_at_{lineno}" + code_options["co_firstlineno"] = lineno + code_options["co_cellvars"] = () + code_options["co_freevars"] = freevars + code_options["co_argcount"] = len(args) + code_options["co_posonlyargcount"] = 0 + code_options["co_kwonlyargcount"] = 0 + code_options["co_varnames"] = tuple( + args + + [v for v in argnames_null if v not in args] + + [v for v in code_options["co_varnames"] if v not in args] + ) + code_options["co_flags"] = code_options["co_flags"] & ~( + CO_VARARGS | CO_VARKEYWORDS + ) + target = next(i for i in instructions if i.offset == offset) + + prefix = [] + if is_py311_plus: + if freevars: + prefix.append( + create_instruction("COPY_FREE_VARS", arg=len(freevars)) + ) + prefix.append(create_instruction("RESUME", arg=0)) + + cleanup: List[Instruction] = [] + hooks = {fn.stack_index: fn for fn in setup_fns} + hook_target_offsets = { + fn.stack_index: setup_fn_target_offsets[i] + for i, fn in enumerate(setup_fns) + } + offset_to_inst = {inst.offset: inst for inst in instructions} + # map old hook targets to new targets generated by the hook + old_hook_target_remap = {} + null_idxes_i = 0 + stack_ctx_vars_d = dict(stack_ctx_vars) # type: ignore[var-annotated,arg-type] + for i in range(nstack): + while ( + null_idxes_i < len(null_idxes) + and null_idxes[null_idxes_i] == i + null_idxes_i + ): + prefix.append(create_instruction("PUSH_NULL")) + null_idxes_i += 1 + prefix.append(create_instruction("LOAD_FAST", argval=f"___stack{i}")) + if i in hooks: + hook = hooks.pop(i) + hook_insts, exn_target = hook(code_options, cleanup) + prefix.extend(hook_insts) + if is_py311_plus: + hook_target_offset = hook_target_offsets.pop(i) + old_hook_target = offset_to_inst[hook_target_offset] + meta.prefix_block_target_offset_remap.append(hook_target_offset) + old_hook_target_remap[old_hook_target] = exn_target + real_i = i + null_idxes_i + if real_i in stack_ctx_vars_d: + # NOTE: we assume that current stack var is a context manager CLASS! + # Load args for context variable and construct it + prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[real_i])) + + if is_py311_plus: + # reverse the mapping since targets of later/nested contexts are inserted + # into the mapping later, but show up earlier in the prefix. + meta.prefix_block_target_offset_remap = list( + reversed(meta.prefix_block_target_offset_remap) + ) + + assert not hooks + + # NOTE: we assume that local var is a context manager CLASS! + # initialize inactive context vars in argnames + for name, vals in argnames_ctx_vars: + prefix.append(create_instruction("LOAD_FAST", argval=name)) + prefix.extend(_load_tuple_and_call(vals)) + prefix.append(create_instruction("STORE_FAST", argval=name)) + + # 3.12+: store NULL into variables that were NULL + if argnames_null: + assert sys.version_info >= (3, 12) + for v in argnames_null: + assert v not in args + prefix.extend( + [ + create_instruction("PUSH_NULL"), + create_instruction("STORE_FAST", argval=v), + ] + ) + + prefix.append(create_jump_absolute(target)) + + # because the line number table monotonically increases from co_firstlineno + # remove starts_line for any instructions before the graph break instruction + # this will ensure the instructions after the break have the correct line numbers + for inst in instructions: + if inst.offset == target.offset: + break + inst.starts_line = None + if sys.version_info >= (3, 11): + inst.positions = None + + if cleanup: + prefix.extend(cleanup) + prefix.extend(cls.unreachable_codes(code_options)) + + # remap original instructions' exception table entries + if old_hook_target_remap: + assert is_py311_plus + for inst in instructions: + if ( + inst.exn_tab_entry + and inst.exn_tab_entry.target in old_hook_target_remap + ): + inst.exn_tab_entry.target = old_hook_target_remap[ + inst.exn_tab_entry.target + ] + + # TODO(jansel): add dead code elimination here + instructions[:] = prefix + instructions + + new_code = transform_code_object(code, update) + ContinueExecutionCache.generated_code_metadata[new_code] = meta + return new_code + + @staticmethod + def unreachable_codes(code_options) -> List[Instruction]: + """Codegen a `raise None` to make analysis work for unreachable code""" + return [ + create_instruction("LOAD_CONST", argval=None), + create_instruction("RAISE_VARARGS", arg=1), + ] + + @classmethod + def generate_based_on_original_code_object( + cls, code, lineno, offset: int, setup_fn_target_offsets: Tuple[int, ...], *args + ): + """ + This handles the case of generating a resume into code generated + to resume something else. We want to always generate starting + from the original code object so that if control flow paths + converge we only generated 1 resume function (rather than 2^n + resume functions). + """ + + meta: ResumeFunctionMetadata = ContinueExecutionCache.generated_code_metadata[ + code + ] + new_offset = None + + def find_new_offset( + instructions: List[Instruction], code_options: Dict[str, Any] + ): + nonlocal new_offset + (target,) = (i for i in instructions if i.offset == offset) + # match the functions starting at the last instruction as we have added a prefix + (new_target,) = ( + i2 + for i1, i2 in zip(reversed(instructions), reversed(meta.instructions)) + if i1 is target + ) + assert target.opcode == new_target.opcode + new_offset = new_target.offset + + transform_code_object(code, find_new_offset) + + if sys.version_info >= (3, 11): + # setup_fn_target_offsets currently contains the target offset of + # each setup_fn, based on `code`. When we codegen the resume function + # based on the original code object, `meta.code`, the offsets in + # setup_fn_target_offsets must be based on `meta.code` instead. + if not meta.block_target_offset_remap: + block_target_offset_remap = meta.block_target_offset_remap = {} + + def remap_block_offsets( + instructions: List[Instruction], code_options: Dict[str, Any] + ): + # NOTE: each prefix block generates exactly one PUSH_EXC_INFO, + # so we can tell which block a prefix PUSH_EXC_INFO belongs to, + # by counting. Then we can use meta.prefix_block-target_offset_remap + # to determine where in the original code the PUSH_EXC_INFO offset + # replaced. + prefix_blocks: List[Instruction] = [] + for inst in instructions: + if len(prefix_blocks) == len( + meta.prefix_block_target_offset_remap + ): + break + if inst.opname == "PUSH_EXC_INFO": + prefix_blocks.append(inst) + + # offsets into prefix + for inst, o in zip( + prefix_blocks, meta.prefix_block_target_offset_remap + ): + block_target_offset_remap[cast(int, inst.offset)] = o + + # old bytecode targets are after the prefix PUSH_EXC_INFO's + old_start_offset = ( + cast(int, prefix_blocks[-1].offset) if prefix_blocks else -1 + ) + # offsets into old bytecode + old_inst_offsets = sorted( + n for n in setup_fn_target_offsets if n > old_start_offset + ) + targets = _filter_iter( + instructions, old_inst_offsets, lambda inst, o: inst.offset == o + ) + new_targets = _filter_iter( + zip(reversed(instructions), reversed(meta.instructions)), + targets, + lambda v1, v2: v1[0] is v2, + ) + for new, old in zip(new_targets, targets): + block_target_offset_remap[old.offset] = new[1].offset + + transform_code_object(code, remap_block_offsets) + + # if offset is not in setup_fn_target_offsets, it is an error + setup_fn_target_offsets = tuple( + meta.block_target_offset_remap[n] for n in setup_fn_target_offsets + ) + return ContinueExecutionCache.lookup( + meta.code, lineno, new_offset, setup_fn_target_offsets, *args + ) + + +""" +# partially finished support for with statements + +def convert_locals_to_cells( + instructions: List[Instruction], + code_options: Dict[str, Any]): + + code_options["co_cellvars"] = tuple( + var + for var in code_options["co_varnames"] + if var not in code_options["co_freevars"] + and not var.startswith("___stack") + ) + cell_and_free = code_options["co_cellvars"] + code_options["co_freevars"] + for inst in instructions: + if str(inst.argval).startswith("___stack"): + continue + elif inst.opname == "LOAD_FAST": + inst.opname = "LOAD_DEREF" + elif inst.opname == "STORE_FAST": + inst.opname = "STORE_DEREF" + elif inst.opname == "DELETE_FAST": + inst.opname = "DELETE_DEREF" + else: + continue + inst.opcode = dis.opmap[inst.opname] + assert inst.argval in cell_and_free, inst.argval + inst.arg = cell_and_free.index(inst.argval) + +def patch_setup_with( + instructions: List[Instruction], + code_options: Dict[str, Any] +): + nonlocal need_skip + need_skip = True + target_index = next( + idx for idx, i in enumerate(instructions) if i.offset == offset + ) + assert instructions[target_index].opname == "SETUP_WITH" + convert_locals_to_cells(instructions, code_options) + + stack_depth_before = nstack + stack_effect(instructions[target_index].opcode, + instructions[target_index].arg) + + inside_with = [] + inside_with_resume_at = None + stack_depth = stack_depth_before + idx = target_index + 1 + for idx in range(idx, len(instructions)): + inst = instructions[idx] + if inst.opname == "BEGIN_FINALLY": + inside_with_resume_at = inst + break + elif inst.target is not None: + unimplemented("jump from with not supported") + elif inst.opname in ("BEGIN_FINALLY", "WITH_CLEANUP_START", "WITH_CLEANUP_FINISH", "END_FINALLY", + "POP_FINALLY", "POP_EXCEPT", + "POP_BLOCK", "END_ASYNC_FOR"): + unimplemented("block ops not supported") + inside_with.append(inst) + stack_depth += stack_effect(inst.opcode, inst.arg) + assert inside_with_resume_at + + instructions = [ + create_instruction("LOAD_FAST", f"___stack{i}") for i in range(nstack) + ] + [ + create_instruction("SETUP_WITH", target=instructions[target_index].target) + ... call the function ... + unpack_tuple + ] + [ + create_instruction("JUMP_ABSOLUTE", target=inside_with_resume_at) + ] +""" diff --git a/lib/python3.10/site-packages/torch/_dynamo/side_effects.py b/lib/python3.10/site-packages/torch/_dynamo/side_effects.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1bd976ff57df31c87c38b58b60f9f886cac4db --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/side_effects.py @@ -0,0 +1,701 @@ +# mypy: allow-untyped-defs +import functools +import inspect +import warnings +from collections.abc import MutableMapping +from typing import Any, Dict, List, Optional, Type, Union + +import torch.nn + +from . import utils, variables +from .bytecode_transformation import ( + bytecode_from_template, + create_call_function, + create_call_method, + create_instruction, +) +from .codegen import PyCodegen +from .exc import unimplemented +from .source import GlobalSource, LocalSource, Source +from .utils import is_frozen_dataclass, nn_module_new, object_new +from .variables.base import ( + is_side_effect_safe, + MutableLocalBase, + MutableLocalSource, + VariableTracker, +) +from .variables.user_defined import FrozenDataClassVariable + + +class MutableSideEffects(MutableLocalBase): + """ + VariableTracker.mutable_local marker to indicate a list passed as + an input that if we mutate we need to re-apply those mutations after + the graph runs. + """ + + def __init__(self, source: Source, is_modified: bool = False): + super().__init__(MutableLocalSource.Existing) + self.source = source + self.is_modified = is_modified + + +class AttributeMutation(MutableLocalBase): + """ + VariableTracker.mutable_local marker to track changes to attributes + """ + + def __init__(self, typ: MutableLocalSource, source: Optional[Source]): + super().__init__(typ) + self.source = source + + +class AttributeMutationExisting(AttributeMutation): + def __init__(self, source: Source): + super().__init__(MutableLocalSource.Existing, source) + self.source = source + + +class AttributeMutationNew(AttributeMutation): + def __init__(self, source: Optional[Source], cls_source: Optional[Source]): + super().__init__(MutableLocalSource.Local, source) + self.cls_source = cls_source + + +def _manual_update_dict(dict_from, dict_to): + for k, v in dict_from.items(): + dict_to[k] = v + + +class SideEffects: + """ + Track side effects (list mutation, setattr, etc) that need to be + applied after an FX graph is run. + """ + + id_to_variable: Dict[int, VariableTracker] + store_attr_mutations: Dict[MutableLocalBase, Dict[str, VariableTracker]] + keepalive: List[Any] + + def __init__( + self, + id_to_variable=None, + store_attr_mutations=None, + keepalive=None, + save_for_backward=None, + tensor_hooks=None, + ): + super().__init__() + self.id_to_variable = id_to_variable or {} + self.store_attr_mutations = store_attr_mutations or {} + self.keepalive = keepalive or [] + self.save_for_backward = save_for_backward or [] + self.tensor_hooks = tensor_hooks or {} + # Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph. + # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd. + self.ca_final_callbacks_var = None + + def __eq__(self, other: object) -> bool: + assert isinstance(other, SideEffects) + # NB: do NOT test keepalive + return ( + self.id_to_variable == other.id_to_variable + and self.store_attr_mutations == other.store_attr_mutations + and self.save_for_backward == other.save_for_backward + and self.tensor_hooks == other.tensor_hooks + ) + + def diff(self, other: "SideEffects") -> Optional[str]: + if self.id_to_variable != other.id_to_variable: + sk_itv = self.id_to_variable.keys() + ok_itv = other.id_to_variable.keys() + if sk_itv != ok_itv: + return f"id_to_variable keys: {sk_itv} != {ok_itv}" + # Feel free to augment this with more fancy diffing logic + # if needed for debugging + return "id_to_variable: unknown diff" + elif self.store_attr_mutations != other.store_attr_mutations: + sk_sam = self.store_attr_mutations.keys() + ok_sam = other.store_attr_mutations.keys() + if sk_sam != ok_sam: + return f"store_attr_mutations keys: {sk_sam} != {ok_sam}" + return "store_attr_mutations: unknown diff" + elif self.save_for_backward != other.save_for_backward: + return "save_for_backward" + elif self.tensor_hooks != other.tensor_hooks: + return "tensor_hooks" + else: + return None + + def clone(self): + """Create a shallow copy""" + return self.__class__( + id_to_variable=dict(self.id_to_variable), + store_attr_mutations={ + k: dict(v) for k, v in self.store_attr_mutations.items() + }, + keepalive=list(self.keepalive), + save_for_backward=self.save_for_backward, + tensor_hooks=self.tensor_hooks, + ) + + def __contains__(self, item): + return id(item) in self.id_to_variable + + def __getitem__(self, item): + return self.id_to_variable[id(item)] + + def check_allowed_side_effect(self, item): + from torch._dynamo.variables.misc import AutogradFunctionContextVariable + + # People do things like self.dim = dim inside autograd.Function. + # These are benign. + if isinstance(item, AutogradFunctionContextVariable): + return True + if not is_side_effect_safe(item.mutable_local): + unimplemented( + "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)" + ) + + def store_attr(self, item: VariableTracker, name: str, value: VariableTracker): + assert self.is_attribute_mutation(item) + self.check_allowed_side_effect(item) + if item.mutable_local not in self.store_attr_mutations: + self.store_attr_mutations[item.mutable_local] = {} + self.store_attr_mutations[item.mutable_local][name] = value + + def load_attr(self, item, name, deleted_ok=False): + assert self.is_attribute_mutation(item) + result = self.store_attr_mutations[item.mutable_local][name] + if not deleted_ok and isinstance(result, variables.DeletedVariable): + unimplemented("read deleted attribute") + return result + + def store_cell(self, cellvar, value): + assert isinstance(cellvar, variables.NewCellVariable) + assert isinstance(value, variables.VariableTracker) + self.store_attr(cellvar, "cell_contents", value) + + def load_cell(self, cellvar): + assert isinstance(cellvar, variables.NewCellVariable) + return self.load_attr(cellvar, "cell_contents") + + def load_global(self, gvar: VariableTracker, name: str): + assert isinstance(gvar, variables.VariableTracker) + return self.load_attr(gvar, name) + + def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker): + assert isinstance(gvar, variables.VariableTracker) + assert isinstance(value, variables.VariableTracker) + self.store_attr(gvar, name, value) + + @staticmethod + def cls_supports_mutation_side_effects(cls): + return ( + inspect.getattr_static(cls, "__getattribute__", None) + is object.__getattribute__ + ) + + def is_attribute_mutation(self, item): + return isinstance(item.mutable_local, AttributeMutation) + + def has_pending_mutation(self, item): + return self.is_attribute_mutation(item) and bool( + self.store_attr_mutations.get(item.mutable_local) + ) + + def has_pending_mutation_of_attr(self, item, name): + return self.is_attribute_mutation( + item + ) and name in self.store_attr_mutations.get(item.mutable_local, ()) + + def is_modified(self, item): + if isinstance(item.mutable_local, AttributeMutationNew): + return True + if self.is_attribute_mutation(item): + return item.mutable_local in self.store_attr_mutations + return item.mutable_local.is_modified + + def _track_obj( + self, + item: Any, + variable: VariableTracker, + mutable_cls=MutableSideEffects, + ): + """Start tracking a new variable for mutation""" + assert variable.source is not None + + if id(item) in self.id_to_variable: + raise AssertionError( + f"{variable} is already tracked for mutation. This could be " + "because you are not using VariableBuilder to construct " + "the variable tracker. " + f"Source of new object: {variable.source}. " + f"Source of previously tracked object: {self.id_to_variable[id(item)].source}." + ) + + variable.mutable_local = mutable_cls(variable.source) + self.id_to_variable[id(item)] = variable + self.keepalive.append(item) + return variable + + track_mutable = _track_obj + + def track_object_existing( + self, + item: Any, + variable: VariableTracker, + ): + return self._track_obj(item, variable, mutable_cls=AttributeMutationExisting) + + def track_object_new( + self, + cls_source: Source, + user_cls: Any, + variable_cls: Any, + options, + ): + if user_cls is torch.autograd.function.FunctionCtx: + with warnings.catch_warnings(record=True): + obj = torch.autograd.Function() + elif issubclass(user_cls, torch.nn.Module): + obj = nn_module_new(user_cls) + else: + obj = object_new(user_cls) + variable = variable_cls( + obj, + mutable_local=AttributeMutationNew(None, cls_source), + **options, + ) + self.id_to_variable[id(obj)] = variable + self.keepalive.append(obj) + return variable + + def track_object_new_from_user_defined_class( + self, + cls_variable: "variables.UserDefinedClassVariable", + ): + cls_source = cls_variable.source + user_cls = cls_variable.value + + # Find the variable class + variable_cls: Type[ + variables.UserDefinedObjectVariable + ] = variables.UserDefinedObjectVariable + if issubclass(user_cls, torch.nn.Module): + variable_cls = variables.UnspecializedNNModuleVariable + elif issubclass(user_cls, MutableMapping): + variable_cls = variables.MutableMappingVariable + elif is_frozen_dataclass(user_cls): + variable_cls = FrozenDataClassVariable + else: + variable_cls = variables.UserDefinedObjectVariable + + assert issubclass(variable_cls, variables.UserDefinedObjectVariable) + + variable_cls = functools.partial(variable_cls, cls_source=cls_source) + + return self.track_object_new(cls_source, user_cls, variable_cls, {}) + + def track_cell_new( + self, + ): + obj = object() + variable = variables.NewCellVariable( + mutable_local=AttributeMutationNew(None, None), + ) + self.id_to_variable[id(obj)] = variable + self.keepalive.append(obj) + return variable + + def track_cell_existing(self, source: Source, item: Any): + variable = variables.NewCellVariable( + mutable_local=AttributeMutationExisting(source), + ) + self.id_to_variable[id(item)] = variable + self.keepalive.append(item) + return variable + + def track_global_existing(self, source: Source, item: Any): + variable = variables.NewGlobalVariable( + mutable_local=AttributeMutationExisting(source), + ) + self.id_to_variable[id(item)] = variable + self.keepalive.append(item) + return variable + + def track_save_for_backward(self, ctx, args): + assert isinstance(ctx, variables.AutogradFunctionContextVariable) + self.save_for_backward.append((ctx, args)) + + def track_tensor_variables_from_runahead_side_effects(self, other): + # In higher order ops we want to keep track of tensors seen in the + # speculate_subgraph so that we don't lift them again as a new input in + # other speculate_subgraph or in the root tracer. + for other_item in other.keepalive: + other_id = id(other_item) + other_variable = other.id_to_variable[other_id] + if other_id not in self.id_to_variable and isinstance( + other_variable, variables.TensorVariable + ): + self.track_object_existing(other_item, other_variable) + + def prune_dead_object_new(self, tx): + live_new_objects = set() + + # use this to avoid cycles in mutable_local (though I'm not sure if that + # can actually happen). + visited: Any = set({}) + + def visit(var: VariableTracker): + mutable_local = var.mutable_local + if mutable_local is None: + return + if mutable_local in visited: + return + visited.add(mutable_local) + # Object may have been mutated, store this mutation. + if isinstance(mutable_local, AttributeMutationNew): + live_new_objects.add(mutable_local) + # It's possible that we have mutated the value of this variable + # to be another one. The new value is in store_attr_mutations. + # Also recurse through the new value to detect alive AttributeMutationNew. + if var.mutable_local in self.store_attr_mutations: + VariableTracker.visit( + visit, self.store_attr_mutations[var.mutable_local] + ) + + def is_live(var: Union[MutableLocalBase, VariableTracker]): + if isinstance(var, AttributeMutationNew): + return var in live_new_objects + if isinstance(var, VariableTracker): + return is_live(var.mutable_local) + return True + + pre_existing_vars = [ + var + for var in self.id_to_variable.values() + if not isinstance(var.mutable_local, AttributeMutationNew) + ] + + # The only live side effects come from returns (tx.stack), any intermediates + # during a graph break (tx.symbolic_locals), and mutation on pre-existing variables. + # Recursively visit Variables and see if any of them have been mutated. + VariableTracker.visit(visit, (tx.stack, tx.symbolic_locals, pre_existing_vars)) + + # NB: cell variable handling.is tricky. + # cell variables must stay alive if any NestedUserFunctionVariable + # are live. "visit"-ing the NestedUserFunctionVariable visits + # the .closures field, from which we will see if we need to keep + # any mutations to cell variables alive. + + self.id_to_variable = { + k: v for k, v in self.id_to_variable.items() if is_live(v) + } + self.store_attr_mutations = { + k: v for k, v in self.store_attr_mutations.items() if is_live(k) + } + + def mutation(self, var): + self.check_allowed_side_effect(var) + if isinstance(var.mutable_local, MutableSideEffects): + var.mutable_local = MutableSideEffects(var.mutable_local.source, True) + + def _get_modified_vars(self): + return [var for var in self.id_to_variable.values() if self.is_modified(var)] + + def codegen_save_tempvars(self, cg: PyCodegen): + for var in self._get_modified_vars(): + if isinstance( + var.mutable_local, (AttributeMutationExisting, AttributeMutationNew) + ) and isinstance(var, variables.NewCellVariable): + cg.add_push_null( + lambda: cg.load_import_from(utils.__name__, "make_cell") + ) + cg.extend_output(create_call_function(0, False)) + cg.add_cache(var) + if isinstance(var.mutable_local, AttributeMutationNew): + var.mutable_local.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] + elif isinstance(var.mutable_local, AttributeMutationNew): + if isinstance(var, variables.AutogradFunctionContextVariable): + unimplemented("AutogradFunctionContextVariable escaped") + cg.add_push_null( + lambda: cg.load_import_from(utils.__name__, "object_new") + ) + cg(var.mutable_local.cls_source) + cg.extend_output(create_call_function(1, False)) + cg.add_cache(var) + var.mutable_local.source = LocalSource(cg.tempvars[var]) + elif var in cg.tempvars: + assert cg.tempvars.get(var) is None + # subsequent usage should point to the original variable + cg(var.mutable_local.source) + cg.add_cache(var) + + for ctx, args in self.save_for_backward: + cg(ctx.source) + cg.load_method("save_for_backward") + for arg in args: + cg(arg) + cg.extend_output( + [ + *create_call_method(len(args)), + create_instruction("POP_TOP"), + ] + ) + + def register_hook(self, tensor, hook, handle, name): + assert isinstance(tensor, variables.TensorVariable) + assert isinstance(hook, variables.VariableTracker) + assert ( + isinstance(handle, variables.RemovableHandleVariable) + and handle.mutable_local + ) + assert hasattr(torch.Tensor, name) + idx = len(self.tensor_hooks.keys()) + # duplicate index possible because of self.remove_hook() + while idx in self.tensor_hooks: + idx += 1 + self.tensor_hooks[idx] = (tensor, hook, handle, name) + assert not handle.idx + handle.idx = idx + + def remove_hook(self, idx): + del self.tensor_hooks[idx] + + def codegen_hooks(self, cg): + for ( + tensor, + hook, + handle, + name, + ) in self.tensor_hooks.values(): + # Note: [On tensor.register_hook] + # + # register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented + # when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries). + # + # For tensors with a source, we bypass direct inclusion of register_hook calls in the graph. + # Instead, these are tracked and stashed as a global variable, enabling their association with tensors in + # the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able + # tensors. Because a source indicates knowledge of this object outside the torch compile region, and + # because we are running residuals firmly before .backward() can be run, it is sound to invoke + # `register_hook` on a known tensor. + # + # For tensors without a source, we support a limited subset of hooks. Global functions only, and + # compiled_autograd must be enabled or we will graph break. + # + # Handling the Handle: When a user retains the register_hook result in a handle, we intercept the + # STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed + # bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the + # stack intact. + # + # Dynamo Tensor Hooks Workflow: + # - Functions passed to register_hook are lifted globally. + # - For tensors with sources: + # - In the "side_effects" phase of codegen, we iterate over tensors with hooks to: + # - Generate the tensor. + # - Issue a register_hook call on the tensor, linking to the globally stored function. + # - Incorporate a handle if one was established in the eager phase. + # - For tensors without sources: + # - We don't generate any instructions for registering a hook. + # - Handles from intermediary hooks are NYI. + # - We produce a call function that utilizes the trace_wrapped higher order op, closing over it. + # - We then manually insert the call function above into the graph. + # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST. + assert tensor.source, "Hooks on non input tensors NYI - should not get here" + + def gen_fn(): + cg(tensor) + cg.extend_output([cg.create_load_attr(name)]) + + cg.add_push_null(gen_fn) + cg(hook) + cg.extend_output(create_call_function(1, False)) + + # Adding the handle to the cache means RemovableHandleVariable().reconstruct() will + # be associated with the return value of register_hook(). This consumes the top of stack. + cg.add_cache(handle) + + def get_ca_final_callbacks_var(self): + from .variables.base import MutableLocal + + if self.ca_final_callbacks_var is None: + self.ca_final_callbacks_var = variables.ListVariable( + [], mutable_local=MutableLocal() + ) + return self.ca_final_callbacks_var + + def codegen_update_mutated(self, cg: PyCodegen): + suffixes = [] + for var in self._get_modified_vars(): + if isinstance(var, variables.ListVariable): + # old[:] = new + cg(var, allow_cache=False) + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.extend_output( + [ + cg.create_load_const(None), + cg.create_load_const(None), + create_instruction("BUILD_SLICE", arg=2), + ] + ) + suffixes.append([create_instruction("STORE_SUBSCR")]) + elif isinstance(var, variables.CustomizedDictVariable): + # need to update the dict manually since update method may be invalid + varname_map = {} + for name in _manual_update_dict.__code__.co_varnames: + varname_map[name] = cg.tx.output.new_var() + + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.extend_output( + [create_instruction("STORE_FAST", argval=varname_map["dict_to"])] + ) + + cg(var, allow_cache=False) + cg.extend_output( + [create_instruction("STORE_FAST", argval=varname_map["dict_from"])] + ) + + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.load_method("clear") + + # unfortunately can't just use DICT_MERGE due to possible custom behaviors + dict_update_insts = bytecode_from_template( + _manual_update_dict, varname_map=varname_map + ) + + suffixes.append( + [ + *create_call_method(0), # clear + create_instruction("POP_TOP"), + *dict_update_insts, + create_instruction("POP_TOP"), + ] + ) + + elif isinstance(var, variables.ConstDictVariable): + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.load_method("update") + cg(var, allow_cache=False) + + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.load_method("clear") + + suffixes.append( + [ + *create_call_method(0), # clear + create_instruction("POP_TOP"), + *create_call_method(1), # update + create_instruction("POP_TOP"), + ] + ) + elif isinstance( + var, variables.torch_function.TorchFunctionModeStackVariable + ): + cg.add_push_null( + lambda: cg.load_import_from( + utils.__name__, "set_torch_function_mode_stack" + ) + ) + cg.foreach(var.symbolic_stack) + cg.append_output( + create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) + ) + cg.call_function(1, False) + cg.append_output(create_instruction("POP_TOP")) + elif self.is_attribute_mutation(var): + # Applying mutations involves two steps: 1) Push all + # reconstructed objects onto the stack. 2) Call STORE_ATTR to + # apply the mutations. + # + # Dynamo must ensure that mutations are applied in the same + # order as in the original program. Therefore, two reverse + # operations occur below. + # + # The first reverse operation concerns `suffixes`. We apply + # suffixes in reverse order due to the way Python handles the + # stack. In Step 1, we push all reconstructed objects onto the + # stack, but the item at the top of the stack refers to the last + # attribute in the mutation order. If not fixed, this will apply + # the mutations of attributes in the reverse order. To account + # for this reversal, we iterate through the mutable attributes + # in reverse order. + for name, value in reversed( + self.store_attr_mutations.get(var.mutable_local, {}).items() + ): + if isinstance(var, variables.NewGlobalVariable): + cg.tx.output.update_co_names(name) + cg(value) + assert isinstance(var.mutable_local.source, GlobalSource) # type: ignore[attr-defined] + suffixes.append( + [create_instruction("STORE_GLOBAL", argval=name)] + ) + elif isinstance(value, variables.DeletedVariable): + if isinstance( + var.mutable_local, AttributeMutationExisting + ) and hasattr(getattr(var, "value", None), name): + cg.tx.output.update_co_names(name) + cg(var.mutable_local.source) + suffixes.append( + [create_instruction("DELETE_ATTR", argval=name)] + ) + elif ( + isinstance(var, variables.UserDefinedObjectVariable) + and var.needs_slow_setattr() + ): + # __setattr__ is defined on this object, so call object.__setattr__ directly + cg.load_import_from("builtins", "object") + cg.load_method("__setattr__") + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg(variables.ConstantVariable(name)) + cg(value) + suffixes.append( + [*create_call_method(3), create_instruction("POP_TOP")] + ) + else: + cg.tx.output.update_co_names(name) + cg(value) + cg(var.mutable_local.source) + suffixes.append([create_instruction("STORE_ATTR", argval=name)]) + elif isinstance(var, variables.TupleIteratorVariable): + for _ in range(var.index): + cg.add_push_null( + lambda: cg.load_import_from(utils.__name__, "iter_next") + ) + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.call_function(1, False) + cg.pop_top() + elif isinstance(var, variables.RandomVariable): + # set correct random seed state + def gen_fn(): + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.load_attr("setstate") + + cg.add_push_null(gen_fn) + cg(var.wrap_state(var.random.getstate())) + + suffixes.append( + [ + *create_call_function(1, False), # setstate + create_instruction("POP_TOP"), + ] + ) + else: + raise AssertionError(type(var)) + + # do all the actual mutations at the very end to handle dependencies + for suffix in reversed(suffixes): + cg.extend_output(suffix) + + def is_empty(self): + return not ( + any(map(self.is_modified, self.id_to_variable.values())) + or self.tensor_hooks + or self.save_for_backward + or self.tensor_hooks + ) + + def clear(self): + self.keepalive.clear() + self.id_to_variable.clear() diff --git a/lib/python3.10/site-packages/torch/_dynamo/source.py b/lib/python3.10/site-packages/torch/_dynamo/source.py new file mode 100644 index 0000000000000000000000000000000000000000..2d3a4424167da9a75d2f62bd51e2813de0baeadd --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/source.py @@ -0,0 +1,759 @@ +# mypy: allow-untyped-defs +import collections +import dataclasses +import enum +from typing import Any, Optional, Union + +from torch._guards import ChainedSource, GuardSource, Source + +from . import utils +from .bytecode_transformation import create_call_function, create_instruction +from .utils import enum_repr + + +# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module, +# so those cases are omitted intentionally + +# represents nn.Modules tracked with NNModuleVariable (specialized is implicit in the variable name) +_GUARD_SOURCE_SPECIALIZED_NN_MODULE = { + GuardSource.LOCAL: GuardSource.LOCAL_SPECIALIZED_NN_MODULE, + GuardSource.GLOBAL: GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, + GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_SPECIALIZED_NN_MODULE, + GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, + # Just to ensure that guard_source() works + GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, +} + +# represents nn.Modules tracked with UnspecializedNNModuleVariable +_GUARD_SOURCE_UNSPECIALIZED_NN_MODULE = { + GuardSource.LOCAL: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, + GuardSource.GLOBAL: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, + # this happens for an UnspecializedNNModule submodule on a NNModuleVariable + GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, + GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, + # Just to ensure that guard_source() works + GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, +} + +# represents nn.Modules tracked with UnspecializedBuiltinNNModuleVariable +_GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE = { + GuardSource.LOCAL: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.GLOBAL: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + # Just to ensure that guard_source() works + GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, +} + +_GUARD_SOURCE_FSDP_MODULE = { + GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE, + GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, +} + + +def is_constant_source(source): + if isinstance(source, ConstantSource): + return True + try: + if source.guard_source() == GuardSource.CONSTANT: + return True + except NotImplementedError: + pass + + return False + + +def reconstruct_getitem( + source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice +): + source.base.reconstruct(codegen) + if isinstance(source.index, Source): + source.index.reconstruct(codegen) + else: + if index_is_slice: + assert isinstance(source, GetItemSource) + codegen.append_output(codegen.create_load_const(source.unpack_slice())) + else: + codegen.append_output(codegen.create_load_const(source.index)) + + +@dataclasses.dataclass(frozen=True) +class LocalSource(Source): + local_name: str + cell_or_freevar: bool = False + + def reconstruct(self, codegen): + codegen.append_output(codegen.create_load(self.local_name)) + + def guard_source(self): + return GuardSource.LOCAL + + def name(self): + return f"L[{repr(self.local_name)}]" + + +@dataclasses.dataclass(frozen=True) +class SyntheticLocalSource(Source): + local_name: str + + def reconstruct(self, codegen): + codegen.append_output(codegen.create_load(self.local_name)) + + def guard_source(self): + return GuardSource.SYNTHETIC_LOCAL + + def name(self): + return f"SYNTHETIC_LOCAL[{self.local_name!r}]" + + +@dataclasses.dataclass(frozen=True) +class RandomValueSource(Source): + random_call_index: int + + def guard_source(self): + return GuardSource.RANDOM_VALUE + + def reconstruct(self, codegen): + codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var)) + codegen.append_output(codegen.create_load_const(self.random_call_index)) + codegen.append_output(create_instruction("BINARY_SUBSCR")) + + def name(self): + return f"random_value_{self.random_call_index}" + + +@dataclasses.dataclass(frozen=True) +class GlobalSource(Source): + global_name: str + + def reconstruct(self, codegen): + codegen.append_output(codegen.create_load_global(self.global_name, add=True)) + + def guard_source(self): + return GuardSource.GLOBAL + + def name(self): + return f"G[{repr(self.global_name)}]" + + +@dataclasses.dataclass(frozen=True) +class GlobalWeakRefSource(Source): + global_name: str + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.append_output( + codegen.create_load_global(self.global_name, add=True) + ) + ) + codegen.extend_output(create_call_function(0, False)) + + def guard_source(self): + return GuardSource.GLOBAL + + def name(self): + return f"G[{repr(self.global_name)}]()" + + +@dataclasses.dataclass(frozen=True) +class WeakRefCallSource(ChainedSource): + def reconstruct(self, codegen): + codegen.add_push_null(lambda: self.base.reconstruct(codegen)) + codegen.extend_output(create_call_function(0, False)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}()" + + +@dataclasses.dataclass(frozen=True) +class AttrSource(ChainedSource): + member: str + + def __post_init__(self): + assert self.base, "Can't construct an AttrSource without a valid base source" + if "." in self.member: + member_parts = self.member.split(".") + object.__setattr__( + self, "base", AttrSource(self.base, ".".join(member_parts[:-1])) + ) + object.__setattr__(self, "member", member_parts[-1]) + + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + codegen.extend_output(codegen.create_load_attrs(self.member)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + if not self.member.isidentifier(): + return f"getattr({self.base.name()}, {self.member!r})" + return f"{self.base.name()}.{self.member}" + + +# Represents tensor.grad source. It could be represented by AttrSource as well. +# But, we could access grad field on tensor directly in C++ without going +# through the Python bytecodes. Therefore, we use a separate source for grad +# field. +@dataclasses.dataclass(frozen=True) +class GradSource(ChainedSource): + member: str = "grad" + + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + codegen.extend_output(codegen.create_load_attrs(self.member)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}.{self.member}" + + +@dataclasses.dataclass(frozen=True) +class ParamBufferSource(AttrSource): + def guard_source(self): + return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] + + +# Special AttrSource to differentiate module._buffers or module._parameters +@dataclasses.dataclass(frozen=True) +class UnspecializedParamBufferSource(AttrSource): + pass + + +# This source is intended to be used in places where a source is needed but it is expected +# that the symbol will be simplified out later on. Symbols with ephemeral sources are +# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral +# source. Guarding on this source is an error. +# +# Example: During subclass view fake-ification, any close-over ViewFunc state should be +# symbolicized / fake-ified to avoid invalid specialization during view replay. This source +# is useful for symbols utilized in the middle of the view chain that are not expected to be +# present within the final view shape metadata. +@dataclasses.dataclass(frozen=True) +class EphemeralSource(Source): + desc: Optional[str] = None + + def guard_source(self): + return GuardSource.EPHEMERAL + + def name(self): + return f"" + + def make_guard(self): + raise NotImplementedError + + def is_ephemeral(self): + return True + + +class TensorProperty(enum.Enum): + SIZE = 0 + STRIDE = 1 + STORAGE_OFFSET = 2 + + def method_name(self): + if self is TensorProperty.SIZE: + return "size" + elif self is TensorProperty.STRIDE: + return "stride" + elif self is TensorProperty.STORAGE_OFFSET: + return "storage_offset" + + +@dataclasses.dataclass(frozen=True) +class TensorPropertySource(ChainedSource): + prop: TensorProperty + idx: Optional[int] = None # None for STORAGE_OFFSET + + def __post_init__(self): + assert self.base is not None + if self.prop is TensorProperty.STORAGE_OFFSET: + assert self.idx is None + else: + assert self.idx is not None + + def reconstruct(self, codegen): + def gen_fn(): + self.base.reconstruct(codegen) + codegen.append_output(codegen.create_load_attr(self.prop.method_name())) + + codegen.add_push_null(gen_fn) + if self.idx is not None: + codegen.append_output(codegen.create_load_const(self.idx)) + codegen.extend_output( + create_call_function(1 if self.idx is not None else 0, False) + ) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + if self.prop is TensorProperty.SIZE: + return f"{self.base.name()}.size()[{self.idx}]" + elif self.prop is TensorProperty.STRIDE: + return f"{self.base.name()}.stride()[{self.idx}]" + elif self.prop is TensorProperty.STORAGE_OFFSET: + assert self.idx is None + return f"{self.base.name()}.storage_offset()" + else: + raise AssertionError(f"unhandled {self.prop}") + + +@dataclasses.dataclass(frozen=True) +class NegateSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen): + raise NotImplementedError + + def guard_source(self): + return self.base.guard_source() + + def name(self): + # NB: use method call so that function stripping regexes work + return f"{self.base.name()}.__neg__()" + + +@dataclasses.dataclass(frozen=True) +class ConvertIntSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"cast_symbool_to_symint_guardless({self.base.name()})" + + +@dataclasses.dataclass(frozen=True) +class FlattenScriptObjectSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}.__obj_flatten__()" + + +@dataclasses.dataclass(frozen=True) +class ScriptObjectQualifiedNameSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}._type().qualified_name()" + + +class AttrProxySource(ChainedSource): + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}.get_base()" + + +@dataclasses.dataclass(frozen=True) +class DefaultsSource(ChainedSource): + idx_key: Union[int, str] + is_kw: bool = False + field: str = dataclasses.field(init=False, repr=False, compare=False) + _name: str = dataclasses.field(init=False, repr=False, compare=False) + + def __post_init__(self): + assert ( + self.base + ), "Base must be a valid source in order to properly track and guard this Defaults to its origin." + if self.is_kw: + assert isinstance(self.idx_key, str) + object.__setattr__(self, "field", "__kwdefaults__") + object.__setattr__( + self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']" + ) + else: + assert isinstance(self.idx_key, int) + object.__setattr__(self, "field", "__defaults__") + object.__setattr__( + self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" + ) + + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + codegen.extend_output(codegen.create_load_attrs(self.field)) + codegen.append_output(codegen.create_load_const(self.idx_key)) + codegen.append_output(create_instruction("BINARY_SUBSCR")) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return self._name + + +@dataclasses.dataclass(frozen=True) +class GetItemSource(ChainedSource): + index: Any + index_is_slice: bool = False + + def __post_init__(self): + assert self.base is not None + if isinstance(self.index, slice): + # store the hashable version of the slice so the whole GetItemSource is hashable + super().__setattr__("index", self.index.__reduce__()) + super().__setattr__("index_is_slice", True) + + def reconstruct(self, codegen): + reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice) + codegen.append_output(create_instruction("BINARY_SUBSCR")) + + def guard_source(self): + return self.base.guard_source() + + def unpack_slice(self): + assert self.index_is_slice + slice_class, slice_args = self.index + return slice_class(*slice_args) + + def name(self): + # Index can be of following types + # 1) ConstDictKeySource + # 2) enum.Enum + # 3) index is a slice - example 1:4 + # 4) index is a constant - example string, integer + if isinstance(self.index, Source): + if not isinstance(self.index, ConstDictKeySource): + raise ValueError( + "GetItemSource index must be a constant, enum or ConstDictKeySource" + ) + return f"{self.base.name()}[{self.index.name()}]" + elif self.index_is_slice: + return f"{self.base.name()}[{self.unpack_slice()!r}]" + elif isinstance(self.index, enum.Enum): + return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]" + else: + return f"{self.base.name()}[{self.index!r}]" + + +@dataclasses.dataclass(frozen=True) +class ConstDictKeySource(GetItemSource): + def is_dict_key(self): + return True + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem") + ) + self.base.reconstruct(codegen) + codegen.append_output(codegen.create_load_const(self.index)) + codegen.extend_output(create_call_function(2, False)) + + def name(self): + # The list creation will be CSE'd by PyExprCSEPass + return f"list({self.base.name()}.keys())[{self.index!r}]" + + +@dataclasses.dataclass(frozen=True) +class TupleIteratorGetItemSource(GetItemSource): + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") + ) + self.base.reconstruct(codegen) + codegen.append_output(codegen.create_load_const(self.index)) + codegen.extend_output(create_call_function(2, False)) + + def name(self): + return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" + + +@dataclasses.dataclass(frozen=True) +class TypeSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type")) + self.base.reconstruct(codegen) + codegen.extend_output(create_call_function(1, False)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"type({self.base.name()})" + + +@dataclasses.dataclass(frozen=True) +class ODictGetItemSource(ChainedSource): + index: Any + + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.append_output( + codegen._create_load_const(collections.OrderedDict.__getitem__) + ) + ) + reconstruct_getitem(self, codegen, index_is_slice=False) + codegen.extend_output(create_call_function(2, False)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + if isinstance(self.index, type): + rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}' + return f"___odict_getitem({self.base.name()}, {rep})" + elif isinstance(self.index, Source): + return f"___odict_getitem({self.base.name()}, {self.index.name()})" + else: + return f"___odict_getitem({self.base.name()}, {self.index!r})" + + +@dataclasses.dataclass(frozen=True) +class OptimizerSource(ChainedSource): + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return self.base.name() + + +@dataclasses.dataclass(frozen=True) +class NNModuleSource(ChainedSource): + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + + def guard_source(self): + return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] + + def name(self): + return self.base.name() + + +@dataclasses.dataclass(frozen=True) +class UnspecializedNNModuleSource(NNModuleSource): + def guard_source(self): + return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()] + + +@dataclasses.dataclass(frozen=True) +class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource): + def guard_source(self): + return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()] + + +@dataclasses.dataclass(frozen=True) +class FSDPNNModuleSource(NNModuleSource): + def guard_source(self): + return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()] + + +@dataclasses.dataclass(frozen=True) +class GlobalStateSource(Source): + def name(self): + return "" + + def guard_source(self): + return GuardSource.GLOBAL + + +@dataclasses.dataclass(frozen=True) +class TorchFunctionModeStackSource(Source): + ind: int + + def name(self): + return "" + + def _get_index(self): + from .variables.torch_function import TorchFunctionModeStackVariable + + return TorchFunctionModeStackVariable.get_mode_index(self.ind) + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.load_import_from( + utils.__name__, "get_torch_function_mode_stack_at" + ) + ) + codegen.extend_output([codegen.create_load_const(self._get_index())]) + codegen.extend_output(create_call_function(1, False)) + + def guard_source(self): + return GuardSource.GLOBAL + + +@dataclasses.dataclass(frozen=True) +class ConstantSource(Source): + source_name: str + + def reconstruct(self, codegen): + codegen.append_output(codegen.create_load_global(self.source_name, add=False)) + + def guard_source(self): + return GuardSource.CONSTANT + + def name(self): + return self.source_name + + def make_guard(self, fn): + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class NumpyTensorSource(ChainedSource): + def name(self) -> str: + return f"___from_numpy({self.base.name()})" + + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor")) + self.base.reconstruct(codegen) + codegen.extend_output(create_call_function(1, False)) + + +@dataclasses.dataclass(frozen=True) +class SubclassAttrListSource(ChainedSource): + def name(self) -> str: + return f"{self.base.name()}.__tensor_flatten__()[0]" + + def guard_source(self): + return self.base.guard_source() + + +# NB: We don't expect you to actually ever generate guards against this +# source, it is ephemeral +@dataclasses.dataclass(frozen=True) +class FloatTensorSource(ChainedSource): + def name(self) -> str: + return f"___as_tensor({self.base.name()})" + + def guard_source(self): + return self.base.guard_source() + + +@dataclasses.dataclass(frozen=True) +class CallMethodItemSource(ChainedSource): + def name(self) -> str: + return f"{self.base.name()}.item()" + + def guard_source(self): + return self.base.guard_source() + + +# This is a synthetic source that is associated with the singleton +# shape env guard we always register for all frames. We get the actual +# guard contents from the ambient ShapeEnv +@dataclasses.dataclass(frozen=True) +class ShapeEnvSource(Source): + def name(self): + return "" + + def guard_source(self): + return GuardSource.SHAPE_ENV + + +@dataclasses.dataclass(frozen=True) +class BackwardStateSource(Source): + def name(self): + return "" + + def guard_source(self): + return GuardSource.BACKWARD_STATE + + +def is_from_local_source(source: Source, *, allow_cell_or_freevar=True): + if isinstance(source, ChainedSource): + return is_from_local_source( + source.base, allow_cell_or_freevar=allow_cell_or_freevar + ) + if not isinstance(source, LocalSource): + return False + if not allow_cell_or_freevar and source.cell_or_freevar: + return False + return True + + +def is_from_unspecialized_param_buffer_source(source: Source): + if isinstance(source, UnspecializedParamBufferSource): + return True + if isinstance(source, ChainedSource): + return is_from_unspecialized_param_buffer_source(source.base) + return False + + +def is_from_flatten_script_object_source(source: Source): + if isinstance(source, FlattenScriptObjectSource): + return True + elif isinstance(source, ChainedSource): + return is_from_flatten_script_object_source(source.base) + return False + + +def is_from_optimizer_source(source: Source): + if isinstance(source, OptimizerSource): + return True + if isinstance(source, ChainedSource): + return is_from_optimizer_source(source.base) + return False + + +# TODO: can probably write a generic "test this on everything in the chain" +# helper +def is_from_defaults(source: Source): + if isinstance(source, DefaultsSource): + return True + if isinstance(source, ChainedSource): + return is_from_defaults(source.base) + return False + + +def is_cell_contents(source: Source): + return isinstance(source, AttrSource) and source.member == "cell_contents" diff --git a/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py b/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..ab92f82aa0f630dd3ca2ac84bae4b8ad000d421f --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py @@ -0,0 +1,3440 @@ +# mypy: allow-untyped-defs +import collections +import collections.abc +import contextlib +import copy +import dataclasses +import dis +import functools +import importlib +import inspect +import itertools +import linecache +import logging +import operator +import re +import sys +import threading +import traceback +import types +import typing +import weakref +from typing import ( + Any, + Callable, + cast, + Deque, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TYPE_CHECKING, + Union, +) +from unittest.mock import patch + +import torch +import torch._logging +from torch._guards import tracing, TracingContext + +from . import config, exc, logging as torchdynamo_logging, trace_rules, variables +from .bytecode_analysis import ( + get_indexof, + JUMP_OPNAMES, + livevars_analysis, + propagate_line_nums, +) +from .bytecode_transformation import ( + cleaned_instructions, + create_call_function, + create_instruction, + create_jump_absolute, + create_swap, + get_code_keys, + Instruction, + is_generator, + unique_id, +) +from .code_context import code_context +from .codegen import PyCodegen +from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported +from .funcname_cache import get_funcname +from .guards import GuardBuilder, install_guard +from .output_graph import GraphCompileReason, OutputGraph +from .replay_record import DummyModule, ExecutionRecorder +from .resume_execution import ContinueExecutionCache, ReenterWith +from .source import ( + AttrSource, + GetItemSource, + GlobalSource, + GlobalWeakRefSource, + LocalSource, + Source, + TorchFunctionModeStackSource, +) +from .trace_rules import is_builtin_constant, is_forbidden +from .utils import ( + counters, + get_fake_value, + get_instruction_source_311, + get_torch_function_mode_stack, + graph_break_dup_warning_checker, + istype, + LazyString, + proxy_args_kwargs, +) +from .variables.base import is_side_effect_safe, MutableLocal, typestr, VariableTracker +from .variables.builder import VariableBuilder, wrap_fx_proxy +from .variables.builtin import BuiltinVariable +from .variables.constant import ConstantVariable +from .variables.ctx_manager import ( + ContextWrappingVariable, + GenericContextWrappingVariable, + WithExitFunctionVariable, +) +from .variables.dicts import ConstDictVariable, SetVariable +from .variables.functions import ( + BaseUserFunctionVariable, + NestedUserFunctionVariable, + SkipFunctionVariable, + UserFunctionVariable, + UserMethodVariable, +) +from .variables.iter import MAX_ITERATOR_LIMIT +from .variables.lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + SliceVariable, + TupleVariable, +) +from .variables.misc import ( + ClosureVariable, + GetAttrVariable, + InlinedClosureVariable, + NullVariable, + PythonModuleVariable, + UnknownVariable, +) +from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable +from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable + + +if TYPE_CHECKING: + from .variables.torch_function import TorchFunctionModeVariable + +from .variables.user_defined import ( + RemovableHandleVariable, + UserDefinedClassVariable, + UserDefinedObjectVariable, +) + + +log = logging.getLogger(__name__) +graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") +trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") +trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source") +trace_bytecode_log = torch._logging.getArtifactLogger(__name__, "trace_bytecode") +tls = threading.local() +compare_op_handlers: Dict[str, Any] = { + k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items() +} +handle_contains = BuiltinVariable(operator.contains).call_function +handle_not = BuiltinVariable(operator.not_).call_function +compare_op_handlers["in"] = lambda tx, args, _: handle_contains( + tx, [*reversed(args)], {} +) +compare_op_handlers["not in"] = lambda tx, args, _: handle_not( + tx, [handle_contains(tx, [*reversed(args)], {})], {} +) + + +PT2_ISSUE_TRACKER_URL = "https://github.com/pytorch/pytorch/issues/new?&labels=oncall%3A+pt2&projects=&template=pt2-bug-report.yml" + + +@dataclasses.dataclass +class SpeculationEntry: + filename: str + lineno: int + instruction_pointer: int + inst: Instruction # for debugging only + failed: bool = False + reason: Optional[GraphCompileReason] = None + + def fail_and_restart_analysis(self): + """ + Start tracing of the current frame over again, and don't take this branch. + """ + self.failed = True + if self.reason is not None: + restart_reason = self.reason.reason + else: + restart_reason = "Unknown fail_and_restart_analysis" + raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason) + + +@dataclasses.dataclass +class SpeculationLog: + """ + SpeculationLog replaces the prior copy_graphstate/restore_graphstate + checkpointing. Rather than saving/restoring state, we restart the + dynamo conversion process over from the beginning -- but when we + hit the start of the speculation that failed, we instead generate + a graph break. + """ + + entries: List[SpeculationEntry] = dataclasses.field(default_factory=list) + index: int = 0 + + def restart(self): + self.index = 0 + + def clear(self): + self.entries.clear() + self.index = 0 + + def next( + self, filename: str, lineno: int, instruction_pointer, inst + ) -> SpeculationEntry: + """ + Lookup or create a SpeculationEntry() that is shared across + RestartAnalysis calls. Args are used only for debug checks. + """ + if len(self.entries) == self.index: + self.entries.append( + SpeculationEntry(filename, lineno, instruction_pointer, inst) + ) + entry = self.entries[self.index] + prev_entry_msg = "" + if self.index != 0: + prev_entry = self.entries[self.index - 1] + prev_entry_msg = ( + f"Previous instruction: {prev_entry.filename}:{prev_entry.lineno}" + f"({prev_entry.inst.opname} @ {prev_entry.instruction_pointer})\n" + ) + assert ( + entry.instruction_pointer == instruction_pointer + and entry.filename == filename + and entry.lineno == lineno + ), f""" +SpeculationLog diverged at index {self.index} (log had {len(self.entries)} entries): +- Expected: {entry.filename}:{entry.lineno} ({entry.inst.opname} at ip={entry.instruction_pointer}) +- Actual: {filename}:{lineno} ({inst.opname} at ip={instruction_pointer}) +{prev_entry_msg} +There are two usual reasons why this may have occured: +- When Dynamo analysis restarted, the second run took a different path than + the first. If this occurred, the previous instruction is the critical instruction that + behaved differently. +- Speculation entries are only added under certain conditions (as seen in + step()), e.g., there must exist operators in the graph; those conditions may + have changed on restart. + +If this divergence was intentional, clear the speculation log before restarting (do NOT +do this for graph breaks, you will infinite loop). + +Otherwise, please submit a bug report, ideally including the contents of TORCH_LOGS=+dynamo +""" + self.index += 1 + return entry + + +@dataclasses.dataclass +class LocalState: + input_sizes: Dict[str, List[int]] = dataclasses.field(default_factory=dict) + input_strides: Dict[str, List[int]] = dataclasses.field(default_factory=dict) + + +# Mutable box that is shared across restarts +@dataclasses.dataclass +class DistributedState: + compile_pg: Any + local_state: LocalState + all_states: Optional[List[LocalState]] = None + + +@functools.lru_cache(None) +def _step_logger(): + return torchdynamo_logging.get_step_logger(log) + + +@dataclasses.dataclass +class BlockStackEntry: + # Current instruction that pushes something to block_stack + inst: Instruction + target: Instruction + stack_index: Optional[int] = None + with_context: Optional[ + Union[ContextWrappingVariable, GenericContextWrappingVariable] + ] = None + + def can_restore(self): + return self.with_context is not None + + def resume_fn(self): + assert self.stack_index is not None + if ( + self.with_context + and hasattr(self.with_context, "target_values") + and self.with_context.target_values + ): + return ReenterWith(self.stack_index, tuple(self.with_context.target_values)) + else: + return ReenterWith(self.stack_index) + + def exit(self, tx): + assert self.with_context is not None + return self.with_context.exit(tx) + + +class ReturnValueOp(Exception): + pass + + +def stack_op(fn: typing.Callable[..., object]): + nargs = len(inspect.signature(fn).parameters) + fn_var = BuiltinVariable(fn) + + @functools.wraps(fn) + def impl(self: "InstructionTranslator", inst: Instruction): + self.push(fn_var.call_function(self, self.popn(nargs), {})) + + return impl + + +def _detect_and_normalize_assert_statement( + self: "InstructionTranslatorBase", + truth_fn: typing.Callable[[object], bool], + push: bool, +): + # Detect if this jump instruction is assert and normalize the assert + # by pushing dummy error message when nothing is given. + # + # Python 3.9 assertion is in following format: + # 18 POP_JUMP_IF_TRUE 28 + # 20 LOAD_ASSERTION_ERROR + # 22 LOAD_CONST 3 ('Assert message') -> optional instruction + # 24 CALL_FUNCTION 1 -> optional instruction + # 26 RAISE_VARARGS + # + # Python 3.8 assertion is in following format: + # 18 POP_JUMP_IF_TRUE 28 + # 20 LOAD_GLOBAL 0 (Assertion type) + # 22 LOAD_CONST 3 ('Assert message') -> optional instruction + # 24 CALL_FUNCTION 1 -> optional instruction + # 26 RAISE_VARARGS 1 + + if (truth_fn is not operator.truth) or push: + return False + + assert isinstance(self.instruction_pointer, int) + current_instruction_pointer = self.instruction_pointer + inst = self.instructions[current_instruction_pointer] + # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0 + if sys.version_info < (3, 9): + if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError": + return False + else: + if inst.opname != "LOAD_ASSERTION_ERROR": + return False + + current_instruction_pointer += 1 + + # Use dummy error message if its hard to extract + error_msg = "assertion error" + + inst = self.instructions[current_instruction_pointer] + # DETECT RAISE_VARARGS or LOAD CONST + if inst.opname == "LOAD_CONST": + if not isinstance(inst.argval, str): + return False + error_msg = inst.argval + + # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION + # (PRECALL for Python 3.11, CALL for Python 3.12+) + current_instruction_pointer += 1 + inst = self.instructions[current_instruction_pointer] + if inst.opname not in ("CALL_FUNCTION", "PRECALL", "CALL"): + return False + + # for Python 3.11, PRECALL should be followed by CALL, then RAISE_VARARGS + # for Python != 3.11, CALL_FUNCTION/CALL should be followed by RAISE_VARARGS + current_instruction_pointer += 1 + if inst.opname == "PRECALL": + current_instruction_pointer += 1 + inst = self.instructions[current_instruction_pointer] + + if inst.opname != "RAISE_VARARGS": + return False + + self.push(ConstantVariable.create(error_msg)) + + return True + + +def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): + def jump_graph_break(self, inst, value, extra_msg=""): + if not self.should_compile_partial_graph(): + unimplemented("should_compile_partial_graph=False") + # compile a partial subgraph prefix then jump into user code + if self.maybe_has_backedge(): + msg = ( + "Skipping frame because there is a graph break in a for/while loop\n" + f"{self.frame_summary()}" + ) + log.info(msg) + raise exc.SkipFrame(msg) + + self.push(value) + log.debug("generic_jump triggered compile") + self.output.compile_subgraph( + self, + reason=GraphCompileReason( + f"generic_jump {typestr(value)}{extra_msg}", [self.frame_summary()] + ), + ) + self.pop() + + if_next = self.create_call_resume_at(self.next_instruction) + if push: + self.push(value) + if_jump = self.create_call_resume_at(inst.target) + + if sys.version_info >= (3, 13): + # 3.13 requires stack[-1] to be bool type + self.output.add_output_instructions([create_instruction("TO_BOOL")]) + + self.output.add_output_instructions( + [create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump + ) + + def inner(self: "InstructionTranslatorBase", inst: Instruction): + value: VariableTracker = self.pop() + if ( + config.rewrite_assert_with_torch_assert + and _detect_and_normalize_assert_statement(self, truth_fn, push) + ): + error_msg: VariableTracker = self.pop() + # Skip over things like `assert True` + if value.is_python_constant(): + if bool(value.as_python_constant()): + return self.jump(inst) + else: + jump_graph_break(self, inst, value) + + # TODO maybe should respect DtoH sync intention of users later?? + # Manually insert torch._assert_async instead of python assert and jump over + # assert related instructions as we don't need them anymore. + + # if we see Tensor as assert statement, no need to call scalar_tensor + if isinstance(value, TensorVariable): + self.output.create_proxy( + "call_function", + torch._assert_async, + *proxy_args_kwargs((value, error_msg), {}), + ) + self.jump(inst) + return + + if isinstance(value, SymNodeVariable): + # if the assertion is normal shape expression. + # just install guard and bail out. + sym_expr = value.sym_num + if not isinstance(sym_expr, torch.SymBool): + sym_expr = sym_expr != 0 + + result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr) + if not result: + unimplemented( + "Assertion failed on symbolic shapes. Did you make sure eager mode succeeds?" + ) + self.jump(inst) + return + + scalar_to_tensor_proxy = self.output.create_proxy( + "call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {}) + ) + + scalar_to_tensor = wrap_fx_proxy( + self, + scalar_to_tensor_proxy, + example_value=get_fake_value(scalar_to_tensor_proxy.node, self), + ) + + self.output.create_proxy( + "call_function", + torch._assert_async, + *proxy_args_kwargs((scalar_to_tensor, error_msg), {}), + ) + self.jump(inst) + return + + if value.is_python_constant(): + if truth_fn(value.as_python_constant()): + if push: + self.push(value) + self.jump(inst) + elif ( + isinstance(value, (TensorVariable)) and self.should_compile_partial_graph() + ): + jump_graph_break(self, inst, value) + elif isinstance(value, NNModuleVariable): + # Equivalent of "self.nn_module is not None" + mod = self.output.get_submodule(value.module_key) + if truth_fn(mod): + if push: + self.push(value) + self.jump(inst) + elif isinstance(value, UnspecializedNNModuleVariable): + mod = value.value + if truth_fn(mod): + if push: + self.push(value) + self.jump(inst) + elif isinstance(value, UserDefinedObjectVariable): + try: + x = value.var_getattr(self, "__bool__") # type: ignore[arg-type] + except exc.ObservedAttributeError: + exc.handle_observed_exception(self) + # if __bool__ is missing, trying __len__ to infer a truth value. + try: + x = value.var_getattr(self, "__len__") # type: ignore[arg-type] + except exc.ObservedAttributeError: + exc.handle_observed_exception(self) + x = None + + # __bool__ or __len__ is function + if isinstance(x, UserMethodVariable): + result = x.call_function(self, [], {}) # type: ignore[arg-type] + if isinstance(result, ConstantVariable) and isinstance( + result.value, (bool, int) + ): + if truth_fn(result.value): + if push: + self.push(value) + self.jump(inst) + else: + unimplemented( + "generic_jump on UserDefined with __bool__ returning non-constant" + ) + # __bool__ or __len__ is non-function or not existed in the user defined object + else: + if truth_fn(True): + if push: + self.push(value) + self.jump(inst) + elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence( + self + ): + if truth_fn(len(value.unpack_var_sequence(self))): + if push: + self.push(value) + self.jump(inst) + elif isinstance(value, SymNodeVariable): + try: + eval_result = value.evaluate_expr(self.output) + except exc.UserError as e: + if self.should_compile_partial_graph(): + return jump_graph_break(self, inst, value, extra_msg=f"\n{e}") + raise + if truth_fn(eval_result): + if push: + self.push(value) + self.jump(inst) + elif isinstance(value, variables.BackwardHookVariable): + if truth_fn(True): + if push: + self.push(value) + self.jump(inst) + else: + from .source import is_constant_source + + if value.source is not None and is_constant_source(value.source): + if truth_fn(value.get_real_value()): # type: ignore[attr-defined] + if push: + self.push(value) + self.jump(inst) + else: + # TODO link the torch.cond doc later + raise exc.UserError( + exc.UserErrorType.DYNAMIC_CONTROL_FLOW, + "Dynamic control flow is not supported at the moment. Please use " + "functorch.experimental.control_flow.cond to explicitly capture the control flow.", + case_name="cond_operands", + ) + + return inner + + +explain = False + + +def break_graph_if_unsupported(*, push): + def decorator(inner_fn): + @functools.wraps(inner_fn) + def wrapper(self: "InstructionTranslatorBase", inst: Instruction): + speculation = self.speculate() + if speculation.failed: + assert speculation.reason is not None + return handle_graph_break(self, inst, speculation.reason) + try: + return inner_fn(self, inst) + except Unsupported as excp: + if self.generic_context_manager_depth > 0: + # We don't support graph break under GenericContextWrappingVariable, + # If there is, we roll back to the checkpoint and fall back. + excp.remove_from_stats() + unimplemented("Graph break under GenericContextWrappingVariable") + + if isinstance(excp, exc.UncapturedHigherOrderOpError): + raise + + if not self.should_compile_partial_graph(): + raise + + user_stack = excp.real_stack + # TODO: Also report the traceback from the parent frame + try: + frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) + except IndexError: + # first instruction + code_options = self.code_options + frame_loc = ( + code_options["co_filename"], + code_options["co_firstlineno"], + ) + # torch._dynamo.explain() formats this a little nicer, and presents a slightly + # more actionable user code pointer + if ( + graph_break_log.isEnabledFor(logging.DEBUG) + and not explain + and graph_break_dup_warning_checker.add(frame_loc) + ): + user_stack_formatted = "".join(traceback.format_list(user_stack)) + # This log line is exercised from + # python test/dynamo/test_exc.py -k test_graph_break_log + graph_break_log.debug( + "Graph break: from user code at:\n%s", + user_stack_formatted, + exc_info=True, + ) + else: + # This log line MUST NOT contain the string "Graph break", + # exercised by + # python test/dynamo/test_misc.py -k test_duplicate_graph_break_log + log.debug( + "Unsupported break in user code at %s:%s (details suppressed)", + *frame_loc, + ) + + if self.maybe_has_backedge(): + msg = ( + "Skipping frame because there is a graph break in a for/while loop\n" + f"{self.frame_summary()}" + ) + log.info(msg) + raise exc.SkipFrame(msg) from excp + + excp.remove_from_stats() + excp.add_to_stats("graph_break") + speculation.reason = GraphCompileReason(excp.msg, user_stack) + speculation.fail_and_restart_analysis() + + def handle_graph_break( + self: "InstructionTranslatorBase", + inst: Instruction, + reason: GraphCompileReason, + ): + self.output.compile_subgraph(self, reason=reason) + cg = PyCodegen(self) + cleanup: List[Instruction] = [] + # Reconstruct the context variable CLASS in the block stack + for b in self.block_stack: + assert b.with_context is not None + assert isinstance(b.with_context, ContextWrappingVariable) + b.with_context.reconstruct_type(cg) + cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) + self.output.add_output_instructions(cg.get_instructions()) + del cg + + if sys.version_info >= (3, 11) and inst.opname == "CALL": + kw_names = ( + self.kw_names.as_python_constant() + if self.kw_names is not None + else () + ) + if len(kw_names) > 0: + # KW_NAMES no longer used in 3.13 + assert sys.version_info < (3, 13) + self.output.add_output_instructions( + [create_instruction("KW_NAMES", argval=kw_names)] + ) + self.output.add_output_instructions( + create_call_function(inst.arg, False) + ) + else: + # copy instruction, but without exception table data + assert inst.target is None + inst_copy = copy.copy(inst) + inst_copy.exn_tab_entry = None + self.output.add_output_instructions([inst_copy]) + + self.output.add_output_instructions(cleanup) + + if ( + sys.version_info >= (3, 11) + and sys.version_info < (3, 12) + and inst.opname == "CALL" + ): + # stack effect for PRECALL + CALL is split between the two instructions + stack_effect = dis.stack_effect( + dis.opmap["PRECALL"], inst.arg + ) + dis.stack_effect(dis.opmap["CALL"], inst.arg) + else: + stack_effect = dis.stack_effect(inst.opcode, inst.arg) + self.popn(push - stack_effect) + + for _ in range(push): + self.push(UnknownVariable()) + self.output.add_output_instructions( + self.create_call_resume_at(self.next_instruction) + ) + + return wrapper + + return decorator + + +class BytecodeDistpatchTableMeta(type): + """Installs a `cls.dispatch_table` on every subclass to speed up calls to self.OPCODE()""" + + def __init__(cls, name, bases, dct) -> None: + super().__init__(name, bases, dct) + + def _missing(opname, *args): + unimplemented(f"missing: {opname}") + + dispatch_table = { + op: getattr(cls, opname, functools.partial(_missing, opname)) + for opname, op in dis.opmap.items() + } + cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)] + + +class InstructionTranslatorBase( + metaclass=BytecodeDistpatchTableMeta, +): + output: OutputGraph + symbolic_locals: Dict[str, VariableTracker] + symbolic_globals: Dict[str, VariableTracker] + symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"] + stack: List[VariableTracker] + instruction_pointer: Optional[int] + current_instruction: Instruction + block_stack: List[BlockStackEntry] + lineno: int + kw_names: Optional[ConstantVariable] + accept_prefix_inst: bool + prefix_insts: List[Instruction] + inline_depth: int + inconsistent_side_effects: bool + current_speculation: Optional[SpeculationEntry] + dispatch_table: List[Any] + exn_vt_stack: List[VariableTracker] + exec_recorder: Optional[ExecutionRecorder] + strict_checks_fn: Optional[Callable[[VariableTracker], bool]] + + def mark_inconsistent_side_effects(self): + """ + InstructionTranslator has encountered instructions which may cause + dynamo to see a different version of history from eager + See: https://github.com/pytorch/pytorch/issues/110765 + """ + self.inconsistent_side_effects = True + + def maybe_has_backedge(self): + # This function employs a heuristic. It does not reliably detect a backedge. + # The heuristic is straightforward: starting from the current instruction and + # continuing to the end, if any jump instruction targets an instruction before + # the current one, there might be a backedge. + + # Python 3.12 introduced changes to bytecode that group common paths in + # blockstacks (with or try...else) and allow for early returns. Consequently, + # there can be multiple RETURN_VALUE instructions. Another heuristic is to + # halt detection upon encountering the first RETURN_VALUE or RETURN_CONST. + + # These heuristics can result in both false positives and negatives, but + # in either case, the Dynamo code remains valid. For false positives + # (where an edge is incorrectly marked as a backedge), Dynamo will + # perform a SkipFrame instead of potentially applying optimizations. For + # false negatives (where an edge that should be marked as a backedge + # isn't), multiple graphs may be generated if there's a break in the + # graph during a for loop. In general, its better to have fewer false + # negatives so that Dynamo does not skip the whole frame. + + cur_offset = self.current_instruction.offset + assert self.instruction_pointer is not None + for inst in self.instructions[self.instruction_pointer :]: + if inst.opname in ("RETURN_VALUE", "RETURN_CONST"): + return False + if inst.opname in JUMP_OPNAMES: + jump_offset = inst.argval + if jump_offset < cur_offset: + return True + return False + + def cell_and_freevars(self): + if not hasattr(self, "_cell_and_freevars"): + self._cell_and_freevars = tuple( + self.code_options["co_cellvars"] or [] + ) + tuple(self.code_options["co_freevars"] or []) + + # An inlined function might depend on the freevar of the parent + # function. So, recursively obtain parent cell and freevars. + if isinstance(self, InliningInstructionTranslator): + self._cell_and_freevars += self.parent.cell_and_freevars() + return self._cell_and_freevars + + def prune_dead_locals(self): + reads = livevars_analysis(self.instructions, self.current_instruction) + # implicit use by super() + # reads = reads | {"__class__"} + # output variables? + reads = reads | set(self.cell_and_freevars()) + self.symbolic_locals = { + k: v for k, v in self.symbolic_locals.items() if k in reads + } + self.output.side_effects.prune_dead_object_new(self) + + def call_function( + self, + fn: VariableTracker, + args: List[VariableTracker], + kwargs: Dict[str, VariableTracker], + ): + assert isinstance(fn, VariableTracker) + assert isinstance(args, list) + assert isinstance(kwargs, dict) + assert all( + isinstance(x, VariableTracker) + for x in itertools.chain(args, kwargs.values()) + ) + inner_fn = None + if hasattr(fn, "value"): + inner_fn = fn.value + if hasattr(fn, "fn"): + inner_fn = fn.fn + if inner_fn and callable(inner_fn) and is_forbidden(inner_fn): + raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}") + self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] + + def inline_user_function_return(self, fn, args, kwargs): + """ + A call to some user defined function by inlining it. + """ + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + + def get_line_of_code_header(self, lineno=None): + if lineno is None: + lineno = self.lineno + inline_depth_str = ( + f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else "" + ) + funcname = get_funcname(self.f_code.co_filename, lineno) + funcname_str = "" if funcname is None else f" ({funcname})" + return f"{self.f_code.co_filename}:{lineno} in {self.f_code.co_name}{funcname_str}{inline_depth_str}" + + def get_log_starts_line_log_str(self): + log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n" + line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip() + log_str += f" {line}" + return log_str + + def starts_line(self, lineno): + if self.lineno == lineno: + return + self.lineno = lineno + TracingContext.set_current_loc( + self.f_code.co_filename, lineno, self.f_code.co_name + ) + from torch._logging.structured import dump_file + + dump_file(self.f_code.co_filename) + if trace_source_log.isEnabledFor(logging.DEBUG): + trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str)) + + def step(self): + """Process exactly one instruction, return False we should exit""" + ip = self.instruction_pointer + if ip is None: + return False + self.current_instruction = inst = self.instructions[ip] + self.instruction_pointer = ip + 1 + + if inst.starts_line: + self.starts_line(inst.starts_line) + + if ( + not self.stack + and self.should_compile_partial_graph() + and self.is_non_empty_graph() + ): + self.current_speculation = self.speculate() + if self.current_speculation.failed: + return self.step_graph_break(inst) + + if trace_bytecode_log.isEnabledFor(logging.DEBUG): + trace_bytecode_log.debug( + "TRACE %s %s %s", inst.opname, inst.argval, self.stack + ) + + self.update_block_stack(inst) + + try: + self.dispatch_table[inst.opcode](self, inst) + return not self.output.should_exit + except exc.ObservedException as e: + self.exception_handler(e) + return True + except ReturnValueOp: + return False + except Unsupported: + if self.current_speculation is None: + log.debug("empty checkpoint") + raise + log.debug("step triggered compile", exc_info=True) + + self.current_speculation.fail_and_restart_analysis() + + if sys.version_info >= (3, 11): + + def update_block_stack(self, inst): + # 3.11+ no longer uses a block stack, but we still keep track of one + # so that we know which contexts are currently active. + # For our purposes, all exception table entries with the same target + # are considered to be part of the same "block". + # NOTE: we only keep track of with blocks that are not contained in try blocks. + # This is because we will not create continuation functions on graph breaks in try blocks, + # but we may for with blocks. We do not push blocks here since + # with blocks are pushed when handling BEFORE_WITH. + entry = inst.exn_tab_entry + if entry: + # Detect when we have exited the top with block. + # The with blocks on the block stack are not enclosed in try + # blocks, so a with block's cleanup code should be in the + # previous with block (if any). + if ( + len(self.block_stack) >= 2 + and entry.target is not self.block_stack[-1].target + and entry.target is self.block_stack[-2].target + ): + # exit the current block + self.block_stack.pop() + else: + # no longer in any block + # It is possible for NOPs to be between two instructions + # in the same block, but the NOPs are not covered by an + # exception table entry. In this case, assume that we + # are still in the same block. + # In 3.12+, JUMP_BACKWARD might also not be covered by + # an exception table entry, so we also assume that we + # are still in the same block. It is probably safe to do + # this in 3.11, even though we haven't encountered this case before. + if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"): + # If we really escape from a block and the current + # instruction is not in another block, then there + # should be no other nested blocks that we are in. + assert len(self.block_stack) == 1 + self.block_stack.pop() + + else: + + def update_block_stack(self, inst): + pass + + @property + def next_instruction(self): + return self.instructions[self.instruction_pointer] # type: ignore[index] + + def step_graph_break(self, continue_inst): + # generate code from checkpoint + assert not self.output.output_instructions + assert self.current_speculation is not None + self.output.compile_subgraph( + self, + partial_convert=True, + reason=GraphCompileReason("step_unsupported", [self.frame_summary()]), + ) + self.output.add_output_instructions( + [create_jump_absolute(continue_inst)] + self.instructions + ) + + def run_ctx_mgr(self): + # NB: Don't push the top level frame summary; set_current_loc will + # take care of it. However, DO make sure we attach real_stack to + # exceptions + return TracingContext.current_frame(None) + + def run(self): + with self.run_ctx_mgr(): + try: + self.output.push_tx(self) + while self.step(): + pass + except BackendCompilerFailed: + raise + except Exception as e: + if self.exec_recorder: + e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined] + raise + finally: + self.output.pop_tx() + # Cleanup the outputGraph to delete the held tensors. We perform the + # cleanup only for InstructionTranslator and not + # InliningInstructionTranslator. The InliningInstructionTranslator + # mutates the output object and is restored to original state if + # there was an exception. + if isinstance(self, InstructionTranslator): + self.output.cleanup() + + def push(self, val: Optional[VariableTracker], name: Any = None): + assert val is None or isinstance( + val, VariableTracker + ), f"push expects VariableTracker, got {typestr(val)}" + self.stack.append(val) # type: ignore[arg-type] + if sys.version_info >= (3, 13): + self.name_stack.append(name) + assert len(self.stack) == len(self.name_stack) + + def push_many(self, vals: List[VariableTracker]): + for val in vals: + self.push(val) + + def pop(self) -> VariableTracker: + if sys.version_info >= (3, 13): + assert len(self.stack) == len(self.name_stack) + self.name_stack.pop() + return self.stack.pop() + + def popn(self, n: int) -> List[VariableTracker]: + return [*reversed([self.pop() for _ in range(n)])] + + def _load_closure(self, name): + return ClosureVariable(name=name) + + def _load_fast(self, name): + if self.exec_recorder and name in self.f_locals: + self.exec_recorder.add_local_var(name, self.f_locals[name]) + + try: + self.push(self.symbolic_locals[name].unwrap(), name=name) + except KeyError: + if sys.version_info >= (3, 13) and name in self.cell_and_freevars(): + # 3.13 merged LOAD_CLOSURE into LOAD_FAST + # If we fail to LOAD_FAST, then we probably should have done LOAD_CLOSURE. + # Closure variable creation is actually done in SET_FUNCTION_ATTRIBUTE, + # but we'll do it again here so that we don't need to push a dummy variable. + # We shouldn't actually be doing anything with this variable anyway. + self.push(self._load_closure(name), name=name) + elif name.startswith("."): + try: + # This happens in dict/list comprehensions + new_name = name.replace(".", "implicit") + self.push(self.symbolic_locals[new_name], name=new_name) + except KeyError: + unimplemented("undefined LOAD_FAST (implicit)") + else: + unimplemented("undefined LOAD_FAST") + + # for continuation functions + if name.startswith("___stack"): + self.symbolic_locals.pop(name) + + def LOAD_FAST(self, inst): + self._load_fast(inst.argval) + + def LOAD_DEREF(self, inst): + assert inst.argval in self.cell_and_freevars() + + if self.exec_recorder and inst.argval in self.f_locals: + self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval]) + + if inst.argval not in self.symbolic_locals: + unimplemented(f"undefined LOAD_DEREF {inst.argval}") + self.push(self.symbolic_locals[inst.argval]) + + def _store_fast(self, name): + loaded_vt = self.pop() + loaded_vt.set_name_hint(name) + self.symbolic_locals[name] = loaded_vt + + def STORE_FAST(self, inst): + self._store_fast(inst.argval) + + def DELETE_FAST(self, inst): + del self.symbolic_locals[inst.argval] + + STORE_DEREF = STORE_FAST + + def LOAD_CLOSURE(self, inst): + self.push(self._load_closure(inst.argval)) + + def _load_const(self, inst): + i = inst.arg + if i is None: + return ConstantVariable.create(value=inst.argval) + val = self._constants_cache[i] + if not val: + self._constants_cache[i] = val = ConstantVariable.create(value=inst.argval) + return val + + def LOAD_CONST(self, inst): + self.push(self._load_const(inst)) + + def _load_global(self, inst): + name = inst.argval + + if self.exec_recorder: + if name in self.f_globals: + self.exec_recorder.add_global_var(name, self.f_globals[name]) + else: + assert name in self.f_builtins + self.exec_recorder.builtins[name] = self.f_builtins[name] + + if name in self.symbolic_globals: + variable = self.output.side_effects[self.symbolic_globals[name]] + self.push(self.output.side_effects.load_global(variable, name)) + return + + try: + value = self.f_globals[name] + except KeyError: + return self.load_builtin(inst) + + source = GlobalSource(name) + self.push(VariableBuilder(self, source)(value)) + + @functools.cached_property + def nn_modules_globals_vt(self): + module_name = "torch.nn.modules.module" + module_source = self.import_source(module_name) + fglobals_value = importlib.import_module(module_name) # type: ignore[assignment] + return VariableBuilder(self, module_source)(fglobals_value) + + def LOAD_GLOBAL(self, inst): + if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2: + self.PUSH_NULL(inst) + self._load_global(inst) + if sys.version_info >= (3, 13) and inst.arg % 2: + self.PUSH_NULL(inst) + + def STORE_GLOBAL(self, inst): + value = self.pop() + name = inst.argval + source = GlobalSource(name) + if name not in self.symbolic_globals: + self.symbolic_globals[name] = object() # type: ignore[assignment] # sentinel object + variable = self.output.side_effects.track_global_existing( + source, self.symbolic_globals[name] + ) + if isinstance(value, RemovableHandleVariable): + unimplemented("Storing handles in globals - NYI") + self.output.side_effects.store_global(variable, name, value) + + def import_source(self, module_name): + """Create an alias to a module for use in guards""" + if "torch_package" in module_name: + value = torch.package.package_importer._package_imported_modules[ + module_name + ] + alias = ( + module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_") + ) + else: + value = importlib.import_module(module_name) + alias = f"__import_{module_name.replace('.', '_dot_')}" + f_globals = self.output.global_scope + assert alias not in f_globals or f_globals[alias] is value + f_globals[alias] = value + self.output.update_co_names(alias) + return GlobalSource(alias) + + def resolve_name(self, name, package, level): + """ + Copied from the Cpython implementation of __import__ + Resolve a relative module name to an absolute one. + https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902 + """ + bits = package.rsplit(".", level - 1) + if len(bits) < level: + raise ImportError("attempted relative import beyond top-level package") + base = bits[0] + return f"{base}.{name}" if name else base + + def calc_package(self): + """ + Copied from the Cpython implementation of __import__ + https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090 + """ + package = self.f_globals.get("__package__") + spec = self.f_globals.get("__spec__") + if package is not None: + if spec is not None and package != spec.parent: + log.warning( + "__package__ != __spec__.parent (%r != %r)", + package, + spec.parent, + stacklevel=3, + ) + return package + elif spec is not None: + return spec.parent + else: + log.warning( + "can't resolve package from __spec__ or __package__, " + "falling back on __name__ and __path__", + stacklevel=3, + ) + package = self.f_globals["__name__"] + if "__path__" not in self.f_globals: + package = package.rpartition(".")[0] + return package + + def IMPORT_NAME(self, inst): + level, fromlist = self.popn(2) + level = level.as_python_constant() + fromlist = fromlist.as_python_constant() + module_name = inst.argval + + # Are we replaying? if so, load recorded module + recorded_name = ( + f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}" + ) + if recorded_name in self.f_globals: + value = self.f_globals[recorded_name] + source = GlobalSource(recorded_name) + else: + try: + value = __import__( + module_name, + fromlist=fromlist, + level=level, + globals=self.f_globals, + ) + except ImportError: + unimplemented("import a module that does not exist") + + if level != 0: + pkg = self.calc_package() + module_name = self.resolve_name(module_name, pkg, level) + + # For __import__, when the name variable is of the form package.module, + # normally, the top-level package (the name up till the first dot) is + # returned, not the module named by module_name. However, when a + # non-empty fromlist argument is given, the module named by name is + # returned. Therefore, we set the source correctly here. + if not fromlist: + top_level_module_name = module_name.partition(".")[0] + source = self.import_source(top_level_module_name) + else: + source = self.import_source(module_name) + + if self.exec_recorder: + self.exec_recorder.add_local_mod(recorded_name, value) + + if istype(value, (types.ModuleType, DummyModule)): + self.push(PythonModuleVariable(value, source=source)) + else: + unimplemented(f"IMPORT_NAME {typestr(value)}") + + def IMPORT_FROM(self, inst): + self.DUP_TOP(inst) + self._load_attr(inst) + + def load_builtin_from_argval(self, argval): + if argval not in self.f_builtins: + raise NameError(f"name '{argval}' is not defined") + val = self.f_builtins[argval] + + if callable(val): + builtins_source = GlobalSource( + self.output.name_of_builtins_dict_key_in_fglobals + ) + var_source = GetItemSource(builtins_source, argval) + self.push(VariableBuilder(self, var_source)(val)) + else: + assert is_builtin_constant(val) + self.push(ConstantVariable.create(value=val)) + + def load_builtin(self, inst): + self.load_builtin_from_argval(inst.argval) + + def jump(self, inst): + self.instruction_pointer = self.indexof[inst.target] + + JUMP_FORWARD = jump + JUMP_ABSOLUTE = jump + + POP_JUMP_IF_FALSE = generic_jump(operator.not_, False) + POP_JUMP_IF_TRUE = generic_jump(operator.truth, False) + JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True) + JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True) + + def SETUP_LOOP(self, inst): + # only exists in python<=3.7 + self.block_stack.append(BlockStackEntry(inst, inst.target)) + + def SETUP_EXCEPT(self, inst): + # only exists in python<=3.7 + self.block_stack.append(BlockStackEntry(inst, inst.target)) + + def POP_BLOCK(self, inst): + self.block_stack.pop() + + def SETUP_WITH(self, inst): + self.setup_or_before_with(inst) + + def SETUP_FINALLY(self, inst): + self.block_stack.append(BlockStackEntry(inst, inst.target)) + + def BEGIN_FINALLY(self, inst): + self.push(None) + + def WITH_CLEANUP_START(self, inst): + exit, exc = self.popn(2) + assert exc is None + self.push(exc) + self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {})) + + def WITH_CLEANUP_FINISH(self, inst): + self.popn(2) + self.push(None) + + def CALL_FINALLY(self, inst): + """ + pushes the address of the next instruction onto the stack and increments + bytecode counter by delta + """ + # Python 3.8 only + addr = self.indexof[self.next_instruction] + self.push(ConstantVariable.create(addr)) + self.jump(inst) + + def END_FINALLY(self, inst): + # Python 3.8 only + # https://docs.python.org/3.8/library/dis.html#opcode-END_FINALLY + tos = self.pop() + if isinstance(tos, ConstantVariable): + self.instruction_pointer = tos.as_python_constant() + else: + pass + + def POP_FINALLY(self, inst): + # Python 3.8 only + preserve_tos = inst.argval + if preserve_tos: + tos = self.pop() + _ = self.pop() + if preserve_tos: + self.push(tos) # type: ignore[possibly-undefined] + + def FOR_ITER(self, inst): + it = self.pop().realize() + try: + val = it.next_variable(self) + self.push(it) + self.push(val) + except (StopIteration, exc.ObservedUserStopIteration) as e: + if isinstance(e, exc.ObservedUserStopIteration): + exc.handle_observed_exception(self) + + # leave iterator upon exhaustion in 3.12 + if sys.version_info >= (3, 12): + # CPython 3.12 actually jumps to the instruction after the END_FOR + # and performs the action of END_FOR as part of FOR_ITER. We jump + # to the END_FOR and run it, so we need to make sure 2 values are + # on the stack for it to pop. + self.push(it) + self.push(ConstantVariable.create(None)) + self.jump(inst) + + def _raise_exception_variable(self, inst): + val = self.pop() + # User can raise exception in 2 ways + # 1) raise exception type - raise NotImplementedError + # 2) raise execption instance - raise NotImplemetedError("foo") + + # 1) when user raises exception type + if isinstance(val, variables.BuiltinVariable): + # Create the instance of the exception type + # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 + val = val.call_function(self, [], {}) # type: ignore[arg-type] + + # Save the exception in a global data structure + self.exn_vt_stack.append(val) + + # 2) when user raises exception instance + if isinstance(val, variables.ExceptionVariable): + if observed_exception_type := exc.observed_exception_map.get(val.exc_type): + raise observed_exception_type(f"raised exception {val}") + raise exc.ObservedException(f"raised exception {val}") + unimplemented(f"raise {exc}") + + def RAISE_VARARGS(self, inst): + if inst.arg == 0: + unimplemented("re-raise") + elif inst.arg == 1: + self._raise_exception_variable(inst) + else: + # Support raise .. from None ... Dynamo does not track __cause__ and other attributes of exception. So we + # ignore `from None` part. + from_vt = self.pop() + if isinstance(from_vt, ConstantVariable) and from_vt.value is None: + self._raise_exception_variable(inst) + unimplemented("raise ... from ...") + + def RERAISE(self, inst): + if sys.version_info >= (3, 11): + # RERAISE is currently supported in a narrow case of `raise ... from None` + self._raise_exception_variable(inst) + unimplemented("RERAISE") + + def exception_handler(self, raised_exception): + if sys.version_info >= (3, 11): + exn_tab_entry = self.current_instruction.exn_tab_entry + if exn_tab_entry: + # Implementation is based on https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + + # 1) pop values from the stack until it matches the stack depth + # for the handler + while len(self.stack) > exn_tab_entry.depth: + self.pop() + + # 2) if 'lasti' is true, then push the offset that the exception was raised at + if exn_tab_entry.lasti: + self.push( + variables.ConstantVariable(self.current_instruction.offset) + ) + + # 3) push the exception to the stack + assert len(self.exn_vt_stack) + self.push(self.exn_vt_stack[-1]) + + # 4) jump to the handler + self.jump(exn_tab_entry) + else: + # No handler found. Bubble the exception to the parent + # instruction translater. We use special exception for this. + self.stack.clear() + if type(self) is InstructionTranslator: + raise Unsupported("Observed exception") + raise raised_exception + else: + if len(self.block_stack): + # base implementation - https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4455 + + assert len(self.exn_vt_stack) + exception_var = self.exn_vt_stack[-1] + + block_stack_entry = self.block_stack.pop() + + while block_stack_entry.inst.opname == "EXCEPT_HANDLER": + # TODO(anijain2305) - This is not tested .. unable to create a testcase + # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 + self.popn(3) + if len(self.block_stack) == 0: + # No handler found in this frame. Bubble the exception to the parent + # instruction translater. + self.stack.clear() + if type(self) is InstructionTranslator: + raise Unsupported("Observed exception") + raise raised_exception + block_stack_entry = self.block_stack.pop() + + if block_stack_entry.inst.opname != "SETUP_FINALLY": + unimplemented( + "exception is raised when top of the block stack " + "is not exception handler (e.g. try .. with .. except). " + f"Current TOS is {block_stack_entry.inst}" + ) + + # Push a dummy block stack entry of EXCEPT_HANDLER + # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 + except_handler_inst = Instruction(1e6, "EXCEPT_HANDLER", None, 0) + self.block_stack.append(BlockStackEntry(except_handler_inst, None)) + + # Push old exception + if len(self.exn_vt_stack) >= 2: + old_exception = self.exn_vt_stack[-2] + + # Push the old exception on to stack - tb, value, type + # Traceback is currently mapped to UnknownVariable + self.push(variables.UnknownVariable()) + self.push(old_exception) + self.push(variables.BuiltinVariable(old_exception.exc_type)) + else: + # Push empty exception tb, value, type + self.push(variables.ConstantVariable(None)) + self.push(variables.ConstantVariable(None)) + self.push(variables.ConstantVariable(None)) + + # Push new exception - tb, val, type + # Traceback is currently mapped to UnknownVariable + self.push(variables.UnknownVariable()) + self.push(exception_var) + self.push(variables.BuiltinVariable(exception_var.exc_type)) + + # Jump to target + self.jump(block_stack_entry) + else: + # No handler found. Bubble the exception to the parent + # instruction translater. We use special exception for this. + self.stack.clear() + if type(self) is InstructionTranslator: + raise Unsupported("Observed exception") + raise raised_exception + + def PUSH_EXC_INFO(self, inst): + val = self.pop() + assert len(self.exn_vt_stack) + self.push(self.exn_vt_stack[-1]) + self.push(val) + + def POP_EXCEPT(self, inst): + if sys.version_info >= (3, 11): + val = self.pop() + assert isinstance(val, variables.ExceptionVariable) + + # This exception is handled and therefore we can clear the error indicator + assert len(self.exn_vt_stack) + self.exn_vt_stack.pop() + else: + assert len(self.block_stack) > 0 + if self.block_stack[-1].inst.opname != "EXCEPT_HANDLER": + raise AssertionError( + "Bug in Dynamo tracing of exception handling." + "Top of the block stack is not EXCEPT_HANDLER." + ) + self.block_stack.pop() + + self.popn(3) + + # This exception is handled and therefore we can clear the error indicator + assert len(self.exn_vt_stack) + self.exn_vt_stack.pop() + + def check_if_exc_matches(self): + assert len(self.stack) >= 2 + expected_exc_types = self.pop() + if sys.version_info >= (3, 11): + # CHECK_EXC_MATCH (which is used from 3.11 onwards) does not pop. + # This is the description from the disassembly doc + # + # Performs exception matching for ``except``. Tests whether the ``STACK[-2]`` + # is an exception matching ``STACK[-1]``. Pops ``STACK[-1]`` and pushes the boolean + # result of the test. + exc_instance = self.stack[-1] + else: + # This is used prior to 3.11 via opcode JUMP_IF_NOT_EXC_MATCH + # There is no documentation but here is the code pointer that does 2 pops + # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L3650-L3665 + exc_instance = self.stack.pop() + + # Users can check exception in 2 ways + # 1) except NotImplementedError --> BuilinVariable + # 2) except (NotImplemetedError, AttributeError) -> TupleVariable + + if not isinstance(expected_exc_types, (BuiltinVariable, TupleVariable)): + unimplemented( + f"except has an unsupported types of objects {expected_exc_types}" + ) + + if sys.version_info >= (3, 11): + if not isinstance(exc_instance, variables.ExceptionVariable): + unimplemented( + f"except expects to recieve an object of exception type but received {exc_instance}" + ) + + if isinstance(expected_exc_types, TupleVariable): + expected_types = expected_exc_types.items + else: + expected_types = [ + expected_exc_types, + ] + + for expected_type in expected_types: + if not isinstance(expected_type, BuiltinVariable): + unimplemented( + f"except has an unsupported types of object {expected_type}" + ) + if isinstance(exc_instance, variables.ExceptionVariable) and issubclass( + exc_instance.exc_type, expected_type.fn + ): + return True + elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass( + exc_instance.fn, expected_type.fn + ): + return True + + return False + + def CHECK_EXC_MATCH(self, inst): + self.push(variables.ConstantVariable(self.check_if_exc_matches())) + + def JUMP_IF_NOT_EXC_MATCH(self, inst): + if not self.check_if_exc_matches(): + self.jump(inst) + + def COMPARE_OP(self, inst): + if inst.argval == "exception match": + self.CHECK_EXC_MATCH(inst) + else: + self.push(compare_op_handlers[inst.argval](self, self.popn(2), {})) + + def GET_ITER(self, inst): + self.call_function(BuiltinVariable(iter), [self.pop()], {}) + + @break_graph_if_unsupported(push=1) + def CALL_FUNCTION(self, inst): + args = self.popn(inst.argval) + fn = self.pop() + self.call_function(fn, args, {}) + + @break_graph_if_unsupported(push=1) + def CALL_FUNCTION_EX(self, inst): + kwargsvars: VariableTracker + if inst.argval == 0: + kwargsvars = ConstDictVariable({}) + argsvars = self.pop() + elif inst.argval == 1: + kwargsvars = self.pop() + argsvars = self.pop() + else: + unimplemented("CALL_FUNCTION_EX") + + if sys.version_info >= (3, 13): + # 3.13 swapped null and callable + null = self.pop() + assert isinstance(null, NullVariable) + + fn = self.pop() + + if sys.version_info >= (3, 11) and sys.version_info < (3, 13): + null = self.pop() + assert isinstance(null, NullVariable) + + if isinstance(fn, GetAttrVariable) and isinstance(fn.obj, TensorVariable): + # realize is requires for Python 3.8 + kwargsvars = kwargsvars.realize() + if fn.name == "view" and isinstance( + argsvars, (ConstantVariable, TensorVariable) + ): + # Hack to handle special case in some bert models. Converts + # x.view(*shape) into x.view(shape), which is correct for view() + # but not generally. See test_transpose_for_scores(). + argsvars = TupleVariable([argsvars]) + elif ( + fn.name == "random_" + and isinstance(argsvars, TupleVariable) + and len(argsvars.items) == 0 + and isinstance(kwargsvars, ConstDictVariable) + and ConstantVariable.create("from") in kwargsvars + ): + # `from`` is python keyword. Adding random_ with `from` in the + # Fx graph causes syntax error. Even if we convert the kwargs to + # args, aot_autograd/inductor while lowering generates + # aten.random.from, again causing syntax errors. Since this + # usecase is uncommon, graph break. + unimplemented("random_ op is called with from keyword") + elif ( + fn.name == "uniform_" + and isinstance(argsvars, TupleVariable) + and len(argsvars.items) == 0 + and isinstance(kwargsvars, ConstDictVariable) + and ConstantVariable.create("from") in kwargsvars + ): + # `from`` is python keyword. Adding uniform_ with `from` in the + # Fx graph causes syntax error. Even if we convert the kwargs to + # args, aot_autograd/inductor while lowering generates + # aten.uniform.from, again causing syntax errors. Since this + # usecase is uncommon, graph break. + unimplemented("uniform_ op is called with from keyword") + + if not isinstance( + argsvars, BaseListVariable + ) and argsvars.has_force_unpack_var_sequence(self): + argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self)) + + # Unpack for cases like fn(**obj) where obj is a map + if isinstance(kwargsvars, UserDefinedObjectVariable): + kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type] + + if not isinstance(argsvars, BaseListVariable) or not isinstance( + kwargsvars, ConstDictVariable + ): + unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}") + + # Map to a dictionary of str -> VariableTracker + kwargsvars = kwargsvars.keys_as_python_constant() + self.call_function(fn, argsvars.items, kwargsvars) + + @break_graph_if_unsupported(push=1) + def CALL_FUNCTION_KW(self, inst): + argnames = self.pop() + args = self.popn(inst.argval) + fn = self.pop() + assert isinstance(argnames, TupleVariable) and argnames.is_python_constant() + argnames = argnames.as_python_constant() + args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :] + kwargs = dict(zip(argnames, kwargs_list)) + assert len(kwargs) == len(argnames) + self.call_function(fn, args, kwargs) + + def LOAD_METHOD_SUPER(self, inst): + self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) + arg = inst.argval[0] + argval = self.code_options["co_names"][arg] + if sys.version_info < (3, 11): + self._load_attr(dataclasses.replace(inst, argval=argval)) + else: + self.LOAD_METHOD(dataclasses.replace(inst, argval=argval)) + + def LOAD_ATTR_SUPER(self, inst): + self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) + arg = inst.argval[0] + argval = self.code_options["co_names"][arg] + self._load_attr(dataclasses.replace(inst, argval=argval)) + + def LOAD_METHOD(self, inst): + self._load_attr(inst) + obj = self.pop() + if sys.version_info >= (3, 13): + self.push(obj) + self.PUSH_NULL(inst) + elif sys.version_info >= (3, 11): + # always follow the NULL + fn convention, since if obj + # is actually a method, self is already bound to it, so it + # doesn't need to be passed in as an arg. + self.PUSH_NULL(inst) + self.push(obj) + else: + self.push(obj) + self.push(None) + + def CALL_METHOD(self, inst): + args = self.popn(inst.argval) + dummy = self.pop() + assert dummy is None + fn = self.pop() + self.call_function(fn, args, {}) + + def _load_attr(self, inst): + obj = self.pop() + result = BuiltinVariable(getattr).call_function( + self, [obj, ConstantVariable.create(inst.argval)], {} # type: ignore[arg-type] + ) + self.push(result) + + def LOAD_ATTR(self, inst): + if sys.version_info >= (3, 12): + if inst.arg % 2: + self.LOAD_METHOD(inst) + return + self._load_attr(inst) + + def STORE_ATTR(self, inst): + speculation = self.speculate() + if speculation.failed: + return self.store_attr_graph_break(inst) + val, obj = self.popn(2) + + if isinstance(obj, NNModuleVariable) and not isinstance(val, ConstantVariable): + # We don't allow side effects during export on non-constant values + # https://github.com/pytorch/torchdynamo/issues/1475 + assert ( + not self.export + ), f"Mutating module attribute {inst.argval} during export." + + try: + BuiltinVariable(setattr).call_function( + self, [obj, ConstantVariable.create(inst.argval), val], {} # type: ignore[arg-type] + ) + return + except Unsupported as e: + if not self.should_compile_partial_graph(): + raise + log.debug("STORE_ATTR triggered compile", exc_info=True) + e.remove_from_stats() + e.add_to_stats("graph_break") + speculation.fail_and_restart_analysis() + + def store_attr_graph_break(self, inst): + if not self.should_compile_partial_graph(): + unimplemented("should_compile_partial_graph=False") + self.output.compile_subgraph( + self, reason=GraphCompileReason("store_attr", [self.frame_summary()]) + ) + self.output.add_output_instructions([copy.copy(inst)]) + self.popn(2) + self.output.add_output_instructions( + self.create_call_resume_at(self.next_instruction) + ) + + def DELETE_ATTR(self, inst): + obj = self.pop() + BuiltinVariable(delattr).call_function( + self, [obj, ConstantVariable.create(inst.argval)], {} # type: ignore[arg-type] + ) + + def create_call_resume_at(self, offset): + raise AssertionError( + f"create_call_resume_at not overridden by subclass {type(self)}" + ) + + def should_compile_partial_graph(self) -> bool: + raise AssertionError( + f"should_compile_partial_graph not overridden by subclass {type(self)}" + ) + + @break_graph_if_unsupported(push=0) + def STORE_SUBSCR(self, inst): + val, obj, key = self.popn(3) + result = obj.call_method(self, "__setitem__", [key, val], {}) + + def DELETE_SUBSCR(self, inst): + obj, key = self.popn(2) + obj.call_method(self, "__delitem__", [key], {}) + + def BUILD_TUPLE(self, inst): + name_tuple = None + if sys.version_info >= (3, 13): + name_tuple = tuple(self.name_stack[-inst.argval :]) + items = self.popn(inst.argval) + self.push(TupleVariable(items), name=name_tuple) + + def BUILD_SLICE(self, inst): + items = self.popn(inst.argval) + self.push(SliceVariable(items)) + + def BUILD_LIST(self, inst): + items = self.popn(inst.argval) + self.push(ListVariable(items, mutable_local=MutableLocal())) + + def BUILD_SET(self, inst): + if config.inject_BUILD_SET_unimplemented_TESTING_ONLY: + unimplemented("missing: BUILD_SET") + items = self.popn(inst.argval) + new_set = SetVariable(items, mutable_local=MutableLocal()) + self.push(new_set) + + def BUILD_LIST_UNPACK(self, inst, cls=ListVariable): + seqs = self.popn(inst.argval) + items = [] + for seq in seqs: + try: + items.extend(seq.force_unpack_var_sequence(self)) + except NotImplementedError: + unimplemented(f"BUILD_LIST_UNPACK {seq}") + self.push(cls(items, mutable_local=MutableLocal())) + + def BUILD_TUPLE_UNPACK(self, inst): + self.BUILD_LIST_UNPACK(inst, cls=TupleVariable) + + BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK + + def BUILD_MAP(self, inst): + items = self.popn(inst.argval * 2) + d = dict(zip(items[::2], items[1::2])) + self.push(ConstDictVariable(d, mutable_local=MutableLocal())) + + def BUILD_MAP_UNPACK(self, inst): + items = self.popn(inst.argval) + # ensure everything is a dict + items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type] + result = {} + for x in items: + assert isinstance(x, ConstDictVariable) + result.update(x.items) + self.push( + ConstDictVariable( + result, + mutable_local=MutableLocal(), + ) + ) + + BUILD_MAP_UNPACK_WITH_CALL = BUILD_MAP_UNPACK + + def BUILD_CONST_KEY_MAP(self, inst): + keys = self.pop() + values = self.popn(inst.argval) + assert isinstance(keys, TupleVariable) + assert keys.is_python_constant() + + keys = keys.force_unpack_var_sequence(self) + assert len(keys) == len(values) + + self.push( + ConstDictVariable( + dict(zip(keys, values)), + mutable_local=MutableLocal(), + ) + ) + + def MAP_ADD(self, inst): + k, v = self.popn(2) + assert inst.argval > 0 + obj = self.stack[-inst.arg].realize() + assert isinstance(obj, ConstDictVariable) + obj.call_method(self, "__setitem__", (k, v), {}) # type: ignore[arg-type] + + def SET_ADD(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg] + assert isinstance(obj, SetVariable) + assert obj.mutable_local + return obj.call_method(self, "add", [v], {}) + + def SET_UPDATE(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg] + assert isinstance(obj, SetVariable) + assert obj.mutable_local + obj.call_method(self, "update", [v], {}) + + def LIST_APPEND(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg].realize() + assert isinstance(obj, ListVariable) + assert obj.mutable_local + self.output.side_effects.mutation(obj) + obj.items.append(v) + + def MAKE_FUNCTION(self, inst): + flags = inst.arg + old_stack = list(self.stack) + if sys.version_info < (3, 11): + fn_name = self.pop() + code = self.pop() + if sys.version_info >= (3, 11): + # MAKE_FUNCTION behavior actually changed in 3.11, see + # https://github.com/python/cpython/pull/93189/ + assert hasattr(code.value, "co_qualname") # type: ignore[attr-defined] + fn_name = ConstantVariable.create(value=code.value.co_qualname) # type: ignore[attr-defined] + defaults = None + closure = None + annotations = None + kwdefaults = None + + if sys.version_info < (3, 13): + # in 3.13, this is handled in SET_FUNCTION_ATTRIBUTE + if flags & 0x08: + closure = self.pop() + if flags & 0x04: + annotations = self.pop() + if flags & 0x02: + kwdefaults = self.pop() + if flags & 0x01: + defaults = self.pop() + + self.push( + NestedUserFunctionVariable( + fn_name, + code, + self.f_globals, + defaults, + kwdefaults, + annotations, + closure, + closure_scope=self, + ) + ) + + def UNPACK_SEQUENCE(self, inst): + seq = self.pop() + if isinstance(seq, TensorVariable): + val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) # type: ignore[arg-type] + elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable): + # x, y = a.shape + proxy = getattr(seq.obj.as_proxy(), seq.name) + val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)] + elif seq.has_force_unpack_var_sequence(self): + val = seq.force_unpack_var_sequence(self) + else: + unimplemented(f"UNPACK_SEQUENCE {seq}") + if len(val) != inst.argval: + unimplemented("UNPACK_SEQUENCE length mismatch") + for i in reversed(val): + self.push(i) + + def UNPACK_EX(self, inst): + assert 0 <= inst.argval <= 0xFFFF + prefix = inst.argval & 0xFF # low byte + suffix = inst.argval >> 8 # high byte + seq = self.pop() + if seq.has_force_unpack_var_sequence(self): + vals = list(seq.force_unpack_var_sequence(self)) + assert len(vals) >= prefix + suffix + vals_prefix = vals[:prefix] + vals_list = vals[prefix : len(vals) - suffix] + vals_suffix = vals[len(vals) - suffix :] + for item in reversed(vals_suffix): + self.push(item) + self.push(TupleVariable(vals_list)) + for item in reversed(vals_prefix): + self.push(item) + else: + unimplemented(f"UNPACK_EX {seq}") + + def NOP(self, inst): + pass + + def POP_TOP(self, inst): + self.pop() + + def ROT_TWO(self, inst): + a = self.pop() + b = self.pop() + self.push(a) + self.push(b) + + def ROT_THREE(self, inst): + a = self.pop() + b = self.pop() + c = self.pop() + self.push(a) + self.push(c) + self.push(b) + + def ROT_FOUR(self, inst): + a = self.pop() + b = self.pop() + c = self.pop() + d = self.pop() + self.push(a) + self.push(d) + self.push(c) + self.push(b) + + def DUP_TOP(self, inst): + a = self.pop() + self.push(a) + self.push(a) + + def DUP_TOP_TWO(self, inst): + a = self.pop() + b = self.pop() + self.push(b) + self.push(a) + self.push(b) + self.push(a) + + def FORMAT_VALUE(self, inst): + flags = inst.arg + if (flags & 0x04) == 0x04: + fmt_spec = self.pop() + else: + fmt_spec = ConstantVariable.create("") + + value = self.pop() + if isinstance(value, SymNodeVariable): + from torch._dynamo.variables.lazy import ( + LazySymNodeFormatString, + LazyVariableTracker, + ) + + value = LazyVariableTracker.create( + LazySymNodeFormatString(value, fmt_spec), source=value.source + ) + self.push(value) + return + if (flags & 0x03) == 0x01: + value = BuiltinVariable(str).call_function(self, [value], {}) # type: ignore[arg-type] + elif (flags & 0x03) == 0x02: + value = BuiltinVariable(repr).call_function(self, [value], {}) # type: ignore[arg-type] + elif (flags & 0x03) == 0x03: + value = BuiltinVariable(ascii).call_function(self, [value], {}) # type: ignore[arg-type] + + fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}") + + self.call_function(BuiltinVariable(str.format), [fmt_var, value], {}) + + def BUILD_STRING(self, inst): + format_string_parts: List[str] = [] + args: List[VariableTracker] = [] + kwargs: Dict[str, VariableTracker] = {} + for part in self.popn(inst.arg): + if isinstance(part, ConstantVariable): + format_string_parts.append("{}") + args.append(part) + elif isinstance(part, variables.StringFormatVariable): + format_string_parts.append(part.format_string) + args.extend(part.sym_args) + if set(kwargs.keys()) & set(part.sym_kwargs.keys()): + unimplemented( + f"BUILD_STRING key conflict {kwargs} & {part.sym_kwargs}" + ) + kwargs.update(part.sym_kwargs) + else: + unimplemented(f"BUILD_STRING {part}") + self.push( + variables.StringFormatVariable.create( + "".join(format_string_parts), args, kwargs + ) + ) + + def IS_OP(self, inst): + assert inst.argval == 0 or inst.argval == 1 + if inst.argval == 0: + new_argval = "is" + else: + new_argval = "is not" + new_inst = create_instruction("COMPARE_OP", argval=new_argval) + self.COMPARE_OP(new_inst) + + def CONTAINS_OP(self, inst): + assert inst.argval == 0 or inst.argval == 1 + left, right = self.popn(2) + op = inst.argval + self.push(right.call_method(self, "__contains__", [left], {})) + if op == 1: + self.UNARY_NOT(inst) + + def LIST_EXTEND(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg] + assert isinstance(obj, ListVariable) + assert obj.mutable_local + obj.call_method(self, "extend", [v], {}) + + def LIST_TO_TUPLE(self, inst): + self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type] + + def DICT_MERGE(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg].realize() + assert isinstance(obj, ConstDictVariable) + assert obj.mutable_local + obj.call_method(self, "update", [v], {}) + + DICT_UPDATE = DICT_MERGE + + def GEN_START(self, inst): + self.pop() + + def GET_LEN(self, inst): + tos = self.stack[-1] + if tos.is_python_constant(): + self.push(ConstantVariable.create(len(tos.as_python_constant()))) + else: + self.push(tos.call_method(self, "__len__", [], {})) + + def MATCH_MAPPING(self, inst): + tos = self.stack[-1] + assert isinstance(tos, ConstDictVariable) + if isinstance(tos.items, collections.abc.Mapping): + self.push(ConstantVariable.create(True)) + else: + self.push(ConstantVariable.create(False)) + + def MATCH_SEQUENCE(self, inst): + tos = self.stack[-1] + assert tos.is_python_constant() + tos_value = tos.as_python_constant() + if isinstance(tos_value, collections.abc.Sequence) and not isinstance( + tos_value, (str, bytes, bytearray) + ): + self.push(ConstantVariable.create(True)) + else: + self.push(ConstantVariable.create(False)) + + def MATCH_KEYS(self, inst): + tos = self.stack[-1] + tos1 = self.stack[-2] + assert isinstance(tos1, ConstDictVariable) + + if all(k in tos1 for k in tos): # type: ignore[attr-defined] + self.push(TupleVariable([tos1.getitem_const(self, k) for k in tos])) # type: ignore[attr-defined,arg-type] + if sys.version_info < (3, 11): + self.push(ConstantVariable.create(True)) + else: + self.push(ConstantVariable.create(None)) + if sys.version_info < (3, 11): + self.push(ConstantVariable.create(False)) + + def LOAD_ASSERTION_ERROR(self, inst): + self.load_builtin_from_argval("AssertionError") + + UNARY_POSITIVE = stack_op(operator.pos) + UNARY_NEGATIVE = stack_op(operator.neg) + UNARY_NOT = stack_op(operator.not_) + UNARY_INVERT = stack_op(operator.invert) + + BINARY_POWER = stack_op(operator.pow) + BINARY_MULTIPLY = stack_op(operator.mul) + BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul) + BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv) + BINARY_TRUE_DIVIDE = stack_op(operator.truediv) + BINARY_MODULO = stack_op(operator.mod) + BINARY_REMAINDER = stack_op(operator.mod) + BINARY_ADD = stack_op(operator.add) + BINARY_SUBTRACT = stack_op(operator.sub) + BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem)) + BINARY_LSHIFT = stack_op(operator.lshift) + BINARY_RSHIFT = stack_op(operator.rshift) + BINARY_AND = stack_op(operator.and_) + BINARY_OR = stack_op(operator.or_) + BINARY_XOR = stack_op(operator.xor) + + INPLACE_POWER = stack_op(operator.ipow) + INPLACE_MULTIPLY = stack_op(operator.imul) + INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul) + INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv) + INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv) + INPLACE_MODULO = stack_op(operator.imod) + INPLACE_REMAINDER = stack_op(operator.imod) + INPLACE_ADD = stack_op(operator.iadd) + INPLACE_SUBTRACT = stack_op(operator.isub) + INPLACE_LSHIFT = stack_op(operator.ilshift) + INPLACE_RSHIFT = stack_op(operator.irshift) + INPLACE_AND = stack_op(operator.iand) + INPLACE_XOR = stack_op(operator.ixor) + INPLACE_OR = stack_op(operator.ior) + + # 3.11 opcodes + def RESUME(self, inst): + if inst.arg == 0: + self.append_prefix_inst(inst) + self.accept_prefix_inst = False + else: + assert not self.accept_prefix_inst + + if sys.version_info >= (3, 11): + + def BINARY_OP(self, inst): + return _binary_op_lookup[inst.arg](self, inst) + + def PRECALL(self, inst): + pass + + def KW_NAMES(self, inst): + kw_names = self.code_options["co_consts"][inst.arg] + assert isinstance(kw_names, tuple) + for name in kw_names: + assert isinstance(name, str) + assert self.kw_names is None + self.kw_names = ConstantVariable.create(value=kw_names) # type: ignore[assignment] + + def PUSH_NULL(self, inst): + self.push(NullVariable()) + + def _call(self, inst, call_kw=False): + # see https://docs.python.org/3.11/library/dis.html#opcode-CALL + # for convention + if call_kw: + # TOS is kw_names for CALL_KW instruction + assert sys.version_info >= (3, 13) + kw_names = self.pop() + assert isinstance(kw_names, TupleVariable) and kw_names.is_python_constant() + kw_names = kw_names.as_python_constant() + else: + kw_names = self.kw_names.value if self.kw_names else () + + contents = self.popn(inst.arg + 2) + if sys.version_info >= (3, 13): + # NULL and callable swapped + fn = contents[0] + args = [] if isinstance(contents[1], NullVariable) else [contents[1]] + else: + if isinstance(contents[0], NullVariable): + fn = contents[1] + args = [] + else: + fn = contents[0] + args = [contents[1]] + + if kw_names: + args = args + contents[2 : -len(kw_names)] + kwargs_list = contents[-len(kw_names) :] + kwargs = dict(zip(kw_names, kwargs_list)) + assert len(kwargs) == len(kw_names) + else: + args = args + contents[2:] + kwargs = {} + + try: + # if call_function fails, need to set kw_names to None, otherwise + # a subsequent call may have self.kw_names set to an old value + self.call_function(fn, args, kwargs) + finally: + self.kw_names = None + + @break_graph_if_unsupported(push=1) + def CALL(self, inst): + self._call(inst) + + def COPY(self, inst): + self.push(self.stack[-inst.arg]) + + def SWAP(self, inst): + self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1] + + JUMP_BACKWARD = jump + JUMP_BACKWARD_NO_INTERRUPT = jump + + POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False) + POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False) + POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False) + POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False) + + def CACHE(self, inst): + pass + + def BEFORE_WITH(self, inst): + self.setup_or_before_with(inst) + + def setup_or_before_with(self, inst): + ctx = self.pop() + if not isinstance( + ctx, (ContextWrappingVariable, GenericContextWrappingVariable) + ): + unimplemented(f"{inst.opname} {ctx}") + + if isinstance(ctx, GenericContextWrappingVariable): + self.generic_context_manager_depth += 1 + + # Need this redundant check for mypy + assert isinstance( + ctx, (ContextWrappingVariable, GenericContextWrappingVariable) + ) + + exit = WithExitFunctionVariable( + ctx, + inst.target, + ) + + if sys.version_info >= (3, 11): + # See create_call_resume_at for block stack details. + # Only push a block if the current instruction's block is a + # with block that is not nested in a try block - that is, the current + # instruction's block target is the same as the top block's target. + if inst.exn_tab_entry and ( + not self.block_stack + or inst.exn_tab_entry.target is not self.block_stack[-1].target + ): + target = None + else: + target = self.next_instruction.exn_tab_entry.target + else: + target = inst.target + + if target: + if isinstance(self, InstructionTranslator): + self.block_stack.append( + BlockStackEntry(inst, target, len(self.stack), ctx) + ) + else: + self.block_stack.append(BlockStackEntry(inst, target)) + + self.push(exit) + self.push(ctx.enter(self)) + + def append_prefix_inst(self, inst): + assert self.accept_prefix_inst + self.prefix_insts.append(inst) + + def MAKE_CELL(self, inst): + if sys.version_info >= (3, 12) and not self.accept_prefix_inst: + # In 3.12+, MAKE_CELL is not longer necessarily a prefix instruction. + # It can be generated by inlined comprehensions. + assert isinstance(self.symbolic_locals[inst.argval], NullVariable) + self.symbolic_locals[ + inst.argval + ] = self.output.side_effects.track_cell_new() + else: + self.append_prefix_inst(inst) + + def COPY_FREE_VARS(self, inst): + self.append_prefix_inst(inst) + + def RETURN_GENERATOR(self, inst): + self.append_prefix_inst(inst) + + # 3.12 opcodes + # BINARY/STORE_SLICE opcodes are broken down into + # BUILD_SLICE 2 and BINARY/STORE_SUBSCR + + def END_FOR(self, inst): + if sys.version_info >= (3, 13): + self.pop() + else: + self.popn(2) + + def LOAD_FAST_CHECK(self, inst): + if isinstance(self.symbolic_locals[inst.argval], NullVariable): + unimplemented("LOAD_FAST_CHECK on uninitialized variable") + self.LOAD_FAST(inst) + + def LOAD_FAST_AND_CLEAR(self, inst): + if inst.argval not in self.symbolic_locals: + self.push(NullVariable()) + else: + self.LOAD_FAST(inst) + self.symbolic_locals[inst.argval] = NullVariable() + + def LOAD_SUPER_ATTR(self, inst): + self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) + if inst.arg & 1: + self.LOAD_METHOD(inst) + else: + self._load_attr(inst) + + def CALL_INTRINSIC_1(self, inst): + if inst.argval == 5: + # INTRINSIC_UNARY_POSITIVE + self.UNARY_POSITIVE(inst) + elif inst.argval == 6: + # INTRINSIC_LIST_TO_TUPLE + self.push(TupleVariable(self.pop().force_unpack_var_sequence(self))) + else: + unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}") + + def END_SEND(self, inst): + tos = self.pop() + self.pop() + self.push(tos) + + # 3.13 opcodes + # fused instructions LOAD_FAST_LOAD_FAST, STORE_FAST_STORE_FAST, STORE_FAST_LOAD_FAST + # are broken down. + @break_graph_if_unsupported(push=1) + def CALL_KW(self, inst): + self._call(inst, call_kw=True) + + def TO_BOOL(self, inst): + # TO_BOOL only precedes a conditional jump or UNARY_NOT (see compile.c in CPython) + # So we can skip this instruction as long as we remember to codegen a TO_BOOL + # before conditional jumps/UNARY_NOT. + assert self.next_instruction.opname in ( + "POP_JUMP_IF_TRUE", + "POP_JUMP_IF_FALSE", + "UNARY_NOT", + ) + + def SET_FUNCTION_ATTRIBUTE(self, inst): + flags = inst.arg + fn = self.pop() + assert isinstance(fn, NestedUserFunctionVariable) + attr_names = self.name_stack[-1] + attr = self.pop() + + if flags & 0x08: + # 3.13 merged LOAD_CLOSURE into LOAD_FAST, so we won't know if a given LOAD_FAST + # is meant to load a closure variable or not. Our workaround is to maintain a stack + # of LOAD_FAST variable names and tuples (self.name_stack). So if we are indeed + # constructing a closure tuple, we can use self.name_stack to construct the closure + # variables here. + assert isinstance(attr_names, tuple) and all( + isinstance(name, str) for name in attr_names + ) + fn.closure = TupleVariable( + [self._load_closure(name) for name in attr_names] + ) + fn.closure_scope = self + elif flags & 0x04: + fn.annotations = attr + elif flags & 0x02: + fn.kwdefaults = attr + elif flags & 0x01: + fn.defaults = attr + + self.push(fn) + + def _format_value_313(self, fmt_spec): + value = self.pop() + if isinstance(value, SymNodeVariable): + value = ConstantVariable.create(str(value.sym_num)) + + fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}") + + self.call_function(BuiltinVariable(str.format), [fmt_var, value], {}) + + def FORMAT_SIMPLE(self, inst): + self._format_value_313(ConstantVariable.create("")) + + def FORMAT_WITH_SPEC(self, inst): + self._format_value_313(self.pop()) + + def is_non_empty_graph(self): + if self.output.count_calls() > 1: + # perf optimization only + self.is_non_empty_graph = lambda: True # type: ignore[method-assign] + return True + return False + + def format_frame_summary(self, additional_stack_frames=None): + if additional_stack_frames is None: + additional_stack_frames = [] + return "".join( + traceback.format_list( + [self.frame_summary()] + list(reversed(additional_stack_frames)) + ) + ) + + def frame_summary(self): + return traceback.FrameSummary( + getattr(self.f_code, "co_filename", ""), + self.lineno, + getattr(self.f_code, "co_name", ""), + lookup_line=False, + ) + + def is_co_filename_from_nn_modules(self): + filename = getattr(self.f_code, "co_filename", "") + nn_modules_pattern = re.compile(r".*torch/nn/modules.*") + return nn_modules_pattern.match(filename) is not None + + def store_global_weakref_by_id(self, prefix, value): + global_name = self.output.install_global_by_id(prefix, weakref.ref(value)) + install_guard( + GlobalWeakRefSource(global_name).make_guard(GuardBuilder.WEAKREF_ALIVE) + ) + return global_name + + @property + def fake_mode(self): + return self.output.tracing_context.fake_mode + + def find_symbolic_locals_name(self, tensor_variable): + for key, value in self.symbolic_locals.items(): + if value is tensor_variable: + return key + return None + + @contextlib.contextmanager + def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]): + """ + Strict mode is enabled on a per-VariableTracker level depending on the return value of check_fn(node). + """ + prior = self.strict_checks_fn + self.strict_checks_fn = check_fn + try: + yield + finally: + self.strict_checks_fn = prior + + def speculate(self) -> SpeculationEntry: + assert self.instruction_pointer is not None + assert self.instruction_pointer > 0 + return self.speculation_log.next( + self.f_code.co_filename, + self.lineno, + self.instruction_pointer - 1, + self.instructions[self.instruction_pointer - 1], + ) + + def __init__( + self, + output: OutputGraph, + instructions: List[Instruction], + f_locals: Dict[str, Any], + f_globals: Dict[str, Any], + f_builtins: Dict[str, Any], + code_options: Dict[str, Any], + symbolic_locals: Dict[str, VariableTracker], + symbolic_globals: Dict[str, VariableTracker], + symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], + f_code: types.CodeType, + export: bool, + inline_depth: int, + speculation_log: SpeculationLog, + distributed_state: Optional[DistributedState], + ) -> None: + super().__init__() + self.speculation_log = speculation_log + self.distributed_state = distributed_state + + # Mutable state checkpointed by copy_graphstate() + self.output = output + self.symbolic_locals = symbolic_locals + self.symbolic_globals = symbolic_globals + self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack + self.stack = [] + # stack of variable names for tracking 3.13 closures + self.name_stack: list[Any] = [] + self.instruction_pointer = 0 + self.current_instruction = create_instruction("NOP") + self.block_stack = [] + # states before SETUP_WITH for checkpointing and fallback + self.generic_context_manager_depth = 0 + self.lineno = -1 + self.kw_names = None + self.accept_prefix_inst = True + self.prefix_insts = [] + self.exn_vt_stack = [] + + # Properties of the input/output code + self.instructions: List[Instruction] = instructions + self.indexof: Dict[Instruction, int] = get_indexof(self.instructions) + self.f_locals: Dict[ + str, Any + ] = f_locals # needed for recording accessed locals for replay + self.f_globals: Dict[str, Any] = f_globals + self.f_builtins: Dict[str, Any] = f_builtins + self.code_options: Dict[str, Any] = code_options + self.f_code: types.CodeType = f_code + + # Execution record for replaying errors + if config.replay_record_enabled: + self.exec_recorder = ExecutionRecorder( + code=f_code, code_options=code_options + ) + else: + self.exec_recorder = None + # Stack of module being parsed, current nn.module is at the end of ordered dict. + # The first field of tuple is the fully qualified name of current module + # in original hierarchy. The second field is the type of current nn.module + self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {} + # Flag to indicate whether tracing is used for export. + self.export = export + self.one_graph = False + + self.current_speculation = None + + self.strict_checks_fn = None + + if sys.version_info >= (3, 10): + from .resume_execution import ( + CO_ASYNC_GENERATOR, + CO_COROUTINE, + CO_GENERATOR, + CO_ITERABLE_COROUTINE, + ) + + if f_code.co_flags & ( + CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR + ): + self.push(BuiltinVariable(None)) + + self.inline_depth = inline_depth + self.inconsistent_side_effects = False + self._constants_cache: List[Optional[VariableTracker]] = [None] * len( + f_code.co_consts + ) + linecache.lazycache(f_code.co_filename, f_globals) + + +class InstructionTranslator(InstructionTranslatorBase): + mutated_closure_cell_contents: Set[str] + + @staticmethod + def current_tx() -> "InstructionTranslator": + return tls.current_tx + + @contextlib.contextmanager + def set_current_tx(self): + prior = getattr(tls, "current_tx", None) + tls.current_tx = self + try: + yield + finally: + tls.current_tx = prior + + def __init__( + self, + instructions: List[Instruction], + f_code, + f_locals, + f_globals, + f_builtins, + code_options, + compiler_fn, + one_graph, + export, + export_constraints, + mutated_closure_cell_contents: Set[str], + frame_state, + speculation_log: SpeculationLog, + distributed_state: Optional[DistributedState], + ) -> None: + _step_logger()( + logging.INFO, + f"torchdynamo start tracing {f_code.co_name} {code_options['co_filename']}:{code_options['co_firstlineno']}", + ) + super().__init__( + output=OutputGraph( + code_options, + compiler_fn, + self, + export, + export_constraints, + frame_state, + local_scope=f_locals, + global_scope=f_globals, + f_code=f_code, + ), + instructions=instructions, + f_locals=f_locals, + f_globals=f_globals, + f_builtins=f_builtins, + code_options=code_options, + symbolic_locals={}, # set below + # A global var is inserted only after a STORE_GLOBAL happens to it + symbolic_globals={}, + symbolic_torch_function_mode_stack=collections.deque(), + f_code=f_code, + export=export, + inline_depth=0, + speculation_log=speculation_log, + distributed_state=distributed_state, + ) + + self._throw_if_in_functorch() + + # as soon as we create the tracing context we should keep it active, so any calls + # into dynamo apis can rely on finding it + with tracing(self.output.tracing_context), self.set_current_tx(): + self.one_graph: bool = one_graph + self.export = export + self.mutated_closure_cell_contents = mutated_closure_cell_contents + if self.export: + assert ( + self.one_graph + ), "Export without one graph - something has gone wrong." + + vars = list(code_options["co_varnames"]) + cells_and_freevars = [x for x in self.cell_and_freevars() if x not in vars] + vars.extend(cells_and_freevars) + cells_and_freevars_set = set(cells_and_freevars) + + self.symbolic_locals = { + k: variables.LazyVariableTracker.create( + f_locals[k], + source=LocalSource(k, cell_or_freevar=k in cells_and_freevars_set), + ) + for k in vars + if k in f_locals + } + + self._init_torch_function_mode_stack() + + self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] + if export: + # export gets confused if we never realize unused inputs + # in export mode just eagerly realize everything + self.symbolic_locals = variables.LazyVariableTracker.realize_all( + self.symbolic_locals + ) + + self._freevars_ids = {} + for name in self.code_options["co_freevars"]: + if name in f_locals: + self._freevars_ids[name] = id(f_locals[name]) + + def _throw_if_in_functorch(self): + # Fallback to eager in case of a graph break inside vmap + eager = torch._dynamo.lookup_backend("eager") + compiler_fn = inspect.getattr_static( + self.output.compiler_fn, "compiler_fn", self.output.compiler_fn + ) + ci = torch._C._functorch.peek_interpreter_stack() + forbidden_keys = ( + torch._C._functorch.TransformType.Vmap, + torch._C._functorch.TransformType.Grad, + torch._C._functorch.TransformType.Jvp, + ) + + if ci is not None and ci.key() in forbidden_keys and compiler_fn is not eager: + name = ci.key().name.lower() + msg = ( + "If you are reaching here, it means dynamo failed for one of the following reasons:\n" + # Calling a torch.compiled function + f"- Calling torch.func.{name}(compiled_fn) function from eager mode is not supported. " + f"Ensure that torch.func.{name} is also wrapped within a torch.compile function. " + "For more information, see PyTorch issue #128711.\n" + # if it reaches here, it means Dynamo failed to inline a functorch function + f"- torch.func.{name}(fn) requires the function to be inlined by dynamo" + ) + unimplemented(msg) + + def _init_torch_function_mode_stack(self): + from .variables.torch_function import TorchFunctionModeStackVariable + + TorchFunctionModeStackVariable.reset() + + self.symbolic_torch_function_mode_stack: Deque[ + TorchFunctionModeVariable + ] = collections.deque() + # We want to retrieve all modes to properly reconstruct the stack if needed + py_stack = get_torch_function_mode_stack(filter_ignored=False) + + if py_stack: + has_device_context = isinstance( + py_stack[0], torch.utils._device.DeviceContext + ) + + for i, val in enumerate(py_stack): + self.symbolic_torch_function_mode_stack.append( + variables.LazyVariableTracker.create( + val, source=TorchFunctionModeStackSource(i) + ) + ) + + def get_example_value(self, source: Source): + if isinstance(source, LocalSource): + return self.f_locals[source.local_name] + if isinstance(source, GlobalSource): + return self.f_globals[source.global_name] + raise KeyError + + def run(self): + super().run() + + def match_nested_cell(self, name, cell): + """Match a cell in this method to one in a function we are inlining""" + try: + value = cell.cell_contents + except ValueError: + return None + # TODO(jansel): check the id of the cell rather than the contents + if id(value) != self._freevars_ids.get(name): + return None + return self.symbolic_locals[name] + + def should_compile_partial_graph(self): + if sys.version_info >= (3, 11): + # Do not compile if current instruction's block is not the top with block + entry = self.current_instruction.exn_tab_entry + if entry and ( + not self.block_stack or entry.target is not self.block_stack[-1].target + ): + return False + return ( + all(b.can_restore() for b in self.block_stack) + and not self.one_graph + and self.generic_context_manager_depth == 0 + ) + + def create_call_resume_at(self, inst): + self.instruction_pointer = None + + if inst.opname == "RETURN_VALUE": + return [create_instruction("RETURN_VALUE")] + elif inst.opname == "RETURN_CONST": + return [create_instruction("RETURN_CONST", argval=inst.argval)] + + reads = livevars_analysis(self.instructions, inst) + all_argnames = tuple( + k + for k in self.symbolic_locals.keys() + if k in reads and k not in self.cell_and_freevars() + ) + # NOTE: do not use isinstance, since it realizes lazy VT's + argnames = tuple( + k + for k in all_argnames + if not type.__instancecheck__(NullVariable, self.symbolic_locals[k]) + ) + argnames_null = tuple( + k + for k in all_argnames + if type.__instancecheck__(NullVariable, self.symbolic_locals[k]) + ) + if sys.version_info < (3, 12): + assert len(argnames_null) == 0, "variables should not be NULL in < 3.12" + + cg = PyCodegen(self) + + # Handle inactive context variables. + # The resume function assumes that context variables are the class, NOT the object. + # e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled + stack_ctx_vars = [] + for i, var in enumerate(self.stack): + if type.__instancecheck__(ContextWrappingVariable, var): + ctx = cast(ContextWrappingVariable, var) + target_values = ( + () if ctx.target_values is None else tuple(ctx.target_values) + ) + stack_ctx_vars.append((i, target_values)) + # Replace the current stack var with the context class + ctx.reconstruct_type(cg) + cg.extend_output(create_swap(len(self.stack) - i + 1)) + cg.append_output(create_instruction("POP_TOP")) + + argnames_ctx_vars = [] + for name in argnames: + if type.__instancecheck__( + ContextWrappingVariable, var := self.symbolic_locals[name] + ): + ctx = cast(ContextWrappingVariable, var) + target_values = ( + () if ctx.target_values is None else tuple(ctx.target_values) + ) + argnames_ctx_vars.append((name, target_values)) + # Replace the local with the context class + ctx.reconstruct_type(cg) + cg.append_output(create_instruction("STORE_FAST", argval=name)) + + # Python does not allow null to be an arg to a function, so + # we remove nulls from the stack and restore them in the + # prologue of the resume function + + # sorted list of indices of nulls on the stack + null_idxes: List[int] = [] + if sys.version_info >= (3, 11): + # find indices of NullVariables + for i, var in enumerate(self.stack): + if type.__instancecheck__(NullVariable, var): + null_idxes.append(i) + # generate bytecode to pop the nulls + null_cnt = 0 + for i, var in enumerate(reversed(self.stack)): + if type.__instancecheck__(NullVariable, var): + for j in range(2, i + 2 - null_cnt): + cg.append_output(create_instruction("SWAP", arg=j)) + cg.extend_output(cg.pop_null()) + null_cnt += 1 + + # we popped all nulls from the stack at runtime, + # so we should not count NullVariables + stack_len = len(self.stack) - len(null_idxes) + nargs = stack_len + len(argnames) + + name = unique_id(f"__resume_at_{inst.offset}") + + new_code: types.CodeType = ContinueExecutionCache.lookup( + self.f_code, + self.lineno, + inst.offset, + tuple(b.target.offset for b in self.block_stack), + stack_len, + argnames, + argnames_null, + tuple(b.resume_fn() for b in self.block_stack), + tuple(stack_ctx_vars), + tuple(argnames_ctx_vars), + tuple(null_idxes), + ) + + # Add original GraphModule context to the resume function to handle + # the case of a graph break while tracing a GraphModule + orig_graphmodule_maybe = code_context.get_context(self.f_code).get( + "orig_graphmodule", lambda: None + )() + if orig_graphmodule_maybe is not None: + code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref( + orig_graphmodule_maybe + ) + + if new_code.co_freevars: + # expose code object for debugging purposes + self.output.install_global_unsafe(name, new_code) + cg.make_function_with_closure(name, new_code, True, stack_len) + else: + # This is safe: we pre-generate a unique name + self.output.install_global_unsafe( + name, types.FunctionType(new_code, self.f_globals, name) + ) + cg.extend_output(cg.load_function_name(name, True, stack_len)) + + cg.extend_output([cg.create_load(k) for k in argnames]) + cg.extend_output(create_call_function(nargs, False)) + cg.append_output(create_instruction("RETURN_VALUE")) + return cg.get_instructions() + + def symbolic_locals_contain_module_class(self): + for v in self.symbolic_locals.values(): + if isinstance(v, UserDefinedClassVariable) and issubclass( + v.as_python_constant(), torch.nn.Module + ): + return True + return False + + def _return(self, inst): + if ( + self.output.count_calls() == 0 + and not self.inconsistent_side_effects + and not self.symbolic_locals_contain_module_class() + and not self.export + ): + raise exc.SkipFrame("because no content in function call") + self.instruction_pointer = None + _step_logger()( + logging.INFO, + f"torchdynamo done tracing {self.f_code.co_name} ({inst.opname})", + ) + log.debug("%s triggered compile", inst.opname) + self.output.compile_subgraph( + self, + reason=GraphCompileReason( + "return_value", [self.frame_summary()], graph_break=False + ), + ) + return_inst = ( + create_instruction("RETURN_VALUE") + if inst.opname == "RETURN_VALUE" + else create_instruction("RETURN_CONST", argval=inst.argval) + ) + self.output.add_output_instructions([return_inst]) + raise ReturnValueOp + + def RETURN_VALUE(self, inst): + self._return(inst) + + def RETURN_CONST(self, inst): + self._return(inst) + + +if sys.version_info >= (3, 11): + _binary_op_lookup = [ + getattr( + InstructionTranslator, + opname[3:] if "INPLACE" in opname else f"BINARY_{opname[3:]}", + ) + for opname, _ in dis._nb_ops # type: ignore[attr-defined] + ] + + +class InliningInstructionTranslator(InstructionTranslatorBase): + """Trace and inline a called method""" + + symbolic_result: Optional[TensorVariable] + + @classmethod + def inline_call(cls, parent, func, args, kwargs): + with patch.dict(counters, {"unimplemented": counters["inline_call"]}): + return cls.inline_call_(parent, func, args, kwargs) + + @staticmethod + def check_inlineable(func): + if func.has_self(): + unimplemented("inline with __self__") + + result = trace_rules.check_verbose(func, is_inlined_call=True) + if result.skipped: + from torch._dynamo.variables.misc import produce_trampoline_autograd_apply + + # _origin marks this as coming from an internal dynamo known function that is safe to + # trace through. + if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [ + produce_trampoline_autograd_apply, + ]: + # Known sound + return trace_rules.SkipResult( + False, "allowlist in dynamo known function" + ) + fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else "" + unimplemented( + f"'inline in skipfiles: {fn_qualname} | {func.get_name()} {func.get_filename()}, {result.reason}'" + ) + + if isinstance(func, UserFunctionVariable) and inspect.getattr_static( + func.get_function(), "_torchdynamo_disable", False + ): + unimplemented( + f"call torch._dynamo.disable() wrapped function {func.get_function()}" + ) + else: + return result + + @staticmethod + def inline_call_( + parent, func: VariableTracker, args: List[VariableTracker], kwargs + ): + if isinstance(func, SkipFunctionVariable): + unimplemented("inline with functions in skip files") + assert isinstance( + func, + (UserFunctionVariable, NestedUserFunctionVariable), + ) + result = InliningInstructionTranslator.check_inlineable(func) + assert result.skipped is False + try: + sub_locals, closure_cells = func.bind_args(parent, args, kwargs) + except TypeError as e: + # Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info + raise ArgsMismatchError( # noqa: B904 + "{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format( + reason=str(e), + func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}", + args=[arg.python_type() for arg in args], + kwargs=kwargs, + ), + ) + + for v in itertools.chain(sub_locals.values(), closure_cells.values()): + if not isinstance(v, VariableTracker): + unimplemented(f"unconverted arg {v}") + + code: types.CodeType = func.get_code() + if code.co_name in ("__setitem__", "__setattr__") and not ( + args + and isinstance( + args[0], + (variables.CustomizedDictVariable, variables.UserDefinedObjectVariable), + ) + ): + unimplemented(f"inline {code.co_name}") + + suffix = "" + # TODO: mlazos, add support for enabling multiple artifact logs + # with a single alias + if torch._logging._internal.log_state.is_artifact_enabled("bytecode"): + suffix = f"\n{dis.Bytecode(code).dis()}" + if sys.version_info >= (3, 11): + cur_inst = parent.current_instruction + parent_code = parent.f_code + header = parent.get_line_of_code_header(lineno=cur_inst.positions.lineno) + + def get_trace_call_log_str(): + line = get_instruction_source_311(parent_code, cur_inst).rstrip() + return f"TRACE inlined call {code.co_name} from {header}\n{line}" + + trace_call_log.debug("%s", LazyString(get_trace_call_log_str)) + log.debug("INLINING %s%s, %s", code, suffix, result.reason) + + # Detect inline GraphModule calls in order to propagate node metadata, + # by checking if the first argument (self) is a variable tracking a GraphModule. + if args and isinstance(args[0], NNModuleVariable): + module = parent.output.get_submodule(args[0].module_key) + if isinstance(module, torch.fx.GraphModule): + # The inline call might not actually be a call to `forward`, + # but it is enough to add a context for `forward` in case it is called. + code_context.get_context(module.forward.__code__)[ + "orig_graphmodule" + ] = weakref.ref(module) + + tracer: InliningInstructionTranslator + if is_generator(code): + tracer = InliningGeneratorInstructionTranslator( + parent, + code, + sub_locals, + parent.symbolic_globals, + parent.symbolic_torch_function_mode_stack, + closure_cells, + func, + ) + else: + tracer = InliningInstructionTranslator( + parent, + code, + sub_locals, + parent.symbolic_globals, + parent.symbolic_torch_function_mode_stack, + closure_cells, + func, + ) + + strict_ctx: Any = contextlib.nullcontext() + if parent.strict_checks_fn: + strict_ctx = tracer.strict_translation_mode(parent.strict_checks_fn) + try: + with strict_ctx: + tracer.run() + except exc.ObservedException as e: + msg = f"Observed exception DURING INLING {code} : {e}" + # TODO(anijain2305) - This works but we should probably have a + # global/central data structure for the exception stack. + parent.exn_vt_stack.extend(tracer.exn_vt_stack) + log.debug(msg) + # bubble up the exception to the parent frame. + raise + except exc.SkipFrame as e: + msg = f"SKIPPED INLINING {code}: {e}" + log.debug(msg) + raise Unsupported(msg) from e + except Exception as e: + log.debug("FAILED INLINING %s", code) + raise + assert tracer.symbolic_result is not None + func.export_freevars(parent, tracer) + + if tracer.f_globals is parent.f_globals: + # Merge symbolic_globals back if parent and child are in the same namespace + parent.symbolic_globals.update(tracer.symbolic_globals) + + parent.inconsistent_side_effects |= tracer.inconsistent_side_effects + + log.debug("DONE INLINING %s", code) + + if is_generator(code): + assert isinstance(tracer, InliningGeneratorInstructionTranslator) + assert tracer.symbolic_result.as_python_constant() is None + return ListIteratorVariable( + tracer.generated_items, + mutable_local=MutableLocal(), + ) + else: + return tracer.symbolic_result + + def __init__( + self, + parent: InstructionTranslatorBase, + code: types.CodeType, + symbolic_locals: Dict[str, VariableTracker], + symbolic_globals: Dict[str, VariableTracker], + symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], + closure_cells: Dict[str, VariableTracker], + funcvar: BaseUserFunctionVariable, + ) -> None: + f_globals = funcvar.get_globals() # type: ignore[attr-defined] + f_builtins = f_globals["__builtins__"] + if not isinstance(f_builtins, dict): + f_builtins = f_builtins.__dict__ + instructions = cleaned_instructions(code) + propagate_line_nums(instructions) + super().__init__( + output=parent.output, + f_locals={}, + f_globals=f_globals, + f_builtins=f_builtins, + symbolic_locals=symbolic_locals, + symbolic_globals=symbolic_globals, + symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack, + instructions=instructions, + code_options={k: getattr(code, k) for k in get_code_keys()}, + f_code=code, + export=parent.export, + inline_depth=parent.inline_depth + 1, + speculation_log=parent.speculation_log, + distributed_state=parent.distributed_state, + ) + self.parent = parent + self.symbolic_result = None + self.closure_cells = closure_cells + self.nn_module_stack = parent.nn_module_stack.copy() + self.one_graph = parent.one_graph + + @property + def fake_mode(self): + return self.parent.fake_mode + + def run_ctx_mgr(self): + return TracingContext.current_frame(self.parent.frame_summary()) + + def STORE_DEREF(self, inst): + if inst.argval in self.closure_cells: + cell = self.closure_cells[inst.argval] + val = self.pop() + if isinstance(cell, ClosureVariable): + if not self.output.is_root_tracer(): + unimplemented( + "HigherOrderOperator: Mutating a variable not in the current scope (ClosureVariable)" + ) + self.output.root_tx.symbolic_locals[cell.name] = val + else: + self.output.side_effects.store_cell(cell, val) + else: + maybe_cell = self.symbolic_locals.get(inst.argval) + if isinstance( + maybe_cell, + variables.NewCellVariable, + ): + self.output.side_effects.store_cell( + self.symbolic_locals[inst.argval], self.pop() + ) + else: + if ( + maybe_cell is not None + and maybe_cell.source.name() + not in self.output.root_tx.mutated_closure_cell_contents + ): + # Why is the source name here unique? + # mutated_closure_cell_contents is a per-frame + # concept, and sources identify, e.g., particular + # locals from the frame. If you had two locals, + # they'll get different source names, and therefore + # differ here. + self.output.root_tx.mutated_closure_cell_contents.add( + maybe_cell.source.name() + ) + raise exc.UnspecializeRestartAnalysis + unimplemented("write to __closure__ while inlining") + + def LOAD_DEREF(self, inst): + if inst.argval in self.closure_cells: + cell = self.closure_cells[inst.argval] + if isinstance(cell, ClosureVariable): + self.push(self.output.root_tx.symbolic_locals[cell.name]) + else: + self.push(self.output.side_effects.load_cell(cell)) + else: + maybe_sym_local = self.symbolic_locals.get(inst.argval, None) + if isinstance(maybe_sym_local, variables.NewCellVariable): + self.push(self.output.side_effects.load_cell(maybe_sym_local)) + else: + super().LOAD_DEREF(inst) + + def _load_closure(self, name): + assert name in self.cell_and_freevars() + if name in self.closure_cells: + return self.closure_cells[name] + else: + return InlinedClosureVariable(name=name) + + def check_replace_is_safe(self, oldvar): + if not is_side_effect_safe(oldvar.mutable_local): + unimplemented( + "HigherOrderOperator: Mutating a variable not in the current scope (replace_all)" + ) + + def should_compile_partial_graph(self): + return False # inlining functions is all-or-nothing + + def create_call_resume_at(self, offset): + unimplemented("cant resume while inlining") + + def RETURN_VALUE(self, inst): + self.symbolic_result = self.pop() # type: ignore[assignment] + self.instruction_pointer = None + raise ReturnValueOp + + def RETURN_CONST(self, inst): + self.symbolic_result = self._load_const(inst) + self.instruction_pointer = None + raise ReturnValueOp + + def get_globals_source_and_value(self, name): + if "__name__" in self.f_globals: + module_name = self.f_globals["__name__"] + module_source = self.import_source(module_name) + if "torch_package" in module_name: + fglobals_value = torch.package.package_importer._package_imported_modules[module_name] # type: ignore[assignment] + else: + fglobals_value = importlib.import_module(module_name) # type: ignore[assignment] + fglobals_vt = VariableBuilder(self, module_source)(fglobals_value) + global_source = AttrSource(module_source, name) + else: + globals_name = self.output.install_global_by_id( + "___unnamed_scope", self.f_globals + ) + globals_source = GlobalSource(globals_name) + fglobals_value = self.f_globals # type: ignore[assignment] + fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value) + global_source = GetItemSource(globals_source, name) # type: ignore[assignment] + return fglobals_value, fglobals_vt, global_source + + def _load_global(self, inst): + if self.output.global_scope is self.f_globals: + super()._load_global(inst) + else: + name = inst.argval + + _, fglobals_vt, global_source = self.get_globals_source_and_value(name) + if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name): + self.push(self.output.side_effects.load_attr(fglobals_vt, name)) + else: + try: + value = self.f_globals[name] + except KeyError: + return self.load_builtin(inst) + + self.push(VariableBuilder(self, global_source)(value)) + + def STORE_GLOBAL(self, inst): + if self.f_globals is self.parent.f_globals: + super().STORE_GLOBAL(inst) + else: + value = self.pop() + if isinstance(value, RemovableHandleVariable): + unimplemented("Storing handles in globals - NYI") + name = inst.argval + fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name) + self.output.side_effects.store_attr(fglobals_vt, name, value) + + +class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): + generated_items: List[VariableTracker] + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.generated_items = [] + + def YIELD_VALUE(self, inst: Instruction): + self.generated_items.append(self.pop()) + if len(self.generated_items) > MAX_ITERATOR_LIMIT: + unimplemented( + "Too many yield values in generator. Maybe you are inlining an infinite generator. " + f"If not, please report a bug at {PT2_ISSUE_TRACKER_URL}", + ) + self.push(ConstantVariable.create(None)) + + def GET_YIELD_FROM_ITER(self, inst): + tos = self.stack[-1] + if not isinstance(tos, ListIteratorVariable): + self.pop() + res = BuiltinVariable(iter).call_function(self, [tos], {}) # type: ignore[arg-type] + self.push(res) + + def YIELD_FROM(self, inst): + assert len(self.stack) >= 2 + val = self.pop() + tos = self.stack[-1] + if not (isinstance(val, ConstantVariable) and val.value is None): + # invoke send + # Unreachable code - if you hit this, you are implementing generator support and have + # lifted the `unimplemented("generator")` in frame conversion. This codepath handles + # subgenerator and lines up with this line in Python 3.10 + # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599 + unimplemented("Unreachable sub-generator code") + + try: + val = tos.next_variable(self) + except (StopIteration, exc.ObservedUserStopIteration) as ex: + if isinstance(ex, exc.ObservedUserStopIteration): + exc.handle_observed_exception(self) + + # The iterator is exhausted. Stop the loop and return. + self.pop() + self.push(ConstantVariable.create(ex.value)) + else: + self.push(val) + # Add the value to yield into generated_items and replace the top of the stack with None + self.YIELD_VALUE(inst) + + # Repeat the YIELD_FROM instruction in the next eval loop + assert ( + isinstance(self.instruction_pointer, int) + and self.instruction_pointer > 0 + ) + self.instruction_pointer -= 1 + + def SEND(self, inst): + assert len(self.stack) >= 2 + val = self.pop() + tos = self.stack[-1] + if isinstance(tos, ListIteratorVariable) or ( + isinstance(tos, UserDefinedObjectVariable) + and isinstance(tos.value, collections.abc.Iterator) + ): + if isinstance(val, ConstantVariable) and val.value is None: + try: + val = tos.next_variable(self) + except (StopIteration, exc.ObservedUserStopIteration) as ex: + # To implement SEND, we have to look at the implementation + # when the iterator returns StopIteration. This translates to this code + # 3.11: https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2613-L2619 + # 3.12: https://github.com/python/cpython/blob/3.12/Python/bytecodes.c#L863-L866 + # The implementation is different in 3.11 and 3.12. In 3.12, we rely + # on END_SEND to clean up. In 3.11, SEND does the cleanup as well. + if sys.version_info < (3, 12): + self.pop() # Python 3.12 uses new opcode END_SEND + self.push(ConstantVariable.create(ex.value)) + self.jump(inst) + else: + self.push(val) + else: + # invoke send + # Unreachable code - if you hit this, you are implementing generator support and have + # lifted the `unimplemented("generator")` in frame conversion. This codepath handles + # subgenerator and lines up with this line in Python 3.11 + # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597 + unimplemented("Unreachable sub-generator code") + else: + unimplemented(f"SEND {typestr(tos)}") diff --git a/lib/python3.10/site-packages/torch/_dynamo/tensor_version_op.py b/lib/python3.10/site-packages/torch/_dynamo/tensor_version_op.py new file mode 100644 index 0000000000000000000000000000000000000000..889b2450409f48bd0334dbe9e3d0197820a2c820 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/tensor_version_op.py @@ -0,0 +1,59 @@ +# mypy: allow-untyped-defs +import torch +from torch._prims import _make_prim, RETURN_TYPE +from torch._subclasses import FakeTensorMode +from torch._subclasses.functional_tensor import FunctionalTensorMode + + +_tensor_version = _make_prim( + schema="_tensor_version(Tensor self) -> SymInt", + return_type=RETURN_TYPE.NEW, + meta=torch.ops.aten._version.default, + impl_aten=torch.ops.aten._version.default, + doc="Tracable unbacked SymInt version of torch.Tensor._version", +) + + +@_tensor_version.py_impl(FakeTensorMode) +def _tensor_version_fake(fake_mode, self_tensor): + """ + The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the + `._version` into an unbacked SymInt so that we don't need to specialize on the `._version` + of input tensors to the graph. + """ + return fake_mode.shape_env.create_unbacked_symint() + + +_unsafe_set_version_counter = _make_prim( + schema="_unsafe_set_version_counter(Tensor self, SymInt version) -> ()", + return_type=RETURN_TYPE.NEW, + meta=lambda self, version: None, + impl_aten=torch._C._autograd._unsafe_set_version_counter, + doc="Tracable+SymInt version of torch._C._autograd._unsafe_set_version_counter", +) +torch.fx.node.has_side_effect(_unsafe_set_version_counter) + + +""" +When we functionalize _tensor_version + _unsafe_set_version_counter, +the ops disappear from the traced graph. We run them eagerly on the +fake tensors used for tracing, in order to get past asserts that would +fail in autograd. + +Why is this ok? +1) Versions on functional tensors don't make any sense since you can't mutate a functional tensor. +2) The whole point of version munging is to trick autograd into doing what we want, and after + AotAtuograd there is no longer any need for these ops. + +Note this is similar to how no_grad is handled. +""" + + +@_tensor_version.py_impl(FunctionalTensorMode) +def _tensor_version_functional(mode, self): + return self._version + + +@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) +def _unsafe_set_version_counter_functional(ctx, self, version): + torch._C._autograd._unsafe_set_version_counter(self, version) diff --git a/lib/python3.10/site-packages/torch/_dynamo/test_case.py b/lib/python3.10/site-packages/torch/_dynamo/test_case.py new file mode 100644 index 0000000000000000000000000000000000000000..81c0407833f67cb240f99d1d45ac4eda1b369028 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/test_case.py @@ -0,0 +1,75 @@ +# mypy: allow-untyped-defs +import contextlib +import importlib +import logging + +import torch +import torch.testing +from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] + IS_WINDOWS, + TEST_WITH_CROSSREF, + TEST_WITH_TORCHDYNAMO, + TestCase as TorchTestCase, +) + +from . import config, reset, utils + + +log = logging.getLogger(__name__) + + +def run_tests(needs=()): + from torch.testing._internal.common_utils import run_tests + + if TEST_WITH_TORCHDYNAMO or IS_WINDOWS or TEST_WITH_CROSSREF: + return # skip testing + + if isinstance(needs, str): + needs = (needs,) + for need in needs: + if need == "cuda": + if not torch.cuda.is_available(): + return + else: + try: + importlib.import_module(need) + except ImportError: + return + run_tests() + + +class TestCase(TorchTestCase): + _exit_stack: contextlib.ExitStack + + @classmethod + def tearDownClass(cls): + cls._exit_stack.close() + super().tearDownClass() + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack = contextlib.ExitStack() # type: ignore[attr-defined] + cls._exit_stack.enter_context( # type: ignore[attr-defined] + config.patch( + raise_on_ctx_manager_usage=True, + suppress_errors=False, + log_compilation_metrics=False, + ), + ) + + def setUp(self): + self._prior_is_grad_enabled = torch.is_grad_enabled() + super().setUp() + reset() + utils.counters.clear() + + def tearDown(self): + for k, v in utils.counters.items(): + print(k, v.most_common()) + reset() + utils.counters.clear() + super().tearDown() + if self._prior_is_grad_enabled is not torch.is_grad_enabled(): + log.warning("Running test changed grad mode") + torch.set_grad_enabled(self._prior_is_grad_enabled) diff --git a/lib/python3.10/site-packages/torch/_dynamo/test_minifier_common.py b/lib/python3.10/site-packages/torch/_dynamo/test_minifier_common.py new file mode 100644 index 0000000000000000000000000000000000000000..b05542d578f43bd6f4eaadb919b2852233d84e2d --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/test_minifier_common.py @@ -0,0 +1,249 @@ +# mypy: allow-untyped-defs +import dataclasses +import io +import logging +import os +import re +import shutil +import subprocess +import sys +import tempfile +import traceback +from typing import Optional +from unittest.mock import patch + +import torch +import torch._dynamo +import torch._dynamo.test_case +from torch._dynamo.trace_rules import _as_posix_path +from torch.utils._traceback import report_compile_source_on_error + + +@dataclasses.dataclass +class MinifierTestResult: + minifier_code: str + repro_code: str + + def _get_module(self, t): + match = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t) + assert match is not None, "failed to find module" + r = match.group(0) + r = re.sub(r"\s+$", "\n", r, flags=re.MULTILINE) + r = re.sub(r"\n{3,}", "\n\n", r) + return r.strip() + + def minifier_module(self): + return self._get_module(self.minifier_code) + + def repro_module(self): + return self._get_module(self.repro_code) + + +class MinifierTestBase(torch._dynamo.test_case.TestCase): + DEBUG_DIR = tempfile.mkdtemp() + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack.enter_context( # type: ignore[attr-defined] + torch._dynamo.config.patch(debug_dir_root=cls.DEBUG_DIR) + ) + # These configurations make new process startup slower. Disable them + # for the minification tests to speed them up. + cls._exit_stack.enter_context( # type: ignore[attr-defined] + torch._inductor.config.patch( + { + # https://github.com/pytorch/pytorch/issues/100376 + "pattern_matcher": False, + # multiprocess compilation takes a long time to warmup + "compile_threads": 1, + # https://github.com/pytorch/pytorch/issues/100378 + "cpp.vec_isa_ok": False, + } + ) + ) + + @classmethod + def tearDownClass(cls): + if os.getenv("PYTORCH_KEEP_TMPDIR", "0") != "1": + shutil.rmtree(cls.DEBUG_DIR) + else: + print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}") + cls._exit_stack.close() # type: ignore[attr-defined] + + def _gen_codegen_fn_patch_code(self, device, bug_type): + assert bug_type in ("compile_error", "runtime_error", "accuracy") + return f"""\ +{torch._dynamo.config.codegen_config()} +{torch._inductor.config.codegen_config()} +torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r} +""" + + def _maybe_subprocess_run(self, args, *, isolate, cwd=None): + if not isolate: + assert len(args) >= 2, args + assert args[0] == "python3", args + if args[1] == "-c": + assert len(args) == 3, args + code = args[2] + args = ["-c"] + else: + assert len(args) >= 2, args + with open(args[1]) as f: + code = f.read() + args = args[1:] + + # WARNING: This is not a perfect simulation of running + # the program out of tree. We only interpose on things we KNOW we + # need to handle for tests. If you need more stuff, you will + # need to augment this appropriately. + + # NB: Can't use save_config because that will omit some fields, + # but we must save and reset ALL fields + dynamo_config = torch._dynamo.config.shallow_copy_dict() + inductor_config = torch._inductor.config.shallow_copy_dict() + try: + stderr = io.StringIO() + log_handler = logging.StreamHandler(stderr) + log = logging.getLogger("torch._dynamo") + log.addHandler(log_handler) + try: + prev_cwd = _as_posix_path(os.getcwd()) + if cwd is not None: + cwd = _as_posix_path(cwd) + os.chdir(cwd) + with patch("sys.argv", args), report_compile_source_on_error(): + exec(code, {"__name__": "__main__", "__compile_source__": code}) + rc = 0 + except Exception: + rc = 1 + traceback.print_exc(file=stderr) + finally: + log.removeHandler(log_handler) + if cwd is not None: + os.chdir(prev_cwd) # type: ignore[possibly-undefined] + # Make sure we don't leave buggy compiled frames lying + # around + torch._dynamo.reset() + finally: + torch._dynamo.config.load_config(dynamo_config) + torch._inductor.config.load_config(inductor_config) + + # TODO: return a more appropriate data structure here + return subprocess.CompletedProcess( + args, + rc, + b"", + stderr.getvalue().encode("utf-8"), + ) + else: + if cwd is not None: + cwd = _as_posix_path(cwd) + return subprocess.run(args, capture_output=True, cwd=cwd, check=False) + + # Run `code` in a separate python process. + # Returns the completed process state and the directory containing the + # minifier launcher script, if `code` outputted it. + def _run_test_code(self, code, *, isolate): + proc = self._maybe_subprocess_run( + ["python3", "-c", code], isolate=isolate, cwd=self.DEBUG_DIR + ) + + print("test stdout:", proc.stdout.decode("utf-8")) + print("test stderr:", proc.stderr.decode("utf-8")) + repro_dir_match = re.search( + r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8") + ) + if repro_dir_match is not None: + return proc, repro_dir_match.group(1) + return proc, None + + # Runs the minifier launcher script in `repro_dir` + def _run_minifier_launcher(self, repro_dir, isolate, *, minifier_args=()): + self.assertIsNotNone(repro_dir) + launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py")) + with open(launch_file) as f: + launch_code = f.read() + self.assertTrue(os.path.exists(launch_file)) + + args = ["python3", launch_file, "minify", *minifier_args] + if not isolate: + args.append("--no-isolate") + launch_proc = self._maybe_subprocess_run(args, isolate=isolate, cwd=repro_dir) + print("minifier stdout:", launch_proc.stdout.decode("utf-8")) + stderr = launch_proc.stderr.decode("utf-8") + print("minifier stderr:", stderr) + self.assertNotIn("Input graph did not fail the tester", stderr) + + return launch_proc, launch_code + + # Runs the repro script in `repro_dir` + def _run_repro(self, repro_dir, *, isolate=True): + self.assertIsNotNone(repro_dir) + repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py")) + with open(repro_file) as f: + repro_code = f.read() + self.assertTrue(os.path.exists(repro_file)) + + repro_proc = self._maybe_subprocess_run( + ["python3", repro_file], isolate=isolate, cwd=repro_dir + ) + print("repro stdout:", repro_proc.stdout.decode("utf-8")) + print("repro stderr:", repro_proc.stderr.decode("utf-8")) + return repro_proc, repro_code + + # Template for testing code. + # `run_code` is the code to run for the test case. + # `patch_code` is the code to be patched in every generated file; usually + # just use this to turn on bugs via the config + def _gen_test_code(self, run_code, repro_after, repro_level): + return f"""\ +import torch +import torch._dynamo +{_as_posix_path(torch._dynamo.config.codegen_config())} +{_as_posix_path(torch._inductor.config.codegen_config())} +torch._dynamo.config.repro_after = "{repro_after}" +torch._dynamo.config.repro_level = {repro_level} +torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}" +{run_code} +""" + + # Runs a full minifier test. + # Minifier tests generally consist of 3 stages: + # 1. Run the problematic code + # 2. Run the generated minifier launcher script + # 3. Run the generated repro script + # + # If possible, you should run the test with isolate=False; use + # isolate=True only if the bug you're testing would otherwise + # crash the process + def _run_full_test( + self, run_code, repro_after, expected_error, *, isolate, minifier_args=() + ) -> Optional[MinifierTestResult]: + if isolate: + repro_level = 3 + elif expected_error is None or expected_error == "AccuracyError": + repro_level = 4 + else: + repro_level = 2 + test_code = self._gen_test_code(run_code, repro_after, repro_level) + print("running test", file=sys.stderr) + test_proc, repro_dir = self._run_test_code(test_code, isolate=isolate) + if expected_error is None: + # Just check that there was no error + self.assertEqual(test_proc.returncode, 0) + self.assertIsNone(repro_dir) + return None + # NB: Intentionally do not test return code; we only care about + # actually generating the repro, we don't have to crash + self.assertIn(expected_error, test_proc.stderr.decode("utf-8")) + self.assertIsNotNone(repro_dir) + print("running minifier", file=sys.stderr) + minifier_proc, minifier_code = self._run_minifier_launcher( + repro_dir, isolate=isolate, minifier_args=minifier_args + ) + print("running repro", file=sys.stderr) + repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate) + self.assertIn(expected_error, repro_proc.stderr.decode("utf-8")) + self.assertNotEqual(repro_proc.returncode, 0) + return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code) diff --git a/lib/python3.10/site-packages/torch/_dynamo/testing.py b/lib/python3.10/site-packages/torch/_dynamo/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..4922b521bada34aaf0653702078da7cf369c931b --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/testing.py @@ -0,0 +1,409 @@ +# mypy: allow-untyped-defs +import contextlib +import dis +import functools +import logging +import os.path +import random +import re +import sys +import types +import unittest +from typing import List, Optional, Sequence, Union +from unittest.mock import patch + +import torch +from torch import fx +from torch._dynamo.output_graph import OutputGraph + +from . import config, eval_frame, optimize_assert, reset +from .bytecode_transformation import ( + create_instruction, + debug_checks, + is_generator, + transform_code_object, +) +from .guards import CheckFunctionManager, CompileId, GuardedCode +from .utils import same + + +np: Optional[types.ModuleType] = None +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +unsupported = eval_frame.unsupported +three = 3 + +log = logging.getLogger(__name__) + + +def clone_me(x): + if x is None: + return None + return x.detach().clone().requires_grad_(x.requires_grad) + + +def remove_optimized_module_prefix(name) -> str: + return re.sub(r"^_orig_mod[.]", "", name) + + +def collect_results(model, prediction, loss, example_inputs): + results = [] + results.append(prediction) + results.append(loss) + # if isinstance(loss, torch.Tensor) and loss.item() > 1: + # log.warning( + # f"High loss value alert - {loss:.2f}. Can result in unstable gradients." + # ) + + grads = {} + params = {} + for name, param in model.named_parameters(): + if isinstance(model, eval_frame.OptimizedModule): + name = remove_optimized_module_prefix(name) + param_copy = param + grad = param.grad + # Treat None and zero grad as same + if param.grad is None: + grad = torch.zeros_like(param) + grads[name + ".grad"] = grad + params[name] = param_copy + results.append(grads) + results.append(params) + buffers = {} + for name, buffer in model.named_buffers(): + if isinstance(model, eval_frame.OptimizedModule): + name = remove_optimized_module_prefix(name) + buffers[name] = buffer + results.append(buffers) + for example in example_inputs: + if isinstance(example, (tuple, list)): + for inp in example: + if isinstance(inp, torch.Tensor): + results.append(inp.grad) + else: + if isinstance(example, torch.Tensor): + results.append(example.grad) + return results + + +def requires_bwd_pass(out): + if isinstance(out, torch.Tensor): + return out.requires_grad + elif isinstance(out, (list, tuple)): + return any(requires_bwd_pass(x) for x in out) + elif out is None: + return False + elif isinstance(out, int): + return False + raise NotImplementedError("Don't know how to reduce", type(out)) + + +def reduce_to_scalar_loss(out): + """Reduce the output of a model to get scalar loss""" + if isinstance(out, torch.Tensor): + # Mean does not work on integer tensors + return out.sum() / out.numel() + elif isinstance(out, (list, tuple)): + return sum(reduce_to_scalar_loss(x) for x in out) / len(out) + elif type(out).__name__ in ( + "MaskedLMOutput", + "Seq2SeqLMOutput", + "CausalLMOutputWithCrossAttentions", + ): + return reduce_to_scalar_loss(out.logits) + elif type(out).__name__ == "SquashedNormal": + return out.mean.sum() + elif isinstance(out, dict): + return sum(reduce_to_scalar_loss(value) for value in out.values()) / len( + out.keys() + ) + raise NotImplementedError("Don't know how to reduce", type(out)) + + +def debug_dir() -> str: + path = os.path.join(os.path.dirname(__file__), "../debug") + if not os.path.exists(path): + os.mkdir(path) + return path + + +def debug_dump(name, code: types.CodeType, extra="") -> None: + with open(os.path.join(debug_dir(), name), "w") as fd: + fd.write( + f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n" + ) + + +def debug_insert_nops( + frame, cache_size, hooks, _, *, skip: int = 0 +) -> Optional[GuardedCode]: + """used to debug jump updates""" + + def insert_nops(instructions, code_options): + instructions.insert(0, create_instruction("NOP")) + instructions.insert(0, create_instruction("NOP")) + + if is_generator(frame.f_code): + return None + + debug_checks(frame.f_code) + code = transform_code_object(frame.f_code, insert_nops) + graph = OutputGraph( + code_options={}, + compiler_fn=None, + root_tx=None, + export=False, + export_constraints=None, + frame_state={"_id": 0}, + # TODO: shouldn't this be f_locals/f_globals from frame? + local_scope=locals(), + global_scope=globals(), + f_code=frame.f_code, + ) + + return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) + + +class CompileCounter: + def __init__(self): + self.frame_count = 0 + self.op_count = 0 + + def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + self.frame_count += 1 + for node in gm.graph.nodes: + if "call" in node.op: + self.op_count += 1 + return gm.forward + + def clear(self): + self.frame_count = 0 + self.op_count = 0 + + +class CompileCounterWithBackend: + def __init__(self, backend): + self.frame_count = 0 + self.op_count = 0 + self.backend = backend + self.graphs = [] + + def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + from .backends.registry import lookup_backend + + self.frame_count += 1 + for node in gm.graph.nodes: + if "call" in node.op: + self.op_count += 1 + self.graphs.append(gm) + return lookup_backend(self.backend)(gm, example_inputs) + + +# Equivalent to backend="eager", but also records graphs that +# we can assert on +class EagerAndRecordGraphs: + def __init__(self): + self.graphs = [] + + def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + self.graphs.append(gm) + return gm.forward + + +def strip_comment(code) -> str: + code = str(code) + return re.sub(r"(?m)^ *#.*\n?", "", code) + + +def remove_trailing_space(code) -> str: + return "\n".join([line.rstrip() for line in code.split("\n")]) + + +def normalize_gm(gm_str) -> str: + # strip comments as comments have path to files which may differ from + # system to system. + return remove_trailing_space(strip_comment(gm_str)) + + +def empty_line_normalizer(code: str) -> str: + """ + Normalize code: remove empty lines. + """ + normal_code = re.sub(r"[\r\n]+", "\n", code) + return normal_code + + +def standard_test( + self, + fn, + nargs, + expected_ops=None, + expected_ops_dynamic=None, + expected_frame_count=1, +): + if not config.assume_static_by_default and expected_ops_dynamic is not None: + expected_ops = expected_ops_dynamic + + actual = CompileCounter() + + args1 = [torch.randn(10, 10) for _ in range(nargs)] + args2 = [torch.randn(10, 10) for _ in range(nargs)] + correct1 = fn(*args1) + correct2 = fn(*args2) + reset() + opt_fn = optimize_assert(actual)(fn) + val1a = opt_fn(*args1) + val2a = opt_fn(*args2) + val1b = opt_fn(*args1) + val2b = opt_fn(*args2) + reset() + self.assertTrue(same(val1a, correct1)) + self.assertTrue(same(val1b, correct1)) + self.assertTrue(same(val2a, correct2)) + self.assertTrue(same(val2b, correct2)) + self.assertEqual(actual.frame_count, expected_frame_count) + if expected_ops is not None: + self.assertEqual(actual.op_count, expected_ops) + + +def dummy_fx_compile(gm: fx.GraphModule, example_inputs): + return gm.forward + + +def format_speedup(speedup, pvalue, is_correct=True, pvalue_threshold=0.1): + if not is_correct: + return "ERROR" + if pvalue > pvalue_threshold: + return f"{speedup:.3f}x SAME" + return f"{speedup:.3f}x p={pvalue:.2f}" + + +def rand_strided( + size: Sequence[int], + stride: Sequence[int], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + extra_size: int = 0, +): + needed_size = ( + sum((shape - 1) * stride for shape, stride in zip(size, stride)) + + 1 + + extra_size + ) + if dtype.is_floating_point: + if dtype.itemsize == 1: + """ + normal distribution kernel is not implemented for fp8.. + Workaround that by creating a fp16 tensor and then cast. + """ + buffer = torch.randn(needed_size, dtype=torch.float16, device=device).to( + dtype=dtype + ) + else: + buffer = torch.randn(needed_size, dtype=dtype, device=device) + else: + buffer = torch.zeros(size=[needed_size], dtype=dtype, device=device) + return torch.as_strided(buffer, size, stride) + + +def _make_fn_with_patches(fn, *patches): + @functools.wraps(fn) + def _fn(*args, **kwargs): + with contextlib.ExitStack() as stack: + for module, attr, val in patches: + stack.enter_context(patch.object(module, attr, val)) + + return fn(*args, **kwargs) + + return _fn + + +def make_test_cls_with_patches( + cls, cls_prefix, fn_suffix, *patches, xfail_prop=None, decorator=lambda x: x +): + DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {}) + DummyTestClass.__qualname__ = DummyTestClass.__name__ + + for name in dir(cls): + if name.startswith("test_"): + fn = getattr(cls, name) + if not callable(fn): + setattr(DummyTestClass, name, getattr(cls, name)) + continue + new_name = f"{name}{fn_suffix}" + new_fn = _make_fn_with_patches(fn, *patches) + new_fn.__name__ = new_name + if xfail_prop is not None and hasattr(fn, xfail_prop): + new_fn = unittest.expectedFailure(new_fn) + setattr(DummyTestClass, new_name, decorator(new_fn)) + # NB: Doesn't handle slots correctly, but whatever + elif not hasattr(DummyTestClass, name): + setattr(DummyTestClass, name, getattr(cls, name)) + + return DummyTestClass + + +# test Python 3.11+ specific features +def skipIfNotPy311(fn): + if sys.version_info >= (3, 11): + return fn + return unittest.skip(fn) + + +def skipIfNotPy312(fn): + if sys.version_info >= (3, 12): + return fn + return unittest.skip(fn) + + +def xfailIfPy312(fn): + if sys.version_info >= (3, 12): + return unittest.expectedFailure(fn) + return fn + + +def skipIfPy312(fn): + if sys.version_info >= (3, 12): + return unittest.skip(fn) + return fn + + +def requiresPy310(fn): + if sys.version_info >= (3, 10): + return fn + else: + unittest.skip(fn) + + +# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py +# and test/dynamo/test_dynamic_shapes.py +def expectedFailureDynamic(fn): + fn._expected_failure_dynamic = True + return fn + + +# Controls tests generated in test/inductor/test_torchinductor_codegen_dynamic_shapes.py +def expectedFailureCodegenDynamic(fn): + fn._expected_failure_codegen_dynamic = True + return fn + + +# Controls test generated in test/inductor/test_cpp_wrapper.py +def expectedFailureDynamicWrapper(fn): + fn._expected_failure_dynamic_wrapper = True + return fn + + +def reset_rng_state(use_xla=False): + torch.manual_seed(1337) + random.seed(1337) + if np: + np.random.seed(1337) + if use_xla: + import torch_xla.core.xla_model as xm + + xm.set_rng_state(1337, str(xm.xla_device())) diff --git a/lib/python3.10/site-packages/torch/_dynamo/trace_rules.py b/lib/python3.10/site-packages/torch/_dynamo/trace_rules.py new file mode 100644 index 0000000000000000000000000000000000000000..01cc4dc51be129bf09c6c5235bfacae9583919d9 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/trace_rules.py @@ -0,0 +1,3637 @@ +# mypy: allow-untyped-defs +import _collections_abc +import _weakrefset +import abc +import builtins +import collections +import contextlib +import copy +import copyreg +import dataclasses +import enum +import functools +import importlib +import inspect +import linecache +import logging +import multiprocessing +import operator +import os +import posixpath +import random +import re +import selectors +import signal +import sys +import tempfile +import threading +import tokenize +import traceback +import types +import typing +import unittest +import weakref +from collections import defaultdict +from pathlib import Path +from typing import Any, Callable, cast, Dict, List, Optional, Set, Type, Union + +import torch +import torch._inductor.test_operators +import torch.distributed +import torch.utils._content_store +from torch.utils import _config_module + +from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX +from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper +from .variables import ( + BuiltinVariable, + FunctionalCallVariable, + FunctorchHigherOrderVariable, + NestedUserFunctionVariable, + PolyfilledFunctionVariable, + SkipFunctionVariable, + TorchInGraphFunctionVariable, + UserFunctionVariable, + UserMethodVariable, +) + + +np: Optional[types.ModuleType] = None +try: + import numpy as np +except ModuleNotFoundError: + pass + + +if typing.TYPE_CHECKING: + from .variables.base import VariableTracker + + +""" +A note on skip/inline rules: + +Dynamo consults this file to determine whether function should be inlined or skipped. + +A skip applies at the frame boundary, meaning dynamo either triggers a graph break +at the beginning of the frame or attempts to trace/inline the whole frame. When skipping +a frame, recursively called frames are still traced by dynamo unless also skipped. + +Skipfiles (skipped at the file level instead of function level) still apply on a +frame-by-frame boundary as dynamo traces, but apply to all functions in that file. + +@skip is a helper decorator that can be applied to your function to cause it to be +included here. + +Dynamo skip/inline rules & priorities are defined as follows: +* Inline is the default behavior and will be used unless explicitly skipped. +* Dynamo has two SKIPLIST: BUILTIN_SKIPLIST and THIRDPARTY_SKIPLIST. + * BUILTIN_SKIPLIST contains builtin python modules, such as abc, collections, etc. + * THIRDPARTY_SKIPLIST contains common third party libraries, such as numpy, pandas, etc. +* Functions in these two SKIPLISTs are always skipped, except: + * They have explicitly defined rule in `manual_torch_name_rule_map`; + * The corresponding python module has been put into MOD_INLINELIST. +* PyTorch(torch) is in the BUILTIN_SKIPLIST by default, but there are many cases + where we want inline the functions under torch namespace. + We should specify inline for the functions in `manual_torch_name_rule_map` or + put the corresponding python module into MOD_INLINELIST to make dynamo inline them. +* If you call functions under skipped modules/files, Dynamo will wrap these functions + as SkipFunctionVariable. There are a few functions(e.g, collections.OrderedDict) that + we have special handling at SkipFunctionVariable.call_function. + +Overall: *_INLINELIST has precedence over *_SKIPLIST has precedence over DEFAULT (inline) + +To figure out what the behavior is, check the following list in order: +* `manual_torch_name_rule_map` (Inline if YES) +* MOD_INLINELIST (Inline if YES) +* BUILTIN_SKIPLIST & THIRDPARTY_SKIPLIST (Skip if YES) +* Inline by default + +In general, if you want to force inline a function or module, please consider adding +the function's python module to MOD_INLINELIST first. +Use the `manual_torch_name_rule_map` only when there are other functions under the same module that +you don't want to inline them. +""" + +""" +Map of function objects to their tracing rules (Dynamo variables). +* TorchInGraphFunctionVariable: The functions should be put into the FX graph or can be constant folded. E.g., + - torch.add: should be put into the FX graph. + - torch.is_floating_point: constant folded. +* SkipFunctionVariable: The objects should be skipped from tracing. +* UserFunctionVariable: The functions should be inlined. + +For developers: If you add/remove a torch level API, it may trigger failures from +test/dynamo/test_trace_rules.py:test_torch_name_rule_map_updated. To fix the failures: +If you are adding a new torch level API or Dynamo implementation: +* Add the name with the corresponding tracing rule to this map + if you are adding a new in graph function or Dynamo implementation for an existing function. +* Remove the object name from test/dynamo/test_trace_rules.ignored_c_binding_in_graph_function_names if it's there. + +If you are removing an existing torch level API: +* Remove the entry represented the API from this map or test/dynamo/test_trace_rules.ignored_c_binding_in_graph_function_names + depends on where it is. + + +""" +manual_torch_name_rule_map = { + "torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable, + "torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable, + "torch.overrides.is_tensor_like": TorchInGraphFunctionVariable, + "torch.jit.is_scripting": TorchInGraphFunctionVariable, + "torch.jit.is_tracing": TorchInGraphFunctionVariable, + "torch.jit.annotate": TorchInGraphFunctionVariable, + "torch.distributed.is_available": TorchInGraphFunctionVariable, + "torch.distributed.is_initialized": TorchInGraphFunctionVariable, + "torch.distributed.get_rank": TorchInGraphFunctionVariable, + "torch.distributed.get_world_size": TorchInGraphFunctionVariable, + "torch.distributed.tensor._api.DTensor#from_local": TorchInGraphFunctionVariable, + "torch.distributed.distributed_c10d._get_group_size_by_name": TorchInGraphFunctionVariable, + "torch.distributed.distributed_c10d._resolve_group_name_by_ranks_and_tag": TorchInGraphFunctionVariable, + "torch.distributed.distributed_c10d._get_group_tag": TorchInGraphFunctionVariable, + "torch.distributed.distributed_c10d.get_process_group_ranks": TorchInGraphFunctionVariable, + "torch._utils.is_compiling": TorchInGraphFunctionVariable, + "torch.fx._symbolic_trace.is_fx_tracing": TorchInGraphFunctionVariable, + "torch._dynamo.external_utils.is_compiling": TorchInGraphFunctionVariable, + "torch.compiler.is_compiling": TorchInGraphFunctionVariable, + "torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable, + "torch.autograd._profiler_enabled": SkipFunctionVariable, + "torch._C._to_dlpack": SkipFunctionVariable, + "torch.to_dlpack": SkipFunctionVariable, + # We graph break on RNG state setters or getters like + # `torch.get_rng_state` or `torch.set_rng_state`. These functions + # are not aten operations and therefore they are completely ignored + # by the AOT dispatcher. As a result, the AOT graph does not have + # these setter or getter functions, producing an incorrect graph + # when it comes to rng states. + "torch.default_generator#get_state": SkipFunctionVariable, + "torch._C.Generator#get_state": SkipFunctionVariable, + "torch.get_rng_state": SkipFunctionVariable, + "torch.cuda.get_rng_state": SkipFunctionVariable, + "torch.default_generator#set_state": SkipFunctionVariable, + "torch._C.Generator#set_state": SkipFunctionVariable, + "torch.set_rng_state": SkipFunctionVariable, + "torch.cuda.set_rng_state": SkipFunctionVariable, + # https://github.com/pytorch/pytorch/issues/107187 + "torch.manual_seed": SkipFunctionVariable, + # https://github.com/pytorch/pytorch/issues/93501 + "torch.nn.utils.rnn.pack_padded_sequence": SkipFunctionVariable, + "torch.nn.Parameter": TorchInGraphFunctionVariable, + "torch.nn.Buffer": TorchInGraphFunctionVariable, + "torch._nested_tensor_from_mask": SkipFunctionVariable, + "torch._nested_from_padded": SkipFunctionVariable, + "torch.nested.nested_tensor_from_jagged": UserFunctionVariable, + # symbol operators implemented in Python + "torch.sym_not": TorchInGraphFunctionVariable, + "torch.sym_float": TorchInGraphFunctionVariable, + "torch.sym_int": TorchInGraphFunctionVariable, + "torch.sym_max": TorchInGraphFunctionVariable, + "torch.sym_min": TorchInGraphFunctionVariable, + "torch.sym_sqrt": TorchInGraphFunctionVariable, + "torch.sym_ite": TorchInGraphFunctionVariable, + "torch.Tensor#_make_wrapper_subclass": SkipFunctionVariable, + "torch.Tensor#__init__": SkipFunctionVariable, + "torch.cuda.set_device": SkipFunctionVariable, + "torch.cuda.current_device": SkipFunctionVariable, + "torch._C.autocast_decrement_nesting": SkipFunctionVariable, + "torch._C.autocast_increment_nesting": SkipFunctionVariable, + "torch.autograd.grad": SkipFunctionVariable, + "torch.autograd.backward": SkipFunctionVariable, + "torch._C.clear_autocast_cache": SkipFunctionVariable, + "torch.distributions.constraints.is_dependent": SkipFunctionVariable, + "torch.jit.isinstance": SkipFunctionVariable, + "torch._C.set_anomaly_enabled": SkipFunctionVariable, + "torch._C.set_autocast_cache_enabled": SkipFunctionVariable, + "torch._C.set_autocast_cpu_dtype": SkipFunctionVariable, + "torch._C.set_autocast_cpu_enabled": SkipFunctionVariable, + "torch._C.set_autocast_enabled": SkipFunctionVariable, + "torch._C.set_autocast_gpu_dtype": SkipFunctionVariable, + "torch._C.set_autocast_ipu_dtype": SkipFunctionVariable, + "torch._C.set_autocast_ipu_enabled": SkipFunctionVariable, + "torch._C.set_autocast_xla_dtype": SkipFunctionVariable, + "torch._C.set_autocast_xla_enabled": SkipFunctionVariable, + "torch.resize_as_": SkipFunctionVariable, + "torch.resize_as_sparse_": SkipFunctionVariable, + "torch.get_default_device": TorchInGraphFunctionVariable, + # functorch/vmap + "torch._functorch.vmap._check_int_or_none": UserFunctionVariable, + "torch._functorch.vmap._check_out_dims_is_int_or_int_pytree": UserFunctionVariable, + "torch._functorch.vmap._check_randomness_arg": UserFunctionVariable, + "torch._functorch.vmap._chunked_vmap": UserFunctionVariable, + "torch._functorch.vmap._concat_chunked_outputs": UserFunctionVariable, + "torch._functorch.vmap._create_batched_inputs": UserFunctionVariable, + "torch._functorch.vmap._flat_vmap": UserFunctionVariable, + "torch._functorch.vmap._flatten_chunks_output": UserFunctionVariable, + "torch._functorch.vmap._get_chunked_inputs": UserFunctionVariable, + "torch._functorch.vmap._get_name": UserFunctionVariable, + "torch._functorch.vmap._maybe_remove_batch_dim": UserFunctionVariable, + "torch._functorch.vmap._num_outputs": UserFunctionVariable, + "torch._functorch.vmap._process_batched_inputs": UserFunctionVariable, + "torch._functorch.vmap._unwrap_batched": UserFunctionVariable, + "torch._functorch.vmap._validate_and_get_batch_size": UserFunctionVariable, + "torch._functorch.vmap.doesnt_support_saved_tensors_hooks": UserFunctionVariable, + "torch._functorch.vmap.get_chunk_sizes": UserFunctionVariable, + # lazy_load_decompositions uses a lock that is not supported yet in dynamo + # "torch._functorch.vmap.lazy_load_decompositions": UserFunctionVariable, + "torch._functorch.vmap.restore_vmap": UserFunctionVariable, + "torch._functorch.apis.vmap": UserFunctionVariable, + "torch._functorch.vmap.unwrap_batched": UserFunctionVariable, + "torch._functorch.vmap.vmap_impl": FunctorchHigherOrderVariable, + "torch._functorch.vmap.wrap_batched": UserFunctionVariable, + # functorch/grad + "torch._functorch.eager_transforms.grad_impl": FunctorchHigherOrderVariable, + "torch._functorch.apis.grad_and_value": UserFunctionVariable, + "torch._functorch.eager_transforms._as_tuple": UserFunctionVariable, + "torch._functorch.eager_transforms._check_unique_non_empty": UserFunctionVariable, + "torch._functorch.eager_transforms._create_differentiable": UserFunctionVariable, + "torch._functorch.eager_transforms._slice_argnums": UserFunctionVariable, + "torch._functorch.eager_transforms._undo_create_differentiable": UserFunctionVariable, + "torch._functorch.eager_transforms._validate_and_wrap_argnum": UserFunctionVariable, + "torch._functorch.eager_transforms._validate_and_wrap_argnums": UserFunctionVariable, + "torch._functorch.eager_transforms._wrap_all_tensors": UserFunctionVariable, + "torch._functorch.eager_transforms._wrap_tensor_for_grad": UserFunctionVariable, + # functorch/jacrev + "torch._functorch.eager_transforms.jacrev": FunctorchHigherOrderVariable, + "torch._functorch.eager_transforms.error_if_complex": UserFunctionVariable, + "torch._functorch.eager_transforms._chunked_standard_basis_for_": UserFunctionVariable, + "torch._functorch.eager_transforms._safe_zero_index": UserFunctionVariable, + # functorch/vjp + "torch._functorch.eager_transforms.vjp": FunctorchHigherOrderVariable, + "torch._functorch.eager_transforms._vjp_with_argnums": UserFunctionVariable, + "torch._functorch.eager_transforms.assert_non_empty_tensor_output": UserFunctionVariable, + # functorch/jvp + "torch._functorch.eager_transforms._jvp_with_argnums": UserFunctionVariable, + "torch._functorch.eager_transforms.jvp": FunctorchHigherOrderVariable, + "torch._functorch.eager_transforms._replace_args": UserFunctionVariable, + "torch._functorch.eager_transforms.safe_unpack_dual": UserFunctionVariable, + "torch._functorch.eager_transforms.assert_non_empty_list_of_tensors": UserFunctionVariable, + "torch._functorch.eager_transforms.assert_output_is_tensor_or_tensors": UserFunctionVariable, + "torch.autograd.forward_ad.enter_dual_level": UserFunctionVariable, + "torch.autograd.forward_ad.exit_dual_level": UserFunctionVariable, + "torch.autograd.forward_ad.make_dual": UserFunctionVariable, + "torch.autograd.forward_ad.unpack_dual": UserFunctionVariable, + # functorch/linearize + "torch._functorch.eager_transforms.linearize": FunctorchHigherOrderVariable, + # functorch/jacfwd + "torch._functorch.eager_transforms.jacfwd": FunctorchHigherOrderVariable, + "torch._functorch.eager_transforms._construct_standard_basis_for": UserFunctionVariable, + "torch._functorch.eager_transforms.safe_unflatten": UserFunctionVariable, + # functorch/hessian + "torch._functorch.eager_transforms.hessian": FunctorchHigherOrderVariable, + # functional_call + "torch._functorch.functional_call.functional_call": FunctionalCallVariable, + "torch.nn.utils.stateless._groupby_tensor": TorchInGraphFunctionVariable, + # functorch/deprecated + "torch._functorch.deprecated.jvp": UserFunctionVariable, + "torch._functorch.deprecated.hessian": UserFunctionVariable, + "torch._functorch.deprecated.jacfwd": UserFunctionVariable, + "torch._functorch.deprecated.jacrev": UserFunctionVariable, + "torch._functorch.deprecated.grad": UserFunctionVariable, + "torch._functorch.deprecated.grad_and_value": UserFunctionVariable, + "torch._functorch.deprecated.vjp": UserFunctionVariable, + # everything else + "torch._constrain_as_size": UserFunctionVariable, + "torch._tensor._convert": UserFunctionVariable, + "torch.jit._unwrap_optional": UserFunctionVariable, + "torch.backends.mha.get_fastpath_enabled": UserFunctionVariable, + "torch._C._functorch._add_batch_dim": TorchInGraphFunctionVariable, + "torch._C._functorch._remove_batch_dim": TorchInGraphFunctionVariable, + "torch._C._functorch._wrap_for_grad": TorchInGraphFunctionVariable, + "torch._C._functorch._unwrap_for_grad": TorchInGraphFunctionVariable, + "torch._C._functorch.maybe_current_level": TorchInGraphFunctionVariable, + "torch._C._functorch.is_batchedtensor": TorchInGraphFunctionVariable, + "torch._dynamo.mark_static": UserFunctionVariable, + "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, + "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, + "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, + "torch.sparse_bsc_tensor": SkipFunctionVariable, + "torch.sparse_bsr_tensor": SkipFunctionVariable, + "torch.sparse_csc_tensor": SkipFunctionVariable, + "torch.sparse_csr_tensor": SkipFunctionVariable, + "torch.sparse_compressed_tensor": SkipFunctionVariable, + "torch._C._autograd._unsafe_set_version_counter": TorchInGraphFunctionVariable, + # avoid skipping user defined modules in distributed unit tests + "torch/testing/_internal/common_fsdp.py#forward": UserFunctionVariable, + f"torch/testing/_internal/common_fsdp.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, + "torch/testing/_internal/distributed/_tensor/common_dtensor.py#forward": UserFunctionVariable, + f"torch/testing/_internal/distributed/_tensor/common_dtensor.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, + "torch/testing/_internal/common_distributed.py#forward": UserFunctionVariable, + f"torch/testing/_internal/common_distributed.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, +} + + +# In graph functions (including constant folding) that are C bindings +torch_c_binding_in_graph_functions = dict.fromkeys( + [ + "math.acos", + "math.acosh", + "math.asin", + "math.asinh", + "math.atan", + "math.atan2", + "math.atanh", + "math.ceil", + "math.comb", + "math.copysign", + "math.cos", + "math.cosh", + "math.degrees", + "math.dist", + "math.erf", + "math.erfc", + "math.exp", + "math.expm1", + "math.fabs", + "math.factorial", + "math.floor", + "math.fmod", + "math.frexp", + "math.fsum", + "math.gamma", + "math.gcd", + "math.hypot", + "math.isclose", + "math.isfinite", + "math.isinf", + "math.isnan", + "math.isqrt", + "math.ldexp", + "math.lgamma", + "math.log", + "math.log10", + "math.log1p", + "math.log2", + "math.modf", + "math.nextafter", + "math.perm", + "math.pow", + "math.prod", + "math.radians", + "math.remainder", + "math.sin", + "math.sinh", + "math.tan", + "math.tanh", + "math.trunc", + "math.ulp", + "torch._adaptive_avg_pool2d", + "torch._adaptive_avg_pool3d", + "torch._add_batch_dim", + "torch._add_relu_", + "torch._add_relu", + "torch._addmm_activation", + "torch._aminmax", + "torch._amp_foreach_non_finite_check_and_unscale_", + "torch._amp_update_scale_", + "torch._assert_async", + "torch._assert_tensor_metadata", + "torch._batch_norm_impl_index", + "torch._C._activate_gpu_trace", + "torch._C._add_cached_tensor", + "torch._C._add_docstr", + "torch._C._are_functorch_transforms_active", + "torch._C._autograd_init", + "torch._C._awaitable_nowait", + "torch._C._awaitable_wait", + "torch._C._awaitable", + "torch._C._backport_for_mobile_from_buffer_to_buffer", + "torch._C._backport_for_mobile_from_buffer", + "torch._C._backport_for_mobile_to_buffer", + "torch._C._backport_for_mobile", + "torch._C._broadcast_coalesced", + "torch._C._broadcast_out", + "torch._C._broadcast", + "torch._C._c10d_init", + "torch._C._calculate_package_version_based_on_upgraders", + "torch._C._can_use_flash_attention", + "torch._C._can_use_mem_efficient_attention", + "torch._C._can_use_cudnn_attention", + "torch._C._check_onnx_proto", + "torch._C._check_sparse_tensor_invariants", + "torch._C._collect_all", + "torch._C._commit_update", + "torch._C._compile_graph_to_code_table", + "torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata", + "torch._C._construct_storage_from_data_pointer", + "torch._C._conv_determine_backend_memory_format", + "torch._C._cpu._is_avx2_supported", + "torch._C._cpu._is_avx512_supported", + "torch._C._cpu._is_avx512_vnni_supported", + "torch._C._cpu._is_avx512_bf16_supported", + "torch._C._cpu._is_amx_tile_supported", + "torch._C._cpu._init_amx", + "torch._C._crash_if_aten_asan", + "torch._C._crash_if_csrc_asan", + "torch._C._crash_if_csrc_ubsan", + "torch._C._crash_if_debug_asserts_fail", + "torch._C._crash_if_vptr_ubsan", + "torch._C._create_function_from_graph", + "torch._C._create_function_from_trace_with_dict", + "torch._C._create_function_from_trace", + "torch._C._create_graph_by_tracing", + "torch._C._create_module_with_type", + "torch._C._create_object_with_type", + "torch._C._cuda_attach_out_of_memory_observer", + "torch._C._cuda_beginAllocateCurrentStreamToPool", + "torch._C._cuda_canDeviceAccessPeer", + "torch._C._cuda_changeCurrentAllocator", + "torch._C._cuda_checkPoolLiveAllocations", + "torch._C._cuda_clearCublasWorkspaces", + "torch._C._cuda_cudaCachingAllocator_raw_alloc", + "torch._C._cuda_cudaCachingAllocator_raw_delete", + "torch._C._cuda_cudaCachingAllocator_set_allocator_settings", + "torch._C._cuda_cudaHostAllocator", + "torch._C._cuda_customAllocator", + "torch._C._cuda_emptyCache", + "torch._C._cuda_endAllocateCurrentStreamToPool", + "torch._C._cuda_exchangeDevice", + "torch._C._cuda_get_conv_benchmark_empty_cache", + "torch._C._cuda_get_cudnn_benchmark_limit", + "torch._C._cuda_get_sync_debug_mode", + "torch._C._cuda_getAllocator", + "torch._C._cuda_getAllocatorBackend", + "torch._C._cuda_getArchFlags", + "torch._C._cuda_getCheckpointState", + "torch._C._cuda_getCompiledVersion", + "torch._C._cuda_getCurrentBlasHandle", + "torch._C._cuda_getCurrentRawStream", + "torch._C._cuda_getCurrentStream", + "torch._C._cuda_getDefaultStream", + "torch._C._cuda_getDevice", + "torch._C._cuda_getDeviceCount", + "torch._C._cuda_hasPrimaryContext", + "torch._C._cuda_init", + "torch._C._cuda_ipc_collect", + "torch._C._cuda_isCurrentStreamCapturing", + "torch._C._cuda_isHistoryEnabled", + "torch._C._cuda_isInBadFork", + "torch._C._cuda_jiterator_compile_and_launch_kernel", + "torch._C._cuda_lock_mutex", + "torch._C._cuda_maybeExchangeDevice", + "torch._C._cuda_memorySnapshot", + "torch._C._cuda_memoryStats", + "torch._C._cuda_record_memory_history_legacy", + "torch._C._cuda_record_memory_history", + "torch._C._cuda_releasePool", + "torch._C._cuda_resetAccumulatedMemoryStats", + "torch._C._cuda_resetPeakMemoryStats", + "torch._C._cuda_set_cudnn_benchmark_limit", + "torch._C._cuda_set_sync_debug_mode", + "torch._C._cuda_setCheckpointPoolState", + "torch._C._cuda_setDevice", + "torch._C._cuda_setMemoryFraction", + "torch._C._cuda_setStream", + "torch._C._cuda_sleep", + "torch._C._cuda_synchronize", + "torch._C._cuda_unlock_mutex", + "torch._C._cudnn_set_conv_benchmark_empty_cache", + "torch._C._cudnn.getCompileVersion", + "torch._C._cudnn.getRuntimeVersion", + "torch._C._cudnn.getVersionInt", + "torch._C._current_autograd_node", + "torch._C._current_graph_task_execution_order", + "torch._C._current_graph_task_id", + "torch._C._cxx_flags", + "torch._C._debug_get_fusion_group_inlining", + "torch._C._debug_only_are_vmap_fallback_warnings_enabled", + "torch._C._debug_only_display_vmap_fallback_warnings", + "torch._C._debug_set_autodiff_subgraph_inlining", + "torch._C._debug_set_fusion_group_inlining", + "torch._C._demangle", + "torch._C._disabled_torch_dispatch_impl", + "torch._C._disabled_torch_function_impl", + "torch._C._dispatch_call_boxed", + "torch._C._dispatch_check_all_invariants", + "torch._C._dispatch_check_invariants", + "torch._C._dispatch_dump_table", + "torch._C._dispatch_dump", + "torch._C._dispatch_find_dangling_impls", + "torch._C._dispatch_find_schema_or_throw", + "torch._C._dispatch_get_all_op_names", + "torch._C._dispatch_get_backend_keyset_from_autograd", + "torch._C._dispatch_get_registrations_for_dispatch_key", + "torch._C._dispatch_has_backend_fallback", + "torch._C._dispatch_has_computed_kernel_for_dispatch_key", + "torch._C._dispatch_has_kernel_for_any_dispatch_key", + "torch._C._dispatch_has_kernel_for_dispatch_key", + "torch._C._dispatch_has_kernel", + "torch._C._dispatch_is_alias_key", + "torch._C._dispatch_is_included_in_alias", + "torch._C._dispatch_is_main_interpreter", + "torch._C._dispatch_isTensorSubclassLike", + "torch._C._dispatch_key_for_device", + "torch._C._dispatch_key_name", + "torch._C._dispatch_key_parse", + "torch._C._dispatch_key_set", + "torch._C._dispatch_keys", + "torch._C._dispatch_keyset_full_after", + "torch._C._dispatch_keyset_full", + "torch._C._dispatch_keyset_to_string", + "torch._C._dispatch_library", + "torch._C._dispatch_num_backends", + "torch._C._dispatch_print_registrations_for_dispatch_key", + "torch._C._dispatch_pystub", + "torch._C._dispatch_set_report_error_callback", + "torch._C._dispatch_tls_is_dispatch_key_excluded", + "torch._C._dispatch_tls_is_dispatch_key_included", + "torch._C._dispatch_tls_local_exclude_set", + "torch._C._dispatch_tls_local_include_set", + "torch._C._dispatch_tls_set_dispatch_key_excluded", + "torch._C._dispatch_tls_set_dispatch_key_included", + "torch._C._dist_autograd_init", + "torch._C._dump_local_tls_set", + "torch._C._dump_upgraders_map", + "torch._C._enable_mobile_interface_call_export", + "torch._C._enter_dual_level", + "torch._C._error_if_any_worker_fails", + "torch._C._exit_dual_level", + "torch._C._export_operator_list", + "torch._C._export_opnames", + "torch._C._faulty_agent_init", + "torch._C._fft.fft_fft", + "torch._C._fft.fft_fft2", + "torch._C._fft.fft_fftfreq", + "torch._C._fft.fft_fftn", + "torch._C._fft.fft_fftshift", + "torch._C._fft.fft_hfft", + "torch._C._fft.fft_hfft2", + "torch._C._fft.fft_hfftn", + "torch._C._fft.fft_ifft", + "torch._C._fft.fft_ifft2", + "torch._C._fft.fft_ifftn", + "torch._C._fft.fft_ifftshift", + "torch._C._fft.fft_ihfft", + "torch._C._fft.fft_ihfft2", + "torch._C._fft.fft_ihfftn", + "torch._C._fft.fft_irfft", + "torch._C._fft.fft_irfft2", + "torch._C._fft.fft_irfftn", + "torch._C._fft.fft_rfft", + "torch._C._fft.fft_rfft2", + "torch._C._fft.fft_rfftfreq", + "torch._C._fft.fft_rfftn", + "torch._C._free_And_Remove_DeleterFn", + "torch._C._freeze_module", + "torch._C._from_dlpack", + "torch._C._functionality_to_backend_keys", + "torch._C._functionalization_reapply_views_tls", + "torch._C._fuse_to_static_module", + "torch._C._gather_out", + "torch._C._gather", + "torch._C._generate_upgraders_graph", + "torch._C._get_autograd_fallback_mode", + "torch._C._get_backcompat_broadcast_warn", + "torch._C._get_backcompat_keepdim_warn", + "torch._C._get_blas_preferred_backend", + "torch._C._get_caught_jit_exception_class_name", + "torch._C._get_caught_jit_exception_original_msg", + "torch._C._get_constant_bool_symnode", + "torch._C._get_cpp_backtrace", + "torch._C._get_cpu_capability", + "torch._C._get_cublas_allow_bf16_reduced_precision_reduction", + "torch._C._get_cublas_allow_fp16_reduced_precision_reduction", + "torch._C._get_cublas_allow_tf32", + "torch._C._get_cudnn_allow_tf32", + "torch._C._get_cudnn_benchmark", + "torch._C._get_cudnn_deterministic", + "torch._C._get_cudnn_enabled", + "torch._C._get_custom_class_python_wrapper", + "torch._C._get_default_device", + "torch._C._get_deterministic_algorithms_warn_only", + "torch._C._get_deterministic_algorithms", + "torch._C._get_deterministic_fill_uninitialized_memory", + "torch._C._get_dispatch_mode", + "torch._C._get_dispatch_stack_at", + "torch._C._get_file_format", + "torch._C._get_flash_sdp_enabled", + "torch._C._get_float32_matmul_precision", + "torch._C._get_function_stack_at", + "torch._C._get_graph_executor_optimize", + "torch._C._get_linalg_preferred_backend", + "torch._C._get_math_sdp_enabled", + "torch._C._get_math_sdp_allow_fp16_bf16_reduction", + "torch._C._get_max_operator_version", + "torch._C._get_mem_efficient_sdp_enabled", + "torch._C._get_mkldnn_enabled", + "torch._C._get_cudnn_sdp_enabled", + "torch._C._set_sdp_use_cudnn", + "torch._C._get_mobile_model_contained_types_from_buffer", + "torch._C._get_mobile_model_contained_types", + "torch._C._get_model_bytecode_version_from_buffer", + "torch._C._get_model_bytecode_version", + "torch._C._get_model_extra_files_from_buffer", + "torch._C._get_model_extra_files", + "torch._C._get_model_ops_and_info_from_buffer", + "torch._C._get_model_ops_and_info", + "torch._C._get_module_info_from_flatbuffer", + "torch._C._get_nnpack_enabled", + "torch._C._get_obj_in_tls", + "torch._C._get_operation_overload", + "torch._C._get_operator_version_map", + "torch._C._get_privateuse1_backend_name", + "torch._C._get_qengine", + "torch._C._get_schema", + "torch._C._get_nested_int", + "torch._C._get_tensor_metadata", + "torch._C._get_tracing_state", + "torch._C._get_upgrader_ranges", + "torch._C._get_upgraders_entry_map", + "torch._C._get_upgraders_map_size", + "torch._C._get_value_trace", + "torch._C._get_version_calculator_flag", + "torch._C._get_warnAlways", + "torch._C._graph_pool_handle", + "torch._C._group_tensors_by_device_and_dtype", + "torch._C._hack_do_not_use_clone_module_with_class", + "torch._C._has_distributed", + "torch._C._has_Standard_Deleter", + "torch._C._has_storage", + "torch._C._has_tensorexpr_cpp_tests", + "torch._C._run_tensorexpr_cpp_tests", + "torch._C._has_torch_function_unary", + "torch._C._has_torch_function_variadic", + "torch._C._has_torch_function", + "torch._C._import_ir_module_from_package", + "torch._C._increment_version", + "torch._C._infer_size", + "torch._C._init_names", + "torch._C._initExtension", + "torch._C._is_alias_of", + "torch._C._is_any_autocast_enabled", + "torch._C._is_cached_tensor", + "torch._C._is_flash_attention_available", + "torch._C._is_fwd_grad_enabled", + "torch._C._is_key_in_tls", + "torch._C._is_multithreading_enabled", + "torch._C._is_torch_function_enabled", + "torch._C._is_torch_function_mode_enabled", + "torch._C._is_tracing", + "torch._C._is_view_replay_enabled", + "torch._C._is_xnnpack_enabled", + "torch._C._itt.is_available", + "torch._C._itt.mark", + "torch._C._itt.rangePop", + "torch._C._itt.rangePush", + "torch._C._ivalue_debug_python_object", + "torch._C._ivalue_tags_match", + "torch._C._jit_assert_is_instance", + "torch._C._jit_can_fuse_on_cpu_legacy", + "torch._C._jit_can_fuse_on_cpu", + "torch._C._jit_can_fuse_on_gpu", + "torch._C._jit_cat_wo_conditionals", + "torch._C._jit_check_alias_annotation", + "torch._C._jit_clear_class_registry", + "torch._C._jit_debug_fuser_num_cached_kernel_specs", + "torch._C._jit_debug_module_iterators", + "torch._C._jit_decay_packed_param_input_types", + "torch._C._jit_decomposition_graph_for_node", + "torch._C._jit_differentiate", + "torch._C._jit_erase_non_input_shape_information", + "torch._C._jit_flatten", + "torch._C._jit_fuser_get_fused_kernel_code", + "torch._C._jit_get_all_schemas", + "torch._C._jit_get_custom_class_schemas", + "torch._C._jit_get_emit_hooks", + "torch._C._jit_get_inline_everything_mode", + "torch._C._jit_get_logging_option", + "torch._C._jit_get_num_profiled_runs", + "torch._C._jit_get_operation", + "torch._C._jit_get_schemas_for_operator", + "torch._C._jit_get_te_cuda_pointwise_block_count", + "torch._C._jit_get_te_cuda_pointwise_block_size", + "torch._C._jit_get_te_cuda_pointwise_loop_levels", + "torch._C._jit_get_te_generate_block_code", + "torch._C._jit_get_te_must_use_llvm_cpu", + "torch._C._jit_get_tracer_state_warn", + "torch._C._jit_has_cpp_tests", + "torch._C._jit_init", + "torch._C._jit_interpret_graph", + "torch._C._jit_is_onnx_log_enabled", + "torch._C._jit_is_script_object", + "torch._C._jit_llga_enabled", + "torch._C._jit_nvfuser_can_be_enabled", + "torch._C._jit_nvfuser_clear_comparison_callback", + "torch._C._jit_nvfuser_enabled", + "torch._C._jit_nvfuser_horizontal_mode", + "torch._C._jit_nvfuser_set_comparison_callback", + "torch._C._jit_nvfuser_single_node_mode", + "torch._C._jit_object_is_non_holding", + "torch._C._jit_onnx_convert_pattern_from_subblock", + "torch._C._jit_onnx_create_full_scope_name", + "torch._C._jit_onnx_list_model_parameters", + "torch._C._jit_onnx_log", + "torch._C._jit_opt_conditionals", + "torch._C._jit_override_can_fuse_on_cpu_legacy", + "torch._C._jit_override_can_fuse_on_cpu", + "torch._C._jit_override_can_fuse_on_gpu", + "torch._C._jit_pass_autocast", + "torch._C._jit_pass_batch_mm", + "torch._C._jit_pass_canonicalize_graph_fuser_ops", + "torch._C._jit_pass_canonicalize", + "torch._C._jit_pass_complete_shape_analysis", + "torch._C._jit_pass_concat_frozen_linear", + "torch._C._jit_pass_constant_loop_unrolling", + "torch._C._jit_pass_constant_pooling", + "torch._C._jit_pass_constant_propagation_immutable_types", + "torch._C._jit_pass_constant_propagation", + "torch._C._jit_pass_convert_frozen_ops_to_mkldnn", + "torch._C._jit_pass_create_autodiff_subgraphs", + "torch._C._jit_pass_create_functional_graphs", + "torch._C._jit_pass_cse", + "torch._C._jit_pass_custom_pattern_based_rewrite_graph", + "torch._C._jit_pass_custom_pattern_based_rewrite", + "torch._C._jit_pass_dbr_quant_remove_redundant_aliases", + "torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects", + "torch._C._jit_pass_dce", + "torch._C._jit_pass_decompose_ops", + "torch._C._jit_pass_dedup_module_uses", + "torch._C._jit_pass_erase_number_types", + "torch._C._jit_pass_erase_shape_information", + "torch._C._jit_pass_filter_non_tensor_arguments", + "torch._C._jit_pass_fixup_onnx_controlflow_node", + "torch._C._jit_pass_fold_convbn", + "torch._C._jit_pass_fold_frozen_conv_add_or_sub", + "torch._C._jit_pass_fold_frozen_conv_bn", + "torch._C._jit_pass_fold_frozen_conv_mul_or_div", + "torch._C._jit_pass_fold_frozen_linear_bn", + "torch._C._jit_pass_fold_prepacking_ops", + "torch._C._jit_pass_functional_to_inplace_activation", + "torch._C._jit_pass_fuse_add_relu", + "torch._C._jit_pass_fuse_addmm", + "torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv", + "torch._C._jit_pass_fuse_frozen_conv_add_relu", + "torch._C._jit_pass_fuse_linear", + "torch._C._jit_pass_fuse_quantized_add_relu", + "torch._C._jit_pass_fuse_tensorexprs", + "torch._C._jit_pass_fuse", + "torch._C._jit_pass_inline_fork_wait", + "torch._C._jit_pass_inline_functional_graphs", + "torch._C._jit_pass_inline", + "torch._C._jit_pass_inplace_to_functional_activation", + "torch._C._jit_pass_insert_observer_method_for_ondevice_ptq", + "torch._C._jit_pass_insert_observers", + "torch._C._jit_pass_insert_prepack_unpack", + "torch._C._jit_pass_insert_prepacked_ops", + "torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq", + "torch._C._jit_pass_insert_quant_dequant", + "torch._C._jit_pass_integer_value_refinement", + "torch._C._jit_pass_lint", + "torch._C._jit_pass_loop_unrolling", + "torch._C._jit_pass_lower_all_tuples", + "torch._C._jit_pass_lower_graph", + "torch._C._jit_pass_metal_fold_prepacking_ops", + "torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv", + "torch._C._jit_pass_metal_insert_prepacked_ops", + "torch._C._jit_pass_metal_optimize_for_mobile", + "torch._C._jit_pass_onnx_assign_output_shape", + "torch._C._jit_pass_onnx_assign_scoped_names_for_node_and_value", + "torch._C._jit_pass_onnx_autograd_function_process", + "torch._C._jit_pass_onnx_block", + "torch._C._jit_pass_onnx_cast_all_constant_to_floating", + "torch._C._jit_pass_onnx_clear_scope_records", + "torch._C._jit_pass_onnx_constant_fold", + "torch._C._jit_pass_onnx_deduplicate_initializers", + "torch._C._jit_pass_onnx_eliminate_unused_items", + "torch._C._jit_pass_onnx_eval_peephole", + "torch._C._jit_pass_onnx_function_extraction", + "torch._C._jit_pass_onnx_function_substitution", + "torch._C._jit_pass_onnx_graph_shape_type_inference", + "torch._C._jit_pass_onnx_lint", + "torch._C._jit_pass_onnx_node_shape_type_inference", + "torch._C._jit_pass_onnx_peephole", + "torch._C._jit_pass_onnx_preprocess_caffe2", + "torch._C._jit_pass_onnx_preprocess", + "torch._C._jit_pass_onnx_quantization_insert_permutes", + "torch._C._jit_pass_onnx_remove_inplace_ops_for_onnx", + "torch._C._jit_pass_onnx_remove_print", + "torch._C._jit_pass_onnx_scalar_type_analysis", + "torch._C._jit_pass_onnx_set_dynamic_input_shape", + "torch._C._jit_pass_onnx_track_scope_attributes", + "torch._C._jit_pass_onnx_unpack_quantized_weights", + "torch._C._jit_pass_onnx", + "torch._C._jit_pass_optimize_for_inference", + "torch._C._jit_pass_optimize_for_mobile", + "torch._C._jit_pass_optimize_frozen_graph", + "torch._C._jit_pass_pattern_based_rewrite", + "torch._C._jit_pass_peephole_list_idioms", + "torch._C._jit_pass_peephole", + "torch._C._jit_pass_prepare_division_for_onnx", + "torch._C._jit_pass_propagate_device", + "torch._C._jit_pass_propagate_dtype", + "torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute", + "torch._C._jit_pass_propagate_shapes_on_graph", + "torch._C._jit_pass_quant_finalize_for_ondevice_ptq", + "torch._C._jit_pass_quant_finalize", + "torch._C._jit_pass_quant_fusion", + "torch._C._jit_pass_refine_integer_values", + "torch._C._jit_pass_refine_tuple_types", + "torch._C._jit_pass_remove_dropout", + "torch._C._jit_pass_remove_expands", + "torch._C._jit_pass_remove_inplace_ops", + "torch._C._jit_pass_remove_mutation", + "torch._C._jit_pass_replace_old_ops_with_upgraders", + "torch._C._jit_pass_replicate_dequantize", + "torch._C._jit_pass_run_decompositions", + "torch._C._jit_pass_specialize_autogradzero", + "torch._C._jit_pass_swap_functional_linear", + "torch._C._jit_pass_transform_conv1d_to_conv2d", + "torch._C._jit_pass_transpose_frozen_linear", + "torch._C._jit_pass_vulkan_fold_prepacking_ops", + "torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv", + "torch._C._jit_pass_vulkan_insert_prepacked_ops", + "torch._C._jit_pass_vulkan_optimize_for_mobile", + "torch._C._jit_register_decomposition_for_schema", + "torch._C._jit_register_shape_compute_graph_for_node", + "torch._C._jit_resolve_packet", + "torch._C._jit_run_cpp_tests", + "torch._C._jit_script_class_compile", + "torch._C._jit_script_compile_overload", + "torch._C._jit_script_compile", + "torch._C._jit_script_interface_compile", + "torch._C._jit_set_autocast_mode", + "torch._C._jit_set_bailout_depth", + "torch._C._jit_set_emit_hooks", + "torch._C._jit_set_fusion_strategy", + "torch._C._jit_set_inline_everything_mode", + "torch._C._jit_set_llga_enabled", + "torch._C._jit_set_logging_option", + "torch._C._jit_set_logging_stream", + "torch._C._jit_set_num_profiled_runs", + "torch._C._jit_set_nvfuser_enabled", + "torch._C._jit_set_nvfuser_guard_mode", + "torch._C._jit_set_nvfuser_horizontal_mode", + "torch._C._jit_set_nvfuser_single_node_mode", + "torch._C._jit_set_nvfuser_skip_node_kind", + "torch._C._jit_set_onnx_log_enabled", + "torch._C._jit_set_onnx_log_output_stream", + "torch._C._jit_set_profiling_executor", + "torch._C._jit_set_profiling_mode", + "torch._C._jit_set_symbolic_shapes_test_mode", + "torch._C._jit_set_te_cuda_pointwise_block_count", + "torch._C._jit_set_te_cuda_pointwise_block_size", + "torch._C._jit_set_te_cuda_pointwise_loop_levels", + "torch._C._jit_set_te_generate_block_code", + "torch._C._jit_set_te_must_use_llvm_cpu", + "torch._C._jit_set_texpr_dynamic_shape_enabled", + "torch._C._jit_set_texpr_fuser_enabled", + "torch._C._jit_set_texpr_reductions_enabled", + "torch._C._jit_set_tracer_state_warn", + "torch._C._jit_set_utf8_decoding_ignore", + "torch._C._jit_shape_compute_graph_for_node", + "torch._C._jit_symbolic_shapes_test_mode_enabled", + "torch._C._jit_texpr_dynamic_shape_enabled", + "torch._C._jit_texpr_fallback_allowed", + "torch._C._jit_texpr_fuser_enabled", + "torch._C._jit_texpr_reductions_enabled", + "torch._C._jit_texpr_set_fallback_allowed", + "torch._C._jit_to_backend_selective", + "torch._C._jit_to_backend", + "torch._C._jit_to_static_module", + "torch._C._jit_trace_graph", + "torch._C._jit_trace_module", + "torch._C._jit_tree_views.FalseLiteral", + "torch._C._jit_tree_views.NoneLiteral", + "torch._C._jit_tree_views.TrueLiteral", + "torch._C._jit_try_infer_type", + "torch._C._jit_unflatten", + "torch._C._last_executed_optimized_graph", + "torch._C._len_torch_dispatch_stack", + "torch._C._len_torch_function_stack", + "torch._C._linalg._linalg_eigvals", + "torch._C._linalg.linalg_cholesky_ex", + "torch._C._linalg.linalg_cholesky", + "torch._C._linalg.linalg_cond", + "torch._C._linalg.linalg_cross", + "torch._C._linalg.linalg_det", + "torch._C._linalg.linalg_diagonal", + "torch._C._linalg.linalg_eig", + "torch._C._linalg.linalg_eigh", + "torch._C._linalg.linalg_eigvals", + "torch._C._linalg.linalg_eigvalsh", + "torch._C._linalg.linalg_householder_product", + "torch._C._linalg.linalg_inv_ex", + "torch._C._linalg.linalg_inv", + "torch._C._linalg.linalg_ldl_factor_ex", + "torch._C._linalg.linalg_ldl_factor", + "torch._C._linalg.linalg_ldl_solve", + "torch._C._linalg.linalg_lstsq", + "torch._C._linalg.linalg_lu_factor_ex", + "torch._C._linalg.linalg_lu_factor", + "torch._C._linalg.linalg_lu_solve", + "torch._C._linalg.linalg_lu", + "torch._C._linalg.linalg_matmul", + "torch._C._linalg.linalg_matrix_exp", + "torch._C._linalg.linalg_matrix_norm", + "torch._C._linalg.linalg_matrix_power", + "torch._C._linalg.linalg_matrix_rank", + "torch._C._linalg.linalg_multi_dot", + "torch._C._linalg.linalg_norm", + "torch._C._linalg.linalg_pinv", + "torch._C._linalg.linalg_qr", + "torch._C._linalg.linalg_slogdet", + "torch._C._linalg.linalg_solve_ex", + "torch._C._linalg.linalg_solve_triangular", + "torch._C._linalg.linalg_solve", + "torch._C._linalg.linalg_svd", + "torch._C._linalg.linalg_svdvals", + "torch._C._linalg.linalg_tensorinv", + "torch._C._linalg.linalg_tensorsolve", + "torch._C._linalg.linalg_vander", + "torch._C._linalg.linalg_vecdot", + "torch._C._linalg.linalg_vector_norm", + "torch._C._llvm_enabled", + "torch._C._load_for_lite_interpreter_from_buffer", + "torch._C._load_for_lite_interpreter", + "torch._C._load_jit_module_from_bytes", + "torch._C._load_jit_module_from_file", + "torch._C._load_mobile_module_from_bytes", + "torch._C._load_mobile_module_from_file", + "torch._C._log_api_usage_metadata", + "torch._C._log_api_usage_once", + "torch._C._logging_set_logger", + "torch._C._meta_in_tls_dispatch_include", + "torch._C._mps_acquireEvent", + "torch._C._mps_currentAllocatedMemory", + "torch._C._mps_deviceSynchronize", + "torch._C._mps_driverAllocatedMemory", + "torch._C._mps_recommendedMaxMemory", + "torch._C._mps_elapsedTimeOfEvents", + "torch._C._mps_emptyCache", + "torch._C._mps_get_default_generator", + "torch._C._mps_is_available", + "torch._C._mps_is_in_bad_fork", + "torch._C._mps_is_on_macos_13_or_newer", + "torch._C._mps_profilerStartTrace", + "torch._C._mps_profilerStopTrace", + "torch._C._mps_queryEvent", + "torch._C._mps_recordEvent", + "torch._C._mps_releaseEvent", + "torch._C._mps_setMemoryFraction", + "torch._C._mps_synchronizeEvent", + "torch._C._mps_waitForEvent", + "torch._C._multiprocessing_init", + "torch._C._nccl_all_gather", + "torch._C._nccl_all_reduce", + "torch._C._nccl_broadcast", + "torch._C._nccl_init_rank", + "torch._C._nccl_reduce_scatter", + "torch._C._nccl_reduce", + "torch._C._nccl_unique_id", + "torch._C._nccl_version_suffix", + "torch._C._nccl_version", + "torch._C._nested.nested_tensor", + "torch._C._nested.nested_to_padded_tensor", + "torch._C._new_symbolic_shape_symbol", + "torch._C._nn_module_to_mobile", + "torch._C._nn._conv_depthwise2d", + "torch._C._nn._pad_circular", + "torch._C._nn._pad_enum", + "torch._C._nn._parse_to", + "torch._C._nn._test_ambiguous_defaults", + "torch._C._nn._test_optional_filled_intlist", + "torch._C._nn._test_optional_floatlist", + "torch._C._nn._test_optional_intlist", + "torch._C._nn._test_string_default", + "torch._C._nn._test_warn_in_autograd", + "torch._C._nn._upsample_bicubic2d_aa", + "torch._C._nn._upsample_bilinear2d_aa", + "torch._C._nn._upsample_nearest_exact1d", + "torch._C._nn._upsample_nearest_exact2d", + "torch._C._nn._upsample_nearest_exact3d", + "torch._C._nn.adaptive_avg_pool2d", + "torch._C._nn.adaptive_avg_pool3d", + "torch._C._nn.adaptive_max_pool2d", + "torch._C._nn.adaptive_max_pool3d", + "torch._C._nn.avg_pool2d", + "torch._C._nn.avg_pool3d", + "torch._C._nn.binary_cross_entropy", + "torch._C._nn.col2im", + "torch._C._nn.conv_depthwise3d", + "torch._C._nn.cross_entropy_loss", + "torch._C._nn.elu_", + "torch._C._nn.elu", + "torch._C._nn.flatten_dense_tensors", + "torch._C._nn.fractional_max_pool2d", + "torch._C._nn.fractional_max_pool3d", + "torch._C._nn.gelu_", + "torch._C._nn.gelu", + "torch._C._nn.glu", + "torch._C._nn.hardsigmoid_", + "torch._C._nn.hardsigmoid", + "torch._C._nn.hardswish_", + "torch._C._nn.hardswish", + "torch._C._nn.hardtanh_", + "torch._C._nn.hardtanh", + "torch._C._nn.huber_loss", + "torch._C._nn.im2col", + "torch._C._nn.l1_loss", + "torch._C._nn.leaky_relu_", + "torch._C._nn.leaky_relu", + "torch._C._nn.linear", + "torch._C._nn.log_sigmoid", + "torch._C._nn.max_pool2d_with_indices", + "torch._C._nn.max_pool3d_with_indices", + "torch._C._nn.max_unpool2d", + "torch._C._nn.max_unpool3d", + "torch._C._nn.mish_", + "torch._C._nn.mish", + "torch._C._nn.mkldnn_linear", + "torch._C._nn.mkldnn_reorder_conv2d_weight", + "torch._C._nn.mkldnn_reorder_conv3d_weight", + "torch._C._nn.mse_loss", + "torch._C._nn.multi_margin_loss", + "torch._C._nn.multilabel_margin_loss", + "torch._C._nn.nll_loss_nd", + "torch._C._nn.nll_loss", + "torch._C._nn.nll_loss2d", + "torch._C._nn.one_hot", + "torch._C._nn.pad_sequence", + "torch._C._nn.pad", + "torch._C._nn.reflection_pad1d", + "torch._C._nn.reflection_pad2d", + "torch._C._nn.reflection_pad3d", + "torch._C._nn.relu6_", + "torch._C._nn.relu6", + "torch._C._nn.replication_pad1d", + "torch._C._nn.replication_pad2d", + "torch._C._nn.replication_pad3d", + "torch._C._nn.rrelu_with_noise_", + "torch._C._nn.rrelu_with_noise", + "torch._C._nn.scaled_dot_product_attention", + "torch._C._nn.silu_", + "torch._C._nn.silu", + "torch._C._nn.slow_conv_dilated2d", + "torch._C._nn.slow_conv_dilated3d", + "torch._C._nn.slow_conv_transpose2d", + "torch._C._nn.slow_conv_transpose3d", + "torch._C._nn.slow_conv3d", + "torch._C._nn.smooth_l1_loss", + "torch._C._nn.soft_margin_loss", + "torch._C._nn.softplus", + "torch._C._nn.softshrink", + "torch._C._nn.thnn_conv2d", + "torch._C._nn.unflatten_dense_tensors", + "torch._C._nn.upsample_bicubic2d", + "torch._C._nn.upsample_bilinear2d", + "torch._C._nn.upsample_linear1d", + "torch._C._nn.upsample_nearest1d", + "torch._C._nn.upsample_nearest2d", + "torch._C._nn.upsample_nearest3d", + "torch._C._nn.upsample_trilinear3d", + "torch._C._non_sym_sizes", + "torch._C._overlaps", + "torch._C._parallel_info", + "torch._C._parse_dispatch_key", + "torch._C._parse_source_def", + "torch._C._pop_torch_dispatch_stack", + "torch._C._pop_torch_function_stack", + "torch._C._propagate_and_assign_input_shapes", + "torch._C._propagate_shapes", + "torch._C._propagate_xla_data", + "torch._C._push_on_torch_dispatch_stack", + "torch._C._push_on_torch_function_stack", + "torch._C._quantize_ondevice_ptq_dynamic", + "torch._C._register_py_class_for_device", + "torch._C._remove_cached_tensor", + "torch._C._remove_worker_pids", + "torch._C._rename_privateuse1_backend", + "torch._C._replace_", + "torch._C._replace_overloaded_method_decl", + "torch._C._resolve_type_from_object", + "torch._C._resolve_type", + "torch._C._rocm_is_backward_pass", + "torch._C._rpc_init", + "torch._C._run_emit_module_hook", + "torch._C._save_jit_module_to_bytes", + "torch._C._save_jit_module", + "torch._C._save_mobile_module_to_bytes", + "torch._C._save_mobile_module", + "torch._C._save_parameters", + "torch._C._scatter_out", + "torch._C._scatter", + "torch._C._select_conv_backend", + "torch._C._select_batch_norm_backend", + "torch._C._set_autograd_fallback_mode", + "torch._C._set_backcompat_broadcast_warn", + "torch._C._set_backcompat_keepdim_warn", + "torch._C._set_blas_preferred_backend", + "torch._C._set_cached_tensors_enabled", + "torch._C._set_check_sparse_tensor_invariants", + "torch._C._set_conj", + "torch._C._set_cublas_allow_bf16_reduced_precision_reduction", + "torch._C._set_cublas_allow_fp16_reduced_precision_reduction", + "torch._C._set_cublas_allow_tf32", + "torch._C._set_cudnn_allow_tf32", + "torch._C._set_cudnn_benchmark", + "torch._C._set_cudnn_deterministic", + "torch._C._set_cudnn_enabled", + "torch._C._set_default_dtype", + "torch._C._set_default_mobile_cpu_allocator", + "torch._C._set_default_tensor_type", + "torch._C._set_deterministic_algorithms", + "torch._C._set_deterministic_fill_uninitialized_memory", + "torch._C._set_dispatch_mode", + "torch._C._set_float32_matmul_precision", + "torch._C._set_fwd_grad_enabled", + "torch._C._set_grad_enabled", + "torch._C._set_graph_executor_optimize", + "torch._C._set_linalg_preferred_backend", + "torch._C._set_meta_in_tls_dispatch_include", + "torch._C._set_mkldnn_enabled", + "torch._C._set_multithreading_enabled", + "torch._C._set_neg", + "torch._C._set_nnpack_enabled", + "torch._C._set_print_stack_traces_on_fatal_signal", + "torch._C._set_qengine", + "torch._C._set_sdp_use_flash", + "torch._C._set_sdp_use_math", + "torch._C._set_math_sdp_allow_fp16_bf16_reduction", + "torch._C._set_sdp_use_mem_efficient", + "torch._C._set_should_use_format_with_string_table", + "torch._C._set_storage_access_error_msg", + "torch._C._set_tensor_metadata", + "torch._C._set_tracing_state", + "torch._C._set_value_trace", + "torch._C._set_view_replay_enabled", + "torch._C._set_warnAlways", + "torch._C._set_worker_pids", + "torch._C._set_worker_signal_handlers", + "torch._C._should_allow_numbers_as_tensors", + "torch._C._show_config", + "torch._C._sparse._sparse_addmm", + "torch._C._sparse._sparse_log_softmax", + "torch._C._sparse._sparse_mm_reduce_impl", + "torch._C._sparse._sparse_mm", + "torch._C._sparse._sparse_softmax", + "torch._C._sparse._spdiags", + "torch._C._sparse.sparse_sampled_addmm", + "torch._C._special.special_airy_ai", + "torch._C._special.special_bessel_j0", + "torch._C._special.special_bessel_j1", + "torch._C._special.special_bessel_y0", + "torch._C._special.special_bessel_y1", + "torch._C._special.special_chebyshev_polynomial_t", + "torch._C._special.special_chebyshev_polynomial_u", + "torch._C._special.special_chebyshev_polynomial_v", + "torch._C._special.special_chebyshev_polynomial_w", + "torch._C._special.special_digamma", + "torch._C._special.special_entr", + "torch._C._special.special_erf", + "torch._C._special.special_erfc", + "torch._C._special.special_erfcx", + "torch._C._special.special_erfinv", + "torch._C._special.special_exp2", + "torch._C._special.special_expit", + "torch._C._special.special_expm1", + "torch._C._special.special_gammainc", + "torch._C._special.special_gammaincc", + "torch._C._special.special_gammaln", + "torch._C._special.special_hermite_polynomial_h", + "torch._C._special.special_hermite_polynomial_he", + "torch._C._special.special_i0", + "torch._C._special.special_i0e", + "torch._C._special.special_i1", + "torch._C._special.special_i1e", + "torch._C._special.special_laguerre_polynomial_l", + "torch._C._special.special_legendre_polynomial_p", + "torch._C._special.special_log_ndtr", + "torch._C._special.special_log_softmax", + "torch._C._special.special_log1p", + "torch._C._special.special_logit", + "torch._C._special.special_logsumexp", + "torch._C._special.special_modified_bessel_i0", + "torch._C._special.special_modified_bessel_i1", + "torch._C._special.special_modified_bessel_k0", + "torch._C._special.special_modified_bessel_k1", + "torch._C._special.special_multigammaln", + "torch._C._special.special_ndtr", + "torch._C._special.special_ndtri", + "torch._C._special.special_polygamma", + "torch._C._special.special_psi", + "torch._C._special.special_round", + "torch._C._special.special_scaled_modified_bessel_k0", + "torch._C._special.special_scaled_modified_bessel_k1", + "torch._C._special.special_shifted_chebyshev_polynomial_t", + "torch._C._special.special_shifted_chebyshev_polynomial_u", + "torch._C._special.special_shifted_chebyshev_polynomial_v", + "torch._C._special.special_shifted_chebyshev_polynomial_w", + "torch._C._special.special_sinc", + "torch._C._special.special_softmax", + "torch._C._special.special_spherical_bessel_j0", + "torch._C._special.special_xlog1py", + "torch._C._special.special_xlogy", + "torch._C._special.special_zeta", + "torch._C._stash_obj_in_tls", + "torch._C._storage_id", + "torch._C._storage_Use_Count", + "torch._C._supported_qengines", + "torch._C._te.abs", + "torch._C._te.acos", + "torch._C._te.annotate_input_shapes", + "torch._C._te.asin", + "torch._C._te.atan", + "torch._C._te.atan2", + "torch._C._te.ceil", + "torch._C._te.Compute", + "torch._C._te.Compute2", + "torch._C._te.construct_codegen", + "torch._C._te.cos", + "torch._C._te.cosh", + "torch._C._te.erf", + "torch._C._te.erfc", + "torch._C._te.exp", + "torch._C._te.expm1", + "torch._C._te.fixup_missing_shape_info", + "torch._C._te.floor", + "torch._C._te.fmod", + "torch._C._te.frac", + "torch._C._te.ifThenElse", + "torch._C._te.is_graph_compilable", + "torch._C._te.isnan", + "torch._C._te.lgamma", + "torch._C._te.log", + "torch._C._te.log10", + "torch._C._te.log1p", + "torch._C._te.log2", + "torch._C._te.lower", + "torch._C._te.make_shapes_symbolic", + "torch._C._te.pow", + "torch._C._te.Reduce", + "torch._C._te.remainder", + "torch._C._te.remove_graph_output", + "torch._C._te.remove_unused_self_argument", + "torch._C._te.replace_list_output_with_tuple", + "torch._C._te.round", + "torch._C._te.rsqrt", + "torch._C._te.sigmoid", + "torch._C._te.simplify", + "torch._C._te.sin", + "torch._C._te.sinh", + "torch._C._te.sqrt", + "torch._C._te.tan", + "torch._C._te.tanh", + "torch._C._te.trim_graph", + "torch._C._te.trunc", + "torch._C._tensor_impl_raw_handle", + "torch._C._test_only_add_entry_to_op_version_map", + "torch._C._test_only_populate_upgraders", + "torch._C._test_only_remove_entry_to_op_version_map", + "torch._C._test_only_remove_upgraders", + "torch._C._to_functionality_key", + "torch._C._tracer_set_force_outplace", + "torch._C._tracer_set_get_unique_name_fn", + "torch._C._tracer_warn_use_python", + "torch._C._unset_default_mobile_cpu_allocator", + "torch._C._unset_dispatch_mode", + "torch._C._valgrind_supported_platform", + "torch._C._valgrind_toggle_and_dump_stats", + "torch._C._valgrind_toggle", + "torch._C._verbose.mkl_set_verbose", + "torch._C._verbose.mkldnn_set_verbose", + "torch._C._vmapmode_decrement_nesting", + "torch._C._vmapmode_increment_nesting", + "torch._C._warn_deprecation", + "torch._C._warn", + "torch._C._will_engine_execute_node", + "torch._C._wrap_tensor_impl", + "torch._C.fork", + "torch._C.get_autocast_cpu_dtype", + "torch._C.get_autocast_dtype", + "torch._C.get_autocast_gpu_dtype", + "torch._C.get_autocast_ipu_dtype", + "torch._C.get_autocast_xla_dtype", + "torch._C.get_default_dtype", + "torch._C.get_num_interop_threads", + "torch._C.get_num_threads", + "torch._C.import_ir_module_from_buffer", + "torch._C.import_ir_module", + "torch._C.init_num_threads", + "torch._C.is_anomaly_check_nan_enabled", + "torch._C.is_anomaly_enabled", + "torch._C.is_autocast_cache_enabled", + "torch._C.is_autocast_cpu_enabled", + "torch._C.is_autocast_enabled", + "torch._C.is_autocast_ipu_enabled", + "torch._C.is_autocast_xla_enabled", + "torch._C.is_grad_enabled", + "torch._C.is_inference_mode_enabled", + "torch._C.merge_type_from_type_comment", + "torch._C.parse_ir", + "torch._C.parse_schema", + "torch._C.parse_type_comment", + "torch._C.read_vitals", + "torch._C.set_vital", + "torch._C.unify_type_list", + "torch._C.vitals_enabled", + "torch._C.wait", + "torch._cast_Byte", + "torch._cast_Char", + "torch._cast_Double", + "torch._cast_Float", + "torch._cast_Half", + "torch._cast_Int", + "torch._cast_Long", + "torch._cast_Short", + "torch._choose_qparams_per_tensor", + "torch._chunk_cat", + "torch._coalesce", + "torch._compute_linear_combination", + "torch._conj_copy", + "torch._conj_physical", + "torch._conj", + "torch._convert_indices_from_coo_to_csr", + "torch._convert_indices_from_csr_to_coo", + "torch._convert_weight_to_int4pack", + "torch._convolution_mode", + "torch._convolution", + "torch._copy_from_and_resize", + "torch._copy_from", + "torch._cslt_compress", + "torch._cslt_sparse_mm", + "torch._ctc_loss", + "torch._cudnn_ctc_loss", + "torch._cudnn_init_dropout_state", + "torch._cudnn_rnn_flatten_weight", + "torch._cudnn_rnn", + "torch._cufft_clear_plan_cache", + "torch._cufft_get_plan_cache_max_size", + "torch._cufft_get_plan_cache_size", + "torch._cufft_set_plan_cache_max_size", + "torch._cummax_helper", + "torch._cummin_helper", + "torch._debug_has_internal_overlap", + "torch._dim_arange", + "torch._dirichlet_grad", + "torch._disable_functionalization", + "torch._efficientzerotensor", + "torch._embedding_bag_forward_only", + "torch._embedding_bag", + "torch._empty_affine_quantized", + "torch._empty_per_channel_affine_quantized", + "torch._enable_functionalization", + "torch._euclidean_dist", + "torch._fake_quantize_learnable_per_channel_affine", + "torch._fake_quantize_learnable_per_tensor_affine", + "torch._fake_quantize_per_tensor_affine_cachemask_tensor_qparams", + "torch._fft_c2c", + "torch._fft_c2r", + "torch._fft_r2c", + "torch._fill_mem_eff_dropout_mask_", + "torch._foobar", + "torch._foreach_abs_", + "torch._foreach_abs", + "torch._foreach_acos_", + "torch._foreach_acos", + "torch._foreach_add_", + "torch._foreach_add", + "torch._foreach_addcdiv_", + "torch._foreach_addcdiv", + "torch._foreach_addcmul_", + "torch._foreach_addcmul", + "torch._foreach_asin_", + "torch._foreach_asin", + "torch._foreach_atan_", + "torch._foreach_atan", + "torch._foreach_ceil_", + "torch._foreach_ceil", + "torch._foreach_clamp_max_", + "torch._foreach_clamp_max", + "torch._foreach_clamp_min_", + "torch._foreach_clamp_min", + "torch._foreach_copy_", + "torch._foreach_cos_", + "torch._foreach_cos", + "torch._foreach_cosh_", + "torch._foreach_cosh", + "torch._foreach_div_", + "torch._foreach_div", + "torch._foreach_erf_", + "torch._foreach_erf", + "torch._foreach_erfc_", + "torch._foreach_erfc", + "torch._foreach_exp_", + "torch._foreach_exp", + "torch._foreach_expm1_", + "torch._foreach_expm1", + "torch._foreach_floor_", + "torch._foreach_floor", + "torch._foreach_frac_", + "torch._foreach_frac", + "torch._foreach_lerp_", + "torch._foreach_lerp", + "torch._foreach_lgamma_", + "torch._foreach_lgamma", + "torch._foreach_log_", + "torch._foreach_log", + "torch._foreach_log10_", + "torch._foreach_log10", + "torch._foreach_log1p_", + "torch._foreach_log1p", + "torch._foreach_log2_", + "torch._foreach_log2", + "torch._foreach_maximum_", + "torch._foreach_maximum", + "torch._foreach_minimum_", + "torch._foreach_minimum", + "torch._foreach_mul_", + "torch._foreach_mul", + "torch._foreach_neg_", + "torch._foreach_neg", + "torch._foreach_norm", + "torch._foreach_pow_", + "torch._foreach_pow", + "torch._foreach_reciprocal_", + "torch._foreach_reciprocal", + "torch._foreach_round_", + "torch._foreach_round", + "torch._foreach_sigmoid_", + "torch._foreach_sigmoid", + "torch._foreach_sign_", + "torch._foreach_sign", + "torch._foreach_sin_", + "torch._foreach_sin", + "torch._foreach_sinh_", + "torch._foreach_sinh", + "torch._foreach_sqrt_", + "torch._foreach_sqrt", + "torch._foreach_sub_", + "torch._foreach_sub", + "torch._foreach_tan_", + "torch._foreach_tan", + "torch._foreach_tanh_", + "torch._foreach_tanh", + "torch._foreach_trunc_", + "torch._foreach_trunc", + "torch._foreach_zero_", + "torch._freeze_functional_tensor", + "torch._from_functional_tensor", + "torch._functional_assert_async", + "torch._functional_sym_constrain_range_for_size", + "torch._functional_sym_constrain_range", + "torch._functionalize_are_all_mutations_hidden_from_autograd", + "torch._functionalize_commit_update", + "torch._functionalize_enable_reapply_views", + "torch._functionalize_has_data_mutation", + "torch._functionalize_has_metadata_mutation", + "torch._functionalize_is_multi_output_view", + "torch._functionalize_mark_mutation_hidden_from_autograd", + "torch._functionalize_replace", + "torch._functionalize_sync", + "torch._functionalize_was_storage_changed", + "torch._fused_adam_", + "torch._fused_adamw_", + "torch._fused_dropout", + "torch._fused_moving_avg_obs_fq_helper", + "torch._fused_sdp_choice", + "torch._fw_primal_copy", + "torch._grid_sampler_2d_cpu_fallback", + "torch._has_compatible_shallow_copy_type", + "torch._histogramdd_bin_edges", + "torch._histogramdd_from_bin_cts", + "torch._histogramdd_from_bin_tensors", + "torch._index_put_impl_", + "torch._indices_copy", + "torch._int_mm", + "torch._is_all_true", + "torch._is_any_true", + "torch._is_functional_tensor", + "torch._is_zerotensor", + "torch._linalg_check_errors", + "torch._linalg_det", + "torch._linalg_eigh", + "torch._linalg_eigvals", + "torch._linalg_slogdet", + "torch._linalg_solve_ex", + "torch._linalg_svd", + "torch._log_softmax_backward_data", + "torch._log_softmax", + "torch._logcumsumexp", + "torch._lstm_mps", + "torch._lu_with_info", + "torch._make_dep_token", + "torch._make_dual_copy", + "torch._make_dual", + "torch._make_per_channel_quantized_tensor", + "torch._make_per_tensor_quantized_tensor", + "torch._masked_scale", + "torch._masked_softmax", + "torch._mirror_autograd_meta_to", + "torch._mixed_dtypes_linear", + "torch._mkldnn_reshape", + "torch._mkldnn_transpose_", + "torch._mkldnn_transpose", + "torch._mps_convolution_transpose", + "torch._mps_convolution", + "torch._native_batch_norm_legit_no_training", + "torch._native_batch_norm_legit", + "torch._native_multi_head_attention", + "torch._neg_view_copy", + "torch._neg_view", + "torch._nested_from_padded_and_nested_example", + "torch._nested_tensor_from_mask_left_aligned", + "torch._nested_tensor_from_tensor_list", + "torch._nested_tensor_softmax_with_shape", + "torch._nested_view_from_buffer_copy", + "torch._nested_view_from_buffer", + "torch._nnpack_available", + "torch._nnpack_spatial_convolution", + "torch._pack_padded_sequence", + "torch._pad_packed_sequence", + "torch._pin_memory", + "torch._prelu_kernel", + "torch._propagate_xla_data", + "torch._remove_batch_dim", + "torch._reshape_alias_copy", + "torch._reshape_from_tensor", + "torch._resize_output_", + "torch._rowwise_prune", + "torch._sample_dirichlet", + "torch._saturate_weight_to_fp16", + "torch._scaled_dot_product_attention_math", + "torch._scaled_dot_product_efficient_attention", + "torch._scaled_dot_product_flash_attention", + "torch._scaled_dot_product_flash_attention_for_cpu", + "torch._scaled_dot_product_cudnn_attention", + "torch._scaled_mm", + "torch._shape_as_tensor", + "torch._sobol_engine_draw", + "torch._sobol_engine_ff_", + "torch._sobol_engine_initialize_state_", + "torch._sobol_engine_scramble_", + "torch._softmax_backward_data", + "torch._softmax", + "torch._sparse_broadcast_to_copy", + "torch._sparse_broadcast_to", + "torch._sparse_csr_prod", + "torch._sparse_csr_sum", + "torch._sparse_log_softmax_backward_data", + "torch._sparse_semi_structured_addmm", + "torch._sparse_semi_structured_linear", + "torch._sparse_semi_structured_mm", + "torch._sparse_softmax_backward_data", + "torch._sparse_sparse_matmul", + "torch._sparse_sum", + "torch._stack", + "torch._standard_gamma_grad", + "torch._standard_gamma", + "torch._test_autograd_multiple_dispatch_view_copy", + "torch._test_autograd_multiple_dispatch_view", + "torch._test_autograd_multiple_dispatch", + "torch._test_check_tensor", + "torch._test_functorch_fallback", + "torch._test_serialization_subcmul", + "torch._to_cpu", + "torch._to_functional_tensor", + "torch._to_sparse_semi_structured", + "torch._transform_bias_rescale_qkv", + "torch._transformer_encoder_layer_fwd", + "torch._trilinear", + "torch._triton_multi_head_attention", + "torch._triton_scaled_dot_attention", + "torch._unique", + "torch._unique2", + "torch._unpack_dual", + "torch._unsafe_index_put", + "torch._unsafe_index", + "torch._unsafe_masked_index_put_accumulate", + "torch._unsafe_masked_index", + "torch._use_cudnn_ctc_loss", + "torch._use_cudnn_rnn_flatten_weight", + "torch._values_copy", + "torch._weight_int4pack_mm", + "torch._weight_int8pack_mm", + "torch._weight_norm_interface", + "torch._weight_norm", + "torch.abs_", + "torch.abs", + "torch.absolute", + "torch.acos_", + "torch.acos", + "torch.acosh_", + "torch.acosh", + "torch.adaptive_avg_pool1d", + "torch.adaptive_max_pool1d", + "torch.add", + "torch.addbmm", + "torch.addcdiv", + "torch.addcmul", + "torch.addmm", + "torch.addmv_", + "torch.addmv", + "torch.addr", + "torch.adjoint", + "torch.affine_grid_generator", + "torch.alias_copy", + "torch.all", + "torch.allclose", + "torch.alpha_dropout_", + "torch.alpha_dropout", + "torch.amax", + "torch.amin", + "torch.aminmax", + "torch.angle", + "torch.any", + "torch.arange", + "torch.arccos_", + "torch.arccos", + "torch.arccosh_", + "torch.arccosh", + "torch.arcsin_", + "torch.arcsin", + "torch.arcsinh_", + "torch.arcsinh", + "torch.arctan_", + "torch.arctan", + "torch.arctan2", + "torch.arctanh_", + "torch.arctanh", + "torch.argmax", + "torch.argmin", + "torch.argsort", + "torch.argwhere", + "torch.as_strided_", + "torch.as_strided_copy", + "torch.as_strided_scatter", + "torch.as_strided", + "torch.as_tensor", + "torch.asarray", + "torch.asin_", + "torch.asin", + "torch.asinh_", + "torch.asinh", + "torch.atan_", + "torch.atan", + "torch.atan2", + "torch.atanh_", + "torch.atanh", + "torch.avg_pool1d", + "torch.baddbmm", + "torch.bartlett_window", + "torch.batch_norm_backward_elemt", + "torch.batch_norm_backward_reduce", + "torch.batch_norm_elemt", + "torch.batch_norm_gather_stats_with_counts", + "torch.batch_norm_gather_stats", + "torch.batch_norm_stats", + "torch.batch_norm_update_stats", + "torch.batch_norm", + "torch.bernoulli", + "torch.bilinear", + "torch.binary_cross_entropy_with_logits", + "torch.bincount", + "torch.binomial", + "torch.bitwise_and", + "torch.bitwise_left_shift", + "torch.bitwise_not", + "torch.bitwise_or", + "torch.bitwise_right_shift", + "torch.bitwise_xor", + "torch.blackman_window", + "torch.bmm", + "torch.broadcast_to", + "torch.bucketize", + "torch.can_cast", + "torch.cat", + "torch.ccol_indices_copy", + "torch.ceil_", + "torch.ceil", + "torch.celu_", + "torch.celu", + "torch.channel_shuffle", + "torch.cholesky_inverse", + "torch.cholesky_solve", + "torch.cholesky", + "torch.choose_qparams_optimized", + "torch.chunk", + "torch.clamp_", + "torch.clamp_max_", + "torch.clamp_max", + "torch.clamp_min_", + "torch.clamp_min", + "torch.clamp", + "torch.clip_", + "torch.clip", + "torch.clone", + "torch.col_indices_copy", + "torch.column_stack", + "torch.combinations", + "torch.complex", + "torch.concat", + "torch.concatenate", + "torch.conj_physical_", + "torch.conj_physical", + "torch.conj", + "torch.constant_pad_nd", + "torch.conv_tbc", + "torch.conv_transpose1d", + "torch.conv_transpose2d", + "torch.conv_transpose3d", + "torch.conv1d", + "torch.conv2d", + "torch.conv3d", + "torch.convolution", + "torch.copysign", + "torch.corrcoef", + "torch.cos_", + "torch.cos", + "torch.cosh_", + "torch.cosh", + "torch.cosine_embedding_loss", + "torch.cosine_similarity", + "torch.count_nonzero", + "torch.cov", + "torch.cross", + "torch.crow_indices_copy", + "torch.ctc_loss", + "torch.cudnn_affine_grid_generator", + "torch.cudnn_batch_norm", + "torch.cudnn_convolution_add_relu", + "torch.cudnn_convolution_relu", + "torch.cudnn_convolution_transpose", + "torch.cudnn_convolution", + "torch.cudnn_grid_sampler", + "torch.cudnn_is_acceptable", + "torch.cummax", + "torch.cummin", + "torch.cumprod", + "torch.cumsum", + "torch.cumulative_trapezoid", + "torch.deg2rad_", + "torch.deg2rad", + "torch.dequantize", + "torch.det", + "torch.detach_", + "torch.detach_copy", + "torch.detach", + "torch.diag_embed", + "torch.diag", + "torch.diagflat", + "torch.diagonal_copy", + "torch.diagonal_scatter", + "torch.diagonal", + "torch.diff", + "torch.digamma", + "torch.dist", + "torch.div", + "torch.divide", + "torch.dot", + "torch.dropout_", + "torch.dropout", + "torch.dsmm", + "torch.dsplit", + "torch.dstack", + "torch.embedding_bag", + "torch.embedding_renorm_", + "torch.embedding", + "torch.empty_like", + "torch.empty_permuted", + "torch.empty_quantized", + "torch.empty_strided", + "torch.empty", + "torch.eq", + "torch.equal", + "torch.erf_", + "torch.erf", + "torch.erfc_", + "torch.erfc", + "torch.erfinv", + "torch.exp_", + "torch.exp", + "torch.exp2_", + "torch.exp2", + "torch.expand_copy", + "torch.expm1_", + "torch.expm1", + "torch.eye", + "torch.fake_quantize_per_channel_affine", + "torch.fake_quantize_per_tensor_affine", + "torch.fbgemm_linear_fp16_weight_fp32_activation", + "torch.fbgemm_linear_fp16_weight", + "torch.fbgemm_linear_int8_weight_fp32_activation", + "torch.fbgemm_linear_int8_weight", + "torch.fbgemm_linear_quantize_weight", + "torch.fbgemm_pack_gemm_matrix_fp16", + "torch.fbgemm_pack_quantized_matrix", + "torch.feature_alpha_dropout_", + "torch.feature_alpha_dropout", + "torch.feature_dropout_", + "torch.feature_dropout", + "torch.fill_", + "torch.fill", + "torch.fix_", + "torch.fix", + "torch.flatten", + "torch.flip", + "torch.fliplr", + "torch.flipud", + "torch.float_power", + "torch.floor_", + "torch.floor_divide", + "torch.floor", + "torch.fmax", + "torch.fmin", + "torch.fmod", + "torch.frac_", + "torch.frac", + "torch.frexp", + "torch.frobenius_norm", + "torch.from_file", + "torch.from_numpy", + "torch.frombuffer", + "torch.full_like", + "torch.full", + "torch.fused_moving_avg_obs_fake_quant", + "torch.gather", + "torch.gcd_", + "torch.gcd", + "torch.ge", + "torch.geqrf", + "torch.ger", + "torch.get_device", + "torch.gradient", + "torch.greater_equal", + "torch.greater", + "torch.grid_sampler_2d", + "torch.grid_sampler_3d", + "torch.grid_sampler", + "torch.group_norm", + "torch.gru_cell", + "torch.gru", + "torch.gt", + "torch.hamming_window", + "torch.hann_window", + "torch.hardshrink", + "torch.heaviside", + "torch.hinge_embedding_loss", + "torch.histc", + "torch.histogram", + "torch.histogramdd", + "torch.hsmm", + "torch.hsplit", + "torch.hspmm", + "torch.hstack", + "torch.hypot", + "torch.i0_", + "torch.i0", + "torch.igamma", + "torch.igammac", + "torch.imag", + "torch.index_add", + "torch.index_copy", + "torch.index_fill", + "torch.index_put_", + "torch.index_put", + "torch.index_reduce", + "torch.index_select", + "torch.indices_copy", + "torch.inner", + "torch.instance_norm", + "torch.int_repr", + "torch.inverse", + "torch.is_complex", + "torch.is_conj", + "torch.is_distributed", + "torch.is_floating_point", + "torch.is_inference", + "torch.is_neg", + "torch.is_nonzero", + "torch.is_same_size", + "torch.is_signed", + "torch.is_vulkan_available", + "torch.isclose", + "torch.isfinite", + "torch.isin", + "torch.isinf", + "torch.isnan", + "torch.isneginf", + "torch.isposinf", + "torch.isreal", + "torch.istft", + "torch.kaiser_window", + "torch.kl_div", + "torch.kron", + "torch.kthvalue", + "torch.layer_norm", + "torch.lcm_", + "torch.lcm", + "torch.ldexp_", + "torch.ldexp", + "torch.le", + "torch.lerp", + "torch.less_equal", + "torch.less", + "torch.lgamma", + "torch.linspace", + "torch.log_", + "torch.log_softmax", + "torch.log", + "torch.log10_", + "torch.log10", + "torch.log1p_", + "torch.log1p", + "torch.log2_", + "torch.log2", + "torch.logaddexp", + "torch.logaddexp2", + "torch.logcumsumexp", + "torch.logdet", + "torch.logical_and", + "torch.logical_not", + "torch.logical_or", + "torch.logical_xor", + "torch.logit_", + "torch.logit", + "torch.logspace", + "torch.logsumexp", + "torch.lstm_cell", + "torch.lstm", + "torch.lt", + "torch.lu_solve", + "torch.lu_unpack", + "torch.margin_ranking_loss", + "torch.masked_fill", + "torch.masked_scatter", + "torch.masked_select", + "torch.matmul", + "torch.matrix_exp", + "torch.matrix_power", + "torch.max_pool1d_with_indices", + "torch.max_pool1d", + "torch.max_pool2d", + "torch.max_pool3d", + "torch.max", + "torch.maximum", + "torch.mean", + "torch.median", + "torch.min", + "torch.minimum", + "torch.miopen_batch_norm", + "torch.miopen_convolution_add_relu", + "torch.miopen_convolution_relu", + "torch.miopen_convolution_transpose", + "torch.miopen_convolution", + "torch.miopen_depthwise_convolution", + "torch.miopen_rnn", + "torch.mkldnn_adaptive_avg_pool2d", + "torch.mkldnn_convolution", + "torch.mkldnn_linear_backward_weights", + "torch.mkldnn_max_pool2d", + "torch.mkldnn_max_pool3d", + "torch.mkldnn_rnn_layer", + "torch.mm", + "torch.mode", + "torch.moveaxis", + "torch.movedim", + "torch.msort", + "torch.mul", + "torch.multinomial", + "torch.multiply", + "torch.mv", + "torch.mvlgamma", + "torch.nan_to_num_", + "torch.nan_to_num", + "torch.nanmean", + "torch.nanmedian", + "torch.nanquantile", + "torch.nansum", + "torch.narrow_copy", + "torch.narrow", + "torch.native_batch_norm", + "torch.native_channel_shuffle", + "torch.native_dropout", + "torch.native_group_norm", + "torch.native_layer_norm", + "torch.native_norm", + "torch.ne", + "torch.neg_", + "torch.neg", + "torch.negative_", + "torch.negative", + "torch.nextafter", + "torch.nonzero_static", + "torch.nonzero", + "torch.norm_except_dim", + "torch.normal", + "torch.not_equal", + "torch.nuclear_norm", + "torch.numel", + "torch.ones_like", + "torch.ones", + "torch.orgqr", + "torch.ormqr", + "torch.outer", + "torch.pairwise_distance", + "torch.pdist", + "torch.permute_copy", + "torch.permute", + "torch.pinverse", + "torch.pixel_shuffle", + "torch.pixel_unshuffle", + "torch.poisson_nll_loss", + "torch.poisson", + "torch.polar", + "torch.polygamma", + "torch.positive", + "torch.pow", + "torch.prelu", + "torch._print", + "torch.prod", + "torch.promote_types", + "torch.put", + "torch.q_per_channel_axis", + "torch.q_per_channel_scales", + "torch.q_per_channel_zero_points", + "torch.q_scale", + "torch.q_zero_point", + "torch.qr", + "torch.quantile", + "torch.quantize_per_channel", + "torch.quantize_per_tensor_dynamic", + "torch.quantize_per_tensor", + "torch.quantized_batch_norm", + "torch.quantized_gru_cell", + "torch.quantized_lstm_cell", + "torch.quantized_max_pool1d", + "torch.quantized_max_pool2d", + "torch.quantized_max_pool3d", + "torch.quantized_rnn_relu_cell", + "torch.quantized_rnn_tanh_cell", + "torch.rad2deg_", + "torch.rad2deg", + "torch.rand_like", + "torch.rand", + "torch.randint_like", + "torch.randint", + "torch.randn_like", + "torch.randn", + "torch.randperm", + "torch.range", + "torch.ravel", + "torch.real", + "torch.reciprocal_", + "torch.reciprocal", + "torch.relu_", + "torch.relu", + "torch.remainder", + "torch.renorm", + "torch.repeat_interleave", + "torch.reshape", + "torch.resolve_conj", + "torch.resolve_neg", + "torch.result_type", + "torch.rms_norm", + "torch.rnn_relu_cell", + "torch.rnn_relu", + "torch.rnn_tanh_cell", + "torch.rnn_tanh", + "torch.roll", + "torch.rot90", + "torch.round_", + "torch.round", + "torch.row_indices_copy", + "torch.row_stack", + "torch.rrelu_", + "torch.rrelu", + "torch.rsqrt_", + "torch.rsqrt", + "torch.rsub", + "torch.saddmm", + "torch.scalar_tensor", + "torch.scatter_add", + "torch.scatter_reduce", + "torch.scatter", + "torch.searchsorted", + "torch.segment_reduce", + "torch.select_copy", + "torch.select_scatter", + "torch.select", + "torch.selu_", + "torch.selu", + "torch.sgn", + "torch.sigmoid_", + "torch.sigmoid", + "torch.sign", + "torch.signal.windows.windows.sqrt", + "torch.signbit", + "torch.sin_", + "torch.sin", + "torch.sinc_", + "torch.sinc", + "torch.sinh_", + "torch.sinh", + "torch.slice_copy", + "torch.slice_scatter", + "torch.slogdet", + "torch.smm", + "torch.softmax", + "torch.sort", + "torch.split_copy", + "torch.split_with_sizes_copy", + "torch.split_with_sizes", + "torch.spmm", + "torch.sqrt_", + "torch.sqrt", + "torch.square_", + "torch.square", + "torch.squeeze_copy", + "torch.squeeze", + "torch.sspaddmm", + "torch.stack", + "torch.std_mean", + "torch.std", + "torch.sub", + "torch.subtract", + "torch.sum", + "torch.svd", + "torch.swapaxes", + "torch.swapdims", + "torch.sym_constrain_range_for_size", + "torch.sym_constrain_range", + "torch.t_copy", + "torch.t", + "torch.take_along_dim", + "torch.take", + "torch.tan_", + "torch.tan", + "torch.tanh_", + "torch.tanh", + "torch.tensor_split", + "torch.tensor", + "torch.threshold_", + "torch.threshold", + "torch.tile", + "torch.topk", + "torch.trace", + "torch.transpose_copy", + "torch.transpose", + "torch.trapezoid", + "torch.trapz", + "torch.triangular_solve", + "torch.tril_indices", + "torch.tril", + "torch.triplet_margin_loss", + "torch.triu_indices", + "torch.triu", + "torch.true_divide", + "torch.trunc_", + "torch.trunc", + "torch.unbind_copy", + "torch.unbind", + "torch.unflatten", + "torch.unfold_copy", + "torch.unsafe_chunk", + "torch.unsafe_split_with_sizes", + "torch.unsafe_split", + "torch.unsqueeze_copy", + "torch.unsqueeze", + "torch.values_copy", + "torch.vander", + "torch.var_mean", + "torch.var", + "torch.vdot", + "torch.view_as_complex_copy", + "torch.view_as_complex", + "torch.view_as_real_copy", + "torch.view_as_real", + "torch.view_copy", + "torch.vsplit", + "torch.vstack", + "torch.where", + "torch.xlogy_", + "torch.xlogy", + "torch.zero_", + "torch.zeros", + "torch.zeros_like", + "torch._fused_sgd_", + "torch.slice_inverse", + "torch._assert_scalar", + "torch._functional_assert_scalar", + ], + TorchInGraphFunctionVariable, +) + + +if sys.version_info >= (3, 9): + torch_c_binding_in_graph_functions["math.lcm"] = TorchInGraphFunctionVariable +if sys.version_info >= (3, 11): + torch_c_binding_in_graph_functions["math.exp2"] = TorchInGraphFunctionVariable + torch_c_binding_in_graph_functions["math.cbrt"] = TorchInGraphFunctionVariable + + +# In graph functions (including constant folding) that are not C bindings +torch_non_c_binding_in_graph_functions = dict.fromkeys( + [ + "torch.__future__.get_overwrite_module_params_on_conversion", + "torch.__future__.set_overwrite_module_params_on_conversion", + "torch.__getattr__", + "torch._assert", + "torch._check_index", + "torch._check_is_size", + "torch._check_not_implemented", + "torch._check_tensor_all_with", + "torch._check_tensor_all", + "torch._check_type", + "torch._check_value", + "torch._check_with", + "torch._check", + "torch._compile._disable_dynamo", + "torch._functorch.apis.chunk_vmap", + "torch._functorch.autograd_function.custom_function_call_functionalize", + "torch._functorch.autograd_function.custom_function_call_grad", + "torch._functorch.autograd_function.custom_function_call_vmap_generate_rule", + "torch._functorch.autograd_function.custom_function_call_vmap", + "torch._functorch.autograd_function.generate_single_level_function", + "torch._functorch.autograd_function.get_tangents_in_dims", + "torch._functorch.autograd_function.has_overriden_vmap_rule", + "torch._functorch.autograd_function.reductify_leaf", + "torch._functorch.autograd_function.reductify", + "torch._functorch.autograd_function.validate_vmap_returns_tuple_of_two_elements", + "torch._functorch.autograd_function.vmapify_autograd_function", + "torch._functorch.autograd_function.wrap_outputs_maintaining_identity", + "torch._functorch.batch_norm_replacement.batch_norm_without_running_stats", + "torch._functorch.batch_norm_replacement.replace_all_batch_norm_modules_", + "torch._functorch.deprecated.combine_state_for_ensemble", + "torch._functorch.deprecated.functionalize", + "torch._functorch.deprecated.get_warning", + "torch._functorch.deprecated.make_functional_with_buffers", + "torch._functorch.deprecated.make_functional", + "torch._functorch.deprecated.setup_docs", + "torch._functorch.deprecated.warn_deprecated", + "torch._functorch.eager_transforms._any_differentiable", + "torch._functorch.eager_transforms._autograd_grad", + "torch._functorch.eager_transforms._vjp_treespec_compare", + "torch._functorch.eager_transforms._set_tensor_requires_grad", + "torch._functorch.eager_transforms._jvp_treespec_compare", + "torch._functorch.eager_transforms._linearize_treespec_compare", + "torch._functorch.eager_transforms._is_differentiable", + "torch._functorch.eager_transforms._maybe_unwrap_functional_tensor", + "torch._functorch.eager_transforms._maybe_wrap_functional_tensor", + "torch._functorch.eager_transforms._unwrap_all_tensors_from_functional", + "torch._functorch.eager_transforms._wrap_all_tensors_to_functional", + "torch._functorch.eager_transforms.assert_flat_tuple_of_tensors", + "torch._functorch.eager_transforms.functionalize", + "torch._functorch.eager_transforms.lazy_dynamo_disable", + "torch._functorch.eager_transforms.noop", + "torch._functorch.pyfunctorch.coerce_cinterpreter", + "torch._functorch.pyfunctorch.dispatch_functorch", + "torch._functorch.pyfunctorch.nested", + "torch._functorch.pyfunctorch.retrieve_current_functorch_interpreter", + "torch._functorch.pyfunctorch.temporarily_pop_interpreter_stack", + "torch._functorch.utils.enable_single_level_autograd_function", + "torch._functorch.utils.exposed_in", + "torch._functorch.utils.unwrap_dead_wrappers", + "torch._functorch.vmap.lazy_load_decompositions", + "torch._guards.compile_context", + "torch._guards.detect_fake_mode", + "torch._guards.tracing", + "torch._higher_order_ops.map._has_potential_branch_input_alias", + "torch._higher_order_ops.map._has_potential_branch_input_mutation", + "torch._higher_order_ops.map._stack_pytree", + "torch._higher_order_ops.map._unstack_pytree", + "torch._higher_order_ops.map.create_fw_bw_graph", + "torch._higher_order_ops.map.map_autograd", + "torch._higher_order_ops.map.map_dense", + "torch._higher_order_ops.map.map_fake_tensor_mode", + "torch._higher_order_ops.map.map_functionalize", + "torch._higher_order_ops.map.map_proxy_torch_dispatch_mode", + "torch._higher_order_ops.map.map_wrapper", + "torch._higher_order_ops.map.trace_map", + "torch._higher_order_ops.out_dtype.elementwise_dtypes", + "torch._higher_order_ops.out_dtype.is_int_mm", + "torch._higher_order_ops.out_dtype.out_dtype_dense", + "torch._higher_order_ops.out_dtype.out_dtype_fake_tensor_mode", + "torch._higher_order_ops.out_dtype.out_dtype_fallback", + "torch._higher_order_ops.out_dtype.out_dtype_func", + "torch._higher_order_ops.out_dtype.out_dtype_proxy", + "torch._higher_order_ops.out_dtype.trace_out_dtype", + "torch._higher_order_ops.utils.autograd_not_implemented_inner", + "torch._higher_order_ops.utils.autograd_not_implemented", + "torch._linalg_utils._symeig", + "torch._linalg_utils.basis", + "torch._linalg_utils.bform", + "torch._linalg_utils.eig", + "torch._linalg_utils.get_floating_dtype", + "torch._linalg_utils.is_sparse", + "torch._linalg_utils.lstsq", + "torch._linalg_utils.matmul", + "torch._linalg_utils.matrix_rank", + "torch._linalg_utils.qform", + "torch._linalg_utils.solve", + "torch._linalg_utils.symeig", + "torch._load_global_deps", + "torch._lowrank._svd_lowrank", + "torch._lowrank.get_approximate_basis", + "torch._lowrank.pca_lowrank", + "torch._lowrank.svd_lowrank", + "torch._ops._compute_keyset", + "torch._ops._get_tensors", + "torch._ops._to_flat_tuple", + "torch._ops.add_cached_op", + "torch._ops.dl_open_guard", + "torch._ops.get_cached_ops", + "torch._ops.key_extractor", + "torch._ops.reset_cached_ops", + "torch._ops.resolve_key", + "torch._preload_cuda_deps", + "torch._register_device_module", + "torch._running_with_deploy", + "torch._utils._dummy_type", + "torch._weights_only_unpickler._get_allowed_globals", + "torch._weights_only_unpickler.load", + "torch.align_tensors", + "torch.amp.autocast_mode._enter_autocast", + "torch.amp.autocast_mode._exit_autocast", + "torch.amp.autocast_mode.autocast_decorator", + "torch.amp.autocast_mode.custom_bwd", + "torch.amp.autocast_mode.custom_fwd", + "torch.are_deterministic_algorithms_enabled", + "torch.atleast_1d", + "torch.atleast_2d", + "torch.atleast_3d", + "torch.autograd._calculate_shape", + "torch.autograd._is_checkpoint_valid", + "torch.autograd._make_grads", + "torch.autograd._register_py_tensor_class_for_device", + "torch.autograd._tensor_or_tensors_to_tuple", + "torch.autograd.forward_ad._maybe_load_decompositions", + "torch.autograd.function._iter_filter", + "torch.autograd.function._iter_jit_values", + "torch.autograd.function._iter_None_tensors", + "torch.autograd.function._iter_tensors_permissive", + "torch.autograd.function._iter_tensors", + "torch.autograd.function._jit_unwrap_structured", + "torch.autograd.function._map_tensor_data", + "torch.autograd.function._nested_map", + "torch.autograd.function._unflatten", + "torch.autograd.function.once_differentiable", + "torch.autograd.function.traceable", + "torch.autograd.functional._as_tuple_nocheck", + "torch.autograd.functional._as_tuple", + "torch.autograd.functional._autograd_grad", + "torch.autograd.functional._check_requires_grad", + "torch.autograd.functional._construct_standard_basis_for", + "torch.autograd.functional._fill_in_zeros", + "torch.autograd.functional._grad_postprocess", + "torch.autograd.functional._grad_preprocess", + "torch.autograd.functional._jacfwd", + "torch.autograd.functional._tuple_postprocess", + "torch.autograd.functional._validate_v", + "torch.autograd.functional.hessian", + "torch.autograd.functional.hvp", + "torch.autograd.functional.jacobian", + "torch.autograd.functional.jvp", + "torch.autograd.functional.vhp", + "torch.autograd.functional.vjp", + "torch.autograd.grad_mode._enter_inference_mode", + "torch.autograd.grad_mode._exit_inference_mode", + "torch.autograd.graph._get_sid", + "torch.autograd.graph._get_tid", + "torch.autograd.graph.allow_mutation_on_saved_tensors", + "torch.autograd.graph.get_gradient_edge", + "torch.autograd.graph.increment_version", + "torch.autograd.graph.register_multi_grad_hook", + "torch.autograd.variable", + "torch.backends.__allow_nonbracketed_mutation", + "torch.backends.cpu.get_cpu_capability", + "torch.backends.cuda.can_use_efficient_attention", + "torch.backends.cuda.can_use_flash_attention", + "torch.backends.cuda.can_use_cudnn_attention", + "torch.backends.cuda.enable_flash_sdp", + "torch.backends.cuda.enable_math_sdp", + "torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp", + "torch.backends.cuda.enable_mem_efficient_sdp", + "torch.backends.cuda.flash_sdp_enabled", + "torch.backends.cuda.is_built", + "torch.backends.cuda.is_flash_attention_available", + "torch.backends.cuda.math_sdp_enabled", + "torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed", + "torch.backends.cuda.mem_efficient_sdp_enabled", + "torch.backends.cuda.cudnn_sdp_enabled", + "torch.backends.cuda.enable_cudnn_sdp", + "torch.backends.cuda.preferred_blas_library", + "torch.backends.cuda.preferred_linalg_library", + "torch.backends.cuda.sdp_kernel", + "torch.backends.cudnn._init", + "torch.backends.cudnn.flags", + "torch.backends.cudnn.is_acceptable", + "torch.backends.cudnn.is_available", + "torch.backends.cudnn.set_flags", + "torch.backends.cudnn.version", + "torch.backends.disable_global_flags", + "torch.backends.flags_frozen", + "torch.backends.mkl.is_available", + "torch.backends.mkldnn.flags", + "torch.backends.mkldnn.is_available", + "torch.backends.mkldnn.set_flags", + "torch.backends.mps._init", + "torch.backends.mps.is_available", + "torch.backends.mps.is_built", + "torch.backends.mps.is_macos13_or_newer", + "torch.backends.openmp.is_available", + "torch.backends.quantized._get_qengine_id", + "torch.backends.quantized._get_qengine_str", + "torch.block_diag", + "torch.broadcast_tensors", + "torch.cartesian_prod", + "torch.cdist", + "torch.chain_matmul", + "torch.compile", + "torch.compiled_with_cxx11_abi", + "torch._C._cpu._is_avx2_supported", + "torch._C._cpu._is_avx512_supported", + "torch._C._cpu._is_avx512_vnni_supported", + "torch._C._cpu._is_avx512_bf16_supported", + "torch._C._cpu._is_amx_tile_supported", + "torch.cpu._init_amx", + "torch.cpu.current_device", + "torch.cpu.current_stream", + "torch.cpu.device_count", + "torch.cpu.is_available", + "torch.cpu.set_device", + "torch.cpu.stream", + "torch.cpu.synchronize", + "torch.cuda._check_capability", + "torch.cuda._check_cubins", + "torch.cuda._device_count_amdsmi", + "torch.cuda._device_count_nvml", + "torch.cuda._get_amdsmi_handler", + "torch.cuda._get_amdsmi_device_index", + "torch.cuda._get_device", + "torch.cuda._get_generator", + "torch.cuda._get_nvml_device_index", + "torch.cuda._get_pynvml_handler", + "torch.cuda._get_rng_state_offset", + "torch.cuda._is_compiled", + "torch.cuda._lazy_call", + "torch.cuda._lazy_init", + "torch.cuda._memory_viz._block_extra_legacy", + "torch.cuda._memory_viz._block_extra", + "torch.cuda._memory_viz._format_size", + "torch.cuda._memory_viz._format_viz", + "torch.cuda._memory_viz._frame_filter", + "torch.cuda._memory_viz._frame_fmt", + "torch.cuda._memory_viz._frames_fmt", + "torch.cuda._memory_viz._profile_to_snapshot", + "torch.cuda._memory_viz._report_free", + "torch.cuda._memory_viz._write_blocks", + "torch.cuda._memory_viz.calc_active", + "torch.cuda._memory_viz.compare", + "torch.cuda._memory_viz.format_flamegraph", + "torch.cuda._memory_viz.memory", + "torch.cuda._memory_viz.profile_plot", + "torch.cuda._memory_viz.segment_plot", + "torch.cuda._memory_viz.segments", + "torch.cuda._memory_viz.segsum", + "torch.cuda._memory_viz.trace_plot", + "torch.cuda._memory_viz.trace", + "torch.cuda._nvml_based_avail", + "torch.cuda._parse_visible_devices", + "torch.cuda._raw_device_count_amdsmi", + "torch.cuda._raw_device_count_nvml", + "torch.cuda._raw_device_uuid_amdsmi", + "torch.cuda._raw_device_uuid_nvml", + "torch.cuda._register_triton_kernels", + "torch.cuda._set_rng_state_offset", + "torch.cuda._set_stream_by_id", + "torch.cuda._sleep", + "torch.cuda._transform_uuid_to_ordinals", + "torch.cuda._utils._get_device_index", + "torch.cuda.amp.autocast_mode._cast", + "torch.cuda.amp.autocast_mode.custom_bwd", + "torch.cuda.amp.autocast_mode.custom_fwd", + "torch.cuda.amp.common.amp_definitely_not_available", + "torch.amp.grad_scaler._refresh_per_optimizer_state", + "torch.cuda.can_device_access_peer", + "torch.cuda.check_error", + "torch.cuda.clock_rate", + "torch.cuda.cudart", + "torch.cuda.current_blas_handle", + "torch.cuda.current_stream", + "torch.cuda.default_stream", + "torch.cuda.device_count", + "torch.cuda.get_arch_list", + "torch.cuda.get_device_capability", + "torch.cuda.get_device_name", + "torch.cuda.get_device_properties", + "torch.cuda.get_gencode_flags", + "torch.cuda.get_sync_debug_mode", + "torch.cuda.graphs.graph_pool_handle", + "torch.cuda.graphs.is_current_stream_capturing", + "torch.cuda.graphs.make_graphed_callables", + "torch.cuda.init", + "torch.cuda.ipc_collect", + "torch.cuda.is_available", + "torch.cuda.is_bf16_supported", + "torch.cuda.is_initialized", + "torch.cuda.jiterator._create_jit_fn", + "torch.cuda.jiterator._create_multi_output_jit_fn", + "torch.cuda.memory_usage", + "torch.cuda.memory._dump_snapshot", + "torch.cuda.memory._free_mutex", + "torch.cuda.memory._get_current_allocator", + "torch.cuda.memory._host_allocator", + "torch.cuda.memory._record_memory_history_impl", + "torch.cuda.memory._record_memory_history_legacy", + "torch.cuda.memory._record_memory_history", + "torch.cuda.memory._save_memory_usage", + "torch.cuda.memory._save_segment_usage", + "torch.cuda.memory._set_allocator_settings", + "torch.cuda.memory._snapshot", + "torch.cuda.memory.caching_allocator_alloc", + "torch.cuda.memory.caching_allocator_delete", + "torch.cuda.memory.change_current_allocator", + "torch.cuda.memory.empty_cache", + "torch.cuda.memory.get_allocator_backend", + "torch.cuda.memory.list_gpu_processes", + "torch.cuda.memory.max_memory_allocated", + "torch.cuda.memory.max_memory_cached", + "torch.cuda.memory.max_memory_reserved", + "torch.cuda.memory.mem_get_info", + "torch.cuda.memory.memory_allocated", + "torch.cuda.memory.memory_cached", + "torch.cuda.memory.memory_reserved", + "torch.cuda.memory.memory_snapshot", + "torch.cuda.memory.memory_stats_as_nested_dict", + "torch.cuda.memory.memory_stats", + "torch.cuda.memory.memory_summary", + "torch.cuda.memory.reset_accumulated_memory_stats", + "torch.cuda.memory.reset_max_memory_allocated", + "torch.cuda.memory.reset_max_memory_cached", + "torch.cuda.memory.reset_peak_memory_stats", + "torch.cuda.memory.set_per_process_memory_fraction", + "torch.cuda.nccl._check_sequence_type", + "torch.cuda.nccl.all_gather", + "torch.cuda.nccl.all_reduce", + "torch.cuda.nccl.broadcast", + "torch.cuda.nccl.init_rank", + "torch.cuda.nccl.is_available", + "torch.cuda.nccl.reduce_scatter", + "torch.cuda.nccl.reduce", + "torch.cuda.nccl.unique_id", + "torch.cuda.nccl.version", + "torch.cuda.nvtx.mark", + "torch.cuda.nvtx.range_end", + "torch.cuda.nvtx.range_pop", + "torch.cuda.nvtx.range_push", + "torch.cuda.nvtx.range_start", + "torch.cuda.nvtx.range", + "torch.cuda.power_draw", + "torch.cuda.profiler.init", + "torch.cuda.profiler.profile", + "torch.cuda.profiler.start", + "torch.cuda.profiler.stop", + "torch.cuda.random.get_rng_state_all", + "torch.cuda.random.initial_seed", + "torch.cuda.random.manual_seed_all", + "torch.cuda.random.manual_seed", + "torch.cuda.random.seed_all", + "torch.cuda.random.seed", + "torch.cuda.random.set_rng_state_all", + "torch.cuda.set_stream", + "torch.cuda.set_sync_debug_mode", + "torch.cuda.stream", + "torch.cuda.synchronize", + "torch.cuda.temperature", + "torch.cuda.utilization", + "torch.einsum", + "torch.functional._check_list_size", + "torch.functional._consecutive_return_counts", + "torch.functional._consecutive_return_inverse_false", + "torch.functional._consecutive_return_inverse_true", + "torch.functional._consecutive_return_inverse", + "torch.functional._consecutive_return_output", + "torch.functional._lu_impl", + "torch.functional._lu_no_infos", + "torch.functional._lu_with_infos", + "torch.functional._meshgrid", + "torch.functional._return_counts", + "torch.functional._return_inverse_false", + "torch.functional._return_inverse_true", + "torch.functional._return_inverse", + "torch.functional._return_output", + "torch.functional._unique_consecutive_impl", + "torch.functional._unique_impl", + "torch.functional._unravel_index", + "torch.functional.broadcast_shapes", + "torch.functional.lu", + "torch.functional.unique", + "torch.functional.unravel_index", + "torch.futures.collect_all", + "torch.futures.wait_all", + "torch.fx.experimental.const_fold.split_const_subgraphs", + "torch.fx.experimental.proxy_tensor.make_fx", + "torch.get_deterministic_debug_mode", + "torch.get_float32_matmul_precision", + "torch.is_deterministic_algorithms_warn_only_enabled", + "torch.is_storage", + "torch.is_tensor", + "torch.is_warn_always_enabled", + "torch.masked._ops._any", + "torch.masked._ops._apply_docstring_templates", + "torch.masked._ops._canonical_dim", + "torch.masked._ops._combine_input_and_mask", + "torch.masked._ops._generate_docstring", + "torch.masked._ops._input_mask", + "torch.masked._ops._output_mask", + "torch.masked._ops._reduction_identity", + "torch.masked._ops._sparse_coo_flatten_indices", + "torch.masked._ops._sparse_coo_scatter_reduction_helper", + "torch.masked._ops._sparse_coo_where", + "torch.masked._ops._sparse_csr_segment_reduction_helper", + "torch.masked._ops._sparse_csr_where", + "torch.masked._ops._std_var", + "torch.masked._ops._where", + "torch.masked._ops.amax", + "torch.masked._ops.amin", + "torch.masked._ops.argmax", + "torch.masked._ops.argmin", + "torch.masked._ops.corresponding_real_dtype", + "torch.masked._ops.cumprod", + "torch.masked._ops.cumsum", + "torch.masked._ops.log_softmax", + "torch.masked._ops.logaddexp", + "torch.masked._ops.logsumexp", + "torch.masked._ops.mean", + "torch.masked._ops.median", + "torch.masked._ops.norm", + "torch.masked._ops.normalize", + "torch.masked._ops.prod", + "torch.masked._ops.softmax", + "torch.masked._ops.softmin", + "torch.masked._ops.std", + "torch.masked._ops.sum", + "torch.masked._ops.var", + "torch.meshgrid", + "torch.mps._get_default_mps_generator", + "torch.mps.current_allocated_memory", + "torch.mps.driver_allocated_memory", + "torch.mps.empty_cache", + "torch.mps.get_rng_state", + "torch.mps.manual_seed", + "torch.mps.profiler.profile", + "torch.mps.profiler.start", + "torch.mps.profiler.stop", + "torch.mps.seed", + "torch.mps.set_per_process_memory_fraction", + "torch.mps.set_rng_state", + "torch.mps.synchronize", + "torch.nested._internal.nested_tensor.buffer_from_jagged", + "torch.nested._internal.nested_tensor.get_tensor_symint", + "torch.nested._internal.nested_tensor.is_expandable_to", + "torch.nested._internal.nested_tensor.jagged_from_list", + "torch.nested._internal.nested_tensor.jagged_from_tensor_and_lengths", + "torch.nested._internal.nested_tensor.nested_view_from_values_offsets", + "torch.nested._internal.nested_tensor.nested_view_from_values_offsets_lengths", + "torch.nested.as_nested_tensor", + "torch.nested.narrow", + "torch.nested.nested_tensor", + "torch.nn._reduction.get_enum", + "torch.nn._reduction.legacy_get_enum", + "torch.nn._reduction.legacy_get_string", + "torch.nn.factory_kwargs", + "torch.nn.functional.adaptive_avg_pool2d", + "torch.nn.functional.adaptive_avg_pool3d", + "torch.nn.functional.adaptive_max_pool1d_with_indices", + "torch.nn.functional.adaptive_max_pool1d", + "torch.nn.functional.adaptive_max_pool2d_with_indices", + "torch.nn.functional.adaptive_max_pool2d", + "torch.nn.functional.adaptive_max_pool3d_with_indices", + "torch.nn.functional.adaptive_max_pool3d", + "torch.nn.functional.affine_grid", + "torch.nn.functional.alpha_dropout", + "torch.nn.functional.assert_int_or_pair", + "torch.nn.functional.batch_norm", + "torch.nn.functional.binary_cross_entropy_with_logits", + "torch.nn.functional.binary_cross_entropy", + "torch.nn.functional.celu", + "torch.nn.functional.cosine_embedding_loss", + "torch.nn.functional.cross_entropy", + "torch.nn.functional.ctc_loss", + "torch.nn.functional.dropout", + "torch.nn.functional.dropout1d", + "torch.nn.functional.dropout2d", + "torch.nn.functional.dropout3d", + "torch.nn.functional.elu", + "torch.nn.functional.embedding_bag", + "torch.nn.functional.embedding", + "torch.nn.functional.feature_alpha_dropout", + "torch.nn.functional.fold", + "torch.nn.functional.fractional_max_pool2d_with_indices", + "torch.nn.functional.fractional_max_pool2d", + "torch.nn.functional.fractional_max_pool3d_with_indices", + "torch.nn.functional.fractional_max_pool3d", + "torch.nn.functional.gaussian_nll_loss", + "torch.nn.functional.glu", + "torch.nn.functional.grid_sample", + "torch.nn.functional.group_norm", + "torch.nn.functional.gumbel_softmax", + "torch.nn.functional.hardsigmoid", + "torch.nn.functional.hardswish", + "torch.nn.functional.hardtanh", + "torch.nn.functional.hinge_embedding_loss", + "torch.nn.functional.huber_loss", + "torch.nn.functional.instance_norm", + "torch.nn.functional.interpolate", + "torch.nn.functional.kl_div", + "torch.nn.functional.l1_loss", + "torch.nn.functional.layer_norm", + "torch.nn.functional.leaky_relu", + "torch.nn.functional.local_response_norm", + "torch.nn.functional.log_softmax", + "torch.nn.functional.lp_pool1d", + "torch.nn.functional.lp_pool2d", + "torch.nn.functional.margin_ranking_loss", + "torch.nn.functional.max_pool1d_with_indices", + "torch.nn.functional.max_pool1d", + "torch.nn.functional.max_pool2d_with_indices", + "torch.nn.functional.max_pool2d", + "torch.nn.functional.max_pool3d_with_indices", + "torch.nn.functional.max_pool3d", + "torch.nn.functional.max_unpool1d", + "torch.nn.functional.max_unpool2d", + "torch.nn.functional.max_unpool3d", + "torch.nn.functional.mish", + "torch.nn.functional.mse_loss", + "torch.nn.functional.multi_head_attention_forward", + "torch.nn.functional.multi_margin_loss", + "torch.nn.functional.multilabel_margin_loss", + "torch.nn.functional.multilabel_soft_margin_loss", + "torch.nn.functional.nll_loss", + "torch.nn.functional.normalize", + "torch.nn.functional.poisson_nll_loss", + "torch.nn.functional.relu", + "torch.nn.functional.relu6", + "torch.nn.functional.rrelu", + "torch.nn.functional.selu", + "torch.nn.functional.sigmoid", + "torch.nn.functional.silu", + "torch.nn.functional.smooth_l1_loss", + "torch.nn.functional.soft_margin_loss", + "torch.nn.functional.softmax", + "torch.nn.functional.softmin", + "torch.nn.functional.softsign", + "torch.nn.functional.tanh", + "torch.nn.functional.tanhshrink", + "torch.nn.functional.triplet_margin_loss", + "torch.nn.functional.unfold", + "torch.nn.functional.upsample_bilinear", + "torch.nn.functional.upsample_nearest", + "torch.nn.functional.upsample", + "torch.nn.grad._pair", + "torch.nn.grad._single", + "torch.nn.grad._triple", + "torch.nn.grad.conv1d_input", + "torch.nn.grad.conv1d_weight", + "torch.nn.grad.conv2d_input", + "torch.nn.grad.conv2d_weight", + "torch.nn.grad.conv3d_input", + "torch.nn.grad.conv3d_weight", + "torch.nn.modules.activation._is_make_fx_tracing", + "torch.nn.modules.utils._list_with_default", + "torch.nn.modules.utils._ntuple", + "torch.nn.modules.utils._quadruple", + "torch.nn.modules.utils._reverse_repeat_tuple", + "torch.nn.modules.utils.consume_prefix_in_state_dict_if_present", + "torch.nn.parameter.is_lazy", + "torch.norm", + "torch.quantization.default_eval_fn", + "torch.random._seed_custom_device", + "torch.random.fork_rng", + "torch.random.initial_seed", + "torch.random.seed", + "torch.return_types.pytree_register_structseq", + "torch.set_default_device", + "torch.set_default_dtype", + "torch.set_default_tensor_type", + "torch.set_deterministic_debug_mode", + "torch.set_float32_matmul_precision", + "torch.set_warn_always", + "torch.signal.windows.windows._add_docstr", + "torch.signal.windows.windows._window_function_checks", + "torch.signal.windows.windows.bartlett", + "torch.signal.windows.windows.blackman", + "torch.signal.windows.windows.cosine", + "torch.signal.windows.windows.exponential", + "torch.signal.windows.windows.gaussian", + "torch.signal.windows.windows.general_cosine", + "torch.signal.windows.windows.general_hamming", + "torch.signal.windows.windows.hamming", + "torch.signal.windows.windows.hann", + "torch.signal.windows.windows.kaiser", + "torch.signal.windows.windows.merge_dicts", + "torch.signal.windows.windows.nuttall", + "torch.signal.windows.windows.parse_kwargs", + "torch.sparse.semi_structured.to_sparse_semi_structured", + "torch.sparse.sum", + "torch.split", + "torch.stft", + "torch.sym_float", + "torch.sym_int", + "torch.sym_ite", + "torch.sym_max", + "torch.sym_min", + "torch.sym_not", + "torch.tensordot", + "torch.typename", + "torch.unique_consecutive", + "torch.use_deterministic_algorithms", + ], + TorchInGraphFunctionVariable, +) + + +torch_name_rule_map = [ + manual_torch_name_rule_map, + torch_c_binding_in_graph_functions, + torch_non_c_binding_in_graph_functions, +] + + +""" +Generate the torch object - Dynamo tracing rule (the wrapping variable) map. +""" + + +@functools.lru_cache(None) +def get_torch_obj_rule_map() -> Dict[Any, Type["VariableTracker"]]: + d: Dict[Any, Type[VariableTracker]] = {} + for m in torch_name_rule_map: + for k, v in m.items(): # type: ignore[attr-defined] + if ".py#" not in k: + obj = load_object(k) + else: + obj = _module_dir(torch) + k[len("torch/") :] + if obj is not None: + if obj in d and d[obj] != v: + raise AssertionError( + f"Duplicate torch object {obj} with different rules: {v}, {d[obj]}" + ) + else: + d[obj] = v + return d + + +def _load_obj_from_str(fully_qualified_name): + module, obj_name = fully_qualified_name.rsplit(".", maxsplit=1) + return getattr(importlib.import_module(module), obj_name) + + +""" +Load string represented torch objects. +""" + + +def load_object(name): + try: + x = name.split("#") + if len(x) == 2: + obj = _load_obj_from_str(x[0]) + val = getattr(obj, x[1]) + else: + assert len(x) == 1, f"Invalid obj name {name}" + val = _load_obj_from_str(x[0]) + val = unwrap_if_wrapper(val) + except (AttributeError, ImportError): + val = None + return val + + +""" +Get all torch.Tensor methods which are allowed to be in graph functions. +""" + + +@functools.lru_cache(None) +def get_tensor_method(): + s = set() + for name in dir(torch.Tensor): + method = getattr(torch.Tensor, name) + if isinstance( + method, (types.MethodDescriptorType, types.WrapperDescriptorType) + ): + s.add(method) + return frozenset(s) + + +""" +Return if a torch object is ATen op or torch.Tensor method. +""" + + +def is_aten_op_or_tensor_method(obj): + return obj in get_tensor_method() or isinstance( + obj, + (torch._ops.OpOverloadPacket, torch._ops.OpOverload), + ) + + +class FunctionIdSet: + """ + Track a set of `id()`s of objects which are either allowed or not + allowed to go into the generated FX graph. Use to test for torch.*, + numpy.*, builtins.*, etc. + + Support user modification to permit customization of what can be + added to the graph and what will cause a graph break. + """ + + function_ids: Optional[Set[int]] = None + function_names: Optional[Dict[int, str]] = None + + def __init__( + self, lazy_initializer: Callable[[], Union[Dict[int, str], Set[int]]] + ) -> None: + self.lazy_initializer = lazy_initializer + + def __call__(self) -> Set[int]: + if self.function_ids is None: + value = self.lazy_initializer() + if isinstance(value, dict): + self.function_ids = set(value.keys()) + self.function_names = value + else: + assert isinstance(value, set) + self.function_ids = value + return self.function_ids + + def get_name(self, idx: int, default: str): + self() # lazy init + assert self.function_names is not None + return self.function_names.get(idx, default) + + def add(self, idx: int): + function_ids = self() # lazy init + function_ids.add(idx) + + def remove(self, idx: int): + function_ids = self() + if idx in function_ids: + function_ids.remove(idx) + + def __contains__(self, idx: int) -> bool: + return idx in self() + + +@FunctionIdSet +def _allowed_callable_ids() -> Dict[int, str]: + rv: Dict[int, str] = {} + return rv + + +@FunctionIdSet +def _disallowed_callable_ids() -> Dict[int, str]: + rv: Dict[int, str] = {} + return rv + + +@FunctionIdSet +def _builtin_function_ids() -> Dict[int, str]: + # See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids + rv = { + id(v): f"builtins.{k}" + for k, v in builtins.__dict__.items() + if not k.startswith("_") and callable(v) + } + rv.update( + { + id(v): f"operator.{k}" + for k, v in operator.__dict__.items() + if not k.startswith("_") and callable(v) + } + ) + rv.update( + { + id(cast): "typing.cast", + id(functools.reduce): "functools.reduce", + id(copy.deepcopy): "copy.deepcopy", + } + ) + return rv + + +@FunctionIdSet +def _numpy_function_ids() -> Dict[int, str]: + rv = {} + for mod in NP_SUPPORTED_MODULES: + rv.update( + { + id(v): f"{mod.__name__}.{k}" + for k, v in mod.__dict__.items() + if callable(v) + and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__ + } + ) + return rv + + +@FunctionIdSet +def _builtin_constant_ids() -> Dict[int, str]: + """ + Collects constant builtins by eliminating callable items. + """ + rv = { + id(v): f"builtins.{k}" + for k, v in builtins.__dict__.items() + if not k.startswith("_") and not callable(v) + } + return rv + + +_lazy_module_init: Dict[str, List[Callable[[], None]]] = defaultdict(list) + + +def add_module_init_func(name: str, init_func: Callable[[], None]) -> None: + """Register a module without eagerly importing it""" + # If the module is already imported, eagerly run init + assert "." not in name, f"Expected a root module name, but got {name}" + assert name not in _lazy_module_init + _lazy_module_init[name].append(init_func) + + +def _maybe_init_lazy_module(obj: object) -> None: + module = getattr(obj, "__module__", None) + if module is None: + return + + base_module = module.split(".")[0] + init_funcs = _lazy_module_init.pop(base_module, None) + if init_funcs is not None: + for fn in init_funcs: + fn() + + +def is_callable_allowed(obj) -> bool: + _maybe_init_lazy_module(obj) + return id(obj) in _allowed_callable_ids + + +def is_callable_disallowed(obj) -> bool: + _maybe_init_lazy_module(obj) + return id(obj) in _disallowed_callable_ids + + +def is_forbidden(obj) -> bool: + _maybe_init_lazy_module(obj) + return inspect.getattr_static(obj, "_dynamo_forbidden", False) + + +def is_builtin_callable(obj) -> bool: + # See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids + return id(obj) in _builtin_function_ids + + +def is_builtin_constant(obj) -> bool: + return id(obj) in _builtin_constant_ids + + +def is_numpy(obj) -> bool: + if np is None: + return False + return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids + + +def is_numpy_dtype(obj) -> bool: + if np is None: + return False + return isinstance(obj, np.dtype) + + +def is_numpy_type_info(obj) -> bool: + if np is None: + return False + return isinstance(obj, (np.finfo, np.iinfo)) + + +BUILTIN_SKIPLIST = ( + abc, + collections, + contextlib, + copy, + copyreg, + dataclasses, + enum, + functools, + importlib, + inspect, + linecache, + logging, + multiprocessing, + operator, + posixpath, + random, + re, + selectors, + signal, + tempfile, + threading, + tokenize, + torch, # torch/* is skipped by default unless specified in FUNC_INLINELIST or MOD_INLINELIST + traceback, + types, + typing, + unittest, + weakref, + _collections_abc, + _weakrefset, +) + +# third party libraries skiplist is defined by str, because users may not use these libraries. +# we should use lazy import & skip in the future. +THIRDPARTY_SKIPLIST = ( + "fx2trt_oss", + "hypothesis", + "networkx", + "numpy", + "omegaconf", + "onnx", + "onnxruntime", + "onnx_tf", + "pandas", + "sklearn", + "tabulate", + "tensorflow", + "tensorrt", + "torch2trt", + "tqdm", + "tree", + "tvm", + "xarray", +) + + +def _as_posix_path(path): + posix_path = Path(os.path.normpath(path)).as_posix() + # os.path.normpath and pathlib.Path remove trailing slash, so we need to add it back + if path.endswith((os.path.sep, "/")): + posix_path += "/" + return posix_path + + +def _strip_init_py(s): + # TODO: Once we require py3.9 use removesuffix instead. + suffix = "__init__.py" + if s.endswith(suffix): + s = s[: -len(suffix)] + return _as_posix_path(s) + + +def _module_dir(m: types.ModuleType): + # Protect against a module not exporting __file__ - this can happen for + # frozen modules, for example. + file = getattr(m, "__file__", None) + return file and _strip_init_py(file) + + +# These are legacy workarounds, don't add new modules to this list. +# Please use the MOD_INLINELIST instead to force inline functions under particular modules. +LEGACY_MOD_INLINELIST = { + "torch._dynamo.external_utils", + "torch._export.db.examples", + "torch._export.wrappers", + "torch._functorch.apis", + "torch._functorch.deprecated", + "torch._higher_order_ops.cond", + "torch._higher_order_ops.while_loop", + "torch._higher_order_ops.associative_scan", + "torch.nn.attention.flex_attention", + "torch.ao.quantization.pt2e.export_utils", + "torch.ao.quantization.pt2e.qat_utils", + "torch.ao.quantization.pt2e.representation.rewrite", + "torch.ao.quantization.pt2e.utils", + "torch.ao.quantization.quantizer.xnnpack_quantizer", + "torch.export.unflatten", + "torch.optim", +} + +if torch.distributed.is_available(): + LEGACY_MOD_INLINELIST |= { + "torch.distributed.tensor._api", + "torch.distributed.tensor.device_mesh", + "torch.distributed.device_mesh", + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper", + "torch.distributed.tensor.parallel._data_parallel_utils", + "torch.distributed.tensor.parallel._utils", + "torch.distributed.tensor.parallel.style", + # we have to add replicate to LEGACY_MOD_INLINELIST to ensure + # the forward_hook won't be ignored. + "torch.distributed._composable.replicate", + } + if not torch._dynamo.config.skip_fsdp_hooks: + LEGACY_MOD_INLINELIST.add("torch.distributed._composable.fsdp") + + +# Force inline functions under these modules, even they are in *_SKIPLIST. +# We are using python module name instead of file or directory object to avoid circular dependency. +# Please keep this sorted alphabetically. +MOD_INLINELIST = [ + "torch._decomp", + "torch._dynamo._trace_wrapped_higher_order_op", + "torch._dynamo.comptime", + "torch._dynamo.polyfills", + "torch._functorch.autograd_function", + "torch._functorch.eager_transforms", + "torch._functorch.functional_call", + "torch._functorch.vmap", + "torch._higher_order_ops.associative_scan", + "torch._higher_order_ops.strict_mode", + "torch._higher_order_ops.while_loop", + "torch._inductor.test_operators", + "torch._library.autograd", + "torch._library.custom_ops", + "torch._prims", + "torch._refs", + "torch._tensor", + "torch.amp.autocast_mode", + "torch.ao.nn", + "torch.autograd.function", + "torch.backends.cuda", + "torch.cuda.amp.autocast_mode", + "torch.distributions", + "torch.export._tree_utils", + "torch.fx._pytree", + "torch.fx._symbolic_trace", + "torch.fx.experimental.proxy_tensor", + "torch.fx.passes.shape_prop", + "torch.nn", + "torch.overrides", + "torch.random", + "torch.sparse", + "torch.testing", + "torch.utils._content_store", + "torch.utils._contextlib", + "torch.utils._foreach_utils", + "torch.utils._python_dispatch", + "torch.utils._pytree", + "torch.utils.hooks", +] +assert sorted(set(MOD_INLINELIST)) == MOD_INLINELIST +MOD_INLINELIST = set(MOD_INLINELIST) + + +if torch.distributed.is_available(): + MOD_INLINELIST.add("torch.distributed") + if not torch._dynamo.config.skip_fsdp_hooks: + MOD_INLINELIST.add("torch.distributed._composable.fsdp") + + +@functools.lru_cache(None) +def get_legacy_mod_inlinelist(): + inlinelist = { + _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + for m in LEGACY_MOD_INLINELIST + } + return inlinelist + + +@functools.lru_cache(None) +def get_mod_inlinelist(): + inlinelist = { + _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + for m in MOD_INLINELIST + } + return inlinelist + + +# skip some standard python builtin libs +SKIP_DIRS = [ + "", + _as_posix_path(_config_module.__file__), + "triton/backends", +] +SKIP_DIRS.extend(map(_as_posix_path, filter(None, map(_module_dir, BUILTIN_SKIPLIST)))) + +SKIP_DIRS_RE = re.compile(r"match nothing^") + +is_fbcode = importlib.import_module("torch._inductor.config").is_fbcode() +# Skip fbcode paths(including torch.package paths) containing +# one of the following strings. +FBCODE_SKIP_DIRS: Set[str] = set() + +FBCODE_SKIP_DIRS_RE = re.compile(f".*({'|'.join(map(re.escape, FBCODE_SKIP_DIRS))})") + +# Remove this after fbcode is fully migrated to tracing through torchrec. +FBCODE_SKIP_TORCHREC_DIRS = { + "torchrec/distributed", + "trochrec/fb/distributed", + "caffe2/torch/fb/sparsenn/pooled_embeddings_modules.py", +} + +FBCODE_SKIP_TORCHREC_DIRS_RE = re.compile( + f".*({'|'.join(re.escape(_as_posix_path(d)) for d in FBCODE_SKIP_TORCHREC_DIRS)})" +) + +# TODO(yanboliang, anijain2305) - There are a few concerns that we should +# resolve +# 1) Audit if torchrec/distributed is even required in FBCODE_SKIPS_DIR +# 2) To inline just one file but skip others in a directory, we could use +# manual_torch_name_rule_map but this one is hard because FBCODE can add unusual +# names like torch_package. +# So, this is a stop gap solution till then. +FBCODE_INLINE_FILES_IN_SKIPPED_DIRS = { + "torchrec/distributed/types.py", +} +FBCODE_INLINE_FILES_IN_SKIPPED_DIRS_RE = re.compile( + f".*({'|'.join(re.escape(_as_posix_path(d)) for d in FBCODE_INLINE_FILES_IN_SKIPPED_DIRS)})" +) + +# torch.optim is a special case, +# we usually want to inline it, but the directory +# structure does not match the module structure +# and we want to skip the functions in optim/lr_scheduler.py +# this has precedence over all other rules in check_file +FORCE_SKIP_FILES = {f"{_module_dir(torch)}optim/lr_scheduler.py"} + + +def _recompile_re(): + global SKIP_DIRS_RE + SKIP_DIRS_RE = re.compile( + rf"^[^\s<]*({'|'.join(re.escape(_as_posix_path(d)) for d in SKIP_DIRS)})" + ) + + +def add(import_name: str): + if isinstance(import_name, types.ModuleType): + return add(import_name.__name__) + assert isinstance(import_name, str) + from importlib.util import find_spec + + module_spec = find_spec(import_name) + if not module_spec: + return + origin = module_spec.origin + if origin is None: + return + SKIP_DIRS.append(_strip_init_py(origin)) + _recompile_re() + + +@dataclasses.dataclass +class SkipResult: + skipped: bool + reason: Optional[str] + + +def check_file(filename, is_inlined_call=False): + """Should skip this file?""" + if filename is None: + return SkipResult(True, "filename is None") + filename = _as_posix_path(filename) + if filename in FORCE_SKIP_FILES: + return SkipResult(True, "FORCE_SKIP_FILES") + if any(filename.startswith(d) for d in get_legacy_mod_inlinelist()): + return SkipResult( + False, + "LEGACY_MOD_INLINELIST", + ) + if is_inlined_call and is_torch_inline_allowed(filename): + return SkipResult( + False, + "MOD_INLINELIST", + ) + if ( + is_fbcode + and FBCODE_SKIP_DIRS + and bool(FBCODE_SKIP_DIRS_RE.match(filename)) + and not bool(FBCODE_INLINE_FILES_IN_SKIPPED_DIRS_RE.match(filename)) + ): + return SkipResult( + True, + "FBCODE_SKIP_DIRS", + ) + + if ( + is_fbcode + and torch._dynamo.config.skip_torchrec + and FBCODE_SKIP_TORCHREC_DIRS + and bool(FBCODE_SKIP_TORCHREC_DIRS_RE.match(filename)) + and not bool(FBCODE_INLINE_FILES_IN_SKIPPED_DIRS_RE.match(filename)) + ): + return SkipResult(True, "FBCODE_SKIP_TORCHREC_DIRS") + + if bool(SKIP_DIRS_RE.match(filename)): + return SkipResult(True, "SKIP_DIRS") + else: + return SkipResult(False, "inlined by default") + + +@dataclasses.dataclass +class FunctionInfo: + py_obj: Optional[object] + name: Optional[str] + filename: str + code: Optional[types.CodeType] + + +""" +This is the main entry point to determine whether an object (function) should be inlined or skipped. +Let's illustrate the logic with an example: + @torch.compile + def f1(x, y): + ...... + f2(x, y) + ...... + + def f2(x, y): + ...... + f3(x, y) + ...... + + def f3(x, y): + ...... + +There are mainly three call sites of check/check_verbose: +* The compile region entrance (like function f1), the correspoinding code is located at eval_frame.py. +* When tracing the recursively called functions (like function f2 and f3). + * Dynamo decides inline/skip everytime it encounters a new recursively function call, and the call site + is in InliningInstructionTranslator.check_inlineable of symbolic_convert.py. + * If f2 is skipped by Dynamo, when evaluating the frame of f3, Dynamo need the inline/skip check again + and the call site is in catch_errors_wrapper.catch_errors of convert_frame.py. +* For global variables and function arguments, Dynamo needs to decide if they are wrapped as SkipFunctionVariable in builder.py. + +`is_inlined_call` is used to indicate if the current function call is inlined (f2 is inlined call if it passes check) +or not (f3 is not inlined call if f2 is skipped). Inside of the `check_verbose` function, there are more rules +to be checked if this `is_inlined_call`. +The reason to have this flag is that if the upper level function call (e.g, f2) is skipped, +we don't want to inline the lower level function call (e.g, f3) by default. +""" + + +def check_verbose(obj, is_inlined_call=False): + if isinstance( + obj, (UserFunctionVariable, UserMethodVariable, NestedUserFunctionVariable) + ): + try: + py_obj = obj.get_function() + except NotImplementedError: + py_obj = None + fi = FunctionInfo(py_obj, obj.get_name(), obj.get_filename(), obj.get_code()) + elif isinstance(obj, types.CodeType): + fi = FunctionInfo(None, obj.co_name, obj.co_filename, obj) + elif isinstance(obj, (types.FunctionType, types.MethodType)): + fi = FunctionInfo( + obj, obj.__name__, getfile(obj), obj.__code__ # type: ignore[union-attr] # FIXME Add MethodType.__code__ to typeshed + ) + else: + fi = FunctionInfo(obj, None, getfile(obj), None) + + # Consulte the central trace rules defined in torch._dynamo.trace_rules. + reasons: Set[str] = set() + rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) + if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)): + return SkipResult( + False, + f"inlined according trace_rules.lookup {reasons.pop()}", + ) + else: + assert rule == SkipFunctionVariable, rule + return SkipResult( + True, + f"skipped according trace_rules.lookup {reasons.pop()}", + ) + + +def check(obj, is_inlined_call=False): + return check_verbose(obj, is_inlined_call).skipped + + +# skip common third party libs +for _name in THIRDPARTY_SKIPLIST: + add(_name) + +_recompile_re() + + +def is_torch_inline_allowed(filename): + return any(filename.startswith(d) for d in get_mod_inlinelist()) + + +@functools.lru_cache(None) +def dynamo_dir(): + import torch._dynamo + + return _module_dir(torch._dynamo) + + +def is_torch(filename): + if filename.startswith(dynamo_dir()): + return False + return filename.startswith(_module_dir(torch)) + + +""" +Main entry point for looking up the trace rule (the Dynamo variable) for a given callable object. +""" + + +def lookup_callable(obj): + if not hashable(obj): + return None + # Custom allow/disallow in graph takes precedence over the general lookup. + if is_callable_disallowed(obj): + return SkipFunctionVariable + if is_callable_allowed(obj): + return TorchInGraphFunctionVariable + if is_builtin_callable(obj): + return BuiltinVariable + return None + + +""" +Main entry point for looking up the trace rule (the Dynamo variable) for a given function object. +E.g, the lookup result of `torch.sin` is `TorchInGraphFunctionVariable`. +""" + + +def lookup(obj): + return lookup_inner(obj) + + +def lookup_inner( + obj, + name=None, + filename=None, + is_direct_call=True, + reasons: Union[None, Set[str]] = None, +): + # Step 1: lookup obj's tracing rule in `torch_name_rule_map`. + # The rules defined in `torch_name_rule_map` mainly includes two parts: + # - Manually defined rules for any functions. + # - The list of torch in graph functions. + try: + can_hash = hashable(obj) + except Exception: + can_hash = False + if not can_hash: + if reasons is not None: + reasons.add("obj is not hashable") + return None + if obj is not None: + if is_aten_op_or_tensor_method(obj): + return TorchInGraphFunctionVariable + rule = get_torch_obj_rule_map().get(obj, None) + if rule is not None: + if reasons is not None: + reasons.add("get_torch_obj_rule_map") + return rule + elif name is not None and filename is not None and not is_direct_call: + if name.startswith(TORCH_DYNAMO_RESUME_IN_PREFIX): + rule = get_torch_obj_rule_map().get( + filename + "#" + TORCH_DYNAMO_RESUME_IN_PREFIX, None + ) + else: + rule = get_torch_obj_rule_map().get(filename + "#" + name, None) + if rule is not None: + if reasons is not None: + reasons.add("get_torch_obj_rule_map") + return rule + + # Step 2: lookup obj's tracing rule by function name. + if is_direct_call: + if name == "patched_init": + if reasons is not None: + reasons.add("func name is patched_init") + return SkipFunctionVariable + elif name == "__torch_function__": + if reasons is not None: + reasons.add("func name is __torch_function__") + return UserFunctionVariable + + if not is_direct_call: + if name == "__getattr__": + # is_direct_call = False indicates that this is the top-level frame + # being traced (i.e., it is not inlined and not called from + # InliningInstructionTranslator). Tracing __getattr__ at the top + # level is unlikely because we inline it for + # UserDefinedObjectVariable. This scenario occurs only for + # UnspecializedNNModuleVariable, where Dynamo directly calls + # __getattr__ during trace time, generating LOAD_ATTR bytecode + # without going through the underlying __getattr__ data structures. + # When this optimized bytecode is executed, Dynamo is triggered + # again on the __getattr__ call. Therefore, we skip Dynamo tracing + # in this case. + if reasons is not None: + reasons.add( + "Tracing __getattr__ as the top level frame, unsuitable for tracing." + ) + return SkipFunctionVariable + + # Step 3: lookup obj's tracing rule by filename. + if filename is None: + filename = getfile(obj) + + skip_result = check_file(filename, is_direct_call) + if reasons is not None: + reasons.add(skip_result.reason) + if skip_result.skipped: + return SkipFunctionVariable + else: + return UserFunctionVariable + + +def clear_lru_cache(): + torch._dynamo.trace_rules.get_torch_obj_rule_map.cache_clear() + torch._dynamo.trace_rules.get_tensor_method.cache_clear() + torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear() + torch._dynamo.trace_rules.get_mod_inlinelist.cache_clear() + torch._dynamo.trace_rules.dynamo_dir.cache_clear() diff --git a/lib/python3.10/site-packages/torch/_dynamo/types.py b/lib/python3.10/site-packages/torch/_dynamo/types.py new file mode 100644 index 0000000000000000000000000000000000000000..8cab8ed5197fc354970080a072c7d50a2171b906 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/types.py @@ -0,0 +1,96 @@ +import dataclasses +import sys +import types +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union + +# CacheEntry has a `check_fn` field for the guard, and a `code` field for the code object. +from torch._C._dynamo.eval_frame import ( + _CacheEntry as CacheEntry, + _ExtraState as ExtraState, +) +from torch._guards import CompileId + + +if sys.version_info >= (3, 11): + from torch._C._dynamo.eval_frame import _PyInterpreterFrame as DynamoFrameType +else: + from types import FrameType as DynamoFrameType + + +# We use a dict to store additional data per frame. +FrameState = Dict[Any, Any] + + +class GuardFail(NamedTuple): + # A string repr of the piece of failed guard code we eval-ed + reason: str + # A code object where we failed a guard + orig_code: types.CodeType + + +class GuardFn(Protocol): + closure_vars: Dict[str, object] + args: List[str] + code_parts: List[str] + verbose_code_parts: List[str] + global_scope: Dict[str, object] + guard_fail_fn: Optional[Callable[[GuardFail], None]] + cache_entry: Optional[CacheEntry] + extra_state: Optional[ExtraState] + + # maps locals of user function to bool + def __call__(self, f_locals: Dict[str, object]) -> bool: + ... + + +@dataclasses.dataclass +class GuardedCode: + code: types.CodeType + check_fn: GuardFn + compile_id: CompileId + + +class DynamoCallbackFn(Protocol): + def __call__( + self, + frame: DynamoFrameType, + cache_entry: Optional[CacheEntry], + frame_state: FrameState, + ) -> Optional[GuardedCode]: + ... + + +DynamoCallback = Union[DynamoCallbackFn, None, bool] + + +class DynamoGuardHook(Protocol): + def __call__( + self, + guard_fn: GuardFn, + code: types.CodeType, + f_locals: Dict[str, object], + index: int, + last: bool, + ) -> None: + ... + + +class ProfilerStartHook(Protocol): + def __call__( + self, + name: str, + # TODO(whc) how do I annotate a _RecordFunction here? + ) -> Any: + ... + + +class ProfilerEndHook(Protocol): + def __call__(self, record: Any) -> None: + ... + + +class BytecodeHook(Protocol): + def __call__( + self, code: types.CodeType, new_code: types.CodeType + ) -> Optional[types.CodeType]: + ... diff --git a/lib/python3.10/site-packages/torch/_dynamo/utils.py b/lib/python3.10/site-packages/torch/_dynamo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b0cbbb75d2a192abdadd3c131cefb726a8e63b5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_dynamo/utils.py @@ -0,0 +1,3181 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import atexit +import collections +import contextlib +import copy +import dataclasses +import datetime +import dis +import enum +import functools +import gc +import importlib +import inspect +import itertools +import linecache +import logging +import math +import operator +import os +import re +import sys +import threading +import time +import types +import typing +import uuid +import warnings +import weakref +from contextlib import contextmanager +from dataclasses import is_dataclass +from functools import lru_cache +from types import MethodWrapperType +from typing import ( + Any, + Callable, + cast, + ClassVar, + Counter, + DefaultDict, + Deque, + Dict, + Iterable, + Iterator, + KeysView, + List, + Optional, + overload, + Set, + Tuple, + Type, + TypeVar, + Union, + ValuesView, +) +from typing_extensions import Literal, TypeGuard + +import torch +import torch._functorch.config +import torch._inductor.config as inductor_config +import torch.fx.experimental.symbolic_shapes +import torch.utils._pytree as pytree +from torch import fx +from torch._C import ( + _get_function_stack_at, + _instruction_counter, + _len_torch_function_stack, + _pop_torch_function_stack, + _push_on_torch_function_stack, +) +from torch._dispatch.python import enable_python_dispatcher +from torch._guards import Source, TracingContext +from torch._subclasses.meta_utils import is_sparse_compressed +from torch._utils_internal import log_chromium_event_internal, log_compilation_event +from torch.fx._utils import _format_graph_code, lazy_format_graph_code +from torch.nn.modules.lazy import LazyModuleMixin +from torch.utils._triton import has_triton, has_triton_package +from torch.utils.hooks import RemovableHandle + + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + +try: + import torch._logging + import torch._numpy as tnp + from torch._guards import detect_fake_mode # noqa: F401n + from torch._logging import LazyString + + from . import config + + # NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync. + if np: + NP_SUPPORTED_MODULES: Tuple[types.ModuleType, ...] = ( + np, + np.fft, + np.linalg, + np.random, + ) + + NP_TO_TNP_MODULE = { + np: tnp, + np.fft: tnp.fft, + np.linalg: tnp.linalg, + np.random: tnp.random, + } + else: + NP_SUPPORTED_MODULES = () + + NP_TO_TNP_MODULE = {} + from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode +except ImportError: + pass + + +T = TypeVar("T") + +unpatched_nn_module_getattr = torch.nn.Module.__getattr__ + +counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter) +optimus_scuba_log: Dict[str, Any] = {} +troubleshooting_url = ( + "https://pytorch.org/docs/main/torch.compiler_troubleshooting.html" +) +nnmodule_doc_url = "https://pytorch.org/docs/main/torch.compiler_nn_module.html" +nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations." +log = logging.getLogger(__name__) + +# profiling compilation time by function +compilation_time_metrics: Dict[str, List[float]] = {} + +# profiling compilation time by frame phase +frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict( + lambda: collections.defaultdict(float) +) + +timer_counter = itertools.count() + + +def tabulate( + rows: Union[List[Tuple[str, object]], List[List[object]]], + headers: Union[Tuple[str, ...], List[str]], +) -> str: + try: + import tabulate + + return tabulate.tabulate(rows, headers=headers) + except ImportError: + return "\n".join( + ", ".join(map(str, row)) for row in itertools.chain([headers], rows) + ) + + +curr_frame = 0 + + +# Note: Called for you by dynamo - you almost never ever want to invoke this yourself. +def increment_frame() -> None: + global curr_frame + curr_frame = curr_frame + 1 + + +# Note: Called for you by dynamo - you almost never ever want to invoke this yourself. +def reset_frame_count() -> None: + global curr_frame + frame_phase_timing.clear() + compilation_time_metrics.clear() + curr_frame = 0 + + +op_count = 0 + + +def increment_op_count(cnt: int) -> None: + global op_count + op_count += cnt + + +# Calculate total time spent so far for each phase +# For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806} +def calculate_time_spent() -> Dict[str, float]: + total_wall_time = 0.0 + total_by_key = {} + for timings in frame_phase_timing.values(): + total_wall_time += timings.get( + "entire_frame_compile", timings.get("inductor_compile", 0) + ) + + for key, timing in timings.items(): + if key not in total_by_key: + total_by_key[key] = timing + else: + total_by_key[key] += timing + + if total_by_key: + total_by_key["total_wall_time"] = total_wall_time + + return total_by_key + + +# Print a report of time spent so far +# Ex: +# TIMING: +# entire_frame_compile:8.574629999999999 +# backend_compile:5.26806 +def print_time_report() -> None: + total_by_key = calculate_time_spent() + + out = "TIMING:" + for key, value in total_by_key.items(): + out = f"{out} {key}:{round(value, 5)}" + + print(out) + + +def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None: + frame_phase_timing[key][phase_name] += time_spent + + +def get_cache_stats() -> Dict[str, Any]: + """Get a bunch of metadata about cache hits and misses to use in chromium events""" + cache_stats = { + "fxgraph_cache_hit": counters["inductor"]["fxgraph_cache_hit"], + "fxgraph_cache_miss": counters["inductor"]["fxgraph_cache_miss"], + "fxgraph_cache_bypass": counters["inductor"]["fxgraph_cache_bypass"], + } + return cache_stats + + +# dynamo_timed is a context manager +# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics +# where the key is the functions name. +# For example: +# +# def _foo(...): +# with dynamo_timed("_foo"): +# ... +# +# Would show up as an entry in our timing dict: +# OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])]) +# This is extremely useful for granular debugging. +# +# Although it is tempting to use dynamo_timed as a decorator, please do not. +# In its decorator form it makes cProfile traces less useful as dynamo_timed +# suddenly becomes a bottleneck for lots of function calls (as only one parent +# pointer is recorded). +# +# For a higher-level mode, pass a phase_name into dynamo_timed +# phase_names record an extra record into a separate compilation timing structure, +# one keyed on frame+name rather than function. +# The frame is incremented outside of this function, in def increment_frame() above. +# `fwd_only` is used to identify if this phase or function is only called +# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`. +# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs. + + +@contextmanager +def dynamo_timed( + key: str, + phase_name: Optional[str] = None, + fwd_only: bool = True, +): + chromium_log: ChromiumEventLogger = get_chromium_event_logger() + if key not in compilation_time_metrics: + compilation_time_metrics[key] = [] + + fail_type: Optional[str] = None + fail_reason: Optional[str] = None + time_spent = float("-inf") + start = time.time_ns() + try: + with torch.profiler.record_function(f"{key} (dynamo_timed)"): + t0 = time.time() + chromium_log.log_event_start(key, start, None) + if phase_name: + chromium_log.log_event_start(phase_name, start) + yield + time_spent = time.time() - t0 + compilation_time_metrics[key].append(time_spent) + except Exception as e: + fail_type = str(type(e)) + fail_reason = str(e) + raise + finally: + # Always log the end event even on exception + if phase_name: + chromium_log.log_event_end( + phase_name, + time.time_ns(), + {"cache_stats": get_cache_stats()}, + start, + ) + chromium_log.log_event_end( + key, time.time_ns(), {"cache_stats": get_cache_stats()}, start + ) + # Only record backward compilation metrics if phase_name is not None! + if phase_name: + frame_key = str(curr_frame) + # fwd only compilation stages: entire_frame_compile, backend_compile. + # use frame_key as time aggregation key. + if fwd_only and fail_type is None: + _add_time_spent(frame_key, phase_name, time_spent) + else: + # fwd + bwd compilation stages: inductor_compile, code_gen. + # use frame_key as time aggregation key for fwd graphs; + # use compile_id as time aggregation key for bwd graphs. + if torch._guards.TracingContext.try_get() is not None: + aot_graph_name = str( + torch._guards.TracingContext.get().aot_graph_name + ) + if ( + "forward" in aot_graph_name or "inference" in aot_graph_name + ) and fail_type is None: + _add_time_spent(frame_key, phase_name, time_spent) + elif "backward" in aot_graph_name: + compile_id = str( + torch._guards.CompileContext.current_compile_id() + ) + if fail_type is None: + _add_time_spent(compile_id, phase_name, time_spent) + + # log backward compilation metrics at the end of `inductor_compile` of bwd graph, + # one record for one bwd graph. + if phase_name == "inductor_compile": + if fail_type is None: + inductor_compile_time = frame_phase_timing[ + compile_id + ].get("inductor_compile", None) + code_gen_time = frame_phase_timing[compile_id].get( + "code_gen", None + ) + else: + inductor_compile_time = None + code_gen_time = None + metrics = BwdCompilationMetrics( + compile_id, + inductor_compile_time, + code_gen_time, + fail_type, + fail_reason, + ) + record_compilation_metrics(metrics) + + +@overload +def compile_times(repr: Literal["str"], aggregate: bool = False) -> str: + ... + + +@overload +def compile_times( + repr: Literal["csv"], aggregate: bool = False +) -> Tuple[List[str], List[object]]: + ... + + +def compile_times(repr="str", aggregate: bool = False): + """ + Get metrics about torchdynamo frontend/backend compilation times. + + Accumulates information from functions tagged with `dynamo_timed`. + + repr='str' returns a printable string for user interaction, and 'csv' + returns headers, rows which can be logged for output + + aggregate causes values from multiple compilations (e.g. split graphs) + to be accumulated into one value. If false, expect more than one value + per metric. + """ + + def fmt_fn(values, item_fn=lambda x: x): + if aggregate: + return item_fn(sum(values)) + return ", ".join(map(item_fn, values)) + + if repr == "str": + rows = [ + (k, fmt_fn(compilation_time_metrics[k], item_fn=lambda x: f"{x:.4f}")) + for k in compilation_time_metrics + ] + out = "TorchDynamo compilation metrics:\n" + out += tabulate(rows, headers=("Function", "Runtimes (s)")) + return out + elif repr == "csv": + values = [ + fmt_fn(v, item_fn=lambda x: f"{x:.6f}") + for v in compilation_time_metrics.values() + ] + headers = list(compilation_time_metrics.keys()) + return headers, values + return None + + +@atexit.register +def dump_compile_times() -> None: + log.info(compile_times(repr="str", aggregate=True)) + + +tensortype_to_dtype = { + torch.FloatTensor: (torch.float32, torch.float), + torch.DoubleTensor: (torch.float64, torch.double), + torch.HalfTensor: (torch.float16, torch.half), + torch.BFloat16Tensor: (torch.bfloat16,), + torch.ByteTensor: (torch.uint8,), + torch.CharTensor: (torch.int8,), + torch.LongTensor: (torch.int64, torch.long), + torch.IntTensor: (torch.int32, torch.int), + torch.ShortTensor: (torch.int16, torch.short), + torch.BoolTensor: (torch.bool,), +} + + +class DuplicateWarningChecker: + def __init__(self, maxsize: int = 4096) -> None: + self.maxsize = maxsize + self.reset() + + def reset(self): + self.set = collections.OrderedDict() + + def add(self, key: Union[str, Tuple[object, object]]) -> bool: + if key in self.set: + self.set.move_to_end(key, last=True) + if not config.verbose: + return False + else: + self.set[key] = None + while len(self.set) > self.maxsize: + self.set.popitem(last=False) + return True + + +graph_break_dup_warning_checker = DuplicateWarningChecker() + + +def setup_compile_debug(): + compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + if compile_debug: + return add_file_handler() + + return contextlib.ExitStack() + + +def reset_graph_break_dup_checker() -> None: + graph_break_dup_warning_checker.reset() + + +def add_file_handler(): + log_path = os.path.join(get_debug_dir(), "torchdynamo") + os.makedirs(log_path, exist_ok=True) + + log_file_handler = logging.FileHandler(os.path.join(log_path, "debug.log")) + logger = logging.getLogger("torch._dynamo") + logger.addHandler(log_file_handler) + + exitstack = contextlib.ExitStack() + exitstack.callback(lambda: logger.removeHandler(log_file_handler)) + return exitstack + + +def setup_log_file(): + exitstack = contextlib.ExitStack() + if config.log_file_name is not None: + log_file_handler = logging.FileHandler(config.log_file_name) + for logger in torch._logging._internal.get_loggers(): + logger.addHandler(log_file_handler) + exitstack.callback(lambda: logger.removeHandler(log_file_handler)) + return exitstack + + return exitstack + + +def gen_record_file_name(exc, code) -> str: + return f"{get_debug_dir()}/error_recordings/\ +{code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec" + + +def write_record_to_file(filename: str, exec_record) -> None: + try: + if os.path.exists(filename): + log.warning( + "Unable to write execution record %s; file already exists.", filename + ) + else: + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "wb") as f: + exec_record.dump(f) + except Exception: + log.exception("Unable to write execution record %s", filename) + + +def count_calls(g: fx.Graph) -> int: + c = 0 + for n in g.nodes: + if "call" in n.op: + c += 1 + return c + + +def identity(x): + return x + + +def hashable(x): + try: + hash(x) + return True + except TypeError: + return False + # cannot hash writable memoryview object + except ValueError: + return False + + +def nothing(*args, **kwargs): + pass + + +class ExactWeakKeyDictionary: + """Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality""" + + def __init__(self): + self.values = {} + self.refs = {} + + def __getitem__(self, key): + return self.values[id(key)] + + def get(self, key, default=None): + return self.values.get(id(key), default) + + def __contains__(self, key): + return id(key) in self.values + + def __setitem__(self, key, value): + idx = id(key) + if idx not in self.refs: + self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx)) + self.values[idx] = value + + def _remove_id(self, idx): + if idx in self.values: + del self.values[idx] + if idx in self.refs: + del self.refs[idx] + + def clear(self): + self.refs.clear() + self.values.clear() + + +@overload +def istype(obj: object, allowed_types: Type[T]) -> TypeGuard[T]: + ... + + +@overload +def istype( + obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]] +) -> TypeGuard[T]: + ... + + +@overload +def istype(obj: object, allowed_types: Iterable[type]) -> bool: + ... + + +def istype(obj, allowed_types): + """isinstance() without subclasses""" + if isinstance(allowed_types, (tuple, list, set)): + return type(obj) in allowed_types + return type(obj) is allowed_types + + +if sys.version_info >= (3, 12): + # Some typing classes moved to C in 3.12, + # which no longer have the _Final mixin. + _builtin_final_typing_classes = ( + typing.ParamSpecArgs, + typing.ParamSpecKwargs, + typing.ParamSpec, + typing.TypeVar, + typing.TypeVarTuple, + typing.TypeAliasType, + ) + + +def is_typing(value): + # _Final catches most of typing classes: + # - Any + # - Callable + # - Union + # ... + # + # NB: we intentionally ignore classes that inherit from Generic, since they + # can be used as both TypingVariable as well as UserDefinedClassVariable. + if sys.version_info >= (3, 12) and isinstance(value, _builtin_final_typing_classes): + return True + return isinstance(value, typing._Final) or value is typing.Generic # type: ignore[attr-defined] + + +def is_numpy_int_type(value): + if not np: + return False + + return istype( + value, + ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ), + ) + + +def is_numpy_float_type(value): + if not np: + return False + + return istype( + value, + ( + np.float16, + np.float32, + np.float64, + ), + ) + + +def is_lru_cache_wrapped_function(value): + return isinstance(value, functools._lru_cache_wrapper) and is_function( + inspect.getattr_static(value, "__wrapped__") + ) + + +def is_function_or_wrapper(value): + return is_function(value) or isinstance( + value, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ) + + +def is_function(value): + return isinstance( + value, + ( + types.FunctionType, + types.BuiltinFunctionType, + types.MethodDescriptorType, + types.WrapperDescriptorType, + ), + ) + + +def is_wrapper_or_member_descriptor(value): + return isinstance( + value, + ( + # set up by PyGetSetDef + types.GetSetDescriptorType, + # set by PyMethodDef, e.g. list.append + types.MethodDescriptorType, + # slots - list.__add__ + types.WrapperDescriptorType, + # set up by PyMemberDef + types.MemberDescriptorType, + # wrapper over C functions + types.MethodWrapperType, + ), + ) + + +def unwrap_if_wrapper(fn): + return unwrap_with_attr_name_if_wrapper(fn)[0] + + +def unwrap_with_attr_name_if_wrapper(fn): + # TODO(anijain2305) - Investigate if we can get rid of this function + # unpack @torch._dynamo.optimize()(fn) wrapped function + if is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False): + fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) + attr_name = "_torchdynamo_inline" + else: + attr_name = None + return fn, attr_name + + +def is_numpy_ndarray(value): + if not np: + return False + + return istype(value, np.ndarray) + + +def istensor(obj): + """Check of obj is a tensor""" + tensor_list: Tuple[type, ...] = ( + torch.Tensor, + torch.nn.Parameter, + *config.traceable_tensor_subclasses, + ) + tensor_list = tensor_list + (torch._subclasses.FakeTensor,) + return istype(obj, tensor_list) + + +def is_lazy_module(mod): + return isinstance(mod, LazyModuleMixin) + + +@functools.lru_cache(4096) +def print_once(*args): + print(*args) + + +def make_cell(val=None): + """Some black magic to create a cell object that usually only exists in a closure""" + x = val + + def f(): + return x + + assert f.__closure__ is not None and len(f.__closure__) == 1 + return f.__closure__[0] + + +def proxy_args_kwargs(args, kwargs): + try: + proxy_args = tuple(arg.as_proxy() for arg in args) + proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return proxy_args, proxy_kwargs + except NotImplementedError as e: + from .exc import unimplemented + from .variables.base import typestr + + unimplemented( + f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}", + from_exc=e, + ) + + +@dataclasses.dataclass +class CompilationMetrics: + compile_id: str + frame_key: str + co_name: str + co_filename: str + co_firstlineno: int + cache_size: int + accumulated_cache_size: int + guard_count: Optional[int] + shape_env_guard_count: Optional[int] + graph_op_count: Optional[int] + graph_node_count: Optional[int] + graph_input_count: Optional[int] + start_time: float + entire_frame_compile_time_s: Optional[float] + backend_compile_time_s: Optional[float] + inductor_compile_time_s: Optional[float] + code_gen_time_s: Optional[float] + fail_type: Optional[str] + fail_reason: Optional[str] + fail_user_frame_filename: Optional[str] + fail_user_frame_lineno: Optional[int] + non_compliant_ops: Set[str] + compliant_custom_ops: Set[str] + restart_reasons: Set[str] + dynamo_time_before_restart_s: float + # Sometimes, we will finish analyzing a frame but conclude we don't want + # to install any guarded code. True means we actually decided to install + # a compiled frame + has_guarded_code: bool + possibly_missed_reinplacing_opportunities: Optional[int] + + +@dataclasses.dataclass +class BwdCompilationMetrics: + compile_id: str + inductor_compile_time_s: Optional[float] + code_gen_time_s: Optional[float] + fail_type: Optional[str] + fail_reason: Optional[str] + + +DEFAULT_COMPILATION_METRICS_LIMIT = 64 + + +_compilation_metrics: Deque[ + Union[CompilationMetrics, BwdCompilationMetrics] +] = collections.deque(maxlen=DEFAULT_COMPILATION_METRICS_LIMIT) + + +def record_compilation_metrics( + compilation_metrics: Union[CompilationMetrics, BwdCompilationMetrics] +): + global _compilation_metrics + _compilation_metrics.append(compilation_metrics) + if isinstance(compilation_metrics, CompilationMetrics): + name = "compilation_metrics" + else: + name = "bwd_compilation_metrics" + torch._logging.trace_structured( + name, + lambda: { + k: list(v) if isinstance(v, set) else v + for k, v in dataclasses.asdict(compilation_metrics).items() + }, + ) + if config.log_compilation_metrics: + log_compilation_event(compilation_metrics) + + +def set_compilation_metrics_limit(new_size: int) -> None: + global _compilation_metrics + while len(_compilation_metrics) > new_size: + _compilation_metrics.popleft() + new_deque = collections.deque(_compilation_metrics, maxlen=new_size) + _compilation_metrics = new_deque + + +def clear_compilation_metrics() -> None: + global _compilation_metrics + _compilation_metrics.clear() + + +def get_compilation_metrics() -> List[Union[CompilationMetrics, BwdCompilationMetrics]]: + return list(_compilation_metrics) + + +class ChromiumEventLogger: + """Logs chromium events to structured logs. tlparse will concatenate these into a perfetto UI link. + + See https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview#heading=h.yr4qxyxotyw for + a specification of the Chromium Event JSON format. + """ + + def get_stack(self): + if hasattr(self.tls, "stack"): + return self.tls.stack + else: + self.tls.stack = ["__start__"] + return self.tls.stack + + def __init__(self): + self.tls = threading.local() + # Generate a unique id for this logger, which we can use in scuba to filter down + # to a single python run. + self.id_ = str(uuid.uuid4()) + + # TODO: log to init/id tlparse after I add support for it + log.info("ChromiumEventLogger initialized with id %s", self.id_) + + def log_event_start( + self, + event_name: str, + time_ns: int, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Logs the start of a single event. + :param str event_name Name of event to appear in trace + :param time_ns Timestamp in nanoseconds + :param metadata: Any extra metadata associated with this event + """ + event = self._log_timed_event( + event_name, + time_ns, + "B", + metadata, + ) + log_chromium_event_internal(event, self.get_stack(), self.id_) + self.get_stack().append(event_name) + + def reset(self) -> None: + # We this on every compile in case a compile crashes or restarts and we haven't + # cleared the stack. + stack = self.get_stack() + stack.clear() + stack.append("__start__") + + def log_event_end( + self, + event_name: str, + time_ns: int, + metadata: Optional[Dict[str, Any]] = None, + start_time_ns: Optional[int] = None, + ) -> None: + """ + Logs the end of a single event. This function should only be + called after log_event_start with the same event_name. + :param event_name: Name of event to appear in trace + :param time_ns: Timestamp in nanoseconds + :param metadata: Any extra metadata associated with this event + """ + # These stack health checks currently never happen, + # but they're written this way to future proof any weird event + # overlaps in the future. + stack = self.get_stack() + if event_name not in stack: + # Something went wrong, we never called start on this event, + # or it was skipped due to overlapping events below + log.warning("ChromiumEventLogger: Start event not in stack, ignoring") + return + + event = self._log_timed_event( + event_name, + time_ns, + "E", + metadata, + ) + + while event_name != stack[-1]: + # If the event isn't the most recent one to end, pop + # off the stack until it is. + # Since event_name in self.stack, this pop is always safe + log.warning( + "ChromiumEventLogger: Detected overlapping events, fixing stack" + ) + stack.pop() + + log_chromium_event_internal(event, stack, self.id_, start_time_ns) + # Finally pop the actual event off the stack + stack.pop() + + def _log_timed_event( + self, + event_name: str, + time_ns: int, + phase: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Logs a timed event in chromium format. See log_event_start, log_event_end, etc. + """ + event = { + "name": event_name, + "ts": time_ns / 1000, # Chromium events are in micro seconds + "args": metadata, + "ph": phase, + # These categories are needed in all chromium traces + "cat": "dynamo_timed", + "tid": 0, + "pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id + } + torch._logging.trace_structured( + "chromium_event", + payload_fn=lambda: event, + suppress_context=False, + expect_trace_id=False, # Not every chromium event will have a trace_id + ) + return event + + def log_instant_event( + self, + event_name: str, + time_ns: int, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Log an instant event with no associated duration. + :param str event_name: Name of event to appear in trace + :param int time_ns Timestamp in nanoseconds + :param Optional[Dict[str, Any]] metadata: Any extra metadata associated with this event + :param str cname optional color for the arrow in the trace + """ + event = { + "name": event_name, + "ts": time_ns / 1000, + "args": metadata, + "ph": "i", + # These categories are needed in all chromium traces + "cat": "dynamo_timed", + "tid": 0, + "pid": 0, + "s": "p", # We use "process" level instant events so they all appear on the same row in the trace. + } + torch._logging.trace_structured( + "chromium_event", + payload_fn=lambda: event, + suppress_context=False, + expect_trace_id=True, + ) + # Log an instant event with the same start and end time + log_chromium_event_internal(event, self.get_stack(), self.id_) + + +CHROMIUM_EVENT_LOG: Optional[ChromiumEventLogger] = None + + +def get_chromium_event_logger() -> ChromiumEventLogger: + global CHROMIUM_EVENT_LOG + if CHROMIUM_EVENT_LOG is None: + CHROMIUM_EVENT_LOG = ChromiumEventLogger() + return CHROMIUM_EVENT_LOG + + +@dataclasses.dataclass +class CleanupHook: + """Remove a global variable when hook is called""" + + scope: Dict[str, Any] + name: str + + def __call__(self, *args): + # Make sure we're not shutting down + if CleanupManager is not None: + CleanupManager.count -= 1 + del self.scope[self.name] + + @staticmethod + def create(scope, name, val): + assert name not in scope + CleanupManager.count += 1 + scope[name] = val + return CleanupHook(scope, name) + + +class CleanupManager(ExactWeakKeyDictionary): + count = 0 + instance: ClassVar[CleanupManager] + + def _remove_id(self, idx): + for hook in self.values[idx]: + hook() + super()._remove_id(idx) + + +CleanupManager.instance = CleanupManager() + + +def clone_tensor(x): + """Clone the tensor and its gradient""" + y = x.clone().requires_grad_(x.requires_grad) + if x.is_leaf and x.grad is not None: + y.grad = x.grad.clone() + return y + + +def clone_input(x, *, dtype=None): + """copy while preserving strides""" + # TODO: this is questionable + if is_fake(x): + # this func fails on fake tensors in __torch_dispatch__ + return x + + def torch_clone(x): + y = torch.clone(x) + if x.is_leaf: + y.requires_grad_(x.requires_grad) + if x.is_leaf and x.grad is not None: + y.grad = clone_input(x.grad, dtype=dtype) + if hasattr(x, "_dynamo_dynamic_indices"): + y._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy() # type: ignore[attr-defined] + return y + + with torch.no_grad(): + if x.device.type == "xla": + # Access data_ptr() for a xla tensor will cause crash + return torch_clone(x) + + # Handle sparse storage (no stride). + if x.layout is torch.sparse_coo: + return torch.sparse_coo_tensor( + torch_clone(x._indices()), + torch_clone(x._values()), + x.shape, + is_coalesced=x.is_coalesced(), + ) + elif is_sparse_compressed(x): + if x.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices = x.crow_indices() + plain_indices = x.col_indices() + else: + compressed_indices = x.ccol_indices() + plain_indices = x.row_indices() + return torch.sparse_compressed_tensor( + torch_clone(compressed_indices), + torch_clone(plain_indices), + torch_clone(x.values()), + x.shape, + layout=x.layout, + ) + + needed_size = sum( + (shape - 1) * stride for shape, stride in zip(x.size(), x.stride()) + ) + if x.is_quantized: + result = torch.empty_quantized((needed_size + 32,), x) + else: + result = torch.empty( + needed_size + 32, dtype=dtype or x.dtype, device=x.device + ) + cache_line_offset = ( + (x.data_ptr() - result.data_ptr()) % 32 + ) // x.element_size() + result.as_strided_(x.size(), x.stride(), cache_line_offset) + try: + result.copy_(x.clone()) + if x.is_leaf: + result.requires_grad_(x.requires_grad) + if x.is_leaf and x.grad is not None: + result.grad = clone_input(x.grad, dtype=dtype) + except RuntimeError: + # RuntimeError: unsupported operation: more than one element of the written-to + # tensor refers to a single memory location. Please clone() the tensor before + # performing the operation. + return torch_clone(x) + if hasattr(x, "_dynamo_dynamic_indices"): + result._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy() # type: ignore[attr-defined] + return result + + +def clone_inputs(example_inputs): + res: Union[Dict[Any, Any], List[Any]] + if type(example_inputs) is dict: + res = dict(example_inputs) + for key, value in res.items(): + if isinstance(value, tuple): + res[key] = clone_inputs(value) + else: + assert isinstance(value, torch.Tensor), type(value) + res[key] = clone_input(value) + return res + + res = list(example_inputs) + for i in range(len(res)): + if isinstance(res[i], torch.Tensor): + res[i] = clone_input(res[i]) + return res + + +def skip_frame_if_in_functorch_mode(val: torch.Tensor): + try: + val.data_ptr() # will throw for functorch tensors + except RuntimeError as e: + from .exc import SkipFrame + + # This will be GradTrackingTensor/BatchedTensor/etc + functorch_subclass_name = re.sub(r"\(.*", "", repr(val)) + raise SkipFrame( + f"torch.compile cannot be run in context: {functorch_subclass_name}" + ) from e + + +@contextmanager +def preserve_rng_state(): + disable_functorch = torch._C._DisableFuncTorch + disable_current_modes = torch.utils._python_dispatch._disable_current_modes + with disable_current_modes(), disable_functorch(): + rng_state = torch.clone(torch.random.get_rng_state()) + skip_frame_if_in_functorch_mode(rng_state) + if torch.cuda.is_available(): + cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) + try: + yield + finally: + with torch.utils._python_dispatch._disable_current_modes(): + torch.random.set_rng_state(rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] + + +def is_jit_model(model0): + return isinstance( + model0, + ( + torch.jit._trace.TopLevelTracedModule, + torch.jit._script.RecursiveScriptModule, + torch.jit.ScriptFunction, + torch.jit.ScriptModule, + ), + ) + + +def torchscript(model, example_inputs, verbose=False): + if is_jit_model(model): + # already done? + return model + + try: + return torch.jit.trace(model, example_inputs) + except Exception: + try: + return torch.jit.script(model) + except Exception: + if verbose: + log.exception("jit error") + else: + log.error("Both torch.jit.trace and torch.jit.script failed") + return None + + +def getfile(obj): + try: + return inspect.getfile(obj) + except (TypeError, OSError): + return None + + +def is_namedtuple(obj): + """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple""" + return is_namedtuple_cls(type(obj)) + + +def is_namedtuple_cls(cls): + """Test if an object is a namedtuple or a (torch.return_types|torch.autograd.forward_ad).* quasi-namedtuple""" + try: + if issubclass(cls, tuple): + bases = getattr(cls, "__bases__", []) or [None] + module = getattr(cls, "__module__", None) + return module in ("torch.return_types", "torch.autograd.forward_ad") or ( + bases[0] is tuple and hasattr(cls, "_make") and hasattr(cls, "_fields") + ) + except TypeError: + pass + return False + + +@functools.lru_cache(1) +def namedtuple_fields(cls): + """Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple""" + if cls is slice: + return ["start", "stop", "step"] + + assert issubclass(cls, tuple) + if hasattr(cls, "_fields"): + # normal namedtuples + return cls._fields + + @dataclasses.dataclass + class Marker: + index: int + + # frustrating ones e.g. torch.return_types.max + assert cls.__module__ == "torch.return_types" + obj = cls(map(Marker, range(cls.n_fields))) + fields: List[Optional[str]] = [None] * cls.n_fields + for name in dir(obj): + if name[0] != "_" and isinstance(getattr(obj, name), Marker): + fields[getattr(obj, name).index] = name + return fields + + +def checkpoint_params(gm): + with torch.no_grad(): + rng_state = torch.clone(torch.random.get_rng_state()) + if torch.cuda.is_available(): + cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) + saved_state = [] + for param in itertools.chain(gm.parameters(), gm.buffers()): + saved_state.append((param, param._version, torch.clone(param))) + + def restore(): + with torch.no_grad(): + torch.random.set_rng_state(rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) + for param, version, original_value in saved_state: + if param._version != version: + param.copy_(original_value) + + return restore + + +def timed(model, example_inputs, times=1): + if torch.cuda.is_available(): + synchronize = torch.cuda.synchronize + else: + synchronize = nothing + + synchronize() + gc.collect() + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(times): + result = model(*example_inputs) + synchronize() + t1 = time.perf_counter() + return result, t1 - t0 # type: ignore[possibly-undefined] + + +def check_is_cuda(gm, example_inputs): + return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True))) + + +@lru_cache(32) +def rot_n_helper(n): + assert n > 1 + vars = [f"v{i}" for i in range(n)] + rotated = reversed(vars[-1:] + vars[:-1]) + fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})") + fn.__name__ = f"rot_{n}_helper" + return fn + + +common_constant_types: Set[type] = { + int, + float, + complex, + bool, + str, + bytes, + type(None), + Ellipsis.__class__, + types.CodeType, + torch.device, + torch.dtype, + torch.memory_format, + torch.layout, +} + +if has_triton_package(): + import triton + + common_constant_types.add(triton.language.dtype) + +""" + Difference between is_safe_constant and common_constant_types. + * common_constant_types: Constants would be wrapped by VariableBuilder.wrap_literal + as ConstantVariable. + * is_safe_constant: Constants can be loaded by LOAD_CONST bytecode. +""" + + +def is_safe_constant(v): + if istype(v, (tuple, frozenset)): + return all(map(is_safe_constant, v)) + return isinstance(v, (enum.Enum, type, torch.Size)) or istype( + v, + common_constant_types | {slice}, + ) + + +def specialize_symnode(arg): + from .variables import ConstantVariable, SymNodeVariable + + # Guard and specialize + if isinstance(arg, SymNodeVariable): + return ConstantVariable.create(arg.evaluate_expr()) + + return arg + + +def guard_if_dyn(arg): + from .variables import ConstantVariable + + arg = specialize_symnode(arg) + + if isinstance(arg, ConstantVariable): + return arg.as_python_constant() + + return arg + + +def check_constant_args(args, kwargs): + return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values())) + + +def check_unspec_python_args(args, kwargs): + from .variables.constant import ConstantVariable + from .variables.tensor import UnspecializedPythonVariable + + unspec_count = 0 + for x in itertools.chain(args, kwargs.values()): + if isinstance(x, UnspecializedPythonVariable): + unspec_count += 1 + elif not isinstance(x, ConstantVariable): + return False + return unspec_count > 0 + + +def check_unspec_or_constant_args(args, kwargs): + # A fused version of: + # return check_constant_args(args, kwargs) or check_unspec_python_args(args, kwargs) + from .variables.tensor import UnspecializedPythonVariable + + for x in itertools.chain(args, kwargs.values()): + if not (x.is_python_constant() or isinstance(x, UnspecializedPythonVariable)): + return False + return True + + +def check_numpy_ndarray_args(args, kwargs): + from .variables.tensor import NumpyNdarrayVariable + + return any( + isinstance(x, NumpyNdarrayVariable) + for x in itertools.chain(args, kwargs.values()) + ) + + +dict_keys: Type[KeysView[Any]] = type({}.keys()) +dict_values: Type[ValuesView[Any]] = type({}.values()) +odict_values: Type[ValuesView[Any]] = type(collections.OrderedDict().values()) +tuple_iterator: Type[Iterator[Any]] = type(iter(())) +tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined] +object_new = object.__new__ + + +def nn_module_new(cls): + obj = object_new(cls) + torch.nn.Module.__init__(obj) + return obj + + +def product(it): + return functools.reduce(operator.mul, it, 1) + + +def tuple_iterator_getitem(it, index): + _, (obj,), start = it.__reduce__() + return obj[start + index] + + +iter_next = next + + +def to_subclass(t, cls): + return t.as_subclass(cls) + + +def dict_keys_getitem(d, n): + return next(itertools.islice(iter(d), n, n + 1)) + + +def enum_repr(value, local): + # enum class can override __str__ method. Use __class__ and name attribute + # to extract the class name and key name. + name = value.__class__.__name__ + val = value.name + scope = "L" if local else "G" + local_name = f'{scope}["{name}"].{val}' + return local_name + + +def set_example_value(node, example_value): + # NB: example_value is a bit of a misnomer, because this is always a fake + # tensor of some sort. Furthermore, these example values serve as the + # runtime state of Dynamo tracing, which means if metadata mutation + # occurs, the example_value gets directly updated (so you can't rely on + # this to accurately reflect what the state of the value was at the time + # the program was traced). + node.meta["example_value"] = example_value + shape_env = TracingContext.get().fake_mode.shape_env + if symbol_to_path := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings( + shape_env, example_value + ): + node.meta["unbacked_bindings"] = symbol_to_path + + +def _get_fake_tensor(vt): + fake_tensor = vt.as_proxy().node.meta.get("example_value") + if not is_fake(fake_tensor): + from .exc import unimplemented + + unimplemented("Cannot check Tensor object identity without its fake value") + return fake_tensor + + +def iter_contains(items, search, tx, check_tensor_identity=False): + from .variables import ( + BuiltinVariable, + ConstantVariable, + TensorVariable, + VariableTracker, + ) + + if search.is_python_constant(): + found_const = any( + x.is_python_constant() + and x.as_python_constant() == search.as_python_constant() + for x in items + ) + return ConstantVariable.create(found_const) + + must_check_tensor_id = False + if check_tensor_identity and isinstance(search, TensorVariable): + must_check_tensor_id = True + # Match of Tensor means match of FakeTensor + search = _get_fake_tensor(search) + + found: Optional[VariableTracker] = None + for x in items: + if must_check_tensor_id: + if isinstance(x, TensorVariable): + if search is _get_fake_tensor(x): # Object equivalence + return ConstantVariable.create(True) + else: + check = BuiltinVariable(operator.eq).call_function(tx, [x, search], {}) + if found is None: + found = check + else: + found = BuiltinVariable(operator.or_).call_function( + tx, [check, found], {} + ) + if found is None: + found = ConstantVariable.create(False) + return found + + +def key_is_id(k): + """Returns whether it indexes dictionaries using its id""" + return isinstance(k, (torch.Tensor, torch.nn.Module, MethodWrapperType)) + + +def key_to_id(value): + return [id(k) if key_is_id(k) else k for k in value.keys()] + + +def const_repr(x, *, local) -> str: + from .trace_rules import is_builtin_callable + + if isinstance(x, (list, tuple)): + elems_repr = ",".join(const_repr(s, local=local) for s in x) + if isinstance(x, list): + return f"[{elems_repr}]" + else: + assert isinstance(x, tuple) + if len(x) == 1: + return f"({elems_repr},)" + else: + return f"({elems_repr})" + elif isinstance(x, enum.Enum): + # To workaround repr(Enum) returning invalid global reference before python 3.11 + # by calling enum_repr and removing quotes to render enum in guard code. + return enum_repr(x, local=local).replace("'", "") + elif is_builtin_callable(x): + return x.__name__ + elif isinstance(x, type): + + def fullname(o): + klass = o.__class__ + module = klass.__module__ + if module == "builtins": + return klass.__qualname__ # avoid outputs like 'builtins.str' + return module + "." + klass.__qualname__ + + return fullname(x) + else: + return f"{x!r}" + + +def dict_keys_repr(const_keys, *, local) -> str: + keys_str = ",".join(const_repr(s, local=local) for s in const_keys) + return "[" + keys_str + "]" + + +GLOBAL_KEY_PREFIX = "__dict_key" + + +from torch._subclasses import UnsupportedFakeTensorException # noqa: F401 + + +def get_safe_global_name(tx, root, obj): + # The global_mangled_class_name should be different for different + # invocations of torch.compile. Otherwise, we can run into a situation + # where multiple torch.compile invocations re-use the same global name, + # but the global's lifetime is tied to the first invocation (and + # may be deleted when the first torch.compile invocation is deleted) + # We mangle it based off of the output_graph's id. + return f"{root}_{id(obj)}_c{tx.output.compile_id}" + + +def wrap_fake_exception(fn): + try: + return fn() + except UnsupportedFakeTensorException as e: + from .exc import unimplemented + + msg = f"Unsupported: {e.reason} with fake tensor propagation." + log.warning(msg) + unimplemented(msg, from_exc=e) + + +def deepcopy_to_fake_tensor(obj, fake_mode): + with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode): + return wrap_fake_exception(lambda: copy.deepcopy(obj)) + + +def rmse(ref, res): + """ + Calculate root mean squared error + """ + return torch.sqrt(torch.mean(torch.square(ref - res))) + + +def same( + ref, + res, + fp64_ref=None, + cos_similarity=False, + tol=1e-4, + equal_nan=False, + exact_dtype=True, + relax_numpy_equality=False, + ignore_non_fp=False, + log_error=log.error, + use_larger_multiplier_for_smaller_tensor=False, +): + """Check correctness to see if ref and res match""" + if fp64_ref is None: + fp64_ref = ref + if isinstance( + ref, (list, tuple, collections.deque, torch.nn.ParameterList, torch.Size) + ): + assert isinstance( + res, (list, tuple, collections.deque) + ), f"type mismatch {type(ref)} {type(res)}" + if len(ref) != len(res): + log_error("Length mismatch") + return False + return len(ref) == len(res) and all( + same( + ai, + bi, + fp64_refi, + cos_similarity, + tol, + equal_nan, + exact_dtype, + relax_numpy_equality, + ignore_non_fp, + log_error=log_error, + use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor, + ) + for ai, bi, fp64_refi in zip(ref, res, fp64_ref) + ) + elif type(ref).__name__ == "QuestionAnsweringModelOutput": + # This skips checking accuracy for start_logits/end_logits. + # Tentatively, start_logits/end_logits appear to be very prone to + # inaccuracies and is somewhat subsumed by checking the loss. + return same( + ref.loss, + res.loss, + fp64_ref.loss, + cos_similarity, + tol, + equal_nan, + exact_dtype, + relax_numpy_equality, + ignore_non_fp, + log_error=log_error, + use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor, + ) + elif isinstance(ref, dict): + assert isinstance(res, dict) + assert set(ref.keys()) == set( + res.keys() + ), f"keys mismatch {set(ref.keys())} == {set(res.keys())}" + for k in sorted(ref.keys()): + if not ( + same( + ref[k], + res[k], + fp64_ref[k], + cos_similarity=cos_similarity, + tol=tol, + equal_nan=equal_nan, + exact_dtype=exact_dtype, + relax_numpy_equality=relax_numpy_equality, + ignore_non_fp=ignore_non_fp, + log_error=log_error, + use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor, + ) + ): + log_error("Accuracy failed for key name %s", k) + return False + return True + elif isinstance(ref, set): + assert isinstance(res, set) + assert set(ref) == set(res), f"elements mismatch {set(ref)} == {set(res)}" + return True + elif isinstance(ref, (torch.Tensor, float)): + assert not isinstance(ref, torch._subclasses.FakeTensor) + assert not isinstance(res, torch._subclasses.FakeTensor) + + def to_tensor(t): + return t if isinstance(t, torch.Tensor) else torch.tensor(t) + + ref, res, fp64_ref = (to_tensor(val) for val in (ref, res, fp64_ref)) + + if ref.is_sparse: + assert res.is_sparse + ref = ref.to_dense() + res = res.to_dense() + assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}" + if exact_dtype: + if ref.dtype != res.dtype: + log_error("dtype mismatch %s, %s", ref.dtype, res.dtype) + return False + if ref.dtype == torch.bool: + if ignore_non_fp: + return True + # triton stores bool as int8, so add this for more accurate checking + r = torch.allclose( + ref.to(dtype=torch.uint8), + res.to(dtype=torch.uint8), + atol=tol, + rtol=tol, + equal_nan=equal_nan, + ) + if not r: + log_error("Accuracy failed: uint8 tensor did not match") + return r + + if cos_similarity: + ref = ref.flatten().to(torch.float32) + res = res.flatten().to(torch.float32) + if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=True): + # early exit that handles zero/nan better + # cosine_similarity(zeros(10), zeros(10), dim=0) is 0 + return True + score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6) + if score < 0.99: + log.warning("Similarity score=%s", score.cpu().detach().item()) + return score >= 0.99 + else: + if not exact_dtype: + ref = ref.to(res.dtype) + + # First try usual allclose + if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=equal_nan): + return True + + # Check error from fp64 version + if fp64_ref.dtype == torch.float64: + # Fix a corner case that res and fp64_ref does not contains NaN and match (with loose tolerance) + # while the ref contains NaN. In this case, RMSE should not match any ways. + # But res is 'BETTER' than ref so we count it pass. + # + # This happens for Super_SloMo when loop ordering after fusion is enabled: + # https://gist.github.com/shunting314/11f235c70f7db0d52718d26f4a701cab + loose_tol = 1e-2 * 4 + if ( + not fp64_ref.isnan().any() + and not res.isnan().any() + and ref.isnan().any() + and torch.allclose( + fp64_ref.to(dtype=res.dtype), + res, + atol=loose_tol, + rtol=loose_tol, + equal_nan=equal_nan, + ) + ): + return True + ref_error = rmse(fp64_ref, ref).item() + # ref unable to produce this with stable numerics in this precision, ignore + if math.isnan(ref_error): + log.warning( + "Found nan in reference. Consider running in higher precision." + ) + + res_error = rmse(fp64_ref, res).item() + + # In the case of using AMP (Automatic Mixed Precision), certain models have + # failed the benchmark's correctness check. However, the end-to-end model's + # accuracy when comparing AMP with FP32 is within a difference of less than 0.1%. + # Thus, it's possible that the correctness check failures for these models are + # false alarms. We use multiplier of 3 instead of 2 to avoid these false alarms. + multiplier = ( + 3.0 if res.dtype in (torch.float16, torch.bfloat16) else 2.0 + ) + + if use_larger_multiplier_for_smaller_tensor and ( + fp64_ref.numel() <= 10 and tol >= 4 * 1e-2 + ): + multiplier = 10.0 + elif use_larger_multiplier_for_smaller_tensor and ( + fp64_ref.numel() <= 500 and tol >= 4 * 1e-2 + ): + multiplier = 5.0 + elif ( + fp64_ref.numel() < 1000 + or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1) + # large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE + or tol >= 2 * 1e-2 + ): + # In the presence of noise, noise might dominate our error + # metric for smaller tensors. + # Similary, for 1x1 kernels, there seems to be high noise with amp. + multiplier = 3.0 + + passes_test = res_error <= (multiplier * ref_error + tol / 10.0) + if ( + not passes_test + and equal_nan + and math.isnan(ref_error) + and math.isnan(res_error) + # Some unit test for the accuracy minifier relies on + # returning false in this case. + and not inductor_config.cpp.inject_relu_bug_TESTING_ONLY + ): + passes_test = True + if not passes_test: + log_error( + "RMSE (res-fp64): %.5f, (ref-fp64): %.5f and shape=%s. res.dtype: %s, multiplier: %f, tol: %f" + ", use_larger_multiplier_for_smaller_tensor: %d", + res_error, + ref_error, + res.size(), + res.dtype, + multiplier, + tol, + use_larger_multiplier_for_smaller_tensor, + ) + return passes_test + + if ignore_non_fp: + return True + + log_error("Accuracy failed: allclose not within tol=%s", tol) + return False + elif isinstance(ref, (str, int, type(None), bool, torch.device)): + if ignore_non_fp: + return True + r = ref == res + if not r: + log_error("Accuracy failed (%s): %s != %s", type(ref), ref, res) + return r + elif is_numpy_int_type(ref) or is_numpy_float_type(ref): + if relax_numpy_equality and not ( + is_numpy_int_type(res) or is_numpy_float_type(res) + ): + ref = ref.item() + r = (type(ref) is type(res)) and (ref == res) + if not r: + log_error("Accuracy failed (numpy): %s != %s", ref, res) + return r + elif is_numpy_ndarray(ref): + return (type(ref) is type(res)) and same( + torch.as_tensor(ref), + torch.as_tensor(res), + fp64_ref, + cos_similarity=cos_similarity, + tol=tol, + equal_nan=equal_nan, + exact_dtype=exact_dtype, + relax_numpy_equality=relax_numpy_equality, + ignore_non_fp=ignore_non_fp, + log_error=log_error, + use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor, + ) + elif type(ref).__name__ in ( + "MaskedLMOutput", + "Seq2SeqLMOutput", + "CausalLMOutputWithCrossAttentions", + "LongformerMaskedLMOutput", + "Instances", + "SquashedNormal", + "Boxes", + "Normal", + "TanhTransform", + "Foo", + "Variable", + ): + assert type(ref) is type(res) + return all( + same( + getattr(ref, key), + getattr(res, key), + getattr(fp64_ref, key), + cos_similarity=cos_similarity, + tol=tol, + equal_nan=equal_nan, + exact_dtype=exact_dtype, + relax_numpy_equality=relax_numpy_equality, + ignore_non_fp=ignore_non_fp, + log_error=log_error, + use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor, + ) + for key in ref.__dict__.keys() + ) + else: + raise RuntimeError(f"unsupported type: {type(ref).__name__}") + + +def format_func_info(code): + short_filename = code.co_filename.split("/")[-1] + return f"'{code.co_name}' ({short_filename}:{code.co_firstlineno})" + + +@contextlib.contextmanager +def disable_cache_limit(): + prior = config.cache_size_limit + config.cache_size_limit = sys.maxsize + prior_acc_limit = config.accumulated_cache_size_limit + config.accumulated_cache_size_limit = sys.maxsize + + try: + yield + finally: + config.cache_size_limit = prior + config.accumulated_cache_size_limit = prior_acc_limit + + +# map from transformed code back to original user code +orig_code_map = ExactWeakKeyDictionary() + +# keep a record of code_obj -> list of guard failure reasons for logging +guard_failures: DefaultDict[Any, List[Any]] = collections.defaultdict(list) + +# Keep a record of graph break reasons for logging +graph_break_reasons: List[torch._dynamo.output_graph.GraphCompileReason] = [] + +# keep record of compiled code, if we are in "error if recompile" +# to track code that dynamo has compiled previously +seen_code_map = ExactWeakKeyDictionary() + + +# return same dir unless user changes config between calls +@functools.lru_cache(None) +def _get_debug_dir(root_dir): + dir_name = ( + "run_" + + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + # use pid to avoid conflicts among ranks + + "-pid_" + + str(os.getpid()) + ) + return os.path.join(root_dir, dir_name) + + +def get_debug_dir(): + debug_root = config.debug_dir_root + return _get_debug_dir(debug_root) + + +def extract_fake_example_value(node, required=True): + if "example_value" in node.meta and is_fake(node.meta["example_value"]): + return node.meta["example_value"] + elif required: + from torch._dynamo.exc import unimplemented + + unimplemented("`FakeTensor` example value was required but not available") + else: + return None + + +def ensure_graph_fake(e, tx): + assert maybe_get_fake_mode(e) is tx.fake_mode + return e + + +def get_fake_values_from_nodes(tx, nodes, allow_non_graph_fake): + def visit(n: torch.fx.Node): + if n.op == "call_function" and "example_value" not in n.meta: + # fake tensor validity is checked inside get_fake_value using + # ensure_graph_fake + return get_fake_value(n, tx, allow_non_graph_fake) + + out = n.meta["example_value"] + if not allow_non_graph_fake and isinstance(out, torch.Tensor): + return ensure_graph_fake(out, tx) + return out + + return torch.fx.node.map_arg(nodes, visit) + + +def get_fake_value(node, tx, allow_non_graph_fake=False): + """ + Run the computation represented by `node` using fake tensors and return the result. + + allow_non_graph_fake: whether to allow the return result to be: + 1. non-fake or 2. fake that is not created by this instance of Dynamo. + If `True`, you must be prepared to deal with such return values, ideally + by further wrapping them as this graph's fakes. + """ + from torch.utils._sympy.value_ranges import ValueRangeError + + from .exc import ( + TorchRuntimeError, + unimplemented, + Unsupported, + UserError, + UserErrorType, + ) + + op = node.op + + # FX Node should always return the same fake value + if "example_value" in node.meta and is_fake(node.meta["example_value"]): + return node.meta["example_value"] + + args, kwargs = get_fake_values_from_nodes( + tx, (node.args, node.kwargs), allow_non_graph_fake + ) + + nnmodule = None + if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module): + # If the first argument is nn.Module, should copy to fake mode. + args = (deepcopy_to_fake_tensor(args[0], tx.fake_mode),) + tuple(args[1:]) + + if op == "call_module": + nnmodule = tx.output.nn_modules[node.target] + + if is_lazy_module(nnmodule) and hasattr(nnmodule, "_initialize_hook"): + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it. + # Afterwards, lazy module deletes its pre-hooks + # to avoid treating it as lazy on subsequent recompile. + nnmodule._infer_parameters(nnmodule, args) + + # no matter it's lazy module or not, we should copy to fake mode. + nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) + + try: + with tx.fake_mode, enable_python_dispatcher(): + ret_val = wrap_fake_exception( + lambda: run_node(tx.output, node, args, kwargs, nnmodule) + ) + except Unsupported: + raise + except RuntimeError as e: + cause: BaseException = e + if e.__cause__ is not None: + cause = e.__cause__ + + if isinstance( + cause, torch._subclasses.fake_tensor.DataDependentOutputException + ): + unimplemented( + f"data dependent operator: {cause.func}; " + "to enable, set torch._dynamo.config.capture_scalar_outputs = True" + ) + elif isinstance( + cause, torch._subclasses.fake_tensor.DynamicOutputShapeException + ): + if not torch._dynamo.config.capture_dynamic_output_shape_ops: + unimplemented( + f"dynamic shape operator: {cause.func}; " + "to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True" + ) + else: + unimplemented( + f"dynamic shape operator: {cause.func}; " + "Operator does not have a meta kernel that supports dynamic output shapes, " + "please report an issue to PyTorch" + ) + elif isinstance( + cause, torch._subclasses.fake_tensor.UnsupportedOperatorException + ): + op = cause.func + import_suggestion = "" + if isinstance(op, torch._ops.OpOverload): + maybe_pystub = torch._C._dispatch_pystub( + op._schema.name, op._schema.overload_name + ) + if maybe_pystub is not None: + module, ctx = maybe_pystub + import_suggestion = ( + f"It's possible that the support was implemented in " + f"module `{module}` and you may need to `import {module}`" + f"({ctx}), otherwise " + ) + unimplemented( + f"unsupported operator: {cause.func} ({import_suggestion}see " + "https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0" + " for how to fix)" + ) + elif isinstance( + cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode + ): + raise UserError( # noqa: B904 + UserErrorType.CONSTRAINT_VIOLATION, + str(cause), + case_name="constrain_as_size_example", + ) + elif isinstance(cause, ValueRangeError): + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, e.args[0]) from e + elif isinstance(cause, TypeError) and "argument" in str(cause): + unimplemented(f"TypeError {node.target}: {cause}") + + raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None + + if not allow_non_graph_fake: + _ = pytree.tree_map_only( + torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val + ) + return ret_val + + +_current_node = threading.local() + + +def get_current_node(): + return getattr(_current_node, "value", None) + + +@contextmanager +def set_current_node(node): + old = get_current_node() + _current_node.value = node + try: + yield + finally: + _current_node.value = old + + +def run_node(tracer, node, args, kwargs, nnmodule): + """ + Runs a given node, with the given args and kwargs. + + Behavior is dictated by a node's op. + + run_node is useful for extracting real values out of nodes. + See get_real_value for more info on common usage. + + Note: The tracer arg is only used for 'get_attr' ops + Note: The nnmodule arg is only used for 'call_module' ops + + Nodes that are not call_function, call_method, call_module, or get_attr will + raise an AssertionError. + """ + op = node.op + + with set_current_node(node): + + def make_error_message(e): + return f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n" + str(e) + + try: + if op == "call_function": + return node.target(*args, **kwargs) + elif op == "call_method": + return getattr(args[0], node.target)(*args[1:], **kwargs) + elif op == "call_module": + assert nnmodule is not None + return nnmodule(*args, **kwargs) + elif op == "get_attr": + return tracer.output_graph.get_submodule(node.target) + elif op == "placeholder": + assert "example_value" in node.meta + return node.meta["example_value"] + + except (NotImplementedError, UnsupportedFakeTensorException) as e: + # NB: mimic how wrap_fake_exception does it + from .exc import unimplemented + + unimplemented(make_error_message(e), from_exc=e) + except Exception as e: + raise RuntimeError(make_error_message(e)).with_traceback( + e.__traceback__ + ) from e + + raise AssertionError(op) + + +def get_real_value(node, tracer): + """ + Run the actual computation represented by `node` and return the result. + This will execute any dependent nodes in the graph as well. + """ + from .exc import TorchRuntimeError + + cache = tracer.real_value_cache + if node in cache: + return cache[node] + + op = node.op + args, kwargs = torch.fx.node.map_arg( # type: ignore[misc] + (node.args, node.kwargs), + lambda n: get_real_value(n, tracer), + ) + + if op == "placeholder" and "grapharg" in node.meta: + return node.meta["grapharg"].example + + if op == "call_module": + nn_module = tracer.output_graph.nn_modules[node.target] + if not is_lazy_module(nn_module): + nn_module = copy.deepcopy(nn_module) + else: + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it + nn_module(*args, **kwargs) + else: + nn_module = None + + try: + real_value = run_node(tracer, node, args, kwargs, nn_module) + cache[node] = real_value + except RuntimeError as e: + raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None + return real_value + + +def assert_no_fake_params_or_buffers(gm): + from torch._subclasses.fake_tensor import FakeTensorConfig, is_fake + + def stack_or_hint(t): + if FakeTensorConfig.debug: + import traceback + + return f"FAKE TENSOR CREATION TRACEBACK: \n {traceback.format_list(t._debug_trace)}" + else: + return "Enable TORCH_FAKE_TENSOR_DEBUG=1 to get creation stack traces on fake tensors." + + for name, buffer in gm.named_buffers(): + assert not is_fake( + buffer + ), f"Unexpected fake buffer {name} {stack_or_hint(buffer)}" + for name, param in gm.named_parameters(): + assert not is_fake( + param + ), f"Unexpected fake param {name} {stack_or_hint(param)}" + + +def fqn(obj: Any): + """ + Returns the fully qualified name of the object. + """ + return f"{obj.__module__}.{obj.__qualname__}" + + +def ifdynstaticdefault(count1, count2): + if torch._dynamo.config.assume_static_by_default: + return count1 + else: + return count2 + + +def import_submodule(mod: types.ModuleType): + """ + Ensure all the files in a given submodule are imported + """ + for filename in sorted(os.listdir(os.path.dirname(cast(str, mod.__file__)))): + if filename.endswith(".py") and filename[0] != "_": + importlib.import_module(f"{mod.__name__}.{filename[:-3]}") + + +def object_has_getattribute(value: Any): + return class_has_getattribute(type(value)) + + +def class_has_getattribute(cls: type): + try: + if isinstance( + inspect.getattr_static(cls, "__getattribute__"), + types.FunctionType, + ): + return True + except AttributeError: + pass + return False + + +def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False): + try: + getattr_fn = inspect.getattr_static(type(value), "__getattr__") + except AttributeError: + getattr_fn = None + if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__: + # ignore this case of getattr + getattr_fn = None + return getattr_fn + + +class TensorStaticReason(enum.Enum): + PARAMETER = 2 + NOT_TENSOR = 4 + NN_MODULE_PROPERTY = 5 + + +def tensor_static_reason_to_message(reason: TensorStaticReason): + if reason == TensorStaticReason.PARAMETER: + return "mark_dynamic on parameter, parameters are always static today." + if reason == TensorStaticReason.NOT_TENSOR: + return "mark_dynamic on a non tensor, how did this happen?" + if reason == TensorStaticReason.NN_MODULE_PROPERTY: + return "tensor is static because it is nn module associated." + raise AssertionError(f"Illegal reason {reason}") + + +def tensor_always_has_static_shape( + tensor: Union[torch.Tensor, Any], + is_tensor: bool, + tensor_source: Source, +) -> Tuple[bool, Optional[TensorStaticReason]]: + """ + Given a tensor, source, and is_tensor flag, determine if a shape should be static. + + Args: + tensor - the real tensor to evaluate, parameters force a static shape. + is_tensor - internal dynamo check, essentially "is_tensor": target_cls is TensorVariable, + tensors not in a TensorVariable for whatever reason are forced static. + + Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape. + The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed. + """ + from .source import is_from_unspecialized_param_buffer_source + + if ( + tensor_source.guard_source().is_specialized_nn_module() + or tensor_source.guard_source().is_unspecialized_builtin_nn_module() + ) and config.force_nn_module_property_static_shapes: + return True, TensorStaticReason.NN_MODULE_PROPERTY + + if ( + type(tensor) is torch.nn.Parameter + or is_from_unspecialized_param_buffer_source(tensor_source) + ) and config.force_parameter_static_shapes: + return True, TensorStaticReason.PARAMETER + if not is_tensor: + return True, TensorStaticReason.NOT_TENSOR + return False, None + + +def lazy_format_graph_tabular(fn_name, gm): + def inner(): + try: + from tabulate import tabulate # TODO: Check that this is installed + except ImportError: + return ( + "Tabulate module missing, please install tabulate to log the graph in tabular format, logging code instead:\n" + + str(lazy_format_graph_code(fn_name, gm)) + ) + + node_specs = [ + [n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes + ] + graph_str = tabulate( + node_specs, headers=["opcode", "name", "target", "args", "kwargs"] + ) + return _format_graph_code(fn_name, gm.forward.__code__.co_filename, graph_str) + + return LazyString(inner) + + +def format_bytecode(prefix, name, filename, line_no, code): + return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n" + + +forward_hook_names = ["_forward_pre_hooks", "_forward_hooks"] +backward_hook_names = ["_backward_pre_hooks", "_backward_hooks"] +state_dict_hook_names = [ + "_state_dict_pre_hooks", + "_state_dict_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", +] +all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names + + +def nn_module_has_global_hooks(): + # This is limited to backward hooks for now because NNModuleVariable + # supports fwd hooks underneath. + return len(torch.nn.modules.module._global_backward_hooks) or len( + torch.nn.modules.module._global_backward_pre_hooks + ) + + +def nn_module_get_all_hooks( + mod, + check_forward_hooks=False, + check_backward_hooks=False, + check_state_dict_hooks=False, +): + """ + Sometimes its useful to differentiate between types of hooks such as forward/backward/pre + hooks executed during module.__call__, and state_dict hooks which are executed separately. + """ + hook_dicts_to_check = [] + check_all_hooks = ( + not check_forward_hooks + and not check_backward_hooks + and not check_state_dict_hooks + ) + if check_forward_hooks or check_all_hooks: + hook_dicts_to_check.extend(forward_hook_names) + if check_backward_hooks or check_all_hooks: + hook_dicts_to_check.extend(backward_hook_names) + if check_state_dict_hooks: + hook_dicts_to_check.extend(state_dict_hook_names) + + all_hooks = [] + for hook_dict_name in hook_dicts_to_check: + hooks = getattr(mod, hook_dict_name, []) + for hook_name in hooks: + hook = hooks[hook_name] + + all_hooks.append(hook) + return all_hooks + + +def nnmodule_has_hooks( + mod, + check_forward_hooks=False, + check_backward_hooks=False, + check_state_dict_hooks=False, +): + """ + Helper function to check if a module has any hooks attached to it. + """ + hooks = nn_module_get_all_hooks( + mod, + check_forward_hooks=check_forward_hooks, + check_backward_hooks=check_backward_hooks, + check_state_dict_hooks=check_state_dict_hooks, + ) + return bool(hooks) + + +def to_numpy_helper(value): + """Convert tensor and tnp.ndarray to numpy.ndarray.""" + if is_fake(value): + return value + if isinstance(value, tnp.ndarray): + return to_numpy_helper(value.tensor) + elif isinstance(value, torch.Tensor): + return value.numpy(force=True) + elif isinstance(value, (tuple, list)): + return type(value)(to_numpy_helper(obj) for obj in value) + else: + return value + + +def numpy_to_tensor(value): + """Convert tnp.ndarray to tensor, leave other types intact. If a list/tuple, loop through it to convert.""" + assert np is not None + if isinstance(value, np.ndarray): + return torch.as_tensor(value) + if isinstance(value, tnp.ndarray): + return value.tensor + elif isinstance(value, (tuple, list)): + return type(value)(numpy_to_tensor(obj) for obj in value) + else: + return value + + +class numpy_to_tensor_wrapper: + def __init__(self, f): + self.f = f + self.__name__ = "wrapped_" + self.f.__name__ + + def __repr__(self): + return f">" + + def __call__(self, *args, **kwargs): + out = self.f(*args, **kwargs) + return numpy_to_tensor(out) + + +def numpy_attr_wrapper(obj, name): + if isinstance(obj, tnp.ndarray): + out = getattr(obj, name) + return numpy_to_tensor(out) + elif isinstance(obj, torch.Tensor): + out = getattr(tnp.ndarray(obj), name) + return numpy_to_tensor(out) + + +class numpy_method_wrapper: + """Convert obj from torch.Tensor to tnp.ndarray and call method. Then convert result back to torch.Tensor.""" + + def __init__(self, method: str): + self.method = method + self.__name__ = "wrapped_" + self.method + + def __repr__(self): + return f">" + + def __call__(self, *args, **kwargs): + obj = args[0] + if isinstance(obj, torch.Tensor): + obj = tnp.ndarray(obj) + method_callable = getattr(obj, self.method) + out = method_callable(*args[1:], **kwargs) + return numpy_to_tensor(out) + + +class numpy_operator_wrapper: + """Implements dunder methods for tnp.ndarray via functions from the operator library""" + + def __init__(self, op: Callable[..., Any]): + self.op = op + self.__name__ = f"wrapped_{op.__name__}" + + def __repr__(self): + return f">" + + def __call__(self, *args, **kwargs): + assert not kwargs + + args = ( + tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args + ) + out = self.op(*args) + return numpy_to_tensor(out) + + +def defake(x): + if not isinstance(x, FakeTensor): + return x + size: torch._prims_common.ShapeType + stride: torch._prims_common.StrideType + if x._has_symbolic_sizes_strides: + size = [] + for s in x.size(): + if isinstance(s, torch.SymInt): + size.append(s.node.shape_env.size_hint(s.node.expr)) + else: + size.append(s) + stride = [] + for s in x.stride(): + if isinstance(s, torch.SymInt): + stride.append(s.node.shape_env.size_hint(s.node.expr)) + else: + stride.append(s) + else: + size = x.size() + stride = x.stride() + y = torch.empty_strided( + size, + stride, + dtype=x.dtype, + device=x.device, + requires_grad=x.requires_grad, + ) + y.zero_() + return y + + +def is_utils_checkpoint(obj): + # Lazy import to avoid circular dependencies + import torch.utils.checkpoint + + return obj is torch.utils.checkpoint.checkpoint + + +def build_checkpoint_variable(**options): + import torch._higher_order_ops.wrap as higher_order_ops + + from .variables.higher_order_ops import TorchHigherOrderOperatorVariable + + # TODO - This is a temporary situation where we have two versions of + # checkpointing implementation. We will converge on one and remove the other. + activation_checkpoint_op: torch._ops.HigherOrderOperator = ( + higher_order_ops.tag_activation_checkpoint + ) + if torch._functorch.config.functionalize_rng_ops: + activation_checkpoint_op = higher_order_ops.wrap_activation_checkpoint + + return TorchHigherOrderOperatorVariable.make( + activation_checkpoint_op, + **options, + ) + + +def is_compile_supported(device_type): + from .eval_frame import is_dynamo_supported + + compile_supported = is_dynamo_supported() + if device_type == "cpu": + pass + elif device_type == "cuda" and compile_supported: + compile_supported = has_triton() + else: + compile_supported = False + return compile_supported + + +# The following 3.11 source code functions are adapted from +# https://github.com/python/cpython/blob/v3.11.4/Lib/traceback.py +# in order to output source code corresponding to bytecode in 3.11+. +# We need our own versions since we want to support multiline expressions. +def _fix_offset(str: str, offset: int) -> int: + """ + Convert byte offset `offset` of `str` into character offset. + Byte offset is used for 3.11+ instruction column data. + Takes things like unicode characters into consideration. + + Unchanged from CPython implementation. + """ + as_utf8 = str.encode("utf-8") + return len(as_utf8[:offset].decode("utf-8", errors="replace")) + + +@dataclasses.dataclass +class _Anchors: + # inclusive + left_end_lineno: int + left_end_offset: int + right_start_lineno: int + # exclusive + right_start_offset: int + + +def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]: + """ + Given source code `segment` corresponding to a bytecode + instruction, determine: + - for binary ops, the location of the binary op + - for indexing, the location of the brackets. + `segment` is expected to be a valid Python expression + """ + assert sys.version_info >= (3, 11) + + import ast + + try: + # Without brackets, `segment` is parsed as a statement. + # We expect an expression, so wrap `segment` in + # brackets to handle multi-line expressions. + tree = ast.parse("(\n" + segment + "\n)") + except SyntaxError: + return None + + if len(tree.body) != 1: + return None + + lines = segment.split("\n") + + # get character index given byte offset + def normalize(lineno, offset): + return _fix_offset(lines[lineno], offset) + + # Gets the next valid character index in `lines`, if + # the current location is not valid. Handles empty lines. + def next_valid_char(lineno, col): + while lineno < len(lines) and col >= len(lines[lineno]): + col = 0 + lineno += 1 + assert lineno < len(lines) and col < len(lines[lineno]) + return lineno, col + + # Get the next valid character index in `lines`. + def increment(lineno, col): + col += 1 + lineno, col = next_valid_char(lineno, col) + assert lineno < len(lines) and col < len(lines[lineno]) + return lineno, col + + # Get the next valid character at least on the next line + def nextline(lineno, col): + col = 0 + lineno += 1 + lineno, col = next_valid_char(lineno, col) + assert lineno < len(lines) and col < len(lines[lineno]) + return lineno, col + + statement = tree.body[0] + if isinstance(statement, ast.Expr): + expr = statement.value + if isinstance(expr, ast.BinOp): + # ast gives locations for BinOp subexpressions, e.g. + # ( left_expr ) + ( right_expr ) + # left^^^^^ right^^^^^ + # -2 since end_lineno is 1-indexed and because we added an extra + # bracket to `segment` when calling ast.parse + cur_lineno = cast(int, expr.left.end_lineno) - 2 + cur_col = normalize(cur_lineno, expr.left.end_col_offset) + cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col) + + # Heuristic to find the operator character. + # The original CPython implementation did not look for ), \, or #, + # leading to incorrect anchor location, e.g. + # (x) + (y) + # ~~^~~~~~~ + while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#": + if ch in "\\#": + cur_lineno, cur_col = nextline(cur_lineno, cur_col) + else: + cur_lineno, cur_col = increment(cur_lineno, cur_col) + + # binary op is 1 or 2 characters long, on the same line + right_col = cur_col + 1 + if ( + right_col < len(lines[cur_lineno]) + and not (ch := lines[cur_lineno][right_col]).isspace() + and ch not in "\\#" + ): + right_col += 1 + # right_col can be invalid since it is exclusive + + return _Anchors(cur_lineno, cur_col, cur_lineno, right_col) + elif isinstance(expr, ast.Subscript): + # ast gives locations for value and slice subexpressions, e.g. + # ( value_expr ) [ slice_expr ] + # value^^^^^ slice^^^^^ + # subscript^^^^^^^^^^^^^^^^^^^^ + # find left bracket (first '[' after value) + left_lineno = cast(int, expr.value.end_lineno) - 2 + left_col = normalize(left_lineno, expr.value.end_col_offset) + left_lineno, left_col = next_valid_char(left_lineno, left_col) + while lines[left_lineno][left_col] != "[": + left_lineno, left_col = increment(left_lineno, left_col) + # find right bracket (final character of expression) + right_lineno = cast(int, expr.end_lineno) - 2 + right_col = normalize(right_lineno, expr.end_col_offset) + return _Anchors(left_lineno, left_col, right_lineno, right_col) + elif isinstance(expr, ast.Call): + # ( func_expr ) (args, kwargs) + # func^^^^^ + # call^^^^^^^^^^^^^^^^^^^^^^^^ + # find left bracket (first '(' after func) + left_lineno = cast(int, expr.func.end_lineno) - 2 + left_col = normalize(left_lineno, expr.func.end_col_offset) + left_lineno, left_col = next_valid_char(left_lineno, left_col) + while lines[left_lineno][left_col] != "(": + left_lineno, left_col = increment(left_lineno, left_col) + # find right bracket (final character of expression) + right_lineno = cast(int, expr.end_lineno) - 2 + right_col = normalize(right_lineno, expr.end_col_offset) + return _Anchors(left_lineno, left_col, right_lineno, right_col) + + return None + + +def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> str: + """ + Python 3.11+ only. Returns lines of source code (from code object `code`) + corresponding to `inst`'s location data, and underlines relevant code to `inst`. + + Example: CALL on `g`: + f(g( + ^^ + h(x))) + ^^^^^ + + We need our own implementation since `format_frame_summary` in + Python's `traceback` module doesn't handle multi-line expressions + (and their anchor extraction code is not completely correct). + """ + assert inst.positions is not None + if inst.positions.lineno is None: + return "" + # The rstrip + "\n" pattern is used throughout this function to handle + # linecache.getline errors. Error lines are treated as empty strings "", but we want + # to treat them as blank lines "\n". + first_line = linecache.getline(code.co_filename, inst.positions.lineno).rstrip() + if inst.positions.end_lineno is None: + return first_line + if inst.positions.col_offset is None or inst.positions.end_col_offset is None: + return first_line + + # character index of the start of the instruction + start_offset = _fix_offset(first_line, inst.positions.col_offset) + # character index of the end of the instruction + # compute later since end may be a different line + end_offset = None + # expression corresponding to the instruction so we can get anchors + segment = "" + # underline markers to be printed - start with `~` marker and replace with `^` later + markers = [] + + # Compute segment and initial markers + if inst.positions.end_lineno == inst.positions.lineno: + end_offset = _fix_offset(first_line, inst.positions.end_col_offset) + segment = first_line[start_offset:end_offset] + markers.append(" " * start_offset + "~" * (end_offset - start_offset)) + else: + segment = first_line[start_offset:] + "\n" + markers.append(" " * start_offset + "~" * (len(first_line) - start_offset)) + last_line = linecache.getline( + code.co_filename, inst.positions.end_lineno + ).rstrip() + end_offset = _fix_offset(last_line, inst.positions.end_col_offset) + for lineno in range(inst.positions.lineno + 1, inst.positions.end_lineno): + line = linecache.getline(code.co_filename, lineno).rstrip() + segment += line + "\n" + # don't underline leading spaces + num_spaces = len(line) - len(line.lstrip()) + markers.append(" " * num_spaces + "~" * (len(line) - num_spaces)) + segment += last_line[:end_offset] + num_spaces = len(last_line) - len(last_line.lstrip()) + markers.append(" " * num_spaces + "~" * (end_offset - num_spaces)) + + anchors: Optional[_Anchors] = None + try: + anchors = _extract_anchors_from_expr(segment) + except AssertionError: + pass + + # replace `~` markers with `^` where necessary + if anchors is None: + markers = [marker.replace("~", "^") for marker in markers] + else: + # make markers mutable + mutable_markers: List[List[str]] = [list(marker) for marker in markers] + + # anchor positions do not take start_offset into account + if anchors.left_end_lineno == 0: + anchors.left_end_offset += start_offset + if anchors.right_start_lineno == 0: + anchors.right_start_offset += start_offset + + # Turn `~`` markers between anchors to `^` + for lineno in range(len(markers)): + for col in range(len(mutable_markers[lineno])): + if lineno < anchors.left_end_lineno: + continue + if lineno == anchors.left_end_lineno and col < anchors.left_end_offset: + continue + if ( + lineno == anchors.right_start_lineno + and col >= anchors.right_start_offset + ): + continue + if lineno > anchors.right_start_lineno: + continue + if mutable_markers[lineno][col] == "~": + mutable_markers[lineno][col] = "^" + + # make markers into strings again + markers = ["".join(marker) for marker in mutable_markers] + + result = "" + for i in range(len(markers)): + result += ( + linecache.getline(code.co_filename, inst.positions.lineno + i).rstrip() + + "\n" + ) + result += markers[i] + "\n" + return result + + +def get_static_address_type(t): + if isinstance(t, torch.Tensor): + return getattr(t, "_dynamo_static_input_type", None) + + return None + + +def is_rng_state_getter_or_setter(value): + getters = ( + # The following two functions are not identical, so don't remove anyone! + torch._C.Generator.get_state, + torch.default_generator.get_state, + torch.get_rng_state, + torch.cuda.get_rng_state, + ) + setters = ( + torch._C.Generator.set_state, + torch.default_generator.set_state, + torch.set_rng_state, + torch.cuda.set_rng_state, + ) + return value in (*setters, *getters) + + +def is_tensor_base_attr_getter(value): + return ( + isinstance(value, types.MethodWrapperType) + and value.__name__ == "__get__" + and value.__self__.__objclass__ is torch._C._TensorBase # type: ignore[attr-defined] + ) + + +def is_torch_function_object(value): + return hasattr(value, "__torch_function__") + + +def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool: + from torch._dynamo.variables import LazyVariableTracker, UserDefinedObjectVariable + from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable + + if isinstance(vt, TensorWithTFOverrideVariable): + return True + + if isinstance(vt, LazyVariableTracker): + LazyVariableTracker.realize(vt) + + return isinstance(vt, UserDefinedObjectVariable) and hasattr( + vt.value, "__torch_function__" + ) + + +# see note [Tensor Fakification and Symbol Caching] +def to_fake_tensor(t, fake_mode): + symbolic_context = None + source = None + if tracing_context := torch._guards.TracingContext.try_get(): + if t in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[t] + source = symbolic_context.tensor_source + + return fake_mode.from_tensor( + t, static_shapes=False, symbolic_context=symbolic_context, source=source + ) + + +# NB: this works for both classes and instances +def is_frozen_dataclass(value): + return ( + not object_has_getattribute(value) + and not class_has_getattribute(value) + and is_dataclass(value) + and value.__dataclass_params__.frozen + ) + + +def get_first_attr(obj, *attrs): + """ + Return the first available attribute or throw an exception if none is present. + """ + for attr in attrs: + if hasattr(obj, attr): + return getattr(obj, attr) + + raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") + + +@contextlib.contextmanager +def maybe_enable_compiled_autograd(should_enable, fullgraph=True, dynamic=True): + if not should_enable: + yield + else: + + def compiler_fn(gm): + def inner_compiler(gm_, example_inputs_): + torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1 + return torch._inductor.compile(gm_, example_inputs_) + + return torch.compile( + gm, backend=inner_compiler, fullgraph=fullgraph, dynamic=dynamic + ) + + with torch._dynamo.compiled_autograd.enable(compiler_fn) as ctx: + yield ctx + + +def invalid_removeable_handle(): + # need a subclass so weakref works + class Invalid(dict): # type: ignore[type-arg] + pass + + return RemovableHandle(Invalid()) + + +# Returns a "proxy" (new object with the same class and dict) for (non-GraphModule) nn.Module's. +# Attribute changes to the original object/proxy will be reflected in the other. +# This is useful for cases where we want a keep-alive reference to a module without increasing +# its reference count. +def nn_module_proxy(mod): + if not isinstance(mod, torch.nn.Module): + return mod + if isinstance(mod, torch.fx.GraphModule): + # Dynamo-generated GM's shouldn't contain user-created GM's + return mod + proxy = mod.__class__.__new__(mod.__class__) + proxy.__dict__ = mod.__dict__ + return proxy + + +class GmWrapper(torch.nn.Module): + def __init__(self, gm, unflatten_fn): + super().__init__() + self.gm = gm + self.unflatten_fn = unflatten_fn + + def forward(self, *args): + args: List[Any] = list(args) + return self.gm(*self.unflatten_fn(args)) + + +def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): + """ + Mutate inputs so that they are flat and wrap gm such that it + accepts those inputs. This is needed for graphs that take + bumpy inputs. + """ + inputs_idx_to_clear = [ + i + for i, node in enumerate(gm.graph.nodes) + if node.op == "placeholder" and node.meta.get("steal_arg", False) + ] + + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + # fast path, avoid pytree overhead + # compiled autograd inputs are always a list of tensors, maybe followed by symints + assert inputs_idx_to_clear == [0] + assert isinstance(inputs[0], list) + boxed_inputs_count = len(inputs[0]) + + def flatten_fn(args): + return args[0] + list(args[1:]) + + def unflatten_fn(flat_args): + return (flat_args[:boxed_inputs_count], *flat_args[boxed_inputs_count:]) + + compiled_fn = compile_gm(GmWrapper(gm, unflatten_fn), flatten_fn(inputs)) + else: + # slow path, don't know inputs structure + flat_inputs, spec = pytree.tree_flatten(inputs) + unflatten_fn = functools.partial(pytree.tree_unflatten, treespec=spec) + compiled_fn = compile_gm(GmWrapper(gm, unflatten_fn), flat_inputs) + # note this doesn't check the spec, assuming it is the same + flatten_fn = pytree.arg_tree_leaves + + def wrapper(*args): + flat_args = flatten_fn(args) + + # flat_args is a new list, so we need to clear references from the old list + for i in inputs_idx_to_clear: + args[i].clear() + + # this call is boxed to avoid increasing refcount until we reach aot_module_simplified forward + return compiled_fn(flat_args) + + return wrapper + + +def get_locals_to_steal(maybe_gm): + if not isinstance(maybe_gm, torch.fx.GraphModule) or not hasattr(maybe_gm, "meta"): + return [] + return maybe_gm.meta.get("locals_to_steal", []) + + +def set_locals_to_steal(gm, locals_to_steal): + gm.meta["locals_to_steal"] = locals_to_steal + + +class Lit: + def __init__(self, s): + self.s = s + + def __repr__(self): + return self.s + + +warn_once_cache: Set[str] = set() + + +def warn_once(msg, stacklevel=1): + # Dynamo causes all warnings.warn (in user code and in Dynamo code) to print all the time. + # https://github.com/pytorch/pytorch/issues/128427. + # warn_once is a workaround: if the msg has been warned on before, then we will not + # warn again. + # NB: it's totally ok to store a cache of all the strings: this is what warnings.warn does as well. + if msg in warn_once_cache: + return + warn_once_cache.add(msg) + warnings.warn(msg, stacklevel=stacklevel + 1) + + +def strip_color_from_string(text): + # This regular expression matches ANSI escape codes + ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") + return ansi_escape.sub("", text) + + +@contextlib.contextmanager +def _disable_saved_tensors_hooks_during_tracing(): + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + try: + prior = torch._C._autograd._saved_tensors_hooks_set_tracing(True) + yield + finally: + torch._C._autograd._saved_tensors_hooks_set_tracing(prior) + + +def is_parameter_freezing(): + return torch._inductor.config.freezing and not torch.is_grad_enabled() + + +def get_torch_function_mode_stack(filter_ignored=True): + from .variables.torch_function import IGNORED_MODES + + stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())] + if filter_ignored: + stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] + + return stack + + +def get_torch_function_mode_stack_at(ind): + assert ind < _len_torch_function_stack() and ind >= 0 + return torch._C._get_function_stack_at(ind) + + +def set_torch_function_mode_stack(stack): + for i in range(_len_torch_function_stack()): + _pop_torch_function_stack() + + for mode in stack: + _push_on_torch_function_stack(mode) + + +def verify_guard_fn_signature(value): + fn = value.__metadata_guard__ + sig = inspect.signature(fn) + if len(sig.parameters) != 2: + from .exc import InternalTorchDynamoError + + raise InternalTorchDynamoError( + "Tensor subclass method __metadata_guard__ must take exactly two subclass metadata arguments" + ) + if fn.__self__ != value.__class__: + from .exc import InternalTorchDynamoError + + raise InternalTorchDynamoError( + "Tensor subclass method __metadata_guard__ must be a classmethod" + ) + + +def does_not_override_dict_iter_methods(user_cls): + return ( + user_cls.items in (dict.items, collections.OrderedDict.items) + and user_cls.values in (dict.values, collections.OrderedDict.values) + and user_cls.keys in (dict.keys, collections.OrderedDict.keys) + and user_cls.__iter__ in (dict.__iter__, collections.OrderedDict.__iter__) + ) + + +# Helper function to extract relevant parts of a tensor's __dict__ to store in node meta. +# To avoid ref cycles, it's important that no tensors are present here, so leave those out. +def _extract_tensor_dict(t): + KEYS_TO_COPY = [ + "_dynamo_static_input_type", + "tag", + ] + + tensor_dict = { + key: copy.copy(t.__dict__[key]) for key in KEYS_TO_COPY if key in t.__dict__ + } + + return tensor_dict + + +# This is useful for reconstructing within the Dynamo graph the non-graph-input objects +# whose lifetime is governed by the user. +# e.g. torch.cuda.Event is a prime example. +user_obj_id_to_weakref: Dict[int, weakref.ReferenceType[object]] = {} + + +def get_user_object_from_id(obj_id): + obj = user_obj_id_to_weakref[obj_id]() + assert obj is not None, "User object is no longer alive" + return obj + + +def store_user_object_weakref(obj): + obj_id = id(obj) + user_obj_id_to_weakref[obj_id] = weakref.ref(obj) + + +class CompileTimeInstructionCounter: + _counter: int = 0 + _id: int = -1 + _depth = 0 + + @classmethod + def start(cls) -> None: + cls._depth = cls._depth + 1 + if cls._depth == 1: + cls._id = _instruction_counter.start() + + @classmethod + def end(cls) -> None: + cls._depth = cls._depth - 1 + if cls._depth == 0: + cls._counter += _instruction_counter.end(cls._id) + cls._id = -1 + + @classmethod + def clear(cls) -> None: + cls._counter = 0 + + @classmethod + def value(cls) -> int: + return cls._counter + + @classmethod + @contextmanager + def record(cls): + try: + if config.record_compile_time_instruction_count: + cls.start() + yield + finally: + if config.record_compile_time_instruction_count: + cls.end() diff --git a/lib/python3.10/site-packages/torch/_export/__init__.py b/lib/python3.10/site-packages/torch/_export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91d893e05cb89c65a9bf2a07b5ed973d20f9a2c6 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_export/__init__.py @@ -0,0 +1,317 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import functools +import io +import json +import logging +import os +import re +import sys +import types +import warnings +import weakref +import zipfile +from collections import OrderedDict +from contextlib import contextmanager +from functools import lru_cache + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest.mock import patch + +import torch +import torch.fx +import torch.utils._pytree as pytree + +from torch._dispatch.python import enable_python_dispatcher +from torch._utils_internal import log_export_usage +from torch.export._tree_utils import reorder_kwargs +from torch.export.graph_signature import ( + ArgumentSpec, + ConstantArgument, + ExportGraphSignature, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + SymIntArgument, + TensorArgument, +) +from torch.fx import traceback as fx_traceback +from torch.fx._compatibility import compatibility +from torch.fx.experimental.proxy_tensor import make_fx +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo + +from .wrappers import _wrap_submodules + +log = logging.getLogger(__name__) + +@dataclasses.dataclass +class ExportDynamoConfig: + """ + Manage Export-specific configurations of Dynamo. + """ + allow_rnn: bool = True + + +# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph +# is called multiple times. +@lru_cache +def capture_pre_autograd_graph_warning(): + from torch._inductor import config + + log.warning("+============================+") + log.warning("| !!! WARNING !!! |") + log.warning("+============================+") + log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.") + log.warning("Please switch to use torch.export.export_for_training instead.") + if config.is_fbcode(): + log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950 + + +@compatibility(is_backward_compatible=False) +def capture_pre_autograd_graph( + f: torch.nn.Module, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, +) -> torch.nn.Module: + """ + A helper function that is intended to trace a module before any pre-autograd + decomposition is run. The produced module will be "non-functional" and + composed of aten operators. Later this API will be deleted in favor of more general + torch.export API. + + Args: + f: nn.Module to be traced + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: Should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + Returns: + An nn.Module containing the traced method. + + """ + from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps + from torch._utils_internal import capture_pre_autograd_graph_using_training_ir + from torch._export.non_strict_utils import make_constraints + from torch._subclasses.functional_tensor import FunctionalTensor + from torch.export._unlift import _create_stateful_graph_module + from torch.export.dynamic_shapes import _combine_args + + capture_pre_autograd_graph_warning() + + if sys.platform == "win32": + raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows") + + assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance." + + if kwargs is None: + kwargs = {} + + if capture_pre_autograd_graph_using_training_ir(): + @lru_cache + def print_export_warning(): + log.warning("Using torch.export.export_for_training(...,strict=True)") + print_export_warning() + module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module() + else: + log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"}) + + # Do not decompose dropout for exported models, because in eval mode the dropout + # op disappears from the graph, which makes it difficult to switch to train mode. + # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832. + decomp_table = { + op: op.decompose + for op in FunctionalTensor.maybe_aliasing_or_mutating_ops + if op != torch.ops.aten.dropout.default + } + with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps(): + m = torch._dynamo.export( + f, + dynamic_shapes=dynamic_shapes, + assume_static_by_default=True, + tracing_mode="symbolic", + decomposition_table=decomp_table, + pre_dispatch=True, + aten_graph=True, + _log_export_usage=False, + )( + *args, + **kwargs, + )[0] + + _, _, fake_mode = _extract_fake_inputs(m, args, kwargs) + + m.meta["inline_constraints"] = { + k: v + for k, v in fake_mode.shape_env.var_to_range.items() + if re.match(r"^[if]\d+$", str(k)) + } + + if isinstance(f, torch.nn.Module): + from torch.export._trace import _restore_state_dict + _restore_state_dict(f, m) + + flat_args, _ = pytree.tree_flatten((args, kwargs or {})) + combined_args = _combine_args(f, args, kwargs) + range_constraints = make_constraints( + fake_mode, + m, + combined_args, + dynamic_shapes, + 0, + ) + + module = _create_stateful_graph_module( + m, + range_constraints=range_constraints, + ) + + error_message = \ + """ + Calling train() or eval() is not supported for exported models. + Alternatively, you may override these methods to do custom user behavior as follows: + + def _my_train(self, mode: bool = True): + ... + + def _my_eval(self): + ... + + model.train = types.MethodType(_my_train, model) + model.eval = types.MethodType(_my_eval, model) + """ + + def _train(self, mode: bool = True): + raise NotImplementedError(error_message) + + def _eval(self, mode: bool = True): + raise NotImplementedError(error_message) + + module.train = types.MethodType(_train, module) # type: ignore[method-assign] + module.eval = types.MethodType(_eval, module) # type: ignore[method-assign] + + # Remove Proxy because they cannot be deepcopied or pickled. + if hasattr(module, "_buffers"): + torch._export.utils.remove_proxy_from_state_dict( + module._buffers, in_place=True + ) + return module + + +def aot_compile( + f: Callable, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, + remove_runtime_assertions: bool = False, + disable_constraint_solver: bool = False, + same_signature: bool = True, +) -> str: + """ + Note: this function is not stable yet + + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside, generates executable cpp code from the program, and returns + the path to the generated shared library + + Args: + f: the `nn.Module` or callable to trace. + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: Should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + options: A dictionary of options to control inductor + + disable_constraint_solver: Whether the dim constraint solver must be disabled. + + Returns: + Path to the generated shared library + """ + from torch.export._trace import _export_to_torch_ir + from torch._inductor.decomposition import select_decomp_table + from torch._inductor import config + + if config.is_predispatch: + gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module() + else: + # We want to export to Torch IR here to utilize the pre_grad passes in + # inductor, which run on Torch IR. + gm = _export_to_torch_ir( + f, + args, + kwargs, + dynamic_shapes, + disable_constraint_solver=disable_constraint_solver, + same_signature=same_signature, + # Disabling this flag, because instead we can rely on the mapping + # dynamo_flat_name_to_original_fqn which is coming from Dynamo. + restore_fqn=False, + ) + + with torch.no_grad(): + so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options) # type: ignore[arg-type] + + return so_path + +def aot_load(so_path: str, device: str) -> Callable: + """ + Loads a shared library generated by aot_compile and returns a callable + + Args: + so_path: Path to the shared library + + Returns: + A callable + """ + if device == "cpu": + runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg] + elif device == "cuda" or device.startswith("cuda:"): + runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg] + else: + raise RuntimeError("Unsupported device " + device) + + def optimized(*args, **kwargs): + call_spec = runner.get_call_spec() # type: ignore[attr-defined] + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] + flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + return pytree.tree_unflatten(flat_outputs, out_spec) + + return optimized diff --git a/lib/python3.10/site-packages/torch/_export/converter.py b/lib/python3.10/site-packages/torch/_export/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..b45d7849b29ae04ff1e77a812b0ccf86a90a4b0d --- /dev/null +++ b/lib/python3.10/site-packages/torch/_export/converter.py @@ -0,0 +1,1584 @@ +# mypy: allow-untyped-defs +import builtins +import logging +import operator +import typing +import warnings +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union + +import torch +import torch.export._trace +from torch import _C +from torch._export.passes.replace_quantized_ops_with_standard_ops_pass import ( + replace_quantized_ops_with_standard_ops, +) +from torch.export.exported_program import ExportedProgram +from torch.export.graph_signature import ( + ConstantArgument, + CustomObjArgument, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + TensorArgument, +) +from torch.fx import subgraph_rewriter + + +log = logging.getLogger(__name__) + + +def _get_param_count_list(method_graph, args_params): + param_count_list = [] + for input_, arg_params_ in zip(method_graph.inputs(), args_params): + if "PackedParams" in str(input_.type()): + in_vars, _ = torch.jit._flatten(arg_params_) + param_count_list.append(len(in_vars)) + else: + param_count_list.append(arg_params_ is not None) + + return param_count_list + + +def _trace_and_get_graph_from_model(model, args): + # A basic sanity check: make sure the state_dict keys are the same + # before and after running the model. Fail fast! + orig_state_dict_keys = torch.jit._unique_state_dict(model).keys() + + # Disable Autocast cache because it replaces kernel's weight and bias + # by (undesired) constants. + # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665 + prev_autocast_cache_enabled = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(False) + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + model, + args, + strict=False, + _force_outplace=False, + _return_inputs_states=True, + ) + torch.set_autocast_cache_enabled(prev_autocast_cache_enabled) + + if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys(): + raise RuntimeError( + "state_dict changed after running the tracer; " + "something weird is happening in your model!" + ) + + return trace_graph, torch_out + + +def _create_jit_graph( + model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any] +) -> Tuple[torch.Graph, List["_C.IValue"], Any, Optional[torch.ScriptModule]]: + if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): + flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) + torch_out = None + + if isinstance(model, torch.jit.ScriptModule): + try: + graph = model.forward.graph # type: ignore[attr-defined] + except AttributeError as e: + raise RuntimeError("'forward' method must be a script method") from e + _C._jit_pass_onnx_function_substitution(graph) + freezed_module = _C._freeze_module( + typing.cast(_C.ScriptModule, model._c), preserveParameters=True + ) + module, params = _C._jit_onnx_list_model_parameters(freezed_module) + method_graph = module._get_method("forward").graph + args_params = tuple(args) + tuple(params) + param_count_list = _get_param_count_list(method_graph, args_params) + in_vars, _ = torch.jit._flatten(args_params) + graph = _C._propagate_and_assign_input_shapes( + method_graph, tuple(in_vars), param_count_list, False, False + ) + return graph, params, torch_out, module + + # torch.jit.ScriptFunction + params = [] + graph = model.graph + _C._jit_pass_onnx_function_substitution(graph) + param_count_list = _get_param_count_list(graph, args) + graph = _C._propagate_and_assign_input_shapes( + graph, flattened_args, param_count_list, False, False + ) + return graph, params, torch_out, None + + graph, torch_out = _trace_and_get_graph_from_model(model, args) + _C._jit_pass_onnx_lint(graph) + state_dict = torch.jit._unique_state_dict(model) + params = list(state_dict.values()) + graph_inputs = list(graph.inputs()) + user_input_num = len(graph_inputs) - len(state_dict) + param_names = list(state_dict.keys()) + for i, inp in enumerate(graph_inputs): + if i >= user_input_num: + inp.setDebugName(param_names[i - user_input_num]) + _C._jit_pass_onnx_function_substitution(graph) + return graph, params, torch_out, None + + +def list_add(a, b): + return a + b + + +def list_append(container, element): + return container + [element] + + +def execute_subgraph_from_prim_loop( + subgraph, iter_idx, len_loop_local_arguments, *args, **kwargs +): + """ + subgraph: GraphModule from sub-block. + iter_idx: The index of interation. + len_loop_local_arguments: The number of loop local arguments in args. + """ + + # Loop local variables. TS graph create those as inputs because their values + # are updated inside the loop. + loop_local_args = args[:len_loop_local_arguments] + # Global variables that are not passed in as inputs to the loop sub-blocks + # but are directly used. Most of time, their values are not updated, but + # the only exception is when there are some operations that perform inplace + # updates. + global_args = args[len_loop_local_arguments:] + return subgraph(*global_args, iter_idx, *loop_local_args, **kwargs) + + +def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule): + def pattern(im, dim, scale): + sym_size_int = torch.ops.aten.sym_size.int(im, dim) + scalar_tensor = torch.ops.aten.scalar_tensor(sym_size_int) + div_scalar_mode = torch.ops.aten.div.Scalar_mode( + scalar_tensor, scale, rounding_mode="trunc" + ) + int_tensor = torch.ops.aten.Int.Tensor(div_scalar_mode) + return int_tensor + + def replacement(im, dim, scale): + sym_size_int = torch.ops.aten.sym_size.int(im, dim) + return sym_size_int // scale + + replaced_patterns = subgraph_rewriter.replace_pattern(gm, pattern, replacement) + + +def is_valid_for_codegen(name): + if len(name) == 0: + raise RuntimeError("Empty argument name for codegen") + if name[0].isdigit(): + return False + return True + + +def normalize_name(name: str, prefix: str = "rename") -> str: + name = name.replace(".", "_") + if is_valid_for_codegen(name): + return name + return f"{prefix}_{name}" + + +def ir_name_to_func_name(name: str) -> str: + """prim::If -> convert_prim_If""" + name_list = name.split("::") + return "convert_" + "_".join(name_list) + + +def get_node_as_placeholder_or_get_attr(fx_graph, name, is_top_level_graph): + if is_top_level_graph: + return fx_graph.get_attr(name) + return fx_graph.placeholder(name) + + +_TORCH_DTYPE_TO_ENUM = { + torch.uint8: 0, + torch.int8: 1, + torch.int16: 2, + torch.int32: 3, + torch.int64: 4, + torch.float16: 5, + torch.float32: 6, + torch.float64: 7, + torch.complex32: 8, + torch.complex64: 9, + torch.complex128: 10, + torch.bool: 11, + torch.qint8: 12, + torch.quint8: 13, + torch.bfloat16: 15, +} + +_TORCH_ENUM_TO_DTYPE = {value: key for key, value in _TORCH_DTYPE_TO_ENUM.items()} + + +def get_dtype_as_int(tensor): + """ + prim::dtype has the signature "Tensor a) -> int", where it gets the dtype of + the tensor and returns the integer corresponding to this dtype based on the + enum in ScalarType.h + """ + dtype = tensor.dtype + if dtype not in _TORCH_DTYPE_TO_ENUM: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _TORCH_DTYPE_TO_ENUM[dtype] + + +# Those operators will be automatically populated to a instance method +# of TS2FXGraphConverter with name convert__(). +# Please check __init__ for method population implementations. +kind_to_standard_operators = { + "prim::max": builtins.max, + "prim::min": builtins.min, + "prim::TupleIndex": operator.getitem, + "aten::__is__": operator.is_, + "aten::__isnot__": operator.is_not, + "aten::__not__": operator.not_, + "aten::__contains__": operator.contains, + "prim::dtype": get_dtype_as_int, + "aten::len": len, + # Mapping from specialized op to its symbolic counterpart. + # They currently do not have any other overrides. + "aten::numel": torch.ops.aten.sym_numel, + "aten::size": torch.ops.aten.sym_size, + "aten::storage_offset": torch.ops.aten.sym_storage_offset, + "aten::stride": torch.ops.aten.sym_stride, +} + + +def get_ir_value_parent_name_and_attr_name(node): + irv_parent_name, irv_name = node.input().debugName(), node.output().debugName() + attr_name = node.s("name") + return irv_name, irv_parent_name, attr_name + + +def construct_fqn(ir, ref_map, name_map): + name_list = [] + while ir in ref_map: + name_list.append(name_map[ir]) + ir = ref_map[ir] + return ".".join(reversed(name_list)) + + +def get_block_to_lifted_attrs(graph: torch._C.Graph) -> Dict[torch._C.Block, Set[str]]: + """ + Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes. + When a graph has control flow, the graph will be divided into multiple blocks. We want to convert + each block to a graph which will be passed into torch.cond. A restriction for torch.cond is that model + parameters/buffers are expected to be lifted as inputs to the subgraphs. Before converting the model, + we will run this pass which will: + 1. Figure out which params/buffers are used within blocks through tracing the GetAttr calls. + 2. Process the graph bottom up to find the lifted attributes of each block by taking the union + of the attributes used in the current block, and the lifted attributes of all its child blocks. + + Returns: + A mapping of blocks to a set of FQNs of its lifted attributes. + """ + + # A map from a block to its expected to be lifted arguments. + blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = {} + + # Reference map stores the input (i.e., src) and output (i.e., dest) IR of a + # GetAttr node. By traversing this reference map, we can figure out the + # full IR aliasing pass and figure out the FQN of an attribute. + # E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1" + node_to_parent_map: Dict[str, str] = {} + + # Used for reconstructing the FQN of an attribute based on the reference map. + # In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR + # This name map stores which attribute name is called for a src IR --> dest IR action. + # E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear" + node_to_attr_name: Dict[str, str] = {} + + def _dfs_get_attr_dependency(entry): + """ + First DFS path to construct reference map and name map. + """ + for node in entry.nodes(): + if node.kind() == "prim::GetAttr": + ( + irv_name, + irv_parent_name, + attr_name, + ) = get_ir_value_parent_name_and_attr_name(node) + node_to_parent_map[irv_name] = irv_parent_name + node_to_attr_name[irv_name] = attr_name + for block in node.blocks(): + _dfs_get_attr_dependency(block) + + def _map_blocks_to_lifted_attrs(entry): + """ + Walk the graph in a bottom-up fashion to build the expected to be + lifted arguments for each block. + """ + arguments: Set[str] = set() + for node in entry.nodes(): + for block in node.blocks(): + # Recursively build. + arguments = arguments.union(_map_blocks_to_lifted_attrs(block)) + if node.kind() == "prim::GetAttr": + irv_name = node.output().debugName() + # Skip for intermediate GetAttr, which will anyway not result a FQN. + # E.g., node_to_parent_name: {"%3": "%2", "%2": "%1"} + # node_to_attr_name: {"%3": "weight", "%2": "linear", "%1": "self"} + # There is only one FQN %3-->%2-->%1: self.linear.weight + # %2-->%1 is not a FQN: self.linear + if irv_name not in set(node_to_parent_map.values()): + arguments.add( + construct_fqn(irv_name, node_to_parent_map, node_to_attr_name) + ) + if not isinstance(entry, torch._C.Graph): # Skip the top level. + blocks_to_lifted_attrs[entry] = arguments + return arguments + + _dfs_get_attr_dependency(graph) + _map_blocks_to_lifted_attrs(graph) + + return blocks_to_lifted_attrs + + +def get_attribute_fqn_from_ts_node( + name_to_attribute_fqn: Dict[str, str], node: torch._C.Node +) -> str: + def get_attr(name: str): + if name in name_to_attribute_fqn: + return name_to_attribute_fqn[name] + else: + raise ValueError(f"Attribute {name} not found") + + if node.kind() == "prim::SetAttr": + input_name = next(node.inputs()).debugName() + elif node.kind() == "prim::GetAttr": + input_name = node.input().debugName() + else: + raise RuntimeError( + f"Unexpected node kind when getting attribute fqn. node: {node} " + ) + + attr_name = node.s("name") + root_attr_name = get_attr(input_name) + attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name + + return attr_fqn + + +def get_op_overload(node: torch._C.Node): + schema_str = node.schema() + assert schema_str != "(no schema)", f"got empty schema for {node}" + schema: torch._C.FunctionSchema = torch._C.parse_schema(schema_str) + ns, op_name = str(schema.name).split("::") + override = schema.overload_name + + try: + op_overload_mod = getattr(torch.ops, ns) + op_overload_packet = getattr(op_overload_mod, op_name) + if override: + op_overload = getattr(op_overload_packet, override) + else: + op_overload = op_overload_packet.default + except Exception as e: + raise RuntimeError( + f"Unable to find operator {node.kind()} with schema {node.schema()}" + ) from e + + return op_overload + + +class TS2FXGraphConverter: + def __init__( + self, + ts_graph: Union[torch._C.Graph, torch._C.Block], + name_to_param: Dict[str, torch.Tensor], + name_to_buffer: Dict[str, torch.Tensor], + blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]], + name_to_non_tensor_attribute: Dict[str, Any], + name_to_constant: Dict[str, Any], + ): + self.ts_graph = ts_graph + self.name_to_param = name_to_param + self.name_to_buffer = name_to_buffer + + self.fx_graph: torch.fx.Graph = torch.fx.Graph() + self.input_specs: List[InputSpec] = [] + self.output_specs: List[OutputSpec] = [] + + self.name_to_node: Dict[ + str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]] + ] = {} + self.name_to_constant: Dict[str, Any] = name_to_constant + + # Mapping from torchscript node output name to attribute fully qualified name + self.name_to_attribute_fqn: Dict[str, str] = {} + + # Mapping from fully qualified name to real values or a fx graph node + # During convert, this represents the current value of a non-tensor attribute + # One use case is: + # def forward(self, x): + # c1 = self.count + # self.count += 1 + # c2 = self.count + # return x + c1 + c2 + self.name_to_non_tensor_attribute_node: Dict[str, Any] = {} + + # Mapping from fully qualified name to initial real values inputs + # We separate it from self.name_to_non_tensor_attribute_node since + # we need initial real value input when we construct fx.GraphModule + self.name_to_non_tensor_attribute: Dict[str, Any] = name_to_non_tensor_attribute + + self.subgraphs: Dict[str, torch.fx.GraphModule] = {} + + self.blocks_to_lifted_attrs = blocks_to_lifted_attrs + + # Populate methods for the standard operators. + for k in kind_to_standard_operators.keys(): + handler_func_name = ir_name_to_func_name(k) + # Create an indirect function call: + # convert__ --> lambda node: _convert_standard_operator(node) + setattr( + self, + handler_func_name, + lambda node: self._convert_standard_operators(node), + ) + + # This stores a list of return results that do not appear in the original TS + # graph's outputs. The reason we maintain this is because some operations in the sub-block + # might have inplace updates to the variable defined in the parent fx graph. After + # the execution of that sub-block, the variable defined in the parent fx graph also + # needs to be updated. + self.name_update_from_subblock_to_parent: Set[str] = set() + + def _is_get_attr_node(self, fqn): + return ( + fqn in self.name_to_buffer + or fqn in self.name_to_param + or ( + fqn in self.name_to_constant + and isinstance(self.name_to_constant[fqn], torch.ScriptObject) + ) + ) + + def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]): + subgraph_nodes, subgraph_converters = [], [] + for block in node.blocks(): + subgraph_converter = TS2FXGraphConverter( + block, + self.name_to_param, + self.name_to_buffer, + self.blocks_to_lifted_attrs, + {}, + self.name_to_constant, + ) + subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn + + for block_arg in arguments: + normalized_block_arg_name = normalize_name(block_arg) + placeholder_node = subgraph_converter.fx_graph.placeholder( + normalized_block_arg_name + ) + subgraph_converter.name_to_node[block_arg] = placeholder_node + + subgraph = subgraph_converter.convert() + subgraph_name = self.add_subgraph(subgraph) + subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name)) + subgraph_converters.append(subgraph_converter) + return subgraph_nodes, subgraph_converters + + def _identify_inputs_as_arguments(self, entry): + """ + Identify inputs from the innermost sub-block. This is needed + for nested sub-blocks when the input is hidden in the nested sub-block. + E.g., example IR of input is hidden in the nested sub-block. + Graph[x.1] + %1 = ... + Block[] + Block[x.1] + %2 = x.1 ... + """ + arguments: Set[str] = set() + for block in entry.blocks(): + for block_node in block.nodes(): + for block_node_in in block_node.inputs(): + if ( + block_node_in.debugName() in self.name_to_node + and block_node_in.debugName() not in self.name_to_attribute_fqn + ): + arguments.add(block_node_in.debugName()) + arguments = arguments.union( + self._identify_inputs_as_arguments(block_node) + ) + return arguments + + def is_top_level_graph(self): + return isinstance(self.ts_graph, torch._C.Graph) + + def add_subgraph(self, subgraph) -> str: + name = f"subgraph_{len(self.subgraphs)}" + self.subgraphs[name] = subgraph + return name + + def get_args_kwargs(self, node: torch._C.Node, schema): + args = [] + kwargs = {} + for input, schema_arg in zip(node.inputs(), schema.arguments): + if schema_arg.kwarg_only: + kwargs[schema_arg.name] = self.get_fx_value_by_ir_value(input) + else: + args.append(self.get_fx_value_by_ir_value(input)) + + return tuple(args), kwargs + + def get_fx_value_by_ir_value(self, value: torch._C.Value): + value_name = value.debugName() + + if value_name in self.name_to_node: + input_node = self.name_to_node[value_name] + return input_node + elif value_name in self.name_to_constant: + if isinstance(self.name_to_constant[value_name], torch.ScriptObject): + return self.fx_graph.get_attr(value_name) + return self.name_to_constant[value_name] + else: + raise ValueError(f"Input {value_name} not found") + + def get_fx_value_by_fqn(self, name): + if name in self.name_to_node: + fx_node = self.name_to_node[name] + elif name in self.name_to_constant: + fx_node = self.name_to_constant[name] + elif name in self.name_to_non_tensor_attribute_node: + fx_node = self.name_to_non_tensor_attribute_node[name] + elif name in self.name_to_non_tensor_attribute: + fx_node = self.name_to_non_tensor_attribute[name] + else: + raise ValueError(f"Attribute {name} not found") + return fx_node + + def convert(self) -> torch.fx.GraphModule: + self.convert_graph_inputs() + + for node in self.ts_graph.nodes(): + self.convert_node(node) + + self.convert_graph_outputs() + + # Pass parameter and buffer to the root for lookup. + gm = torch.fx.GraphModule( + { + **self.subgraphs, + **self.name_to_param, + **self.name_to_buffer, + **self.name_to_non_tensor_attribute, + **self.name_to_constant, + }, + self.fx_graph, + ) + + inplace_optimize_sym_size_div(gm) + + gm.graph.lint() + + return gm + + def convert_graph_inputs(self): + for graph_input in self.ts_graph.inputs(): + name = graph_input.debugName() + + if name in self.name_to_param: + normalized_name = normalize_name(name) + self.input_specs.append( + InputSpec( + InputKind.PARAMETER, + arg=TensorArgument(name=normalized_name), + target=name, + ) + ) + fx_node = get_node_as_placeholder_or_get_attr( + self.fx_graph, name, self.is_top_level_graph() + ) + elif name in self.name_to_buffer: + normalized_name = normalize_name(name) + self.input_specs.append( + InputSpec( + InputKind.BUFFER, + arg=TensorArgument(name=normalized_name), + target=name, + persistent=True, + ) + ) + fx_node = get_node_as_placeholder_or_get_attr( + self.fx_graph, name, self.is_top_level_graph() + ) + elif name in self.name_to_constant: + assert isinstance( + self.name_to_constant[name], torch.ScriptObject + ), "Input conversion only handles ScriptObject" + normalized_name = normalize_name(name) + self.input_specs.append( + InputSpec( + InputKind.CUSTOM_OBJ, + arg=CustomObjArgument( + name=normalized_name, class_fqn=normalized_name + ), + target=name, + persistent=False, + ) + ) + fx_node = get_node_as_placeholder_or_get_attr( + self.fx_graph, name, self.is_top_level_graph() + ) + elif isinstance(graph_input.type(), torch.ClassType): + # Directly skip inputs that are ScriptObject but not used in the graph. + continue + else: + normalized_name = normalize_name(name, prefix="input") + self.input_specs.append( + InputSpec( + InputKind.USER_INPUT, + arg=TensorArgument(name=normalized_name), + target=name, + ) + ) + fx_node = self.fx_graph.placeholder(normalized_name) + + self.name_to_node[name] = fx_node + + def convert_aten_Float(self, node: torch._C.Node): + def to_float_tensor(t): + return t.to(dtype=torch.float).item() + + inp_list = [ + self.get_fx_value_by_ir_value(inp) for inp in node.inputs() + ] # noqa: C416 + fx_node = self.fx_graph.call_function( + to_float_tensor, + tuple(inp_list), + ) + self.name_to_node[node.output().debugName()] = fx_node + + def convert_aten_tensor(self, node: torch._C.Node): + """aten::tensor creates a constant tensor ad-hoc --> GetAttr""" + args, kwargs = self.get_args_kwargs(node, torch.ops.aten.tensor.default._schema) + + for k in kwargs: + if k == "requires_grad": + kwargs[k] = bool(kwargs[k]) # 0 -> False, 1 -> True + + to_tensor = ( + torch.tensor + if all(isinstance(a, int) for a in args) + else torch._refs.tensor + ) + + def target(*args, **kwargs): + if "dtype" in kwargs and kwargs["dtype"] is not None: + kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]] + return to_tensor(*args, **kwargs) + + # def to_dynamic_tensor(*args, **kwargs): + # if "dtype" in kwargs and kwargs["dtype"] is not None: + # kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]] + # return torch._refs.tensor(*args, **kwargs) + + output_name = node.output().debugName() + fx_node = self.fx_graph.call_function(target, args, kwargs) + self.name_to_node[output_name] = fx_node + + def convert_aten_append(self, node: torch._C.Node): + # special handle python list append: "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)" + + # inplace append to the list!! This is kinda crazy, as we are inplace mutating the list + # This makes the converter "non-functional", and the result depends on the order of the nodes being converter + # In a sense, the converter now becomes an stateful interpreter + warnings.warn( + "Converting aten::append.t, which is a inplace mutation of the list. " + "This makes the converter non-functional: the result depends on the order of the append nodes being converter!" + ) + + args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs()) + fx_node = self.fx_graph.call_function(list_append, args) + self.name_to_node[node.output().debugName()] = fx_node + + # inplace mutate arg[0], which is the python list + self.name_to_node[node.inputsAt(0).debugName()] = fx_node + + # Variables that need to be updated to parent module. + if not self.is_top_level_graph() and args[0].op == "placeholder": + self.name_update_from_subblock_to_parent.add(node.inputsAt(0).debugName()) + + def convert_prim_Constant(self, node: torch._C.Node): + name = node.output().debugName() + + value: Any = None + if node.hasAttribute("value"): + constant_kind = node.kindOf("value") + if constant_kind == "i": + value = node.i("value") + elif constant_kind == "f": + value = node.f("value") + elif constant_kind == "s": + value = node.s("value") + elif constant_kind == "t": + alias_name = ( + f"lifted_tensor_{name}" # Follow naming convention from EP tracing. + ) + fx_node = self.fx_graph.get_attr(alias_name) + self.name_to_node[name] = fx_node + name, value = alias_name, node.t("value") + elif constant_kind == "ival": + value = node.ival("value") + else: + raise ValueError(f"Unsupported constant type: {node.kindOf('value')}") + else: + value = None + + self.name_to_constant[name] = value + + def convert_prim_CallMethod(self, node: torch._C.Node): + inp_list = [ + self.get_fx_value_by_ir_value(inp) for inp in node.inputs() + ] # noqa: C416 + fx_node = self.fx_graph.call_method( + node.s("name"), + tuple(inp_list), + ) + self.name_to_node[node.output().debugName()] = fx_node + + def convert_prim_device(self, node: torch._C.Node): + input_type = node.input().type() + if input_type.isSubtypeOf(torch._C.TensorType.get()): + device = input_type.device() # type: ignore[attr-defined] + output_name = node.output().debugName() + self.name_to_constant[output_name] = device + else: + raise ValueError(f"Unsupported JitType ({input_type}) when get device") + + def convert_prim_GetAttr(self, node: torch._C.Node): + # Build fully qulified name + attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node) + output_name = node.output().debugName() + self.name_to_attribute_fqn[output_name] = attr_fqn + + if self.is_top_level_graph(): + if self._is_get_attr_node(attr_fqn): + # We insert a get_attr node due to two reasons. + # First, ts graph does not lift tensor constants as input nodes. So + # tensor constants may be ignored by in convert_graph_inputs(). + # Second, attr_fqn may have been written to via SetAttr. Two + # GetAttr may give different values. + self.name_to_node[output_name] = self.fx_graph.get_attr(attr_fqn) + else: + if attr_fqn not in self.name_to_non_tensor_attribute_node: + self.name_to_non_tensor_attribute_node[ + attr_fqn + ] = self.name_to_non_tensor_attribute[attr_fqn] + self.name_to_node[output_name] = self.name_to_non_tensor_attribute_node[ + attr_fqn + ] + else: + # Special support for if blocks which do not allow SetAttr TorchScript + # node and get_attr FX Graph Node. + if self._is_get_attr_node(attr_fqn): + self.name_to_node[output_name] = self.name_to_node[attr_fqn] + + def convert_prim_SetAttr(self, node: torch._C.Node): + attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node) + attr_value = tuple(node.inputs())[1] + ts_graph_tensor_input = self.get_fx_value_by_ir_value(attr_value) + if self._is_get_attr_node(attr_fqn): + fx_attr_node = self.fx_graph.get_attr(attr_fqn) + self.fx_graph.call_function( + torch.Tensor.copy_, (fx_attr_node, ts_graph_tensor_input) + ) + else: + self.name_to_non_tensor_attribute_node[attr_fqn] = ts_graph_tensor_input + + def convert_call_function_op(self, node: torch._C.Node): + target = get_op_overload(node) + + args, kwargs = self.get_args_kwargs(node, target._schema) + + fx_node = self.fx_graph.call_function(target, args, kwargs) + + # TODO: covnert sourceRange() into stack_trace + # fx_node.meta["stack_trace"] = node.sourceRange() + + if node.outputsSize() == 1: + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + else: + for i, outp in enumerate(node.outputs()): + output_name = outp.debugName() + next_fx_node = self.fx_graph.call_function( + operator.getitem, (fx_node, i) + ) + self.name_to_node[output_name] = next_fx_node + + def convert_prim_TupleConstruct(self, node: torch._C.Node): + self._convert_prim_iterator(node) + + def convert_prim_ListConstruct(self, node: torch._C.Node): + self._convert_prim_iterator(node) + + def _convert_prim_iterator(self, node: torch._C.Node): + output_list = [] + for inp in node.inputs(): + output_list.append(self.get_fx_value_by_ir_value(inp)) + + output_name = node.output().debugName() + self.name_to_node[output_name] = output_list + + def convert_prim_DictConstruct(self, node: torch._C.Node): + output_dict = {} + k, v = None, None + for i, inp in enumerate(node.inputs()): + # We assume key value are stored in pair in the DictConstruct. + # The first element is the key and the following is the value. + if i % 2 == 0: + k = self.get_fx_value_by_ir_value(inp) + else: + v = self.get_fx_value_by_ir_value(inp) + assert ( + k is not None and v is not None + ), "DictConstruct has an empty key value pair." + output_dict[k] = v + k, v = None, None + + assert ( + k is None and v is None + ), "DictConstruct has an odd number of elements (violating our assumption)." + + output_name = node.output().debugName() + self.name_to_node[output_name] = output_dict + + def convert_prim_ListUnpack(self, node: torch._C.Node): + self._convert_prim_unpack_iterator(node) + + def convert_prim_TupleUnpack(self, node: torch._C.Node): + self._convert_prim_unpack_iterator(node) + + def _convert_prim_unpack_iterator(self, node: torch._C.Node): + # Single input and multiple outputs for unpacking. + for i, outp in enumerate(node.outputs()): + outp_name = outp.debugName() + inp = self.get_fx_value_by_ir_value(node.input()) + fx_node = self.fx_graph.call_function(operator.getitem, (inp, i)) + self.name_to_node[outp_name] = fx_node + + def convert_aten_Int(self, node: torch._C.Node): + # converts aten::Int as aten._to_copy + aten::_local_scalar_dense + target = torch.ops.aten._to_copy.default + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) + to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32}) + + fx_node = self.fx_graph.call_function( + torch.ops.aten._local_scalar_dense.default, (to_copy_node,) + ) + + # TODO: covnert sourceRange() into stack_trace + # fx_node.meta["stack_trace"] = node.sourceRange() + + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_prim_NumToTensor(self, node: torch._C.Node): + # Converts prim::NumToTensor as aten.scalar_tensor. + # prim::NumToTensor IRs are currently triggered by: + # .size() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L950 + # .numel() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L971 + # For both of those APIs, torch.jit.trace implicitly sets the output tensor type + # to be LongTensor. + target = torch.ops.aten.scalar_tensor + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) + + fx_node = self.fx_graph.call_function(target, args, {"dtype": torch.long}) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_prim_CreateObject(self, node: torch._C.Node): + output_name = node.output().debugName() + self.name_to_attribute_fqn[output_name] = "" + + def convert_aten__convolution(self, node: torch._C.Node): + # converts aten::_convolution as aten.convolution, since aten::_convolution + # doesn't have a meta function + target = torch.ops.aten.convolution.default + args, kwargs = self.get_args_kwargs(node, target._schema) + + fx_node = self.fx_graph.call_function(target, args, kwargs) + + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_aten_div(self, node: torch._C.Node): + target = get_op_overload(node) + schema = target._schema + + args, kwargs = self.get_args_kwargs(node, schema) + + # converts aten::div.Tensor_mode(x, tensor_constant) + # as aten.div.Scalar_mode(x, tensor_constant.item()) + if schema.overload_name == "Tensor_mode": + arg1_name = args[1].name + if arg1_name in self.name_to_constant and isinstance( + self.name_to_constant[arg1_name], torch.Tensor + ): + tensor_constant = self.name_to_constant[arg1_name] + if tensor_constant.numel() == 1: + updated_args = list(args) + updated_args[1] = self.name_to_constant[arg1_name].item() + + fx_node = self.fx_graph.call_function( + torch.ops.aten.div.Scalar_mode, + tuple(updated_args), + kwargs, + ) + + # TODO: covnert sourceRange() into stack_trace + # fx_node.meta["stack_trace"] = node.sourceRange() + + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + return + + self.convert_call_function_op(node) + + def convert_aten___getitem__(self, node: torch._C.Node): + input_container, index = tuple( + self.get_fx_value_by_ir_value(input) for input in node.inputs() + ) + fx_node = self.fx_graph.call_function( + operator.getitem, (input_container, index) + ) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_aten_to(self, node: torch._C.Node): + target = get_op_overload(node) + args, kwargs = self.get_args_kwargs(node, target._schema) + + # special handle aten.to.dtype and aten.to.prim_dtype followed by inplace_mutation_op + # coz aten.to + inplace_mutation_op pattern would trigger + # "cannot mutate tensors with frozen storage" functionalization error. + # To work around the issue, we override the copy to be True, so that the output + # is for sure not an alias of input + if target == torch.ops.aten.to.dtype or target == torch.ops.aten.to.prim_dtype: + user_nodes = [use.user for use in node.output().uses()] + user_targets = [ + get_op_overload(user_node) + for user_node in user_nodes + if user_node.schema() != "(no schema)" + ] + has_mutable_target = any( + target._schema.is_mutable for target in user_targets + ) + + if has_mutable_target: + assert len(args) >= 4 + new_args = list(args) + new_args[3] = True # copy, override to True + fx_node = self.fx_graph.call_function( + torch.ops.aten.to.dtype, tuple(new_args) + ) + # temp hack to work around the issue https://github.com/pytorch/pytorch/issues/131679 + # When this issue is fixed, the clone node would be no longer needed + clone_node = self.fx_graph.call_function( + torch.ops.aten.clone.default, (fx_node,) + ) + output_name = node.output().debugName() + self.name_to_node[output_name] = clone_node + return + + self.convert_call_function_op(node) + + def convert_aten_add(self, node: torch._C.Node): + if node.schema() == "(no schema)": + if isinstance(node.inputsAt(0).type(), torch.ListType) and isinstance( + node.inputsAt(1).type(), torch.ListType + ): + target = torch.ops.aten.add.t + else: + raise RuntimeError(f"unable to determind the target for {node}") + else: + target = get_op_overload(node) + + if target == torch.ops.aten.add.t: + # special handle python list/tuple add: "aten::add.t(t[] a, t[] b) -> t[]" for + # RuntimeError: aten::add() Expected a value of type 'List[t]' for argument 'a' but instead found type 'immutable_list'. + args, kwargs = self.get_args_kwargs(node, target._schema) + output_name = node.output().debugName() + self.name_to_node[output_name] = self.fx_graph.call_function(list_add, args) + else: + self.convert_call_function_op(node) + + def _check_prim_loop_support(self, node): + inputs = list(node.inputs()) + + # TODO: (1/N) stage. + if inputs[0].debugName() not in self.name_to_constant: + raise RuntimeError( + "prim::Loop currently cannot run with dynamic value of number of iterations." + ) + + # Make sure the condition is not updated in the subblock. + subblock = next(node.blocks()) + condition_output_name = next(subblock.outputs()).debugName() + for node in subblock.nodes(): + if ( + node.outputsSize() == 1 + and node.output().debugName() == condition_output_name + ): + raise RuntimeError( + "prim::Loop currently cannot run with dynamic value of condition." + ) + if node.outputsSize() >= 2: + for outp in node.outputs(): + if outp.debugName() == condition_output_name: + raise RuntimeError( + "prim::Loop currently cannot run with dynamic value of condition." + ) + + def convert_prim_Loop(self, node: torch._C.Node): + inputs = list(node.inputs()) + self._check_prim_loop_support(node) + + num_iterations = self.get_fx_value_by_ir_value(inputs[0]) + + # Find inputs. + loop_local_arguments = [inp.debugName() for inp in inputs[2:]] + + global_arguments = self._identify_inputs_as_arguments(node) + + # Lift parameters as inputs. + for block in node.blocks(): + global_arguments = global_arguments.union( + self.blocks_to_lifted_attrs[block] + ) + + global_arguments = list(global_arguments) + + subgraph_nodes, subgraph_converters = self._convert_block_to_subgraph( + node, global_arguments + ) + + assert len(subgraph_nodes) == 1 + subgraph_converter = subgraph_converters[0] + if not self.is_top_level_graph(): + self.name_update_from_subblock_to_parent = ( + self.name_update_from_subblock_to_parent.union( + subgraph_converter.name_update_from_subblock_to_parent + ) + ) + + fx_block_args = [ + self.get_fx_value_by_fqn(name) + for name in loop_local_arguments + global_arguments + ] + for iter_idx in range(num_iterations): + loop_node = self.fx_graph.call_function( + execute_subgraph_from_prim_loop, + # Check execute_node function for the expected arguments order. + ( + subgraph_nodes[0], + iter_idx, + len(loop_local_arguments), + *fx_block_args, + ), + {}, + ) + + # Update the value of loop local variables. + if node.outputsSize() >= 1: + for i, outp in enumerate(node.outputs()): + output_name = outp.debugName() + self.name_to_node[output_name] = self.fx_graph.call_function( + operator.getitem, + ( + loop_node, + i + 1, + ), # + 1 because the 0th element is the condition. + ) + fx_block_args[i] = self.name_to_node[output_name] + + # Update the value of global variables, whose values are modified inplace. + for i, name in enumerate( + subgraph_converter.name_update_from_subblock_to_parent + ): + self.name_to_node[name] = self.fx_graph.call_function( + operator.getitem, + ( + loop_node, + i + node.outputsSize() + 1, + ), # + 1 because the 0th element is the condition. + ) + global_argument_index = global_arguments.index(name) + fx_block_args[ + i + node.outputsSize() + global_argument_index + ] = self.name_to_node[name] + + def _check_set_attr_in_if_block(self, if_node: torch._C.Node): + for block in if_node.blocks(): + for node in block.nodes(): + if node.kind() == "prim::SetAttr": + raise RuntimeError( + "During converting prim::If to torch.cond, found prim::SetAttr op" + " which is not supported yet. Please file an issue if you come " + "across this error." + ) + + def convert_prim_If(self, node: torch._C.Node): + self._check_set_attr_in_if_block(node) + + inputs = list(node.inputs()) + assert len(inputs) == 1 + predicate = self.get_fx_value_by_ir_value(inputs[0]) + + # Find inputs. + arguments = self._identify_inputs_as_arguments(node) + + # Lift parameters as inputs. + for block in node.blocks(): + arguments = arguments.union(self.blocks_to_lifted_attrs[block]) + + arguments = list(arguments) + subgraph_nodes, _ = self._convert_block_to_subgraph(node, arguments) + + assert len(subgraph_nodes) == 2 + + fx_block_args = [self.get_fx_value_by_fqn(name) for name in arguments] + + args = ( + predicate, + subgraph_nodes[0], + subgraph_nodes[1], + tuple(fx_block_args), + ) + + cond_node = self.fx_graph.call_function(torch.cond, args, {}) + + # prim::If can also have zero output. + if node.outputsSize() == 1: + output_name = node.output().debugName() + self.name_to_node[output_name] = cond_node + elif node.outputsSize() > 1: + for i, output in enumerate(node.outputs()): + output_name = output.debugName() + getitem = self.fx_graph.call_function(operator.getitem, (cond_node, i)) + self.name_to_node[output_name] = getitem + + def convert_aten_Bool(self, node: torch._C.Node): + self._convert_as_noop(node) + + def convert_prim_Enter(self, node: torch._C.Node): + # export generally treats prim::Enter as noop + # The only context manager export supports is aten::enable_grad. + # Unfortunately, TorchScript does not support aten::enable_grad yet. + # TODO: support aten::enable_grad in both TorchScript and Converter. + return + + def convert_prim_Exit(self, node: torch._C.Node): + # export treats prim::Exit as noop + return + + def _convert_as_noop(self, node: torch._C.Node): + # Converts the node as a no-op by mapping its output node as arg[0] + + target = get_op_overload(node) + schema = target._schema + + args, kwargs = self.get_args_kwargs(node, schema) + + output_name = node.output().debugName() + self.name_to_node[output_name] = args[0] + + def convert_profiler__record_function_exit(self, node: torch._C.Node): + # _record_function_exit has side effect so we keep it in fx.graph + # currently, _record_function_enter_new and _record_function_exit are + # discarded during `retrace_as_exported_program`. + target = torch.ops.profiler._record_function_exit + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) + self.fx_graph.call_function(target, args) + + def convert_prim_tolist(self, node: torch._C.Node): + # prim::tolist cannot be supported by `_convert_standard_operators` + # since it requires call_method instead of call_function. + target = "tolist" + args = (self.get_fx_value_by_ir_value(next(node.inputs())),) + fx_node = self.fx_graph.call_method(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_prim_Uninitialized(self, node: torch._C.Node): + # `prim::Uninitialized` is inserted by the compiler when it can prove + # the value will never be used. It can be introduced by exceptions, + # breaks, continues, and returns. + # So we add a dummy constant to the graph. + output_name = node.output().debugName() + self.name_to_constant[output_name] = torch.Tensor() + + def _convert_standard_operators(self, node: torch._C.Node): + target = kind_to_standard_operators[node.kind()] + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) + fx_node = self.fx_graph.call_function(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_node(self, node: torch._C.Node): + node_kind = node.kind() + + # Get handler based on namespace and operator name. + # Provide a default node handler as well in case we don't find + # matching converter for that. + handler_func_name = ir_name_to_func_name(node_kind) + handler_func = getattr(self, handler_func_name, self.convert_call_function_op) + + # str calls print function implemented in CPP. To avoid repeating + # the entire logic here, we simply keep first line from node string (getting rid + # of sub-blocks IR prints). + node_str = "".join(str(node).split("\n")[:1]) + log.debug("[%s] converts [%s]", handler_func.__name__, node_str) + try: + handler_func(node) + except Exception as e: + raise RuntimeError(f"TS2EPConverter failed for node {node_kind}") from e + + def convert_graph_outputs(self): + args = [] + outp_name_list = [outp.debugName() for outp in self.ts_graph.outputs()] + list( + self.name_update_from_subblock_to_parent + ) + for output_name in outp_name_list: + if output_name in self.name_to_node: + fx_node = self.name_to_node[output_name] + # TODO: Revisit this later after HigherOrderOp design changes. + # Currently, we cannot directly return input as output. + if ( + not self.is_top_level_graph() + and isinstance(fx_node, torch.fx.Node) + and fx_node.op == "placeholder" + ): + fx_node = self.fx_graph.call_function(torch.clone, (fx_node,)) + args.append(fx_node) + self.output_specs.append( + OutputSpec( + OutputKind.USER_OUTPUT, + arg=TensorArgument(name=output_name), + target=output_name, + ) + ) + elif output_name in self.name_to_constant: + args.append(self.name_to_constant[output_name]) + self.output_specs.append( + OutputSpec( + OutputKind.USER_OUTPUT, + arg=ConstantArgument( + name=output_name, value=self.name_to_constant[output_name] + ), + target=output_name, + ) + ) + else: + raise ValueError(f"Output {output_name} not found") + + if len(args) == 0: + # Sub-block of prim::If can have zero output. + self.fx_graph.output([]) + elif len(args) == 1: + self.fx_graph.output( + args[0] + ) # Get rid of an extra list wrapped around final output. + elif len(args) > 1: + self.fx_graph.output( + args + ) # For prim::Loop and prim::If with multiple outputs. + else: + # Sub-block of prim::Loop can have multiple outputs. + self.fx_graph.output(args) + + +class ExplainTS2FXGraphConverter(TS2FXGraphConverter): + """ + Run TS2FXGraphConverter in an explain mode. It collects all failed operators conversions + and provide that information to users. In order to collect all failed conversions, it + also mocks some internal attributes (e.g., name_to_node). + """ + + class _DictMock(dict): + def __init__(self, dict_data, mock_value): + super().__init__(dict_data) + self.mock_value = mock_value + + def __getitem__(self, key): + # If the original dictionary has the key, return its value. + # Otherwise, return the mock value. + if not super().__contains__(key): + return self.mock_value + return super().__getitem__(key) + + def __contains__(self, key): + return True + + def __init__( + self, + ts_graph: Union[torch._C.Graph, torch._C.Block], + name_to_param: Dict[str, torch.Tensor], + name_to_buffer: Dict[str, torch.Tensor], + blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]], + name_to_non_tensor_attribute: Dict[str, Any], + name_to_constant: Dict[str, Any], + ): + super().__init__( + ts_graph, + name_to_param, + name_to_buffer, + blocks_to_lifted_attrs, + name_to_non_tensor_attribute, + name_to_constant, + ) + + # Data to keep track of unsupported nodes. + self.unsupported_node_list: List[torch._C.Node] = [] + + # Add mock to needed attributes. + self.name_to_node = ExplainTS2FXGraphConverter._DictMock( + self.name_to_node, + # Dummy node. + torch.fx.Node( + None, # type: ignore[arg-type] + "mock", + "call_function", + lambda: None, + (), + {}, + ), + ) + + def explain(self): + self.convert_graph_inputs() + for node in self.ts_graph.nodes(): + self.convert_node(node) + self.convert_graph_outputs() + + def convert_node(self, node): + try: + super().convert_node(node) + except Exception as e: + self.unsupported_node_list.append(node) + + +@contextmanager +def disable_logging(log): + disabled = log.disabled + log.disabled = True + try: + yield + finally: + log.disabled = disabled + + +class TS2EPConverter: + # TorchScript model to ExportedProgram converter + def __init__( + self, + ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction], + sample_args: Tuple[Any, ...], + sample_kwargs: Optional[Dict[str, Any]] = None, + ): + self.ts_model = ts_model + self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args) + + self.sample_args = sample_args + self.sample_kwargs = sample_kwargs + + self.name_to_param: Dict[str, torch.Tensor] = {} + self.name_to_buffer: Dict[str, torch.Tensor] = {} + param_list = ( + list(self.ts_model.parameters()) + if not isinstance(self.ts_model, torch._C.ScriptFunction) + else [] + ) + if not isinstance(self.ts_model, torch._C.ScriptFunction): + for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr] + # Check if tensor belongs to any parameter. + if any( + (tensor == param).all() + for param in param_list + if tensor.shape == param.shape + ): + self.name_to_param[k] = tensor + else: + self.name_to_buffer[k] = tensor + + self.name_to_non_tensor_attributes: Dict[str, Any] = {} + self.name_to_constant: Dict[str, Any] = {} + + self.lift_get_attr() + + def convert(self) -> ExportedProgram: + log.info( + """ +TS2EPConverter logging starts from here. + +INFO: (TORCH_LOGS="export" ) + * Log TorchScript IR. + +DEBUG: (TORCH_LOGS="+export" ), additionally + * Log conversion IR by IR in a format of [] converts []. + """ + ) + log.info("TorchScript graph\n\n%s\n", self.ts_graph) + + blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph) + + graph_converter = TS2FXGraphConverter( + self.ts_graph, + self.name_to_param, + self.name_to_buffer, + blocks_to_lifted_attrs, + self.name_to_non_tensor_attributes, + self.name_to_constant, + ) + gm = graph_converter.convert() + + # Post-proccessing step to deal with quantized operators. + replace_quantized_ops_with_standard_ops(gm) + log.info("GraphModule: %s", gm.print_readable(print_output=False)) + + ep = self.retrace_as_exported_program( + gm, + graph_converter.name_to_constant, + ) + log.info("%s", ep) + + # Post-processing step to ensure ExportedProgram has the same state_dict as + # the original TorchScript model. Throw warnings for additionally populated + # state_dict entries. + if not isinstance(self.ts_model, torch._C.ScriptFunction): + for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr] + if k not in ep.state_dict: + warnings.warn( + f"Manually populate {k} into state_dict ExportedProgram, but it is never used by the ExportedProgram." + ) + ep.state_dict[k] = tensor + + return ep + + @disable_logging(log) + def explain(self, print_output=True): + blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph) + + graph_converter = ExplainTS2FXGraphConverter( + self.ts_graph, + self.name_to_param, + self.name_to_buffer, + blocks_to_lifted_attrs, + self.name_to_non_tensor_attributes, + self.name_to_constant, + ) + graph_converter.explain() + if len(graph_converter.unsupported_node_list) > 0: + explain_str = "Unsupported nodes are found in the following list:" + for i, n in enumerate(graph_converter.unsupported_node_list): + node_str = "".join(str(n).split("\n")[:1]) + explain_str += f"\n\n {i}. {n.kind()} [{node_str}]" + else: + explain_str = "Success!" + if print_output: + print(explain_str) + return explain_str + + def retrace_as_exported_program( + self, + gm: torch.fx.GraphModule, + name_to_constant: Dict[str, Any], + ): + # TODO: adjust input orders to match GraphSignature convention + ep = torch.export._trace._export( + gm, + self.sample_args, + strict=False, + pre_dispatch=True, + ) + + # Post-processing to make sure the ExportedProgram states are correct. + # Because during conversion, we set tensor constants as GetAttr, + # retracing cannot recognize them as tensor constants but instead + # treat them as buffers. We need to set them again here. + ep._constants.update( + { + k: v + for k, v in name_to_constant.items() + if isinstance(v, (torch.Tensor, torch.ScriptObject)) + } + ) + for k in name_to_constant: + ep.state_dict.pop(k, None) + + for i, spec in enumerate(ep.graph_signature.input_specs): + # Mark as constant tensors for erroneously traced buffers. + if spec.kind == InputKind.BUFFER and spec.target in name_to_constant: + assert isinstance( + name_to_constant[spec.target], torch.Tensor + ), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer" + spec.kind = InputKind.CONSTANT_TENSOR + ep.verifier().check(ep) + + return ep + + def lift_get_attr(self): + # This function lifts multiple data types. + + # 1. Tensor constants attributes (e.g., self.data = torch.tensor([2,3])) + # to buffers. Currently, when there are tensor constants, export + # would error and ask users to register tensor constants as buffers. + # Since it is hard to manually do so for TorchScript models + # (e.g., source code is missing), this function automatically + # lifts tensor constants to be buffers. + + # 2. ScriptObbject to constant. It will then be converted to getattr in + # in the fx graph. + # + # This function should happen in TS2EPConverter instead of + # TS2FXGraphConverter since it gets attributes from self.ts_model + # which is not accessable in TS2FXGraphConverter. It is similar to where + # we collect self.name_to_param and self.name_to_buffer. + name_to_attribute_fqn: Dict[str, str] = {} + + def get_attr(fqn: str): + name = fqn.split(".") + v = self.ts_model + for n in name: + v = getattr(v, n) + return v + + def get_fqn(node: torch._C.Node): + attr_name = node.s("name") + input_name = node.input().debugName() + root_attr_name = name_to_attribute_fqn[input_name] + attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name + return attr_fqn + + def _dfs_get_attr(block): + for node in block.nodes(): + if node.kind() == "prim::CreateObject": + output_name = node.output().debugName() + name_to_attribute_fqn[output_name] = "" + + if node.kind() == "prim::GetAttr": + attr_fqn = get_fqn(node) + value = get_attr(attr_fqn) + output_name = node.output().debugName() + name_to_attribute_fqn[output_name] = attr_fqn + if isinstance(value, torch.Tensor): + if attr_fqn not in self.name_to_buffer: + # Lift tensor constants to be a buffer + self.name_to_buffer[attr_fqn] = value + elif isinstance(value, torch.ScriptObject): + if attr_fqn not in self.name_to_constant: + self.name_to_constant[attr_fqn] = value + else: + self.name_to_non_tensor_attributes[attr_fqn] = value + + for subblock in node.blocks(): + _dfs_get_attr(subblock) + + _dfs_get_attr(self.ts_graph) diff --git a/lib/python3.10/site-packages/torch/_export/error.py b/lib/python3.10/site-packages/torch/_export/error.py new file mode 100644 index 0000000000000000000000000000000000000000..03b7f52fb9de435b9e58fa4a0bb141cc191e84c5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_export/error.py @@ -0,0 +1,56 @@ +from enum import Enum + + +class ExportErrorType(Enum): + # User providing invalid inputs to either tracer, or other public facing APIs + INVALID_INPUT_TYPE = 1 + + # User returning values from their models that we don't support. + INVALID_OUTPUT_TYPE = 2 + + # Generated IR does not conform to Export IR Specification. + VIOLATION_OF_SPEC = 3 + + # User's code contains types and functionalities we don't support. + NOT_SUPPORTED = 4 + + # User's code didn't provide necessary details for us to successfully trace and export. + # For example, we use a lot of decorators and ask users to annotate their model. + MISSING_PROPERTY = 5 + + # User is using an API without proper initialization step. + UNINITIALIZED = 6 + + +def internal_assert(pred: bool, assert_msg: str) -> None: + """ + This is exir's custom assert method. It internally just throws InternalError. + Note that the sole purpose is to throw our own error while maintaining similar syntax + as python assert. + """ + + if not pred: + raise InternalError(assert_msg) + + +class InternalError(Exception): + """ + Raised when an internal invariance is violated in EXIR stack. + Should hint users to report a bug to dev and expose the original + error message. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class ExportError(Exception): + """ + This type of exception is raised for errors that are directly caused by the user + code. In general, user errors happen during model authoring, tracing, using our public + facing APIs, and writing graph passes. + """ + + def __init__(self, error_code: ExportErrorType, message: str) -> None: + prefix = f"[{error_code}]: " + super().__init__(prefix + message) diff --git a/lib/python3.10/site-packages/torch/_export/non_strict_utils.py b/lib/python3.10/site-packages/torch/_export/non_strict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7331659d26a00dc68e0d169a70328cec251c2b --- /dev/null +++ b/lib/python3.10/site-packages/torch/_export/non_strict_utils.py @@ -0,0 +1,523 @@ +# mypy: allow-untyped-defs +import contextlib +import inspect +import logging +from collections import defaultdict +from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING, Union + +import torch +import torch.utils._pytree as pytree +from torch._dynamo.source import ( + AttrSource, + GetItemSource, + LocalSource, + TensorProperty, + TensorPropertySource, +) +from torch._dynamo.variables.builder import TrackedFake +from torch._export.passes.add_runtime_assertions_for_constraints_pass import InputDim +from torch._export.passes.lift_constants_pass import ConstantAttrMap +from torch._guards import Source +from torch._library.fake_class_registry import FakeScriptObject +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.export import Constraint +from torch.export.dynamic_shapes import ( + _check_dynamic_shapes, + _combine_args, + _DimHint, + _process_dynamic_shapes, + _transform_shapes_for_default_dynamic, + _tree_map_with_path, +) +from torch.export.graph_signature import CustomObjArgument +from torch.fx.experimental import _config as config +from torch.fx.experimental.symbolic_shapes import ( + _find_user_code_frame, + _suggest_fixes_for_data_dependent_error_non_strict, + ConstraintViolationError, + DimDynamic, + EqualityConstraint, + GuardOnDataDependentSymNode, + ShapeEnv, + StatelessSymbolicContext, + ValueRanges, +) +from torch.utils._pytree import ( + GetAttrKey, + KeyPath, + MappingKey, + SequenceKey, + tree_map_with_path, +) + + +if TYPE_CHECKING: + from sympy import Symbol + + +log = logging.getLogger(__name__) + + +def key_path_to_source(kp: KeyPath) -> Source: + """ + Given a key path, return the source for the key path. + """ + source: Source = LocalSource("args") + for k in kp: + if isinstance(k, SequenceKey): + source = GetItemSource(source, k.idx) + elif isinstance(k, MappingKey): + source = GetItemSource(source, k.key) + elif isinstance(k, GetAttrKey): + source = AttrSource(source, k.name) + else: + raise ValueError(f"Unknown KeyEntry {k}") + + return source + + +def _is_constant_argument(t): + return t is None or isinstance(t, (int, float, bool, str)) + + +def fakify( + mode: FakeTensorMode, + kp: KeyPath, + t: Any, + t_constraints: Dict[int, Dict[int, Constraint]], + sources: Dict[Tuple[int, int], List[Source]], +): + source = key_path_to_source(kp) + if _is_constant_argument(t) or isinstance(t, torch.ScriptObject): + return t + + if not isinstance(t, torch.Tensor): + raise ValueError(f"Unsupported input type {type(t)}") + n_dims = len(t.shape) + symbolic_context = StatelessSymbolicContext( + dynamic_sizes=[DimDynamic.DYNAMIC] * n_dims, + constraint_sizes=[None] * n_dims, + ) + t_id = id(t) + assert mode.shape_env is not None + if t_id in t_constraints: + for i, constraint in t_constraints[t_id].items(): + symbolic_context.constraint_sizes[i] = constraint.constraint_range + src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i) + sources[(t_id, i)].append(src) + mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name # type: ignore[assignment] + fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context) + mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) # type: ignore[union-attr] + return fake + + +def make_fake_inputs( + nn_module, + args, + kwargs, + dynamic_shapes, + _is_torch_jit_trace=False, + allow_complex_guards_as_runtime_asserts=False, +): + """ + Given an nn module, example inputs, and constraints, return a new fake mode, + fake inputs created in that mode whose dynamic shape dimensions are constrained + by the given ranges, and sources for pairs of dynamic shape dimensions that are + constrained to be equal. + """ + # TODO(avik): refactor Dynamo to avoid duplication of the following code + # between non-strict and strict. + # Specifically, here (non-strict) we do the following pre-tracing steps: + # - Fakify inputs. + # - Process input shape equalities. + # In strict, these steps are spread across multiple files: + # - output_graph.py fakifies inputs. + # - [post-tracing] guards.py processes input shape equalities. + + combined_args = _combine_args(nn_module, args, kwargs) + _check_dynamic_shapes(combined_args, dynamic_shapes) + transformed_dynamic_shapes = _transform_shapes_for_default_dynamic( + combined_args, dynamic_shapes + ) + constraints = _process_dynamic_shapes(combined_args, transformed_dynamic_shapes) + t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict) + for constraint in constraints: + t_constraints[constraint.t_id][constraint.dim] = constraint + + context = torch._guards.TracingContext.try_get() + if context is not None: + # This occurs when we are exporting within dynamo. There already exists + # a toplevel TracingContext with a fake mode, so we do not want to + # create another fake mode. + fake_mode = context.fake_mode + elif not _is_torch_jit_trace: + code = nn_module.forward.__code__ + co_fields = { + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + } + fake_mode = FakeTensorMode( + shape_env=ShapeEnv( + tracked_fakes=[], + co_fields=co_fields, + prefer_deferred_runtime_asserts_over_guards=True, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + ), + allow_non_fake_inputs=True, + export=True, + ) + else: + fake_mode = FakeTensorMode( + shape_env=ShapeEnv( + tracked_fakes=[], + prefer_deferred_runtime_asserts_over_guards=True, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + ), + allow_non_fake_inputs=True, + ) + if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None: + raise ValueError( + "Detected fake_mode does not have a shape_env with tracked fakes. " + "If you constructed the module under a FakeTensorMode, " + "please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))" + ) + + with fake_mode: + # FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock + if not _is_torch_jit_trace: + original_signature = inspect.signature(nn_module.forward) + else: + original_signature = None + sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list) + fake_args, fake_kwargs = tree_map_with_path( + lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources), + (args, kwargs), + ) + + names: Dict[str, Tuple[int, int]] = {} + source_pairs: List[Tuple[Source, Source]] = [] + derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = [] + phantom_symbols: Dict[str, Symbol] = {} + for constraint in constraints: + torch.export.dynamic_shapes._process_equalities( + constraint, + lambda t_id, dim: sources[(t_id, dim)], + fake_mode.shape_env, + names, + source_pairs, + derived_equalities, + phantom_symbols, + ) + + equalities_inputs = EqualityConstraint( + source_pairs=source_pairs, + derived_equalities=derived_equalities, + phantom_symbols=list(phantom_symbols.values()), + warn_only=False, + ) + return ( + fake_mode, + fake_args, + fake_kwargs, + equalities_inputs, + original_signature, + transformed_dynamic_shapes, + ) + + +def _flatten_dynamic_shapes( + combined_args: Dict[str, Any], + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]], +) -> List[Any]: + flat_shapes = [] + + def _tree_map_helper(path, t, shape): + nonlocal flat_shapes + flat_shapes.append(shape) + + _tree_map_with_path(_tree_map_helper, combined_args, dynamic_shapes) + return flat_shapes + + +def produce_guards_and_solve_constraints( + fake_mode: FakeTensorMode, + gm: torch.fx.GraphModule, + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + equalities_inputs: EqualityConstraint, + original_signature: inspect.Signature, + _is_torch_jit_trace=False, +): + """ + Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions, + and a graph module, produce guards on the fake mode's shape env (raising constraint + violations if any), solve (to suggest simplifications or fixes). + Dynamo already performs this, so this is for non-strict mode. + + Additional inputs: + equalities_inputs: the equality constraints to use for guards + original_signature: the signature of the forward method + """ + shape_env = fake_mode.shape_env + assert shape_env is not None + assert shape_env.tracked_fakes is not None + + placeholders = [tf.fake for tf in shape_env.tracked_fakes] + sources = [tf.source for tf in shape_env.tracked_fakes] + input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes] + constraint_violation_error = None + try: + shape_env.produce_guards( + placeholders, + sources, + input_contexts=input_contexts, + equalities_inputs=equalities_inputs, + ignore_static=False, + ) + except ConstraintViolationError as e: + constraint_violation_error = e + + shape_env.frozen = True + dim_constraints = shape_env.dim_constraints + if dim_constraints is None: + # Expected when shape_env.produce_guards throws an early constraint violation error. + # There is nothing to solve for in this case. + # TODO(avik): Maybe record the constraint violation error instead and replay later? + assert constraint_violation_error + raise constraint_violation_error + dim_constraints.solve() + forced_specializations = dim_constraints.forced_specializations() + if not _is_torch_jit_trace: + msg = dim_constraints.prettify_results( + original_signature, + dynamic_shapes, + constraint_violation_error, + forced_specializations, + ) + else: + # FIXME(ycao): This is a hack to get around missing signature from ScriptMethod + msg = "dummy constraint violation message" + if constraint_violation_error: + constraint_violation_error.args = (constraint_violation_error.args[0] + msg,) + elif forced_specializations: + constraint_violation_error = ConstraintViolationError(msg) + if constraint_violation_error: + raise constraint_violation_error + + +def make_constraints( + fake_mode: FakeTensorMode, + gm: torch.fx.GraphModule, + combined_args: Dict[str, Any], + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + num_lifted_inputs: int, +): + """ + Given a fake mode's shape env and user-specified dynamic shapes, + return the resulting range constraints and equality constraints. + + Additional args: + num_lifted_inputs: the number of non-user-input placeholder nodes in the graph + (used only to enumerate the user-input nodes) + """ + + shape_env = fake_mode.shape_env + assert shape_env is not None + inline_constraints = gm.meta.get("inline_constraints", []) + range_constraints = { + symbol: inline_constraints[symbol] for symbol in inline_constraints + } + if not dynamic_shapes: + return range_constraints + + # get individual dynamic shapes spec for each input + if not isinstance(dynamic_shapes, dict): + assert isinstance(dynamic_shapes, (tuple, list)) + combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] + flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes) + + # check number of shapes vs. number of inputs + num_placeholders = [node.op == "placeholder" for node in gm.graph.nodes].count(True) + assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs + + input_dims = defaultdict(list) + free_symbols = set() + for input_index, node in enumerate(gm.graph.nodes): + if input_index < num_lifted_inputs or node.op != "placeholder": + continue + if _is_constant_argument(node.meta["val"]) or isinstance( + node.meta["val"], CustomObjArgument + ): + continue + shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs] + for i, d in enumerate(node.meta["val"].shape): + if isinstance(d, torch.SymInt) and not d.node.expr.is_number: + # Look up the range constraint for the symbol corresponding to this shape dimension + # and store it indexed by the symbolic expression corresponding to it. + # NOTE(avik): Use node._expr instead of node.expr for the lookup here because + # we want the symbol, not its replacement, which could be an expression. Maybe + # there's a better way to do this, e.g., by (re)computing value ranges for expressions? + dim = shape_spec[i] if shape_spec else None + if dim is None or isinstance(dim, _DimHint): + range_constraints[d.node.expr] = shape_env.var_to_range[ + d.node._expr + ] + else: + range_constraints[d.node.expr] = ValueRanges( + lower=dim.min, upper=dim.max + ) + input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i)) + free_symbols.update(d.node.expr.free_symbols) + + for symbol in free_symbols: + if symbol not in range_constraints: + # Placeholders can have symbolic shapes that are derived expressions. + # The above code will record direct range constraints for them + # so that we can do runtime assertions. In addition, for serde checks + # we want to record range constraints for their root symbols. + range_constraints[symbol] = shape_env.var_to_range[symbol] + + return range_constraints + + +def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap: + """Search the module hierarchy, gathering up all tensor and ScriptObject constants. + + Returns a dictionary mapping hash(value) to the name of the constant. We + have to abuse `hash` here unfortunately, see: [ScriptObject hash]. + """ + constants = ConstantAttrMap() + buffers_parameters = set(m.buffers()) + buffers_parameters.update(m.parameters()) + + def inner(m: torch.nn.Module, prefix_atoms: List[str], constants): + for k, v in m.__dict__.items(): + if isinstance( + v, + ( + torch.Tensor, + torch.ScriptObject, + FakeScriptObject, + ), + ): + if v in buffers_parameters: + # filter out buffers and parameters, leaving only constants + continue + + fqn = ".".join(prefix_atoms + [k]) + constants.add(v, fqn) + for k, v in m.named_children(): + inner(v, prefix_atoms + [k], constants) + + inner(m, [], constants) + return constants + + +@contextlib.contextmanager +def _fakify_script_objects( + mod: torch.nn.Module, + args: Tuple[Any], + kwargs: Dict[Any, Any], + fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, +): + # This context manager is used to fakify script objects into FakeScriptObject. + # Inputs: + # mod: the module to be exported, it (and its recursive submodules)'s script object attrs haven't been fakified. + # args, kwargs: the args and kwargs inputs for mod, script object inputs haven't been fakified. + # fake_mode: the fake mode to be used for fakifying script objects. It's the same mode that fakify input tensors. + # + # Returns: + # mod: the patched module, its (and its recursive submodules) script object attrs have been fakified. + # fake_args, fake_kwargs: new fakified args and kwargs. + # Script object inputs have been fakified. Don't touch the tensors. + # fake_constant_attrs: a new map from FakeScriptObject to the fqn of the original script object. + # fake_to_real: a mapping between FakeScriptObject and the original script object in order to un-do the patching. + + constant_attrs: ConstantAttrMap = _gather_constant_attrs(mod) + assert not any( + isinstance(obj, FakeScriptObject) for obj in constant_attrs.values() + ), "Mod shouldn't contain any FakeScriptObject." + assert not pytree.tree_any( + lambda obj: isinstance(obj, FakeScriptObject), (args, kwargs) + ), "args and kwargs shouldn't contain any FakeScriptObject." + + patched_attr = {} + fake_constant_attrs = ConstantAttrMap() + fake_to_real = {} + + def _maybe_fakify_obj(obj): + fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, obj) + fake_to_real[fake_obj] = obj + return fake_obj + + def _leaf_mod_and_attr( + mod: torch.nn.Module, attr_fqn: str + ) -> Tuple[torch.nn.Module, str]: + *prefix_attr, last_attr = attr_fqn.split(".") + cur_mod = mod + for attr in prefix_attr: + cur_mod = getattr(cur_mod, attr) + return cur_mod, last_attr + + try: + for obj, fqns in constant_attrs.items(): + if isinstance(obj, torch.ScriptObject): + fake_script_obj = _maybe_fakify_obj(obj) + for fqn in fqns: + cur_mod, attr = _leaf_mod_and_attr(mod, fqn) + assert obj is getattr(cur_mod, attr) + setattr(cur_mod, attr, fake_script_obj) + fake_constant_attrs.add(fake_script_obj, fqn) + patched_attr[fqn] = obj + else: + for fqn in fqns: + fake_constant_attrs.add(obj, fqn) + + fake_args, fake_kwargs = pytree.tree_map_only( + torch.ScriptObject, _maybe_fakify_obj, (args, kwargs) + ) + yield (mod, fake_args, fake_kwargs, fake_constant_attrs, fake_to_real) + finally: + for fqn, orig_obj in patched_attr.items(): + cur_mod, attr = _leaf_mod_and_attr(mod, fqn) + setattr(cur_mod, attr, orig_obj) + + +class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): + """ + 1. Handles data-dependent errors raised by torch function calls in non-strict. + + Any data-dependent error is due to some condition on unbacked symints + that cannot be resolved. A mechanical way of fixing the error is to use + a torch._check() call to assert either that condition or its negation. + The handler suggests these options as code and points to the location + of the torch function call that raised the error as part of the error + message shown to the user, who can then simply select and copy-paste + a suggested fix at that location. + + NOTE: Not all data-dependent errors are raised by torch function calls. + In particular, conditions on unbacked symints can appear outside such + calls, and as such are not handled here. + + 2. Handles line-of-code logging for each torch function call in non-strict. + + Usage: TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC=1 TORCH_LOGS="+export" ... + """ + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: + frame = _find_user_code_frame() + if frame is not None: + log.debug( + "%s called at %s:%s in %s", + func.__qualname__, + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) + try: + return func(*args, **kwargs) + except GuardOnDataDependentSymNode as e: + _suggest_fixes_for_data_dependent_error_non_strict(e) + raise diff --git a/lib/python3.10/site-packages/torch/_export/pass_base.py b/lib/python3.10/site-packages/torch/_export/pass_base.py new file mode 100644 index 0000000000000000000000000000000000000000..55612c98ce8d51d95999f0f4e124f3479070deb1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_export/pass_base.py @@ -0,0 +1,441 @@ +# mypy: allow-untyped-defs +import operator +import traceback +import typing +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +from functorch.experimental.control_flow import _unstack_pytree +from torch import fx +from torch._dispatch.python import enable_python_dispatcher +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._subclasses import FakeTensor, UnsupportedFakeTensorException +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx import traceback as fx_traceback +from torch.fx.experimental.proxy_tensor import PythonKeyTracer +from torch.fx.graph import CodeGen +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.utils import _pytree as pytree +from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings + + +__all__ = ["_ExportPassBaseDeprecatedDoNotUse"] + + +Argument = Any +Value = Any +Fn = Callable[..., Any] +PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] + + +_TORCH_SYM_OPS: Set[Callable] = { + torch.sym_int, + torch.sym_float, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_not, + torch.sym_sqrt, +} + + +class ExportPassBaseError(RuntimeError): + pass + + +class _ExportPassBaseDeprecatedDoNotUse(PassBase): + """ + Interpreter-based pass class to help users maintain the IR spec while writing + transformations. + """ + + @staticmethod + def _create_dummy_node_metadata(): + return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))}) + + + class ExportTracer(PythonKeyTracer): + def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None: + super().__init__() + self.callback = callback + self.root = torch.nn.Module() + self.graph = torch.fx.Graph() + self.graph.set_codegen(codegen) + self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment] + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self.submodules: Dict[torch.nn.Module, str] = {} + + def trace(self) -> None: # type: ignore[override] + raise ExportPassBaseError("ExportTracer doesn't support trace().") + + def create_arg(self, a: Argument) -> torch.fx.Node: + if isinstance(a, torch.nn.Module): + if a not in self.submodules: + name_submodule = f"submodule_{len(self.submodules)}" + self.root.add_module(name_submodule, a) + self.submodules[a] = name_submodule + elif isinstance(a, FakeTensor): + if not hasattr(a, "constant") or a.constant is None: + raise ExportPassBaseError(f"Cannot add {a} to graph.") + a = a.constant + node = super().create_arg(a) + if ( + isinstance(a, torch.Tensor) + and isinstance(node, torch.fx.Node) + and node.op == "get_attr" + ): + self.set_metadata(node, a) + self.callback.on_attr(ProxyValue(a, node)) + return node + + def set_metadata( + self, node: torch.fx.Node, value: Argument, + ) -> None: + # propagate the fake tensor or sym nodes + def make_val( + x: Argument, + ) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]: + if isinstance(x, FakeTensor): + return x + elif isinstance(x, torch.Tensor): + if x.is_quantized: + # TODO (tmanlaibaatar) properly support Quantized FakeTensor + x = torch.dequantize(x) + + try: + assert self.fake_tensor_mode is not None + # TODO we should allocate static shapes + # for param/buffer values + if isinstance(x, torch.nn.Parameter): + fake_tensor = self.fake_tensor_mode.from_tensor( + x, static_shapes=True + ) + else: + fake_tensor = self.fake_tensor_mode.from_tensor(x) + except UnsupportedFakeTensorException: + # TODO: This is just a workaround to get over the + # x.as_subclass error + print( + "Fakeifying a Tensor subclass is not supported \ + right now. Instead a TensorMetadata is used." + ) + fake_tensor = None + return fake_tensor + elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)): + return x + else: + return None + + node.meta["val"] = pytree.tree_map(make_val, value) + + # Set the tensor_metadata for values that do not have a corresponding FakeTensor + def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]: + if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor): + if x.is_quantized: + # TODO (tmanlaibaatar) properly support Quantized FakeTensor + x = torch.dequantize(x) + + try: + assert self.fake_tensor_mode is not None + _ = self.fake_tensor_mode.from_tensor(x) + tensor_meta = None + except UnsupportedFakeTensorException: + # TODO: This is just a workaround to get over the + # x.as_subclass error + tensor_meta = _extract_tensor_metadata(x) + return tensor_meta + else: + return None + + node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value) + + class ExportInterpreter(fx.Interpreter): + def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None: + super().__init__(gm) + self.callback = callback + self.node: torch.fx.Node = next(iter(gm.graph.nodes)) + + def placeholder( + self, + target: str, # type: ignore[override] + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> ProxyValue: + arg = super().placeholder(target, args, kwargs) + return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta)) + + def output( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> ProxyValue: + return self.callback.output(args[0], NodeMetadata(self.node.meta)).data + + def call_function( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> ProxyValue: + meta = NodeMetadata(self.node.meta) + + if target == operator.getitem: + value, key = args + return self.callback.call_getitem(value, key, meta) + elif getattr(target, "__module__", None) in {"_operator", "math"}: + assert callable(target) + return self.callback.call_sym(target, args, meta) + elif target in _TORCH_SYM_OPS: + assert callable(target) + return self.callback.call_sym(target, args, meta) + elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)): + return self.callback.call_operator( + target, + args, + kwargs, + meta, + ) + elif target == torch.ops.higher_order.cond: + pred, true_fn, false_fn, inputs = args + return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta) + elif target == torch.ops.higher_order.map_impl: + f, mapped_args, operands = args # type: ignore[assignment] + return self.callback.call_map(f, mapped_args, operands, meta) + # For other unregistered HigherOrderOps, just interpret them blindly + elif isinstance(target, torch._ops.HigherOrderOperator): + return self.callback._fx( + "call_function", + target, + args, + kwargs, + meta, + ) + else: + raise ExportPassBaseError(f"Unsupported target type: {target}") + + def get_attr( + self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override] + ) -> Argument: + return super().get_attr(target, args, kwargs) + + def call_module( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> None: + raise ExportPassBaseError("call_module is not supported.") + + def call_method( + self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override] + ) -> None: + raise ExportPassBaseError("call_method is not supported.") + + def run_node(self, n: torch.fx.Node) -> Argument: + self.node = n + self.callback.node_debug_str = n.format_node() + return super().run_node(n) + + def __init__(self) -> None: + self.interpreter = PropagateUnbackedSymInts( + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + self.tracer = self.ExportTracer(self, CodeGen()) + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self._initialized = True + self.node_debug_str: typing.Optional[str] = None + + def _fx( + self, + kind: str, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + args_data, kwargs_data = pytree.tree_map_only( + ProxyValue, lambda x: x.data, (args, kwargs) + ) + res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data) + args_proxy, kwargs_proxy = pytree.tree_map_only( + ProxyValue, lambda x: x.proxy, (args, kwargs) + ) + + name = None + if isinstance(target, torch._ops.OpOverload): + name = self.tracer.graph._target_to_str(target.overloadpacket.__name__) + + res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name) + res_proxy.node.meta.update(meta.data) + if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env): + if symbol_to_path := compute_unbacked_bindings(shape_env, res_data): + res_proxy.node.meta["unbacked_bindings"] = symbol_to_path + self.tracer.set_metadata(res_proxy.node, res_data) + return ProxyValue(res_data, res_proxy) + + def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]: + # TODO(angelayi): Update this with what we decide to do for metadata in + # the exported graph module + if (args := graph_module.meta.get("args", None)) is not None: + return list(args) + + def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]: + if "val" in node.meta: + fake = node.meta["val"] + if hasattr(fake, "constant") and fake.constant is not None: + return fake.constant + return fake + elif tensor_meta := node.meta.get("tensor_meta"): + assert self.fake_tensor_mode is not None + return FakeTensor( + self.fake_tensor_mode, + torch.empty( + tensor_meta.shape, + dtype=tensor_meta.dtype, + device="meta", + requires_grad=tensor_meta.requires_grad, + memory_format=tensor_meta.memory_format, + ), + torch.device("cpu"), + ) + elif len(node.users) == 0: + return None + raise ExportPassBaseError( + f"Cannot construct an input for graph module: {graph_module}.", + ) + + return [ + extract_input(node) + for node in graph_module.graph.nodes + if node.op == "placeholder" + ] + + def on_attr(self, attr: ProxyValue) -> None: + pass + + def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue: + arg_proxy = self.tracer.create_proxy("placeholder", name, (), {}) + arg_proxy.node.meta = meta.data + self.tracer.set_metadata(arg_proxy.node, arg) + return ProxyValue(arg, arg_proxy) + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + return self._fx("call_function", op, args, kwargs, meta) + + def call_sym( + self, + target: Fn, + args: Tuple[Argument, ...], + meta: NodeMetadata, + ) -> ProxyValue: + return self._fx("call_function", target, args, {}, meta) + + def call_cond( + self, + pred: ProxyValue, + true_fn: torch.fx.GraphModule, + false_fn: torch.fx.GraphModule, + inputs: List[Argument], + meta: NodeMetadata, + ) -> ProxyValue: + true_branch = self.call_submodule(true_fn, tuple(inputs)) + false_branch = self.call_submodule(false_fn, tuple(inputs)) + assert true_branch is not None + assert false_branch is not None + return self._fx( + "call_function", + torch.ops.higher_order.cond, + (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)), + {}, + meta, + ) + + def call_map( + self, + f: torch.fx.GraphModule, + mapped_args: List[ProxyValue], + operands: List[ProxyValue], + meta: NodeMetadata, + ) -> ProxyValue: + xs = _unstack_pytree([arg.data for arg in mapped_args])[0] + f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands])) + assert f_branch is not None + return self._fx( + "call_function", + torch.ops.higher_order.map_impl, + (f_branch.graph_module, mapped_args, operands), + {}, + meta, + ) + + def call_getitem( + self, value: ProxyValue, key: int, meta: NodeMetadata + ) -> ProxyValue: + return self._fx("call_function", operator.getitem, (value, key), {}, meta) + + def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: + return self._fx("output", "output", (results,), {}, meta) + + def call_submodule( + self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] + ) -> PassResult: + prev_tracer, self.tracer = self.tracer, self.ExportTracer( + self, graph_module.graph._codegen + ) + self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode + interpreter = self.ExportInterpreter(self, graph_module) + prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment] + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs) + with fx_traceback.preserve_node_meta(): + interpreter.run(*inputs_data) + + new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph) + + self.tracer = prev_tracer + self.interpreter = prev_interpreter + return PassResult( + new_graph_module, + True, + ) + + def call(self, graph_module: fx.GraphModule) -> PassResult: + if not getattr(self, "_initialized", False): + raise ExportPassBaseError( + "ExportPass is not initialized with __init__().", + ) + + inputs = self.inputs(graph_module) + + fake_tensor_mode = None + for i in inputs: + if isinstance(i, FakeTensor): + assert ( + fake_tensor_mode is None or fake_tensor_mode is i.fake_mode + ), "Multiple fake tensor mode detected." + fake_tensor_mode = i.fake_mode + if fake_tensor_mode is None: + self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True) + fake_tensor_mode = nullcontext() # type: ignore[assignment] + dispatcher_mode = nullcontext() # type: ignore[assignment] + else: + fake_tensor_mode.allow_non_fake_inputs = True + self.tracer.fake_tensor_mode = fake_tensor_mode + dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment] + self.fake_tensor_mode = self.tracer.fake_tensor_mode + + with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr] + result = self.call_submodule(graph_module, tuple(inputs)) + + return result diff --git a/lib/python3.10/site-packages/torch/_export/tools.py b/lib/python3.10/site-packages/torch/_export/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b96f909d1642f888546d1068d31d1a5f4ee9f1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_export/tools.py @@ -0,0 +1,146 @@ +# mypy: allow-untyped-defs +import logging +import warnings +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +import torch.export +import torch.export._trace +from torch._utils_internal import log_export_usage + + +log = logging.getLogger(__name__) + +__all__ = ["report_exportability"] + + +def _generate_inputs_for_submodules( + model: torch.nn.Module, + target_submodules: Iterable[str], + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, +) -> Dict[str, Tuple[Any, Any]]: + """ + Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this + function doesn't work. + + Args: + model: root model. + inputs: inputs to the root model. + target_submodules: submodules that we want to generate inputs for. + + Returns: + A dict that maps from submodule name to its inputs. + """ + kwargs = kwargs or {} + + handles = [] + results = {} + submodule_to_names = {mod: name for name, mod in model.named_modules()} + + def pre_forward(module, module_args, module_kwargs): + results[submodule_to_names[module]] = (module_args, module_kwargs) + + try: + for name, mod in model.named_modules(): + if name in target_submodules: + handles.append( + mod.register_forward_pre_hook(pre_forward, with_kwargs=True) + ) + model(*args, **kwargs) + except Exception as e: + warnings.warn( + f"Failed to generate submodule inputs because of the following error:\n{e}" + ) + finally: + for h in handles: + h.remove() + return results + + +def report_exportability( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + *, + strict: bool = True, + pre_dispatch: bool = False, +) -> Dict[str, Optional[Exception]]: + """ + Report exportability issues for a module in one-shot. + + Args: + mod: root module. + args: args to the root module. + kwargs: kwargs to the root module. + Returns: + A dict that maps from submodule name to the exception that was raised when trying to export it. + `None` means the module is exportable without issue. + Sample output: + { + '': UnsupportedOperatorException(func=), + 'submod_1': UnsupportedOperatorException(func=), + 'submod_2': None + } + """ + + log_export_usage(event="export.report_exportability") + + kwargs = kwargs or {} + + all_submod_names = [name for name, _ in mod.named_modules() if name != ""] + submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs) + + tried_module_types = set() + report: Dict[str, Optional[Exception]] = {} + + def try_export(module, module_name, args, kwargs): + nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types + + if type(module) in tried_module_types: + return + tried_module_types.add(type(module)) + + if args is not None or kwargs is not None: + try: + torch.export._trace._export( + module, + args, + kwargs, + strict=strict, + pre_dispatch=pre_dispatch, + ) + report[module_name] = None + log.info("Successfully exported `%s`", module_name) + return + except Exception as e: + short_msg = repr(e).split("\n")[0] + log.warning( + "Failed exporting `%s` with exception: %s", module_name, short_msg + ) + report[module_name] = e + + for name, submod in module.named_children(): + sub_module_name = name if module_name == "" else f"{module_name}.{name}" + + submod_args, submod_kwargs = submod_inputs.get( + sub_module_name, (None, None) + ) + + try_export(submod, sub_module_name, submod_args, submod_kwargs) + + return + + try_export(mod, "", args, kwargs) + + unique_issues = set() + for exception in report.values(): + if exception is not None: + key = repr(exception).split("\\n")[0] + unique_issues.add(key) + + log.warning("Found %d export issues:", len(unique_issues)) + for issue in unique_issues: + log.warning(issue) + + return report diff --git a/lib/python3.10/site-packages/torch/_export/utils.py b/lib/python3.10/site-packages/torch/_export/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e085e18b68a20753efa38d6f85e6eadf78f8895a --- /dev/null +++ b/lib/python3.10/site-packages/torch/_export/utils.py @@ -0,0 +1,893 @@ +# mypy: allow-untyped-defs +import ast +import dataclasses +import inspect +import math +import operator +import re +from inspect import Parameter +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING + +import torch +from torch._guards import detect_fake_mode +from torch._subclasses.fake_tensor import FakeTensor + + +if TYPE_CHECKING: + from torch._export.passes.lift_constants_pass import ConstantAttrMap + from torch.export import ExportedProgram + from torch.export.graph_signature import ExportGraphSignature + +from torch.export.graph_signature import InputKind, OutputKind +from torch.utils._pytree import ( + _register_pytree_node, + Context, + FlattenFunc, + FromDumpableContextFn, + GetAttrKey, + KeyPath, + keystr, + MappingKey, + SequenceKey, + ToDumpableContextFn, + tree_flatten_with_path, + UnflattenFunc, +) + + +placeholder_prefixes = { + InputKind.USER_INPUT: "", + InputKind.PARAMETER: "p_", + InputKind.BUFFER: "b_", + InputKind.CONSTANT_TENSOR: "c_", + InputKind.CUSTOM_OBJ: "obj_", + InputKind.TOKEN: "token", +} + + +def _collect_and_set_constant_attrs( + graph_signature, constants, mod +) -> "ConstantAttrMap": + # the exported module will store constants & non-persistent buffers such that + # retracing treats them as persistent buffers, so we inform the constants lifting pass + # and overwrite the new graph signature using the previous program. This is intended to only be used + # in run_decompositions where we still have access to original EP. + from torch._export.passes.lift_constants_pass import ConstantAttrMap + + constant_attrs = ConstantAttrMap() + non_persistent_buffers = { + spec.target + for spec in graph_signature.input_specs + if spec.kind == InputKind.BUFFER and not spec.persistent + } + for name, value in constants.items(): + if name in non_persistent_buffers: + continue + # recursive getattr + _mod = mod + *atoms, attr = name.split(".") + for atom in atoms: + _mod = getattr(_mod, atom) + # remove as buffer, reassign as constant/non-persistent buffer + _mod._buffers.pop(attr, None) + setattr(_mod, attr, value) + constant_attrs.add(value, name) + return constant_attrs + + +def _overwrite_signature_for_non_persistent_buffers( + old_sig: "ExportGraphSignature", new_sig: "ExportGraphSignature" +): + # overwrite signature for non-persistent buffers + non_persistent_buffers = { + spec.target + for spec in old_sig.input_specs + if spec.kind == InputKind.BUFFER and not spec.persistent + } + + for spec in new_sig.input_specs: + if spec.kind == InputKind.BUFFER and spec.target in non_persistent_buffers: + spec.persistent = False + return new_sig + + +def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]: + """ + Param/buffer metadata needs to be saved before lowering to aten IR + because aten IR lifts them, as a result, automatic preservation doesn't work. + This is intended to be called on the strict mode tracing right before lowering to + aten IR OR run_decomposition pass. + """ + params_buffers_to_node_meta = {} + + def _getattr(model: torch.fx.GraphModule, attr_name: str): + *prefix, field = attr_name.split(".") + t = model + for item in prefix: + t = getattr(t, item, None) # type: ignore[assignment] + assert t is not None + + return getattr(t, field) + + for node in mod.graph.nodes: + target = node.target + meta = node.meta + if node.op == "call_module": + submodule = _getattr(mod, target) + if isinstance(submodule, torch.nn.Module): + for name, _ in submodule.named_parameters( + recurse=True, remove_duplicate=False + ): + params_buffers_to_node_meta[target + "." + name] = meta + + for name, _ in submodule.named_buffers( + recurse=True, remove_duplicate=False + ): + params_buffers_to_node_meta[target + "." + name] = meta + + if node.op == "get_attr": + submodule = _getattr(mod, target) + if not isinstance(submodule, torch.fx.GraphModule): + params_buffers_to_node_meta[target] = meta + + # If the call_function uses param as input, we also need to update params' meta + # with this call_function node's meta. + # This is basically the same flow as torch.fx.traceback.preserve_meta() + if node.op == "call_function" and not isinstance( + node.target, torch._ops.HigherOrderOperator + ): + for arg in node._input_nodes: + if arg.op == "get_attr": + for entry in torch.fx.proxy._COPY_META_FIELDS: + if entry in meta: + params_buffers_to_node_meta[arg.target][entry] = meta[entry] + + return params_buffers_to_node_meta + + +def _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta: Dict[str, Any], + gm: torch.fx.GraphModule, + new_sig: "ExportGraphSignature", +) -> None: + """ + Given that we collected param'buffer metadata before, we put them back in + newly traced graph module + """ + # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes + for metadata in params_buffers_to_node_meta.values(): + metadata.pop("nn_module_stack", None) + metadata.pop("stack_trace", None) + + for node in gm.graph.nodes: + if node.op == "placeholder": + if node.target in new_sig.inputs_to_parameters: + param_name = new_sig.inputs_to_parameters[node.target] + if param_name in params_buffers_to_node_meta: + for k, v in params_buffers_to_node_meta[param_name].items(): + node.meta[k] = v + if node.target in new_sig.inputs_to_buffers: + buffer_name = new_sig.inputs_to_buffers[node.target] + if buffer_name in params_buffers_to_node_meta: + for k, v in params_buffers_to_node_meta[buffer_name].items(): + node.meta[k] = v + + +def _get_shape_env_from_gm(gm: torch.fx.GraphModule): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + + fake_mode = _detect_fake_mode_from_gm(gm) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + +def _rename_without_collisions( + name_map: Dict[str, str], + orig_name: str, + name: str, + is_placeholder: bool = False, +): + """ + Renames nodes to avoid name collisions, with suffixing. + name_map: map from original name to new name + orig_name: mapping key + name: candidate name (potentially suffixed, e.g. mul_2) + is_placeholder: if the node is a placeholder, avoid detecting suffix + """ + if name in name_map.values(): + # non-placeholder nodes may be suffixed with the count + # instead of adding another suffix, we will try to increment it + match = re.match(r"(.*)_(\d+)", name) + if match and not is_placeholder: + name, n = match.group(1), int(match.group(2)) + else: + n = 0 + while (dup_name := f"{name}_{n + 1}") in name_map.values(): + n += 1 + name_map[orig_name] = dup_name + else: + name_map[orig_name] = name + return name_map[orig_name] + + +def _check_input_constraints_for_graph( + input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints +): + def get_keystr(key_path: KeyPath) -> str: + """For a given index into the flat_args, return a human readable string + describing how to access it, e.g. "*args["foo"][0].bar" + """ + # Prefix the keypath with "*args" or "**kwargs" to make it clearer where + # the arguments come from. Ultimately we ought to serialize the + # original arg names for the best error message here. + args_kwargs_key_path = key_path[0] + assert isinstance(args_kwargs_key_path, SequenceKey) + if args_kwargs_key_path.idx == 0: + return f"*args{keystr(key_path[1:])}" + else: + kwarg_key = key_path[1] + assert isinstance(kwarg_key, MappingKey) + name = str(kwarg_key)[1:-1] # get rid of the enclosed [] + return f"{name}{keystr(key_path[2:])}" + + import sympy + + from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( + _convert_range_to_int, + ) + from torch.utils._sympy.solve import try_solve + + if len(flat_args_with_path) != len(input_placeholders): + raise RuntimeError( + "Unexpected number of inputs " + f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})" + ) + # NOTE: export already guarantees that the same symbol is used in metadata + # for all InputDims related by equality constraints, so we can just unify + # symbols with given input dimension values to check equality constraints. + unification_map: Dict[sympy.Symbol, Any] = {} + for (key_path, arg), node in zip(flat_args_with_path, input_placeholders): + node_val = node.meta.get("val") + if isinstance(node_val, FakeTensor): + if not isinstance(arg, torch.Tensor): + raise RuntimeError( + f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}", + ) + + if len(node_val.shape) != len(arg.shape): + raise RuntimeError( + f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape " + f"(expected {node_val.shape}, got {arg.shape})" + ) + + for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)): + # TODO(avik): Assert the following property in the IR verifier: + # node_dim is either an int or a SymInt containing an int or a unary sympy.Expr + if ( + isinstance(node_dim, torch.SymInt) + and len(node_dim.node.expr.free_symbols) == 1 + ): + symbol = next(iter(node_dim.node.expr.free_symbols)) + if symbol in unification_map: + existing_dim = node_dim.node.expr.subs(unification_map) + if arg_dim != existing_dim: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " + f"{existing_dim}, but got {arg_dim}", + ) + else: + if ( + isinstance(arg_dim, torch.SymInt) + and not arg_dim.node.expr.is_number + ): + # This can happen when, say, arg is a fake tensor. + # We do not run checks on symbolic shapes of fake inputs as + # such checks can affect the shape env. + pass + else: + if isinstance(node_dim.node.expr, sympy.Symbol): + # Short cut for try_solve below. Also useful in cases where + # sympy.Eq(node_dim.node.expr, arg_dim) would evaluate to False + # purely because symbol is constrained to be size-like, + # e.g., when node_dim.node.expr = symbol and arg_dim = 0. + unification_map[symbol] = int(arg_dim) + else: + solution = try_solve( + sympy.Eq(node_dim.node.expr, arg_dim), symbol + ) + if solution is None: + raise RuntimeError( # noqa: B904 + f"Expected input {node.name}.shape[{j}] = {arg_dim} to be " + f"of the form {node_dim.node.expr}, where {symbol} is an integer" + ) + else: + unification_map[symbol] = int(solution[1]) + + if node_dim.node.expr in range_constraints: + min_val, max_val = _convert_range_to_int( + range_constraints[node_dim.node.expr] + ) + # NOTE: we allow dimensions to be 0/1 at runtime + if min_val > 2: + if arg_dim < min_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be >= " + f"{min_val}, but got {arg_dim}", + ) + if max_val < math.inf: + if arg_dim > max_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be <= " + f"{max_val}, but got {arg_dim}", + ) + else: + if arg_dim != node_dim: + if ( + isinstance(node_dim, torch.SymInt) + and not node_dim.node.expr.is_number + ): + # this means we deferred a guard from export analysis to runtime, let this pass + # we'll add a runtime assert checking equality to this replacement expression + continue + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " + f"{node_dim}, but got {arg_dim}", + ) + elif isinstance(node_val, (int, float, str)): + if type(arg) != type(node_val) or arg != node_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", + ) + + +def register_dataclass_as_pytree_node( + cls: Type[Any], + flatten_fn: Optional[FlattenFunc] = None, + unflatten_fn: Optional[UnflattenFunc] = None, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + return_none_fields: bool = False, +) -> None: + assert dataclasses.is_dataclass( + cls + ), f"Only dataclasses can be registered with this function: {cls}" + + def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: + flattened = [] + flat_names = [] + none_names = [] + for f in dataclasses.fields(obj): + name, val = f.name, getattr(obj, f.name) + if val is not None or return_none_fields: + flattened.append(val) + flat_names.append(name) + else: + none_names.append(name) + return flattened, [flat_names, none_names] + + def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any: + flat_names, none_names = context + return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) + + def default_flatten_fn_with_keys(obj: Any) -> Tuple[List[Any], Context]: + flattened, (flat_names, none_names) = flatten_fn(obj) # type: ignore[misc] + return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names + + flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn + unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) + + _register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + flatten_with_keys_fn=default_flatten_fn_with_keys, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def is_param(program: "ExportedProgram", node: torch.fx.Node) -> bool: + """ + Checks if the given node is a parameter within the exported program + """ + + return node.name in program.graph_signature.inputs_to_parameters + + +def get_param( + program: "ExportedProgram", + node: torch.fx.Node, +) -> Optional[torch.nn.Parameter]: + """ + Returns the parameter associated with the given node in the exported program. + Returns None if the node is not a parameter within the exported program + """ + + if is_param(program, node): + parameter_name = program.graph_signature.inputs_to_parameters[node.name] + return program.state_dict[parameter_name] + + return None + + +def is_buffer(program: "ExportedProgram", node: torch.fx.Node) -> bool: + """ + Checks if the given node is a buffer within the exported program + """ + + return node.name in program.graph_signature.inputs_to_buffers + + +def get_buffer( + program: "ExportedProgram", + node: torch.fx.Node, +) -> Optional[torch.Tensor]: + """ + Returns the buffer associated with the given node in the exported program. + Returns None if the node is not a buffer within the exported program + """ + + if is_buffer(program, node): + buffer_name = program.graph_signature.inputs_to_buffers[node.name] + if buffer_name in program.graph_signature.non_persistent_buffers: + return program.constants[buffer_name] + else: + return program.state_dict[buffer_name] + + return None + + +def is_lifted_tensor_constant( + program: "ExportedProgram", + node: torch.fx.Node, +) -> bool: + """ + Checks if the given node is a lifted tensor constant within the exported program + """ + + return node.name in program.graph_signature.inputs_to_lifted_tensor_constants + + +def get_lifted_tensor_constant( + program: "ExportedProgram", + node: torch.fx.Node, +) -> Optional[torch.Tensor]: + """ + Returns the lifted tensor constant associated with the given node in the exported program. + Returns None if the node is not a lifted tensor constant within the exported program + """ + + if is_lifted_tensor_constant(program, node): + lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[ + node.name + ] + return program.constants[lifted_tensor_name] + + return None + + +def sequential_split(gm: torch.fx.GraphModule, node_call_back) -> torch.fx.GraphModule: + """ + sequential_split creates a new graph module that splits the input graph module into multiple submodules + based on the node_call_back. It doesn't mutate the input graph module. The node_call_back should return + True if the node is a delimiter. Delimiter will be the first node in the next submodule. + """ + from torch.fx.passes.split_module import split_module + + split_map = {} + split_id = 0 + for node in gm.graph.nodes: + if node_call_back(node): + split_id += 1 + split_map[node] = split_id + + new_gm = split_module( + gm, + gm, + lambda node: split_map[node], + keep_original_order=True, + keep_original_node_name=True, + ) + # Keep the codegen from original graph module to preserve e.g. pytree info. + new_gm.graph._codegen = gm.graph._codegen + new_gm.recompile() + return new_gm + + +def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: + """Returns the nodes that match the node_call_back as a list.""" + return [node for node in nodes if node_call_back(node)] + + +def nodes_first( + nodes: List[torch.fx.Node], node_call_back=None +) -> Optional[torch.fx.Node]: + """ + Returns the first node that matches the node_call_back. If no node matches, returns None. + When node_call_back is None, returns the first node in the node list. + """ + ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True) + if len(ret) > 0: + return ret[0] + return None + + +def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int: + """Returns the number of nodes that match the node_call_back.""" + return len(nodes_filter(nodes, node_call_back)) + + +def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: + """ + Sequentially visit the nodes list and invoke node_call_back on each element. + Returns the nodes list after the node_call_back is invoked on each element. + """ + for node in nodes: + node_call_back(node) + return nodes + + +def node_replace_(old_node: torch.fx.Node, new_node: torch.fx.Node) -> None: + """ + Replace all uses of old_node with new_node. + """ + old_node.replace_all_uses_with(new_node) + old_node.users.clear() + old_node.graph.erase_node(old_node) + + +def node_inline_(call_mod_node: torch.fx.Node) -> None: + """ + Inline the submodule of the given node into the parent module. + Note: we only support the case where submodule takes tensors inputs. + """ + assert call_mod_node.op == "call_module" + gm = call_mod_node.graph.owning_module + + assert isinstance(call_mod_node.target, str) + sub_gm = getattr(gm, call_mod_node.target) + + phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder") + body = ( + node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output") + ) + output = [node for node in sub_gm.graph.nodes if node.op == "output"] + + for ph, arg in zip(phs, call_mod_node.args): + assert isinstance(arg, torch.fx.Node) + node_replace_(ph, arg) + + with gm.graph.inserting_before(call_mod_node): + for node in body: + new_node = gm.graph.node_copy(node) + node_replace_(node, new_node) + + if len(output) > 0: + assert len(output) == 1 and len(output[0].args) == 1 + new_output = output[0].args[0] + + if isinstance(new_output, torch.fx.Node): + # Clear the users of the output node and set + # the users to be the users of original call_module node. + new_output.users.clear() + node_replace_(call_mod_node, new_output) + elif isinstance(new_output, (list, tuple)): + # Pop subgraph output node from users. + for node in new_output: + node.users.pop(output[0]) + + # Inline the get_item calls for the output node. + get_item_users = nodes_filter( + list(call_mod_node.users.keys()), + lambda node: node.op == "call_function" + and node.target == operator.getitem, + ) + # get_item_node.args[1] is the idx referring to new_output[idx] + nodes_map( + get_item_users, + lambda get_item_node: node_replace_( + get_item_node, + new_output[get_item_node.args[1]], + ), + ) + call_mod_node.graph.erase_node(call_mod_node) + else: + raise NotImplementedError( + f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes." + ) + else: + call_mod_node.graph.erase_node(call_mod_node) + + gm.delete_all_unused_submodules() + gm.recompile() + return gm + + +def _get_torch_jit_trace_forward_signature(mod: torch.nn.Module): + """ + Get source code and parse argument names using AST. The function returns + a signature of the forward() function. + + # TODO: Directly provide inspect.signature compatible TS-d module. + """ + ast_mod = ast.parse(mod.code) + ast_func_def: ast.FunctionDef = ast_mod.body[0] # type: ignore[assignment] + + # FIXME(jiashenc): TorchScript should only allow positional or keywords arguments. + arg_type_map = {"args": Parameter.POSITIONAL_OR_KEYWORD} + + # Traverse all argument types in AST tree and create associated parameters. + param_list = [] + for arg_type, param_type in arg_type_map.items(): + arg_name_list = [a.arg for a in getattr(ast_func_def.args, arg_type)] + for arg_name in arg_name_list: + if arg_name == "self": + continue # Skip self argument. + param_list.append(inspect.Parameter(arg_name, param_type)) + + return inspect.Signature(parameters=param_list) + + +def _bind_signature_to_inputs(mod, fake_args, fake_kwargs): + if isinstance(mod, (torch.jit.ScriptModule, torch.jit.TracedModule)): + sig = _get_torch_jit_trace_forward_signature(mod) + + # Sanity check for placeholder names coming from TorchScript. + assert len(sig.parameters) == len(fake_args) + len(fake_kwargs), ( + "Arguments other than POSITIONAL_OR_KEYWORD kinds in forward() " + "are not supported in _get_torch_jit_trace_forward_signature" + ) + else: + sig = inspect.signature(mod.forward) + + return sig.bind(*fake_args, **fake_kwargs).arguments + + +def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: + """ + Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, + and handle collisions with non-placeholders by count suffixing. + Different HOO subgraph types have different input schemas, so we first enumerate them + and gather the top-level named placeholder nodes. + """ + # gather all HOO subgraphs and their top-level named placeholder nodes + subgraph_ph_tuples: List[Tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = [] + for node in gm.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.HigherOrderOperator + ): + # HOO subgraphs have varying input schemas, so we enumerate them there + if node.target._name == "cond": + _, true_graph, false_graph, cond_args = node._args + subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args)) + subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args)) + elif node.target._name == "wrap_with_set_grad_enabled": + subgraph, phs = node._args[1], node._args[2:] + subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs)) + elif node.target._name == "map_impl": + body_graph, array, args = node._args + subgraph_ph_tuples.append( + (getattr(gm, body_graph.target), array + args) + ) + + # propagate names + for subgraph, hoo_phs in subgraph_ph_tuples: + name_map: Dict[str, str] = {} + for i, node in enumerate(subgraph.graph.nodes): + if i < len(hoo_phs): # placeholder, retain name + name_map[node.name] = hoo_phs[i].name + node.name = node.target = hoo_phs[i].name + else: # non-placeholder, check for collisions + node.name = _rename_without_collisions(name_map, node.name, node.name) + + # recurse and recompile + _name_hoo_subgraph_placeholders(subgraph) + subgraph.recompile() + + +def placeholder_naming_pass( + gm: torch.fx.GraphModule, + export_graph_signature: "ExportGraphSignature", + mod: torch.nn.Module, + fake_args, + fake_kwargs, + fake_params_buffers, + constants: Dict[str, Any], +) -> None: + """ + This pass is run at the end of _export_non_strict() to assign better placeholder node names: + - User inputs: + These follow the signature of mod.forward(), e.g. forward(x, y) produces nodes x, y. + For nested inputs from dictionaries, lists, tuples, or dataclasses, + the names are a concatenation of the path to the tensor. + e.g. x = { + 'a': torch.randn(), + 'b': [torch.randn(), torch.randn()] + } + produces nodes x_a, x_b_0, x_b_1. + - Parameters/buffers/constants/custom objects: + These follow the FQN of the object, prefixed by "p", "b", "c", "obj" respectively. + e.g. self.bar.l0.weight produces "p_bar_l0_weight". + - Effect tokens: + These are named token, token_1, ... + """ + + def _strip_name(x): + if x.startswith("L__self___"): + x = x[len("L__self___") :] + elif x.startswith("self_"): + x = x[len("self_") :] + x = re.sub(r"[^a-zA-Z0-9]", "_", x) + return x + + def _extract_pytree_key(x): + if isinstance(x, MappingKey): + x = re.sub(r"[^a-zA-Z0-9]", "_", str(x.key)) + return x + elif isinstance(x, SequenceKey): + return str(x.idx) + elif isinstance(x, GetAttrKey): + return x.name + else: + raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}") + + name_map: Dict[str, str] = {} + + # map user input names with mod.forward() signature + combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs) + + flat_args_with_path, _ = tree_flatten_with_path(combined_args) + user_input_names = [ + spec.arg.name + for spec in export_graph_signature.input_specs + if spec.kind == InputKind.USER_INPUT + ] + + # use pytree path to name nested user inputs + for (arg_path, arg), user_input_name in zip(flat_args_with_path, user_input_names): + if user_input_name: + _rename_without_collisions( + name_map, + user_input_name, + placeholder_prefixes[InputKind.USER_INPUT] + + "_".join(_extract_pytree_key(x).lower() for x in arg_path), + is_placeholder=True, + ) + + # use graph signature input specs to map param/buffer/constant names + # name effect tokens as token, token_1, ... (these aren't visible to user) + for spec in export_graph_signature.input_specs: + if spec.kind == InputKind.USER_INPUT: + continue + if spec.kind == InputKind.TOKEN: + base_name = "" + else: + base_name = _strip_name(spec.target).lower() + base_name = re.sub(r"[^a-zA-Z0-9]", "_", base_name) + + _rename_without_collisions( + name_map, + spec.arg.name, + placeholder_prefixes[spec.kind] + base_name, + is_placeholder=True, + ) + + # handle naming collisions with call_function/get_attr inputs. + # here, we want to prioritize user input names over call_function names + # e.g. not have forward(self, mul): lead to a placeholder node called mul_13, + # so we increment the suffix of call_function nodes as needed + for node in gm.graph.nodes: + if node.op == "placeholder": + continue + _rename_without_collisions(name_map, node.name, node.name) + + # assign new node names + for node in gm.graph.nodes: + if node.op == "placeholder": + assert node.name in name_map + node.name = node.target = name_map[node.name] + elif node.name in name_map: + node.name = name_map[node.name] + + # propagate names to higher order op subgraphs + _name_hoo_subgraph_placeholders(gm) + + # re-generate graph module code + gm.recompile() + + # modify graph signature (input specs, output specs, user input mutations) + for spec in export_graph_signature.input_specs: + assert spec.arg.name in name_map + spec.arg.name = name_map[spec.arg.name] + if ( # handle targets for custom objects + spec.kind == InputKind.CUSTOM_OBJ and spec.target in name_map + ): + spec.target = name_map[spec.target][4:] # strip obj_ prefix + + for spec in export_graph_signature.output_specs: + if spec.arg.name in name_map: + spec.arg.name = name_map[spec.arg.name] + if spec.kind == OutputKind.USER_INPUT_MUTATION and spec.target in name_map: + spec.target = name_map[spec.target] + + # rename keys in constants dict for custom objects + for name in list(constants.keys()): + constant = constants[name] + if name in name_map and not isinstance( + constant, torch.Tensor + ): # rename custom objects with generic names + new_name = name_map[name] + if ( + new_name != name + and re.match(r"arg(\d+)_1", name) + and new_name != placeholder_prefixes[InputKind.CUSTOM_OBJ] + name + ): + constants[new_name] = constant + del constants[name] + + +def remove_proxy_from_state_dict(state_dict: Dict, in_place: bool) -> Dict: + """ + If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`. + `v` is the values in the dictionary. + If `in_place` is true, modify `state_dict` in place. + """ + if in_place: + for k, v in state_dict.items(): + if hasattr(v, "proxy"): + delattr(state_dict[k], "proxy") + return state_dict + else: + new_state_dict = {} + for k, v in state_dict.items(): + if hasattr(v, "proxy"): + new_state_dict[k] = v.clone().detach() + else: + new_state_dict[k] = v + return new_state_dict + + +def _detect_fake_mode_from_gm( + gm: torch.fx.GraphModule, +) -> torch._subclasses.fake_tensor.FakeTensorMode: + """ + For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs. + Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes. + If no fake mode is found, we return None for fake_mode. + """ + + fake_inps: List[torch.Tensor] = [] + fake_vals: List[torch.Tensor] = [] + for node in gm.graph.nodes: + if node.op == "placeholder" and "val" in node.meta: + fake_val = node.meta["val"] + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_inps.append(fake_val) + elif len(fake_inps) == 0 and ( + "example_value" in node.meta or "val" in node.meta + ): + fake_val = None + if "example_value" in node.meta: + fake_val = node.meta["example_value"] + elif "val" in node.meta: + fake_val = node.meta["val"] + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_vals.append(fake_val) + + return detect_fake_mode(fake_inps + fake_vals) diff --git a/lib/python3.10/site-packages/torch/_export/verifier.py b/lib/python3.10/site-packages/torch/_export/verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..68c5bcaae39af69f5527e8dcf0e08ed49bad4563 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_export/verifier.py @@ -0,0 +1,456 @@ +# mypy: allow-untyped-defs +import inspect +import math +import operator +from collections.abc import Iterable +from typing import Any, Dict, final, List, Tuple, Type, TYPE_CHECKING + +import torch +from torch._ops import HigherOrderOperator, OpOverload +from torch._subclasses.fake_tensor import FakeTensor +from torch.export.graph_signature import ( + CustomObjArgument, + InputKind, + SymIntArgument, + TensorArgument, + TokenArgument, +) +from torch.fx import GraphModule + +if TYPE_CHECKING: + from torch.export.exported_program import ExportedProgram + +class SpecViolationError(Exception): + pass + + +def is_functional(op: OpOverload) -> bool: + return not op._schema.is_mutable + + +def _check_has_fake_tensor(node: torch.fx.Node) -> None: + # TODO(angelayi): remove this in favor of _check_val + return _check_val(node) + + +def _check_val(node: torch.fx.Node) -> None: + from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt + + def _check_correct_val(val): + if val is None: + return True + elif isinstance(val, (int, bool, str, float)): + return True + elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)): + return True + elif isinstance(val, (FakeTensor, torch.Tensor)): # TODO(zhxchen17) Remove Tensor. + return True + elif isinstance(val, (SymInt, SymFloat, SymBool)): + return True + elif isinstance(val, CustomObjArgument): + return True + elif isinstance(val, Iterable): + return all(_check_correct_val(x) for x in val) + return False + + def _no_returns(op): + if not isinstance(op, OpOverload): + return False + return len(op._schema.returns) == 0 + + if "val" not in node.meta: + if node.op == "call_function" and _no_returns(node.target): + return + raise SpecViolationError(f"Node.meta {node.name} is missing val field.") + + val = node.meta["val"] + if not _check_correct_val(val): + raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}") + + +def _check_torch_fn(node: torch.fx.Node) -> None: + torch_fn = node.meta.get("torch_fn") + if torch_fn is None: + raise SpecViolationError(f"Unable to find torch_fn metadata for node {node.name}") + if ( + not isinstance(torch_fn, tuple) and + isinstance(torch_fn[0], str) and + isinstance(torch_fn[1], str) + ): + raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}") + +class _VerifierMeta(type): + _registry: Dict[str, Type['Verifier']] = {} + + def __new__(metacls, name, bases, attrs): + if bases: + if "check" in attrs or "_check_graph_module" in attrs: + raise SyntaxError("Overriding method check is not allowed.") + assert "dialect" in attrs and attrs["dialect"] != "ATEN" + else: + assert "check" in attrs + assert "_check_graph_module" in attrs + assert attrs["dialect"] == "ATEN" + + assert isinstance(attrs["dialect"], str) + ret = type.__new__(metacls, name, bases, attrs) + metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment] + return ret + +def getattr_recursive(obj: Any, target: str) -> Any: + target_atoms = target.split('.') + attr_itr = obj + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +class Verifier(metaclass=_VerifierMeta): + dialect = "ATEN" + + def allowed_builtin_ops(self) -> List: + return [ + operator.getitem, + operator.add, + operator.mul, + operator.sub, + operator.truediv, + operator.ge, + operator.le, + operator.gt, + operator.lt, + operator.eq, + operator.ne, + operator.floordiv, + operator.mod, + operator.and_, + operator.or_, + operator.not_, + operator.pow, + operator.neg, + operator.abs, + math.ceil, + math.floor, + math.trunc, + ] + + def allowed_op_types(self) -> Tuple[Type[Any], ...]: + return (OpOverload, HigherOrderOperator) + + def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: + return (torch.fx.GraphModule,) + + def check_valid_op(self, op): + pass + + def check_additional(self, gm: GraphModule) -> None: + """ + Additional checks that are specific to some dialects. + """ + + @final + def check(self, ep: "ExportedProgram") -> None: + self._check_graph_module(ep.graph_module) + _verify_exported_program_module_call_graph(ep) + _verify_exported_program_signature(ep) + + @final + def _check_graph_module(self, gm: torch.fx.GraphModule) -> None: + def _allowed_getattr_types() -> Tuple[Type[Any], ...]: + ret = self.allowed_getattr_types() + assert not any(t is object for t in ret) + return ret + + def _check_valid_op(op) -> None: + def _allowed_builtin_ops() -> List: + ret = self.allowed_builtin_ops() + assert all(inspect.isbuiltin(op) for op in ret) + return ret + + def _allowed_op_types() -> Tuple[Type[Any], ...]: + ret = self.allowed_op_types() + assert not any(t is object for t in ret) + return ret + + # TODO Remove this allowlist. + _allowed_torch_functions = ( + torch.autograd.grad_mode.set_grad_enabled, + torch.sym_int, + torch.sym_float, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_not, + torch.sym_sqrt, + # TODO (tmanlaibaatar) + # Predispatch export is able to contain autograd ops. + # These will be modeled as HOO later + torch._C._set_grad_enabled, + ) + + if not isinstance(op, _allowed_op_types()): + if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions: + raise SpecViolationError( + f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n" + f"Valid builtin ops: {_allowed_builtin_ops()}" + f"Valid torch functions: {_allowed_torch_functions}" + ) + + if isinstance(op, OpOverload): + # All ops functional + # TODO (tmanlaibaatar) more proper way is needed here + if self.dialect != "TRAINING" and not is_functional(op): + raise SpecViolationError( + f"operator '{op}' is not functional" + ) + self.check_valid_op(op) + + for mod in gm.modules(): + if not isinstance(mod, torch.fx.GraphModule): + continue + + mod.graph.lint() + for node in mod.graph.nodes: + # TODO(T140410192): should have fake tensor for all dialects + if node.op in {"call_module", "call_method"}: + raise SpecViolationError( + f"call_module is not valid: got a class '{node.target}' ", + ) + + elif node.op == "call_function": + _check_val(node) + + _check_valid_op(node.target) + + elif node.op == "get_attr": + if not isinstance(node.target, str): + raise SpecViolationError( + f"Expected get_attr target to be string, but got {type(node.target)}" + ) + + attr = getattr_recursive(mod, node.target) + if isinstance(attr, torch.nn.Module): + def _is_type(name, ty): + return isinstance(getattr(attr, name, None), ty) + if type(attr).__name__ == "LoweredBackendModule": + if _is_type("backend_id", str) \ + and _is_type("processed_bytes", bytes) \ + and _is_type("compile_specs", list) \ + and hasattr(attr, "original_module"): + continue + else: + backend_id = getattr(attr, "backend_id", None) + processed_bytes = getattr(attr, "processed_bytes", None) + compile_specs = getattr(attr, "compile_specs", None) + raise SpecViolationError( + f"Invalid get_attr type {type(attr)}. \n" + f"LoweredBackendModule fields: " + f"backend_id(str) : {type(backend_id)}, " + f"processed_bytes(bytes) : {type(processed_bytes)}, " + f"compile_specs(list) : {type(compile_specs)}" + ) + + if not isinstance(attr, _allowed_getattr_types()): + raise SpecViolationError( + f"Invalid get_attr type {type(attr)}. \n" + f"Valid get_attr types: {_allowed_getattr_types()}" + ) + + + elif node.op == "placeholder": + _check_val(node) + # TODO(zhxchen17) + # elif node.op == "output": + # _check_flattened_outputs() + + self.check_additional(gm) + + +class TrainingIRVerifier(Verifier): + dialect = "TRAINING" + + +def _verify_exported_program_module_call_graph(exported_program) -> None: + module_call_graph = exported_program.module_call_graph + nodes = { + node.name for node in exported_program.graph.nodes + } + for entry in module_call_graph: + if entry.signature is not None: + for arg in entry.signature.inputs: + if arg.name and arg.name not in nodes: + raise SpecViolationError( + f"Input {arg.name} does not exist in the graph." + ) + for arg in entry.signature.outputs: + if arg.name and arg.name not in nodes: + raise SpecViolationError( + f"Output {arg.name} does not exist in the graph." + ) + + +def _verify_exported_program_signature(exported_program) -> None: + # Check ExportedProgram signature matches + gs = exported_program.graph_signature + + # Check every node in the signature exists in the graph + input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"] + + if len(input_node_names) != len(gs.input_specs): + raise SpecViolationError( + f"Number of graph inputs ({len(input_node_names)}) " + f"does not match number of inputs in the graph signature ({len(gs.input_specs)})" + ) + + for input_spec, node in zip(gs.input_specs, input_node_names): + if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)): + if input_spec.arg.name != node: + raise SpecViolationError( + f"Input spec name {input_spec.arg.name} does not match node name {node}" + ) + + if input_spec.kind == InputKind.USER_INPUT: + continue + + elif input_spec.kind == InputKind.PARAMETER: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + param = input_spec.target + if param not in exported_program.state_dict: + raise SpecViolationError( + f"Parameter {param} is not in the state dict." + ) + + if not isinstance(exported_program.state_dict[param], torch.nn.Parameter): + raise SpecViolationError( + f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter." + ) + + elif input_spec.kind == InputKind.BUFFER: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + buffer = input_spec.target + if input_spec.persistent is None: + raise SpecViolationError( + f"Buffer {buffer} is missing a persistence flag" + ) + + if input_spec.persistent is True and buffer not in exported_program.state_dict: + raise SpecViolationError( + f"Buffer {buffer} is not in the state dict." + ) + + if input_spec.persistent is False and buffer in exported_program.state_dict: + raise SpecViolationError( + f"Non-persistent buffer {buffer} is in the state dict, it should not be." + ) + elif input_spec.kind == InputKind.CONSTANT_TENSOR: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + tensor_const = input_spec.target + if tensor_const not in exported_program.constants: + raise SpecViolationError( + f"Constant tensor {tensor_const} is not in the constants dictionary." + ) + elif input_spec.kind == InputKind.CUSTOM_OBJ: + if not isinstance(input_spec.arg, CustomObjArgument): + raise SpecViolationError( + f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + custom_obj = input_spec.target + if custom_obj not in exported_program.constants: + raise SpecViolationError( + f"Custom object {custom_obj} is not in the constants dictionary." + ) + elif input_spec.kind == InputKind.TOKEN: + if not isinstance(input_spec.arg, TokenArgument): + raise SpecViolationError( + f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + else: + raise SpecViolationError( + f"Unknown InputKind {input_spec.kind}." + ) + + # Check outputs + output_node = list(exported_program.graph.nodes)[-1] + assert output_node.op == "output" + output_nodes = [ + arg.name if isinstance(arg, torch.fx.Node) else arg + for arg in output_node.args[0] + ] + + if len(output_nodes) != len(gs.output_specs): + raise SpecViolationError( + f"Number of output nodes {len(output_nodes)} is different " + "Than the number of outputs specified by the graph signature: \n" + f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n" + f"Number of user outputs: {len(gs.user_outputs)}. \n" + ) + + num_tokens = len(gs.output_tokens) + end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens + mutate_nodes: List[str] = output_nodes[num_tokens:end] + user_output_nodes = output_nodes[end:end + len(gs.user_outputs)] + + for mutation_node in mutate_nodes: + if mutation_node in gs.buffers_to_mutate: + if gs.buffers_to_mutate[mutation_node] not in gs.buffers: + raise SpecViolationError( + f"Buffer output {mutation_node} does not point to a buffer that exists. \n" + f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" + f"Buffer nodes available: {gs.buffers} \n" + ) + elif mutation_node in gs.user_inputs_to_mutate: + if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs: + raise SpecViolationError( + f"User input output {mutation_node} does not point to a user input that exists. \n" + f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n" + f"User input nodes available: {gs.user_inputs} \n") + else: + raise SpecViolationError( + f"Mutation node {mutation_node} is neither a buffer nor a user input. " + f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}" + ) + + for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs): + if user_output_node != user_output_name: + raise SpecViolationError( + f"User output {user_output_node} is not in the correct " + "order or is not found in the " + f"exported program's user_output list: {gs.user_outputs}. " + ) + + +def load_verifier(dialect: str) -> Type[Verifier]: + if dialect == "ATEN" or dialect == "": + return _VerifierMeta._registry.get(dialect, Verifier) + return _VerifierMeta._registry[dialect] diff --git a/lib/python3.10/site-packages/torch/_export/wrappers.py b/lib/python3.10/site-packages/torch/_export/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..d57ff46de41c8f5961a859d3d1e2871984929b8d --- /dev/null +++ b/lib/python3.10/site-packages/torch/_export/wrappers.py @@ -0,0 +1,121 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +import torch +import torch._custom_ops +from torch._C import DispatchKey +from torch._higher_order_ops.strict_mode import strict_mode +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.utils import _pytree as pytree + + +class ExportTracepoint(HigherOrderOperator): + def __init__(self): + super().__init__("_export_tracepoint") + + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + + +_export_tracepoint = ExportTracepoint() + + +@_export_tracepoint.py_impl(ProxyTorchDispatchMode) +def export_tracepoint_dispatch_mode(mode, *args, **kwargs): + p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs)) + proxy = mode.tracer.create_proxy( + "call_function", _export_tracepoint, p_args, p_kwargs + ) + return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer) + + +@_export_tracepoint.py_impl(FakeTensorMode) +def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs): + with mode: + return args + + +@_export_tracepoint.py_functionalize_impl +def export_tracepoint_functional(ctx, *args, **kwargs): + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + + with ctx.redispatch_to_next(): + out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs) + return ctx.wrap_tensors(out) + + +_export_tracepoint.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(_export_tracepoint, deferred_error=True) +) + + +@_export_tracepoint.py_impl(DispatchKey.CPU) +def export_tracepoint_cpu(*args, **kwargs): + return args + + +def _wrap_submodule(mod, path, module_call_specs): + assert isinstance(mod, torch.nn.Module) + assert path != "" + submodule = mod + for name in path.split("."): + if not hasattr(submodule, name): + raise RuntimeError(f"Couldn't find submodule at path {path}") + submodule = getattr(submodule, name) + + def update_module_call_signatures(path, in_spec, out_spec): + if path in module_call_specs: + assert module_call_specs[path]["in_spec"] == in_spec + assert module_call_specs[path]["out_spec"] == out_spec + module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec} + + def check_flattened(flat_args): + for a in flat_args: + if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None): + raise AssertionError( + f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}" + ) + + def pre_hook(module, args, kwargs): + flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + check_flattened(flat_args) + flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path) + args, kwargs = pytree.tree_unflatten(flat_args, in_spec) + return args, kwargs + + def post_hook(module, args, kwargs, res): + _, in_spec = pytree.tree_flatten((args, kwargs)) + flat_res, out_spec = pytree.tree_flatten(res) + check_flattened(flat_res) + flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path) + update_module_call_signatures(path, in_spec, out_spec) + return pytree.tree_unflatten(flat_res, out_spec) + + pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True) + post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True) + return pre_handle, post_handle + + +@contextmanager +def _wrap_submodules(f, preserve_signature, module_call_signatures): + handles = [] + + try: + for path in preserve_signature: + handles.extend(_wrap_submodule(f, path, module_call_signatures)) + yield + finally: + for handle in handles: + handle.remove() + + +def _mark_strict_experimental(cls): + def call(self, *args): + return strict_mode(self, args) + + cls.__call__ = call + return cls diff --git a/lib/python3.10/site-packages/torch/_functorch/__init__.py b/lib/python3.10/site-packages/torch/_functorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a55772ab58b21573a6eba0356ddd3080164ac7 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py b/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..f78ebb31c6cb4e6cab3d5519c3f1545ff82bd5c3 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py @@ -0,0 +1,1557 @@ +# mypy: ignore-errors + +import itertools +from contextlib import contextmanager, nullcontext +from functools import partial, wraps +from typing import Any, Callable, Dict, List, NewType, Optional, Tuple +from unittest.mock import patch + +import torch +import torch._dynamo.logging +import torch.nn as nn +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch import Tensor +from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo import compiled_autograd +from torch._dynamo.utils import dynamo_timed, preserve_rng_state +from torch._guards import detect_fake_mode +from torch._inductor.utils import BoxedBool +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + +from . import config +from ._aot_autograd.autograd_cache import ( # noqa: F401 + AOTAutogradCache, + autograd_cache_key, +) +from ._aot_autograd.collect_metadata_analysis import ( # noqa: F401 + run_functionalized_fw_and_collect_metadata, +) +from ._aot_autograd.functional_utils import ( # noqa: F401 + _check_if_mutation_can_be_in_graph, + are_all_mutations_hidden_from_autograd, + are_all_mutations_under_no_grad_or_inference_mode, + assert_functional_graph, + from_fun, + gen_alias_from_base, + has_data_mutation, + has_metadata_mutation, + is_fun, + sync_functional_tensor, + to_fun, +) +from ._aot_autograd.input_output_analysis import ( # noqa: F401 + _tensors_definitely_do_not_overlap, + compute_overlapping_inputs, + create_graph_signature, + create_synthetic_base_metadata, + remove_dupe_metadata, +) +from ._aot_autograd.jit_compile_runtime_wrappers import ( # noqa: F401 + aot_dispatch_autograd, + aot_dispatch_base, + aot_dispatch_export, +) +from ._aot_autograd.logging_utils import ( # noqa: F401 + callback_set, + describe_input, + format_guard_bug_msg, + get_aot_compilation_context, + get_aot_graph_name, + get_graph_being_compiled, + graph_being_compiled, + model_name, + nth_graph, + set_model_name, + setup_stacktrace_preservation_hooks, + track_graph_compiling, +) +from ._aot_autograd.runtime_wrappers import ( # noqa: F401 + AOTDedupeWrapper, + AOTSyntheticBaseWrapper, +) +from ._aot_autograd.schemas import ( # noqa: F401 + AOTConfig, + BackwardSignature, + FQN, + GraphInputName, + GraphOutputName, + GraphSignature, + InputAliasInfo, + MutationType, + OutputAliasInfo, + OutputType, + SubclassCreationMeta, + SubclassMeta, + TensorAlias, + ViewAndMutationMeta, +) +from ._aot_autograd.subclass_utils import ( # noqa: F401 + create_metadata_for_subclass, + requires_subclass_dispatch, + unwrap_tensor_subclasses, + wrap_tensor_subclasses, + wrap_tensor_subclasses_maybe_joint, +) +from ._aot_autograd.traced_function_transforms import ( # noqa: F401 + aot_dispatch_subclass, + create_functional_call, + create_functionalized_fn, + create_functionalized_rng_ops_wrapper, + create_joint, + fn_input_mutations_to_outputs, + fn_prepped_for_autograd, +) +from ._aot_autograd.utils import ( # noqa: F401 + _get_autocast_states, + _get_symint_hints, + call_func_at_runtime_with_args, + create_tree_flattened_fn, + KNOWN_TYPES, + make_boxed_compiler, + make_boxed_func, + maybe_to_fresh_input, + normalize_as_list, + partial_flatten_asdict, + root_module_when_exporting_non_strict, + strict_zip, +) +from .partitioners import default_partition + + +zip = strict_zip + +# This global counter increments every time we compile a graph with +# AOTAutograd. You can use this to correlate runtime error messages +# with compile time (e.g., if you get an error at runtime saying +# compiled graph 3 failed, you can set a breakpoint at compile time +# for this graph number to investigate further at compile time.) +# +# NB: this is different from get_aot_compilation_context, which tracks +# each underlying graph that is compiled. In contrast, AOT_COUNTER +# corresponds to top-level invocations of aot_module/aot_function; +# one counter is allocated per entire compiled block (but this block +# may involve compiling multiple subgraphs; e.g., for forwards/backwards) +AOT_COUNTER = itertools.count() + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# AOT Autograd contains a pretty non-trivial amount of logic to handle edge cases around aliasing and mutation +# that are external to the graph (they show up as side effects in some way when you run the graph). +# +# Take a look at `test_aotdispatch.py TestAOTAutograd.test_input_mutation*` tests for some examples functions +# and what they're compiled graphs looks like. +# Below is a very long comment detailing several edge cases, and showing how AOT Autograd handles them. +# +# Note [AOT Autograd: input data mutations] +# +# If we compile a function that mutates inputs, then those input mutations are real side effects +# that a user expects to see after running the compiled graph. +# However, the graph that we want to send to a backend needs to be *entirely* functional. +# The way we reconcile this difference is that we remove the mutations completely from the graph that we compile +# but we update the graph to return (updated_inputs, user_outputs). +# In the epilogue that runs after the compiled graph is executed, we copy the updated inputs back to the originals. +# +# Example: original user code: +# def f(x): +# x.mul_(2) +# out = x.mul(3) +# return out +# +# After AOT Autograd compiles, we end up with a: +# (a) compiled graph +# (b) autograd.Function.forward() method, that executes the compiled graph +# (c) wrapper function, that calls the autograd.Function.forward() and performs the epilogue +# +# The output of (a, b, c) are all written below. +# +# def compiled_forward_graph(x): +# x_updated = x.mul(2) +# out = x_updated.mul(3) +# return x_updated, out +# +# # x_updated gets a gradient in the compiled backward +# def compiled_backward_graph(grad_x_updated, grad_out): +# grad_x = ... +# return grad_x +# +# def autograd.Function.forward(x): +# x_updated, out = compiled_forward_graph(x) +# return x_updated, out +# +# def compiled_wrapper(x): +# x_updated, out = autograd.Function.apply(x) +# x.copy_(x_updated) +# return out +# +# Another important thing to note is that updated inputs (due to data mutations) *do* participate +# in the compiled backward graph! Since the compiled forward graph gets N extra outputs +# (due to updated inputs showing up as graph outputs), +# The compiled backward gets an additional N inputs. +# That way, during the x.copy_(x_updated) bit in the epilogue, gradients will flow from the updated input +# back to the original input. + + +# Note [AOT Autograd: input metadata mutations] +# +# For the same reason as input mutations, we also don't put input metadata mutations in the graph. +# Instead, we return the updated version of the input (a view), and mutate the input's metadata outside of the graph +# +# Example: original user code: +# def f(x): +# x.t_() +# out = x.mul(3) +# return out +# +# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): +# def compiled_forward_graph(x): +# x_updated = x.t() +# out = x_updated.mul(3) +# return x_updated, out +# +# # x_updated does *not* get a gradient in the compiled backward +# def compiled_backward_graph(grad_out): +# grad_x = ... +# return grad_x +# +# def autograd.Function.forward(x): +# x_updated, out = compiled_forward_graph(x) +# return x_updated, out +# +# def compiled_wrapper(x): +# x_updated, out = autograd.Function.apply(x) +# x.as_strided_(x_updated) +# return out + + +# Note [AOT Autograd: outputs aliasing inputs or intermediates!] +# +# AOT Autograd needs special handling for outputs that alias graph inputs or intermediates! +# Why? +# (1) autograd.Function.forward() has a limitation, where views that returned in the forward cannot later be mutated. +# (2) views don't need to be compiled in the graph anyway - it's cheap to generate them outside of the compiled graph, +# in an epilogue. +# For outputs that alias inputs, we do the following: +# (a) *still* return the aliased output as a graph output +# (b) In the AOT Autograd wrapper/epilogue, we don't return that aliased output. Instead, we use it to regenerate the output. +# +# For outputs that alias *intermediates*, we do the following: +# (a) Return the output in the compiled forward, **and** return it's ._base (a graph intermediates) as an output in the forward +# (b) Use (output, graph_intermediate) to regenerate the alias, and return that to the user (instead of the compiled fw output). +# You might wonder why we return the aliased output directly in the graph (and making the graph compute it), +# only to not return it and instead generate a fresh alias off of the intermediate, +# instead of (say) just storing metadata about the size/stride of the output somewhere to generate the alias. There are two reasons: +# (1) Getting the actual alias tensor allows us to use view-replay to generate the alias, instead of an as_strided() call +# (2) Inductor (and other backends) are free to change the memory format of graph outputs, if it results in better performance. +# This can result in problems if a user later tries to .view() that output expecting it to have one set of strides, +# when it has a different set of strides. +# By including the view op directly in the graph, inductor takes that into account when deciding what memory format +# the graph intermediate should be. +# +# Another important thing to note is how our traced backward() graph handles aliases. +# (this applies to outputs aliasing inputs, outputs aliasing intermediates, +# *and* updated inputs returned in the compiled forward due to metadata-only mutations). +# Any outputs that alias (either inputs or intermediates) do NOT participate in the compiled backward graph +# It would be wasteful to include them in the compiled backward(), because we regenerate them eagerly +# at the end of the forward. +# +# Example: original user code: +# def f(x): +# out1 = x.t() +# intermediate = x.mul(2) +# out2 = intermediate.view(-1) +# return out1, out2 +# +# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): +# def compiled_forward_graph(x): +# out1 = x.t() +# intermediate = x.mul(2) +# out2 = intermediate.view(-1) +# # the compiled graph also returns the intermediate +# return out1, out2, intermediate +# +# # intermediate gets a gradient in the compiled backward. +# # both output aliases (out1 and out2) do not. +# def compiled_backward_graph(grad_intermediate): +# grad_x = ... +# return grad_x +# +# def autograd.Function.forward(x): +# out1, out2, intermediate = compiled_forward_graph(x) +# return out1, out2, intermediate +# +# def compiled_wrapper(x): +# out1, out2, intermediate = autograd.Function.apply(x) +# # regenerate out1 from the input +# out1_regenerated = out1._view_func(x) +# # regenerate out1 from the intermediate +# out2_regenerated = out2._view_func(intermediate) +# return out1_regenerated, out2_regenerated + + +# Note [AOT Autograd: mutations to inputs that alias other inputs] +# +# Another edge case that is (only partially) handled today is when an input is mutated, but itself aliases another input. +# AOT Autograd needs to **ensure** that functionalization knows that the two inputs are aliased to each other. +# That way, when the aliased input is accessed later in the graph, functionalization knows to "update" the alias +# given the mutation that occurred. +# +# This is handled by updating the calling convention: we create a "synthetic base" that becomes a new input +# in the compiled function, and we regenerate the original (aliased) inputs directly off of the base +# inside of the compiled function. +# +# This logic is fully encapsulated in aot_wrapper_synthetic_base() +# +# Example: original user code: +# def f(x, x_view): +# x.mul_(2) +# out = x * x_view +# return out +# f(x, x.view(-1)) +# +# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): +# def compiled_forward_graph(base) +# x = generate_x(base) +# x_view = generate_x_view(base) +# x_updated = x.mul(2) +# x_view_updated = x_updated.view(-1) +# out = x_updated * x_view_updated +# return x_updated, out +# +# # The calling convention change from (aliases) -> (base) happens +# # *outside* of the autograd.Function.forward(). +# # That means the forward() only has 1 input (base), +# # and the backward() only has 1 output (grad_base) +# def compiled_backward_graph(grad_out): +# grad_base = ... +# return grad_base +# +# def autograd.Function.forward(base): +# x_updated, out = compiled_forward_graph(base) +# return x_updated, out +# +# # The compiled wrapper is where we create synthetic bases. +# # The info on which inputs are mutated is also tracked *before* synthetic base creation. +# def compiled_wrapper(x, x_view): +# base = merge_view_inputs(x, x_view) +# x_updated, out = autograd.Function.apply(base) +# # x and x_view are aliased in eager mode, so this mutation to x will automatically affect x_view. +# x.copy_(x_updated) +# return out + + +# Note [AOT Autograd: Views to avoid tangents aliasing inputs] +# +# We view every forward output when creating out tangent tensors to handle the problematic +# case in which a subclass does extra aliasing between graph outputs/inputs in a way that +# is not visible above the sublass. +# +# Ordinarily, when constructing the joint function that we want to trace in AOTAutograd, +# we're guaranteed that the tangent tensors that we pass +# into the joint are distinct tensors from the primals. This is because when +# decide which forward outputs to create tangents for, we only create tangents +# for forward outputs that are not aliases of inputs (See Note +# [AOT Autograd: outputs aliasing inputs or intermediates!]). +# +# However, when wrapper tensor subclasses enter the picture, it is possible +# to have an output of the forward that is a subclass that is not an +# input / alias of an input, but one of its inner tensors is an alias! +# NestedTensor is an example: Performing an out-of-place pointwise op on a +# NestedTensor constructs a fresh NestedTensor that holds onto the input's +# offsets tensor directly. +# +# Having tangent tensors that are the same as the (primal) forward inputs, +# can cause problems during tracing as make_fx() will specialize on our +# duplicate inputs: If we passed in the same tensor for primals_1 and +# tangents_1 during tracing, make_fx() will happily sub out all usages of +# tangents_1 with primals_1 in the graph, which is not what we want. +# +# To work around this, we view every forward output when creating out tangent +# tensors so that tangents can never be the same as forward inputs even if +# forward inputs alias forward outputs. + +# Note [Side-Effectful Tokens in AOTAutograd] +# +# We allow some some side-effectful operators in +# the post-AOTAutograd (functional) graph, such as prints and torchbind operations. +# To ensure that these side-effects are compatible to future graph passes that +# assume that the graph is functional, we will thread "effect tokens" to show +# data dependence between these side-effectful operators. Practically speaking, +# effect tokens are just dummy values (torch.tensor([])). The graph would look +# like the following: +# +# def gm(self, token0, reader): +# token1, frame = with_token(ordered_effect_op, (reader,), token0) +# frame = frame * 2 +# token2, frame2 = with_token(ordered_effect_op, (reader,), token1) +# frame2 = frame2 * 2 +# return token2, frame, frame2 +# +# We will pass the token as an input to the graph, thread it through +# side-effectful operators using the `with_effects` high order operator, and then +# return the updated token as an output. +# So the signature of the graph input would look something like +# (*tokens, *params_buffers, *user_inputs), and the signature of the graph +# output would look something like (*tokens, *outputs). +# +# However, Inductor does not want the concept of tokens in the final generated +# code's input and output. Since changing the graph signature inside of inductor +# is difficult, after generating the forward graph, we will run a pass to +# remove the tokens from the inputgenerate the following graph for Inductor, where +# the tokens are created and sunk within the graph, rather than as inputs and +# outputs: +# +# def gm(self, reader): +# token0 = torch.ops.prims._make_token() +# token1, frame = with_token(ordered_effect_op, (reader,), token0) +# frame = frame * 2 +# token2, frame2 = with_token(ordered_effect_op, (reader,), token1) +# frame2 = frame2 * 2 +# sink_token = torch.ops.prims._sink_tokens([token2]) +# return frame, frame2 + +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +aot_autograd_decompositions = {} + +FakifiedFlatArgs = NewType("FakifiedFlatArgs", List[Any]) + + +def process_inputs( + flat_args: List[Any], + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], +) -> FakifiedFlatArgs: + with fake_mode: + + def convert(idx, x): + if shape_env is not None: + from torch._dynamo.source import ConstantSource + + if isinstance(x, int): + # We always specialize on scalar values in export. + if aot_config.is_export: + return x + source = ConstantSource(f"sym_{idx}") + return shape_env.create_symintnode( + shape_env.create_symbol(x, source), hint=x, source=source + ) + if isinstance(x, torch.ScriptObject): + return torch._library.fake_class_registry.maybe_to_fake_obj( + fake_mode, x + ) + if not isinstance(x, torch.Tensor): + return x + if isinstance(x, FakeTensor): + assert x.fake_mode is fake_mode + return x + if is_traceable_wrapper_subclass(x): + attrs, _ = x.__tensor_flatten__() + if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): + assert all( + getattr(x, attr).fake_mode is fake_mode for attr in attrs + ) + return x + + # see note [Tensor Fakification and Symbol Caching] + symbolic_context = None + source = None + trace = True + if tracing_context := torch._guards.TracingContext.try_get(): + if x in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[x] + source = symbolic_context.tensor_source + # We already fakeified this tensor in Dynamo, don't + # dump the trace for it again + trace = False + if ( + idx < aot_config.num_params_buffers + and config.static_weight_shapes + and not symbolic_context + ): + # TODO: Ensure that this codepath is never exercised from + # Dynamo + return fake_mode.from_tensor(x, static_shapes=True) + + return fake_mode.from_tensor( + x, + static_shapes=False, + symbolic_context=symbolic_context, + source=source, + trace=trace, + ) + + return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)]) + + +def construct_fake_mode( + flat_args: List[Any], aot_config: AOTConfig +) -> Tuple[FakeTensorMode, Optional[ShapeEnv]]: + fake_mode = detect_fake_mode(flat_args) + if fake_mode is None: + shape_env = ShapeEnv() if aot_config.dynamic_shapes else None + fake_mode = FakeTensorMode(shape_env=shape_env) + else: + shape_env = fake_mode.shape_env + return (fake_mode, shape_env) + + +def create_aot_dispatcher_function( + flat_fn, + fake_flat_args: FakifiedFlatArgs, + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], +) -> Tuple[Callable, ViewAndMutationMeta]: + with dynamo_timed("create_aot_dispatcher_function"): + return _create_aot_dispatcher_function( + flat_fn, fake_flat_args, aot_config, fake_mode, shape_env + ) + + +def _create_aot_dispatcher_function( + flat_fn, + fake_flat_args: FakifiedFlatArgs, + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], +) -> Tuple[Callable, ViewAndMutationMeta]: + """ + Traces the forward and backward graphs of the attr:`flat_fn` to generate a + joint graph. The joint graph is an Fx graph with Aten ops. Please refer to + the tracing mechanism to understand the graph capturing details. + + The joint graph is then passed through attr:`partition_fn` to isolate the + forward and backward portions, which are then respectively compiled via the + provided attr:`fw_compiler` and attr:`bw_compiler`. + + The resulting compiled forward and backward graphs are then wrapped up in a + ``torch.autograd.Function`` object. + + The calling convention here is that the first aot_config.num_params_buffers + inputs in flat_args are parameters and buffers, and the rest are inputs. + + We use this to assume that parameters/buffer's shapes don't change. + + Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export) + When aot_config.is_export is True, we return an FX graph + metadata + When aot_config.is_export is False, we return an ordinary runtime function + """ + + # This is the main entry point. + # TODO: Chillee argues that dynamo itself should pass in fake tensors to + # the list of arguments when compiling; at the moment we do not do this + + if aot_config.decompositions is None: + aot_config.decompositions = {} + + aot_config.decompositions = { + **aot_autograd_decompositions, + **aot_config.decompositions, + } + + if config.functionalize_rng_ops: + # Update the decompositions with functionalized random decompositions + aot_config.decompositions = { + **rng_decompositions, + **aot_config.decompositions, + } + + # Check flat_args to see if they're already fake. If so, use that fake + # mode instead. + + python_dispatcher_mode = ( + enable_python_dispatcher() if shape_env is not None else nullcontext() + ) + + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + # If any saved tensor hooks are active, we **don't** want to trace them. + # Instead, we'll let them run at runtime, around the custom autograd.Function + # that we generate in torch.compile. + with torch.autograd.set_multithreading_enabled( + False + ), preserve_rng_state(), ( + fake_mode + ), ( + python_dispatcher_mode + ), PhiloxStateTracker(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + from torch._library.fake_class_registry import ( + FakeScriptObject, + maybe_to_fake_obj, + ) + + # Tracing may mutate the states the fake script object, + # so we need to duplicate the fake script objects so that subsequent tracing + # won't be affected. + def _dup_fake_script_obj(fake_flat_args): + return [ + maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj) + if isinstance(arg, FakeScriptObject) + else arg + for arg in fake_flat_args + ] + + needs_autograd = any( + x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) + ) + + with enable_python_dispatcher(): + # Patch set_rng_state as set_rng_state with fake tensors is + # nonsensical. This does not affect the collection of metadata. + with patch("torch.cuda.set_rng_state", lambda *args: None): + mod = root_module_when_exporting_non_strict(flat_fn) + if mod is not None: + ctx = _detect_attribute_assignment(mod) + else: + ctx = nullcontext() + with ctx: + fw_metadata = run_functionalized_fw_and_collect_metadata( + flat_fn, + static_input_indices=aot_config.static_input_indices, + keep_input_mutations=aot_config.keep_inference_input_mutations, + is_train=needs_autograd, + pre_dispatch=aot_config.pre_dispatch, + )(*_dup_fake_script_obj(fake_flat_args)) + + req_subclass_dispatch = requires_subclass_dispatch( + fake_flat_args, fw_metadata + ) + + output_and_mutation_safe = not any( + x.requires_grad + # view-type operations preserve requires_grad even in no_grad. + # Do not count aliases of inputs with requires_grad as reason to make a training graph, + # as AOTAutograd will perform view-replay to regenerate the view outputs at runtime, + # setting their grad_fn properly. + and not ( + x.output_type + in (OutputType.alias_of_input, OutputType.is_input) + and fw_metadata.input_info[x.base_idx].requires_grad + ) + for x in fw_metadata.output_info + ) and not any( + x.requires_grad + and x.mutates_data + and not x.mutations_under_no_grad_or_inference_mode + and not x.mutations_hidden_from_autograd + for x in fw_metadata.input_info + ) + + if needs_autograd and output_and_mutation_safe: + # We realized that none of the outputs require grad, + # and none of the inputs that require grad are mutated. + # so we actually have an inference graph. + needs_autograd = False + # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta + # changes depending on whether we pass in is_train / keep_input_mutations, + # so we're forced to recompute the metadata. + # TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata + # so that this is unnecessary. + if req_subclass_dispatch: + fw_metadata = run_functionalized_fw_and_collect_metadata( + flat_fn, + keep_input_mutations=aot_config.keep_inference_input_mutations, + is_train=False, + pre_dispatch=aot_config.pre_dispatch, + static_input_indices=aot_config.static_input_indices, + )(*fake_flat_args) + else: + fw_metadata = ViewAndMutationMeta( + input_info=fw_metadata.input_info, + output_info=fw_metadata.output_info, + num_intermediate_bases=fw_metadata.num_intermediate_bases, + keep_input_mutations=aot_config.keep_inference_input_mutations, + traced_tangents=fw_metadata.traced_tangents, + subclass_inp_meta=fw_metadata.subclass_inp_meta, + subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta, + subclass_tangent_meta=fw_metadata.subclass_tangent_meta, + is_train=False, + tokens=fw_metadata.tokens, + static_input_indices=fw_metadata.static_input_indices, + ) + + if fw_metadata.num_intermediate_bases > 0: + assert not req_subclass_dispatch, f"""\ +torch.compile is currently being used with tensor subclass inputs: +{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs +that alias one another, which is currently unsupported in the subclass use case. If you run into this, +please file a github issue""" + + if aot_config.is_export: + # aot_export: ban input metadata mutations for now to keep shared code paths simpler. + # Keeping .resize_() in the graph will require some work + # Allowing it but keeping the graph functional will require some calling convention changes. + if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0: + raise RuntimeError( + f"""\ +Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`. +This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. + +fw_metadata={str(fw_metadata)}""" + ) + # In export, banning data mutations on inputs that require grad for now. + # This should be rare, and is tricky to get right. When we trace the backward, + # we currently trace with autograd.grad instead of .backward(), which makes it difficult + # to ensure that we run autograd all the way through the input **before** it saw the mutation. + if ( + len( + [ + x + for x in fw_metadata.input_info + if x.requires_grad and x.mutates_data + ] + ) + != 0 + ): + raise RuntimeError( + f"""\ +Found a graph input that requires gradients, and received a mutation. +This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. + +fw_metadata={str(fw_metadata)}""" + ) + if req_subclass_dispatch: + raise RuntimeError( + """\ +aot_export is not currently supported with traceable tensor subclass. +If you need this feature, please comment on """ + ) + + # Need to decide on a strategy for functionalized RNG: toggling via global config seems bad, + # and turning it on will require a non-trivial calling convention change for any export runtime. + if config.functionalize_rng_ops: + raise RuntimeError( + """\ +Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue, +or otherwise set torch._functorch.config.functionalize_rng_ops = False.""" + ) + + def choose_dispatcher(needs_autograd, aot_config): + """ + Pick a dispatcher based on the config rules. + """ + if aot_config.is_export: + # export uses just the "graph bits", whereas the other + # two dispatchers include some extra work around handling a runtime epilogue + return partial(aot_dispatch_export, needs_autograd=needs_autograd) + elif needs_autograd and not aot_config.pre_dispatch: + return aot_dispatch_autograd + else: + return aot_dispatch_base + + compiler_fn = choose_dispatcher(needs_autograd, aot_config) + + compiled_fn, fw_metadata = compiler_fn( + flat_fn, + _dup_fake_script_obj(fake_flat_args), + aot_config, + fw_metadata=fw_metadata, + ) + return compiled_fn, fw_metadata + + +def aot_function( + fn: Callable, + fw_compiler: Callable, + bw_compiler: Optional[Callable] = None, + partition_fn: Callable = default_partition, + decompositions: Optional[Dict] = None, + num_params_buffers: int = 0, + keep_inference_input_mutations: bool = False, + inference_compiler: Optional[Callable] = None, + *, + # Whether or not to trace with dynamic shapes + dynamic=False, + enable_log=True, +) -> Callable: + """ + Traces the forward and backward graph of :attr:`fn` using torch dispatch + mechanism, and then compiles the generated forward and backward graphs + through :attr:`fw_compiler` and :attr:`bw_compiler`. + + :func:`aot_function` traces the forward and backward graph ahead of time, + and generates a joint forward and backward graph. :attr:`partition_fn` is + then used to separate out forward and backward graphs. The partitioner + function can be used to perform optimizations such as recomputation. One can + set `decompositions` dictionary to decompose the operators into a sequence + of core or simpler operators supported by the backend compilers. + + .. warning:: + This API is experimental and likely to change. + + Args: + fn (Callable): A Python function that takes one ore more arguments. Must + return one or more Tensors. + fw_compiler (Callable): A Python function that accepts an Fx graph with + Aten ops and input args, and returns a Callable that semantically is + equivalent to the input Fx graph. + bw_compiler (Optional[Callable]): A Python function that accepts an + Fx graph with Aten ops and input args, and returns a Callable that + semantically is equivalent to the input Fx graph. Default: None + (when None, it defaults to the :attr:`fw_compiler`) + partition_fn (Callable): A Python function that takes a joint forward + and backward graph, and partitions it into separate forward and + backward graphs. + decompositions (Dict): A dictionary to define the decomposition of + larger Aten ops into simpler or core Aten ops. + inference_compiler (Optional[Callable]): A Python function that accepts an + Fx graph with Aten ops and input args, and returns a Callable that + semantically is equivalent to the input Fx graph. inference_compiler is invoked + if no autograd is needed. Default: None + (when None, it defaults to the :attr:`fw_compiler`) + Returns: + Returns a ``Callable`` that retains the eager behavior of the original + :attr:`fn`, but with forward and backward graph compiled via + :attr:`fw_compile` and :attr:`bw_compile`. + + A simple example usage of :func:`aot_function` is as follows. This example + will print the forward and backward graphs of the function ``fn`` + + >>> fn = lambda x : x.sin().cos() + >>> def print_compile_fn(fx_module, args): + >>> print(fx_module) + >>> return fx_module + >>> aot_fn = aot_function(fn, print_compile_fn) + >>> x = torch.randn(4, 5, requires_grad=True) + >>> aot_fn(x) + """ + + if bw_compiler is None: + bw_compiler = fw_compiler + if inference_compiler is None: + inference_compiler = fw_compiler + aot_config = AOTConfig( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + inference_compiler=inference_compiler, + partition_fn=partition_fn, + decompositions=decompositions, + num_params_buffers=num_params_buffers, + aot_id=next(AOT_COUNTER), + keep_inference_input_mutations=keep_inference_input_mutations, + dynamic_shapes=dynamic, + aot_autograd_arg_pos_to_source=None, + is_export=False, + no_tangents=False, + enable_log=enable_log, + ) + cached_res = None + + @wraps(fn) + def returned_function(*args, **kwargs): + nonlocal cached_res + # Now flatten the tensor args + flat_args = pytree.arg_tree_leaves(*args, **kwargs) + + # Compile the function and save it in the cache + if cached_res is None: + flat_fn, out_spec = create_tree_flattened_fn(fn, args, kwargs) + (fake_mode, shape_env) = construct_fake_mode(flat_args, aot_config) + fake_flat_args: FakifiedFlatArgs = process_inputs( + flat_args, aot_config, fake_mode, shape_env + ) + compiled_fn, _ = create_aot_dispatcher_function( + flat_fn, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) + cached_res = (compiled_fn, out_spec) + + cached_fn, out_spec = cached_res + out = cached_fn(flat_args) + return out_spec.unflatten(out) + + return returned_function + + +def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module: + """ + Traces the forward and backward graph of :attr:`mod` using torch dispatch + tracing mechanism. It is wrapper function, that underneath uses + :func:`aot_function` to perform tracing and compilation. + + :func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs + to a new callable which is then compiled through :func:`aot_function`. + + .. warning:: + This API is experimental and likely to change. + + Args: + mod (Callable): A ``nn.Module`` module. + args : args to be passed to :func:`aot_function` + kwargs : kwargs to be passed to :func:`aot_function` + + Returns: + Returns a ``nn.Module`` that retains the eager behavior of the original + :attr:`mod`, but with forward and backward graph compiled. + + """ + # See Note: [Fake Modules and AOTAutograd] + torch._dynamo.utils.assert_no_fake_params_or_buffers(mod) + + def functional_call(named_params, named_buffers, *args, **kwargs): + params_and_buffers = {**named_params, **named_buffers} + return torch.func.functional_call(mod, params_and_buffers, args, kwargs) + + named_params = dict(mod.named_parameters(remove_duplicate=False)) + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) + num_params_buffers = len(named_params) + len(named_buffers) + compiled_f = aot_function( + functional_call, *args, num_params_buffers=num_params_buffers, **kwargs + ) + + class AOTModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.orig_module = mod + + def forward(self, *args, **kwargs): + return compiled_f( + named_params, + named_buffers, + *args, + **kwargs, + ) + + return AOTModule() + + +def aot_module_simplified( + mod: nn.Module, + args, + fw_compiler: Callable, + bw_compiler: Optional[Callable] = None, + partition_fn: Callable = default_partition, + decompositions: Optional[Dict] = None, + keep_inference_input_mutations=False, + inference_compiler: Optional[Callable] = None, + cudagraphs: Optional[BoxedBool] = None, +) -> nn.Module: + """ + This is the simplified or low overhead version of aot_module. For frontends + like TorchDynamo, the input functions/modules to AOT are static and have + unpacked inputs/outputs. This gives us an opportunity to remove the + (1) pytree overhead to parse inputs/outputs, + (2) AOT Autograd cache, + (3) Reading of params/buffers in every forward call + + :func:`aot_module_simplified` removes these overheads. + """ + params = { + **dict(mod.named_parameters(remove_duplicate=False)), + **dict(mod.named_buffers(remove_duplicate=False)), + } + params_flat, params_spec = pytree.tree_flatten(params) + params_flat = list(params_flat) + params_len = len(params_flat) + + if cudagraphs is None: + cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) + + if bw_compiler is None: + bw_compiler = fw_compiler + if inference_compiler is None: + inference_compiler = fw_compiler + + seen_sources = set() + + full_args = [] + # First, the params + full_args.extend(params_flat) + + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.params_flat = params_flat + + aot_autograd_arg_pos_to_source = None + # Then, the params 1:1 mapped sources, if relevant. + if hasattr(mod, "_param_name_to_source"): + aot_autograd_arg_pos_to_source = [] + # We now know this came from dynamo, and (1) we care about guards, + # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards + # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below. + for name in params.keys(): + assert name in mod._param_name_to_source, f"{name} not found." + source = mod._param_name_to_source[name] + assert source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + + # Next, the input args + full_args.extend(args) + + static_input_indices = [] + if hasattr(mod, "graph"): + # Non dynamo entrypoints can get to here... + for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")): + if hasattr(node, "_dynamo_source"): + # ... but not here! + if aot_autograd_arg_pos_to_source is None: + aot_autograd_arg_pos_to_source = [] + source = node._dynamo_source + assert source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + source_name = source.name() if source else str(source) + + if "tensor_dict" in node.meta and node.meta["tensor_dict"].get( + "_dynamo_static_input_type", None + ): + static_inputs_log.debug( + "Adding static input pos %s for source %s", pos, source_name + ) + static_input_indices.append(pos) + else: + static_inputs_log.debug( + "Non-static input pos %s for source %s", pos, source_name + ) + + if aot_autograd_arg_pos_to_source is not None: + assert len(full_args) == len(aot_autograd_arg_pos_to_source) + + dynamic_shapes = False + for x in full_args: + if isinstance(x, FakeTensor): + dynamic_shapes = x.fake_mode.shape_env is not None + break + + aot_config = AOTConfig( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + inference_compiler=inference_compiler, + partition_fn=partition_fn, + decompositions=decompositions, + num_params_buffers=params_len, + aot_id=next(AOT_COUNTER), + keep_inference_input_mutations=keep_inference_input_mutations, + dynamic_shapes=dynamic_shapes, + aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source, + static_input_indices=static_input_indices, + is_export=False, + no_tangents=False, + cache_key=None, + ) + fake_mode, shape_env = construct_fake_mode(full_args, aot_config) + fake_flat_args = process_inputs(full_args, aot_config, fake_mode, shape_env) + + def dispatch_and_compile(): + functional_call = create_functional_call(mod, params_spec, params_len) + with compiled_autograd.disable(): + compiled_fn, _ = create_aot_dispatcher_function( + functional_call, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) + return compiled_fn + + # Autograd cache stuff + if config.enable_autograd_cache: + compiled_fn = AOTAutogradCache.load( + dispatch_and_compile, mod, fake_flat_args, aot_config, cudagraphs + ) + else: + compiled_fn = dispatch_and_compile() + + if isinstance(mod, torch._dynamo.utils.GmWrapper): + # This function is called by the flatten_graph_inputs wrapper, which boxes + # the inputs so that they can be freed before the end of this scope. + # For overhead reasons, this is not the default wrapper, see comment: + # https://github.com/pytorch/pytorch/pull/122535/files#r1560096481 + def boxed_forward(runtime_args: List[Any]): + flat_args = [] + flat_args.extend(params_flat) + flat_args.extend(runtime_args) + runtime_args.clear() + return compiled_fn(flat_args) + + # Just for convenience + boxed_forward.zero_grad = mod.zero_grad + boxed_forward.named_parameters = mod.named_parameters + boxed_forward.named_buffers = mod.named_buffers + return boxed_forward + + # TODO: There is something deeply wrong here; compiled_fn running with + # the boxed calling convention, but aot_module_simplified somehow + # historically returned a function that was not the boxed calling + # convention. This should get fixed... + # NB: GraphModule/nn.Module rely on the non-boxed calling convention here + def forward(*runtime_args: Tuple[Any]): + full_args = [] + full_args.extend(params_flat) + full_args.extend(runtime_args) + return compiled_fn(full_args) + + # Just for convenience + forward.zero_grad = mod.zero_grad + forward.named_parameters = mod.named_parameters + forward.named_buffers = mod.named_buffers + + return forward + + +def aot_export_module( + mod: nn.Module, + args, + *, + decompositions: Optional[Dict] = None, + # If true, we'll return a joint forward-backward graph, + # As well as metadata on the loss + gradients in the backward. + trace_joint: bool, + # If trace_joint is True, we expect your module to return a scalar loss. + # Your module can return multiple outputs, so you must specify which output the loss is. + output_loss_index: Optional[int] = None, + pre_dispatch: bool = False, + # If None, will be infered from inputs and mod.graph.nodes if mod is a graph module, but the inferred result might be wrong. + dynamic_shapes: Optional[bool] = None, + kwargs=None, +) -> Tuple[torch.fx.GraphModule, GraphSignature]: + """ + This function takes in a module, and returns: + (1) an FX graph that can be exported + (2) some metadata about the graph + + If `trace_joint=True` we will return a joint graph of the forward + backward. + + The traced FX graph will have the following properties compared to the original module: + (1) Inputs and outputs to the module will be pytree-flattened + (2) Parameters and buffers on the module will be lifted into graph inputs, + graph_inputs = (*parameters, *buffers, *user_inputs) + (3) The graph will be fully functionalized + (4) Any input mutations will be converted into additional outputs in the graph, + meaning whoever calls this graph is responsible for applying the mutations + back to the original inputs. + (5) If is_joint is provided the graph will return parameter gradients in addition to user outputs. + The graph output will look like: + graph_outputs = (*updated_inputs, *user_outputs, *param_gradients) + + There are also several restrictions on what modules can use this API. In particular: + (1) If trace_joint is specified, we expect the loss function to be **fused** + into the module forward. One of the outputs to the forward must be a scalar loss, + which is specified with `output_loss_index`. + All other outputs to the forward are presumed to not require gradients. + (2) This API cannot capture optimizers (although in theory we could build an API for this). + (3) Metadata mutations on params/buffers/inputs are banned. + (4) Data mutations on anything that requires gradients are banned (parameters) + (5) If an input is mutated, it is not allowed to alias any other inputs. + (6) Parameters must not be duplicated. + """ + if pre_dispatch and trace_joint: + raise RuntimeError("pre_dispatch is not supported when trace_joint is True.") + named_parameters = dict(mod.named_parameters(remove_duplicate=False)) + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) + + params_and_buffers = { + **dict(named_parameters), + **dict(named_buffers), + } + params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers) + params_and_buffers_flat = tuple(params_and_buffers_flat) + params_len = len(params_and_buffers_flat) + + kwargs = kwargs or {} + + functional_call = create_functional_call( + mod, params_spec, params_len, store_orig_mod=True + ) + + num_fw_outs = None + + if trace_joint: + # This helper effectively just adds some extra asserts about what the backward will look like: + # Outputs must include a scalar loss, that we compute gradients w.r.t. + # We don't compute gradients w.r.t. anything else: so just in case we detach() + # and other output tensors. + def fn_to_trace(*args): + nonlocal num_fw_outs + out = functional_call(*args) + if output_loss_index is None: + raise RuntimeError( + """\ +If trace_joint=Trueit is required that one of your forward outputs must be a scalar loss. +You must specify the which (index) output is the loss with output_loss_index.""" + ) + if isinstance(out, (torch.Tensor)): + out = (out,) + if not isinstance(out, (tuple, list)): + raise RuntimeError( + f"Expected forward output to be either a tensor or a list/tuple of tensors. found {type(out)}" + ) + + for i, o in enumerate(out): + # We only want to create a backward graph w.r.t. the loss that the user passed in. + # This implies that every other output should not require gradients. + # Instead of making this an error (and forcing the user to detach all other outputs + # of their forward), + # we'll automatically detach them here. + if o.requires_grad and i != output_loss_index: + raise RuntimeError( + f"""\ +Found an output of the forward that requires gradients, that was not the scalar loss. +We require all outputs to the forward that are not the scalar loss to not require gradient, +because we will only compute a backward graph against the scalar loss. +You can fix this by calling .detach() on each of your forward outputs that is not the loss. +You specified that output index {output_loss_index} is the loss, but we found that +the output at index {i} requires gradients.""" + ) + out_loss = out[output_loss_index] + num_fw_outs = len(out) + if not out_loss.requires_grad: + raise RuntimeError( + f"""\ +The output at index {output_loss_index} was marked as the loss, but it does not require gradients""" + ) + if out_loss.numel() != 1: + raise RuntimeError( + f"""\ +We require the output marked as the loss (at index {output_loss_index}) to be a scalar, but it has shape {out_loss.shape}""" + ) + return out + + ctx = nullcontext + else: + # Run under no_grad, so our tracing machinery only traces an inference graph. + # However if pre_dispatch=True, we want to correctly trace set_grad_enabled calls for training. + ctx = nullcontext if pre_dispatch else torch.no_grad + fn_to_trace = functional_call + + full_args = [] + # First, the params + # NB: It is REQUIRED that parameters come first, Inductor infers "fixed" + # parameters by looking at the difference in parameter count outside + # and inside AOTAutograd, and assumes the prefix of arguments are fixed + # arguments + full_args.extend(params_and_buffers_flat) + # Next, the input args + full_args.extend(args) + + with ctx(): + fx_g, metadata, in_spec, out_spec = _aot_export_function( + fn_to_trace, + full_args, + decompositions=decompositions, + num_params_buffers=params_len, + no_tangents=True, + pre_dispatch=pre_dispatch, + dynamic_shapes=dynamic_shapes, + kwargs=kwargs, + ) + if trace_joint: + + def flattened_joint(*args): + # The idea here is that the joint graph that AOTAutograd creates has some strict properties: + # (1) It accepts two arguments (primals, tangents), and pytree_flattens them + # (2) It returns a tuple of (fw_outs, gradients) + # This is a very useful convention for anyone who wants to partition the joint graph + # into a separate forward and backward graph. + # However, + # (1) for people exporting a single joint graph, it would be preferable not to have + # any pytrees in the graph. + # (2) We are guaranteed in the aot_export_module case that the forward outputs a loss, + # and there are therefore no tangents that are needed to run the joint graph. + # (3) AOTAutograd creates a grad_input for every input in the forward, + # including None's for inputs that are not grad-requiring tensors. + # we don't want these in our export graph. + # and there are therefore no tangents that are needed to run the joint graph. + # This function "fixes" both of the above by removing any tangent inputs, + # and removing pytrees from the original FX graph. + fake_tangents = [ + None + for _ in range( + metadata.num_outputs + metadata.num_mutated_inp_runtime_indices + ) + ] + fw_outs, gradients = fx_g(args, fake_tangents) + assert len(gradients) == len(args) + output_gradients = [] + for i, (a, grad) in enumerate(zip(args, gradients)): + if isinstance(a, torch.Tensor) and a.requires_grad: + assert ( + grad is not None + ), """\ +Found a parameter that did not receive a gradient. +"This is most likely a bug, but if this needs to be supported please comment on this Github issue: +https://github.com/pytorch/pytorch/issues/101192 +""" + output_gradients.append(grad) + else: + assert grad is None + return *fw_outs, *output_gradients + + fx_g = make_fx(flattened_joint)(*full_args) + + user_args_flat = pytree.arg_tree_leaves(*args, **kwargs) + return fx_g, create_graph_signature( + fx_g, + metadata, + in_spec, + out_spec, + user_args_flat=user_args_flat, + params_and_buffers_flat=params_and_buffers_flat, + param_names=list(named_parameters.keys()), + buffer_names=list(named_buffers.keys()), + trace_joint=trace_joint, + num_user_fw_outs=num_fw_outs, + loss_index=output_loss_index, + ) + + +def aot_export_joint_simple( + func: Callable, + args, + *, + trace_joint: bool, + # It looks like the main consequence of this API is that for dynamic shapes, + # it will assume that parms/buffers are static. + # With the new inferred dynamic shapes API, maybe this doesn't matter? + num_params_buffers: int = 0, + decompositions: Optional[Dict] = None, +) -> torch.fx.GraphModule: + """ + A simplified version of export. Used by higher order operators. + + This function makes a high-level "no calling convention changes" guarantee: + - If no inputs require grad (so we export an inference graph), + there are *no* calling convention change between the exported graph, and "func". + - If at least one input requires grad (so we trace out and export a joint fw-bw graph), + Then if you were partition the graph into a separate forward and backward graph, + The forward graph will have no calling convention changes compared to "func". + + The above also relies on some strong restrictions around which functions this API accepts: + (1) `args` cannot contain any pytrees (they must have been pytree_flattened already) + (2) `func` cannot mutate any inputs + (3) The outputs of `func` cannot alias any inputs. + + Note: this function is only lightly tested today. It will probably be tested more heavily by higher order ops. + """ + if trace_joint: + ctx = nullcontext + else: + # Run under no_grad, so our tracing machinery only traces an inference graph. + ctx = torch.no_grad + + with ctx(): + fx_g, metadata, in_spec, out_spec = _aot_export_function( + func, + args, + decompositions=decompositions, + ) + in_spec, _kw_in_spec = in_spec.children_specs + # At this point, we can just directly return the (joint or inference graph) that we traced. + # First though: a bunch of assertions to make sure that our graph doesn't require + # any calling convention changes compared to the original function. + # These restrictions are *in addition to* the general restrictions on export. + + # No input mutations + if ( + len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata]) + != 0 + ): + raise RuntimeError( + f"aot_export_joint_simple does not support input mutations. {str(metadata)}" + ) + # No output aliasing + if ( + len([x for x in metadata.output_info if x.output_type != OutputType.non_alias]) + != 0 + ): + raise RuntimeError( + f"aot_export_joint_simple does not support outputs that alias inputs. {str(metadata)}" + ) + # No pytrees + if in_spec.is_leaf(): + raise RuntimeError( + f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}" + ) + if not all(child.is_leaf() for child in in_spec.children_specs): + raise RuntimeError( + f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}" + ) + if out_spec.is_leaf(): + raise RuntimeError( + f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}" + ) + if not all(child.is_leaf() for child in out_spec.children_specs): + raise RuntimeError( + f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}" + ) + # TODO: we might have to temporarily patch config.functionalize_rng + # so that it doesn't run when we're exporting a higher order op. + + if config.debug_assert: + # Smoke test that after partitioning, we can run the forward without any calling convention changes. + fw_module, bw_module = aot_config.default_partition( # noqa: F821 + fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) # noqa: F821 + ) + # Attempt to run the fw_module with the original user inputs + fake_mode = detect_fake_mode(args) + if fake_mode is None: + fake_mode = FakeTensorMode() + with fake_mode: + fw_module(*args) + return fx_g + + +# Private for now because we aren't providing a contract on what to return +# for joint graphs (we could when there's a clearer use case) +# In the future, we may need to add more export API's that provide their own strong guarantees. +# This is meant as a general helper function for handling various export-y use cases. +def _aot_export_function( + func: Callable, + args, + *, + num_params_buffers: int = 0, + decompositions: Optional[Dict] = None, + # If we're exporting a joint graph and we don't want any tangent inputs in the graph + # (because we are backpropping through a scalar 1 loss), + # we need to explicitly specify not to include tangents in the graph. + # It's not enough just to check that our tangent is a scalar, since we also + # need to know if it is a 1 (no need to make it a graph input), or something else + # (requiring it to be a graph input). + # We don't know this info at trace time though, so we need to make it an explicit config. + no_tangents: bool = False, + pre_dispatch: bool = False, + # If None, `dynamic_shapes` will be infered from inputs, but the inferred result might be wrong. + dynamic_shapes: Optional[bool] = None, + kwargs=None, +) -> Tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]: + kwargs = kwargs or {} + + flat_fn, out_spec = create_tree_flattened_fn(func, args, kwargs) + flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + + if dynamic_shapes is None: + # Try to infer `dynamic_shapes from inputs and graph nodes + fake_mode = detect_fake_mode(flat_args) + if ( + fake_mode is None + and hasattr(func, "_orig_mod") + and isinstance(func._orig_mod, torch.fx.GraphModule) + ): + vals = [ + node.meta["val"] + for node in func._orig_mod.graph.nodes + if "val" in node.meta + ] + fake_mode = detect_fake_mode(vals) + dynamic_shapes = fake_mode is not None and fake_mode.shape_env is not None + + # The export use case doesn't care about several bits of AOTConfig + # (1) compilers (we just export the graph) + # (2) partitioners (export is only full graph, user can partition themselves) + aot_config = AOTConfig( + fw_compiler=None, + bw_compiler=None, + inference_compiler=None, + partition_fn=None, + decompositions=decompositions, + num_params_buffers=num_params_buffers, + aot_id=next(AOT_COUNTER), + # For now there's no use case involving keeping input mutations in the graph + # (which we can only do in the inference case anyway). + # We can add this later if we need to. + keep_inference_input_mutations=False, + dynamic_shapes=dynamic_shapes, + aot_autograd_arg_pos_to_source=None, + is_export=True, + no_tangents=no_tangents, + pre_dispatch=pre_dispatch, + ) + fake_mode, shape_env = construct_fake_mode(flat_args, aot_config) + fake_flat_args = process_inputs(flat_args, aot_config, fake_mode, shape_env) + + fx_g, meta = create_aot_dispatcher_function( + flat_fn, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) + return fx_g, meta, in_spec, out_spec.spec + + +@contextmanager +def _detect_attribute_assignment(mod: torch.nn.Module): + # Do not allow assignment of tensor attributes during export unless + # the attribute is registered as a buffer. + + STD_ATTRS = { + "_backward_hooks", + "_backward_pre_hooks", + "_buffers", + "_forward_hooks", + "_forward_hooks_always_called", + "_forward_hooks_with_kwargs", + "_forward_pre_hooks", + "_forward_pre_hooks_with_kwargs", + "_is_full_backward_hook", + "_load_state_dict_post_hooks", + "_load_state_dict_pre_hooks", + "_modules", + "_non_persistent_buffers_set", + "_parameters", + "_state_dict_hooks", + "_state_dict_pre_hooks", + "training", + } + + def _get_attributes(mod): + # return any attributes of a module that are not standard attributes + return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS} + + # save state of attributes before enter + snapshot = pytree.tree_map(lambda x: x, _get_attributes(mod)) + try: + yield + finally: + # after exit, compare state of attributes with snapshot + # to detect which tensor attributes were assigned + assigned_tensor_attributes = [] + + def _collect_assigned_tensor_attributes(kp, v, _v): + if _v is not v: + attr, *rest = kp + if isinstance(v, torch.Tensor): + assigned_tensor_attributes.append( + f"self.{attr.key}{pytree.keystr(rest)}" + ) + # TODO(avik): Assigning all other types are allowed right now. + # Maybe in the future we want to limit this to primitive types? + + pytree.tree_map_with_path( + _collect_assigned_tensor_attributes, snapshot, _get_attributes(mod) + ) + # restore state of all attributes (including, e.g., of primitive types) + mod.__dict__.update(snapshot) + + if assigned_tensor_attributes: + if len(assigned_tensor_attributes) > 1: + noun, verb = "attributes", "were" + else: + noun, verb = "attribute", "was" + raise ValueError( + f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. " + "Such attributes must be registered as buffers using the `register_buffer` API " + "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) + + +compiled_function = aot_function +compiled_module = aot_module diff --git a/lib/python3.10/site-packages/torch/_functorch/apis.py b/lib/python3.10/site-packages/torch/_functorch/apis.py new file mode 100644 index 0000000000000000000000000000000000000000..d906f3c906c9893bfda830dbc6f9220106d0fe72 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/apis.py @@ -0,0 +1,449 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# NOTE: We allow Dynamo to see this file (via torch/_dynamo/trace_rules.py) so that it can +# trace through functorch transforms. +# Currently, we can't allow Dynamo to see `eager_transforms.py`/`vmap.py` as that break a lot of thing +# and there isn't a mechanism to selectively expose only some functions (eg. grad) from a file +# to Dynamo. +import functools + +from torch._functorch.utils import argnums_t, exposed_in +from torch._functorch.vmap import ( + _check_out_dims_is_int_or_int_pytree, + _check_randomness_arg, + _chunked_vmap, + _process_batched_inputs, + Callable, + in_dims_t, + out_dims_t, + vmap_impl, +) + + +# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors, +# sends those into func, and then unwraps the output BatchedTensors. Operations +# on BatchedTensors perform the batched operations that the user is asking for. +# +# vmap's randomness behavior differs from JAX's, which would require a PRNG key +# to be passed everywhere. + + +@exposed_in("torch.func") +def vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + randomness: str = "error", + *, + chunk_size=None, +) -> Callable: + """ + vmap is the vectorizing map; ``vmap(func)`` returns a new function that + maps ``func`` over some dimension of the inputs. Semantically, vmap + pushes the map into PyTorch operations called by ``func``, effectively + vectorizing those operations. + + vmap is useful for handling batch dimensions: one can write a function + ``func`` that runs on examples and then lift it to a function that can + take batches of examples with ``vmap(func)``. vmap can also be used to + compute batched gradients when composed with autograd. + + .. note:: + :func:`torch.vmap` is aliased to :func:`torch.func.vmap` for + convenience. Use whichever one you'd like. + + Args: + func (function): A Python function that takes one or more arguments. + Must return one or more Tensors. + in_dims (int or nested structure): Specifies which dimension of the + inputs should be mapped over. ``in_dims`` should have a + structure like the inputs. If the ``in_dim`` for a particular + input is None, then that indicates there is no map dimension. + Default: 0. + out_dims (int or Tuple[int]): Specifies where the mapped dimension + should appear in the outputs. If ``out_dims`` is a Tuple, then + it should have one element per output. Default: 0. + randomness (str): Specifies whether the randomness in this + vmap should be the same or different across batches. If 'different', + the randomness for each batch will be different. If 'same', the + randomness will be the same across batches. If 'error', any calls to + random functions will error. Default: 'error'. WARNING: this flag + only applies to random PyTorch operations and does not apply to + Python's random module or numpy randomness. + chunk_size (None or int): If None (default), apply a single vmap over inputs. + If not None, then compute the vmap :attr:`chunk_size` samples at a time. + Note that :attr:`chunk_size=1` is equivalent to computing the vmap with a for-loop. + If you run into memory issues computing the vmap, please try a non-None chunk_size. + + Returns: + Returns a new "batched" function. It takes the same inputs as + ``func``, except each input has an extra dimension at the index + specified by ``in_dims``. It takes returns the same outputs as + ``func``, except each output has an extra dimension at the index + specified by ``out_dims``. + + .. warning: + :func:`vmap` works best with functional-style code. Please do not + perform any side-effects in ``func``, with the exception of + in-place PyTorch operations. Examples of side-effects include mutating + Python data structures and assigning values to variables not captured + in ``func``. + + One example of using :func:`vmap` is to compute batched dot products. PyTorch + doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully + rummaging through docs, use :func:`vmap` to construct a new function. + + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N] + >>> x, y = torch.randn(2, 5), torch.randn(2, 5) + >>> batched_dot(x, y) + + :func:`vmap` can be helpful in hiding batch dimensions, leading to a simpler + model authoring experience. + + >>> batch_size, feature_size = 3, 5 + >>> weights = torch.randn(feature_size, requires_grad=True) + >>> + >>> def model(feature_vec): + >>> # Very simple linear model with activation + >>> return feature_vec.dot(weights).relu() + >>> + >>> examples = torch.randn(batch_size, feature_size) + >>> result = torch.vmap(model)(examples) + + :func:`vmap` can also help vectorize computations that were previously difficult + or impossible to batch. One example is higher-order gradient computation. + The PyTorch autograd engine computes vjps (vector-Jacobian products). + Computing a full Jacobian matrix for some function f: R^N -> R^N usually + requires N calls to ``autograd.grad``, one per Jacobian row. Using :func:`vmap`, + we can vectorize the whole computation, computing the Jacobian in a single + call to ``autograd.grad``. + + >>> # Setup + >>> N = 5 + >>> f = lambda x: x ** 2 + >>> x = torch.randn(N, requires_grad=True) + >>> y = f(x) + >>> I_N = torch.eye(N) + >>> + >>> # Sequential approach + >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] + >>> for v in I_N.unbind()] + >>> jacobian = torch.stack(jacobian_rows) + >>> + >>> # vectorized gradient computation + >>> def get_vjp(v): + >>> return torch.autograd.grad(y, x, v) + >>> jacobian = torch.vmap(get_vjp)(I_N) + + :func:`vmap` can also be nested, producing an output with multiple batched dimensions + + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] + >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) + >>> batched_dot(x, y) # tensor of size [2, 3] + + If the inputs are not batched along the first dimension, ``in_dims`` specifies + the dimension that each inputs are batched along as + + >>> torch.dot # [N], [N] -> [] + >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] + >>> x, y = torch.randn(2, 5), torch.randn(2, 5) + >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension + + If there are multiple inputs each of which is batched along different dimensions, + ``in_dims`` must be a tuple with the batch dimension for each input as + + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] + >>> x, y = torch.randn(2, 5), torch.randn(5) + >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None + + If the input is a Python struct, ``in_dims`` must be a tuple containing a struct + matching the shape of the input: + + >>> f = lambda dict: torch.dot(dict['x'], dict['y']) + >>> x, y = torch.randn(2, 5), torch.randn(5) + >>> input = {'x': x, 'y': y} + >>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},)) + >>> batched_dot(input) + + By default, the output is batched along the first dimension. However, it can be batched + along any dimension by using ``out_dims`` + + >>> f = lambda x: x ** 2 + >>> x = torch.randn(2, 5) + >>> batched_pow = torch.vmap(f, out_dims=1) + >>> batched_pow(x) # [5, 2] + + For any function that uses kwargs, the returned function will not batch the kwargs but will + accept kwargs + + >>> x = torch.randn([2, 5]) + >>> def fn(x, scale=4.): + >>> return x * scale + >>> + >>> batched_pow = torch.vmap(fn) + >>> assert torch.allclose(batched_pow(x), x * 4) + >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5] + + .. note:: + vmap does not provide general autobatching or handle variable-length + sequences out of the box. + """ + from torch._dynamo import is_compiling + + _check_randomness_arg(randomness) + if not (chunk_size is None or chunk_size > 0): + raise ValueError( + f"vmap: chunk_size should be None or greater than 0. (got {chunk_size})" + ) + + def wrapped(*args, **kwargs): + return vmap_impl( + func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs + ) + + if not is_compiling(): + wrapped = functools.wraps(func)(wrapped) + + return wrapped + + +def chunk_vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + randomness: str = "error", + chunks=2, +) -> Callable: + """ + chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes + everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of + chunks at a time. For more details about vectorizing map, see :func:`vmap`. + + .. note:: + Please use :func:`vmap` with ``chunk_size`` argument instead of this API. + + Args: + func (function): A Python function that takes one or more arguments. + Must return one or more Tensors. + in_dims (int or nested structure): Specifies which dimension of the + inputs should be mapped over. ``in_dims`` should have a + structure like the inputs. If the ``in_dim`` for a particular + input is None, then that indicates there is no map dimension. + Default: 0. + out_dims (int or Tuple[int]): Specifies where the mapped dimension + should appear in the outputs. If ``out_dims`` is a Tuple, then + it should have one element per output. Default: 0. + randomness (str): Specifies whether the randomness in this + vmap should be the same or different across batches. If 'different', + the randomness for each batch will be different. If 'same', the + randomness will be the same across batches. If 'error', any calls to + random functions will error. Default: 'error'. WARNING: this flag + only applies to random PyTorch operations and does not apply to + Python's random module or numpy randomness. + chunks (int): Number of chunks to use to split the input data. Default is 2. + If equals to 1 then :func:`vmap` is called. + + Returns: + Returns a new "batched" function. It takes the same inputs as + ``func``, except each input has an extra dimension at the index + specified by ``in_dims``. It takes returns the same outputs as + ``func``, except each output has an extra dimension at the index + specified by ``out_dims``. + """ + _check_randomness_arg(randomness) + + if chunks == 1: + return vmap(func, in_dims=in_dims, out_dims=out_dims, randomness=randomness) + + def _get_chunk_flat_args(flat_args_, flat_in_dims_, chunks_): + flat_args_chunks = tuple( + t.chunk(chunks_, dim=in_dim) + if in_dim is not None + else [ + t, + ] + * chunks_ + for t, in_dim in zip(flat_args_, flat_in_dims_) + ) + # transpose chunk dim and flatten structure + # chunks_flat_args is a list of flatten args + chunks_flat_args = zip(*flat_args_chunks) + return chunks_flat_args + + @functools.wraps(func) + def wrapped_with_chunks(*args, **kwargs): + _check_out_dims_is_int_or_int_pytree(out_dims, func) + _, flat_in_dims, flat_args, args_spec = _process_batched_inputs( + in_dims, args, func + ) + # Chunk flat arguments + chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks) + + # Apply vmap on chunks + return _chunked_vmap( + func, + flat_in_dims, + chunks_flat_args, + args_spec, + out_dims, + randomness, + **kwargs, + ) + + return wrapped_with_chunks + + +@exposed_in("torch.func") +def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: + """``grad`` operator helps computing gradients of ``func`` with respect to the + input(s) specified by ``argnums``. This operator can be nested to + compute higher-order gradients. + + Args: + func (Callable): A Python function that takes one or more arguments. + Must return a single-element Tensor. If specified ``has_aux`` equals ``True``, + function can return a tuple of single-element Tensor and other auxiliary objects: + ``(output, aux)``. + argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to. + ``argnums`` can be single integer or tuple of integers. Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a tensor and other + auxiliary objects: ``(output, aux)``. Default: False. + + Returns: + Function to compute gradients with respect to its inputs. By default, the output of + the function is the gradient tensor(s) with respect to the first argument. + If specified ``has_aux`` equals ``True``, tuple of gradients and output auxiliary objects + is returned. If ``argnums`` is a tuple of integers, a tuple of output gradients with + respect to each ``argnums`` value is returned. + + Example of using ``grad``: + + >>> # xdoctest: +SKIP + >>> from torch.func import grad + >>> x = torch.randn([]) + >>> cos_x = grad(lambda x: torch.sin(x))(x) + >>> assert torch.allclose(cos_x, x.cos()) + >>> + >>> # Second-order gradients + >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) + >>> assert torch.allclose(neg_sin_x, -x.sin()) + + When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients: + + >>> # xdoctest: +SKIP + >>> from torch.func import grad, vmap + >>> batch_size, feature_size = 3, 5 + >>> + >>> def model(weights, feature_vec): + >>> # Very simple linear model with activation + >>> assert feature_vec.dim() == 1 + >>> return feature_vec.dot(weights).relu() + >>> + >>> def compute_loss(weights, example, target): + >>> y = model(weights, example) + >>> return ((y - target) ** 2).mean() # MSELoss + >>> + >>> weights = torch.randn(feature_size, requires_grad=True) + >>> examples = torch.randn(batch_size, feature_size) + >>> targets = torch.randn(batch_size) + >>> inputs = (weights, examples, targets) + >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) + + Example of using ``grad`` with ``has_aux`` and ``argnums``: + + >>> # xdoctest: +SKIP + >>> from torch.func import grad + >>> def my_loss_func(y, y_pred): + >>> loss_per_sample = (0.5 * y_pred - y) ** 2 + >>> loss = loss_per_sample.mean() + >>> return loss, (y_pred, loss_per_sample) + >>> + >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True) + >>> y_true = torch.rand(4) + >>> y_preds = torch.rand(4, requires_grad=True) + >>> out = fn(y_true, y_preds) + >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample)) + + .. note:: + Using PyTorch ``torch.no_grad`` together with ``grad``. + + Case 1: Using ``torch.no_grad`` inside a function: + + >>> # xdoctest: +SKIP + >>> def f(x): + >>> with torch.no_grad(): + >>> c = x ** 2 + >>> return x - c + + In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``. + + Case 2: Using ``grad`` inside ``torch.no_grad`` context manager: + + >>> # xdoctest: +SKIP + >>> with torch.no_grad(): + >>> grad(f)(x) + + In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the + outer one. This is because ``grad`` is a "function transform": its result + should not depend on the result of a context manager outside of ``f``. + + """ + # To avoid cyclical dependency. + import torch._functorch.eager_transforms as eager_transforms + from torch._dynamo import is_compiling + + def wrapper(*args, **kwargs): + return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs) + + if not is_compiling(): + wrapper = functools.wraps(func)(wrapper) + + return wrapper + + +@exposed_in("torch.func") +def grad_and_value( + func: Callable, argnums: argnums_t = 0, has_aux: bool = False +) -> Callable: + """ + Returns a function to compute a tuple of the gradient and primal, or + forward, computation. + + Args: + func (Callable): A Python function that takes one or more arguments. + Must return a single-element Tensor. If specified ``has_aux`` + equals ``True``, function can return a tuple of single-element + Tensor and other auxiliary objects: ``(output, aux)``. + argnums (int or Tuple[int]): Specifies arguments to compute gradients + with respect to. ``argnums`` can be single integer or tuple of + integers. Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a tensor and + other auxiliary objects: ``(output, aux)``. Default: False. + + Returns: + Function to compute a tuple of gradients with respect to its inputs + and the forward computation. By default, the output of the function is + a tuple of the gradient tensor(s) with respect to the first argument + and the primal computation. If specified ``has_aux`` equals + ``True``, tuple of gradients and tuple of the forward computation with + output auxiliary objects is returned. If ``argnums`` is a tuple of + integers, a tuple of a tuple of the output gradients with respect to + each ``argnums`` value and the forward computation is returned. + + See :func:`grad` for examples + """ + from torch._dynamo import is_compiling + from torch._functorch import eager_transforms + + def wrapper(*args, **kwargs): + return eager_transforms.grad_and_value_impl( + func, argnums, has_aux, args, kwargs + ) + + if not is_compiling(): + wrapper = functools.wraps(func)(wrapper) + + return wrapper diff --git a/lib/python3.10/site-packages/torch/_functorch/autograd_function.py b/lib/python3.10/site-packages/torch/_functorch/autograd_function.py new file mode 100644 index 0000000000000000000000000000000000000000..cb501e2c924213c5d31baa5f7e256920ce3351c2 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/autograd_function.py @@ -0,0 +1,752 @@ +# mypy: allow-untyped-defs +from typing import Any, NamedTuple, Tuple + +import torch +import torch.utils._pytree as pytree +from torch._C._functorch import ( + _unwrap_for_grad, + _wrap_for_grad, + current_level, + TransformType, +) +from torch._functorch.apis import vmap +from torch._functorch.utils import enable_single_level_autograd_function +from torch._functorch.vmap import ( + _add_batch_dim, + _broadcast_to_and_flatten, + restore_vmap, + unwrap_batched, + wrap_batched, +) +from torch._ops import HigherOrderOperator +from torch.autograd.forward_ad import _set_fwd_grad_enabled + + +# autograd.Function technically runs before the regular PyTorch dispatcher. +# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot) +# work with it. One day we might decide to change this, but until then, +# we need to give the illusion that autograd.Function runs before those things. +# +# We do this by using creating a custom HigherOrderOperator that only functorch +# dispatches specially. +class CustomFunctionHigherOrderOperator(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("custom_function_call") + + def __call__(self, autograd_function, *args, **kwargs): + # When custom_function_call is done dispatching through functorch, + # it should just invoke the autograd.Function. This is consistent + # with the autograd.Function behavior of being invoked before the + # PyTorch dispatcher. + # + # This will lead us into trouble later down the line, but this is + # pre-existing. There is an invariant that a function traced by + # make_fx should have the same behavior when provided the same + # Tensor. However, make_fx sees autograd.Function as a composite + # (because autograd.Function happens before the Python dispatch key) + # and only traces the forward pass. + if torch._C._are_functorch_transforms_active(): + return super().__call__(autograd_function, *args, **kwargs) + return autograd_function.apply(*args, **kwargs) + + +# "custom_function_call" +# This is the mechanism for an autograd.Function that works with functorch transforms. +# It wraps an autograd.Function; interactions with functorch transforms are defined +# via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch +# dispatcher. +custom_function_call = CustomFunctionHigherOrderOperator() + + +# The grad rule for custom_function_call is to construct a new _SingleLevelFunction +# (autograd.Function that only works with a single layer (level) of functorch) that: +# - unwraps the inputs +# - redispatches to custom_function_call +# - wraps the outputs +# and whose backward pass calls the original autograd.Function's backward. +# +# Why do we need to redispatch to custom_function_call? +# ----------------------------------------------------- +# This is consistent with how ATen operators work with functorch's grad transform: +# they always redispatch to the original operator. +# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x) +# +# grad1 will: +# - set up the autograd graph +# - unwrap the inputs +# - redispatch to at::sin (*) +# - rewrap the outputs on the return +# +# On the redispatch in (*), grad0 will: +# - set up the autograd graph +# - unwrap the inputs +# - redispatch to at::sin +# - rewrap the outputs on the return +# +# To "set up the autograd graph", we generate a _SingleLevelFunction +# and apply it. +@custom_function_call.py_impl(TransformType.Grad) +@custom_function_call.py_impl(TransformType.Jvp) +def custom_function_call_grad(interpreter, autograd_function, *operands): + Generated = generate_single_level_function(interpreter, autograd_function) + with enable_single_level_autograd_function(): + flat_out = Generated.apply(*operands) + return flat_out + + +def generate_single_level_function(interpreter, autograd_function): + level = interpreter.level() + + def forward(*operands): + unwrapped_operands = pytree.tree_map_only( + torch.Tensor, lambda x: _unwrap_for_grad(x, level), operands + ) + # Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter + # the transform. _SingleLevelFunction will turn off both fwd and bwd + # gradient computation and we need to turn it back on here. + with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower(): + unwrapped_output = custom_function_call( + autograd_function, *unwrapped_operands + ) + + # See NOTE [mark_dirty object identity check] + def wrap_fn(output): + return _wrap_for_grad(output, level) + + return wrap_outputs_maintaining_identity( + unwrapped_output, unwrapped_operands, operands, wrap_fn + ) + + def setup_context(ctx, inputs, output): + return autograd_function.setup_context(ctx, inputs, output) + + # backward is only used if the transform is TransformType.Grad + def backward(ctx, *grads): + result = autograd_function.backward(ctx, *grads) + return result + + # jvp is only used if the transform is TransformType.Jvp + def jvp(ctx, *tangents): + result = autograd_function.jvp(ctx, *tangents) + return result + + # This is the sequence of magic words to dynamically generate a Subclass with + # a given name. A Tensor's .grad_fn field has a class name that is the original + # autograd.Function's name + Backward, so we do this to generate some + # meaningful name. + name = f"{autograd_function.__name__}Generated" + Generated = type( + name, + (torch.autograd.function._SingleLevelFunction,), + { + "forward": staticmethod(forward), + "backward": staticmethod(backward), + "jvp": staticmethod(jvp), + "setup_context": staticmethod(setup_context), + }, + ) + return Generated + + +# wrap_outputs_maintaining_identity handles outputs from the vmap, +# backward (vjp), and jvp staticmethod. The way it distinguishes +# between the vmap case and the {backward, jvp} case is if the out_dims +# are specified or not. +# +# NB: we cannot use out_dims=None as the deciding factor. This because +# out_dims=None can still happen in the vmap staticmethod! What the +# user is saying in that case is that their output does not have a +# dimension that is being vmapped over, which is valid. +NO_OUT_DIMS = "not specified" + + +# NOTE [mark_dirty object identity check] +# autograd.Function's ctx.mark_dirty expect a returned input +# to have the same object identity as the input. +# Mode-only functorch will greatly simplify this logic. +def wrap_outputs_maintaining_identity( + outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS +): + flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs) + flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs) + + unwrapped_input_to_orig_input = { + id(unwrapped): orig + for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs) + } + + flat_outputs, spec = pytree.tree_flatten(outputs) + result = [] + + out_dims_specified = out_dims != NO_OUT_DIMS + + if out_dims_specified: + flat_out_dims = _broadcast_to_and_flatten(out_dims, spec) + # _broadcast_to_and_flatten returns None if it is unable to broadcast. + # TODO: update following link from master to stable once that's out + if flat_out_dims is None: + raise RuntimeError( + f"The autograd.Function's vmap staticmethod returned an " + f"incompatible (output, out_dims) tuple. " + f"Expected out_dims={out_dims} " + f"to be compatible with the structure of `output`. " + f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} " + f"but output has structure {spec}. " + f"For more details, please see " + f"https://pytorch.org/docs/main/notes/extending.func.html" + ) + + for i, output in enumerate(flat_outputs): + if not isinstance(output, torch.Tensor): + result.append(output) + continue + if id(output) in unwrapped_input_to_orig_input: + result.append(unwrapped_input_to_orig_input[id(output)]) + continue + if out_dims_specified: + result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[possibly-undefined, index] + else: + result.append(wrap_fn(output)) + + return pytree.tree_unflatten(result, spec) + + +# NOTE: [functorch vjp and autograd interaction] +# There's an edge case with the functorch vjp and autograd interaction +# that will eventually be fixed by mode-only functorch. +# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper, +# so we (the framework) need to do it manually. Regular PyTorch operators +# automatically do so this is consistent. +# +# class MyExp(torch.autograd.Function): +# @staticmethod +# def forward(x): +# return x.exp() +# +# @staticmethod +# def setup_context(ctx, inputs, output): +# y = output +# ctx.save_for_backward(y) +# +# @staticmethod +# def backward(gy): +# y, = ctx.saved_tensors() +# return MyMul.apply(gy, y) +# +# x = torch.randn([], requires_grad=True) +# gy = torch.randn([], requires_grad=True) +# _, vjp_fn = vjp(MySin.apply, x) +# result = vjp_fn(gy) +# +# MyMul is an autograd.Function that is not shown here. +# It saves a `y` for backward (since gy requires grad). +# +# in vjp_fn(gy), we get: +# > MyMul.apply(gy, GradTensorWrapper(y, level=dead)) +# Because the y that is saved for backward by MyExp is a GradTensorWrapper +# but is now dead since we are outside the vjp context. +# +# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper, +# will automatically unwrap the GradTensorWrapper when applied. +# But since autograd.Function technically sits above the regular PyTorch +# dispatcher, it doesn't get this treatment. So we manually do +# the unwrapping to be consistent with regular PyTorch dispatcher operations. + + +class VmapInfo(NamedTuple): + batch_size: int + randomness: str + + +def has_overriden_vmap_rule(autograd_function): + return autograd_function.vmap is not torch.autograd.Function.vmap + + +def validate_vmap_returns_tuple_of_two_elements(result): + base_error_msg = ( + "Expected the vmap staticmethod to have two returns, an output " + "and out_dims with pytree structure compatible with the output. " + ) + if not isinstance(result, tuple): + raise RuntimeError(base_error_msg + f"Got a {type(result)} instead") + if not len(result) == 2: + raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead") + + +@custom_function_call.py_impl(TransformType.Vmap) +def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs): + if any( + isinstance(val, torch.Tensor) + for val in torch.utils._pytree.tree_flatten(kwargs)[0] + ): + raise NotImplementedError( + f"Run vmap on autograd.Function with kwarg-only Tensor args. " + f"Please do not pass kwarg-only Tensors to autograd.Function. " + f"Got: {kwargs}" + ) + + if autograd_function.generate_vmap_rule: + if has_overriden_vmap_rule(autograd_function): + # TODO: Update link to stable once that's out + # https://github.com/pytorch/pytorch/issues/92029 + raise RuntimeError( + f"You tried to vmap over {autograd_function.__name__}, but " + f"it has both generate_vmap_rule=True and an overriden vmap " + f"staticmethod. Please set generate_vmap_rule=False or delete " + f"the overriden vmap staticmethod to avoid ambiguity. " + f"For more details, please see " + f"https://pytorch.org/docs/main/notes/extending.func.html" + ) + return custom_function_call_vmap_generate_rule( + interpreter, autograd_function, *operands + ) + + if not has_overriden_vmap_rule(autograd_function): + # TODO: Update link to stable once that's out + # https://github.com/pytorch/pytorch/issues/92029 + raise RuntimeError( + f"You tried to vmap over {autograd_function.__name__}, but " + f"it does not have vmap support. Please override and implement the " + f"vmap staticmethod or set generate_vmap_rule=True. " + f"For more details, please see " + f"https://pytorch.org/docs/main/notes/extending.func.html" + ) + + return custom_function_call_vmap_helper( + interpreter, autograd_function.vmap, autograd_function, *operands, **kwargs + ) + + +def custom_function_call_vmap_helper( + interpreter, vmap_function, op, *operands, **kwargs +): + current_level = interpreter.level() + info = VmapInfo( + batch_size=interpreter.batch_size(), + randomness=interpreter.randomness(), + ) + unwrapped_operands, in_dims = unwrap_batched(operands, current_level) + # If none of the tensors are batched at the current level, then we skip the + # current level. This saves the user from needing to handle this case in + # their vmap staticmethod (and is consistent with our C++ batching rule API) + if pytree.tree_all(lambda dim: dim is None, in_dims): + with interpreter.lower(): + if isinstance(op, torch.autograd.function.FunctionMeta): + return custom_function_call(op, *operands) + else: + return op(*operands, **kwargs) + + with interpreter.lower(): + result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs) + validate_vmap_returns_tuple_of_two_elements(result) + unwrapped_output, out_dims = result + + # See NOTE [mark_dirty object identity check] + def wrap_fn(output, out_dim): + return ( + output + if out_dim is None + else _add_batch_dim(output, out_dim, current_level) + ) + + return wrap_outputs_maintaining_identity( + unwrapped_output, unwrapped_operands, operands, wrap_fn, out_dims=out_dims + ) + + +def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands): + unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level()) + vmapped_function, get_out_dims = vmapify_autograd_function( + autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness() + ) + + with interpreter.lower(): + output = custom_function_call(vmapped_function, *unwrapped_operands) + + out_dims = get_out_dims() + return wrap_batched(output, out_dims, interpreter.level()) + + +@custom_function_call.py_impl(TransformType.Functionalize) +def custom_function_call_functionalize( + interpreter, autograd_function, generate_vmap_rule, *operands +): + raise RuntimeError("NYI: Functionalize rule for custom_function_call") + + +def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness): + # The following values are saved from the forward() and setup_context() + # and used in backward(). + # Why do we save the values out here instead of on the ctx object? + # - out_dims: There's no way to retrieve this from forward() + # - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting + # vmap(vmap( but not completely sure if it is a problem. If we + # assigned those fields to the ctx object, the worry is that they + # get overwritten. + init_val = "not populated" + out_dims = init_val + input_shapes: Any = init_val + saved_tensors_bdims: Any = init_val + + def forward(*operands): + nonlocal out_dims + outputs, out_dims = restore_vmap( + autograd_function.forward, in_dims, batch_size, randomness + )(*operands) + return outputs + + def setup_context(ctx, inputs, outputs): + input_shapes_ = None + saved_tensors_bdims_ = None + + def inner(inputs, outputs): + # wrapped_ctx.save_for_backward will: + # - unwrap batchedtensors into (tensor, bdim) + # - save_for_backward(*unwrapped_tensors) + # - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims + wrapped_ctx = CtxCustomSave(ctx, current_level()) + autograd_function.setup_context(wrapped_ctx, inputs, outputs) + + # input_shapes are used for reductify later to reduce expanded gradients + # to the correct shape. + # See NOTE: [Why can't we rely on autograd to reduce expanded gradients?] + # for more details + nonlocal input_shapes_ + input_shapes_ = tuple( + inp.shape if isinstance(inp, torch.Tensor) else None for inp in inputs + ) + nonlocal saved_tensors_bdims_ + saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims + + # See NOTE: [Why do we need to run setup_context under a vmap?] + restore_vmap( + inner, + (in_dims, out_dims), + batch_size, + randomness, + )(inputs, outputs) + + nonlocal input_shapes + input_shapes = input_shapes_ + nonlocal saved_tensors_bdims + saved_tensors_bdims = saved_tensors_bdims_ + + def jvp(ctx, *tangents): + assert out_dims != init_val + assert saved_tensors_bdims != init_val + + def jvp_no_context(saved_tensors, tangents): + wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors) + return autograd_function.jvp(wrapped_ctx, *tangents) + + tangent_in_dims = get_tangents_in_dims(in_dims, tangents) + out_tangents, out_tangents_dims = restore_vmap( + jvp_no_context, + (saved_tensors_bdims, tangent_in_dims), + batch_size, + randomness, + )(ctx.saved_tensors, tangents) + + result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size) + return result + + def backward(ctx, *grad_outputs): + assert out_dims != init_val + assert input_shapes != init_val + assert saved_tensors_bdims != init_val + + def backward_no_context(inputs): + saved_tensors, grad_outputs = inputs + wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors) + return autograd_function.backward(wrapped_ctx, *grad_outputs) + + grad_ins, grad_ins_dims = restore_vmap( + backward_no_context, + ((saved_tensors_bdims, out_dims),), + batch_size, + randomness, + )((ctx.saved_tensors, grad_outputs)) + result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes) + return result + + name = f"Vmapped{autograd_function.__name__}" + Generated = type( + name, + (torch.autograd.Function,), + { + "forward": staticmethod(forward), + "backward": staticmethod(backward), + "jvp": staticmethod(jvp), + "setup_context": staticmethod(setup_context), + "generate_vmap_rule": True, + }, + ) + + def get_out_dims(): + assert out_dims != init_val + return out_dims + + return Generated, get_out_dims + + +# tangents might be None, so we need to replace +# the corresponding in_dims with None. +def get_tangents_in_dims(input_dims, tangents): + flat_in_dims, spec = pytree.tree_flatten(input_dims) + flat_tangents = pytree.arg_tree_leaves(*tangents) + result = [ + None if tangent is None else in_dim + for in_dim, tangent in zip(flat_in_dims, flat_tangents) + ] + return pytree.tree_unflatten(result, spec) + + +# NOTE: [Why do we need to run setup_context under a vmap?] +# Consider the following autograd.Function +# +# class Sum(torch.autograd.Function): +# @staticmethod +# def forward(x): +# return x.sum() +# @staticmethod +# def setup_context(ctx, inputs, outputs): +# ctx.x_shape = inputs[0] +# @staticmethod +# def backward(ctx, gy): +# return gy.expand(ctx.x_shape) +# +# x = torch.randn(B, 4) +# in_dims = 0 +# vmap(Sum.apply, in_dims)(x) +# +# Let's assume for a moment that we didn't vmap setup_context in VmappedSum: +# +# class VmappedSum(torch.autograd.Function): +# @staticmethod +# def forward(x): +# return vmap(Sum.forward, in_dims)(x) +# +# @staticmethod +# def setup_context(ctx, inputs, outputs): +# Sum.setup_context(ctx, inputs, outputs) +# +# @staticmethod +# def backward(ctx, gy): +# def backward_no_context(gy): +# return gy.expand(ctx.x_shape) +# +# dims = (0,) +# gx = vmap(backward_no_context, dims)(gy) +# return gx +# +# We end up saving [B, 4] as x_shape. In the backward, gy has shape [B], +# and we're doing: +# +# def backward_no_context(gy): +# return gy.expand([B, 4]) +# +# gx = vmap(backward_no_context, dims)(gy: "Tensor[B]") +# +# This gives us the wrong result (gx has shape [B, B, 4], but it should +# have shape [4]). Performing vmap over setup_context means the shape +# saved has shape [4] and leads to a correct result shape for gx. + + +# Wraps a ctx object. Forwards all attr accesses to the underlying object +# except for the attrs in _pt_attrs +class WrappedCtx: + _pt_reserved_attrs: Tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx") + + def __init__(self, ctx): + if not isinstance(ctx, WrappedCtx): + reserved_attrs = type(self)._pt_reserved_attrs + for name in reserved_attrs: + if not hasattr(ctx, name): + continue + raise RuntimeError( + f"PyTorch reserves the {reserved_attrs} field on ctx. " + "Please name your fields on ctx something else to avoid name " + "collision." + ) + self._pt_inner_ctx = ctx + + def __getattr__(self, name): + return getattr(self._pt_inner_ctx, name) + + def __setattr__(self, name, value): + if name in type(self)._pt_reserved_attrs: + self.__dict__[name] = value + return + return setattr(self._pt_inner_ctx, name, value) + + +# Wraps ctx to create a new ctx object that overrides saved_tensors. +class CtxWithSavedTensors(WrappedCtx): + _pt_reserved_attrs = ("_pt_new_saved_tensors", *WrappedCtx._pt_reserved_attrs) + + def __init__(self, ctx, new_saved_tensors): + super().__init__(ctx) + self._pt_new_saved_tensors = new_saved_tensors + + @property + def saved_tensors(self): + return self._pt_new_saved_tensors + + +class CtxCustomSave(WrappedCtx): + _pt_reserved_attrs = ( + "_pt_saved_tensors_bdims", + "_pt_current_level", + *WrappedCtx._pt_reserved_attrs, + ) + + def __init__(self, ctx, current_level): + super().__init__(ctx) + self._pt_saved_tensors_bdims = () + self._pt_current_level = current_level + + def save_for_backward(self, *tensors): + unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level) + self._pt_inner_ctx.save_for_backward(*unwrapped_tensors) + self._pt_saved_tensors_bdims = bdims + + def save_for_forward(self, *tensors): + unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level) + self._pt_inner_ctx.save_for_forward(*unwrapped_tensors) + self._pt_saved_tensors_bdims = bdims + + +def reductify( + grad_input, + grad_input_bdim, + input_bdim, + batch_size, + target_shape_without_bdim_to_reduce_to=None, +): + if not isinstance(grad_input, tuple): + grad_input = (grad_input,) + if not isinstance(grad_input_bdim, tuple): + grad_input_bdim = (grad_input_bdim,) + if not isinstance(input_bdim, tuple): + input_bdim = (input_bdim,) + + if target_shape_without_bdim_to_reduce_to is None: + target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,) + result = tuple( + reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape) + for gi, gi_bdim, i_bdim, maybe_ishape in zip( + grad_input, + grad_input_bdim, + input_bdim, + target_shape_without_bdim_to_reduce_to, + ) + ) + return result + + +def reductify_leaf( + grad_input, + grad_input_bdim, + input_bdim, + batch_size, + target_shape_without_bdim_to_reduce_to=None, +): + if grad_input is None: + return None + + if grad_input_bdim is None and input_bdim is None: + return grad_input + + if grad_input_bdim is not None and input_bdim is None: + return grad_input.sum(grad_input_bdim) + + # NOTE: [Why can't we rely on autograd to reduce expanded gradients?] + # For reverse-mode AD, + # given a grad_input and input, it is valid for the user to return a + # grad_input that has a broadcasted shape when compared to the input. + # In this situation, autograd automatically reduces the grad_input to + # the shape of the input. + # + # However, when input_bdim is not None, we have problems. + # + # [example 1] + # grad_input: Tensor[3, 4], input: Tensor[B, 4] + # We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable + # from [B, 4]. + # + # [example 2] + # grad_input: Tensor[3, B, 4], input: Tensor[B, 4] + # We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable + # from [B, 4]. + # + # This means that we need to also reduce the grad_input to the shape of the + # input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag; + # if not-None then we do the reducing manually, otherwise, we do not do a reduction. + assert input_bdim is not None + + if grad_input_bdim is None: + grad_input = grad_input.unsqueeze(input_bdim) + new_shape = list(grad_input.shape) + new_shape[input_bdim] = batch_size + grad_input = grad_input.expand(new_shape) + grad_input_bdim = input_bdim + + if target_shape_without_bdim_to_reduce_to is not None: + return vmap( + torch.Tensor.sum_to_size, + in_dims=(grad_input_bdim, None), + out_dims=input_bdim, + )(grad_input, target_shape_without_bdim_to_reduce_to) + + if input_bdim != grad_input_bdim: + grad_input = grad_input.movedim(grad_input_bdim, input_bdim) + return grad_input + + +def autograd_function_forward_rewritten(original_forward, original_setup_context): + def new_forward(ctx, *args, **kwargs): + output = original_forward(*args, **kwargs) + original_setup_context(ctx, args, output) + return output + + return new_forward + + +class AutogradFunctionApply(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("autograd_function_apply") + + def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs): + saved_values = None + args_tensor_mask = fwd_kwargs["args_tensor_mask"] + non_differentiable_idx = fwd_kwargs["non_differentiable_idx"] + length_of_tensor_args = sum(args_tensor_mask) + # Filter out the original tensor args from fwd_args, + # lifted freevars should not be args of ApplyTemplate.apply + # since we don't need to calculate the gradients of them. + new_fwd_args = fwd_args[:length_of_tensor_args] + + class ApplyTemplate(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + nonlocal saved_values + output, saved_values = fwd(None, *fwd_args) + + # If users call ctx.mark_non_differentiable() in the original fwd function. + if len(non_differentiable_idx) > 0: + non_differentiable_output = [] + for i, x in enumerate(output): + if i in non_differentiable_idx: + non_differentiable_output.append(x) + ctx.mark_non_differentiable(*non_differentiable_output) + + return output + + @staticmethod + def backward(ctx, *grad): + return bwd(None, *grad, *saved_values) + + return ApplyTemplate.apply(*new_fwd_args) + + +autograd_function_apply = AutogradFunctionApply() diff --git a/lib/python3.10/site-packages/torch/_functorch/batch_norm_replacement.py b/lib/python3.10/site-packages/torch/_functorch/batch_norm_replacement.py new file mode 100644 index 0000000000000000000000000000000000000000..90e4fec99b554a4f3f267caf7c2a0873618fb44d --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/batch_norm_replacement.py @@ -0,0 +1,29 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import torch.nn as nn +from torch._functorch.utils import exposed_in + + +def batch_norm_without_running_stats(module: nn.Module): + if ( + isinstance(module, nn.modules.batchnorm._BatchNorm) + and module.track_running_stats + ): + module.running_mean = None + module.running_var = None + module.num_batches_tracked = None + module.track_running_stats = False + + +@exposed_in("torch.func") +def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module: + """ + In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and + setting track_running_stats to be False for any nn.BatchNorm module in :attr:`root` + """ + # base case + batch_norm_without_running_stats(root) + + for obj in root.modules(): + batch_norm_without_running_stats(obj) + return root diff --git a/lib/python3.10/site-packages/torch/_functorch/benchmark_utils.py b/lib/python3.10/site-packages/torch/_functorch/benchmark_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0bcae4c836e9331f01fda7b27ec475f10d8a00b --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/benchmark_utils.py @@ -0,0 +1,231 @@ +# mypy: ignore-errors + +import contextlib +import json +import operator +import os +import time + +import torch +from torch.profiler import profile, ProfilerActivity + + +def synchronize(): + pass + + +def dump_chrome_trace( + f, + input, + trace_filename, + optimize_ctx, + activities, + num_runs=1, + devices=None, + kwargs_for_f=None, + kwargs_for_profiler=None, +): + """ + Output the chrome trace of running f(input, **kwargs_for_f) with [optimize_ctx] + [num_runs] times to [trace_filename]. + + [activities] are the activities that the profiler will record, e.g. ProfilerActivity.CUDA. + Return total runtime without the profiler + + Outputs to trace_filename + """ + + if devices is None: + devices = ["cuda"] + + global synchronize + if devices != ["cpu"] and torch.cuda.is_available(): + synchronize = torch.cuda.synchronize + + if kwargs_for_f is None: + kwargs_for_f = {} + if kwargs_for_profiler is None: + kwargs_for_profiler = {} + + with optimize_ctx: + torch.manual_seed(1337) + for _ in range(5): # warmup runs + f(input, **kwargs_for_f) + synchronize() + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(num_runs): + f(input, **kwargs_for_f) + synchronize() + t1 = time.perf_counter() + timing = t1 - t0 + + with profile(activities=activities, **kwargs_for_profiler) as prof: + with optimize_ctx: + synchronize() + torch.manual_seed(1337) + for _ in range(num_runs): + f(input, **kwargs_for_f) + synchronize() + prof.export_chrome_trace(trace_filename) + + return timing + + +def get_chrome_trace_events(filename): + f = open(filename) + data = json.load(f) + events = data["traceEvents"] + return events + + +def is_gpu_compute_event(event): + global gpu_pids + return ( + "pid" in event + and event["pid"] in gpu_pids + and "ph" in event + and event["ph"] == "X" + ) + + +def get_sorted_gpu_events(events): + sorted_gpu_events = [] + for event in events: + if not is_gpu_compute_event(event): + continue + sorted_gpu_events.append(event) + return sorted(sorted_gpu_events, key=operator.itemgetter("ts")) + + +def get_duration(sorted_gpu_events): + if len(sorted_gpu_events) == 0: + return 0 + event = sorted_gpu_events[0] + current_end_time = event["ts"] + event["dur"] + total_duration = event["dur"] + for event in sorted_gpu_events[1:]: + start_time = max(event["ts"], current_end_time) + end_time = event["ts"] + event["dur"] + total_duration = total_duration + max(end_time - start_time, 0) + current_end_time = max(current_end_time, end_time) + return total_duration + + +def get_sorted_gpu_mm_conv_events(events): + def is_mm_conv_event(event): + return "name" in event and ( + "gemm" in event["name"] + or "conv" in event["name"] + or "cutlass" in event["name"] + or "wgrad" in event["name"] + ) + + gpu_events = get_sorted_gpu_events(events) + sorted_events = [] + for event in gpu_events: + if not is_mm_conv_event(event): + continue + sorted_events.append(event) + return sorted_events + + +gpu_pids = [] + + +def compute_utilization(filename: str, total_length: float): + """ + Process the chrome traces outputs by the pytorch profiler to compute GPU Utilization + and percent of times spent on matmul and convolution + + Args: + filename(str): Name of chrome traces file produced by pytorch profiler + + total_length(float): total length of the process without profiler in second + + Return: + tuple: (GPU Utilization, percent of time spent on matmul and convolution) + """ + events = get_chrome_trace_events(filename) + + # get pids of GPU events + global gpu_pids + gpu_pids = [] + for event in events: + if "name" not in event: + continue + if event["name"] == "process_labels" and "GPU" in event["args"]["labels"]: + gpu_pids.append(event["pid"]) + + total_length = total_length * 1e6 + sorted_gpu_events = get_sorted_gpu_events(events) + utilization = get_duration(sorted_gpu_events) / total_length + + sorted_gpu_mm_conv_events = get_sorted_gpu_mm_conv_events(events) + mm_conv_utilization = get_duration(sorted_gpu_mm_conv_events) / total_length + + return utilization, mm_conv_utilization + + +def benchmark_utilization( + f, + input, + trace_folder, + optimize_ctx=None, + trace_file_name="tmp_chrome_trace", + num_runs=1, +): + """ + Benchmark the GPU Utilization and percent of time spent on matmul and convolution operations of + running f(input, **kwargs_for_f) with [optimize_ctx] [num_runs] times. + It will produce a chrome trace file in trace_folder/trace_file_name.json + + Example: + + ``` + def f(a): + return a.sum() + a = torch.rand(2**20, device="cuda") + utilization, mm_conv_utilization = benchmark_utilization(f, a, "tmp", trace_file_name = "tmp_chrome_trace") + ``` + + Args: + f: function to benchmark + + input: input to :attr:`f` + + trace_folder: name of the folder to store the chrome trace + + optimize_ctx: the context in which f will run + + trace_file_name: name of the dumped chrome trace file, default to "tmp_chrome_trace" + + num_runs: number of times to run f, excluding the warm-up runs, default to 1. + + Return: + tuple: (GPU Utilization, percent of time spent on matmul and convolution) + + """ + isExist = os.path.exists(trace_folder) + if not isExist: + os.makedirs(trace_folder) + print("create folder " + trace_folder) + + if optimize_ctx is None: + optimize_ctx = contextlib.nullcontext() + + chrome_trace_file_name = os.path.join(trace_folder, trace_file_name + ".json") + total_length = dump_chrome_trace( + f, + input, + chrome_trace_file_name, + optimize_ctx, + [ProfilerActivity.CUDA], + num_runs=num_runs, + devices="cuda", + ) + utilization, mm_conv_utilization = compute_utilization( + chrome_trace_file_name, total_length + ) + + return utilization, mm_conv_utilization diff --git a/lib/python3.10/site-packages/torch/_functorch/compile_utils.py b/lib/python3.10/site-packages/torch/_functorch/compile_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf61f1af3bf3c11250911720873197cae78a8c1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/compile_utils.py @@ -0,0 +1,176 @@ +# mypy: ignore-errors + + +from typing import Callable + +import torch +import torch.fx as fx +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_flatten + + +aten = torch.ops.aten + + +def get_aten_target(node: fx.Node) -> Callable: + if hasattr(node.target, "overloadpacket"): + return node.target.overloadpacket + return node.target + + +rand_ops = [ + aten.dropout, + aten._fused_dropout, + aten._standard_gamma, + aten.bernoulli, + aten.multinomial, + aten.native_dropout, + aten.normal, + aten.poisson, + aten.binomial, + aten.rrelu, + aten.rand_like, + aten.rand, + aten.randint, + aten.randn, + aten.randperm, +] + + +# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph +def fx_graph_cse(fx_g: torch.fx.graph.Graph): + new_graph = fx.Graph() + env = {} # map from node in the old graph to node in the new graph + hash_env = {} # map from hash to a node in the new graph + token_map = {} # map from hash to token + + from torch._inductor.pattern_matcher import ( + compute_mutation_region_ids, + same_mutation_regions, + ) + + compute_mutation_region_ids(fx_g) # type: ignore[arg-type] + + # Make a set of separate storages returned from the output, which will be preserved + # when pruning. This prevents us from deduplicating returned tensors which have + # experienced identical operations, but are separate data structures in eager mode. + output_node: fx.Node = list(fx_g.nodes)[-1] + assert output_node.op == "output" + + def checkable_node(node: fx.Node) -> bool: + """We can evaluate only nodes that represent tensors with defined storage.""" + if "val" not in node.meta or not isinstance(node.meta["val"], torch.Tensor): + return False + + try: + node.meta["val"].untyped_storage() + except NotImplementedError: + return False + + return True + + output_storages = { + StorageWeakRef(n.meta["val"].untyped_storage()) + for n in output_node.all_input_nodes + if checkable_node(n) + } + nodes_that_alias_outputs = { + n + for n in fx_g.nodes + if checkable_node(n) + and StorageWeakRef(n.meta["val"].untyped_storage()) in output_storages + } + + for n in fx_g.nodes: + # The placeholder, output, and get_attr nodes are copied to the new graph without change + # do not CSE away random operations + if ( + n.op == "placeholder" + or n.op == "output" + or n.op == "get_attr" + or get_aten_target(n) in rand_ops + # aten.empty is non-deterministic, so don't CSE it. + # Also, aten.empty is almost always fusible into its consumer, + # so it's not worth CSEing. + or get_aten_target(n) is aten.empty + or n in nodes_that_alias_outputs + ): + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' + # substitute args and kwargs members to their mapping in env if exists + # specs can be used to reconstruct nested list/dictionaries + def substitute(arg_list): + arg_list, spec = tree_flatten(arg_list) + for i in range(len(arg_list)): + v = arg_list[i] + if isinstance(v, torch.fx.node.Node) and v in env: + arg_list[i] = env[v] + if isinstance(v, (torch.SymBool, torch.SymInt, torch.SymFloat)): + arg_list[i] = v.node + return tuple(arg_list), spec + + args, args_spec = substitute(n.args) + kwargs, kwargs_spec = substitute(n.kwargs) + + # each token corresponds to a unique node + # nodes with the same token can be substituted + token = { + "target": n.target, + "args": args, + "args_spec": args_spec, + "kwargs": kwargs, + "kwargs_spec": kwargs_spec, + } + + # hash substituted args to a number, do not hash specs because specs are not hashable + # We need to add type into hash to avoid situations like: + # hash((primals_2, 1.0)) == hash((primals_2, 1)) + hash_arg = hash( + (tuple((a, type(a)) for a in args), tuple((a, type(a)) for a in kwargs)) + ) + hash_val = (n.target, hash_arg) + + # check if a node has a substitute and can be eliminated + hash_val_in_hash_env = hash_val in hash_env + overwrite_due_to_mutation = False + if hash_val_in_hash_env and token_map[hash_val] == token: + duplicate_n_prev = hash_env[hash_val] + if same_mutation_regions(n, duplicate_n_prev): + env[n] = duplicate_n_prev + continue + else: + # any futures duplicates should replace with n, not duplicate_n_prev + overwrite_due_to_mutation = True + + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + if overwrite_due_to_mutation or not hash_val_in_hash_env: + hash_env[hash_val] = new_node + token_map[hash_val] = token + + return new_graph + + +def strip_overloads(gm): + """ + Modifies the target of graph nodes in :attr:`gm` to strip overloads. + + Args: + gm(fx.GraphModule): The input Fx graph module to be modified + """ + for node in gm.graph.nodes: + if isinstance(node.target, torch._ops.OpOverload): + node.target = node.target.overloadpacket + gm.recompile() + + +def get_placeholders(graph): + return graph.find_nodes(op="placeholder") + + +def get_outputs(graph): + for node in graph.find_nodes(op="output"): + return pytree.tree_leaves(node.args[0]) + raise AssertionError("No output node found") diff --git a/lib/python3.10/site-packages/torch/_functorch/compilers.py b/lib/python3.10/site-packages/torch/_functorch/compilers.py new file mode 100644 index 0000000000000000000000000000000000000000..b420daca5ac347f9c1bc5354d495e91bb53177ec --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/compilers.py @@ -0,0 +1,445 @@ +# mypy: ignore-errors + +import copy +import logging +import os +import pickle +import random +from contextlib import contextmanager +from functools import partial +from typing import Callable, Union + +import sympy + +import torch +import torch.fx as fx +import torch.nn as nn +import torch.utils._pytree as pytree +from torch import SymInt +from torch._decomp import get_decompositions +from torch.fx.experimental.symbolic_shapes import bind_symbols + +from .aot_autograd import aot_function, aot_module, make_boxed_compiler +from .compile_utils import strip_overloads +from .partitioners import ( + default_partition, + draw_graph, + min_cut_rematerialization_partition, +) + + +log = logging.getLogger(__name__) + + +# These canonicalizations are needed here (and not decompositions), as the ops +# we're trying to canonicalize to CompositeImplicitAutograd. +def _canonicalize(fx_g): + for node in fx_g.graph.find_nodes( + op="call_function", target=torch.ops.aten._to_copy + ): + node.target = torch.ops.aten.to + fx_g.recompile() + return fx_g + + +@contextmanager +def _disable_jit_autocast(): + old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) + try: + yield + finally: + torch._C._jit_set_autocast_mode(old_jit_autocast_flag) + + +@make_boxed_compiler +def ts_compile(fx_g: fx.GraphModule, inps) -> Callable: + """ + Compiles the :attr:`fx_g` with Torchscript compiler. + + .. warning:: + This API is experimental and likely to change. + + Args: + fx_g(fx.GraphModule): The input Fx graph module to be compiled. + + Returns: + Torch scripted model. + """ + + with _disable_jit_autocast(): + strip_overloads(fx_g) + + for node in fx_g.graph.find_nodes( + op="call_function", target=torch.ops.aten._to_copy + ): + if len(node.args) == 1 and len(node.kwargs) == 1 and "dtype" in node.kwargs: + node.target = torch.ops.aten.to + + for node in fx_g.graph.nodes: + new_kwargs = {} + for k, v in node.kwargs.items(): + if isinstance(v, torch.device): + v = v.type + new_kwargs[k] = v + node.kwargs = new_kwargs + + fx_g.graph.lint() + + fx_g.recompile() + + f = torch.jit.script(fx_g) + + torch._C._jit_pass_remove_mutation(f.graph) + + f = torch.jit.freeze(f.eval()) + f = torch.jit.optimize_for_inference(f) + if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps): + f(*inps) + return f + + +def _draw_graph_compile(fx_g, _, name, clear_meta=True): + print(fx_g.code) + draw_graph(fx_g, name, clear_meta=clear_meta) + return fx_g + + +def draw_graph_compile(name): + return make_boxed_compiler(partial(_draw_graph_compile, name=name)) + + +@make_boxed_compiler +def nop(fx_g: fx.GraphModule, _) -> Callable: + """ + Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler + and can be used to check accuracy. + + .. warning:: + This API is experimental and likely to change. + + """ + return fx_g + + +class DebugInterpreter(fx.Interpreter): + def run(self, *args): + self.symbol_mapping = bind_symbols(self.module, *args) + super().run(*args) + + def run_node(self, n): + def subst_symint(ni): + if not isinstance(ni, SymInt): + return ni + r = sympy.expand(ni.node.expr.xreplace(self.symbol_mapping)) + assert r.is_number, r + return int(r) + + def subst_symint_tuple(nis): + return tuple(subst_symint(ni) for ni in nis) + + def check_significant_strides(a, b): + if subst_symint(a.numel()) > 0: + for idx in range(a.ndim): + if ( + subst_symint(a.stride(idx)) != b.stride(idx) + and subst_symint(a.size(idx)) > 1 + ): + return False + return True + + def check(nv, rv, desc): + assert callable(desc) + assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}" + assert ( + subst_symint_tuple(nv.size()) == rv.size() + ), f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}" + same_strides = check_significant_strides(nv, rv) + assert ( + same_strides + ), f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}" + + r = super().run_node(n) + if "val" in n.meta: + n_vals, n_spec = pytree.tree_flatten(n.meta["val"]) + r_vals, r_spec = pytree.tree_flatten(r) + # TODO: There is some sort of problem where we record that an + # operator returned a tuple/list, and then later it turns out the + # real version of the operator returned a list/tuple. Need to + # figure out what's actually going on here, the error itself is + # harmless enough as we only getitem out the outputs. + # assert n_spec == r_spec, f"{n_spec} != {r_spec}" + assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" + for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): + if not isinstance(rv, torch.Tensor): + continue + check(nv, rv, lambda: f"output {i} where {self.symbol_mapping}") + return r + + +@make_boxed_compiler +def debug_nop(fx_g: fx.GraphModule, _) -> Callable: + """ + Returns a (slow) interpreter over the FX graph module that also checks + various debugging properties (e.g., that tracing strides matched real + strides.) + """ + return DebugInterpreter(fx_g).run + + +@make_boxed_compiler +def simple_ts_compile(fx_g, _): + strip_overloads(fx_g) + f = torch.jit.script(fx_g) + f = torch.jit.freeze(f.eval()) + return f + + +def nnc_jit(f): + return aot_function(f, simple_ts_compile) + + +aten = torch.ops.aten +default_decompositions = { + aten.detach, + aten.gelu_backward, + aten.leaky_relu_backward, + aten.sigmoid_backward, + aten.threshold_backward, + aten.hardtanh_backward, + aten.hardsigmoid_backward, + aten.hardswish_backward, + aten.tanh_backward, + aten.silu_backward, + aten.elu_backward, + aten.cudnn_batch_norm, + aten.cudnn_batch_norm_backward, + aten.masked_fill.Scalar, + aten.masked_fill.Tensor, + aten.elu, + aten.leaky_relu, + aten.hardtanh, + aten.hardswish, + aten.hardsigmoid, + aten.conj_physical, + aten.is_same_size, +} + +default_decompositions = get_decompositions(default_decompositions) + + +@make_boxed_compiler +def print_compile(fx_g, _): + print(fx_g.code) + return fx_g + + +def memory_efficient_fusion( + fn: Union[Callable, nn.Module], + **kwargs, +): + """ + Wrapper function over :func:`aot_function` and :func:`aot_module` to perform + memory efficient fusion. It uses the + :func:`min_cut_rematerialization_partition` partitioner to perform efficient + recomputation. It uses NVFuser to compile the generated forward and backward + graphs. + + .. warning:: + This API is experimental and likely to change. + + Args: + fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module`` + that takes one ore more arguments. Must return one or more Tensors. + **kwargs: Any other overrides you want to make to the settings + + Returns: + Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior + of the original :attr:`fn`, but whose forward and backward graphs have + gone through recomputation optimizations, and the graphs have been + compiled with nvfuser. + + """ + config = { + "fw_compiler": ts_compile, + "bw_compiler": ts_compile, + "partition_fn": min_cut_rematerialization_partition, + "decompositions": default_decompositions, + } + config.update(kwargs) + if isinstance(fn, torch.nn.Module): + return aot_module(fn, **config) + else: + return aot_function(fn, **config) + + +def debug_compile(fx_g, inps): + fx_g.to_folder("foo") + print( + f""" +############################################################## +# To minimize FX graph, copy and paste the below and run it # +############################################################## + +import torch +import torch.fx as fx +from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess + +inps = {[(i.shape, i.dtype) for i in inps]} +inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps] +from foo import FxModule +mod = FxModule().cuda() + +with torch.jit.fuser("fuser2"): + # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess + minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess) +""" + ) + from foo import FxModule + + FxModule().cuda()(*inps) + + return ts_compile(fx_g, inps) + + +graph_index = 0 + + +def get_inputs(input_data_path): + """ + Return a random input for the given inputs meta generated from _save_fx_default. + """ + inputs = [] + with open(input_data_path, "rb") as f: + inputs_meta = pickle.load(f) + inputs = [] + for meta in inputs_meta: + if len(meta) == 1: + type = meta + input = type(random.rand()) + else: + type, shape, stride, dtype, device = meta + if dtype in { + torch.int, + torch.int32, + torch.int64, + torch.bool, + torch.int, + torch.uint8, + int, + float, + }: + input = torch.randint(0, 1, shape, dtype=dtype, device=device) + else: + input = torch.rand(shape, dtype=dtype, device=device) + inputs.append(input) + return inputs + + +def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_inputs): + """ + The forward, backward, and joint computation graph will be stored in + {folder_name}/{current_name}/{current_name}_forward_{graph_index}, + {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and + {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively. + The input shape of the graphs will be stored in the .input files. + These files can be loaded with pickle, + and is a list of format (type, shape, stride, dtype, device). + In the case of type = int or float, it is just (type,). + For joint graph input, it is a nested list [[],[]] + where the two inner lists have the same format. + If dump_example_input is True, example_inputs will be stored in .pt file. + Since each function might produce multiple graphs, + the graph_index is used to distinguish difference graphs + """ + from functorch.compile import aot_module_simplified + + def get_input_meta(args): + input_meta = [] + if len(args) > 0 and isinstance(args[0], tuple): # joint input + input_meta += get_input_meta(args[0]) + input_meta += get_input_meta(args[1]) + return input_meta + for arg in args: + if type(arg) == int or type(arg) == float: + input_meta.append((type(arg),)) + else: + input_meta.append( + (type(arg), arg.shape, arg.stride(), arg.dtype, arg.device) + ) + return input_meta + + def graph_saver_helper(gm_to_save, args, type_name): + global graph_index + if len(gm_to_save.graph.nodes) == 0: + log.log( + logging.WARNING, + "No nodes in graph {%s}_{%s}_{%s}.", + current_name, + type_name, + graph_index, + ) + return + + gm = copy.deepcopy(gm_to_save) + gm.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen + gm.recompile() + + input_meta = get_input_meta(args) + + os.makedirs(f"{folder_name}/{current_name}", exist_ok=True) + gm.to_folder( + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}" + ) + pickle.dump( + input_meta, + open( + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950 + "wb", + ), + ) # noqa: E501 + if dump_example_input: + torch.save( + args, + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950 + ) # noqa: E501 + + def graph_saver_forward(gm, fw_args): + graph_saver_helper(gm, fw_args, "forward") + return gm + + def graph_saver_backward(gm, bw_args): + graph_saver_helper(gm, bw_args, "backward") + global graph_index + graph_index += 1 + return gm + + def graph_saver_joint(gm, joint_args): + graph_saver_helper(gm, joint_args, "joint") + return default_partition(gm, joint_args) + + return aot_module_simplified( + gm, + example_inputs, + fw_compiler=graph_saver_forward, + bw_compiler=graph_saver_backward, + partition_fn=graph_saver_joint, + decompositions=default_decompositions, + ) + + +# WARNING: This isn't tested anywhere!! +def graph_dumper_aot(current_name, folder_name, dump_example_input=False): + """ + Dump the forward, backward, and joint computation graph. + Example Usage: + save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False) + optimize_ctx = torchdynamo.optimize( + save_fx_func + ) + with torch.enable_grad(): + with optimize_ctx: + result = forward_and_backward_pass(model, example_inputs) + """ + global graph_index + graph_index = 0 + return partial(_save_fx_default, current_name, folder_name, dump_example_input) diff --git a/lib/python3.10/site-packages/torch/_functorch/config.py b/lib/python3.10/site-packages/torch/_functorch/config.py new file mode 100644 index 0000000000000000000000000000000000000000..04976c06965c0047e3352846477ef6544275921a --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/config.py @@ -0,0 +1,203 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Global flags for aot autograd +""" +import os +import sys +from typing import TYPE_CHECKING + + +# Converts torch rng ops to their functional philox rng equivalents. Note that +# we functionalize only CUDA rng ops today. +functionalize_rng_ops = False + +# can be useful for debugging if we are incorrectly creating meta fake tensors +fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0" + +# Enables optional asserts in hotpath code to check for errors. If +# you are seeing weird accuracy problems, try turning this on. +# This is currently off by default as it will harm tracing time, +# but it is on by default for aot_eager. +debug_assert = False + +debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0" + +# Today, if you are in a situation where there is "false aliasing" +# (e.g. you have a bunch of model parameters that all alias the same underlying buffer), +# our checks for this situation are very slow if these inputs have dynamic shapes. +# This config is set to ensure that there aren't too many aliased inputs in this situation, +# so that we error loudly instead of compiling forever. +# Eventually, we should make these checks faster. +# For now, however, you can simply turn off dynamic shapes by marking your inputs static +# when you run into this situation. +_max_aliased_inputs_with_dynamic_shapes_enabled = 5 + +static_weight_shapes = True + +# Applies CSE to the graph before partitioning +cse = True + + +enable_autograd_cache = os.environ.get("ENABLE_AOT_AUTOGRAD_CACHE", "0") == "1" + +# When AOTAutograd regenerates aliased graph outputs, +# attempt to use functionalization's view-replay logic +# before falling back to the autograd engine's view replay or as_strided. +# This can have some perf implications +# (although for many models this will not matter). +# (1) If you have many view ops chained together, replaying all of them +# at runtime can have more overhead compared to a single as_strided call +# (2) If you are doing training, AsStridedBackward is quite slow, +# and the individual view op backward formulas will likely be faster. +# (3) Some backends like XLA do not support as_strided + +# Temporary hack: disable this flag for internal +# (needed to fix an internal issue while avoiding bumping XLA pin) +# eventually: either default this config to false completely +# once XLA pin update works, +# or default config to true and fix relevant bugs +from torch._inductor.config import is_fbcode + + +# View replay is currently not compatible with AOTAutogradCache, since +# FunctionalTensors are not serializable. We'll need to make them +# serializable before enabling warm cache with this config turned on. +view_replay_for_aliased_outputs = (not is_fbcode()) and (not enable_autograd_cache) + +# Restricts the amount of computation AOTAutograd can do. +# NB: We have essentially disabled this heuristic now. However, this is kept +# here for now in case it's useful. Setting it low can artificially reduce the +# amount of recomputation AOTAutograd performs, although not in any kind of +# principled way. +max_dist_from_bw = 1000 + + +# Bans recomputation of nodes that are reading from nodes that is far before +# the current node +ban_recompute_used_far_apart = True +# Breaks up long chain of fusible ops, as otherwise we can have an arbitrarily +# long chain of recomputation in the backwards pass. +ban_recompute_long_fusible_chains = True +# Bans recomputation of nodes that must be materialized in the backwards pass +# (used by a non-fusible node) +ban_recompute_materialized_backward = True +# Chooses to ban recomputation of nodes based off an allowlist. Setting it to +# False changes it to use a denylist. Main change is on operators like +# sort/pool/stuff that isn't cheap enough to be fusible for free but also isn't +# that expensive +ban_recompute_not_in_allowlist = True +# Chooses to ban recomputation of reductions. This is generally a good idea, as +# the result of reductions is generally very small but recomputing reductions in +# a fusion can be expensive. +ban_recompute_reductions = True +# Prevents the partitioner from ever saving views (i.e. always recompute them). +# Generally a good idea since views are free to recompute. +recompute_views = False + +# By default, the partitioner is purely trying to optimize for runtime (although +# it should always use less memory than eager) +# This knob controls the partitioner to make that tradeoff for you, choosing the +# fastest option that saves less activations than the memory budget. +# Specifically, 0.0 corresponds to the activation memory from applying +# activation checkpointing to the full compiled region, and 1.0 corresponds to +# the activation memory from the default runtime-optimized strategy. So, 0.4 +# would result in a strategy that saves 40% of the activations compared to the +# default strategy. +# It solves a 0-1 knapsack to find the minimum recompute necessary to stay below +# the activation memory budget. +# NOTE: This *cannot* be treated as +activation_memory_budget = 1.0 + +# This controls how we estimate the runtime when deciding what the cheapest +# operators to recompute are. The 3 options are +# "flops": Bases it off of the flop count provided by torch.utils.flop_counter +# "profile": Benchmarks each operator to come up with a runtime +# "testing": Returns 1 for everything +activation_memory_budget_runtime_estimator = "flops" + +# This controls the solver used for the 0-1 knapsack. By default we use a +# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp" +# (which has a scipy dependency). +activation_memory_budget_solver = "dp" + +# This dumps out a png visualization of the expected runtime vs. activation +# memory tradeoffs for all memory budget values from 0 to 1 in increments of +# 0.5. See an example here: +# https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 +visualize_memory_budget_pareto = ( + os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1" +) + +# Sets all of the ban_recompute heuristics to False except ban_recompute_reductions +# Generally, this will probably result in some memory improvement, but at the +# cost of some performance +aggressive_recomputation = False + +# If FakeTensor.data_ptr() should error. +# This option is independent of AOTAutograd and torch.compile, but our policy +# is to turn it off during torch.compile. +fake_tensor_allow_unsafe_data_ptr_access = True + +# Unlifts effect tokens from the inputs/outputs in the traced graph and instead +# inserts make_token/sink_token calls in the graph to create tokens and then +# sink them at the end. Note that this means the graph is no longer functional +# which may lead to silent errors unless the backend knows how to handle the +# tokens. +unlift_effect_tokens = False + +# This mode specifies that we should also keep track of the real +# tensor along with the fake tensor, and do real compute. While +# seemingly this eliminates the whole point of fake tensors, there are +# two obvious use cases for it: +# +# 1. When users call item()/other data dependent operations, +# if we propagate_real_tensors we are able to determine what +# the true value is and keep going. +# +# 2. It can be useful for testing, when you want to see if the fake +# and real tensors agree with each other. (Note that there are +# currently known inaccuracies in how we clone real tensors, that +# would have to be tightened up for this to be useful in this +# case.) +# +# Note that fake tensors are typically understood to be cheap to store +# indefinitely, so we tend to hold on to them longer than we would +# hold onto the real tensors. So we also support you explicitly +# deallocating the real tensor associated with a fake tensor, at which +# point we will stop propagating real tensors. +# +# One more thing: when you provide a real tensor to fakeify, we will +# clone it, so that we can safely perform mutations on it if necessary. +# This will increase live memory usage. This could potentially be +# optimized by using COW. We also currently do not faithfully +# maintain autograd metadata on the real tensor; this is fine because +# AOTAutograd will only use the fake tensor to determine leafness/etc +# of tensors in question. +fake_tensor_propagate_real_tensors = False + +# This controls whether we collect donated buffer. This flag must be set +# False if a user wants to retain_graph=True for backward. +donated_buffer = False + +# Controls the default graph output format used by draw_graph +# Supported formats are defined here https://graphviz.org/docs/outputs/ +torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg") + + +# Error on BypassAOTAutogradCache instead of just a warning +# Used for tests +strict_autograd_cache = False + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + +from torch.utils._config_module import install_config_module + + +# adds patch, save_config, invalid config checks, etc +install_config_module(sys.modules[__name__]) diff --git a/lib/python3.10/site-packages/torch/_functorch/deprecated.py b/lib/python3.10/site-packages/torch/_functorch/deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb930e8ecb742d23a6cfeefb76ffe8587cea952 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/deprecated.py @@ -0,0 +1,172 @@ +# mypy: allow-untyped-defs +""" +The APIs in this file are exposed as `functorch.*`. They are thin wrappers +around the torch.func.* APIs that have deprecation warnings -- we're trying +to move people to the torch.func.* equivalents. + +NB: We don't use *args, **kwargs in the signatures because that changes the +documentation. +""" + +import textwrap +import warnings +from typing import Any, Callable, Optional, Tuple, Union + +import torch._functorch.apis as apis +import torch._functorch.eager_transforms as _impl +import torch._functorch.make_functional as _nn_impl +import torch.nn as nn +from torch._functorch.eager_transforms import argnums_t +from torch._functorch.vmap import in_dims_t, out_dims_t + + +def get_warning(api, new_api=None, replace_newlines=False): + if new_api is None: + new_api = f"torch.func.{api}" + warning = ( + f"We've integrated functorch into PyTorch. As the final step of the \n" + f"integration, `functorch.{api}` is deprecated as of PyTorch \n" + f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n" + f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n" + f"and/or the `torch.func` migration guide for more details \n" + f"https://pytorch.org/docs/main/func.migrating.html" + ) + if replace_newlines: + warning = warning.replace("\n", "") + return warning + + +def warn_deprecated(api, new_api=None): + warning = get_warning(api, new_api, replace_newlines=True) + warnings.warn(warning, FutureWarning, stacklevel=3) + + +def setup_docs(functorch_api, torch_func_api=None, new_api_name=None): + api_name = functorch_api.__name__ + if torch_func_api is None: + torch_func_api = getattr(_impl, api_name) + # See https://docs.python.org/3/using/cmdline.html#cmdoption-OO + if torch_func_api.__doc__ is None: + return + + warning = get_warning(api_name, new_api_name) + warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, " ") + warning_note = textwrap.indent(warning_note, " ") + functorch_api.__doc__ = torch_func_api.__doc__ + warning_note + + +def vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + randomness: str = "error", + *, + chunk_size=None, +) -> Callable: + warn_deprecated("vmap", "torch.vmap") + return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size) + + +def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: + warn_deprecated("grad") + return apis.grad(func, argnums, has_aux) + + +def grad_and_value( + func: Callable, argnums: argnums_t = 0, has_aux: bool = False +) -> Callable: + warn_deprecated("grad_and_value") + return apis.grad_and_value(func, argnums, has_aux) + + +def vjp(func: Callable, *primals, has_aux: bool = False): + warn_deprecated("vjp") + return _impl.vjp(func, *primals, has_aux=has_aux) + + +def jvp( + func: Callable, + primals: Any, + tangents: Any, + *, + strict: bool = False, + has_aux: bool = False, +): + warn_deprecated("jvp") + return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux) + + +def jacrev( + func: Callable, + argnums: Union[int, Tuple[int]] = 0, + *, + has_aux=False, + chunk_size: Optional[int] = None, + _preallocate_and_copy=False, +): + warn_deprecated("jacrev") + return _impl.jacrev( + func, + argnums, + has_aux=has_aux, + chunk_size=chunk_size, + _preallocate_and_copy=_preallocate_and_copy, + ) + + +def jacfwd( + func: Callable, + argnums: argnums_t = 0, + has_aux: bool = False, + *, + randomness: str = "error", +): + warn_deprecated("jacfwd") + return _impl.jacfwd(func, argnums, has_aux, randomness=randomness) + + +def hessian(func, argnums=0): + warn_deprecated("hessian") + return _impl.hessian(func, argnums=argnums) + + +def functionalize(func: Callable, *, remove: str = "mutations") -> Callable: + warn_deprecated("functionalize") + return _impl.functionalize(func, remove=remove) + + +def make_functional(model: nn.Module, disable_autograd_tracking: bool = False): + warn_deprecated("make_functional", "torch.func.functional_call") + return _nn_impl.make_functional(model, disable_autograd_tracking) + + +def make_functional_with_buffers( + model: nn.Module, disable_autograd_tracking: bool = False +): + warn_deprecated("make_functional_with_buffers", "torch.func.functional_call") + return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking) + + +def combine_state_for_ensemble(models): + warn_deprecated("combine_state_for_ensemble", "torch.func.stack_module_state") + return _nn_impl.combine_state_for_ensemble(models) + + +setup_docs(vmap, apis.vmap, "torch.vmap") +setup_docs(grad, apis.grad) +setup_docs(grad_and_value, apis.grad_and_value) +setup_docs(vjp) +setup_docs(jvp) +setup_docs(jacrev) +setup_docs(jacfwd) +setup_docs(hessian) +setup_docs(functionalize) +setup_docs(make_functional, _nn_impl.make_functional, "torch.func.functional_call") +setup_docs( + make_functional_with_buffers, _nn_impl.make_functional, "torch.func.functional_call" +) +setup_docs( + combine_state_for_ensemble, + _nn_impl.combine_state_for_ensemble, + "torch.func.stack_module_state", +) diff --git a/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py b/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d389c7fda789497025a98ed84fdc0e9108685c67 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py @@ -0,0 +1,1839 @@ +# mypy: ignore-errors + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from functools import partial, wraps +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +import torch.autograd.forward_ad as fwAD +from torch._C._functorch import ( + _assert_wrapped_functional, + _func_decrement_nesting, + _func_increment_nesting, + _grad_decrement_nesting, + _grad_increment_nesting, + _jvp_decrement_nesting, + _jvp_increment_nesting, + _propagate_functional_input_mutation, + _unwrap_for_grad, + _unwrap_functional_tensor, + _wrap_for_grad, + _wrap_functional_tensor, + get_inplace_requires_grad_allowed, + set_inplace_requires_grad_allowed, +) +from torch._functorch.utils import argnums_t, exposed_in +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx.experimental import const_fold +from torch.fx.experimental.proxy_tensor import make_fx +from torch.utils import _pytree as pytree +from torch.utils._pytree import ( + tree_flatten, + tree_map, + tree_map_, + tree_map_only, + tree_unflatten, + treespec_pprint, +) + +from .apis import vmap +from .vmap import doesnt_support_saved_tensors_hooks, get_chunk_sizes + + +def lazy_dynamo_disallow(func): + import torch._dynamo + + return torch._dynamo.disallow_in_graph(func) + + +@contextlib.contextmanager +def enable_inplace_requires_grad(enabled): + prev_state = get_inplace_requires_grad_allowed() + set_inplace_requires_grad_allowed(enabled) + try: + yield + finally: + set_inplace_requires_grad_allowed(prev_state) + + +def _vjp_treespec_compare(primals_out, cotangents): + # Revert this once #116264 gets fixed + _, primals_out_spec = tree_flatten(primals_out) + _, cotangents_spec = tree_flatten(cotangents) + # Dynamo fails to trace operator.ne below. To bypass this limitation, this + # function is not inlined. + if primals_out_spec != cotangents_spec: + raise RuntimeError( + f"Expected pytree structure of cotangents to be the same " + f"as pytree structure of outputs to the function. " + f"cotangents: {treespec_pprint(cotangents_spec)}, " + f"primal output: {treespec_pprint(primals_out_spec)}" + ) + + +def _jvp_treespec_compare(primals, tangents): + # Revert this once #116264 gets fixed + _, primals_spec = tree_flatten(primals) + _, tangents_spec = tree_flatten(tangents) + if primals_spec != tangents_spec: + raise RuntimeError( + f"{jvp_str}: Expected primals and tangents to have the same python " + f"structure. For example, if primals is a tuple of 3 tensors, " + f"tangents also must be. Got primals with structure {primals_spec} " + f"and tangents with structure {tangents_spec}" + ) + + +def _linearize_treespec_compare(primals, tangents): + # Revert this once #116264 gets fixed + _, primals_argspec = tree_flatten(primals) + _, tangent_argspec = tree_flatten(tangents) + if tangent_argspec != primals_argspec: + raise RuntimeError( + f"Expected the tangents {tangent_argspec} to have " + f"the same argspec as the primals {primals_argspec}" + ) + + +def _set_tensor_requires_grad(x): + # avoid graph-break on x.requires_grad_() + # https://github.com/pytorch/pytorch/pull/110053 + return x.requires_grad_() + + +def _create_differentiable(inps, level=None): + def create_differentiable(x): + if isinstance(x, torch.Tensor): + with enable_inplace_requires_grad(True): + return _set_tensor_requires_grad(x) + raise ValueError( + f"Thing passed to transform API must be Tensor, " f"got {type(x)}" + ) + + return tree_map(create_differentiable, inps) + + +def _undo_create_differentiable(inps, level=None): + def unwrap_tensors(x): + if isinstance(x, torch.Tensor): + return _unwrap_for_grad(x, level) + # TODO: Remove the following hack for namedtuples + if isinstance(x, tuple): + return tree_map(unwrap_tensors, tuple(x)) + + raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}") + + return tree_map(unwrap_tensors, inps) + + +def _is_differentiable(maybe_tensor): + if not isinstance(maybe_tensor, torch.Tensor): + return False + return maybe_tensor.requires_grad + + +def _any_differentiable(tensor_or_tuple_of_tensors): + flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors) + return any(tuple(map(_is_differentiable, flat_args))) + + +def _wrap_tensor_for_grad(maybe_tensor, level): + if not isinstance(maybe_tensor, torch.Tensor): + return maybe_tensor + return _wrap_for_grad(maybe_tensor, level) + + +def _wrap_all_tensors(tensor_pytree, level): + return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree) + + +def _as_tuple(val): + if isinstance(val, tuple): + return val + return (val,) + + +# Version of autograd.grad that handles outputs that don't depend on inputs + + +def _autograd_grad( + outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True +): + if grad_outputs is None: + diff_outputs = tuple(out for out in outputs if out.requires_grad) + else: + result = tuple( + (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad + ) + if len(result) == 0: + diff_outputs, grad_outputs = (), () + else: + diff_outputs, grad_outputs = zip(*result) + if len(diff_outputs) == 0: + return tuple(torch.zeros_like(inp) for inp in inputs) + grad_inputs = torch.autograd.grad( + diff_outputs, + inputs, + grad_outputs, + retain_graph=retain_graph, + create_graph=create_graph, + allow_unused=True, + ) + grad_inputs = tuple( + torch.zeros_like(inp) if gi is None else gi + for gi, inp in zip(grad_inputs, inputs) + ) + return grad_inputs + + +# NOTE [grad and vjp interaction with no_grad] +# +# def f(x): +# with torch.no_grad(): +# c = x ** 2 +# return x - c +# +# The thing to consider is if enable_grad is on/off before grad gets called. +# +# Case 1: enable_grad is on. +# grad(f)(x) +# In this case, `grad` should respect the inner torch.no_grad. +# +# Case 2: enable_grad is off +# with torch.no_grad(): +# grad(f)(x) +# In this case, `grad` should respect the inner torch.no_grad, but not the +# outer one. This is because `grad` is a "function transform": its result +# should not depend on the result of a context manager outside of `f`. +# +# This gives us the following desired behavior: +# - (nested) grad transforms must obey torch.no_grad inside them +# - (nested) grad transforms should not obey torch.no_grad outside them +# +# To achieve this behavior, upon entering grad/vjp: +# - we save the current ("previous") is_grad_enabled (*) +# - we unconditionally enable grad. +# +# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer +# off the stack: +# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad +# active, all subsequent grad transforms must obey it). +# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False, +# then we temporarily restore the previous `is_grad_enabled`. This is +# because we're crossing the boundary from a `grad` outside the +# no_grad to a `grad` inside the no_grad. +# +# NB: vjp has some interesting behavior because the vjp's callable can be called +# under a different grad_mode than the forward computation... +# +# NB: forward-mode AD: forward-mode AD doesn't respect torch.no_grad, but +# it respects c10::AutoFwGradMode. We've implemented the same logic for +# our jvp transform (it will have special handling if FwGradMode is disabled). + + +# How do we increment and decrement the nesting? I don't think we can. +@exposed_in("torch.func") +def vjp(func: Callable, *primals, has_aux: bool = False): + """ + Standing for the vector-Jacobian product, returns a tuple containing the + results of ``func`` applied to ``primals`` and a function that, when + given ``cotangents``, computes the reverse-mode Jacobian of ``func`` with + respect to ``primals`` times ``cotangents``. + + Args: + func (Callable): A Python function that takes one or more arguments. Must + return one or more Tensors. + primals (Tensors): Positional arguments to ``func`` that must all be + Tensors. The returned function will also be computing the + derivative with respect to these arguments + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + other auxiliary objects that will not be differentiated. + Default: False. + + Returns: + Returns a ``(output, vjp_fn)`` tuple containing the output of ``func`` + applied to ``primals`` and a function that computes the vjp of + ``func`` with respect to all ``primals`` using the cotangents passed + to the returned function. If ``has_aux is True``, then instead returns a + ``(output, vjp_fn, aux)`` tuple. + The returned ``vjp_fn`` function will return a tuple of each VJP. + + When used in simple cases, :func:`vjp` behaves the same as :func:`grad` + + >>> x = torch.randn([5]) + >>> f = lambda x: x.sin().sum() + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> grad = vjpfunc(torch.tensor(1.))[0] + >>> assert torch.allclose(grad, torch.func.grad(f)(x)) + + However, :func:`vjp` can support functions with multiple outputs by + passing in the cotangents for each of the outputs + + >>> x = torch.randn([5]) + >>> f = lambda x: (x.sin(), x.cos()) + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) + >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) + + :func:`vjp` can even support outputs being Python structs + + >>> x = torch.randn([5]) + >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} + >>> vjps = vjpfunc(cotangents) + >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) + + The function returned by :func:`vjp` will compute the partials with + respect to each of the ``primals`` + + >>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) + >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) + >>> cotangents = torch.randn([5, 5]) + >>> vjps = vjpfunc(cotangents) + >>> assert len(vjps) == 2 + >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) + >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents)) + + ``primals`` are the positional arguments for ``f``. All kwargs use their + default value + + >>> x = torch.randn([5]) + >>> def f(x, scale=4.): + >>> return x * scale + >>> + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> vjps = vjpfunc(torch.ones_like(x)) + >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.)) + + .. note:: + Using PyTorch ``torch.no_grad`` together with ``vjp``. + Case 1: Using ``torch.no_grad`` inside a function: + + >>> def f(x): + >>> with torch.no_grad(): + >>> c = x ** 2 + >>> return x - c + + In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``. + + Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager: + + >>> # xdoctest: +SKIP(failing) + >>> with torch.no_grad(): + >>> vjp(f)(x) + + In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the + outer one. This is because ``vjp`` is a "function transform": its result + should not depend on the result of a context manager outside of ``f``. + """ + return _vjp_with_argnums(func, *primals, has_aux=has_aux) + + +@contextlib.contextmanager +def grad_increment_nesting(): + try: + grad_level = _grad_increment_nesting() + yield grad_level + finally: + _grad_decrement_nesting() + + +def enter_jvp_nesting(): + global JVP_NESTING + jvp_level = _jvp_increment_nesting() + JVP_NESTING += 1 + return jvp_level + + +def exit_jvp_nesting(): + global JVP_NESTING + _jvp_decrement_nesting() + JVP_NESTING -= 1 + + +@contextlib.contextmanager +def jvp_increment_nesting(): + try: + yield enter_jvp_nesting() + finally: + exit_jvp_nesting() + + +@doesnt_support_saved_tensors_hooks +def _vjp_with_argnums( + func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False +): + # This is the same function as vjp but also accepts an argnums argument + # All args are the same as vjp except for the added argument + # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to. + # If None, computes the gradients with respect to all inputs (used for vjp). Default: None + # + # WARN: Users should NOT call this function directly and should just be calling vjp. + # It is only separated so that inputs passed to jacrev but not differentiated get the correct wrappers. + # + # NOTE: All error messages are produced as if vjp was being called, even if this was called by jacrev + # + # Returns the same two elements as :func:`vjp` but the function returned, vjp_fn, returns a tuple of VJPs + # for only the primal elements given by argnums. + with grad_increment_nesting() as level: + # See NOTE [grad and vjp interaction with no_grad] + with torch.enable_grad(): + primals = _wrap_all_tensors(primals, level) + # Note for the reviewer: This is extremely odd but it passes the + # assertion "len(self.block_stack) == 1" on symbolic_convert.py + # The equivalent "if argnums is None" fails for some reason + if not isinstance(argnums, int) and not argnums: + diff_primals = _create_differentiable(primals, level) + else: + diff_primals = _slice_argnums(primals, argnums, as_tuple=False) + tree_map_(partial(_create_differentiable, level=level), diff_primals) + primals_out = func(*primals) + + if has_aux: + if not (isinstance(primals_out, tuple) and len(primals_out) == 2): + raise RuntimeError( + "vjp(f, *primals): output of function f should be a tuple: (output, aux) " + "if has_aux is True" + ) + primals_out, aux = primals_out + aux = _undo_create_differentiable(aux, level) + + flat_primals_out, primals_out_spec = tree_flatten(primals_out) + assert_non_empty_tensor_output(flat_primals_out, "vjp(f, *primals)") + flat_diff_primals, primals_spec = tree_flatten(diff_primals) + results = _undo_create_differentiable(primals_out, level) + + for primal_out in flat_primals_out: + assert isinstance(primal_out, torch.Tensor) + if primal_out.is_floating_point() or primal_out.is_complex(): + continue + raise RuntimeError( + "vjp(f, ...): All outputs of f must be " + "floating-point or complex Tensors, got Tensor " + f"with dtype {primal_out.dtype}" + ) + + def wrapper(cotangents, retain_graph=True, create_graph=None): + if create_graph is None: + create_graph = torch.is_grad_enabled() + flat_cotangents, cotangents_spec = tree_flatten(cotangents) + _vjp_treespec_compare(primals_out, cotangents) + result = _autograd_grad( + flat_primals_out, + flat_diff_primals, + flat_cotangents, + retain_graph=retain_graph, + create_graph=create_graph, + ) + return tree_unflatten(result, primals_spec) + + if has_aux: + return results, wrapper, aux + else: + return results, wrapper + + +def _safe_zero_index(x): + assert len(x) == 1 + return x[0] + + +# jacrev and jacfwd don't support complex functions +# Helper function to throw appropriate error. +def error_if_complex(func_name, args, is_input): + flat_args = pytree.tree_leaves(args) + for idx, arg in enumerate(flat_args): + if isinstance(arg, torch.Tensor) and arg.dtype.is_complex: + input_or_output = "inputs" if is_input else "outputs" + err_msg = ( + f"{func_name}: Expected all {input_or_output} " + f"to be real but received complex tensor at flattened input idx: {idx}" + ) + raise RuntimeError(err_msg) + + +@exposed_in("torch.func") +def jacrev( + func: Callable, + argnums: Union[int, Tuple[int]] = 0, + *, + has_aux=False, + chunk_size: Optional[int] = None, + _preallocate_and_copy=False, +): + """ + Computes the Jacobian of ``func`` with respect to the arg(s) at index + ``argnum`` using reverse mode autodiff + + .. note:: + Using :attr:`chunk_size=1` is equivalent to computing the jacobian + row-by-row with a for-loop i.e. the constraints of :func:`vmap` are + not applicable. + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Jacobian with respect to. + Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + auxiliary objects that will not be differentiated. + Default: False. + chunk_size (None or int): If None (default), use the maximum chunk size + (equivalent to doing a single vmap over vjp to compute the jacobian). + If 1, then compute the jacobian row-by-row with a for-loop. + If not None, then compute the jacobian :attr:`chunk_size` rows at a time + (equivalent to doing multiple vmap over vjp). If you run into memory issues computing + the jacobian, please try to specify a non-None chunk_size. + + Returns: + Returns a function that takes in the same inputs as ``func`` and + returns the Jacobian of ``func`` with respect to the arg(s) at + ``argnums``. If ``has_aux is True``, then the returned function + instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` + is the Jacobian and ``aux`` is auxiliary objects returned by ``func``. + + A basic usage with a pointwise, unary operation will give a diagonal array + as the Jacobian + + >>> from torch.func import jacrev + >>> x = torch.randn(5) + >>> jacobian = jacrev(torch.sin)(x) + >>> expected = torch.diag(torch.cos(x)) + >>> assert torch.allclose(jacobian, expected) + + If you would like to compute the output of the function as well as the + jacobian of the function, use the ``has_aux`` flag to return the output + as an auxiliary object: + + >>> from torch.func import jacrev + >>> x = torch.randn(5) + >>> + >>> def f(x): + >>> return x.sin() + >>> + >>> def g(x): + >>> result = f(x) + >>> return result, result + >>> + >>> jacobian_f, f_x = jacrev(g, has_aux=True)(x) + >>> assert torch.allclose(f_x, f(x)) + + :func:`jacrev` can be composed with vmap to produce batched + Jacobians: + + >>> from torch.func import jacrev, vmap + >>> x = torch.randn(64, 5) + >>> jacobian = vmap(jacrev(torch.sin))(x) + >>> assert jacobian.shape == (64, 5, 5) + + Additionally, :func:`jacrev` can be composed with itself to produce + Hessians + + >>> from torch.func import jacrev + >>> def f(x): + >>> return x.sin().sum() + >>> + >>> x = torch.randn(5) + >>> hessian = jacrev(jacrev(f))(x) + >>> assert torch.allclose(hessian, torch.diag(-x.sin())) + + By default, :func:`jacrev` computes the Jacobian with respect to the first + input. However, it can compute the Jacboian with respect to a different + argument by using ``argnums``: + + >>> from torch.func import jacrev + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacrev(f, argnums=1)(x, y) + >>> expected = torch.diag(2 * y) + >>> assert torch.allclose(jacobian, expected) + + Additionally, passing a tuple to ``argnums`` will compute the Jacobian + with respect to multiple arguments + + >>> from torch.func import jacrev + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacrev(f, argnums=(0, 1))(x, y) + >>> expectedX = torch.diag(torch.ones_like(x)) + >>> expectedY = torch.diag(2 * y) + >>> assert torch.allclose(jacobian[0], expectedX) + >>> assert torch.allclose(jacobian[1], expectedY) + + .. note:: + Using PyTorch ``torch.no_grad`` together with ``jacrev``. + Case 1: Using ``torch.no_grad`` inside a function: + + >>> def f(x): + >>> with torch.no_grad(): + >>> c = x ** 2 + >>> return x - c + + In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``. + + Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager: + + >>> with torch.no_grad(): + >>> jacrev(f)(x) + + In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the + outer one. This is because ``jacrev`` is a "function transform": its result + should not depend on the result of a context manager outside of ``f``. + """ + if not (chunk_size is None or chunk_size > 0): + raise ValueError("jacrev: `chunk_size` should be greater than 0.") + + def wrapper_fn(*args): + error_if_complex("jacrev", args, is_input=True) + vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux) + if has_aux: + output, vjp_fn, aux = vjp_out + else: + output, vjp_fn = vjp_out + + # See NOTE: [Computing jacobian with vmap and vjp for multiple outputs] + flat_output, output_spec = tree_flatten(output) + + error_if_complex("jacrev", flat_output, is_input=False) + + # NB: vjp already checks that all outputs are tensors + # Step 1: Construct grad_outputs by splitting the standard basis + flat_output_numels = tuple(out.numel() for out in flat_output) + + primals = _slice_argnums(args, argnums) + flat_primals, primals_spec = tree_flatten(primals) + + def compute_jacobian_stacked(): + # Helper function to compute chunked Jacobian + # The intermediate chunked calculation are only + # scoped at this function level. + chunked_results = [] + for flat_basis_chunk in _chunked_standard_basis_for_( + flat_output, flat_output_numels, chunk_size=chunk_size + ): + if chunk_size == 1: + # sanity check. + for t in flat_basis_chunk: + assert t.size(0) == 1 + + flat_basis_chunk = tree_map( + lambda t: torch.squeeze(t, 0), flat_basis_chunk + ) + + basis = tree_unflatten(flat_basis_chunk, output_spec) + + if chunk_size == 1: + # Behaviour with `chunk_size=1` is same as `for-loop` + # i.e. user shouldn't deal with the limitations of vmap. + chunked_result = vjp_fn(basis) + else: # chunk_size is None or chunk_size != 1 + chunked_result = vmap(vjp_fn)(basis) + + flat_results = pytree.tree_leaves(chunked_result) + + if chunk_size == 1: + flat_results = tree_map( + lambda t: torch.unsqueeze(t, 0), flat_results + ) + + chunked_results.append(flat_results) + + if len(chunked_results) == 1: + # Short-circuit if we used a single chunk + return chunked_results[0] + + # Concatenate chunks. + flat_results = [] + # Iterate and concat the jacobians of different + # inputs. + for idx in range(len(flat_primals)): + r = tuple(r_[idx] for r_ in chunked_results) + flat_results.append(torch.cat(r, 0)) + + return flat_results + + def compute_jacobian_preallocate_and_copy(): + # Helper function to compute chunked Jacobian + # The intermediate chunked calculation are only + # scoped at this function level. + out_vec_size = sum(flat_output_numels) + + # Don't pre-allocate if we have a single chunk. + if not (chunk_size is None or chunk_size >= out_vec_size): + stacked_results = [ + primal.new_zeros(out_vec_size, *primal.shape) + for primal in flat_primals + ] + + for idx, flat_basis_chunk in enumerate( + _chunked_standard_basis_for_( + flat_output, flat_output_numels, chunk_size=chunk_size + ) + ): + if chunk_size == 1: + # sanity check. + for t in flat_basis_chunk: + assert t.size(0) == 1 + + flat_basis_chunk = [torch.squeeze(t, 0) for t in flat_basis_chunk] + + basis = tree_unflatten(flat_basis_chunk, output_spec) + + if chunk_size == 1: + # Behaviour with `chunk_size=1` is same as `for-loop` + # i.e. user shouldn't deal with the limitations of vmap. + chunked_result = vjp_fn(basis) + else: # chunk_size is None or chunk_size != 1 + chunked_result = vmap(vjp_fn)(basis) + + flat_results = pytree.tree_leaves(chunked_result) + + # Short-circuit if we have a single chunk. + if chunk_size is None or chunk_size >= out_vec_size: + if chunk_size == 1: # and out_vec_size == 1 + # Since we squeezed the output dim + flat_results = tree_map( + lambda t: torch.unsqueeze(t, 0), flat_results + ) + return flat_results + + for r, sr in zip(flat_results, stacked_results): + sr[idx * chunk_size : (idx + 1) * chunk_size].copy_(r) + + return stacked_results + + if _preallocate_and_copy: + flat_jacobians_per_input = compute_jacobian_preallocate_and_copy() + else: + flat_jacobians_per_input = compute_jacobian_stacked() + + # Step 2: The returned jacobian is one big tensor per input. In this step, + # we split each Tensor by output. + flat_jacobians_per_input = [ + result.split(flat_output_numels, dim=0) + for result in flat_jacobians_per_input + ] + flat_input_flat_output = [ + tuple( + split.view(out.shape + primal.shape) + for split, out in zip(splits, flat_output) + ) + for splits, primal in zip(flat_jacobians_per_input, flat_primals) + ] + + # Step 3: Right now, `jacobian` is a List[List[Tensor]]. + # The outer List corresponds to the number of primals, + # the inner List corresponds to the number of outputs. + # We need to: + # a. Exchange the order of the outer List and inner List + # b. tree_unflatten the inner Lists (which correspond to the primals) + # c. handle the argnums=int case + # d. tree_unflatten the outer List (which corresponds to the outputs) + flat_output_flat_input = tuple(zip(*flat_input_flat_output)) + + flat_output_input = tuple( + tree_unflatten(flat_input, primals_spec) + for flat_input in flat_output_flat_input + ) + + if isinstance(argnums, int): + flat_output_input = tuple( + _safe_zero_index(flat_input) for flat_input in flat_output_input + ) + output_input = tree_unflatten(flat_output_input, output_spec) + if has_aux: + return output_input, aux + return output_input + + # Dynamo does not support HOP composition if their inner function is + # annotated with @functools.wraps(...). We circumvent this issue by applying + # wraps only if we're not tracing with dynamo. + if not torch._dynamo.is_compiling(): + wrapper_fn = wraps(func)(wrapper_fn) + + return wrapper_fn + + +# NOTE: [Computing jacobian with vmap and vjp for multiple outputs] +# +# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3). +# It turns out we can compute the jacobian of this function with a single +# call to autograd.grad by using vmap over the correct grad_outputs. +# +# Firstly, one way to compute the jacobian is to stack x**2 and x.sum() +# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()]) +# +# To get the first row of the jacobian, we call +# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0])) +# To get the 2nd row of the jacobian, we call +# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0])) +# and so on. +# +# Using vmap, we can vectorize all 4 of these computations into one by +# passing the standard basis for R^4 as the grad_output. +# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)). +# +# Now, how do we compute the jacobian *without stacking the output*? +# We can just split the standard basis across the outputs. So to +# compute the jacobian of f(x), we'd use +# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...)) +# The grad_outputs looks like the following: +# ( torch.tensor([[1, 0, 0], +# [0, 1, 0], +# [0, 0, 1], +# [0, 0, 0]]), +# torch.tensor([[0], +# [0], +# [0], +# [1]]) ) +# +# But we're not done yet! +# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...))) +# returns a Tensor of shape [4, 3]. We have to remember to split the +# jacobian of shape [4, 3] into two: +# - one of shape [3, 3] for the first output +# - one of shape [ 3] for the second output + + +def _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None): + # This function: + # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix. + # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`. + # - Each chunk corresponds to one tensor. The chunk has the same dtype and + # device as the tensor + # + # For example, with tensor_numels = [1, 2, 1], this function returns: + # ( tensor([[1], tensor([[0, 0], tensor([[0], + # [0], [1, 0], [0], + # [0], [0, 1], [0], + # [0]]) , [0, 0]]) , [1]]) ) + # + # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors) + # Precondition: tensors always has at least one element. + # + # See NOTE: [Computing jacobian with vmap and grad for multiple tensors] + # for context behind this function. + # NOTE: Argument `chunk_size` is used to generate chunked basis instead of + # one huge basis matrix. `chunk_size` dictates the maximum size of the + # basis matrix along dim=0. + assert len(tensors) == len(tensor_numels) + assert len(tensors) > 0 + assert chunk_size is None or chunk_size > 0 + total_numel = sum(tensor_numels) + if chunk_size and chunk_size < total_numel: + chunk_numels = get_chunk_sizes(total_numel, chunk_size) + else: # chunk_size is None or chunk_size >= total_numel + chunk_size = total_numel + chunk_numels = [total_numel] + + diag_start_indices = ( + 0, + *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind(), + ) + + for chunk_idx, total_numel in enumerate(chunk_numels): + chunks = tuple( + tensor.new_zeros(total_numel, tensor_numel) + for tensor, tensor_numel in zip(tensors, tensor_numels) + ) + + for chunk, diag_start_idx in zip(chunks, diag_start_indices): + chunk.diagonal(diag_start_idx + chunk_idx * chunk_size).fill_(1) + chunks = tuple( + chunk.view(total_numel, *tensor.shape) + for chunk, tensor in zip(chunks, tensors) + ) + yield chunks + + +def _construct_standard_basis_for(tensors, tensor_numels): + for basis in _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None): + return basis + + +def _validate_and_wrap_argnum(argnum, num_args): + if not isinstance(argnum, int): + raise RuntimeError(f"argnum must be int, got: {type(argnum)}") + if argnum >= 0 and argnum < num_args: + return argnum + if argnum < 0 and argnum >= -num_args: + return argnum + num_args + raise RuntimeError(f"Got argnum={argnum}, but only {num_args} positional inputs") + + +def _check_unique_non_empty(argnums): + if isinstance(argnums, tuple): + if len(argnums) == 0: + raise RuntimeError("argnums must be non-empty") + if len(set(argnums)) != len(argnums): + raise RuntimeError(f"argnums elements must be unique, got {argnums}") + + +def _replace_args(old_args, new_args, argnums): + if isinstance(argnums, int): + if len(new_args) != 1: + raise RuntimeError( + f"new_args should be of size 1, was of size {len(new_args)}" + ) + return tuple( + new_args[0] if i == argnums else old_args[i] for i in range(len(old_args)) + ) + if isinstance(argnums, tuple): + if len(new_args) != len(argnums): + raise RuntimeError( + "new_args should have the same size as argnums. " + f"Argnums size {len(argnums)}, new_args size {len(new_args)}" + ) + + def get_right_elem(i): + return new_args[argnums.index(i)] if i in argnums else old_args[i] + + return tuple(get_right_elem(i) for i in range(len(old_args))) + raise RuntimeError(f"argnums must be int or Tuple[int, ...], got: {type(argnums)}") + + +def _validate_and_wrap_argnums(argnums, num_args): + if isinstance(argnums, int): + return _validate_and_wrap_argnum(argnums, num_args) + if isinstance(argnums, tuple): + return tuple(_validate_and_wrap_argnum(argnum, num_args) for argnum in argnums) + raise AssertionError("Should never get here") + + +def _slice_argnums(args, argnums, as_tuple=True): + if not isinstance(argnums, int) and not isinstance(argnums, tuple): + raise RuntimeError( + f"argnums must be int or Tuple[int, ...], got: {type(argnums)}" + ) + argnums = _validate_and_wrap_argnums(argnums, len(args)) + _check_unique_non_empty(argnums) + if isinstance(argnums, int): + if as_tuple: + return (args[argnums],) + else: + return args[argnums] + return tuple(args[i] for i in argnums) + + +JVP_NESTING = 0 + + +def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None: + if not isinstance(elts, tuple): + raise RuntimeError( + f"{api}: Expected {argname} to be a tuple of Tensors, got {type(elts)}" + ) + for elt in elts: + if isinstance(elt, torch.Tensor): + continue + raise RuntimeError( + f"{api}: Expected {argname} to be a tuple of Tensors, got " + f"a tuple with an element of type {type(elt)}" + ) + if len(elts) == 0: + raise RuntimeError( + f"{api}: Expected {argname} to be a non-empty tuple of Tensors." + ) + + +def assert_non_empty_tensor_output(output: List[Any], api: str) -> None: + if (len(output) == 1 and output[0] is None) or len(output) < 1: + raise RuntimeError( + f"{api}: Expected f to be a function that has non-empty output (got output = {output})" + ) + for o in output: + if not isinstance(o, torch.Tensor): + raise RuntimeError( + f"{api}: expected f(*primals) to return only tensors" + f", got unsupported type {type(o)}" + ) + + +def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None: + if isinstance(output, torch.Tensor): + return + if not isinstance(output, tuple): + raise RuntimeError( + f"{api}: Expected output of f to be a Tensor or Tensors, got " + f"{type(output)}" + ) + if len(output) == 0: + raise RuntimeError( + f"{api}: Expected output of f to be a non-empty tuple of Tensors." + ) + for out in output: + if isinstance(out, torch.Tensor): + continue + raise RuntimeError( + f"{api}: Expected output of f to be a Tensor or Tensors, got " + f"{type(out)} as an output" + ) + + +def assert_non_empty_list_of_tensors( + output: List[torch.Tensor], api: str, argname: str +) -> None: + if len(output) == 0: + raise RuntimeError(f"{api}: Expected {argname} to contain at least one Tensor.") + for out in output: + if isinstance(out, torch.Tensor): + continue + raise RuntimeError( + f"{api}: Expected {argname} to only contain Tensors, got " f"{type(out)}" + ) + + +jvp_str = "jvp(f, primals, tangents)" + + +def safe_unpack_dual(dual, strict): + if not isinstance(dual, torch.Tensor): + raise RuntimeError( + f"{jvp_str}: expected f(*args) to return only tensors" + f", got unsupported type {type(dual)}" + ) + + primal, tangent = fwAD.unpack_dual(dual) + if tangent is None: + if strict: + raise RuntimeError( + "jvp(f, primals, tangents, strict=True): " + "The output of f is independent of " + "the inputs. This is not allowed with strict=True." + ) + tangent = torch.zeros_like(primal) + return primal, tangent + + +@exposed_in("torch.func") +def jvp( + func: Callable, + primals: Any, + tangents: Any, + *, + strict: bool = False, + has_aux: bool = False, +): + """ + Standing for the Jacobian-vector product, returns a tuple containing + the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at + ``primals``" times ``tangents``. This is also known as forward-mode autodiff. + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + primals (Tensors): Positional arguments to ``func`` that must all be + Tensors. The returned function will also be computing the + derivative with respect to these arguments + tangents (Tensors): The "vector" for which Jacobian-vector-product is + computed. Must be the same structure and sizes as the inputs to + ``func``. + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + other auxiliary objects that will not be differentiated. + Default: False. + + Returns: + Returns a ``(output, jvp_out)`` tuple containing the output of ``func`` + evaluated at ``primals`` and the Jacobian-vector product. + If ``has_aux is True``, then instead returns a ``(output, jvp_out, aux)`` tuple. + + .. note:: + You may see this API error out with "forward-mode AD not implemented + for operator X". If so, please file a bug report and we will prioritize it. + + jvp is useful when you wish to compute gradients of a function R^1 -> R^N + + >>> from torch.func import jvp + >>> x = torch.randn([]) + >>> f = lambda x: x * torch.tensor([1., 2., 3]) + >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) + >>> assert torch.allclose(value, f(x)) + >>> assert torch.allclose(grad, torch.tensor([1., 2, 3])) + + :func:`jvp` can support functions with multiple inputs by passing in the + tangents for each of the inputs + + >>> from torch.func import jvp + >>> x = torch.randn(5) + >>> y = torch.randn(5) + >>> f = lambda x, y: (x * y) + >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) + >>> assert torch.allclose(output, x + y) + + """ + + return _jvp_with_argnums( + func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux + ) + + +def _jvp_with_argnums( + func: Callable, + primals: Any, + tangents: Any, + argnums: Optional[argnums_t], + *, + strict: bool = False, + has_aux: bool, +): + # This is the same function as jvp but also accepts an argnums argument + # Most args are the same as jvp except for the added argument + # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to. + # If None, computes the gradients with respect to all inputs (used for jvp). Default: None + # Because of this, tangents must be of length argnums and matches up to the corresponding primal whose index is + # given by argnums + # + # WARN: Users should NOT call this function directly and should just be calling jvp. + # It is only separated so that inputs passed to jacfwd but not differentiated get the correct wrappers. + # + # NOTE: All error messages are produced as if jvp was being called, even if this was called by jacfwd + # + # Returns the same two elements as :func:`jvp` but the returned tuple, ``jvp_out``, only has JVPs with respect to + # the primals given by argnums + if not isinstance(primals, tuple): + raise RuntimeError( + f"{jvp_str}: Expected primals to be a tuple. " + f"E.g. it should be valid to call f(*primals)." + ) + diff_args = primals if argnums is None else _slice_argnums(primals, argnums) + flat_primals, primals_spec = tree_flatten(diff_args) + flat_tangents, tangents_spec = tree_flatten(tangents) + _jvp_treespec_compare(diff_args, tangents) + assert_non_empty_list_of_tensors(flat_primals, jvp_str, "primals") + assert_non_empty_list_of_tensors(flat_tangents, jvp_str, "tangents") + + global JVP_NESTING + + with jvp_increment_nesting() as level: + with fwAD._set_fwd_grad_enabled(True): + ctx = fwAD.dual_level if JVP_NESTING == 1 else contextlib.nullcontext + with ctx(): + flat_duals = tuple( + fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents) + ) + duals = tree_unflatten(flat_duals, primals_spec) + # Note for the reviewer: This is extremely odd but it passes the + # assertion "len(self.block_stack) == 1" on symbolic_convert.py + # The equivalent "if argnums is not None" fails for some reason + if isinstance(argnums, (int, tuple)): + primals = _wrap_all_tensors(primals, level) + duals = _replace_args(primals, duals, argnums) + result_duals = func(*duals) + if has_aux: + if not (isinstance(result_duals, tuple) and len(result_duals) == 2): + raise RuntimeError( + f"{jvp_str}: output of function f should be a tuple: (output, aux) " + "if has_aux is True" + ) + result_duals, aux = result_duals + aux = _undo_create_differentiable(aux, level) + + result_duals, spec = tree_flatten(result_duals) + assert_non_empty_tensor_output(result_duals, jvp_str) + + primals_out, tangents_out = zip( + *[safe_unpack_dual(dual, strict) for dual in result_duals] + ) + primals_out = tree_map( + partial(_undo_create_differentiable, level=level), primals_out + ) + tangents_out = tree_map( + partial(_undo_create_differentiable, level=level), tangents_out + ) + + primals_out_unflatten = tree_unflatten(primals_out, spec) + tangents_out_unflatten = tree_unflatten(tangents_out, spec) + if has_aux: + return primals_out_unflatten, tangents_out_unflatten, aux + + return primals_out_unflatten, tangents_out_unflatten + + +def safe_unflatten(tensor, dim, shape): + if len(shape) == 0: + assert tensor.shape[dim] == 1 + return tensor.squeeze(dim) + return tensor.unflatten(dim, shape) + + +@exposed_in("torch.func") +def jacfwd( + func: Callable, + argnums: argnums_t = 0, + has_aux: bool = False, + *, + randomness: str = "error", +): + """ + Computes the Jacobian of ``func`` with respect to the arg(s) at index + ``argnum`` using forward-mode autodiff + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Jacobian with respect to. + Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + auxiliary objects that will not be differentiated. + Default: False. + randomness(str): Flag indicating what type of randomness to use. + See :func:`vmap` for more detail. Allowed: "different", "same", "error". + Default: "error" + + Returns: + Returns a function that takes in the same inputs as ``func`` and + returns the Jacobian of ``func`` with respect to the arg(s) at + ``argnums``. If ``has_aux is True``, then the returned function + instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` + is the Jacobian and ``aux`` is auxiliary objects returned by ``func``. + + .. note:: + You may see this API error out with "forward-mode AD not implemented + for operator X". If so, please file a bug report and we will prioritize it. + An alternative is to use :func:`jacrev`, which has better operator coverage. + + A basic usage with a pointwise, unary operation will give a diagonal array + as the Jacobian + + >>> from torch.func import jacfwd + >>> x = torch.randn(5) + >>> jacobian = jacfwd(torch.sin)(x) + >>> expected = torch.diag(torch.cos(x)) + >>> assert torch.allclose(jacobian, expected) + + :func:`jacfwd` can be composed with vmap to produce batched + Jacobians: + + >>> from torch.func import jacfwd, vmap + >>> x = torch.randn(64, 5) + >>> jacobian = vmap(jacfwd(torch.sin))(x) + >>> assert jacobian.shape == (64, 5, 5) + + If you would like to compute the output of the function as well as the + jacobian of the function, use the ``has_aux`` flag to return the output + as an auxiliary object: + + >>> from torch.func import jacfwd + >>> x = torch.randn(5) + >>> + >>> def f(x): + >>> return x.sin() + >>> + >>> def g(x): + >>> result = f(x) + >>> return result, result + >>> + >>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x) + >>> assert torch.allclose(f_x, f(x)) + + Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev` + to produce Hessians + + >>> from torch.func import jacfwd, jacrev + >>> def f(x): + >>> return x.sin().sum() + >>> + >>> x = torch.randn(5) + >>> hessian = jacfwd(jacrev(f))(x) + >>> assert torch.allclose(hessian, torch.diag(-x.sin())) + + By default, :func:`jacfwd` computes the Jacobian with respect to the first + input. However, it can compute the Jacboian with respect to a different + argument by using ``argnums``: + + >>> from torch.func import jacfwd + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacfwd(f, argnums=1)(x, y) + >>> expected = torch.diag(2 * y) + >>> assert torch.allclose(jacobian, expected) + + Additionally, passing a tuple to ``argnums`` will compute the Jacobian + with respect to multiple arguments + + >>> from torch.func import jacfwd + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y) + >>> expectedX = torch.diag(torch.ones_like(x)) + >>> expectedY = torch.diag(2 * y) + >>> assert torch.allclose(jacobian[0], expectedX) + >>> assert torch.allclose(jacobian[1], expectedY) + + """ + + def wrapper_fn(*args): + error_if_complex("jacfwd", args, is_input=True) + primals = args if argnums is None else _slice_argnums(args, argnums) + flat_primals, primals_spec = tree_flatten(primals) + flat_primals_numels = tuple(p.numel() for p in flat_primals) + flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels) + basis = tree_unflatten(flat_basis, primals_spec) + + def push_jvp(basis): + output = _jvp_with_argnums( + func, args, basis, argnums=argnums, has_aux=has_aux + ) + # output[0] is the output of `func(*args)` + error_if_complex("jacfwd", output[0], is_input=False) + if has_aux: + _, jvp_out, aux = output + return jvp_out, aux + _, jvp_out = output + return jvp_out + + results = vmap(push_jvp, randomness=randomness)(basis) + if has_aux: + results, aux = results + # aux is in the standard basis format, e.g. NxN matrix + # We need to fetch the first element as original `func` output + flat_aux, aux_spec = tree_flatten(aux) + flat_aux = [value[0] for value in flat_aux] + aux = tree_unflatten(flat_aux, aux_spec) + + jac_outs, spec = tree_flatten(results) + # Most probably below output check can never raise an error + # as jvp should test the output before + # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)') + + jac_outs_ins = tuple( + tuple( + safe_unflatten(jac_out_in, -1, primal.shape) + for primal, jac_out_in in zip( + flat_primals, + jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1), + ) + ) + for jac_out in jac_outs + ) + jac_outs_ins = tuple( + tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins + ) + + if isinstance(argnums, int): + jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins) + if has_aux: + return tree_unflatten(jac_outs_ins, spec), aux + return tree_unflatten(jac_outs_ins, spec) + + # Dynamo does not support HOP composition if their inner function is + # annotated with @functools.wraps(...). We circumvent this issue by applying + # wraps only if we're not tracing with dynamo. + if not torch._dynamo.is_compiling(): + wrapper_fn = wraps(func)(wrapper_fn) + + return wrapper_fn + + +@exposed_in("torch.func") +def hessian(func, argnums=0): + """ + Computes the Hessian of ``func`` with respect to the arg(s) at index + ``argnum`` via a forward-over-reverse strategy. + + The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is + a good default for good performance. It is possible to compute Hessians + through other compositions of :func:`jacfwd` and :func:`jacrev` like + ``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``. + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Hessian with respect to. + Default: 0. + + Returns: + Returns a function that takes in the same inputs as ``func`` and + returns the Hessian of ``func`` with respect to the arg(s) at + ``argnums``. + + .. note:: + You may see this API error out with "forward-mode AD not implemented + for operator X". If so, please file a bug report and we will prioritize it. + An alternative is to use ``jacrev(jacrev(func))``, which has better + operator coverage. + + A basic usage with a R^N -> R^1 function gives a N x N Hessian: + + >>> from torch.func import hessian + >>> def f(x): + >>> return x.sin().sum() + >>> + >>> x = torch.randn(5) + >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) + >>> assert torch.allclose(hess, torch.diag(-x.sin())) + + """ + return jacfwd(jacrev(func, argnums), argnums) + + +@doesnt_support_saved_tensors_hooks +def grad_and_value_impl(func, argnums, has_aux, args, kwargs) -> Callable: + with grad_increment_nesting() as level: + output, aux, grad_input = None, None, None + # See NOTE [grad and vjp interaction with no_grad] + with torch.enable_grad(): + args = _wrap_all_tensors(args, level) + kwargs = _wrap_all_tensors(kwargs, level) + diff_args = _slice_argnums(args, argnums, as_tuple=False) + tree_map_(partial(_create_differentiable, level=level), diff_args) + + output = func(*args, **kwargs) + if has_aux: + if not (isinstance(output, tuple) and len(output) == 2): + raise RuntimeError( + "grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) " + "if has_aux is True" + ) + output, aux = output + + if not isinstance(output, torch.Tensor): + raise RuntimeError( + "grad_and_value(f)(*args): Expected f(*args) " + f"to return a Tensor, got {type(output)}" + ) + if output.dim() != 0: + raise RuntimeError( + "grad_and_value(f)(*args): Expected f(*args) " + "to return a scalar Tensor, got tensor with " + f"{output.dim()} dims. Maybe you wanted to " + "use the vjp or jacrev APIs instead?" + ) + + flat_diff_args, spec = tree_flatten(diff_args) + + # NB: need create_graph so that backward pass isn't run in no_grad mode + flat_outputs = _as_tuple(output) + flat_grad_input = _autograd_grad( + flat_outputs, flat_diff_args, create_graph=True + ) + grad_input = tree_unflatten(flat_grad_input, spec) + + grad_input = _undo_create_differentiable(grad_input, level) + output = _undo_create_differentiable(output, level) + if has_aux: + aux = _undo_create_differentiable(aux, level) + + if has_aux: + return grad_input, (output, aux) + return grad_input, output + + +def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs): + results = grad_and_value_impl(func, argnums, has_aux, args, kwargs) + if has_aux: + grad, (_, aux) = results + return grad, aux + grad, _ = results + return grad + + +def _maybe_wrap_functional_tensor( + maybe_tensor, level, *, _python_functionalize: bool = False +): + if not isinstance(maybe_tensor, torch.Tensor): + return maybe_tensor + wrapped = _wrap_functional_tensor(maybe_tensor, level) + _assert_wrapped_functional(maybe_tensor, wrapped) + if _python_functionalize: + out = FunctionalTensor(wrapped) + torch._mirror_autograd_meta_to(maybe_tensor, out) + return out + return wrapped + + +def _wrap_all_tensors_to_functional( + tensor_pytree, level, *, _python_functionalize: bool = False +): + return tree_map( + partial( + lambda x: _maybe_wrap_functional_tensor( + x, level, _python_functionalize=_python_functionalize + ) + ), + tensor_pytree, + ) + + +def _maybe_unwrap_functional_tensor(maybe_tensor, *, reapply_views: bool): + if not isinstance(maybe_tensor, torch.Tensor): + return maybe_tensor + if isinstance(maybe_tensor, FunctionalTensor): + maybe_tensor = maybe_tensor.elem + + if not torch._is_functional_tensor(maybe_tensor): + # If it's not a functional tensor, just return it. + # This can happen if we functionalize a fn that returns a global, + # which was never wrapped properly. + return maybe_tensor + # Sync any pending updates on the output tensor + torch._sync(maybe_tensor) + return _unwrap_functional_tensor(maybe_tensor, reapply_views) + + +def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool): + return tree_map( + lambda t: _maybe_unwrap_functional_tensor(t, reapply_views=reapply_views), + tensor_pytree, + ) + + +@exposed_in("torch.func") +def functionalize(func: Callable, *, remove: str = "mutations") -> Callable: + """ + functionalize is a transform that can be used to remove (intermediate) + mutations and aliasing from a function, while preserving the function's + semantics. + + ``functionalize(func)`` returns a new function with the same semantics + as ``func``, but with all intermediate mutations removed. + Every inplace operation performed on an intermediate tensor: + ``intermediate.foo_()`` + gets replaced by its out-of-place equivalent: + ``intermediate_updated = intermediate.foo()``. + + functionalize is useful for shipping a pytorch program off to + backends or compilers that aren't able to easily represent + mutations or aliasing operators. + + Args: + func (Callable): A Python function that takes one or more arguments. + remove (str): An optional string argument, that takes on either + the value 'mutations' or 'mutations_and_views'. + If 'mutations' is passed in then all mutating operators + will be replaced with their non-mutating equivalents. + If 'mutations_and_views' is passed in, then additionally, all aliasing + operators will be replaced with their non-aliasing equivalents. + Default: 'mutations'. + + Returns: + Returns a new "functionalized" function. It takes the same inputs as + ``func``, and has the same behavior, but any mutations + (and optionally aliasing) performed on intermediate tensors + in the function will be removed. + + functionalize will also remove mutations (and views) that were performed on function inputs. + However to preserve semantics, functionalize will "fix up" the mutations after + the transform has finished running, by detecting if any tensor inputs "should have" + been mutated, and copying the new data back to the inputs if necessary. + + + Example:: + + >>> # xdoctest: +SKIP + >>> import torch + >>> from torch.fx.experimental.proxy_tensor import make_fx + >>> from torch.func import functionalize + >>> + >>> # A function that uses mutations and views, but only on intermediate tensors. + >>> def f(a): + ... b = a + 1 + ... c = b.view(-1) + ... c.add_(1) + ... return b + ... + >>> inpt = torch.randn(2) + >>> + >>> out1 = f(inpt) + >>> out2 = functionalize(f)(inpt) + >>> + >>> # semantics are the same (outputs are equivalent) + >>> print(torch.allclose(out1, out2)) + True + >>> + >>> f_traced = make_fx(f)(inpt) + >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt) + >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) + >>> + >>> print(f_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view = torch.ops.aten.view(add, [-1]) + add_ = torch.ops.aten.add_(view, 1); view = None + return add + + >>> print(f_no_mutations_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view = torch.ops.aten.view(add, [-1]); add = None + add_1 = torch.ops.aten.add(view, 1); view = None + view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None + return view_1 + + >>> print(f_no_mutations_and_views_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view_copy = torch.ops.aten.view_copy(add, [-1]); add = None + add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None + view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None + return view_copy_1 + + + >>> # A function that mutates its input tensor + >>> def f(a): + ... b = a.view(-1) + ... b.add_(1) + ... return a + ... + >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) + >>> # + >>> # All mutations and views have been removed, + >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input + >>> # after the function has completed. + >>> print(f_no_mutations_and_views_traced.code) + + + + def forward(self, a_1): + view_copy = torch.ops.aten.view_copy(a_1, [-1]) + add = torch.ops.aten.add(view_copy, 1); view_copy = None + view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None + copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None + return view_copy_1 + + + There are a few "failure modes" for functionalize that are worth calling out: + (1) Like other torch.func transforms, `functionalize()` doesn't work with functions + that directly use `.backward()`. The same is true for torch.autograd.grad. + If you want to use autograd, you can compute gradients directly + with `functionalize(grad(f))`. + (2) Like other torch.func transforms, `functionalize()` doesn't work with global state. + If you call `functionalize(f)` on a function that takes views / mutations of + non-local state, functionalization will simply no-op and pass the view/mutation + calls directly to the backend. + One way to work around this is is to ensure that any non-local state creation + is wrapped into a larger function, which you then call functionalize on. + (3) `resize_()` has some limitations: functionalize will only work on programs + that use resize_()` as long as the tensor being resized is not a view. + (4) `as_strided()` has some limitations: functionalize will not work on + `as_strided()` calls that result in tensors with overlapping memory. + + + Finally, a helpful mental model for understanding functionalization is that + most user pytorch programs are writing with the public torch API. + When executed, torch operators are generally decomposed into + our internal C++ "ATen" API. + The logic for functionalization happens entirely at the level of ATen. + Functionalization knows how to take every aliasing operator in ATen, + and map it to its non-aliasing equivalent + (e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``), + and how to take every mutating operator in ATen, + and map it to its non-mutating equivalent + (e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``), + while tracking aliases and mutations out-of-line to know when to fix things up. + Information about which ATen operators are aliasing or mutating all comes from + https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml. + """ + if remove == "mutations": + reapply_views = True + elif remove == "mutations_and_views": + reapply_views = False + else: + raise RuntimeError( + f"functionalize(f, remove='mutations'): received invalid argument for remove={remove}." + " Valid options are:\n" + " remove='mutations': all inplace and out= operators will be removed from the program, and replaced" + " with their out-of-place equivalents.\n" + " remove='mutations_and_views': In addition to the above, all aliasing operators {view} will be" + " replaced with their non-aliasing counterparts, {view}_copy.\n" + ) + + @wraps(func) + def wrapped(*args, **kwargs): + try: + func_level = _func_increment_nesting(reapply_views) + func_args = _wrap_all_tensors_to_functional(args, func_level) + func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level) + + flattened_unwrapped_args = pytree.arg_tree_leaves(*args) + flattened_wrapped_args = pytree.arg_tree_leaves(*func_args) + flattened_unwrapped_kwargs = pytree.arg_tree_leaves(**kwargs) + flattened_wrapped_kwargs = pytree.arg_tree_leaves(**func_kwargs) + + func_outputs = func(*func_args, **func_kwargs) + outputs = _unwrap_all_tensors_from_functional( + func_outputs, reapply_views=reapply_views + ) + flat_outputs, func_out_spec = tree_flatten(outputs) + + for a in flattened_wrapped_args + flattened_wrapped_kwargs: + if isinstance(a, torch.Tensor): + # Call sync_() on the inputs, to ensure that any pending mutations have been applied. + torch._sync(a) + + # And if any mutations were applied to the inputs, we need to propagate them back to the user. + for unwrapped, wrapped in zip( + flattened_unwrapped_args, flattened_wrapped_args + ): + if isinstance(unwrapped, torch.Tensor) and isinstance( + wrapped, torch.Tensor + ): + _propagate_functional_input_mutation(unwrapped, wrapped) + for unwrapped, wrapped in zip( + flattened_unwrapped_kwargs, flattened_wrapped_kwargs + ): + if isinstance(unwrapped, torch.Tensor) and isinstance( + wrapped, torch.Tensor + ): + _propagate_functional_input_mutation(unwrapped, wrapped) + + return outputs + finally: + _func_decrement_nesting() + + return wrapped + + +@exposed_in("torch.func") +def linearize(func: Callable, *primals) -> Tuple[Any, Callable]: + """ + Returns the value of ``func`` at ``primals`` and linear approximation + at ``primals``. + + Args: + func (Callable): A Python function that takes one or more arguments. + primals (Tensors): Positional arguments to ``func`` that must all be + Tensors. These are the values at which the function is linearly approximated. + + Returns: + Returns a ``(output, jvp_fn)`` tuple containing the output of ``func`` + applied to ``primals`` and a function that computes the jvp of + ``func`` evaluated at ``primals``. + + linearize is useful if jvp is to be computed multiple times at ``primals``. However, + to achieve this, linearize saves intermediate computation and has higher memory requirements + than directly applying `jvp`. So, if all the ``tangents`` are known, it maybe more efficient + to compute vmap(jvp) instead of using linearize. + + .. note:: + linearize evaluates ``func`` twice. Please file an issue for an implementation + with a single evaluation. + + Example:: + >>> import torch + >>> from torch.func import linearize + >>> def fn(x): + ... return x.sin() + ... + >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) + >>> jvp_fn(torch.ones(3, 3)) + tensor([[1., 1., 1.], + [1., 1., 1.], + [1., 1., 1.]]) + >>> + + """ + # Note: We evaluate `fn` twice. + # Once for returning the output and other while + # tracing the graph. + # If this becomes a bottle-neck, we should update + # make_fx such that it also returns the output. + + output = func(*primals) + _, output_spec = tree_flatten(output) + + flat_primals, primals_argspec = tree_flatten(primals) + + # tangents for tracing + flat_tangents = tuple(p.new_empty(()).expand_as(p) for p in flat_primals) + + # function to trace + def trace_fn(flat_tangents): + with fwAD.dual_level(): + flat_duals = tuple( + fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents) + ) + duals = tree_unflatten(flat_duals, primals_argspec) + output = func(*duals) + tangents = tree_map_only( + torch.Tensor, lambda dual: safe_unpack_dual(dual, False)[1], output + ) + + return tangents + + jvp_graph = lazy_dynamo_disallow(make_fx)(trace_fn)(flat_tangents) + const_folded_jvp_graph = lazy_dynamo_disallow(const_fold.split_const_subgraphs)( + jvp_graph + ) + + # Hold only the meta-data regarding the primals. + flat_primals_shape = tuple(p.shape for p in flat_primals) + flat_primals_device = tuple(p.device for p in flat_primals) + flat_primals_dtype = tuple(p.dtype for p in flat_primals) + + def forward_ad_checks(flat_tangents): + for idx, t in enumerate(flat_tangents): + if t.shape != flat_primals_shape[idx]: + msg = ( + f"tangent:{idx} with shape {t.shape} in flattened " + f"pytree doesn't match the shape {flat_primals_shape[idx]} " + "of the corresponding primal." + ) + raise RuntimeError(msg) + + if t.device != flat_primals_device[idx]: + msg = ( + f"tangent:{idx} with device {t.device} in flattened " + f"pytree doesn't match the device {flat_primals_device[idx]} " + "of the corresponding primal." + ) + raise RuntimeError(msg) + + if t.dtype != flat_primals_dtype[idx]: + msg = ( + f"tangent:{idx} with dtype {t.dtype} in flattened " + f"pytree doesn't match the dtype {flat_primals_dtype[idx]} " + "of the corresponding primal." + ) + raise RuntimeError(msg) + + # jvp_fn : callable to return + # It takes care of checking the argspec of tangents, + # calling the folded fx graph and unflattening fx graph output + def jvp_fn(*tangents): + flat_tangents, tangent_argspec = tree_flatten(tangents) + _linearize_treespec_compare(primals, tangents) + + forward_ad_checks(flat_tangents) + + flat_output = const_folded_jvp_graph(*flat_tangents) + # const folded graph can return flat output, + # so transform output. + return tree_unflatten(flat_output, output_spec) + + return output, jvp_fn diff --git a/lib/python3.10/site-packages/torch/_functorch/functional_call.py b/lib/python3.10/site-packages/torch/_functorch/functional_call.py new file mode 100644 index 0000000000000000000000000000000000000000..86c63be17fc9d0390b1e1172bb8e249c7b2a45b5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/functional_call.py @@ -0,0 +1,253 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch._functorch.utils import exposed_in + + +@exposed_in("torch.func") +def functional_call( + module: "torch.nn.Module", + parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]], + args: Union[Any, Tuple], + kwargs: Optional[Dict[str, Any]] = None, + *, + tie_weights: bool = True, + strict: bool = False, +): + r"""Performs a functional call on the module by replacing the module parameters + and buffers with the provided ones. + + .. note:: If the module has active parametrizations, passing a value in the + :attr:`parameter_and_buffer_dicts` argument with the name set to the regular parameter + name will completely disable the parametrization. + If you want to apply the parametrization function to the value passed + please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``. + + .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected + in the ``parameter_and_buffer_dicts`` input. + + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # does self.foo = self.foo + 1 + >>> print(mod.foo) # tensor(0.) + >>> functional_call(mod, a, torch.ones(())) + >>> print(mod.foo) # tensor(0.) + >>> print(a['foo']) # tensor(1.) + + .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the + tie_weights flag. + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied + >>> print(mod.foo) # tensor(1.) + >>> mod(torch.zeros(())) # tensor(2.) + >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too + >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated + >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} + >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.) + + An example of passing multiple dictionaries + + .. code-block:: python + + a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries + mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer + print(mod.weight) # tensor(...) + print(mod.buffer) # tensor(...) + x = torch.randn((1, 1)) + print(x) + functional_call(mod, a, x) # same as x + print(mod.weight) # same as before functional_call + + + And here is an example of applying the grad transform over the parameters + of a model. + + .. code-block:: python + + import torch + import torch.nn as nn + from torch.func import functional_call, grad + + x = torch.randn(4, 3) + t = torch.randn(4, 3) + model = nn.Linear(3, 3) + + def compute_loss(params, x, t): + y = functional_call(model, params, x) + return nn.functional.mse_loss(y, t) + + grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t) + + .. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the + parameters for better performance and memory usage + + Example:: + + >>> detached_params = {k: v.detach() for k, v in model.named_parameters()} + >>> grad_weights = grad(compute_loss)(detached_params, x, t) + >>> grad_weights.grad_fn # None--it's not tracking gradients outside of grad + + This means that the user cannot call ``grad_weight.backward()``. However, if they don't need autograd tracking + outside of the transforms, this will result in less memory usage and faster speeds. + + Args: + module (torch.nn.Module): the module to call + parameters_and_buffer_dicts (Dict[str, Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in + the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can + be used together + args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument. + kwargs (dict): keyword arguments to be passed to the module call + tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as + tied in the reparameterized version. Therefore, if True and different values are passed for the tied + parameters and buffers, it will error. If False, it will not respect the originally tied parameters and + buffers unless the values passed for both weights are the same. Default: True. + strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and + buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will + error. Default: False. + + Returns: + Any: the result of calling ``module``. + """ + if isinstance(parameter_and_buffer_dicts, dict): + parameters_and_buffers = parameter_and_buffer_dicts + elif isinstance(parameter_and_buffer_dicts, Sequence): + if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts): + raise ValueError( + "Expected all elements of parameter_and_buffer_dicts to be dictionaries" + ) + all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()] + all_keys_counter: Dict[str, int] = {} + for k in all_keys: + v = all_keys_counter.get(k, 0) + all_keys_counter[k] = v + 1 + repeated_keys = [key for key, n in all_keys_counter.items() if n > 1] + if len(repeated_keys) > 0: + raise ValueError( + f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous" + ) + parameters_and_buffers = { + k: v for d in parameter_and_buffer_dicts for k, v in d.items() + } + else: + raise ValueError( + f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, " + f"but got {type(parameter_and_buffer_dicts)}" + ) + + return nn.utils.stateless._functional_call( + module, + parameters_and_buffers, + args, + kwargs, + tie_weights=tie_weights, + strict=strict, + ) + + +@exposed_in("torch.func") +def stack_module_state( + models: List[nn.Module], +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """stack_module_state(models) -> params, buffers + + Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`. + + Given a list of ``M`` ``nn.Modules`` of the same class, returns two dictionaries + that stack all of their parameters and buffers together, indexed by name. + The stacked parameters are optimizable (i.e. they are new leaf nodes in the + autograd history that are unrelated to the original parameters and can be + passed directly to an optimizer). + + Here's an example of how to ensemble over a very simple model: + + .. code-block:: python + + num_models = 5 + batch_size = 64 + in_features, out_features = 3, 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + data = torch.randn(batch_size, 3) + + def wrapper(params, buffers, data): + return torch.func.functional_call(models[0], (params, buffers), data) + + params, buffers = stack_module_state(models) + output = vmap(wrapper, (0, 0, None))(params, buffers, data) + + assert output.shape == (num_models, batch_size, out_features) + + When there's submodules, this follows state dict naming conventions + + .. code-block:: python + + import torch.nn as nn + class Foo(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + hidden = 4 + self.l1 = nn.Linear(in_features, hidden) + self.l2 = nn.Linear(hidden, out_features) + + def forward(self, x): + return self.l2(self.l1(x)) + + num_models = 5 + in_features, out_features = 3, 3 + models = [Foo(in_features, out_features) for i in range(num_models)] + params, buffers = stack_module_state(models) + print(list(params.keys())) # "l1.weight", "l1.bias", "l2.weight", "l2.bias" + + .. warning:: + All of the modules being stacked together must be the same (except for + the values of their parameters/buffers). For example, they should be in the + same mode (training vs eval). + """ + if len(models) == 0: + raise RuntimeError("stack_module_state: Expected at least one model, got 0.") + if not (all(m.training for m in models) or all(not m.training for m in models)): + raise RuntimeError( + "stack_module_state: Expected all models to have the same training/eval mode." + ) + model0_typ = type(models[0]) + if not all(type(m) == model0_typ for m in models): + raise RuntimeError( + "stack_module_state: Expected all models to be of the same class." + ) + all_params = [dict(model.named_parameters()) for model in models] + params = { + k: construct_stacked_leaf(tuple(params[k] for params in all_params), k) + for k in all_params[0] + } + all_buffers = [dict(model.named_buffers()) for model in models] + buffers = { + k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k) + for k in all_buffers[0] + } + + return params, buffers + + +def construct_stacked_leaf( + tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str +) -> Tensor: + all_requires_grad = all(t.requires_grad for t in tensors) + none_requires_grad = all(not t.requires_grad for t in tensors) + if not all_requires_grad and not none_requires_grad: + raise RuntimeError( + f"Expected {name} from each model to have the same .requires_grad" + ) + result = torch.stack(tensors) + if all_requires_grad: + result = result.detach().requires_grad_() + return result diff --git a/lib/python3.10/site-packages/torch/_functorch/fx_minifier.py b/lib/python3.10/site-packages/torch/_functorch/fx_minifier.py new file mode 100644 index 0000000000000000000000000000000000000000..d908a666d25350ce6af095779eb4f5efc6975962 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/fx_minifier.py @@ -0,0 +1,501 @@ +# mypy: ignore-errors + +import copy +import math +import os +import sys +from dataclasses import dataclass +from functools import partial, wraps +from typing import Callable, List + +import torch +import torch.fx as fx +from torch.hub import tqdm +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._content_store import ContentStoreWriter + +from .compile_utils import get_outputs, get_placeholders + + +is_tuple = object() + + +@dataclass +class LoadTensorMeta: + size: List[int] + stride: List[int] + dtype: torch.dtype + device: torch.device + + +class ConcreteProp(torch.fx.Interpreter): + def __init__(self, mod, *, writer=None, skip_offload=False): + super().__init__(mod) + self.writer = writer + self.skip_offload = skip_offload + self.seen_storages = set() + + def run_node(self, n): + self.pbar.update(1) + r = super().run_node(n) + name = n.name + + if isinstance(r, torch.Tensor): + if self.writer is None: + n.meta["concrete_value"] = r + else: + if StorageWeakRef(r.untyped_storage()) in self.seen_storages: + # Refuse to offload tensors which alias other live + # tensors, because this will violate operator contracts + n.meta["concrete_value"] = None + else: + if not self.skip_offload: + self.writer.write_tensor(os.path.join("eager", name), r) + n.meta["concrete_value"] = LoadTensorMeta( + r.size(), r.stride(), r.dtype, r.device + ) + self.seen_storages.add(StorageWeakRef(r.untyped_storage())) + else: + n.meta["concrete_value"] = is_tuple + + return r + + def propagate(self, *args): + with tqdm( + desc="Saving intermediates for delta debugging", + total=len(self.module.graph.nodes), + disable=self.writer is None, + ) as pbar: + self.pbar = pbar + r = super().run(*args) + if not self.skip_offload: + pbar.set_description( + "Saved! To skip next time, run with --skip-saving-eager-intermediates" + ) + return r + + +def is_load_tensor_node(node): + return ( + node.op == "call_function" + and node.target is torch.ops.debugprims.load_tensor.default + ) + + +# inplace modifies node/inps +def _convert_node_to_placeholder(graph, node, inps): + if node.op == "output" or node.op == "placeholder": + return False + + if is_load_tensor_node(node): + return False + + concrete_val = node.meta.get("concrete_value", None) + + if isinstance(concrete_val, torch.Tensor): + node.op = "placeholder" + node.target = node.name + node.args = () + node.kwargs = {} + + inps.append(concrete_val) + return True + + elif concrete_val is None: + return False + + elif concrete_val is is_tuple: + r = False + for tuple_user in list(node.users): + r = _convert_node_to_placeholder(graph, tuple_user, inps) or r + # NB: We must not erase the node at this point, because + # we are iterating over the nodes and this would change + # the iteration order + # graph.erase_node(node) + return r + + elif isinstance(concrete_val, LoadTensorMeta): + node.op = "call_function" + node.target = torch.ops.debugprims.load_tensor.default + node.args = ( + os.path.join("eager", node.name), + concrete_val.size, + concrete_val.stride, + ) + node.kwargs = { + "device": concrete_val.device, + "dtype": concrete_val.dtype, + } + return True + + return False + + +def create_minified_hlo_graph(minified_fx_graph, inputs): + """ + Takes minified FX graph as primary input, and ports it to HLO via StableHLO + Provides minified HLO graph as output, and archive them to local directory + """ + hlo_dir = f"{os.getcwd()}/hlo_files" + os.makedirs(hlo_dir, exists_ok=True) + + from torch_xla.stablehlo import save_torch_model_as_stablehlo + + save_torch_model_as_stablehlo(minified_fx_graph, inputs, hlo_dir) + + +def dump_state(fx_g, inps): + print( + f""" +# Working Repro with {len(fx_g.graph.nodes)} nodes +inps = {[(i.shape, i.dtype, i.device.type) for i in inps]} +inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device=device) for (shape, dtype, device) in inps] +{fx_g.code} +""" + ) + + +def is_power_of_two(n): + if n == 0: + return False + return (n & (n - 1)) == 0 + + +@dataclass +class ReproState: + graph: fx.Graph + inps: List[torch.Tensor] + + def __post_init__(self): + ph_nodes = get_placeholders(self.graph) + assert len(ph_nodes) == len(self.inps) + + +def minifier( + fail_f: fx.GraphModule, + inps, + module_fails, + dump_state: Callable = dump_state, + *, + save_dir=None, + offload_to_disk=False, + skip_offload=False, + skip_sanity=False, + max_granularity=None, +): + """ + Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails. + + Does 2 main strategies: + 1. Truncates suffix: Removes some suffix from the graph and sets a new output. + 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, + tries replacing quarter of the graph, etc. + + >>> # xdoctest: +SKIP(failing) + >>> failing_function = fx.symbolic_trace(f) + >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps)) + + note: module_fails returns True if it fails. + """ + assert isinstance(inps, (tuple, list)) + + failing_graph = fail_f.graph + cur_size = len(failing_graph.nodes) + + if max_granularity is not None and not is_power_of_two(max_granularity): + raise RuntimeError(f"max_granularity {max_granularity} not power of two") + + num_queries = 0 + + def deepcopy_fx_graph(fx_graph): + return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph + + def graph_fails(graph, inps): + nonlocal num_queries + graph = copy.deepcopy(graph) + num_queries += 1 + mod = fx.GraphModule(fail_f, graph) + mod.graph.lint() + return module_fails(mod, inps) + + writer = None + if offload_to_disk: + writer = ContentStoreWriter(save_dir) + + ConcreteProp(fail_f, writer=writer, skip_offload=skip_offload).propagate(*inps) + if not skip_sanity and not graph_fails(failing_graph, inps): + raise RuntimeError("Input graph did not fail the tester") + print(f"Started off with {cur_size} nodes", file=sys.stderr) + + def _register_strategy(strategy: Callable, name: str): + @wraps(strategy) + def new_func(old_state: ReproState, granularity=1): + print(file=sys.stderr) + print( + f"Strategy: {name} (G: {granularity}) " + f"({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)", + file=sys.stderr, + ) + new_state = strategy( + deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity + ) + if new_state is not None: + new_nodes = len(new_state.graph.nodes) + old_nodes = len(old_state.graph.nodes) + new_inps = len(new_state.inps) + old_inps = len(old_state.inps) + new_outs = len(get_outputs(new_state.graph)) + old_outs = len(get_outputs(old_state.graph)) + progress_made = False + if new_nodes < old_nodes: + progress_made = True + print( + f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes", + file=sys.stderr, + ) + if new_inps > old_inps: + progress_made = True + print( + f"SUCCESS: Went from {old_inps} to {new_inps} inputs", + file=sys.stderr, + ) + if new_outs < old_outs: + progress_made = True + print( + f"SUCCESS: Went from {old_outs} to {new_outs} outputs", + file=sys.stderr, + ) + + if not progress_made: + raise RuntimeError("Success raised but no progress made?") + + if not graph_fails(new_state.graph, new_state.inps): + print( + "WARNING: Something went wrong, not applying this minification", + file=sys.stderr, + ) + return None + return new_state + else: + print(f"FAIL: {name}", file=sys.stderr) + return None + + return new_func + + def register_strategy(name: str): + return partial(_register_strategy, name=name) + + @register_strategy("Truncate suffix") + def remove_suffix(cur_graph, cur_inps, granularity): + tested = set() + new_graph = fx.Graph() + env = {} + for idx, node in enumerate(cur_graph.nodes): + new_node = new_graph.node_copy(node, lambda x: env[x]) + if node.op not in ["placeholder", "output"]: + # If idx is divisible by (granularity * 2), it would have been checked already. + if ( + idx % granularity == 0 + and (idx % (granularity * 2) != 0) + and idx not in tested + ): + output_node = new_graph.output((new_node,)) + if len(new_graph.nodes) < len(cur_graph.nodes) and graph_fails( + new_graph, cur_inps + ): + return ReproState(new_graph, cur_inps) + else: + tested.add(idx) + new_graph.erase_node(output_node) + env[node] = new_node + return None + + @register_strategy("Remove outputs") + def remove_outputs(cur_graph, cur_inps, granularity): + granularity = max(1, granularity // 2) + for idx, node in enumerate(cur_graph.nodes): + node.idx = idx + if node.op == "output": + output = node + break + + if isinstance(output.args[0], fx.Node): + return None + + output_args = sorted( + output.args[0], key=lambda x: x.idx if isinstance(x, fx.Node) else int(1e9) + ) + if len(output_args) == 1: + return None + + for idx in range(0, len(output_args), granularity): + output.args = (output_args[:idx] + output_args[idx + granularity :],) + if graph_fails(cur_graph, cur_inps): + return ReproState(cur_graph, cur_inps) + return None + + def remove_unused_inputs_unchecked(cur_state: ReproState): + cur_graph = cur_state.graph + cur_inps = cur_state.inps + ph_nodes = get_placeholders(cur_graph) + assert len(ph_nodes) == len(cur_inps) + + new_inps = [] + for idx in range(len(ph_nodes)): + if len(ph_nodes[idx].users) == 0: + cur_graph.erase_node(ph_nodes[idx]) + else: + new_inps.append(cur_inps[idx]) + if len(new_inps) < len(cur_inps): + return ReproState(cur_graph, new_inps) + return None + + def remove_unused_inputs_checked(cur_state: ReproState): + new_state = remove_unused_inputs_unchecked(cur_state) + if new_state is not None and graph_fails(new_state.graph, new_state.inps): + return new_state + return None + + def _remove_unused_wrapper(cur_graph, cur_inps, granularity): + return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps)) + + remove_unused_inputs = register_strategy("Remove unused inputs")( + _remove_unused_wrapper + ) + + @register_strategy("Eliminate dead code") + def eliminate_dead_code(cur_graph, cur_inps, granularity): + if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps): + return ReproState(cur_graph, cur_inps) + return None + + def _consolidate_placeholders(cur_graph, inps): + new_graph = fx.Graph() + env = {} + seen_non_placeholder = False + + # Move all placeholders to the front; also, if any load_tensor + # is at the front, convert it into an input (because it can be live + # all the time) + for node in cur_graph.nodes: + if node.op == "placeholder": + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + elif not seen_non_placeholder and is_load_tensor_node(node): + new_node = new_graph.placeholder(node.name) + env[node] = new_node + inps.append( + torch.ops.debugprims.load_tensor.default(*node.args, **node.kwargs) + ) + else: + seen_non_placeholder = True + + # Move everyone else + for node in cur_graph.nodes: + if node not in env: + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + return new_graph + + @register_strategy("Delta Debugging") + def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity): + num_nodes = len(cur_graph.nodes) + for start_range in range(0, num_nodes, granularity): + is_removing = False + new_graph = deepcopy_fx_graph(cur_graph) + new_inps = cur_inps[:] + end_range = min(num_nodes, start_range + granularity) + for idx in range(start_range, end_range): + new_node = list(new_graph.nodes)[idx] + if _convert_node_to_placeholder(new_graph, new_node, new_inps): + is_removing = True + if not is_removing: + continue + new_graph.eliminate_dead_code() + new_graph = _consolidate_placeholders(new_graph, new_inps) + new_state = remove_unused_inputs_unchecked(ReproState(new_graph, new_inps)) + if new_state is None: + new_state = ReproState(new_graph, new_inps) + if graph_fails(new_state.graph, new_state.inps): + return ReproState(new_state.graph, new_state.inps) + + return None + + @register_strategy("Consolidate Inputs") + def consolidate_inputs(cur_graph, cur_inps, granularity): + old_len = len(cur_inps) + cur_graph = _consolidate_placeholders(cur_graph, cur_inps) + if len(cur_inps) > old_len and graph_fails(cur_graph, cur_inps): + return ReproState(cur_graph, cur_inps) + return None + + failing_state = ReproState(failing_graph, inps) + + def try_granularity(failing_state, granularity, use_non_granular): + print(f"Trying granularity {granularity}", file=sys.stderr) + + strategies = [] + num_nodes = len(failing_state.graph.nodes) + num_outputs = len(get_outputs(failing_state.graph)) + if num_outputs > num_nodes // 2: + strategies += [remove_outputs] + + if use_non_granular: + strategies += [ + eliminate_dead_code, + remove_unused_inputs, + consolidate_inputs, + ] + + strategies += [remove_suffix, delta_debugging] + + for strategy in strategies: + new_state = strategy(failing_state, granularity) + if new_state is not None: + return new_state + return None + + while True: + dump_state(fx.GraphModule(fail_f, failing_state.graph), failing_state.inps) + granularity = int(2 ** (math.floor(math.log2(len(failing_state.graph.nodes))))) + if max_granularity is not None: + granularity = min(max_granularity, granularity) + new_state = try_granularity(failing_state, granularity, use_non_granular=True) + if new_state is not None: + failing_state = new_state + continue + + granularity //= 2 + has_progress = False + while granularity >= 1: + new_state = try_granularity( + failing_state, granularity, use_non_granular=False + ) + if new_state is not None: + failing_state = new_state + has_progress = True + break + granularity //= 2 + if has_progress: + continue + + new_state = remove_outputs(failing_state, 1) + if new_state is not None: + failing_state = new_state + continue + + break + + if not graph_fails(failing_state.graph, failing_state.inps): + raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing") + + print(f"Made {num_queries} queries", file=sys.stderr) + failing_fx = fx.GraphModule(fail_f, failing_state.graph) + + # If XLA debugging environment is enabled, create minified HLO graph as well + if "XLA_HLO_DEBUG" in os.environ: + create_minified_hlo_graph(failing_fx, failing_state.inps) + + dump_state(failing_fx, failing_state.inps) + print("Wrote minimal repro out to repro.py", file=sys.stderr) + return failing_fx, failing_state.inps diff --git a/lib/python3.10/site-packages/torch/_functorch/make_functional.py b/lib/python3.10/site-packages/torch/_functorch/make_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..98a064c46ae1f1d1e695ef32e017c6d78322611a --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/make_functional.py @@ -0,0 +1,617 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NoReturn, + Sequence, + Tuple, + Type, + Union, +) + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn.utils._named_member_accessor import NamedMemberAccessor + + +# Utilities to make nn.Module "functional" +# In particular the goal is to be able to provide a function that takes as input +# the parameters and evaluate the nn.Module using fixed inputs. + + +def raise_parameter_tying_error() -> NoReturn: + raise RuntimeError( + "make_functional(module): we don't yet support models that " + "do parameter tying (also sometimes known as weight sharing). " + "Please try to rewrite your model by replacing all instances of the " + "tied parameter with another and/or comment your support in " + "https://github.com/pytorch/functorch/issues/446" + ) + + +def create_names_map( + named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]], + tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]], +) -> Dict[str, List[str]]: + """ + named_params is a dictionary of tensors: {'A': A, 'B': B} + tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B} + with potentially tied (or 'duplicated') tensors + + This function creates a mapping from the names in named_params to the + names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}. + """ + named_params = dict(named_params) + tied_named_params = dict(tied_named_params) + + tensors_dict_keys = set(named_params.keys()) + tied_tensors_dict_keys = set(tied_named_params.keys()) + assert tensors_dict_keys.issubset(tied_tensors_dict_keys) + + tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {} + for key, tensor in named_params.items(): + tensor_to_mapping[tensor] = (key, []) + for key, tensor in tied_named_params.items(): + assert tensor in tensor_to_mapping + tensor_to_mapping[tensor][1].append(key) + return dict(tensor_to_mapping.values()) + + +def _extract_members( + mod: nn.Module, + named_members: Callable[..., Iterable[Tuple[str, Tensor]]], + subclass: Callable[[Tensor], Tensor], +) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]: + all_named_members = tuple(named_members(remove_duplicate=False)) + unique_named_members = tuple(named_members(remove_duplicate=True)) + names_map = create_names_map(unique_named_members, all_named_members) + + # Remove all the members in the model + memo = {} + accessor = NamedMemberAccessor(mod) + for name, p in all_named_members: + if p not in memo: + memo[p] = subclass(torch.empty_like(p, device="meta")) + replacement = memo[p] + accessor.set_tensor(name, replacement) + + if len(unique_named_members) == 0: + names, params = (), () + else: + names, params = zip(*unique_named_members) # type: ignore[assignment] + return params, names, names_map + + +def extract_weights( + mod: nn.Module, +) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]: + """ + This function removes all the Parameters from the model and + return them as a tuple as well as their original attribute names. + The weights must be re-loaded with `load_weights` before the model + can be used again. + Note that this function modifies the model in place and after this + call, mod.parameters() will be empty. + """ + return _extract_members(mod, mod.named_parameters, nn.Parameter) + + +def extract_buffers( + mod: nn.Module, +) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]: + return _extract_members(mod, mod.named_buffers, lambda x: x) + + +def load_weights( + mod: nn.Module, + names: Sequence[str], + params: Sequence[Tensor], + as_params: bool = False, +) -> None: + """ + Reload a set of weights so that `mod` can be used again to perform a forward pass. + Note that the `params` are regular Tensors (that can have history) and so are left + as Tensors. This means that mod.parameters() will still be empty after this call. + """ + accessor = NamedMemberAccessor(mod) + if as_params: + params = [nn.Parameter(p) for p in params] + accessor.set_tensors(names, params) + + +def _swap_state( + mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor] +) -> List[Tensor]: + result: List[Tensor] = [] + accessor = NamedMemberAccessor(mod) + for (_, attr_names), elem in zip(names_map.items(), elems): + for i, attr_name in enumerate(attr_names): + if i == 0: + result.append(accessor.swap_tensor(attr_name, elem)) + else: + accessor.set_tensor(attr_name, elem) + return result + + +def load_buffers( + mod: nn.Module, + names: Sequence[str], + buffers: Sequence[Tensor], + as_params: bool = False, +) -> None: + accessor = NamedMemberAccessor(mod) + accessor.set_tensors(names, buffers) + + +def load_state( + model: nn.Module, + weights: Sequence[Tensor], + weight_names: Sequence[str], + buffers: Sequence[Tensor] = (), + buffer_names: Sequence[str] = (), +) -> nn.Module: + """load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model + + load_state takes `weights` and `buffers` and assigns them to the model. + This is the inverse operation of `make_functional_deprecated_v1`. + """ + assert len(weight_names) == len(weights) + load_weights(model, weight_names, weights) + if len(buffers) > 0: + assert len(buffer_names) == len(buffers) + load_buffers(model, buffer_names, buffers) + return model + + +def make_functional_deprecated_v1(model: nn.Module): + """make_functional_deprecated_v1(model) -> weights, func, weight_names + + Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights) + and returns a functional version of the model, `func`. This makes + it so that it is possible use transforms over the parameters of + `model`. + + `func` can be invoked as follows: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, func, _ = make_functional_deprecated_v1(model) + func(weights, (x,)) + ``` + + And here is an example of applying the grad transform: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, _, func = make_functional_deprecated_v1(model) + grad_weights = grad(func)(weights, (x,)) + ``` + + To put the state back into a model, use `load_state`. + """ + buffers = list(model.buffers()) + if len(buffers) > 0: + raise RuntimeError( + "make_functional_deprecated_v1(model): `model` has buffers. Please use " + "make_functional_with_buffers_deprecated_v1(model) instead." + ) + weights, descriptors, _ = extract_weights(model) + + def fun(weights, data): + mutable_model = copy.deepcopy(model) + load_weights(mutable_model, descriptors, weights) + return mutable_model(*data) + + return weights, fun, descriptors + + +def make_functional_with_buffers_deprecated_v1(model: nn.Module): + """make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names + + Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers) + and returns a functional version of the model, `func`. + + `func` can be invoked as follows: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model) + func(weights, buffers, (x,)) + ``` + + And here is an example of applying the grad transform: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model) + func(weights, buffers, (x,)) + grad_weights = grad(func)(weights, buffers, (x,)) + ``` + + To put the state back into a model, use `load_state`. + """ + weights, weight_descriptors, _ = extract_weights(model) + buffers, buf_descriptors, _ = extract_buffers(model) + + def fun(weights, buffers, data): + mutable_model = copy.deepcopy(model) + load_weights(mutable_model, weight_descriptors, weights) + load_buffers(mutable_model, buf_descriptors, buffers) + return mutable_model(*data) + + return weights, buffers, fun, weight_descriptors, buf_descriptors + + +class FunctionalModuleWithBuffers(nn.Module): + """ + This is the callable object returned by :func:`make_functional_with_buffers`. + """ + + def __init__( + self, + stateless_model: nn.Module, + param_names: Tuple[str, ...], + buffer_names: Tuple[str, ...], + param_names_map: Dict[str, List[str]], + buffer_names_map: Dict[str, List[str]], + ) -> None: + super().__init__() + self.stateless_model = stateless_model + self.param_names = param_names + self.buffer_names = buffer_names + + self.all_names_map = dict(param_names_map) + self.all_names_map.update(buffer_names_map) + + @staticmethod + def _create_from( + model: nn.Module, disable_autograd_tracking: bool = False + ) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]: + # TODO: We don't need to copy the model to create a stateless copy + model_copy = copy.deepcopy(model) + params, param_names, param_names_map = extract_weights(model_copy) + buffers, buffer_names, buffer_names_map = extract_buffers(model_copy) + if disable_autograd_tracking: + for param in params: + param.requires_grad_(False) + return ( + FunctionalModuleWithBuffers( + model_copy, param_names, buffer_names, param_names_map, buffer_names_map + ), + params, + buffers, + ) + + def forward( + self, params: Iterable[Tensor], buffers: Iterable[Tensor], *args, **kwargs + ) -> Any: + # Temporarily load the state back onto self.stateless_model + old_state = _swap_state( + self.stateless_model, + self.all_names_map, + tuple(params) + tuple(buffers), + ) + try: + return self.stateless_model(*args, **kwargs) + finally: + # Remove the loaded state on self.stateless_model + _swap_state(self.stateless_model, self.all_names_map, old_state) + + +class FunctionalModule(nn.Module): + """ + This is the callable object returned by :func:`make_functional`. + """ + + def __init__( + self, + stateless_model: nn.Module, + param_names: Tuple[str, ...], + names_map: Dict[str, List[str]], + ) -> None: + super().__init__() + self.stateless_model = stateless_model + self.param_names = param_names + self.names_map = names_map + + @staticmethod + def _create_from( + model: nn.Module, disable_autograd_tracking: bool = False + ) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]: + # TODO: We don't need to copy the model to create a stateless copy + model_copy = copy.deepcopy(model) + params, param_names, names_map = extract_weights(model_copy) + if disable_autograd_tracking: + for param in params: + param.requires_grad_(False) + return FunctionalModule(model_copy, param_names, names_map), params + + def forward(self, params: Iterable[Tensor], *args, **kwargs) -> Any: + # Temporarily load the state back onto self.stateless_model + old_state = _swap_state(self.stateless_model, self.names_map, params) + try: + return self.stateless_model(*args, **kwargs) + finally: + # Remove the loaded state on self.stateless_model + _swap_state(self.stateless_model, self.names_map, old_state) + + +def make_functional( + model: nn.Module, disable_autograd_tracking: bool = False +) -> Tuple[FunctionalModule, Tuple[Tensor, ...]]: + """make_functional(model, disable_autograd_tracking=False) -> func, params + + Given a ``torch.nn.Module``, :func:`make_functional` extracts the state + (params) and returns a functional version of the model, ``func``. This + makes it so that it is possible use transforms over the parameters of + ``model``. + + ``func`` can be invoked as follows: + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional + + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params = make_functional(model) + func(params, x) + + And here is an example of applying the grad transform over the parameters + of a model. + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional, grad + + x = torch.randn(4, 3) + t = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params = make_functional(model) + + def compute_loss(params, x, t): + y = func(params, x) + return nn.functional.mse_loss(y, t) + + grad_weights = grad(compute_loss)(params, x, t) + + If the model has any buffers, please use :func:`make_functional_with_buffers` instead. + + Args: + model (torch.nn.Module): Input model. + disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. + The returned params are unrelated to the set of params from the original model. If False (default), + the params will have ``requires_grad=True`` on them (aka they will be trackable with regular + PyTorch autograd), matching the requires_grad-ness of the params from the original model. + Otherwise, the returned params will have ``requires_grad=False``. Default, False. + If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or + ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. + Otherwise, if you're only planning on using functorch's gradient transforms, + then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking + history with PyTorch autograd. + + """ + buffers = list(model.buffers()) + if len(buffers) > 0: + raise RuntimeError( + "make_functional(model): `model` has buffers. Please use " + "make_functional_with_buffers(model) instead." + ) + return FunctionalModule._create_from( + model, disable_autograd_tracking=disable_autograd_tracking + ) + + +def make_functional_with_buffers( + model: nn.Module, disable_autograd_tracking: bool = False +) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]: + """make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers + + Given a ``torch.nn.Module``, make_functional_with_buffers extracts the + state (params and buffers) and returns a functional version of the model + ``func`` that can be invoked like a function. + + ``func`` can be invoked as follows: + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional_with_buffers + + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params, buffers = make_functional_with_buffers(model) + func(params, buffers, x) + + And here is an example of applying the grad transform over the parameters + of a model: + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional_with_buffers, grad + + x = torch.randn(4, 3) + t = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params, buffers = make_functional_with_buffers(model) + + def compute_loss(params, buffers, x, t): + y = func(params, buffers, x) + return nn.functional.mse_loss(y, t) + + grad_weights = grad(compute_loss)(params, buffers, x, t) + + Args: + model (torch.nn.Module): Input model. + disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. + The returned params are unrelated to the set of params from the original model. If False (default), + the params will have ``requires_grad=True`` on them (aka they will be trackable with regular + PyTorch autograd), matching the requires_grad-ness of the params from the original model. + Otherwise, the returned params will have ``requires_grad=False``. Default, False. + If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or + ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. + Otherwise, if you're only planning on using functorch's gradient transforms, + then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking + history with PyTorch autograd. + + """ + return FunctionalModuleWithBuffers._create_from( + model, disable_autograd_tracking=disable_autograd_tracking + ) + + +def transpose_stack( + tuple_of_tuple_of_tensors: Tuple[Tuple[Tensor, ...], ...] +) -> Tuple[Tensor, ...]: + tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors)) + results = tuple( + torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors + ) + return results + + +def combine_state_for_ensemble( + models: Sequence[nn.Module], +) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]: + """combine_state_for_ensemble(models) -> func, params, buffers + + Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`. + + Given a list of ``M`` ``nn.Modules`` of the same class, stacks all of their + parameters and buffers together to make ``params`` and ``buffers``. + Each parameter and buffer in the result will have an additional dimension + of size ``M``. + + :func:`combine_state_for_ensemble` also returns ``func``, a functional + version of one of the models in :attr:`models`. One cannot directly run + ``func(params, buffers, *args, **kwargs)`` directly, you probably want to + use ``vmap(func, ...)(params, buffers, *args, **kwargs)`` + + Here's an example of how to ensemble over a very simple model: + + .. code-block:: python + + num_models = 5 + batch_size = 64 + in_features, out_features = 3, 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + data = torch.randn(batch_size, 3) + + fmodel, params, buffers = combine_state_for_ensemble(models) + output = vmap(fmodel, (0, 0, None))(params, buffers, data) + + assert output.shape == (num_models, batch_size, out_features) + + .. warning:: + All of the modules being stacked together must be the same (except for + the values of their parameters/buffers). For example, they should be in the + same mode (training vs eval). + + This API is subject to change -- we're investigating better ways to + create ensembles and would love your feedback how to improve this. + """ + if len(models) == 0: + raise RuntimeError( + "combine_state_for_ensemble: Expected at least one model, got 0." + ) + if not (all(m.training for m in models) or all(not m.training for m in models)): + raise RuntimeError( + "combine_state_for_ensemble: Expected all models to " + "have the same training/eval mode." + ) + model0_typ = type(models[0]) + if not all(type(m) == model0_typ for m in models): + raise RuntimeError( + "combine_state_for_ensemble: Expected all models to be of the same class." + ) + funcs, params, buffers = zip( + *[make_functional_with_buffers(model) for model in models] + ) + params = transpose_stack(params) + buffers = transpose_stack(buffers) + return funcs[0], params, buffers + + +def functional_init( + model_class: Type[nn.Module], + ensemble_shape: Union[Tuple[()], Tuple[int]] = (), + device: torch.types.Device = "cpu", +): + def wrapped(*args, **kwargs): + if len(ensemble_shape) >= 2: + raise ValueError("NYI: ensemble_shape with more than 1 element") + if len(ensemble_shape) == 0: + model = model_class(*args, **kwargs).to(device) + return make_functional_deprecated_v1(model) + num_models = ensemble_shape[0] # type: ignore[misc] + if num_models <= 0: + raise ValueError(f"num_models {num_models} should be > 0") + # NB: Not very efficient, more of a POC + models = tuple( + model_class(*args, **kwargs).to(device) for _ in range(num_models) + ) + _, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs)) + weights = tuple(make_functional_deprecated_v1(model)[0] for model in models) + weights = tuple(zip(*weights)) + weights = tuple(torch.stack(shards).detach() for shards in weights) + return weights, fn, names + + return wrapped + + +def functional_init_with_buffers( + model_class: Type[nn.Module], + ensemble_shape: Union[Tuple[()], Tuple[int]] = (), + device: torch.types.Device = "cpu", +): + def wrapped(*args, **kwargs): + if len(ensemble_shape) >= 2: + raise ValueError("NYI: ensemble_shape with more than 1 element") + if len(ensemble_shape) == 0: + model = model_class(*args, **kwargs).to(device) + return make_functional_deprecated_v1(model) + num_models = ensemble_shape[0] # type: ignore[misc] + if num_models <= 0: + raise ValueError(f"num_models {num_models} should be > 0") + # NB: Not very efficient, more of a POC + models = tuple( + model_class(*args, **kwargs).to(device) for _ in range(num_models) + ) + ( + _, + _, + fn, + weight_names, + buffer_names, + ) = make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs)) + weights, buffers = zip( + *tuple( + make_functional_with_buffers_deprecated_v1(model)[:2] + for model in models + ) + ) + weights = tuple(zip(*weights)) + weights = tuple(torch.stack(shards).detach() for shards in weights) + buffers = tuple(zip(*buffers)) + buffers = tuple(torch.stack(shards).detach() for shards in buffers) + return weights, buffers, fn, weight_names, buffer_names + + return wrapped diff --git a/lib/python3.10/site-packages/torch/_functorch/partitioners.py b/lib/python3.10/site-packages/torch/_functorch/partitioners.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f609dc75615cffbebf1e8191f9b85917e1b3ce --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/partitioners.py @@ -0,0 +1,1933 @@ +# mypy: allow-untyped-defs +import copy +import functools +import heapq +import itertools +import logging +import math +import operator +import os +from collections import defaultdict +from dataclasses import dataclass, replace +from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union + +import torch +import torch._inductor.inductor_prims +import torch.fx as fx +import torch.utils._pytree as pytree +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import ( + find_symbol_binding_fx_nodes, + free_symbols, + hint_int, + is_symbol_binding_fx_node, +) +from torch.fx.passes import graph_drawer +from torch.utils.checkpoint import CheckpointPolicy + +from . import config +from ._aot_autograd.logging_utils import get_aot_graph_name +from ._aot_autograd.utils import is_with_effects +from .compile_utils import fx_graph_cse, get_aten_target + + +if TYPE_CHECKING: + import sympy + + +AOT_PARTITIONER_DEBUG = config.debug_partitioner +log = logging.getLogger(__name__) + +aten = torch.ops.aten +prims = torch.ops.prims + + +@dataclass +class OpTypes: + """Class for keeping track of different operator categories""" + + fusible_ops: Set[Callable] + compute_intensive_ops: Set[Callable] + random_ops: Set[Callable] + view_ops: Set[Callable] + recomputable_ops: Set[Callable] + + def is_fusible(self, node: fx.Node): + return get_aten_target(node) in self.fusible_ops + + def is_compute_intensive(self, node: fx.Node): + return get_aten_target(node) in self.compute_intensive_ops + + def is_random(self, node: fx.Node): + return get_aten_target(node) in self.random_ops + + def is_view(self, node: fx.Node): + return get_aten_target(node) in self.view_ops + + def is_recomputable(self, node: fx.Node): + return get_aten_target(node) in self.recomputable_ops + + +@dataclass +class NodeInfo: + # Be careful about iterating over these explicitly, as their order may not + # be deterministic + inputs: List[fx.Node] + _required_fw_nodes: Set[fx.Node] + required_bw_nodes: Set[fx.Node] + unclaimed_nodes: Set[fx.Node] + fw_order: Dict[fx.Node, int] + + @functools.cached_property + def required_fw_nodes(self) -> List[fx.Node]: + return sorted( + (n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n] + ) + + def is_required_fw(self, n: fx.Node) -> bool: + return n in self._required_fw_nodes + + def is_required_bw(self, n: fx.Node) -> bool: + return n in self.required_bw_nodes + + def is_unclaimed(self, n: fx.Node) -> bool: + return n in self.unclaimed_nodes + + def get_fw_order(self, n: fx.Node) -> int: + assert n in self._required_fw_nodes, f"Node {n} not in fw nodes!" + return self.fw_order[n] + + +@dataclass +class MinCutOptions: + ban_if_used_far_apart: bool + ban_if_long_fusible_chains: bool + ban_if_materialized_backward: bool + ban_if_not_in_allowlist: bool + ban_if_reduction: bool + + +def must_recompute(node: fx.Node) -> bool: + return node.meta.get("recompute", None) in [ + CheckpointPolicy.MUST_RECOMPUTE, + CheckpointPolicy.PREFER_RECOMPUTE, + ] + + +def has_recomputable_ops(fx_g: fx.GraphModule) -> bool: + found = False + for node in fx_g.graph.nodes: + if must_recompute(node): + return True + return False + + +def has_recomputable_rng_ops(fx_g: fx.GraphModule) -> bool: + for node in fx_g.graph.nodes: + if ( + must_recompute(node) + and hasattr(node.target, "tags") + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return True + return False + + +def sym_node_size(node: fx.Node) -> int: + if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)): + return 1 + assert isinstance(node.meta["val"], torch.SymFloat) + return 4 + + +class InvalidNodeBase: + def __repr__(self): + return "Invalid Node" + + +InvalidNode = InvalidNodeBase() + + +def _extract_graph_with_inputs_outputs( + joint_graph: fx.Graph, + inputs: List[fx.Node], + outputs: List[fx.Node], + subgraph: Optional[str] = None, +) -> fx.Graph: + """ + Given a graph, extracts out a subgraph that takes the specified nodes as + inputs and returns the specified outputs. + + This includes specifying non-placeholder nodes as inputs. + + The general strategy is to initialize all inputs with proxies as we + encounter them, and trace through the graph, only keeping values which take + in valid proxies. Then, all dead code is eliminated. + """ + new_graph = fx.Graph() + env = {} + + # Add new placeholder nodes in the order specified by the inputs + for node in inputs: + new_node = new_graph.placeholder(node.name) + # Can't use node_copy here as we may be turning previous call_function into placeholders + new_node.meta = node.meta + env[node] = new_node + + for node in joint_graph.nodes: + if _must_be_in_backward(node) and subgraph != "backward": + env[node] = InvalidNode # type: ignore[assignment] + continue + + if node in env: + # Node must be one of our inputs. (Any member of env which wasn't an + # input to start must have been created by this loop and won't be in + # joint_graph.nodes). + continue + elif node.op == "placeholder": + env[node] = InvalidNode # type: ignore[assignment] + elif node.op == "call_function": + all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs) + all_args = [ + isinstance(env[x], InvalidNodeBase) + for x in all_args + if isinstance(x, fx.Node) + ] + if any(all_args): + env[node] = InvalidNode # type: ignore[assignment] + continue + env[node] = new_graph.node_copy(node, lambda x: env[x]) + elif node.op == "get_attr": + env[node] = new_graph.node_copy(node, lambda x: env[x]) + elif node.op == "output": + pass + output_values = [] + for x in outputs: + if isinstance(x, fx.Node): + if x not in env: + raise RuntimeError(f"Node {x} couldn't be found in env") + assert not isinstance( + env[x], InvalidNodeBase + ), f"Node {x} was invalid, but is output" + output_values.append(env[x]) + else: + output_values.append(x) + new_graph.output(tuple(output_values)) + + new_graph.eliminate_dead_code() + new_graph.lint() + return new_graph + + +def _is_primal(node: fx.Node) -> bool: + return ( + node.op == "placeholder" + and "tangents" not in str(node.target) + and not _is_bwd_seed_offset(node) + and not _is_fwd_seed_offset(node) + ) + + +def _is_tangent(node: fx.Node) -> bool: + return node.op == "placeholder" and "tangents" in str(node.target) + + +def _is_bwd_seed_offset(node: fx.Node) -> bool: + return node.op == "placeholder" and ( + "bwd_seed" in str(node.target) or "bwd_base_offset" in str(node.target) + ) + + +def _is_fwd_seed_offset(node: fx.Node) -> bool: + return node.op == "placeholder" and ( + "fwd_seed" in str(node.target) or "fwd_base_offset" in str(node.target) + ) + + +def _is_backward_state(node: fx.Node) -> bool: + return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState) + + +def _has_tag_is_backward(node: fx.Node) -> bool: + return node.meta.get("partitioner_tag", None) == "is_backward" + + +def _has_tag_must_be_in_backward(node: fx.Node) -> bool: + return node.meta.get("partitioner_tag", None) == "must_be_in_backward" + + +def _must_be_in_backward(node: fx.Node) -> bool: + return _has_tag_must_be_in_backward(node) or ( + _has_tag_is_backward(node) and is_with_effects(node) + ) + + +def _extract_fwd_bwd_outputs( + joint_module: fx.GraphModule, *, num_fwd_outputs +) -> Tuple[List[fx.Node], List[fx.Node]]: + outputs = pytree.arg_tree_leaves( + *(node.args for node in joint_module.graph.find_nodes(op="output")) + ) + fwd_outputs = outputs[:num_fwd_outputs] + bwd_outputs = outputs[num_fwd_outputs:] + return fwd_outputs, bwd_outputs + + +def _remove_by_name(saved_values: List[fx.Node], name: str): + for saved_value in saved_values: + if saved_value.name == name: + saved_values.remove(saved_value) + break + + +def _extract_fwd_bwd_modules( + joint_module: fx.GraphModule, + saved_values: List[fx.Node], + saved_sym_nodes: List[fx.Node], + *, + num_fwd_outputs: int, +) -> Tuple[fx.GraphModule, fx.GraphModule]: + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( + joint_module, num_fwd_outputs=num_fwd_outputs + ) + placeholders = joint_module.graph.find_nodes(op="placeholder") + primal_inputs = [*filter(_is_primal, placeholders)] + tangent_inputs = [*filter(_is_tangent, placeholders)] + fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)] + bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)] + backward_state_inputs = [*filter(_is_backward_state, placeholders)] + + bwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs, + bwd_outputs, + "backward", + ) + + for node in bwd_graph.find_nodes(op="placeholder"): + # This is to filter out saved values that don't actually end up being used by the backwards pass + if not node.users: + _remove_by_name(saved_values, node.name) + _remove_by_name(saved_sym_nodes, node.name) + elif _is_backward_state(node): + # BackwardState is saved directly + _remove_by_name(saved_values, node.name) + assert backward_state_inputs + + # Now that we have the finalized list of saved values, we need to ensure + # we propagate all symbols which are referenced by backwards inputs. + # These are not directly used in the graph but are required for downstream + # sizevar assignment + saved_symbols: Set[sympy.Symbol] = set() + saved_sym_nodes_binding = [] + saved_sym_nodes_derived = [] + + # Some symbols may already be bound in the directly saved_sym_nodes, + # keep track of them so we don't re-bind them + for node in saved_sym_nodes: + symbol = is_symbol_binding_fx_node(node) + if symbol: + saved_symbols.add(symbol) + saved_sym_nodes_binding.append(node) + else: + saved_sym_nodes_derived.append(node) + + # Now go through all of the prospective backward inputs and track any + # other symbols we need to bind + symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph) + for node in itertools.chain(saved_sym_nodes_derived, saved_values, tangent_inputs): + if "val" not in node.meta: + continue + new_symbols = free_symbols(node.meta["val"]) - saved_symbols + # NB: Deterministic order please! + for s in sorted(new_symbols, key=lambda s: s.name): + # NB: For well formed graphs, the symbol should always be present, + # but we also have ways to produce ill-formed graphs, e.g., direct + # make_fx usages, so don't choke in this case + if s not in symbol_bindings: + continue + saved_sym_nodes_binding.append(symbol_bindings[s]) + saved_symbols |= new_symbols + + # Update saved_sym_nodes that are now reordered to have all bindings at + # front. This can also be used later on to figure out the position of saved + # sym nodes in the output of fwd graph. + saved_sym_nodes.clear() + saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived) + + # Now, we re-generate the fwd/bwd graphs. + # NB: This might increase compilation time, but I doubt it matters + fwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + primal_inputs + fwd_seed_offset_inputs, + fwd_outputs + saved_values + saved_sym_nodes, + "forward", + ) + bwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + saved_sym_nodes + + saved_values + + tangent_inputs + + bwd_seed_offset_inputs + + backward_state_inputs, + bwd_outputs, + "backward", + ) + + fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph) + bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph) + return fwd_module, bwd_module + + +def default_partition( + joint_module: fx.GraphModule, _joint_inputs, *, num_fwd_outputs +) -> Tuple[fx.GraphModule, fx.GraphModule]: + """ + Partitions the :attr:`joint_module` in a manner that closely resembles the + behavior observed in the original ``.forward()`` and ``.backward()`` of the + callable, i.e., the resulting forward graph contains those operators that + are executed in the original ``.forward()`` callable passed to + :func:`aot_function`. + + The default partitioner collects the operators that are between the forward + inputs and the forward outputs. This helps in finding the tensors which have + to be stashed for the backward pass. These stashed tensors become the output + of the generated forward graph. The remaining operators are then placed in + the backward graph. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + if has_recomputable_ops(joint_module): + return min_cut_rematerialization_partition( + joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs + ) + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( + joint_module, num_fwd_outputs=num_fwd_outputs + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs, "forward" + ) + forward_node_names = { + node.name for node in forward_only_graph.nodes if node.op != "output" + } + saved_values = [] + saved_sym_nodes = [] + + for node in joint_module.graph.nodes: + if node.name not in forward_node_names: + continue + if is_sym_node(node): + # Symints must be kept separate from tensors so that PythonFunction only calls + # save_for_backward on tensors and stashes symints in autograd .ctx + saved_sym_nodes.append(node) + elif "tensor_meta" not in node.meta and node.op == "call_function": + # Since we can't save tuple of tensor values, we need to flatten out what we're saving + users = node.users + assert all(user.target == operator.getitem for user in users) + saved_values.extend(users) + else: + backward_usages = [ + n for n in node.users if n.name not in forward_node_names + ] + if "tensor_meta" in node.meta and all( + is_sym_node(n) for n in backward_usages + ): + # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, + # and not the actual tensor data, + # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. + # + # Note that saving the tensor could also cause compilation problems: + # If the user mutated an input in the forward and uses its sizes/strides in the backward, + # then we would be obligated to clone the input before saving it to appease autograd. + # (This is how we originally found this bug). + saved_sym_nodes.extend(backward_usages) + else: + saved_values.append(node) + saved_values = list(dict.fromkeys(saved_values).keys()) + saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) + + return _extract_fwd_bwd_modules( + joint_module, + saved_values, + saved_sym_nodes=saved_sym_nodes, + num_fwd_outputs=num_fwd_outputs, + ) + + +INT_INF = int(1e6) + + +def _tensor_nbytes(numel: int, dtype) -> int: + return numel * dtype.itemsize + + +def _size_of(node: fx.Node) -> int: + def object_nbytes(x) -> int: + if not isinstance(x, torch.Tensor): + return 0 + return _tensor_nbytes(hint_int(x.numel(), fallback=4096), x.dtype) + + if "val" in node.meta: + val = node.meta["val"] + if isinstance(val, py_sym_types): + return 1 + # NB: The fallback values here are meaningless, maybe we should respect + # torch._inductor.config.unbacked_symint_fallback (but this is a + # layering violation) + elif isinstance(val, (list, tuple)): + return sum(object_nbytes(n) for n in val) + elif isinstance(val, dict): + return sum(object_nbytes(n) for _, n in val.items()) + elif isinstance(val, torch.Tensor): + return object_nbytes(val) + + raise RuntimeError(f"Unknown metadata type {type(val)} on node {node}") + if node.op == "get_attr": + return 0 + raise RuntimeError( + f"Node {node} didn't have `val` metadata; we should always have `val` metadata on the nodes." + ) + + +# Used for some investigative purposes +def _count_ops(graph: fx.Graph): + from collections import defaultdict + + cnt: Dict[str, int] = defaultdict(int) + for node in graph.nodes: + if node.op == "call_function": + cnt[node.target.__name__] += 1 + print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) + + +@functools.lru_cache(None) +def pointwise_ops(): + ops = [] + for attr_name in dir(torch.ops.aten): + opoverloadpacket = getattr(torch.ops.aten, attr_name) + if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket): + continue + + for overload in opoverloadpacket.overloads(): + op_overload = getattr(opoverloadpacket, overload) + if torch.Tag.pointwise in op_overload.tags: + # currently aot autograd uses packet not overload + ops.append(opoverloadpacket) + break + + return ops + + +def sort_depths(args, depth_map: Dict[fx.Node, int]) -> List[Tuple[fx.Node, int]]: + arg_depths = { + arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node) + } + return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True) + + +def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: + """ + This pass finds the first bwd node in the graph (by looking at users of + tangents) and then reorders the graph by walking from this node to all the + way to the end of the graph. At each op in this traveral, we insert this op + in a new graph and try to bring only the relevant subgraph from the other + non-bwd edges relevant for this op. This closely mimics the behavior of + autograd engine. + + Why is this pass required in the first place? + + This is an artifact of how partitioners work today. The starting point of + partitioner is a joint graph, which is fwd and then bwd graph. In the case + of checkpointing, we keep portions of fwd graph in their original place in + the joint graph, while obtaining a bwd graph. As a result, the resulting bwd + graph has copies of recomputed fwd subgraphs followed by the original bwd + graph. If we run this naively, this leads to bad memory footprint, because + the fwd subgraphs are live for way longer duration than necessary. This pass + reorders the operations such that we prioritize the ops for the original bwd + graph while only realizing those ops from the fwd graph that are necessary + at any given point in the graph. + """ + + new_graph = fx.Graph() + env: Dict[fx.Node, fx.Node] = {} + + # Add new placeholder nodes in the order specified by the inputs + for node in gm.graph.find_nodes(op="placeholder"): + env[node] = new_graph.node_copy(node, lambda x: env[x]) + + order = {} + for idx, node in enumerate(gm.graph.nodes): + order[node] = idx + + def insert_node_in_graph(node): + cur_nodes = [node] + insertable_nodes = set() + while len(cur_nodes) > 0: + node = cur_nodes.pop() + if node in insertable_nodes or node in env: + continue + insertable_nodes.add(node) + + # Bias traversal towards the nodes that have higher depth - prioritizes + # critical path first. + cur_nodes += node.all_input_nodes + + insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n]) + for node in insertable_nodes: + env[node] = new_graph.node_copy(node, lambda x: env[x]) + + # Find first bwd node in the graph + tangent_inputs = list(filter(_is_tangent, gm.graph.nodes)) + first_node_in_bwd = None + minimum_order = math.inf + for tangent in tangent_inputs: + for user in tangent.users: + if order[user] < minimum_order: + minimum_order = order[user] + first_node_in_bwd = user + + # If gradInp does not depend upon gradOut, we may not find any nodes in the "backwards pass" + if first_node_in_bwd is None: + return gm + + # Build the graph op-by-op by starting from the node all the way to the end + for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]: + insert_node_in_graph(node) + + # The output node is already built by the traversal. + new_gm = torch.fx.GraphModule(gm, new_graph) + return new_gm + + +def functionalize_rng_ops( + joint_module: fx.GraphModule, + fw_module: fx.GraphModule, + bw_module: fx.GraphModule, + num_sym_nodes: int, +) -> Tuple[fx.GraphModule, fx.GraphModule]: + # During user-driven activation checkpointing, we have to ensure that a rng + # op in fwd yields the same output as the recomputed rng op in the bwd. To + # do this, we use functionalize wrappers to wrap the random ops and share + # rng state between the fwd and bwd graphs. + + # There are 3 main steps to do this + # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. + # Step 2 - Modify the fwd pass such that + # 1) Replace rand with run_and_save_rng_state wrapper + # 2) Replace the users of the original op with the output[1] of this op. + # 3) Collect all the rng_state - output[0] of each op, and make them + # output nodes. Special care needs to be taken here because fwd outputs + # has symints at the very end. + # Step 3 - Modify the bwd pass such that + # 1) Add the input nodes just before the tangents for the stashed rng states + # 2) Replace rand with run_with_save_rng_state wrappers + # 3) Use the stashed states as inputs to these ops + + # Unique id to generate name + uid = itertools.count() + + def get_rng_ops(gmod): + random_nodes = {} + for node in gmod.graph.nodes: + if ( + node.op == "call_function" + and hasattr(node.target, "tags") + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + random_nodes[node.name] = node + return random_nodes + + def get_device(node): + """ + Check the example value of the node outputs to find the device type. + """ + if "val" not in node.meta: + return None + + candidates = node.meta["val"] + if not isinstance(candidates, tuple): + candidates = (candidates,) + + for candidate in candidates: + if isinstance(candidate, torch.Tensor): + if candidate.device.type == "cuda": + return "cuda" + + return "cpu" + + def get_sample_rng_state(device): + if device == "cuda": + return torch.cuda.get_rng_state() + return torch.get_rng_state() + + # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. + joint_graph_rng_ops = get_rng_ops(joint_module) + fw_graph_rng_ops = get_rng_ops(fw_module) + bw_graph_rng_ops = get_rng_ops(bw_module) + recomputable_rng_ops_map = {} + for node in joint_module.graph.nodes: + if ( + must_recompute(node) + and hasattr(node.target, "tags") + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + base_node = joint_graph_rng_ops[node.name] + fw_node = fw_graph_rng_ops[node.name] + bw_node = bw_graph_rng_ops[node.name] + recomputable_rng_ops_map[base_node] = {"fwd": fw_node, "bwd": bw_node} + + run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state + run_with_rng_state = torch._prims.rng_prims.run_with_rng_state + bw_tangent_start_node = None + for node in bw_module.graph.find_nodes(op="placeholder"): + if "tangent" in node.name: + bw_tangent_start_node = node + break + if bw_tangent_start_node is None: + raise RuntimeError( + "Couldn't find tangent node in graph inputs. This is unexpected, please file a bug if you see this" + ) + + fw_rng_state_outputs = [] + for base_node, node_pair in recomputable_rng_ops_map.items(): + # Step 2 - Modify the fwd pass such that + fw_node = node_pair["fwd"] + bw_node = node_pair["bwd"] + fw_graph = fw_module.graph + with fw_graph.inserting_before(fw_node): + functional_fw_node = fw_graph.create_node( + "call_function", + run_and_save_rng, + args=(fw_node.target, *fw_node.args), + kwargs=fw_node.kwargs, + ) + state = fw_graph.create_node( + "call_function", + operator.getitem, + args=(functional_fw_node, 0), + kwargs={}, + ) + rng_output = fw_graph.create_node( + "call_function", + operator.getitem, + args=( + functional_fw_node, + 1, + ), + kwargs={}, + ) + fw_node.replace_all_uses_with(rng_output) + fw_graph.erase_node(fw_node) + fw_rng_state_outputs.append(state) + + # Step 3 - Modify the bwd pass such that + bw_graph = bw_module.graph + with bw_graph.inserting_before(bw_tangent_start_node): + state_name = f"rng_state_output_{next(uid)}" + bw_rng_state_node = bw_graph.placeholder(state_name) + bw_rng_state_node.meta["val"] = get_sample_rng_state(get_device(fw_node)) + + with bw_graph.inserting_before(bw_node): + rng_output = bw_graph.create_node( + "call_function", + run_with_rng_state, + args=(bw_rng_state_node, bw_node.target, *bw_node.args), + kwargs=bw_node.kwargs, + ) + + bw_node.replace_all_uses_with(rng_output) + bw_graph.erase_node(bw_node) + + # Add the rng states in the output of the fwd graph. AOT Autograd assumes + # that symints are at the end of forward graph outputs. So, insert the new + # rng states accordingly. + fw_output_node = next(iter(fw_module.graph.find_nodes(op="output"))) + fw_outputs = fw_output_node.args[0] + sym_node_start_idx = len(fw_outputs) - num_sym_nodes + outputs = ( + fw_outputs[:sym_node_start_idx] + + tuple(fw_rng_state_outputs) + + fw_outputs[sym_node_start_idx:] + ) + fw_module.graph.output(outputs) + fw_module.graph.erase_node(fw_output_node) + fw_module.recompile() + bw_module.recompile() + return fw_module, bw_module + + +def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: + """ + If there are two consecutive checkpointed blocks with no operator in + between, we would still want to stash the tensor at the boundary of + checkpointed blocks. The following pass makes the last output node + non-recomputable to allow for that. + """ + for node in joint_module.graph.nodes: + if must_recompute(node): + for user in node.users: + if ( + must_recompute(user) + and user.meta["ac_graph_id"] > node.meta["ac_graph_id"] + ): + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + return joint_module + + +def solve_min_cut( + joint_graph: fx.Graph, + node_info: NodeInfo, + min_cut_options: MinCutOptions, + dont_ban=None, +): + if dont_ban is None: + dont_ban = set() + op_types = get_default_op_list() + + if AOT_PARTITIONER_DEBUG: + joint_module_ops = { + str(node.target._overloadpacket) + for node in joint_graph.nodes + if node.op == "call_function" and hasattr(node.target, "_overloadpacket") + } + ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops} + print("Ops banned from re-materialization: ", ops_ignored) + print() + + def can_fuse_into_auto_functionalized(a, b): + if b.target != torch.ops.higher_order.auto_functionalized: + return False + mutable_op = b.args[0] + ( + mutable_arg_names, + _, + ) = torch._higher_order_ops.auto_functionalize.get_mutable_args(mutable_op) + for name in mutable_arg_names: + arg = b.kwargs[name] + if a is arg: + return True + if isinstance(arg, list): + if a in arg: + return True + return False + + def can_fuse_into_triton_kernel_wrapper_functional(a, b): + if b.target != torch.ops.higher_order.triton_kernel_wrapper_functional: + return False + mutable_arg_names = b.kwargs["tensors_to_clone"] + for name in mutable_arg_names: + arg = b.kwargs["kwargs"][name] + if a is arg: + return True + return False + + def is_fusible(a, b): + # We can perform "memory fusion" into a cat, but cat cannot be a + # producer to a fusion + if get_aten_target(b) == aten.cat: + return True + if can_fuse_into_auto_functionalized(a, b): + return True + if can_fuse_into_triton_kernel_wrapper_functional(a, b): + return True + return op_types.is_fusible(a) and op_types.is_fusible(b) + + try: + import networkx as nx + except ImportError as e: + raise RuntimeError( + "Need networkx installed to perform smart recomputation " "heuristics" + ) from e + + def is_materialized_backwards(node): + if op_types.is_view(node): + return False + cur_nodes = {node} + while len(cur_nodes) > 0: + cur = cur_nodes.pop() + for user in cur.users: + if not node_info.is_required_fw(user) and not is_fusible(cur, user): + return True + if op_types.is_view(user): + cur_nodes.add(user) + + return False + + def should_ban_recomputation(node): + if node.op != "call_function": + return False + if node.target == operator.getitem: + return False + if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE: + return True + if config.recompute_views and op_types.is_view(node): + return False + if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: + return False + + if min_cut_options.ban_if_not_in_allowlist: + if not op_types.is_recomputable(node): + return True + else: + if op_types.is_random(node) or op_types.is_compute_intensive(node): + return True + + # If a node *must* be materialized in the backwards pass, then we + # should never recompute it. This is a pretty subtle point. In + # general, the assumption we make is that recomputing a node in the + # backwards pass is "free". However, if a node must be materialized + # in the backwards pass, then recomputing it is never free. + if min_cut_options.ban_if_materialized_backward and is_materialized_backwards( + node + ): + log.info("materialized backwards: %s %s", node, tuple(node.users)) + return True + + # Arbitrary hack that sometimes seems to help things. The above + # modification appears to have made this heuristic a lot less critical + # for performance. + # NB: As of PR #121692, this hack no longer seems necessary. + if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw: + return True + + # If the output of an op is 4x smaller (arbitrary choice), + # then we don't allow recomputation. The idea here is that for + # things like reductions, saving the output of the reduction is very + # cheap/small, and it makes sure we don't do things like recompute + # normalizations in the backwards. + if min_cut_options.ban_if_reduction: + input_tensors_size = sum( + _size_of(i) for i in node.args if isinstance(i, fx.Node) + ) + output_size = _size_of(node) + return output_size * 4 < input_tensors_size + return False + + def is_materialized(node): + if node.op == "placeholder": + return True + + return not all(is_fusible(node, user) for user in node.users) + + def get_node_weight(node) -> float: + mem_sz = _size_of(node) + if config.recompute_views and op_types.is_view(node): + # If `config.recompute_views=True`, we don't save views. This is generally + # a good idea since views are free to recompute, and it makes it a bit simpler + # to analyze. + # NB: If they're not free to recompute (e.g. nested tensors)... I + # think we should modify checks for view_ops to `is_view` and check + # that. Basically, with nested tensors, `aten.view` is not a "view + # op". + return math.inf + + if isinstance(node.meta["val"], py_sym_types): + # We never want to save symfloats + if not isinstance(node.meta["val"], torch.SymInt): + return INT_INF + + # Heuristic to bias towards nodes closer to the backwards pass + # Complete guess about current value + mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1))) + if is_materialized(node): + return mem_sz + else: + return mem_sz * 2 + + nx_graph = nx.DiGraph() + banned_nodes = set() + + def ban_recomputation_if_allowed(node): + if op_types.is_view(node): + return False + if node in dont_ban: + return False + # This bans recomputation of the node unless we've been forced not to by + # user annotation + if must_recompute(node): + return False + + if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat): + return False + + banned_nodes.add(node) + # A node will only ever be recomputed if there is a path from an + # ancestor of this node to the backwards path through this node that + # doesn't go through any saved value. If this node is saved, then that + # condition is not possible. + nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) + return True + + for node in joint_graph.nodes: + if node.op == "output": + continue + + if node in node_info.required_bw_nodes: + if node not in node_info.inputs: + nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) + continue + # If someone saves a input for backward as-is and backward + # returns that tensor as-is as a grad input, then the node x would + # be both a required_bw_node and an input. In this case we + # (1) connect x_in to to the source, (2) x_out to the sink, and + # (3) assign the proper weight to the x_in-x_out edge, so that + # x would be part of cut nodes. A case where this happens is if + # NestedTensor saves a offset tensor as part of the singleton int + # in sizes. + nx_graph.add_edge(node.name + "_out", "sink", capacity=math.inf) + + if must_recompute(node): + # If user explicitly says they want to recompute a node, we honor it + # by adding an inf-capacity edge from X_in to the sink. + # This way, X_in node is guaranteed to be part of the subgraph that contains "sink" + # after the cut, thus guaranteeing that X op will be recomputed. + nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) + continue + + if _is_primal(node) or _is_fwd_seed_offset(node): + ban_recomputation_if_allowed(node) + + # If a node can't be recomputed (too expensive or involves randomness), + # we prevent it from being recomputed by adding an inf edge to the source + # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed. + if node_info.is_required_fw(node) and should_ban_recomputation(node): + ban_recomputation_if_allowed(node) + + # Checks if a node is actually a tuple. Can be simplified to just an isinstance check if we always use faketensors. + is_non_tensor_node = ( + "val" not in node.meta and "tensor_meta" not in node.meta + ) or ("val" in node.meta and not isinstance(node.meta["val"], torch.Tensor)) + + if is_sym_node(node): + weight = float(sym_node_size(node)) + elif is_non_tensor_node: + weight = ( + 0.0 if isinstance(node.meta.get("val"), BackwardState) else math.inf + ) + else: + weight = get_node_weight(node) + # Creates the weights on the "node" edge + nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) + for user in node.users: + nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf) + + # todo(chilli): This is the most questionable of the 3 heuristics for banning recompute. + # Some example models to look at where this helps perf: poolformer_m36, + # mixer_b16_224, cait_m36_384 + + # The "rough" idea here is that if you have some node that is used by both a + # node nearby downstream as well as a node far downstream, if we recompute + # both of the downstream nodes, we're unlikely to be able to fuse both + # downstream nodes together. + + # Thus, we shouldn't aim to recompute far downstream nodes that depend on + # this node. That intuition of "far downstream" is captured by whether + # there's an unfusible op along the chain somewhere + + # It could probably be improved by properly analyzing what's going on in the + # backwards pass instead of only relying on whether it's unfusible in the + # forwards. + + def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: + """ + Finds the first unfusible node in the chain of nodes starting from + `start_nodes` and returns its position. + """ + sorted_nodes: List[Tuple[int, fx.Node, bool]] = [] + for n in start_nodes: + heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True)) + + while len(sorted_nodes) > 0: + _, node, node_is_fusible = heapq.heappop(sorted_nodes) + if not node_is_fusible: + return node_info.get_fw_order(node) + for user in node.users: + if node_info.is_required_fw(user): + if node_info.get_fw_order(user) > max_range: + continue + heapq.heappush( + sorted_nodes, + (node_info.get_fw_order(user), user, is_fusible(node, user)), + ) + return max_range + + if min_cut_options.ban_if_used_far_apart: + for used_node in node_info.required_fw_nodes: + orders = [ + node_info.get_fw_order(user) + for user in used_node.users + if node_info.is_required_fw(user) + ] + fw_users = [ + user for user in used_node.users if node_info.is_required_fw(user) + ] + if len(orders) > 0: + first_unfusible_use = find_first_unfusible(fw_users, max(orders)) + for user in tuple(used_node.users): + if ( + node_info.is_required_fw(user) + and node_info.get_fw_order(user) > first_unfusible_use + and is_fusible(used_node, user) + ): + if user in banned_nodes: + continue + log.info( + "used above/below fusible %s:(%s) -> %s -> %s:(%s)", + used_node, + node_info.get_fw_order(used_node), + first_unfusible_use, + user, + node_info.get_fw_order(user), + ) + ban_recomputation_if_allowed(user) + + # This heuristic is fairly straightforward. The idea is that although it is + # cheap to recompute bandwidth-bound ops, we don't want to end up in a situation + # where we have a long chain of pointwise ops from the beginning to the end + # of the model (like say, residual connections) + + # todo: I'm not totally sure why this heuristic matters. It's possible that this is + # working around Inductor fusion decisions, or that it's a patch over + # suboptimal partitioning decisions + + # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36 + + if min_cut_options.ban_if_long_fusible_chains: + visited = set() + for start_node in joint_graph.nodes: + if not node_info.is_required_fw(start_node): + continue + fusible = [(node_info.get_fw_order(start_node), start_node)] + start_order = node_info.get_fw_order(start_node) + while len(fusible) > 0: + _, cur = heapq.heappop(fusible) + if cur in visited: + continue + visited.add(cur) + # 100 is arbitrary choice to try and prevent degenerate cases + if ( + node_info.get_fw_order(cur) > start_order + 100 + and len(fusible) == 0 + ): + log.info( + "too long %s %s %s %s", + cur, + start_node, + node_info.get_fw_order(cur), + node_info.get_fw_order(start_node), + ) + ban_recomputation_if_allowed(cur) + break + + for user in cur.users: + if ( + node_info.is_required_fw(user) + and is_fusible(cur, user) + and user not in banned_nodes + ): + heapq.heappush(fusible, (node_info.get_fw_order(user), user)) + + try: + cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") + except Exception: + print("Failed to compute min-cut on following graph:") + print("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))) + visualize_min_cut_graph(nx_graph) + raise + + reachable, non_reachable = partition + cutset: Set[Tuple[str, str]] = set() + for u, nbrs in ((n, nx_graph[n]) for n in reachable): + cutset.update((u, v) for v in nbrs if v in non_reachable) + + cut_nodes = set() + for node_in, node_out in cutset: + assert node_in[:-3] == node_out[:-4] + node_name = node_in[:-3] + cut_nodes.add(node_name) + + name_to_node = get_name_to_node(joint_graph) + # To make this stuff deterministic + node_idx = {node: idx for idx, node in enumerate(joint_graph.nodes)} + saved_values = sorted( + (name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x] + ) + return saved_values, banned_nodes + + +def visualize_min_cut_graph(nx_graph): + import networkx as nx + import pydot + + dot_format = nx.nx_pydot.to_pydot(nx_graph).to_string() + dot_graph = pydot.graph_from_dot_data(dot_format)[0] + for edge in dot_graph.get_edges(): + weight = nx_graph[edge.get_source()][edge.get_destination()]["capacity"] + # Set edge label to weight + edge.set_label(str(weight)) + # Color edges with weight 'inf' as red + if weight == float("inf"): + edge.set_color("red") + print("Visualizing the failed graph to min_cut_failed.svg") + dot_graph.write_svg("min_cut_failed.svg") + + +def get_default_op_list() -> OpTypes: + default_recomputable_ops: List[Callable] = [ + aten.add, + aten.sub, + aten.div, + aten.atan2, + aten.mul, + aten.max, + aten.min, + aten.pow, + aten.remainder, + aten.fmod, + aten.__and__, + aten.__or__, + aten.__xor__, + aten.__lshift__, + aten.__rshift__, + aten.eq, + aten.ne, + aten.ge, + aten.gt, + aten.le, + aten.lt, + aten.abs, + aten.bitwise_not, + aten.ceil, + aten.floor, + aten.frac, + aten.neg, + aten.relu, + aten.round, + aten.silu, + aten.trunc, + aten.log, + aten.log10, + aten.log1p, + aten.log2, + aten.lgamma, + aten.exp, + aten.expm1, + aten.erf, + aten.erfc, + aten.cos, + aten.acos, + aten.cosh, + aten.sin, + aten.asin, + aten.sinh, + aten.tan, + aten.atan, + aten.tanh, + aten.atanh, + aten.sqrt, + aten.rsqrt, + aten.reciprocal, + aten.sigmoid, + aten.softplus, + aten.threshold, + aten.threshold_backward, + aten.clamp, + aten.where, + aten.lerp, + aten.addcmul, + aten.gelu, + aten.gelu_backward, + aten.sum, + aten.mean, + aten._grad_sum_to_size, + aten.sum_to_size, + aten.amax, + aten.to, + aten.type_as, + operator.getitem, + aten.squeeze, + aten.unsqueeze, + aten.rsub, + aten._to_copy, + ] # noqa: E501,B950 + recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] + recomputable_view_ops += [ + aten.view, + aten.slice, + aten.t, + prims.broadcast_in_dim, + aten.expand, + aten.as_strided, + aten.permute, + ] + view_ops = recomputable_view_ops + default_recomputable_ops += [ + prims.div, + prims.convert_element_type, + aten.clone, + aten._to_copy, + aten.full_like, + prims.var, + prims.sum, + aten.var, + aten.std, + prims.broadcast_in_dim, + aten.select, + aten._unsafe_view, + aten.view, + aten.expand, + aten.slice, + aten.reshape, + aten.broadcast_tensors, + aten.scalar_tensor, + aten.ones, + aten.new_zeros, + aten.lift_fresh_copy, + aten.arange, + aten.triu, + aten.var_mean, + aten.isinf, + aten.any, + aten.full, + aten.as_strided, + aten.zeros, + aten.empty, + aten.empty_like, + aten.argmax, + aten.maximum, + prims.iota, + prims._low_memory_max_pool2d_offsets_to_indices, + ] # noqa: E501,B950 + # Natalia said that we should allow recomputing indexing :) + default_recomputable_ops += [aten.index, aten.gather] + default_recomputable_ops += view_ops + + default_recomputable_ops += pointwise_ops() + + default_recomputable_ops += [ + aten.zeros_like, + ] + + default_recomputable_ops += [method_to_operator(m) for m in magic_methods] + recomputable_ops = set(default_recomputable_ops) + + random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] + compute_intensive_ops = [ + aten.mm, + aten.convolution, + aten.convolution_backward, + aten.bmm, + aten.addmm, + aten._scaled_dot_product_flash_attention, + aten._scaled_dot_product_efficient_attention, + aten._flash_attention_forward, + aten._efficient_attention_forward, + aten.upsample_bilinear2d, + aten._scaled_mm, + ] # noqa: E501,B950 + + fusible_ops = recomputable_ops | set(random_ops) + return OpTypes( + set(fusible_ops), + set(compute_intensive_ops), + set(random_ops), + set(view_ops), + set(recomputable_ops), + ) + + +def get_name_to_node(graph: fx.Graph): + name_to_node = {} + for node in graph.nodes: + name_to_node[node.name] = node + return name_to_node + + +def greedy_knapsack( + memory: List[float], runtimes: List[float], max_memory: float +) -> Tuple[float, List[int], List[int]]: + n = len(runtimes) + items = list(range(n)) + + # Sort items based on the ratio of runtime to memory in descending order + items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True) + + total_memory = 0.0 + total_runtime = 0.0 + items_to_save = [] + items_to_allow_recomputing = [] + + for i in items: + if total_memory + memory[i] <= max_memory: + total_memory += memory[i] + total_runtime += runtimes[i] + items_to_save.append(i) + else: + items_to_allow_recomputing.append(i) + return total_runtime, items_to_save, items_to_allow_recomputing + + +def ilp_knapsack( + memory: List[float], runtimes: List[float], max_memory: float +) -> Tuple[float, List[int], List[int]]: + import numpy as np + + try: + from scipy.optimize import Bounds, LinearConstraint, milp + except ImportError: + raise RuntimeError( + "To use the ILP for memory budget checkpointing you need to install scipy" + ) from None + + np_memory = np.array(memory) + np_runtimes = np.array(runtimes) + c = -np_runtimes # type: ignore[operator] + + memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory)) + constraints = [memory_constraint] + + integrality = np.ones_like(c) + res = milp( + c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1) + ) + if not res.success: + raise RuntimeError("Somehow scipy solving failed") + + items_to_save = [] + items_to_allow_recomputing = [] + for idx, i in enumerate(res.x): + if i == 1: + items_to_save.append(idx) + else: + items_to_allow_recomputing.append(idx) + return -res.fun, items_to_save, items_to_allow_recomputing + + +def dp_knapsack( + memory: List[float], runtimes: List[float], max_memory: float +) -> Tuple[float, List[int], List[int]]: + # Scaling factor to convert floating point weights to integers + S = 10000 + + # Quantize the memory weights + quantized_memory = torch.tensor( + [int(round(m * S)) for m in memory], dtype=torch.long, device="cpu" + ) + runtimes = torch.tensor(runtimes, dtype=torch.float32, device="cpu") + + # Quantized pseudopolynomial DP for 0-1 Knapsack + quantized_max_memory = int(round(max_memory * S)) + + n = len(memory) + + # Initialize the DP table + # TODO(chilli): I think if needed, this memory can be optimized with sliding + # window trick + Hirschberg trick: + # https://codeforces.com/blog/entry/47247?#comment-316200 + dp = torch.zeros( + (n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu" + ) + + for i in range(1, n + 1): + current_memory = quantized_memory[i - 1] + current_runtime = runtimes[i - 1] + + # Copy the previous row + dp[i, :] = dp[i - 1, :] + + # Update dp[i, j] for all j >= current_memory + if current_memory == 0: + dp[i, :] = dp[i - 1, :] + current_runtime + else: + dp[i, current_memory:] = torch.maximum( + dp[i - 1, current_memory:], + dp[i - 1, :-current_memory] + current_runtime, + ) + + # Backtrack to find the items included in the knapsack + saved_items = [] + recomputable_items = [] + j: int = quantized_max_memory + for i in range(n, 0, -1): + if dp[i][j] != dp[i - 1][j]: + saved_items.append(i - 1) # Include this item (indexing from 0) + j -= int(quantized_memory[i - 1].item()) + else: + recomputable_items.append(i - 1) + + saved_items.reverse() # To get items in the order they were added + + # The maximum runtime that can be achieved within the max_memory constraint + max_runtime = dp[n][quantized_max_memory].item() + + return max_runtime, saved_items, recomputable_items + + +def _optimize_runtime_with_given_memory( + memory: List[float], + runtimes: List[float], + max_memory: float, +) -> Tuple[float, List[int], List[int]]: + SOLVER = config.activation_memory_budget_solver + if SOLVER == "greedy": + return greedy_knapsack(memory, runtimes, max_memory) + elif SOLVER == "ilp": + return ilp_knapsack(memory, runtimes, max_memory) + elif SOLVER == "dp": + return dp_knapsack(memory, runtimes, max_memory) + else: + raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}") + + +from torch.utils._mode_utils import no_dispatch + + +def estimate_runtime(node): + RUNTIME_MODE = config.activation_memory_budget_runtime_estimator + + def materialize_arg(x): + if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor): + shape = list(x.meta["val"].shape) + + def realize_symbol(d): + return hint_int(d, fallback=4096) + + shape = [realize_symbol(s) for s in shape] + return x.meta["val"].new_empty_strided( + shape, stride=x.meta["tensor_meta"].stride + ) + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt): + return hint_int(x.meta["val"], fallback=4096) + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat): + return 1.0 + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool): + return True + else: + return x + + if RUNTIME_MODE == "testing": + return 1 + + elif RUNTIME_MODE == "profile": + with no_dispatch(): + from torch._inductor.runtime.benchmarking import benchmarker + + args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) + ms = benchmarker.benchmark_gpu(lambda: node.target(*args, **kwargs)) + return ms + + elif RUNTIME_MODE == "flops": + # todo(chilli): Normalize this to also return ms + from torch.utils.flop_counter import FlopCounterMode + + args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) + with FlopCounterMode(display=False) as mode: + node.target(*args, **kwargs) + counted_flops = mode.get_total_flops() + return max(counted_flops, 1) + else: + raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}") + + +def choose_saved_values_set( + joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1 +) -> List[fx.Node]: + if memory_budget > 1 or memory_budget < 0: + raise RuntimeError( + f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}" + ) + min_cut_options = MinCutOptions( + ban_if_used_far_apart=config.ban_recompute_used_far_apart, + ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains, + ban_if_materialized_backward=config.ban_recompute_materialized_backward, + ban_if_not_in_allowlist=config.ban_recompute_not_in_allowlist, + ban_if_reduction=config.ban_recompute_reductions, + ) + + if config.aggressive_recomputation: + min_cut_options = replace( + min_cut_options, + ban_if_used_far_apart=False, + ban_if_long_fusible_chains=False, + ban_if_materialized_backward=False, + ban_if_not_in_allowlist=False, + ) + if memory_budget == 0: + return node_info.inputs + + runtime_optimized_saved_values, _ = solve_min_cut( + joint_graph, + node_info, + min_cut_options, + ) + # return runtime_optimized_saved_values + if memory_budget == 1: + return runtime_optimized_saved_values + + def estimate_activations_size(saved_values: List[fx.Node]) -> float: + return sum(map(_size_of, saved_values)) / 1e9 + + min_act_size = estimate_activations_size(node_info.inputs) + max_act_size = estimate_activations_size(runtime_optimized_saved_values) + # The optimized choice is smaller than the inputs anyways + if max_act_size <= min_act_size: + return runtime_optimized_saved_values + + def get_normalized_size(sz): + return (sz / 1e9) / (max_act_size - min_act_size) + + def get_mem_ratio(activations: List[fx.Node]): + return (estimate_activations_size(activations) - min_act_size) / ( + max_act_size - min_act_size + ) + + more_aggressive_options = replace( + min_cut_options, + ban_if_used_far_apart=False, + ban_if_long_fusible_chains=False, + ban_if_materialized_backward=False, + ) + more_aggressive_saved_values, _ = solve_min_cut( + joint_graph, node_info, more_aggressive_options + ) + if get_mem_ratio(more_aggressive_saved_values) < memory_budget: + return more_aggressive_saved_values + + aggressive_options = replace( + more_aggressive_options, + ban_if_not_in_allowlist=False, + ) + aggressive_recomputation_saved_values, banned_nodes = solve_min_cut( + joint_graph, node_info, aggressive_options + ) + + if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget: + return aggressive_recomputation_saved_values + + from torch._inductor.fx_utils import get_node_storage + + input_storages = {get_node_storage(node) for node in node_info.inputs} + + def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]: + return [ + i + for i in banned_nodes + if ( + # Only allow recomputing nodes that are actually required for BW + i.dist_from_bw < int(1e9) # type: ignore[attr-defined] + and get_node_storage(i) not in input_storages + ) + ] + + recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes) + + # default: runtime_optimized_saved_values + # more aggressive: more_aggressive_saved_values + # full aggressive: aggressive_recomputation_saved_values + + all_recomputable_banned_nodes = sorted( + recomputable_banned_nodes, key=_size_of, reverse=True + ) + if len(all_recomputable_banned_nodes) == 0: + return node_info.inputs + memories_banned_nodes = [ + get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes + ] + runtimes_banned_nodes = [ + estimate_runtime(node) for node in all_recomputable_banned_nodes + ] + from torch.utils._mode_utils import no_dispatch + + def get_saved_values_knapsack(memory_budget): + with no_dispatch(): + ( + expected_runtime, + saved_node_idxs, + recomputable_node_idxs, + ) = _optimize_runtime_with_given_memory( + memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0) + ) + dont_ban = set() + for idx in recomputable_node_idxs: + dont_ban.add(all_recomputable_banned_nodes[idx]) + assert dont_ban.issubset(all_recomputable_banned_nodes) + + saved_values, _ = solve_min_cut( + joint_graph, + node_info, + aggressive_options, + dont_ban, + ) + return saved_values, expected_runtime + + if config.visualize_memory_budget_pareto: + options = [] + for sweep_memory_budget in range(100, -1, -5): + saved_values, expected_runtime = get_saved_values_knapsack( + sweep_memory_budget / 100 + ) + options.append( + ( + sweep_memory_budget, + sum(runtimes_banned_nodes) - expected_runtime, + get_mem_ratio(saved_values), + ) + ) + + import matplotlib.pyplot as plt + + x_values = [item[2] for item in options] + y_values = [item[1] for item in options] + + # Plotting the values with updated axis labels and chart title + plt.figure(figsize=(10, 6)) + plt.plot(x_values, y_values, marker="o") + + # Adding labels for each point + for i, txt in enumerate(x_values): + plt.annotate( + f"{txt:.2f}", + (txt, y_values[i]), + textcoords="offset points", + xytext=(0, 10), + ha="center", + ) + + plt.xlabel("Memory Budget") + plt.ylabel("Runtime of Recomputed Components") + plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime") + plt.grid(True) + fig = plt.gcf() + plt.show() + fig_name = f"memory_budget_pareto_{get_aot_graph_name()}.png" + fig.savefig(fig_name) + log.warning("Generated Pareto frontier curve at %s", fig_name) + + # todo(chilli): Estimated doesn't align exactly with actual - actual is + # usually less memory than estimated. i'm guessing (actually quite + # unsure about this) that's because estimated is just only including + # tensors we actually banned from recompute, but there may be other + # tensors that we choose to save. + + return get_saved_values_knapsack(memory_budget=memory_budget)[0] + + +def min_cut_rematerialization_partition( + joint_module: fx.GraphModule, + _joint_inputs, + compiler="inductor", + *, + num_fwd_outputs, +) -> Tuple[fx.GraphModule, fx.GraphModule]: + """ + Partitions the joint graph such that the backward recomputes the forward. + Recomputing helps in trading off memory bandwidth with computation. + + To create the fwd and bwd graph, we copy the joint graph, manually set the + outputs to just original forward or backward outputs. And then we run the + resulting graphs through dead code elimination. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + _joint_inputs: The inputs to the joint graph. This is unused. + compiler: This option determines the default set of recomputable ops. + Currently, there are two options: ``nvfuser`` and ``inductor``. + recomputable_ops: This is an optional set of recomputable ops. If this + is not None, then this set of ops will be used instead of the + default set of ops. + num_fwd_outputs: The number of outputs from the forward graph. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + + joint_module.graph.eliminate_dead_code() + joint_module.recompile() + + fx_g = joint_module.graph + + # add the CSE pass + if config.cse: + cse_graph = fx_graph_cse(fx_g) + joint_module.graph = cse_graph + joint_graph = joint_module.graph + + graph_has_recomputable_ops = has_recomputable_ops(joint_module) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) + if graph_has_recomputable_ops: + joint_module = cleanup_recompute_tags(joint_module) + + def classify_nodes(joint_module): + name_to_node = get_name_to_node(joint_module.graph) + required_bw_nodes = set() + for node in joint_module.graph.nodes: + if node.op == "placeholder" and "tangents" in node.target: + required_bw_nodes.add(node) + elif _must_be_in_backward(node): + required_bw_nodes.add(node) + + if node in required_bw_nodes: + for user in node.users: + required_bw_nodes.add(user) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list( + filter(_is_fwd_seed_offset, joint_module.graph.nodes) + ) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( + joint_module, num_fwd_outputs=num_fwd_outputs + ) + required_bw_nodes.update( + o for o in bwd_outputs if o is not None and o.op != "output" + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs, "forward" + ) + required_fw_nodes: Set[fx.Node] = { + name_to_node[node.name] + for node in forward_only_graph.nodes + if node.op != "output" + } + unclaimed_nodes = { + node + for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes + } + fw_cnt = 0 + fw_order = {} + for node in joint_module.graph.nodes: + if node in required_fw_nodes: + fw_order[node] = fw_cnt + fw_cnt += 1 + return NodeInfo( + inputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes, fw_order + ) + + node_info = classify_nodes(joint_module) + + # networkx blows up on graphs with no required backward nodes + # Since there's nothing to partition anyway, and the default partitioner can "handle" + # this case, send our graph over to the default partitioner. + if len(node_info.required_bw_nodes) == 0: + return default_partition( + joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs + ) + + for node in reversed(joint_module.graph.nodes): + if node.op == "output": + node.dist_from_bw = int(1e9) + elif not node_info.is_required_fw(node): + node.dist_from_bw = 0 + else: + node.dist_from_bw = int(1e9) + for user in node.users: + node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) + + memory_budget = config.activation_memory_budget + for node in joint_graph.nodes: + if isinstance(node.meta.get("memory_budget", None), float): + memory_budget = node.meta["memory_budget"] + break + # print("Memory Budget: ", memory_budget) + saved_values = choose_saved_values_set( + joint_graph, node_info, memory_budget=memory_budget + ) + # save_for_backward on tensors and stashes symints in autograd .ctx + saved_sym_nodes = list(filter(is_sym_node, saved_values)) + saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) + + # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols + fw_module, bw_module = _extract_fwd_bwd_modules( + joint_module, + saved_values, + saved_sym_nodes=saved_sym_nodes, + num_fwd_outputs=num_fwd_outputs, + ) + + if graph_has_recomputable_ops: + if graph_has_recomputable_rng_ops: + fw_module, bw_module = functionalize_rng_ops( + joint_module, fw_module, bw_module, len(saved_sym_nodes) + ) + bw_module = reordering_to_mimic_autograd_engine(bw_module) + + if AOT_PARTITIONER_DEBUG: + from torch._inductor.fx_utils import get_node_storage + + storages = {get_node_storage(node) for node in saved_values} + print( + "Theoretical Activations Stored: ", + sum(_size_of(i) for i in saved_values) / 1e9, + ) + sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values]) + fw_module_nodes = { + node.name for node in fw_module.graph.nodes if node.op == "call_function" + } + bw_module_nodes = { + node.name for node in bw_module.graph.nodes if node.op == "call_function" + } + remat_nodes = fw_module_nodes & bw_module_nodes + + counts: Dict[str, int] = defaultdict(int) + for node in fw_module.graph.nodes: + if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"): + counts[str(node.target._overloadpacket)] += 1 + print( + f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}" + ) + print( + "Count of Ops Rematerialized: ", + sorted(counts.items(), key=lambda x: x[1], reverse=True), + ) + return fw_module, bw_module + + +def draw_graph( + traced: torch.fx.GraphModule, + fname: str, + figname: str = "fx_graph", + clear_meta: bool = True, + prog: Optional[Union[str, List[str]]] = None, + parse_stack_trace: bool = False, + dot_graph_shape: Optional[str] = None, +) -> None: + if clear_meta: + new_graph = copy.deepcopy(traced.graph) + traced = fx.GraphModule(traced, new_graph) + for node in traced.graph.nodes: + node.meta = {} + base, ext = os.path.splitext(fname) + if not ext: + ext = "." + config.torch_compile_graph_format + print(f"Writing FX graph to file: {base}{ext}") + g = graph_drawer.FxGraphDrawer( + traced, + figname, + parse_stack_trace=parse_stack_trace, + dot_graph_shape=dot_graph_shape, + ) + x = g.get_main_dot_graph() + write_method = getattr(x, "write_" + ext.lstrip(".")) + fname = f"{base}{ext}" + if prog is None: + write_method(fname) + else: + write_method(fname, prog=prog) diff --git a/lib/python3.10/site-packages/torch/_functorch/pyfunctorch.py b/lib/python3.10/site-packages/torch/_functorch/pyfunctorch.py new file mode 100644 index 0000000000000000000000000000000000000000..b2dfaa116f72924c68351aadcb9e007886343faa --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/pyfunctorch.py @@ -0,0 +1,294 @@ +# mypy: allow-untyped-defs +import contextlib +from abc import ABC, abstractmethod +from typing import Any, List, Tuple + +import torch +import torch.utils._pytree as pytree +from torch._C._functorch import ( + CFunctionalizeInterpreterPtr, + CGradInterpreterPtr, + CInterpreter, + CJvpInterpreterPtr, + CVmapInterpreterPtr, + pop_dynamic_layer_stack, + push_dynamic_layer_stack, + RandomnessType, + TransformType, +) +from torch.autograd.forward_ad import _set_fwd_grad_enabled + + +""" +This file contains the functorch integration with PyDispatcher. + +PyDispatcher does not understand functorch's DynamicLayerStack dispatching +logic because it is entirely implemented in C++ in the fallbacks for two +dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable +to directly reuse C++ boxed fallbacks). + +Instead of trying to hammer PyDispatcher into understanding those fallbacks, +we re-implement the logic of peeking the top of the stack for an interpreter, +selecting the interpreter to dispatch on, etc, in Python. This leads to a +simpler design. + +The main difference between C++ functorch and PyDispatcher's functorch logic +is that: +- C++ functorch needs to manually tweak dispatch keys to ping-pong between + DynamicLayerFrontMode and DynamicLayerBackMode. +- PyDispatcher's functorch logic pops an Interpreter from the top of the stack + and asks it to execute the rule associated with the Interpreter. + +In C++ we do the ping-pong because e.g. vmap rules are associated with the +batched DispatchKey, but in PyDispatcher we are able to avoid this by asking +the user to register a batching rule directly to a transform that an +interpreter then invokes. +""" + + +# FuncTorchInterpreter is the Python version of Interpreter (recall that +# the DynamicLayerStack is a stack of interpreters). +# It is a wrapper around the actual C++ Interpreter object. +# +# Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h +class FuncTorchInterpreter(ABC): + def __init__(self, cptr: Any): + self._cptr = cptr + + # Process an operation. eg for vmap, this is invoking a batching rule. + # Conceptually this is analogous to Interpreter::process in C++ + @abstractmethod + def process(self, op, args, kwargs): + pass + + # lower an operation from this Interpreter to the next Interpreter on the stack. + # Concretely, this involves temporarily popping the current Interpreter. + # Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++ + def lower(self): + return temporarily_pop_interpreter_stack() + + def level(self): + return self._cptr.level() + + def key(self): + return self._cptr.key() + + def get_state(self): + raise NotImplementedError + + def check_state(self, state): + return state == self.get_state() + + +@contextlib.contextmanager +def temporarily_pop_interpreter_stack(): + try: + saved = pop_dynamic_layer_stack() + yield + finally: + push_dynamic_layer_stack(saved) + + +@contextlib.contextmanager +def temporarily_clear_interpreter_stack(): + stack = [] + try: + while torch._C._functorch.peek_interpreter_stack() is not None: + stack.append(pop_dynamic_layer_stack()) + yield list(stack) + finally: + while stack: + push_dynamic_layer_stack(stack.pop()) + + +@contextlib.contextmanager +def temporarily_restore_interpreter_stack(stack): + pushed = [] + try: + for s in reversed(stack): + push_dynamic_layer_stack(s) + pushed.append(s) + yield + finally: + for s in reversed(pushed): + # TODO: would be nice to assert that the layers are the same, but + # Python object identity is not preserved + pop_dynamic_layer_stack() + + +class VmapInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Vmap + # NOTE: [Interpreter cdata vs cptr] + # cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr + # so that we can access methods specific to the vmap interpreter + self._cdata = cdata + self._cptr = CVmapInterpreterPtr(cdata) + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Vmap] + return kernel(self, *args, **kwargs) + + def batch_size(self): + return self._cptr.batchSize() + + def randomness(self): + typ = self._cptr.randomness() + if typ == RandomnessType.Error: + return "error" + elif typ == RandomnessType.Same: + return "same" + elif typ == RandomnessType.Different: + return "different" + raise RuntimeError(f"Unknown RandomnessType: {typ}") + + def get_state(self): + return (self.key().name, self.level(), self.randomness()) + + +@contextlib.contextmanager +def nested(*contexts): + with contextlib.ExitStack() as stack: + for ctx in contexts: + stack.enter_context(ctx) + yield contexts + + +class GradInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Grad + # See NOTE: [Interpreter cdata vs cptr] + self._cdata = cdata + self._cptr = CGradInterpreterPtr(cdata) + + def lift(self, args, kwargs): + args, kwargs = pytree.tree_map_only( + torch.Tensor, self._cptr.lift, [args, kwargs] + ) + return args, kwargs + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Grad] + args, kwargs = self.lift(args, kwargs) + return kernel(self, *args, **kwargs) + + # GradInterpreter has custom lower because of the no_grad interaction + # See NOTE [grad and vjp interaction with no_grad] + # This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter + def lower(self): + prev_grad_mode = self.prev_grad_mode() + if not prev_grad_mode: + return nested(torch.no_grad(), super().lower()) + return super().lower() + + def prev_grad_mode(self): + return self._cptr.prevGradMode() + + def get_state(self): + return (self.key().name, self.level(), self.prev_grad_mode()) + + +class JvpInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Jvp + # See NOTE: [Interpreter cdata vs cptr] + self._cdata = cdata + self._cptr = CJvpInterpreterPtr(cdata) + + def lift(self, args, kwargs): + args, kwargs = pytree.tree_map_only( + torch.Tensor, self._cptr.lift, [args, kwargs] + ) + return args, kwargs + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Jvp] + args, kwargs = self.lift(args, kwargs) + return kernel(self, *args, **kwargs) + + # Jvp has custom lower because of the no_fwd_grad interaction + # See NOTE [grad and vjp interaction with no_grad] for related info. + # This logic is mirrored from C++ JvpInterpreterPtr::sendToNextInterpreter + def lower(self): + prev_fwd_grad_mode = self.prev_fwd_grad_mode() + if not prev_fwd_grad_mode: + return nested(_set_fwd_grad_enabled(False), super().lower()) + return super().lower() + + def prev_fwd_grad_mode(self): + return self._cptr.prevFwdGradMode() + + def get_state(self): + return (self.key().name, self.level(), self.prev_fwd_grad_mode()) + + +class FunctionalizeInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Functionalize + self._cdata = cdata + self._cptr = CFunctionalizeInterpreterPtr(cdata) + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Functionalize] + return kernel(self, *args, **kwargs) + + def functionalize_add_back_views(self): + return self._cptr.functionalizeAddBackViews() + + def get_state(self): + return (self.key().name, self.level()) + + +def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter: + key = cinterpreter.key() + if key == TransformType.Grad: + return GradInterpreter(cinterpreter) + if key == TransformType.Vmap: + return VmapInterpreter(cinterpreter) + if key == TransformType.Jvp: + return JvpInterpreter(cinterpreter) + if key == TransformType.Functionalize: + return FunctionalizeInterpreter(cinterpreter) + raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}") + + +def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter: + interpreter = torch._C._functorch.peek_interpreter_stack() + assert interpreter is not None + return coerce_cinterpreter(interpreter) + + +def retrieve_all_functorch_interpreters() -> List[FuncTorchInterpreter]: + cis = torch._C._functorch.get_interpreter_stack() + if cis is None: + return [] + return [coerce_cinterpreter(ci) for ci in cis] + + +def compare_functorch_state(states: List[Tuple[Any, ...]]) -> bool: + # There are four possible cases covered here: + # 1. Current stack empty AND stack when generated not empty -> Invalidate + # 2. Current stack not empty AND stack when generated empty -> Invalidate + # 3. Current stack and generated stack empty -> Valid FX graph + # 4. Current stack and generated stack not empty -> Valid if both states match + peek = torch._C._functorch.peek_interpreter_stack() + if (peek is None and len(states) != 0) or (peek is not None and len(states) == 0): + return False + + cis = retrieve_all_functorch_interpreters() + return len(cis) == len(states) and all( + ci.check_state(state) for ci, state in zip(cis, states) + ) + + +def dispatch_functorch(op, args, kwargs): + interpreter = retrieve_current_functorch_interpreter() + # In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's + # unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers. + # PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch + # transforms, so we manually unwrap the dead tensors here. + # This logic won't need to exist when we have mode-only functorch. + args, kwargs = pytree.tree_map_only( + torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs) + ) + return interpreter.process(op, args, kwargs) diff --git a/lib/python3.10/site-packages/torch/_functorch/python_key.py b/lib/python3.10/site-packages/torch/_functorch/python_key.py new file mode 100644 index 0000000000000000000000000000000000000000..557334f68928a057a6a9e036c904c8c0bd3231c1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/python_key.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +__all__ = ["make_fx", "dispatch_trace", "PythonKeyTracer", "pythonkey_decompose"] +from torch.fx.experimental.proxy_tensor import ( + decompose, + dispatch_trace, + make_fx, + PythonKeyTracer, +) + + +pythonkey_decompose = decompose diff --git a/lib/python3.10/site-packages/torch/_functorch/pytree_hacks.py b/lib/python3.10/site-packages/torch/_functorch/pytree_hacks.py new file mode 100644 index 0000000000000000000000000000000000000000..96dea7ad100705ae53139aa5ae729fd2206182af --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/pytree_hacks.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +# TODO: remove this file when the migration of the pytree utility is done +from torch.utils._pytree import tree_map_, treespec_pprint + + +__all__ = ["tree_map_", "treespec_pprint"] + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch._functorch.pytree_hacks` is deprecated and will be removed in a future release. " + "Please `use torch.utils._pytree` instead.", + DeprecationWarning, + stacklevel=2, + ) diff --git a/lib/python3.10/site-packages/torch/_functorch/top_operators_github_usage.py b/lib/python3.10/site-packages/torch/_functorch/top_operators_github_usage.py new file mode 100644 index 0000000000000000000000000000000000000000..1fcdbe0b41ac00c21198e3a82d58c5a446d19324 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/top_operators_github_usage.py @@ -0,0 +1,632 @@ +# mypy: ignore-errors + +""" +From https://docs.google.com/spreadsheets/d/12R3nCOLskxPYjjiNkdqy4OdQ65eQp_htebXGODsjSeA/edit#gid=0 +Try to keep this list in sync with that. +""" +import operator + + +top_torch = [ + ("t", 6837449), + ("tensor", 585786), + ("mode", 462182), + ("cat", 394818), + ("max", 368038), + ("zeros", 329495), + ("load", 327756), + ("no_grad", 294694), + ("save", 265130), + ("from_numpy", 243063), + ("manual_seed", 165044), + ("ones", 153696), + ("randn", 150796), + ("stack", 133358), + ("sum", 130772), + ("arange", 98087), + ("rand", 94715), + ("mean", 88546), + ("exp", 73883), + ("zeros_like", 72831), + ("min", 72248), + ("sigmoid", 66798), + ("log", 62135), + ("matmul", 47811), + ("clamp", 45304), + ("sqrt", 44911), + ("abs", 43535), + ("tanh", 42793), + ("empty", 40311), + ("argmax", 38435), + ("bmm", 33984), + ("pow", 33571), + ("norm", 31125), + ("mm", 30995), + ("is_tensor", 29546), + ("ones_like", 29512), + ("nonzero", 28681), + ("full", 28373), + ("unsqueeze", 27911), + ("where", 26585), + ("randperm", 26450), + ("eye", 24342), + ("mul", 23236), + ("topk", 22537), + ("as_tensor", 21967), + ("sort", 21412), + ("squeeze", 20863), + ("randint", 20771), + ("linspace", 20041), + ("add", 19201), + ("transpose", 18663), + ("split", 18325), + ("gather", 17904), + ("set_grad_enabled", 16013), + ("sin", 15669), + ("cos", 15562), + ("div", 15513), + ("index_select", 14866), + ("multinomial", 14331), + ("flatten", 14267), + ("isnan", 14170), + ("randn_like", 13096), + ("eq", 12680), + ("einsum", 12480), + ("round", 12367), + ("floor", 11628), + ("allclose", 11000), + ("reshape", 10605), + ("diag", 10167), + ("chunk", 9581), + ("std", 9379), + ("set_default_tensor_type", 9281), + ("triu", 8559), + ("meshgrid", 8292), + ("set_num_threads", 8126), + ("unique", 7964), + ("full_like", 7780), + ("tril", 7538), + ("dot", 7275), + ("sign", 6943), + ("equal", 6916), + ("normal", 6750), + ("cumsum", 6556), + ("dist", 6058), + ("isfinite", 6030), + ("gt", 5935), + ("set_printoptions", 5888), + ("range", 5491), + ("empty_like", 5351), + ("flip", 5342), + ("masked_select", 5341), + ("bernoulli", 5262), + ("atan", 5253), + ("var", 5247), + ("prod", 5200), + ("erf", 5088), + ("inverse", 5072), + ("addmm", 4854), + ("logsumexp", 4582), + ("fft", 4436), + ("lt", 4421), + ("log2", 4316), + ("enable_grad", 4238), + ("rand_like", 4187), + ("argsort", 3972), + ("seed", 3932), + ("mv", 3547), + ("ger", 3309), + ("ge", 3248), + ("atan2", 3210), + ("ceil", 3202), + ("ne", 3075), + ("bincount", 3063), + ("acos", 3055), + ("rsqrt", 3031), + ("svd", 3029), + ("numel", 3003), + ("log1p", 2840), + ("unbind", 2808), + ("le", 2714), + ("isinf", 2707), + ("cross", 2646), + ("set_default_dtype", 2536), + ("argmin", 2535), + ("sparse_coo_tensor", 2489), + ("log10", 2304), + ("kthvalue", 2192), + ("set_rng_state", 2158), + ("get_rng_state", 1996), + ("get_default_dtype", 1879), + ("det", 1868), + ("qr", 1864), + ("histc", 1852), + ("symeig", 1832), + ("trace", 1801), + ("median", 1795), + ("addcmul", 1751), + ("remainder", 1717), + ("baddbmm", 1693), + ("lgamma", 1665), + ("repeat_interleave", 1598), + ("fmod", 1576), + ("reciprocal", 1575), + ("tan", 1560), + ("initial_seed", 1532), + ("take", 1529), + ("stft", 1487), + ("get_num_threads", 1477), + ("real", 1459), + ("cholesky", 1406), + ("quantize_per_tensor", 1392), + ("diag_embed", 1364), + ("lerp", 1363), + ("asin", 1345), + ("eig", 1333), + ("trunc", 1290), + ("diagonal", 1287), + ("cosh", 1279), + ("rfft", 1269), + ("cumprod", 1260), + ("addr", 1211), + ("roll", 1198), + ("narrow", 1188), + ("digamma", 1172), + ("square", 1163), + ("sinh", 1131), + ("logspace", 1084), + ("broadcast_tensors", 1070), + ("irfft", 1013), + ("frac", 997), + ("hann_window", 994), + ("solve", 989), + ("logdet", 977), + ("expm1", 968), + ("cdist", 946), + ("addmv", 903), + ("randint_like", 888), + ("tensordot", 888), + ("ifft", 877), + ("true_divide", 854), + ("erfinv", 830), + ("addcdiv", 819), + ("addbmm", 813), + ("renorm", 781), + ("pinverse", 753), + ("isclose", 740), + ("erfc", 729), + ("is_storage", 725), + ("triangular_solve", 723), + ("rot90", 709), + ("logical_not", 686), + ("geqrf", 681), + ("slogdet", 677), + ("lu", 665), + ("hamming_window", 659), + ("orgqr", 651), + ("ormqr", 622), + ("is_floating_point", 602), + ("diagflat", 562), + ("cholesky_solve", 559), + ("tril_indices", 552), + ("chain_matmul", 551), + ("triu_indices", 548), + ("angle", 522), + ("poisson", 505), + ("matrix_power", 485), + ("unique_consecutive", 471), + ("quantize_per_channel", 465), + ("std_mean", 458), + ("bartlett_window", 447), + ("var_mean", 428), + ("lstsq", 421), + ("logical_and", 419), + ("mvlgamma", 411), + ("blackman_window", 400), + ("bitwise_not", 395), + ("cholesky_inverse", 388), + ("as_strided", 384), + ("floor_divide", 353), + ("cartesian_prod", 321), + ("lu_solve", 317), + ("set_flush_denormal", 310), + ("empty_strided", 283), + ("logical_xor", 282), + ("polygamma", 282), + ("logical_or", 280), + ("set_num_interop_threads", 278), + ("combinations", 274), + ("trapz", 270), + ("matrix_rank", 260), + ("lu_unpack", 255), + ("result_type", 244), + ("conj", 231), + ("cummax", 230), + ("lobpcg", 229), + ("bitwise_xor", 217), + ("promote_types", 213), + ("get_num_interop_threads", 211), + ("cummin", 205), + ("bitwise_and", 198), + ("dequantize", 192), + ("bitwise_or", 191), + ("imag", 191), + ("can_cast", 184), + ("istft", 180), + ("compiled_with_cxx11_abi", 159), + ("is_complex", 151), + ("block_diag", 136), + ("pca_lowrank", 124), + ("absolute", 122), + ("svd_lowrank", 108), + ("neg", 2), +] + +top_nn_functional = [ + ("nn.functional.softmax", 10522), + ("nn.functional.relu", 8572), + ("nn.functional.interpolate", 7277), + ("nn.functional.pad", 5207), + ("nn.functional.log_softmax", 4699), + ("nn.functional.normalize", 2338), + ("nn.functional.cross_entropy", 2083), + ("nn.functional.grid_sample", 1970), + ("nn.functional.one_hot", 1967), + ("nn.functional.mse_loss", 1920), + ("nn.functional.conv2d", 1593), + ("nn.functional.dropout", 1516), + ("nn.functional.softplus", 1385), + ("nn.functional.sigmoid", 1128), + ("nn.functional.linear", 1036), + ("nn.functional.gelu", 930), + ("nn.functional.avg_pool2d", 899), + ("nn.functional.max_pool2d", 876), + ("nn.functional.nll_loss", 863), + ("nn.functional.embedding", 737), + ("nn.functional.tanh", 664), + ("nn.functional.leaky_relu", 640), + ("nn.functional.adaptive_avg_pool2d", 633), + ("nn.functional.cosine_similarity", 627), + ("nn.functional.unfold", 609), + ("nn.functional.conv1d", 596), + ("nn.functional.binary_cross_entropy_with_logits", 591), + ("nn.functional.l1_loss", 571), + ("nn.functional.binary_cross_entropy", 492), + ("nn.functional.elu", 416), + ("nn.functional.batch_norm", 413), + ("nn.functional.upsample", 413), + ("nn.functional.fold", 305), + ("nn.functional.affine_grid", 298), + ("nn.functional.max_pool1d", 297), + ("nn.functional.torch", 294), + ("nn.functional.threshold", 263), + ("nn.functional.smooth_l1_loss", 262), + ("nn.functional.pairwise_distance", 253), + ("nn.functional.logsigmoid", 243), + ("nn.functional.adaptive_max_pool2d", 235), + ("nn.functional.relu6", 213), + ("nn.functional.pixel_shuffle", 209), + ("nn.functional.avg_pool3d", 203), + ("nn.functional.bilinear", 203), + ("nn.functional.conv_transpose2d", 201), + ("nn.functional.gumbel_softmax", 197), + ("nn.functional.max_unpool2d", 196), + ("nn.functional.kl_div", 191), + ("nn.functional.hardtanh", 189), + ("nn.functional.ctc_loss", 185), + ("nn.functional.layer_norm", 178), + ("nn.functional.conv3d", 172), + ("nn.functional.max_unpool3d", 167), + ("nn.functional.hardshrink", 165), + ("nn.functional.hardswish", 156), + ("nn.functional.selu", 156), + ("nn.functional.glu", 155), + ("nn.functional.assert_int_or_pair", 150), + ("nn.functional.hardsigmoid", 146), + ("nn.functional.upsample_bilinear", 146), + ("nn.functional.max_pool3d", 140), + ("nn.functional.adaptive_avg_pool3d", 139), + ("nn.functional.instance_norm", 124), + ("nn.functional.embedding_bag", 122), + ("nn.functional.upsample_nearest", 110), + ("nn.functional.avg_pool1d", 105), + ("nn.functional.prelu", 102), + ("nn.functional.celu", 92), + ("nn.functional.dropout2d", 86), + ("nn.functional.hinge_embedding_loss", 82), + ("nn.functional.softsign", 81), + ("nn.functional.max_unpool1d", 74), + ("nn.functional.silu", 74), + ("nn.functional.softshrink", 70), + ("nn.functional.leaky_relu_", 68), + ("nn.functional.softmin", 67), + ("nn.functional.channel_shuffle", 66), + ("nn.functional.multilabel_margin_loss", 66), + ("nn.functional.dropout3d", 65), + ("nn.functional.multi_margin_loss", 65), + ("nn.functional.lp_pool2d", 64), + ("nn.functional.conv_transpose1d", 62), + ("nn.functional.triplet_margin_loss", 62), + ("nn.functional.tanhshrink", 61), + ("nn.functional.adaptive_max_pool1d", 59), + ("nn.functional.cosine_embedding_loss", 58), + ("nn.functional.multi_head_attention_forward", 58), + ("nn.functional.max_pool1d_with_indices", 53), + ("nn.functional.poisson_nll_loss", 53), + ("nn.functional.margin_ranking_loss", 52), + ("nn.functional.soft_margin_loss", 52), + ("nn.functional.adaptive_max_pool3d", 51), + ("nn.functional.group_norm", 51), + ("nn.functional.local_response_norm", 51), + ("nn.functional.multilabel_soft_margin_loss", 51), + ("nn.functional.relu_", 50), + ("nn.functional.alpha_dropout", 49), + ("nn.functional.feature_alpha_dropout", 49), + ("nn.functional.lp_pool1d", 49), + ("nn.functional.adaptive_max_pool1d_with_indices", 48), + ("nn.functional.adaptive_max_pool2d_with_indices", 48), + ("nn.functional.adaptive_max_pool3d_with_indices", 48), + ("nn.functional.fractional_max_pool2d", 48), + ("nn.functional.fractional_max_pool2d_with_indices", 48), + ("nn.functional.fractional_max_pool3d", 48), + ("nn.functional.fractional_max_pool3d_with_indices", 48), + ("nn.functional.max_pool2d_with_indices", 48), + ("nn.functional.max_pool3d_with_indices", 48), + ("nn.functional.handle_torch_function", 47), + ("nn.functional.has_torch_function", 47), + ("nn.functional.adaptive_avg_pool1d", 43), + ("nn.functional.pdist", 43), + ("nn.functional.rrelu_", 37), + ("nn.functional.elu_", 34), + ("nn.functional.boolean_dispatch", 33), + ("nn.functional.hardtanh_", 26), + ("nn.functional.triplet_margin_with_distance_loss", 23), + ("nn.functional.selu_", 20), + ("nn.functional.pixel_unshuffle", 19), + ("nn.functional.conv_transpose3d", 18), + ("nn.functional.gaussian_nll_loss", 15), + ("nn.functional.has_torch_function_unary", 15), + ("nn.functional.has_torch_function_variadic", 15), + ("nn.functional.celu_", 13), + ("nn.functional.huber_loss", 7), + ("nn.functional.mish", 4), + ("nn.functional.threshold_", 3), + ("nn.functional.grad", 2), + ("nn.functional.conv_tbc", 1), + ("nn.functional.math", 1), +] + +top_nn_module = [ + ("nn.Module", 927129, None), + ("nn.Linear", 530688, "nn.functional.linear"), + ("nn.Sequential", 384968, None), + ("nn.Conv2d", 383320, "nn.functional.conv2d"), + ("nn.ReLU", 318877, "nn.functional.relu"), + ("nn.BatchNorm2d", 233265, "nn.functional.batch_norm"), + ("nn.Dropout", 179268, "nn.functional.dropout"), + ("nn.ModuleList", 171225, None), + ("nn.Parameter", 153291, None), + ("nn.CrossEntropyLoss", 152696, "nn.functional.cross_entropy"), + ("nn.MaxPool2d", 138619, "nn.functional.max_pool2d"), + ("nn.Embedding", 111844, "nn.functional.embedding"), + ("nn.DataParallel", 104238, None), + ("nn.MSELoss", 82954, "nn.functional.mse_loss"), + ("nn.Sigmoid", 75810, "nn.functional.sigmoid"), + ("nn.LeakyReLU", 65632, "nn.functional.leaky_relu"), + ("nn.BatchNorm1d", 65374, "nn.functional.batch_norm"), + ("nn.Softmax", 65114, "nn.functional.softmax"), + ("nn.Tanh", 59445, "nn.functional.tanh"), + ("nn.AdaptiveAvgPool2d", 59071, "nn.functional.adaptive_avg_pool2d"), + ("nn.AvgPool2d", 58377, "nn.functional.avg_pool2d"), + ("nn.ConvTranspose2d", 57524, "nn.functional.conv_transpose2d"), + ("nn.LSTM", 57411, None), + ("nn.Conv1d", 41108, "nn.functional.conv1d"), + ("nn.LayerNorm", 36089, "nn.functional.layer_norm"), + ("nn.BCELoss", 34005, "nn.functional.binary_cross_entropy"), + ("nn.Upsample", 32527, "nn.functional.interpolate"), + ("nn.BCEWithLogitsLoss", 29944, "nn.functional.binary_cross_entropy_with_logits"), + ("nn.GRU", 25421, None), + ("nn.Dropout2d", 23512, "nn.functional.dropout2d"), + ("nn.LogSoftmax", 22897, "nn.functional.log_softmax"), + ("nn.L1Loss", 22778, "nn.functional.l1_loss"), + ("nn.GroupNorm", 22183, "nn.functional.group_norm"), + ("nn.NLLLoss", 21751, "nn.functional.nll_loss"), + ("nn.Conv3d", 20874, "nn.functional.conv3d"), + ("nn.Identity", 17911, None), + ("nn.InstanceNorm2d", 16426, "nn.functional.instance_norm"), + ("nn.BatchNorm3d", 16378, "nn.functional.batch_norm"), + ("nn.PReLU", 13472, "nn.functional.prelu"), + ("nn.ReLU6", 12622, "nn.functional.relu6"), + ("nn.ELU", 12508, "nn.functional.elu"), + ("nn.LSTMCell", 10885, None), + ("nn.Flatten", 10384, "torch.flatten"), + ("nn.ModuleDict", 10255, None), + ("nn.ReflectionPad2d", 9954, "nn.functional.pad"), + ("nn.MaxPool3d", 9526, "nn.functional.max_pool3d"), + ("nn.MaxPool1d", 9154, "nn.functional.max_pool1d"), + ("nn.RNN", 9154, None), + ("nn.ZeroPad2d", 8847, "nn.functional.pad"), + ("nn.ParameterList", 7702, None), + ("nn.SyncBatchNorm", 6814, None), + ("nn.PixelShuffle", 6571, "nn.functional.pixel_shuffle"), + ("nn.SmoothL1Loss", 6517, "nn.functional.smooth_l1_loss"), + ("nn.Hardswish", 6458, "nn.functional.hardswish"), + ("nn.AdaptiveMaxPool2d", 6071, "nn.functional.adaptive_max_pool2d"), + ("nn.SELU", 6043, "nn.functional.selu"), + ("nn.ConvTranspose3d", 6039, "nn.functional.conv_transpose3d"), + ("nn.GRUCell", 5840, None), + ("nn.ReplicationPad2d", 5600, "nn.functional.pad"), + ("nn.KLDivLoss", 5541, "nn.functional.kl_div"), + ("nn.ConvTranspose1d", 5183, "nn.functional.conv_transpose1d"), + ("nn.Softplus", 5120, "nn.functional.softplus"), + ("nn.SiLU", 4895, "nn.functional.silu"), + ("nn.AvgPool3d", 4523, "nn.functional.avg_pool3d"), + ("nn.CosineSimilarity", 4058, "nn.functional.cosine_similarity"), + ("nn.GELU", 3932, "nn.functional.gelu"), + ("nn.UpsamplingBilinear2d", 3673, "nn.functional.interpolate"), + ("nn.InstanceNorm1d", 3658, "nn.functional.instance_norm"), + ("nn.Transformer", 3604, None), + ("nn.MultiheadAttention", 3435, "nn.functional.multi_head_attention_forward"), + ("nn.AvgPool1d", 3195, "nn.functional.avg_pool1d"), + ("nn.Dropout3d", 2964, "nn.functional.dropout3d"), + ("nn.AdaptiveAvgPool3d", 2915, "nn.functional.adaptive_avg_pool3d"), + ("nn.InstanceNorm3d", 2893, "nn.functional.instance_norm"), + ("nn.Hardtanh", 2613, "nn.functional.hardtanh"), + ("nn.MarginRankingLoss", 2568, "nn.functional.margin_ranking_loss"), + ("nn.GLU", 2526, "nn.functional.glu"), + ("nn.AdaptiveAvgPool1d", 2481, "nn.functional.adaptive_avg_pool1d"), + ("nn.EmbeddingBag", 2344, "nn.functional.embedding_bag"), + ("nn.TransformerEncoderLayer", 2292, None), + ("nn.TransformerEncoder", 2091, None), + ("nn.MaxUnpool2d", 2031, "nn.functional.max_unpool2d"), + ("nn.UpsamplingNearest2d", 2004, "nn.functional.interpolate"), + ("nn.ConstantPad1d", 1904, "nn.functional.pad"), + ("nn.ConstantPad2d", 1791, "nn.functional.pad"), + ("nn.CTCLoss", 1789, "nn.functional.ctc_loss"), + ("nn.AdaptiveMaxPool1d", 1713, "nn.functional.adaptive_max_pool1d"), + ("nn.AdaptiveLogSoftmaxWithLoss", 1665, None), + ("nn.Bilinear", 1664, "nn.functional.bilinear"), + ("nn.RNNCell", 1653, None), + ("nn.MultiLabelSoftMarginLoss", 1624, "nn.functional.multilabel_soft_margin_loss"), + ("nn.Unfold", 1452, "nn.functional.unfold"), + ("nn.RReLU", 1431, "nn.functional.rrelu"), + ("nn.CosineEmbeddingLoss", 1357, "nn.functional.cosine_embedding_loss"), + ("nn.LocalResponseNorm", 1331, "nn.functional.local_response_norm"), + ("nn.Softmax2d", 1300, "nn.functional.softmax"), + ("nn.PairwiseDistance", 1241, "nn.functional.pairwise_distance"), + ("nn.LogSigmoid", 1235, "nn.functional.logsigmoid"), + ("nn.TripletMarginLoss", 1230, "nn.functional.triplet_margin_loss"), + ("nn.RNNBase", 1133, None), + ("nn.Threshold", 1043, "nn.functional.threshold"), + ("nn.AdaptiveMaxPool3d", 1025, "nn.functional.adaptive_max_pool3d"), + ("nn.CELU", 1018, "nn.functional.celu"), + ("nn.NLLLoss2d", 966, "nn.functional.nll_loss"), + ("nn.Softsign", 877, "nn.functional.softsign"), + ("nn.ReplicationPad1d", 862, "nn.functional.pad"), + ("nn.SoftMarginLoss", 856, "nn.functional.soft_margin_loss"), + ("nn.ParameterDict", 742, None), + ("nn.ReflectionPad1d", 731, "nn.functional.pad"), + ("nn.Softshrink", 713, "nn.functional.softshrink"), + ("nn.AlphaDropout", 710, "nn.functional.alpha_dropout"), + ("nn.Tanhshrink", 681, "nn.functional.tanhshrink"), + ("nn.PoissonNLLLoss", 676, "nn.functional.poisson_nll_loss"), + ("nn.MaxUnpool3d", 660, "nn.functional.max_unpool3d"), + ("nn.Fold", 630, "nn.functional.fold"), + ("nn.MultiMarginLoss", 622, "nn.functional.multi_margin_loss"), + ("nn.TransformerDecoderLayer", 614, None), + ("nn.TransformerDecoder", 607, None), + ("nn.Hardshrink", 592, "nn.functional.hardshrink"), + ("nn.ConstantPad3d", 582, "nn.functional.pad"), + ("nn.MultiLabelMarginLoss", 580, "nn.functional.multilabel_margin_loss"), + ("nn.LPPool2d", 550, "nn.functional.lp_pool2d"), + ("nn.Softmin", 537, "nn.functional.softmin"), + ("nn.MaxUnpool1d", 518, "nn.functional.max_unpool1d"), + ("nn.FractionalMaxPool2d", 484, "nn.functional.fractional_max_pool2d"), + ("nn.Hardsigmoid", 477, "nn.functional.hardsigmoid"), + ("nn.ReplicationPad3d", 470, "nn.functional.pad"), + ("nn.HingeEmbeddingLoss", 442, "nn.functional.hinge_embedding_loss"), + ("nn.LPPool1d", 386, "nn.functional.lp_pool1d"), + ("nn.FractionalMaxPool3d", 252, "nn.functional.fractional_max_pool3d"), + ("nn.Container", 217, None), + ("nn.Unflatten", 206, "nn.functional.unflatten"), + ("nn.FeatureAlphaDropout", 136, "nn.functional.feature_alpha_dropout"), + ( + "nn.TripletMarginWithDistanceLoss", + 107, + "nn.functional.triplet_margin_with_distance_loss", + ), + ("nn.ChannelShuffle", 90, "nn.functional.channel_shuffle"), + ("nn.RNNCellBase", 88, None), + ("nn.LazyLinear", 81, "nn.functional.linear"), + ("nn.UninitializedParameter", 60, None), + ("nn.CrossMapLRN2d", 59, None), + ("nn.GaussianNLLLoss", 55, "nn.functional.gaussian_nll_loss"), + ("nn.PixelUnshuffle", 45, "nn.functional.pixel_unshuffle"), + ("nn.Mish", 31, "nn.functional.mish"), + ("nn.ReflectionPad3d", 22, "nn.functional.pad"), + ("nn.HuberLoss", 18, "nn.functional.huber_loss"), + ("nn.LazyConv2d", 15, None), + ("nn.LazyConv1d", 9, None), + ("nn.LazyConv3d", 8, None), + ("nn.LazyConvTranspose1d", 8, None), + ("nn.LazyConvTranspose2d", 8, None), + ("nn.LazyConvTranspose3d", 8, None), + ("nn.LazyBatchNorm1d", 3, None), + ("nn.LazyBatchNorm2d", 3, None), + ("nn.LazyBatchNorm3d", 3, None), + ("nn.UninitializedBuffer", 3, None), +] + +# No rankings because these are a little hard to get rankings for +method_only_ops = [ + "bfloat16", + "bool", + "byte", + "char", + "contiguous", + "cpu", + "cuda", + "detach", + "double", + "expand", + "expand_as", + "float", + "get_device", + "half", + "hardshrink", + "index_add", + "index_copy", + "index_fill", + "index_put", + "int", + "is_contiguous", + "is_pinned", + "is_set_to", + "is_shared", + "is_signed", + "item", + "long", + "masked_scatter", + "masked_fill", + "narrow_copy", + "numpy", + "pin_memory", + "repeat", + "reshape_as", + "select", + "short", + "storage_offset", + "sum_to_size", + "to", + "to_mkldnn", + "tolist", + "type", + "type_as", + "unfold", + "view", + "view_as", +] + + +def get_nn_functional_top_list(): + top_nn_functional_ = dict(top_nn_functional) + for _, count, functional_name in top_nn_module: + if functional_name is None: + continue + if functional_name == "torch.flatten": + continue + if functional_name not in top_nn_functional_: + top_nn_functional_[functional_name] = count + else: + top_nn_functional_[functional_name] += count + + top_nn_functional_ = list(top_nn_functional_.items()) + top_nn_functional_.sort(key=operator.itemgetter(1), reverse=True) + return top_nn_functional_ + + +usage_count = {} +for k, v in get_nn_functional_top_list(): + usage_count[k] = v +for k, v in top_torch: + usage_count[k] = v diff --git a/lib/python3.10/site-packages/torch/_functorch/utils.py b/lib/python3.10/site-packages/torch/_functorch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..514b2f4e25586d6a432fe257197571dd4b95f11b --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/utils.py @@ -0,0 +1,40 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Tuple, Union + +import torch +from torch._C._functorch import ( + get_single_level_autograd_function_allowed, + set_single_level_autograd_function_allowed, + unwrap_if_dead, +) +from torch.utils._exposed_in import exposed_in + + +__all__ = [ + "exposed_in", + "argnums_t", + "enable_single_level_autograd_function", + "unwrap_dead_wrappers", +] + + +@contextlib.contextmanager +def enable_single_level_autograd_function(): + try: + prev_state = get_single_level_autograd_function_allowed() + set_single_level_autograd_function_allowed(True) + yield + finally: + set_single_level_autograd_function_allowed(prev_state) + + +def unwrap_dead_wrappers(args): + # NB: doesn't use tree_map_only for performance reasons + result = tuple( + unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args + ) + return result + + +argnums_t = Union[int, Tuple[int, ...]] diff --git a/lib/python3.10/site-packages/torch/_functorch/vmap.py b/lib/python3.10/site-packages/torch/_functorch/vmap.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb96ad06d24ec39a09576f912aab064c007e7a7 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_functorch/vmap.py @@ -0,0 +1,532 @@ +# mypy: ignore-errors + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import functools +import itertools +import os +import threading +from functools import partial +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch._C._functorch import ( + _add_batch_dim, + _remove_batch_dim, + _vmap_decrement_nesting, + _vmap_increment_nesting, + is_batchedtensor, +) +from torch.utils._pytree import ( + _broadcast_to_and_flatten, + tree_flatten, + tree_map_, + tree_unflatten, + TreeSpec, +) + + +in_dims_t = Union[int, Tuple] +out_dims_t = Union[int, Tuple[int, ...]] + + +def doesnt_support_saved_tensors_hooks(f): + message = ( + "torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. " + "Please open an issue with your use case." + ) + + @functools.wraps(f) + def fn(*args, **kwargs): + with torch.autograd.graph.disable_saved_tensors_hooks(message): + return f(*args, **kwargs) + + return fn + + +# Checks that all args-to-be-batched have the same batch dim size +def _validate_and_get_batch_size( + flat_in_dims: List[Optional[int]], flat_args: List +) -> int: + batch_sizes = [ + arg.size(in_dim) + for in_dim, arg in zip(flat_in_dims, flat_args) + if in_dim is not None + ] + if len(batch_sizes) == 0: + raise ValueError("vmap: Expected at least one Tensor to vmap over") + if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes): + raise ValueError( + f"vmap: Expected all tensors to have the same size in the mapped " + f"dimension, got sizes {batch_sizes} for the mapped dimension" + ) + return batch_sizes[0] + + +def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int: + if isinstance(batched_outputs, tuple): + return len(batched_outputs) + return 1 + + +# If value is a tuple, check it has length `num_elements`. +# If value is not a tuple, make a tuple with `value` repeated `num_elements` times + + +def _as_tuple( + value: Any, num_elements: int, error_message_lambda: Callable[[], str] +) -> Tuple: + if not isinstance(value, tuple): + return (value,) * num_elements + if len(value) != num_elements: + raise ValueError(error_message_lambda()) + return value + + +def _process_batched_inputs( + in_dims: in_dims_t, args: Tuple, func: Callable +) -> Tuple[int, List[Any], List[Any], TreeSpec]: + if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"expected `in_dims` to be int or a (potentially nested) tuple " + f"matching the structure of inputs, got: {type(in_dims)}." + ) + if len(args) == 0: + raise ValueError( + f"vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add " + f"inputs, or you are trying to vmap over a function with no inputs. " + f"The latter is unsupported." + ) + + flat_args, args_spec = tree_flatten(args) + flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) + if flat_in_dims is None: + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"in_dims is not compatible with the structure of `inputs`. " + f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs " + f"has structure {args_spec}." + ) + + for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)): + if not isinstance(in_dim, int) and in_dim is not None: + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"Got in_dim={in_dim} for an input but in_dim must be either " + f"an integer dimension or None." + ) + if isinstance(in_dim, int) and not isinstance(arg, Tensor): + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"Got in_dim={in_dim} for an input but the input is of type " + f"{type(arg)}. We cannot vmap over non-Tensor arguments, " + f"please use None as the respective in_dim" + ) + if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()): + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"Got in_dim={in_dim} for some input, but that input is a Tensor " + f"of dimensionality {arg.dim()} so expected in_dim to satisfy " + f"-{arg.dim()} <= in_dim < {arg.dim()}." + ) + if in_dim is not None and in_dim < 0: + flat_in_dims[i] = in_dim % arg.dim() + + return ( + _validate_and_get_batch_size(flat_in_dims, flat_args), + flat_in_dims, + flat_args, + args_spec, + ) + + +# Creates BatchedTensors for every Tensor in arg that should be batched. +# Returns the (potentially) batched arguments and the batch_size. + + +def _create_batched_inputs( + flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec +) -> Tuple: + # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] + batched_inputs = [ + arg if in_dim is None else _add_batch_dim(arg, in_dim, vmap_level) + for in_dim, arg in zip(flat_in_dims, flat_args) + ] + return tree_unflatten(batched_inputs, args_spec) + + +def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim): + if out_dim is None: + if isinstance(batched_output, torch.Tensor) and is_batchedtensor( + batched_output + ): + raise ValueError( + f"vmap({name}, ...): `{name}` can not return a " + f"BatchedTensor when out_dim is None" + ) + return batched_output + + # out_dim is non None + if not isinstance(batched_output, torch.Tensor): + raise ValueError( + f"vmap({name}, ...): `{name}` must only return " + f"Tensors, got type {type(batched_output)}. " + "Did you mean to set out_dims= to None for output?" + ) + + return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim) + + +# Undos the batching (and any batch dimensions) associated with the `vmap_level`. +def _unwrap_batched( + batched_outputs: Union[Tensor, Tuple[Tensor, ...]], + out_dims: out_dims_t, + vmap_level: int, + batch_size: int, + func: Callable, +) -> Tuple: + flat_batched_outputs, output_spec = tree_flatten(batched_outputs) + + def incompatible_error(): + raise ValueError( + f"vmap({_get_name(func)}, ..., out_dims={out_dims})(): " + f"out_dims is not compatible with the structure of `outputs`. " + f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs " + f"has structure {output_spec}." + ) + + if isinstance(batched_outputs, torch.Tensor): + # Some weird edge case requires us to spell out the following + # see test_out_dims_edge_case + if isinstance(out_dims, int): + flat_out_dims = [out_dims] + elif isinstance(out_dims, tuple) and len(out_dims) == 1: + flat_out_dims = out_dims + elif out_dims is None: + flat_out_dims = [out_dims] + else: + incompatible_error() + else: + flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec) + if flat_out_dims is None: + incompatible_error() + + flat_outputs = [ + _maybe_remove_batch_dim( + _get_name(func), batched_output, vmap_level, batch_size, out_dim + ) + for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims) + ] + return tree_unflatten(flat_outputs, output_spec) + + +def _check_int_or_none(x, func, out_dims): + if isinstance(x, int): + return + if x is None: + return + raise ValueError( + f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be " + f"an int, None or a python collection of ints representing where in the outputs the " + f"vmapped dimension should appear." + ) + + +def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None: + if isinstance(out_dims, int): + return + tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims) + + +def _get_name(func: Callable): + if hasattr(func, "__name__"): + return func.__name__ + + # Not all callables have __name__, in fact, only static functions/methods do. + # A callable created via functools.partial or an nn.Module, to name some + # examples, don't have a __name__. + return repr(func) + + +DECOMPOSITIONS_LOADED = False +DECOMPOSITIONS_LOCK = threading.Lock() +VMAP_DECOMPOSITIONS_LIB = None + + +# torch.package, Python 3.11, and torch.jit-less environments are unhappy with +# decompositions. Only load them when needed if possible. +def lazy_load_decompositions(): + global DECOMPOSITIONS_LOADED + if DECOMPOSITIONS_LOADED: + return + + with DECOMPOSITIONS_LOCK: + if DECOMPOSITIONS_LOADED: + return + + if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__): + DECOMPOSITIONS_LOADED = True + return + + # use an alternate way to register an operator into the decomposition table + # _register_jit_decomposition doesn't work for some operators, e.g. addr, + # because the Tensor types generated cannot be unioned by torchscript + # decomp should be type OpOverload + global VMAP_DECOMPOSITIONS_LIB + VMAP_DECOMPOSITIONS_LIB = torch.library.Library( + "aten", "IMPL", "FuncTorchBatched" + ) + + from torch._decomp import decomposition_table + + def _register_python_decomposition_vmap(decomp): + if decomp in decomposition_table: + VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp]) + else: + raise RuntimeError(f"could not find decomposition for {decomp}") + + _register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default) + _register_python_decomposition_vmap( + torch.ops.aten.smooth_l1_loss_backward.default + ) + _register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.addr.default) + + DECOMPOSITIONS_LOADED = True + + +def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs): + lazy_load_decompositions() + _check_out_dims_is_int_or_int_pytree(out_dims, func) + batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs( + in_dims, args, func + ) + + if chunk_size is not None: + chunks_flat_args = _get_chunked_inputs( + flat_args, flat_in_dims, batch_size, chunk_size + ) + return _chunked_vmap( + func, + flat_in_dims, + chunks_flat_args, + args_spec, + out_dims, + randomness, + **kwargs, + ) + + # If chunk_size is not specified. + return _flat_vmap( + func, + batch_size, + flat_in_dims, + flat_args, + args_spec, + out_dims, + randomness, + **kwargs, + ) + + +def get_chunk_sizes(total_elems, chunk_size): + n_chunks = n_chunks = total_elems // chunk_size + chunk_sizes = [chunk_size] * n_chunks + # remainder chunk + remainder = total_elems % chunk_size + if remainder != 0: + chunk_sizes.append(remainder) + return chunk_sizes + + +def _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size): + split_idxs = (batch_size,) + if chunk_size is not None: + chunk_sizes = get_chunk_sizes(batch_size, chunk_size) + split_idxs = tuple(itertools.accumulate(chunk_sizes)) + + flat_args_chunks = tuple( + t.tensor_split(split_idxs, dim=in_dim) + if in_dim is not None + else [ + t, + ] + * len(split_idxs) + for t, in_dim in zip(flat_args, flat_in_dims) + ) + + # transpose chunk dim and flatten structure + # chunks_flat_args is a list of flatten args + chunks_flat_args = zip(*flat_args_chunks) + return chunks_flat_args + + +def _flatten_chunks_output(chunks_output_): + # chunks_output is a list of chunked outputs + # flatten chunked outputs: + flat_chunks_output = [] + arg_spec = None + for output in chunks_output_: + flat_output, arg_specs = tree_flatten(output) + flat_chunks_output.append(flat_output) + if arg_spec is None: + arg_spec = arg_specs + + # transpose chunk dim and flatten structure + # flat_output_chunks is flat list of chunks + flat_output_chunks = list(zip(*flat_chunks_output)) + return flat_output_chunks, arg_spec + + +def _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks): + # concat chunks on out_dim + flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec) + assert len(flat_out_dims) == len(flat_output_chunks) + flat_output = [] + for idx, out_dim in enumerate(flat_out_dims): + flat_output.append(torch.cat(flat_output_chunks[idx], dim=out_dim)) + # release tensors + flat_output_chunks[idx] = None + + return flat_output + + +# Applies vmap on chunked_input and returns concatenated output over the chunks. +def _chunked_vmap( + func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs +): + chunks_output = [] + rs = torch.get_rng_state() if randomness == "same" else None + for flat_args in chunks_flat_args: + batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) + + # The way we compute split the input in `_get_chunked_inputs`, + # we may get a tensor with `0` batch-size. We skip any computation + # in that case. + # Eg. + # >>> chunk_size = 1 + # >>> batch_size = 6 + # >>> t = torch.zeros(batch_size, 1) + # >>> t.tensor_split([1, 2, 3, 4, 5, 6]) + # (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), + # tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1))) + if batch_size == 0: + continue + + if rs is not None: + torch.set_rng_state(rs) + chunks_output.append( + _flat_vmap( + func, + batch_size, + flat_in_dims, + flat_args, + args_spec, + out_dims, + randomness, + **kwargs, + ) + ) + + flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output) + + # chunked output tensors are held by both `flat_output_chunks` and `chunks_output`. + # eagerly remove the reference from `chunks_output`. + del chunks_output + + # concat chunks on out_dim + flat_output = _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks) + + # finally unflatten the output + return tree_unflatten(flat_output, arg_spec) + + +# Vmap refactored helper functions: +def _check_randomness_arg(randomness): + if randomness not in ["error", "different", "same"]: + raise RuntimeError( + f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}" + ) + + +@contextlib.contextmanager +def vmap_increment_nesting(batch_size, randomness): + try: + vmap_level = _vmap_increment_nesting(batch_size, randomness) + yield vmap_level + finally: + _vmap_decrement_nesting() + + +def _flat_vmap( + func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs +): + with vmap_increment_nesting(batch_size, randomness) as vmap_level: + batched_inputs = _create_batched_inputs( + flat_in_dims, flat_args, vmap_level, args_spec + ) + batched_outputs = func(*batched_inputs, **kwargs) + return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func) + + +# `restore_vmap` is a private helper function. It is vmap but has the following +# differences: +# - instead of returning outputs, it returns an (outputs, out_dims) tuple. +# out_dims is a pytree of same shape as outputs and contains Optional[int] +# specifying where the vmapped dimension, if it exists, is in the corresponding output. +# - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped). +# restore_vmap allows for no inputs to have the vmap dimension +# - does no validation on outputs (vmap expects only Tensor outputs) +# restore_vmap allows for return of arbitrary outputs (not just Tensors) +# +# The TL;DR is that restore_vmap is more general than vmap and has a slightly +# different API. The relaxations are so that we can "pause" vmap in the middle +# of its execution and then "restore" it later (this is what we do in +# the generate_vmap_rule=True implementation of autograd.Function). +# +# restore_vmap can be technically used in the implementation of vmap, but doing +# that refactor is a bit technically challenging because: +# - vmap couples the tensor-wrapping code with error checking +# - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it +# in python because it overlaps with unwrap_batched +def restore_vmap(func, in_dims, batch_size, randomness): + def inner(*args, **kwargs): + with vmap_increment_nesting(batch_size, randomness) as vmap_level: + batched_inputs = wrap_batched(args, in_dims, vmap_level) + batched_outputs = func(*batched_inputs, **kwargs) + return unwrap_batched(batched_outputs, vmap_level) + + return inner + + +def wrap_batched(args, bdims, level): + flat_args, spec = tree_flatten(args) + flat_bdims = _broadcast_to_and_flatten(bdims, spec) + assert flat_bdims is not None + result = _create_batched_inputs(flat_bdims, flat_args, level, spec) + return result + + +def unwrap_batched(args, level): + flat_args, spec = tree_flatten(args) + if len(flat_args) == 0: + return args, () + result = [ + torch._C._functorch._unwrap_batched(arg, level) + if isinstance(arg, torch.Tensor) + else (arg, None) + for arg in flat_args + ] + output, bdims = zip(*result) + return tree_unflatten(output, spec), tree_unflatten(bdims, spec) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/__init__.py b/lib/python3.10/site-packages/torch/_higher_order_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..72800cae7fc98ca67e94b7bc420a79c2799791d5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/__init__.py @@ -0,0 +1,16 @@ +from torch._higher_order_ops.cond import cond +from torch._higher_order_ops.flex_attention import ( + flex_attention, + flex_attention_backward, +) +from torch._higher_order_ops.hints_wrap import hints_wrapper +from torch._higher_order_ops.while_loop import while_loop + + +__all__ = [ + "cond", + "while_loop", + "flex_attention", + "flex_attention_backward", + "hints_wrapper", +] diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/associative_scan.py b/lib/python3.10/site-packages/torch/_higher_order_ops/associative_scan.py new file mode 100644 index 0000000000000000000000000000000000000000..11dcca5eb4959ff3f62efc9e72c923f7a750dd00 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/associative_scan.py @@ -0,0 +1,365 @@ +# mypy: allow-untyped-defs +import functools +import itertools +from typing import Callable, List + +import torch +import torch._prims_common as utils +import torch._subclasses.functional_tensor +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + _maybe_run_with_interpreter, + _set_compilation_env, + autograd_not_implemented, + reenter_make_fx, + unique_graph_id, +) +from torch._inductor.utils import is_pointwise_use +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +aten = torch._ops.ops.aten + + +def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves): + assert len(args) == 2 * num_leaves + lhs = pytree.tree_unflatten(args[:num_leaves], spec) + rhs = pytree.tree_unflatten(args[num_leaves:], spec) + combined = combine_fn(lhs, rhs) + combined_leaves = pytree.tree_leaves(combined) + assert num_leaves == len(combined_leaves) + return combined_leaves + + +def _interleave(a, b, dim): + # https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors + if b_trunc := (a.shape[dim] == b.shape[dim] + 1): + pad = ( + [0] * ((b.ndim - dim - 1) * 2 + 1) + + [1] + + [0] * (b.ndim * 2 - ((b.ndim - dim - 1) * 2 + 2)) + ) + b = torch.nn.functional.pad(b, pad) + + stacked = torch.stack([a, b], dim=dim + 1) + interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1) + if b_trunc: + # TODO: find torch alternative for slice_along dim for torch.jit.script to work + interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1) + return interleaved + + +def safe_map(f, *args): + args = list(map(list, args)) + n = len(args[0]) + for arg in args[1:]: + if len(arg) != n: + raise ValueError("length mismatch: {list(map(len, args))}") + + def nf(a): + return f(*a) + + return list(map(nf, zip(*args))) + + +class AssociativeScanOp(HigherOrderOperator): + def __init__(self): + super().__init__("associative_scan") + + def __call__(self, combine_fn, input, dim): + return super().__call__(combine_fn, input, dim) + + +associative_scan_op = AssociativeScanOp() + + +def associative_scan( + combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree], + input: pytree.PyTree, + dim: int, + reverse: bool = False, + combine_mode: str = "pointwise", +) -> torch.Tensor: + r""" + Performs an inclusive scan with an associative pointwise combine function. + + .. warning:: + `torch.associative_scan` is a prototype feature in PyTorch. It currently + does not support autograd and you may run into miscompiles. + Read more about feature classification at: + https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + This operator requires runtime code generation and so requires support for + ``torch.compile``. Further, only CUDA device codegen is supported at the moment. + + Args: + combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``, + or if input is a pytree ``(pytree, pytree) -> pytree``. + This function must be pure, pointwise, and satisfy the associative property. + input (torch.Tensor): The input tensor, or nested pytree of tensors. + All inputs are expected to have the same shape. + dim (int): the dimension to scan over + reverse (bool): A boolean stating if the scan should be reversed with respect to the dimension. + combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``. + If ``combine_mode=pointwise``, ``combine_fn`` must be pure, may only contain pointwise operations + and ``input`` must be CUDA tensors. + In all other cases ``combine_mode=generic`` should be used. + Note: ``combine_mode=pointwise`` is more efficient than ``combine_mode=generic``. + + + Example:: + + def add(x: torch.Tensor, y: torch.Tensor): + return x + y + + cumsum = associative_scan(add, x, dim) + + """ + assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}" + assert isinstance(dim, int), "dim must be an int, but got {type(dim)}" + assert combine_mode in ["pointwise", "generic"] + + if not torch._dynamo.is_compiling(): + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): + return torch.compile(associative_scan, fullgraph=True)( + combine_fn, input, dim, reverse=reverse, combine_mode=combine_mode + ) + + leaves, spec = pytree.tree_flatten(input) + + if combine_mode == "pointwise" and not all(l.device.type == "cuda" for l in leaves): + raise ValueError( + "For combine_mode='pointwise', all input tensors need to be on CUDA" + ) + + assert len(leaves) >= 1, "expected at least 1 input leaf" + assert all( + isinstance(x, torch.Tensor) for x in leaves + ), "input leaves must be a Tensor" + + if reverse: + leaves = [torch.flip(elem, [dim]) for elem in leaves] + + shape = leaves[0].shape + ndim = len(shape) + dim = utils.canonicalize_dim(ndim, dim) + + for x in leaves[1:]: + assert x.shape == shape, "All input tensors must have the same shape" + + out = combine_fn( + pytree.tree_unflatten(leaves, spec), + pytree.tree_unflatten(leaves, spec), + ) + out_leaves, tree_out = pytree.tree_flatten(out) + assert len(leaves) == len( + out_leaves + ), "The pytree of the output of the operator needs to match the input pytree" + for x in out_leaves: + assert ( + x.shape == shape + ), "The pytree of the output of the operator needs to match the input pytree" + + combine_fn = functools.partial( + wrap_combine_fn_flat, combine_fn=combine_fn, spec=spec, num_leaves=len(leaves) + ) + + if combine_mode == "generic": + result_flat = generic_associative_scan(combine_fn, leaves, dim) + else: + result_flat = associative_scan_op(combine_fn, leaves, dim) + + if reverse: + result_flat = [torch.flip(elem, [dim]) for elem in result_flat] + + return pytree.tree_unflatten(result_flat, spec) + + +def generic_associative_scan(operator, elems_flat, dim=0): + r""" + This function performs the associative_scan operation. + The algorithm works by recursively collecting neighbours of ``elems_flat`` and subsequently + applying the ``operator`` on all pairs in parallel along ``dim``. + The results of the recursive calls are later combined. + + Args: + operator (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``, + or if input is a pytree ``(pytree, pytree) -> pytree``. + This function must be pure, pointwise, and satisfy the associative property. + elems_flat (torch.Tensor): A list of torch.Tensors converted from the pytree of + ``input`` provided to ``associative_scan``. + All inputs are expected to have the same shape. + dim (int): the dimension to scan over + + + Example:: + + def add(x: torch.Tensor, y: torch.Tensor): + return x + y + + elems_flat = torch.tensor([0.0, 1.0, 2.0, 3.0]) + + First iteration of _scan -> + # odd_elems -> apply operator on all neighbours + # odd_elems = operator([torch.tensor([0.0, 2.0])], + # [torch.tensor([1.0, 3.0])]) + odd_elems = torch.tensor([1.0, 5.0]) + Second iteration of _scan -> + # odd_elems = operator([torch.tensor([1.0])], + # [torch.tensor([5.0])]) + odd_elems = torch.tensor([6.0]) + # even_elems -> apply operator on all odd_elems and + # every second element of ``elems``, starting from the second element. + # even_elems is expanded with the first element of ``elems`` + even_elems = [1.0] + # Merges odd_elems and even_elems + res = torch.tensor([1.0, 6.0]) + # even_elems -> apply operator on all odd_elems and + # every second element of ``elems``, starting from the second element. + # even_elems is expanded with the first element of ``elems`` + even_elems = [0.0, 3.0] + # Merges odd_elems and even_elems + res = torch.tensor([0.0, 1.0, 3.0, 6.0]) + + """ + + def _scan(elems): + """Perform the actual recursive scan on ``elems``.""" + num_elems = elems[0].shape[dim] + + if num_elems < 2: + return elems + + reduced_elems = operator( + *[aten.slice(elem, dim, 0, -1, 2) for elem in elems], + *[aten.slice(elem, dim, 1, None, 2) for elem in elems], + ) + + # Recursively compute scan for partially reduced tensors. + odd_elems = _scan(reduced_elems) + + if num_elems % 2 == 0: + even_elems = operator( + *[aten.slice(e, dim, 0, -1) for e in odd_elems], + *[aten.slice(e, dim, 2, None, 2) for e in elems], + ) + else: + even_elems = operator( + *odd_elems, + *[aten.slice(e, dim, 2, None, 2) for e in elems], + ) + + # The first element of a scan is the same as the first element + # of the original `elems`. + even_elems = [ + torch.cat([aten.slice(elem, dim, 0, 1), result], dim=dim) + if result.shape.numel() > 0 and elem.shape[dim] > 0 + else result + if result.shape.numel() > 0 + else aten.slice( + elem, dim, 0, 1 + ) # Jax allows/ignores concat with 0-dim, Pytorch does not + for (elem, result) in zip(elems, even_elems) + ] + + return list( + safe_map(functools.partial(_interleave, dim=dim), even_elems, odd_elems) + ) + + scans = _scan(elems_flat) + + return scans + + +def trace_associative_scan( + proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int +): + with disable_proxy_modes_tracing(): + sample_inputs = [ + torch.empty_like( + x, + dtype=x.dtype, + device=x.device, + requires_grad=x.requires_grad, + ) + for x in itertools.chain(input, input) + ] + combine_graph = reenter_make_fx(combine_fn)(*sample_inputs) + + outputs = None + for node in combine_graph.graph.nodes: + if node.op == "output": + assert outputs is None + assert len(node.args) == 1 + outputs = node.args[0] + + if not all(is_pointwise_use(use) or use.op == "output" for use in node.users): + raise ValueError( + "For combine_mode='pointwise', the combine_fn needs to be pointwise" + ) + + assert outputs is not None + assert len(outputs) == len( + input + ), f"expected combine_fn to return {len(input)} results but got {len(outputs)}" + + for i, o in zip(input, outputs): + o_meta = o.meta["tensor_meta"] + assert o_meta.dtype == i.dtype, ( + f"combine_fn output type mismatch, expected {i.dtype} " + + f"but got {o_meta.dtype}" + ) + + _, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph") + + proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph) + + args = (combine_graph, input, dim) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="associative_scan" + ) + + with disable_proxy_modes_tracing(): + out = [aten.clone(x) for x in input] + + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def associative_scan_op_dense(combine_fn, input, dim): + raise NotImplementedError("associative_scan is not implemented for eager") + + +associative_scan_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(associative_scan_op, deferred_error=True) +) + + +@associative_scan_op.py_impl(ProxyTorchDispatchMode) +def associative_scan_proxy_mode(mode, combine_fn, input, dim): + return trace_associative_scan(mode, associative_scan_op, combine_fn, input, dim) + + +@associative_scan_op.py_impl(FakeTensorMode) +def assoiciative_scan_fake_tensor_mode(mode, combine_fn, input, dim): + with mode: + return [x.clone() for x in input] + + +@associative_scan_op.py_functionalize_impl +def associative_scan_functionalize(ctx, combine_fn, input, dim): + unwrapped_input = ctx.unwrap_tensors(input) + with ctx.redispatch_to_next() as m: + functional_combine_fn = ctx.functionalize( + _maybe_run_with_interpreter(combine_fn) + ) + ret = associative_scan_op(functional_combine_fn, unwrapped_input, dim) + return ctx.wrap_tensors(ret) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/auto_functionalize.py b/lib/python3.10/site-packages/torch/_higher_order_ops/auto_functionalize.py new file mode 100644 index 0000000000000000000000000000000000000000..232981f1f0192eb8f50b35586298aa645304be30 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/auto_functionalize.py @@ -0,0 +1,713 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import warnings +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator, OperatorBase, OpOverload +from torch._prims_common import clone_preserve_strides +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +def get_base(tensor): + if torch.is_inference_mode_enabled(): + return tensor._inference_mode_base + else: + return tensor._base + + +@dataclass +class ViewInfo: + base_index: int + size: Optional[Sequence[Union[int, torch.SymInt]]] = None + stride: Optional[Sequence[Union[int, torch.SymInt]]] = None + storage_offset: Optional[int] = None + # When is_view is false, the tensor is the base, and + # size, stride and storage_offset are all None. + is_view: bool = True + + def regenerate_view(self, bases_list: List[Tensor]): + if not self.is_view: + return bases_list[self.base_index] + + assert self.stride is not None + assert self.size is not None + assert self.storage_offset is not None + + return torch.as_strided( + bases_list[self.base_index], + self.size, + self.stride, + self.storage_offset, + ) + + +def write_view_information_to_args( + mutable_arg_names: List[str], + mutable_arg_types: List[torch.Type], + kwargs: Dict[str, Any], + arg_to_base_index: Dict[str, Any], +): + """ + This function writes the view information into kwargs. It reads mutable_args from kwargs. + and uses arg_to_base_index and tensor information to write ViewInfo into kwargs. + mutable_arg_names: mutable custom operator arg names. + mutable_arg_types: mutable custom operator arg types. + kwargs: the original custom operator args. + arg_to_base_index: maps mutable_arg_name to int | [int] that refers to the base tensor that + corresponds to the input tensor + """ + + def write_single_view(prefix: str, tensor: Tensor, base_index: int): + assert f"{prefix}_base_index" not in kwargs + assert f"{prefix}_size" not in kwargs + assert f"{prefix}_stride" not in kwargs + assert f"{prefix}_storage_offset" not in kwargs + + if tensor is None: + kwargs[f"{prefix}_base_index"] = None + elif get_base(tensor) is None: + # if the tensor is the base (not view), for simplicity we do not serialize view meta. + kwargs[f"{prefix}_base_index"] = base_index + else: + kwargs[f"{prefix}_base_index"] = base_index + kwargs[f"{prefix}_size"] = tensor.size() + kwargs[f"{prefix}_stride"] = tensor.stride() + kwargs[f"{prefix}_storage_offset"] = tensor.storage_offset() + + for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): + arg = kwargs[arg_name] + if isinstance(arg_type, torch.ListType): + if arg is None: + kwargs[f"_{arg_name}_length"] = None + + kwargs[f"_{arg_name}_length"] = len(arg) + for i, elem in enumerate(arg): + write_single_view( + f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i] + ) + + elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)): + write_single_view( + f"_{arg_name}", + kwargs[arg_name], + arg_to_base_index.get(arg_name, None), + ) + else: + raise RuntimeError(f"Unsupported type {arg_type}") + + +# Returns a dict of arg_name -> ViewInfo | [ViewInfo] +def read_view_information_from_args( + mutable_arg_names: List[str], + mutable_arg_types: List[torch.Type], + kwargs: Dict[str, Any], + all_bases: List[Tensor], +): + """ + This reads the view information added by `write_view_information_to_args` from kwargs, pop them, + and returns a dict arg_name -> ViewInfo | [ViewInfo](if the input is list). that maps each mutable arg + to its view information. + mutable_arg_names: mutable custom operator arg names. + mutable_arg_types: mutable custom operator arg types. + kwargs : args of auto_functionalize(custom_op, kwargs) + """ + + def get_arg(name): + return kwargs.pop(name) + + def read_single_view(prefix): + base_index = get_arg(f"{prefix}_base_index") + if base_index is None: + return None + elif f"{prefix}_size" not in kwargs: + assert f"{prefix}_stride" not in kwargs + assert f"{prefix}_storage_offset" not in kwargs + + # This means that the argument is the base tensor + return ViewInfo(base_index, all_bases[base_index], is_view=False) + + else: + size = get_arg(f"{prefix}_size") + stride = get_arg(f"{prefix}_stride") + storage_offset = get_arg(f"{prefix}_storage_offset") + return ViewInfo(base_index, size, stride, storage_offset, is_view=True) + + args_view_info: Dict[str, Any] = {} + for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): + if isinstance(arg_type, torch.ListType): + length = get_arg(f"_{arg_name}_length") + if length is None: + # The whole list is None. + args_view_info[arg_name] = None + else: + args_view_info[arg_name] = [ + read_single_view(f"_{arg_name}_{i}") for i in range(length) + ] + + elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)): + args_view_info[arg_name] = read_single_view(f"_{arg_name}") + else: + raise RuntimeError(f"Unsupported type {arg_type}") + return args_view_info + + +# NOTE: [auto-functionalizing custom ops] +# Users may wish to torch.compile custom ops that mutate their inputs. +# torch.compile will automatically support this op without anyone needing +# to provide a functionalization kernel for it. Here's how. +# +# Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> () +# op. First, when FakeTensor sees this op: +# - If the schema says it returns nothing, we can generate a trivial +# FakeTensor rule for it (that returns nothing). +# - Otherwise, the user needs to provide a FakeTensor impl (fake impl) +# +# Next, when Python FunctionalTensor sees the op, it will functionalize +# it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...}) +# HOP and replacing the mutated inputs with corresponding outputs of this HOP. +# This HOP effectively runs the functional version of the op when +# called: it clones inputs that will be mutated, runs the op, and +# then returns (output, Tensors with the new values) +# +# auto_functionalize_v2 is an improved version of auto_functionalize that better handle +# re-inplacing views. + + +class AutoFunctionalized(HigherOrderOperator): + """auto_functionalized(_mutable_op, **kwargs) + + This HOP runs a "functional" version of _mutable_op. + + Concretely, it looks at all the arguments that are mutable through + _mutable_op's operator schema, clones those kwargs, runs + `out = _mutable_op(**kwargs)` with the cloned values, and then returns the + operator output concatenated with the cloned values that were mutated. + + We have some restrictions on `_mutable_op`. + See `can_auto_functionalize` for the restrictions. We can likely lift + many of these if users request it. + + The reason why _mutable_op is prefixed with an + underscore is to prevent collisions with kwarg names in **kwargs. + """ + + def __init__(self) -> None: + super().__init__("auto_functionalized") + + def __call__( + self, + /, + _mutable_op: OpOverload, + **kwargs: Any, + ) -> Tuple[Any, Tuple[Tensor, ...]]: + assert can_auto_functionalize(_mutable_op) + assert isinstance(kwargs, dict) + return super().__call__(_mutable_op, **kwargs) + + +auto_functionalized = AutoFunctionalized() +auto_functionalized.__module__ = "torch.ops.higher_order" + +auto_functionalized.fallthrough(DispatchKey.AutogradCPU) +auto_functionalized.fallthrough(DispatchKey.AutogradCUDA) + + +class AutoFunctionalizedV2(HigherOrderOperator): + """auto_functionalized_v2(_mutable_op, **kwargs) + + This HOP runs a "functional" version of _mutable_op. + Unlike AutoFunctionalized, this version is improved to better handle + view tensors. This version is only used in non export mode. + """ + + def __init__(self) -> None: + super().__init__("auto_functionalized_v2") + + def __call__( + self, + /, + _mutable_op: OpOverload, + **kwargs: Any, + ) -> Tuple[Any, Tuple[Tensor, ...]]: + assert can_auto_functionalize(_mutable_op) + assert isinstance(kwargs, dict) + return super().__call__(_mutable_op, **kwargs) + + +auto_functionalized_v2 = AutoFunctionalizedV2() +auto_functionalized_v2.__module__ = "torch.ops.higher_order" + +auto_functionalized_v2.fallthrough(DispatchKey.AutogradCPU) +auto_functionalized_v2.fallthrough(DispatchKey.AutogradCUDA) + + +def can_auto_functionalize(op: OperatorBase) -> bool: + if not isinstance(op, OpOverload): + return False + + if torch._library.utils.is_builtin(op): + # We control the built-ins. These may (in rare cases) + # do input metadata mutation (which we have banned on custom ops) + return False + schema = op._schema + if not schema.is_mutable: + return False + schema = op._schema + + for arg in schema.arguments: + if arg.alias_info is None: + continue + if not arg.alias_info.is_write: + continue + if type(arg.type) is torch.TensorType: + continue + if ( + type(arg.type) is torch.OptionalType + and type(arg.type.getElementType()) is torch.TensorType + ): + continue + if ( + type(arg.type) is torch.ListType + and type(arg.type.getElementType()) is torch.TensorType + ): + continue + # Not yet supported: other Tensor types. This includes things like + # Tensor?[], Tensor[]?. + return False + + if len(schema.returns) == 1 and isinstance(schema.returns[0].type, torch.NoneType): + # Skip schema returns -> None + return True + # The returns must not alias anything + for ret in schema.returns: + if ret.alias_info is None and type(ret.type) is torch.TensorType: + continue + # Not yet supported: List[Tensor] return. + return False + if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), "Functionalize"): + return False + return True + + +def get_mutable_args(op: OpOverload) -> Tuple[List[str], List[torch.Type]]: + """ + Returns the list of argument names that get mutated according to the + schema and their types. + """ + mutable_args_names = [ + arg.name + for arg in op._schema.arguments + if arg.alias_info is not None and arg.alias_info.is_write + ] + + mutable_args_types = [ + arg.type + for arg in op._schema.arguments + if arg.alias_info is not None and arg.alias_info.is_write + ] + return mutable_args_names, mutable_args_types + + +def do_auto_functionalize( + op: OpOverload, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], +) -> Any: + """Functionalizes a call to op(*args, **kwargs) by emitting a call to + `outs = auto_functionalized(op, normalized_kwargs)` + and replacing the mutated (args, kwargs) with the corresponding outputs. + + The normalized_kwargs are just the (args, kwargs), but all in kwarg form. + This makes handling easier for the auto_functionalized HOP. + """ + from torch._subclasses.functional_tensor import PythonFunctionalizeAPI + + ctx = PythonFunctionalizeAPI() + + # All of the (args, kwargs), but all as kwargs. The names for the + # args come from the schema. This makes it easier for us to work with them. + normalized_kwargs = {} + schema = op._schema + for idx, arg in enumerate(schema.arguments): + # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema + if arg.name in kwargs: + normalized_kwargs[arg.name] = kwargs[arg.name] + elif idx < len(args): + # if its out of bounds we don't need to do anything + # as it means the the optional arg was passed with its default + # value + normalized_kwargs[arg.name] = args[idx] + else: + normalized_kwargs[arg.name] = arg.default_value + + unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] + if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs: + warnings.warn( + "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. " + "Please consider using a different name for this argument to avoid potential issues." + ) + with ctx.redispatch_to_next(): + unwrapped_outs = auto_functionalized( + op, **unwrapped_kwargs # type: ignore[arg-type] + ) + + # List of the name of args that get mutated (according to the schema) + mutable_args_names, _ = get_mutable_args(op) + + unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[ + : -len(mutable_args_names) + ] + unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :] + + if len(op._schema.returns) == 0: + assert unwrapped_actual_out[0] is None + unwrapped_actual_out = None + elif len(op._schema.returns) == 1: + assert len(unwrapped_actual_out) == 1 + unwrapped_actual_out = unwrapped_actual_out[0] + else: + assert len(unwrapped_actual_out) == len(op._schema.returns) + + for name, unwrapped_out in zip(mutable_args_names, unwrapped_mutable_out): + # Can be None if input was `Tensor(a!)?` + if unwrapped_out is None: + continue + + # We only handle Tensor or List[Tensor] here for now. + def sync_update(o, orig_arg): + ctx.replace(orig_arg, o) + ctx.commit_update(orig_arg) + ctx.sync(orig_arg) + + orig_arg = normalized_kwargs[name] + + if isinstance(unwrapped_out, torch.Tensor): + sync_update(unwrapped_out, orig_arg) + elif isinstance(unwrapped_out, list) and all( + isinstance(o, torch.Tensor) for o in unwrapped_out + ): + assert len(orig_arg) == len(unwrapped_out) + for orig_a, o in zip(orig_arg, unwrapped_out): + sync_update(o, orig_a) + else: + raise RuntimeError( + f"unsupported type for auto-functionalization: {unwrapped_out}" + ) + + return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] + + +def do_auto_functionalize_v2( + op: OpOverload, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], +) -> Any: + from torch._subclasses.functional_tensor import PythonFunctionalizeAPI + + ctx = PythonFunctionalizeAPI() + + # All of the (args, kwargs), but all as kwargs. The names for the + # args come from the schema. This makes it easier for us to work with them. + normalized_kwargs = {} + + schema = op._schema + for idx, arg in enumerate(schema.arguments): + # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema + if arg.name in kwargs: + normalized_kwargs[arg.name] = kwargs[arg.name] + elif idx < len(args): + # if its out of bounds we don't need to do anything + # as it means the the optional arg was passed with its default + # value + normalized_kwargs[arg.name] = args[idx] + else: + normalized_kwargs[arg.name] = arg.default_value + + # List of the name of args that get mutated (according to the schema) + mutable_args_names, mutable_args_types = get_mutable_args(op) + + # A list of all bases of mutable args without duplication + all_bases = [] + all_bases_addresses: list[int] = [] + + # Map arg_name to the index of its base in all_bases. + arg_to_base_index: Dict[str, Any] = {} + + def update_dict(tensor, arg_name, index=None): + base = tensor if get_base(tensor) is None else get_base(tensor) + + def set_result(base_index): + if index is None: + arg_to_base_index[arg_name] = base_index + else: + arg_to_base_index[arg_name][index] = base_index + + if not all_bases_addresses.__contains__(base._cdata): + all_bases_addresses.append(base._cdata) + all_bases.append(base) + set_result(len(all_bases) - 1) + else: + set_result(all_bases_addresses.index(base._cdata)) + + for arg_name in mutable_args_names: + arg = normalized_kwargs[arg_name] + if arg is None: + continue + + if isinstance(arg, list): + arg_to_base_index[arg_name] = {} + for i, tensor in enumerate(arg): + if tensor is None: + arg_to_base_index[arg_name].append(None) + continue + + update_dict(tensor, arg_name, i) + + else: + update_dict(arg, arg_name) + + # add view_meta for each args into unwrapped_kwargs. + write_view_information_to_args( + mutable_args_names, + mutable_args_types, + normalized_kwargs, + arg_to_base_index, + ) + + # remove mutated args from the kwargs (its a function of _all_bases now) + for arg_name in mutable_args_names: + del normalized_kwargs[arg_name] # type: ignore[arg-type] + + unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] + if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs: + warnings.warn( + "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. " + "Please consider using a different name for this argument to avoid potential issues." + ) + all_basis_unwrapped = ctx.unwrap_tensors(all_bases) + + with ctx.redispatch_to_next(): + unwrapped_outs = auto_functionalized_v2( + op, **dict(unwrapped_kwargs, _all_bases=all_basis_unwrapped) # type: ignore[arg-type] + ) + + unwrapped_actual_out: Union[Any, Tuple[Any]] = ( + unwrapped_outs if len(all_bases) == 0 else unwrapped_outs[: -len(all_bases)] + ) + + unwrapped_mutable_out = ( + [] if len(all_bases) == 0 else unwrapped_outs[-len(all_bases) :] + ) + + if len(op._schema.returns) == 0: + assert unwrapped_actual_out[0] is None + unwrapped_actual_out = None + elif len(op._schema.returns) == 1: + assert len(unwrapped_actual_out) == 1 + unwrapped_actual_out = unwrapped_actual_out[0] + else: + assert len(unwrapped_actual_out) == len(op._schema.returns) + + for orig_arg, unwrapped_out in zip(all_bases, unwrapped_mutable_out): + # Can be None if input was `Tensor(a!)?` + if unwrapped_out is None: + continue + + # We only handle Tensor or List[Tensor] here for now. + def sync_update(o, orig_arg): + ctx.replace(orig_arg, o) + ctx.commit_update(orig_arg) + ctx.sync(orig_arg) + + if isinstance(unwrapped_out, torch.Tensor): + sync_update(unwrapped_out, orig_arg) + elif isinstance(unwrapped_out, list) and all( + isinstance(o, torch.Tensor) for o in unwrapped_out + ): + assert len(orig_arg) == len(unwrapped_out) + for orig_a, o in zip(orig_arg, unwrapped_out): + sync_update(o, orig_a) + else: + raise RuntimeError( + f"unsupported type for auto-functionalization: {unwrapped_out}" + ) + + return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] + + +# auto_functionalize functions +@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd) +def auto_functionalized_dense( + _mutable_op: OpOverload, + _only_clone_these_tensors: Optional[Tuple[str, ...]] = None, + **kwargs: Any, +) -> Tuple[Any, Tuple[Tensor, ...]]: + new_kwargs = dict(**kwargs) + result = [] + + _mutable_args_names, _ = get_mutable_args(_mutable_op) + for name in _mutable_args_names: + if ( + _only_clone_these_tensors is not None + and name not in _only_clone_these_tensors + ): + new_kwargs[name] = kwargs[name] + else: + new_kwargs[name] = ( + [clone_preserve_strides(x) for x in kwargs[name]] + if kwargs[name] is not None and isinstance(kwargs[name], list) + else clone_preserve_strides(kwargs[name]) + if kwargs[name] is not None + else None + ) + result.append(new_kwargs[name]) + out = _mutable_op(**new_kwargs) + + if isinstance(out, tuple): + return (*out, *result) # type: ignore[return-value] + else: + return (out, *result) # type: ignore[return-value] + + +@auto_functionalized.py_impl(FakeTensorMode) +def auto_functionalized_fake( + mode, + _mutable_op: OpOverload, + **kwargs: Any, +) -> Tuple[Any, Tuple[Tensor, ...]]: + with mode: + result = auto_functionalized_dense(_mutable_op, **kwargs) + return result + + +@auto_functionalized.py_impl(ProxyTorchDispatchMode) +def auto_functionalized_proxy( + mode, + _mutable_op: OpOverload, + **kwargs: Any, +) -> Tuple[Any, Tuple[Tensor, ...]]: + with disable_proxy_modes_tracing(): + out = auto_functionalized(_mutable_op, **kwargs) + + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", + auto_functionalized, + (_mutable_op,), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + +@auto_functionalized.py_functionalize_impl +def auto_functionalized_func(ctx, _mutable_op, **kwargs): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + with ctx.redispatch_to_next(): + result = auto_functionalized(_mutable_op, **unwrapped_kwargs) + return ctx.wrap_tensors(result) + + +# auto_functionalized_v2 functions +@auto_functionalized_v2.py_impl(DispatchKey.CompositeExplicitAutograd) +def auto_functionalized_v2_dense( + _mutable_op: OpOverload, + _only_clone_these_bases: Optional[Tuple[int, ...]] = None, + **kwargs: Any, +) -> Tuple[Any, Tuple[Tensor, ...]]: + all_bases: List[Tensor] = kwargs.pop("_all_bases", []) + mutable_args_names, mutable_args_types = get_mutable_args(_mutable_op) + args_view_info = read_view_information_from_args( + mutable_args_names, mutable_args_types, kwargs, all_bases + ) + + if _only_clone_these_bases is None: + _only_clone_these_bases = tuple(range(len(all_bases))) + + def maybe_copy(i, t): + if t is None: + return None + if i in _only_clone_these_bases: + return clone_preserve_strides(t) + else: + return t + + all_bases_new = [maybe_copy(i, t) for i, t in enumerate(all_bases)] + + # create new args + new_kwargs = dict(**kwargs) + + # re-generate all inputs from all_bases_new using args_view_info and add them to new_kwargs. + for arg_name in mutable_args_names: + if args_view_info[arg_name] is None: + new_kwargs[arg_name] = None + elif isinstance(args_view_info[arg_name], list): + new_kwargs[arg_name] = [] + for i, elem in enumerate(args_view_info[arg_name]): + if elem is None: + new_kwargs[arg_name].append(None) + else: + view_info = args_view_info[arg_name][i] + new_kwargs[arg_name].append( + view_info.regenerate_view(all_bases_new) + ) + else: + new_kwargs[arg_name] = args_view_info[arg_name].regenerate_view( + all_bases_new + ) + + out = _mutable_op(**new_kwargs) + + if isinstance(out, tuple): + return (*out, *all_bases_new) # type: ignore[return-value] + else: + return (out, *all_bases_new) # type: ignore[return-value] + + +@auto_functionalized_v2.py_impl(FakeTensorMode) +def auto_functionalized_v2_fake( + mode, + _mutable_op: OpOverload, + **kwargs: Dict[str, Any], +) -> Tuple[Any, Tuple[Tensor, ...]]: + with mode: + result = auto_functionalized_v2_dense(_mutable_op, **kwargs) + return result + + +@auto_functionalized_v2.py_impl(ProxyTorchDispatchMode) +def auto_functionalized_v2_proxy( + mode, + _mutable_op: OpOverload, + **kwargs: Dict[str, Any], +) -> Tuple[Any, Tuple[Tensor, ...]]: + with disable_proxy_modes_tracing(): + out = auto_functionalized_v2(_mutable_op, **kwargs) + + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", + auto_functionalized_v2, + (_mutable_op,), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + +@auto_functionalized_v2.py_functionalize_impl +def auto_functionalized_v2_func(ctx, _mutable_op, **kwargs): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + with ctx.redispatch_to_next(): + result = auto_functionalized_v2(_mutable_op, **unwrapped_kwargs) + return ctx.wrap_tensors(result) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py b/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py new file mode 100644 index 0000000000000000000000000000000000000000..dee400d76f5964fc330acfb16459700cba6679e7 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py @@ -0,0 +1,521 @@ +# mypy: allow-untyped-defs +import contextlib +import logging + +import torch +import torch._subclasses.functional_tensor +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._C._functorch import ( + _add_batch_dim, + get_unwrapped, + is_batchedtensor, + maybe_get_bdim, +) +from torch._dispatch.python import suspend_functionalization +from torch._functorch.utils import exposed_in +from torch._guards import detect_fake_mode +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + _maybe_run_with_interpreter, + _set_compilation_env, + reenter_make_fx, + unique_graph_id, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch._subclasses.functional_tensor import disable_functional_mode +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_pre_dispatch_torch_function_mode, + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils._python_dispatch import _get_current_dispatch_mode + +from .utils import _from_fun, create_fw_bw_graph + + +log = logging.getLogger(__name__) + +""" +We're going to define a `cond_op` operation. +In order to do this, we need implementations for each of the dispatch keys. +""" + + +class CondOp(HigherOrderOperator): + def __init__(self): + super().__init__("cond") + + def __call__(self, pred, true_fn, false_fn, operands): + return super().__call__(pred, true_fn, false_fn, operands) + + +cond_op = CondOp() + + +@exposed_in("torch") +def cond(pred, true_fn, false_fn, operands): + r""" + Conditionally applies `true_fn` or `false_fn`. + + .. warning:: + `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and + doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. + Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + `cond` is structured control flow operator. That is, it is like a Python if-statement, + but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be + capturable using torch.compile and torch.export. + + Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following:: + + def cond(pred, true_branch, false_branch, operands): + if pred: + return true_branch(*operands) + else: + return false_branch(*operands) + + Args: + pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element, + indicating which branch function to apply. + + true_fn (Callable): A callable function (a -> b) that is within the + scope that is being traced. + + false_fn (Callable): A callable function (a -> b) that is within the + scope that is being traced. The true branch and false branch must + have consistent input and outputs, meaning the inputs have to be + the same, and the outputs have to be the same type and shape. + + operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the true/false functions. + + Example:: + + def true_fn(x: torch.Tensor): + return x.cos() + def false_fn(x: torch.Tensor): + return x.sin() + return cond(x.shape[0] > 4, true_fn, false_fn, (x,)) + + Restrictions: + - The conditional statement (aka `pred`) must meet one of the following constraints: + + - It's a `torch.Tensor` with only one element, and torch.bool dtype + + - It's a boolean expression, e.g. `x.shape[0] > 10` or `x.dim() > 1 and x.shape[1] > 10` + + - The branch function (aka `true_fn`/`false_fn`) must meet all of the following constraints: + + - The function signature must match with operands. + + - The function must return a tensor with the same metadata, e.g. shape, + dtype, etc. + + - The function cannot have in-place mutations on inputs or global variables. + (Note: in-place tensor operations such as `add_` for intermediate results + are allowed in a branch) + + .. warning:: + Temporal Limitations: + + - The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future. + + """ + if torch.compiler.is_dynamo_compiling(): + return cond_op(pred, true_fn, false_fn, operands) + + if isinstance(pred, (bool, int, float)): + log.warning( + "Pred is a Python constant. When used with torch.cond, it executes only one of the branches." + " If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool." + ) + if pred: + return true_fn(*operands) + else: + return false_fn(*operands) + + def _validate_input(pred, true_fn, false_fn, operands): + if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)): + raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.") + + if isinstance(pred, torch.Tensor) and pred.numel() != 1: + raise RuntimeError( + f"Expected pred to be bool or single-element tensor, but got {pred}." + ) + + if not callable(true_fn) or not callable(false_fn): + raise RuntimeError("Expect both branches to be callbale.") + + if not isinstance(operands, (tuple, list)) or pytree.tree_any( + lambda t: not isinstance(t, torch.Tensor), operands + ): + raise RuntimeError( + "Expect operands to be a tuple of possibly nested dict/list/tuple that only" + f"consists of tensor leaves, but got {operands}." + ) + + _validate_input(pred, true_fn, false_fn, operands) + + if not torch._dynamo.is_dynamo_supported(): + raise RuntimeError("torch.cond requires dynamo support.") + + # Dynamo is expecting a callable with "__code__" attribute. + # We cannot directly pass cond_op to it. So we wrap it in a dummy function. + def _cond_op_wrapper(*args, **kwargs): + return cond_op(*args, **kwargs) + + with _set_compilation_env(): + with torch._dynamo.utils.disable_cache_limit(): + with _temp_remove_pre_dispatch_torch_function_mode(): + return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)( + pred, true_fn, false_fn, operands + ) + + +def create_fw_bw_graph_branches(true_fn, false_fn, *operands): + # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + fw_inputs = pytree.tree_map(_from_fun, operands) + + fw_outputs_true = pytree.tree_map(_from_fun, true_fn(*fw_inputs)) + if any( + not isinstance(out, torch.Tensor) + for out in fw_outputs_true + if out is not None + ): + raise RuntimeError( + "Expect outputs of true_fn to only contains tensors or None. " + f"Got types {[type(out) for out in fw_outputs_true]}." + ) + fw_outputs_false = pytree.tree_map(_from_fun, false_fn(*fw_inputs)) + if any( + not isinstance(out, torch.Tensor) + for out in fw_outputs_false + if out is not None + ): + raise RuntimeError( + "Expect outputs of false_fn to only contains tensors or None. " + f"Got types {[type(out) for out in fw_outputs_false]}." + ) + + # TODO: There is a major issue that the create_fw_bw in the higher_order_op is invoked twice: + # Once in the forward path (as it should) and once in the backward path, where it shouldn't be called + # If we can get rid of the second invokation, it would simplify this function + fw_true_graph, joint_true_graph = create_fw_bw_graph( + true_fn, False, fw_inputs, fw_outputs_true + ) + fw_false_graph, joint_false_graph = create_fw_bw_graph( + false_fn, False, fw_inputs, fw_outputs_false + ) + + return fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph + + +def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): + assert isinstance( + operands, (list, tuple) + ), "Cond operands must be a list or tuple of tensors" + assert all( + isinstance(o, torch.Tensor) for o in operands + ), "Cond operands must be a list of tensors" + + true_graph = reenter_make_fx(true_fn)(*operands) + false_graph = reenter_make_fx(false_fn)(*operands) + + true_outs = [] + false_outs = [] + for node in true_graph.graph.nodes: + if node.op == "output": + true_outs.extend(node.args) + + for node in false_graph.graph.nodes: + if node.op == "output": + false_outs.extend(node.args) + + flat_true_outs = pytree.arg_tree_leaves(*true_outs) + flat_false_outs = pytree.arg_tree_leaves(*false_outs) + if len(flat_true_outs) != len(flat_false_outs): + raise torch._dynamo.exc.CondOpArgsMismatchError( + f"Expected to return same number of outputs but got:" + f"\n true branch returns {len(flat_true_outs)} item(s)" + f"\n false branch returns {len(flat_false_outs)} item(s)" + ) + + for i in range(0, len(flat_true_outs)): + true_out = flat_true_outs[i] + false_out = flat_false_outs[i] + + # Note that we need skip the check for requires_grad because we're after + # after autograd key during tracing, so the rquires_grad attribute of the tensors + # are no longer. See Note [invariants for node meta 'val'] + def _same_meta_except_requires_grad(true_out, false_out): + if true_out is None and false_out is None: + return True + elif true_out is None or false_out is None: + # Consider the following case: + # def true_fn(x, y): + # return x * y + # + # def false_fn(x, y): + # return x.sin() + # + # We'll get the following graphs for backward: + # def backward_true_fn(x, y, grad_out): + # return grad_out * y, grad_out * x + # + # def backward_false_fn(x, y, grad_out): + # retrun grad_out, None + # + # This suggests that when we make_fx into the backward graph, + # the output graph would produce outputs with metadata, this is undesirable. + # + # Ideally, we should provide an optional type to indicate that one of the branches might + # return None. But we'll just let it pass for now and let downstream/runtime handle. + # + # Note that this corner case should **only** happen when user want to trace backward graph because + # if it's foward, dynamo will error. + return True + true_meta = true_out.meta.get("tensor_meta", None) + false_meta = false_out.meta.get("tensor_meta", None) + return ( + true_meta.shape == false_meta.shape + and true_meta.dtype == false_meta.dtype + and true_meta.stride == false_meta.stride + ) + + if not _same_meta_except_requires_grad(true_out, false_out): + raise torch._dynamo.exc.CondOpArgsMismatchError( + f"Expected each tensor to have same metadata but got:" + f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}" + f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}" + ) + + i, true_name = unique_graph_id(proxy_mode, prefix="true_graph") + + false_name = f"false_graph_{i}" + assert not hasattr(proxy_mode.tracer.root, false_name) + + proxy_mode.tracer.root.register_module(true_name, true_graph) + proxy_mode.tracer.root.register_module(false_name, false_graph) + + args = (pred, true_graph, false_graph, operands) + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {} + ) + + # At this point, we're *guaranteed* that whether an output came from the + # true or false branch is indistinguishable. So, as this is just for tracing + # purposes, choose the true branch. + + # TODO: the unbacked symbol allocations MUST NOT leak out, if you want to + # support this we need to arrange for the reenter_make_fx unbacked SymInts + # to be used, AND we need to arrange for some sort of unification between + # the two branches (but not really unification; e.g., if one branch + # returns [u0] and the other returns [5] this is OK but you MUST NOT + # conclude the result is 5. Also if one branch returns [3] and another + # branch returns [5] you can make it work by immediately allocating a new + # unbacked SymInt here). + ignore_fresh_unbacked = contextlib.nullcontext() + if (fake_mode := detect_fake_mode()) and fake_mode.shape_env: + ignore_fresh_unbacked = fake_mode.shape_env.ignore_fresh_unbacked_symbols() + + # TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in + # a FakeTensorMode error : + # `Current active mode not registered` + # TODO Sometimes the operands are not completely FakeTensor, something seems went wrong in + # dynamo? Because of that it runs real computation sometimes and re-triggering downstream dispatch keys. + with ignore_fresh_unbacked: + out = false_fn(*operands) + + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def cond_op_dense(pred, true_fn, false_fn, operands): + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + if pred: + return true_fn(*operands) + else: + return false_fn(*operands) + + +class CondAutogradOp(torch.autograd.Function): + @staticmethod + def forward( + ctx, + pred, + fw_true_graph, + fw_false_graph, + joint_true_graph, + joint_false_graph, + *operands, + ): + ctx._pred = pred + ctx._joint_true_graph = joint_true_graph + ctx._joint_false_graph = joint_false_graph + ctx.save_for_backward(*operands) + + with torch._C._AutoDispatchBelowAutograd(): + return cond_op(pred, fw_true_graph, fw_false_graph, operands) + + @staticmethod + def backward(ctx, *flat_grads): + operands = ctx.saved_tensors + + grads = cond_op( + ctx._pred, + ctx._joint_true_graph, + ctx._joint_false_graph, + flat_grads + operands, + ) + return None, None, None, None, None, *grads + + +@cond_op.py_impl(DispatchKey.Autograd) +def cond_autograd(pred, true_fn, false_fn, operands): + # A shortcut for the case where all inputs don't require gradient, + # we skip tracing the forward and backward graph. + if pytree.tree_all_only( + torch.Tensor, + lambda t: not t.requires_grad, # type: ignore[union-attr] + (pred, operands), + ): + with torch._C._AutoDispatchBelowAutograd(): + return cond_op(pred, true_fn, false_fn, operands) + + ( + fw_true_graph, + fw_false_graph, + joint_true_graph, + joint_false_graph, + ) = create_fw_bw_graph_branches(true_fn, false_fn, *operands) + flat_out = CondAutogradOp.apply( + pred, + fw_true_graph, + fw_false_graph, + joint_true_graph, + joint_false_graph, + *operands, + ) + return flat_out + + +@cond_op.py_impl(ProxyTorchDispatchMode) +def inner(mode, pred, true_fn, false_fn, operands): + return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands) + + +@cond_op.py_impl(FakeTensorMode) +def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands): + # Ignore here, because if you've gotten here but you're not manually + # tracing the inner graphs, that means that you intend to reuse the graph + # directly. Which means the old unbacked symbol bindings are appropriate. + # This strategy will not work if unbacked symbols can escape. + ignore_fresh_unbacked = contextlib.nullcontext() + if mode.shape_env: + ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols() + + with mode, ignore_fresh_unbacked: + true_outs = true_fn(*operands) + flat_true_outs = pytree.tree_leaves(true_outs) + flat_false_outs = pytree.tree_leaves(false_fn(*operands)) + if len(flat_true_outs) != len(flat_false_outs): + raise RuntimeError("Unmatched number of outputs from cond() branches.") + + for true_out, false_out in zip(flat_true_outs, flat_false_outs): + true_meta = _extract_tensor_metadata(true_out) + false_meta = _extract_tensor_metadata(false_out) + if true_meta != false_meta: + raise torch._dynamo.exc.CondOpArgsMismatchError( + f"Expected each tensor to have same metadata but got:" + f"\n {true_fn.__name__} returns {true_meta}" + f"\n {false_fn.__name__} returns {false_meta}" + ) + return true_outs + + +@cond_op.py_functionalize_impl +def cond_func(ctx, pred, true_fn, false_fn, inputs): + unwrapped_inputs = ctx.unwrap_tensors(inputs) + unwrapped_pred = ctx.unwrap_tensors(pred) + with ctx.redispatch_to_next() as m: + functional_true = ctx.functionalize(_maybe_run_with_interpreter(true_fn)) + functional_false = ctx.functionalize(_maybe_run_with_interpreter(false_fn)) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + for branch in [functional_true, functional_false]: + if _has_potential_branch_input_mutation( + branch, unwrapped_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "One of torch.cond branch might be modifying the input!" + ) + for branch in [true_fn, false_fn]: + if _has_potential_branch_input_alias( + branch, unwrapped_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "One of torch.cond branch might be aliasing the input!" + ) + + cond_return = cond_op( + unwrapped_pred, functional_true, functional_false, unwrapped_inputs + ) + return ctx.wrap_tensors(cond_return) + + +@cond_op.py_impl(torch._C._functorch.TransformType.Vmap) +def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs): + assert isinstance( + inputs, (list, tuple) + ), "Cond inputs must be a list or tuple of tensors" + assert all( + isinstance(i, torch.Tensor) for i in inputs + ), "Cond inputs must be a list of tensors" + + pred_ = get_unwrapped(pred) if is_batchedtensor(pred) else pred + + # unbatched tensors are not vmapped + tensors, in_dims = zip( + *[ + (get_unwrapped(t), maybe_get_bdim(t)) if is_batchedtensor(t) else (t, None) + for t in inputs + ] + ) + + if is_batchedtensor(pred): + # prepend "pred" and vmap everything + tensors = (pred_,) + tensors + in_dims = (0,) + in_dims + + def fn(p, *args): + t = true_fn(*args) + f = false_fn(*args) + return torch.where(p, t[0], f[0]) + + with interpreter.lower(): + result = torch.vmap(fn, in_dims=in_dims)(*tensors) + + else: + # predicate is known at this stage and it is a boolean expression or a + # tensor with one element. + true_fn = torch.vmap(true_fn, in_dims=in_dims) + false_fn = torch.vmap(false_fn, in_dims=in_dims) + + with interpreter.lower(): + result = cond_op(pred, true_fn, false_fn, tensors) + + if not isinstance(result, tuple): + result = (result,) + lvl = interpreter.level() + return tuple([_add_batch_dim(r, 0, lvl) for r in result]) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/effects.py b/lib/python3.10/site-packages/torch/_higher_order_ops/effects.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3d93bc69eb6b89ae2a3694bd28e24dee38ecbe --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/effects.py @@ -0,0 +1,289 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from enum import Enum +from typing import Any, Dict, Optional, Tuple, Union +from weakref import WeakKeyDictionary + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.torchbind import call_torchbind +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +class _EffectType(Enum): + ORDERED = "Ordered" + + +OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload] + + +SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary( + { + torch.ops.aten._print.default: _EffectType.ORDERED, + call_torchbind: _EffectType.ORDERED, + } +) + + +def _register_effectful_op(op: OpType, effect: _EffectType): + assert isinstance( + op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ) and not has_aliasing(op) + if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect: + raise RuntimeError( + f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, " + f"trying to register a different effect type {effect}." + ) + SIDE_EFFECTS[op] = effect + + +def _deregister_effectful_op(op: OpType): + if op not in SIDE_EFFECTS: + raise RuntimeError(f"Op {op} is not registered as effectful") + + del SIDE_EFFECTS[op] + + +class WithEffects(HigherOrderOperator): + """ + with_effects(token, op, args, kwargs) -> (new_token, op_results) + + This HOP helps ensure ordering between side effectful ops like prints or ops + using torchbind objects. This is needed to ensure a traced graph from + AOTAutograd is functional so that future optimization passes do not reorder + these operators. This is done through threading "effect tokens" through the + graph to enforce data dependence between side effectful ops. + + The tokens are basically dummy values (torch.tensor([])). We create a token + per "effect type", which are enumerated in the _EffectType enum. + """ + + def __init__(self) -> None: + super().__init__("with_effects") + + def __call__( + self, + token, + op: OpType, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], + ) -> Tuple[Any, ...]: + assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload)) + assert not has_aliasing(op), "Ops with aliasing is not supported" + assert has_effects(op, args, kwargs) + assert isinstance(kwargs, dict) + return super().__call__(token, op, *args, **kwargs) + + +with_effects = WithEffects() + + +def has_aliasing(op: OpType): + # NOT FOR PUBLIC USE + if isinstance(op, torch._ops.HigherOrderOperator): + return op not in SIDE_EFFECTS + + for arg in op._schema.arguments: + if arg.alias_info is not None: + return True + for arg in op._schema.returns: + if arg.alias_info is not None: + return True + return False + + +def has_effects(op, args, kwargs) -> bool: + # Skip over the profiler's RecordFunction as they should not show up in the graph + _skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction} + if op in _skip_ops: + return False + + return ( + isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload)) + and not has_aliasing(op) + and get_effect_key(op, args, kwargs) is not None + ) + + +def get_effect_key(op, args, kwargs) -> Optional[_EffectType]: + if op in SIDE_EFFECTS: + return SIDE_EFFECTS[op] + + for arg in args: + if isinstance(arg, torch.ScriptObject): + # Add it to the table so that next time we see the same op we don't + # have to parse through the args again + SIDE_EFFECTS[op] = _EffectType.ORDERED + return _EffectType.ORDERED + + return None + + +def new_token_tensor() -> torch.Tensor: + # Use dtype bool to not affect Inductor dtype promotions + return torch.tensor([], dtype=torch.bool) + + +@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd) +def with_effects_dense( + token: torch.Tensor, + op: torch._ops.OpOverload, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> Tuple[torch.Tensor, ...]: + out = op(*args, **kwargs) + new_token = new_token_tensor() + if isinstance(out, tuple): + return (new_token, *out) + return (new_token, out) + + +@with_effects.py_impl(FakeTensorMode) +def with_effects_fake( + mode, + token: torch.Tensor, + op: torch._ops.OpOverload, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> Tuple[torch.Tensor, ...]: + with mode: + result = with_effects_dense(token, op, *args, **kwargs) + return result + + +@with_effects.py_impl(ProxyTorchDispatchMode) +def with_effects_proxy( + mode, + token: torch.Tensor, + op: torch._ops.OpOverload, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> Tuple[torch.Tensor, ...]: + with disable_proxy_modes_tracing(): + out = with_effects(token, op, *args, **kwargs) + + proxy_token = mode.tracer.unwrap_proxy(token) + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + + from torch.fx.node import has_side_effect + + # To avoid the being DCEed by graph.eliminate_dead_code if they. + # don't have output or their outputs are not used. + has_side_effect(op) + + out_proxy = mode.tracer.create_proxy( + "call_function", + with_effects, + (proxy_token, op, *proxy_args), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + +with_effects.fallthrough(DispatchKey.AutogradCPU) +with_effects.fallthrough(DispatchKey.AutogradCUDA) + + +def _get_schema(op, args) -> torch.FunctionSchema: + if isinstance(op, torch._ops.OpOverload): + return op._schema + elif op == call_torchbind: + return getattr(args[0], args[1]).schema + else: + raise RuntimeError(f"Unable to get schema for op {op}") + + +def handle_effects( + allow_token_discovery: bool, + tokens: Dict[_EffectType, torch.Tensor], + op: OpType, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], +) -> Any: + """ + Args: + allow_token_discovery: Whether or not we are discovering tokens. If this + is true, we will create a token for every side effect type seen that + does not have a token assigned yet. If this is false, the tokens + should've all been created ahead of time, so we will error if there is + no token mapping to every effect type. + + tokens: Map of effect type to tokens. This is to chain operators of the + same effects together so that they do not get reordered in later + optimization passes. + """ + + # Get a token. We can't do `tokens.get(op, torch.tensor([]))` because + # this will create an empty tensor during proxy mode tracing if the token + # doesn't exist. But the tokens should always exist during proxy mode tracing. + key = get_effect_key(op, args, kwargs) + assert key is not None + if key not in tokens: + assert ( + allow_token_discovery + ), f"Could not find a token for effect {key} which came from the function {op}" + proxy_tensor_mode = torch._C._get_dispatch_mode( + torch._C._TorchDispatchModeKey.PROXY + ) + if proxy_tensor_mode is not None: + # If we discovered a new token during tracing, we are in backward. + # Then we patch the graph, adding additional tangents_token as input to the joint graph. + tracer = proxy_tensor_mode.tracer + + from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + track_tensor_tree, + ) + + with disable_proxy_modes_tracing(): + token_tensor = new_token_tensor() + + token_proxy = proxy_tensor_mode.tracer.create_proxy( + "placeholder", "tangents_token", (), {}, name="tangents_token" + ) + track_tensor_tree(token_tensor, token_proxy, constant=None, tracer=tracer) + + tokens[key] = token_tensor + else: + tokens[key] = new_token_tensor() + + token = tokens[key] + + from torch._subclasses.functional_tensor import PythonFunctionalizeAPI + + ctx = PythonFunctionalizeAPI() + + unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type] + unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type] + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type] + with ctx.redispatch_to_next(): + (new_token, *unwrapped_outs) = with_effects( + unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs # type: ignore[arg-type] + ) + + schema = _get_schema(op, unwrapped_args) + if len(schema.returns) == 0: + assert unwrapped_outs[0] is None + unwrapped_outs = None # type: ignore[assignment] + elif len(schema.returns) == 1: + assert len(unwrapped_outs) == 1 + unwrapped_outs = unwrapped_outs[0] + else: + assert len(unwrapped_outs) == len(schema.returns) + + # Add the newly created token into the tokens map for a following call to + # use this token. + wrapped_token = ctx.wrap_tensors(new_token) + assert isinstance(wrapped_token, torch.Tensor) + tokens[key] = wrapped_token + + return ctx.wrap_tensors(unwrapped_outs) # type: ignore[arg-type] diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/executorch_call_delegate.py b/lib/python3.10/site-packages/torch/_higher_order_ops/executorch_call_delegate.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ee5205ff4ebbde1240929727ec9e47b0a2433a --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/executorch_call_delegate.py @@ -0,0 +1,175 @@ +# mypy: allow-untyped-defs + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +from typing import Any, cast + +import torch +import torch.utils._pytree as pytree +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + get_proxy_slot, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils._pytree import tree_flatten + + +class ExecutorchCallDelegate(HigherOrderOperator): + def __init__(self): + super().__init__("executorch_call_delegate") + + def __call__(self, lowered_module, *args): + return super().__call__(lowered_module, *args) + + +executorch_call_delegate = ExecutorchCallDelegate() +executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher) +executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot) +executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView) +executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU) + +LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule" + + +# pyre-ignore +def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args): + # pyre-ignore + def _unwrap_proxy(e): + if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): + return e + return get_proxy_slot( + cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy # type: ignore[attr-defined] + ) + + if not is_lowered_module(lowered_module): + raise ValueError( + "executorch_call_delegate()'s first argument must be a LoweredBackendModule" + ) + + with disable_proxy_modes_tracing(): + out = call_delegate_cpu(lowered_module, *args) + + get_lowered_module_name(proxy_mode.tracer.root, lowered_module) + + node_args = (lowered_module, *args) + proxy_args = pytree.tree_map(_unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="executorch_call_delegate" + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) +# pyre-ignore +def call_delegate_cpu(lowered_module, *args): + # FX creates this immutable_dict/list concept. Get rid of this. + map_types = { + torch.fx.immutable_collections.immutable_dict: dict, + torch.fx.immutable_collections.immutable_list: list, + } + new_args = pytree.tree_map_only( + tuple(map_types.keys()), + lambda a: map_types[type(a)](a), + args, + lambda a: isinstance(a, tuple(map_types.keys())), + ) + return lowered_module.original_module.module()(*new_args) + + +@executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd) +# pyre-ignore +def call_delegate_autograd(lowered_module, *args): + # TODO: support autograd + flat_operands, _ = tree_flatten([lowered_module, *args]) + requires_grad = any( + f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) + ) + + with torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU) + ): + res = executorch_call_delegate(lowered_module, *args) + + if requires_grad: + # Create aliases of the output that has requires_grad=True. We need + # at least one of the inputs to err_fn to require grad so that the + # output will have a grad_fn. + + # pyre-ignore + def fake_requires_grad(var): + if var is not None: + var = var.detach() + if torch.is_floating_point(var) or torch.is_complex(var): + var.requires_grad = True + return var + + return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res) + + return res + + +@executorch_call_delegate.py_impl(ProxyTorchDispatchMode) +# pyre-ignore +def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args): + res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args) + return res + + +@executorch_call_delegate.py_impl(FakeTensorMode) +# pyre-ignore +def call_delegate_fake_tensor_mode(mode, lowered_module, *args): + with mode: + return call_delegate_cpu(lowered_module, *args) + + +@executorch_call_delegate.py_functionalize_impl +# pyre-ignore +def call_delegate_functionalize(ctx, lowered_module, *args): + unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args) + with ctx.redispatch_to_next(): + res = executorch_call_delegate(lowered_module, *unwrapped_args) + return ctx.wrap_tensors(res) + + +# pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre +def is_lowered_module(obj: Any) -> bool: + """ + This function is added to avoid using isinstance(obj, + LoweredBackendModule) as it will import LoweredBackendModule, which may + cause a circular import. + """ + return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE + + +def get_lowered_module_name( + root: torch.nn.Module, + # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type. + lowered_module: LOWERED_BACKEND_MODULE_TYPE, # type: ignore[valid-type] +) -> str: + """ + Adds the given lowered_module into the given root module and returns the + name of the module added. + """ + # Find a qualifying name for the lowered submodule + qualname = None + i = 0 + while True: + qualname = f"lowered_module_{i}" + if not hasattr(root, qualname): + break + i += 1 + assert qualname is not None + + root.add_module(qualname, lowered_module) + return qualname diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py b/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc62594c50941f8fdf2edef58b059d127e1ad97 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py @@ -0,0 +1,1054 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math +from typing import Any, Callable, Dict, Sequence, Tuple, Union + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_mutation, + autograd_not_implemented, + reenter_make_fx, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + make_fx, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.fx.graph_module import GraphModule +from torch.overrides import TorchFunctionMode + + +# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import +def _construct_strides( + sizes: Sequence[int], + fill_order: Sequence[int], +) -> Sequence[int]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len( + fill_order + ), "Length of sizes must match the length of the fill order" + strides = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides + + +def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch.Tensor: + """ + Create a new tensor with the same data and shape as the input, + but with strides permuted based on the input tensor's stride order. + + Args: + out (torch.Tensor): The output tensor of attention. + query_strides (List[int]): The stride order of the input query tensor + + Returns: + torch.Tensor: A new tensor with same shape and data as the input, + but with strides permuted based on the query tensor's stride order. + """ + from torch._inductor.ir import get_stride_order, stride_order2fill_order + + stride_order = get_stride_order(query_strides) + fill_order = stride_order2fill_order(stride_order) + assert out.storage_offset() == 0, "Only support storage_offset == 0" + out_strides = _construct_strides(out.shape, fill_order) + new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides) + new_out.copy_(out) + return new_out + + +class TransformGetItemToIndex(TorchFunctionMode): + # This is needed since we want to support calling + # A[q_idx], where q_idx is a scalar tensor in score_mod. + # Today, when q_idx is a scalar tensor, we implicitly convert it to a python + # scalar and create a view. We do not want that behavior in this case, so we + # use this torchfunctionmode to override that behavior for score_mod + # wherever we're running it. + def __torch_function__(self, func, types, args, kwargs=None): + if func == torch.Tensor.__getitem__: + index_args = pytree.tree_leaves(args[1]) + if all(isinstance(x, torch.Tensor) for x in index_args): + return torch.ops.aten.index(args[0], index_args) + return func(*args, **(kwargs or {})) + + +class FlexAttentionHOP(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("flex_attention") + + def __call__( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), + ) -> Tuple[torch.Tensor, torch.Tensor]: + if not all( + isinstance(buf, torch.Tensor) + for buf in score_mod_other_buffers + mask_mod_other_buffers + ): + raise RuntimeError("Other buffers must be tensors.") + return super().__call__( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + +flex_attention = FlexAttentionHOP() + + +class FlexAttentionBackwardHOP(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("flex_attention_backward") + + def __call__( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not all( + isinstance(buf, torch.Tensor) + for buf in score_mod_other_buffers + mask_mod_other_buffers + ): + raise RuntimeError("Other buffers must be tensors.") + return super().__call__( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + +flex_attention_backward = FlexAttentionBackwardHOP() + + +def _math_attention_inner( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor]: + working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32 + + scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision) + + b = torch.arange(0, scores.size(0), device=scores.device) + h = torch.arange(0, scores.size(1), device=scores.device) + m = torch.arange(0, scores.size(2), device=scores.device) + n = torch.arange(0, scores.size(3), device=scores.device) + + captured_buffers_in_dim = (None,) * len(score_mod_other_buffers) + from torch.nn.attention.flex_attention import _vmap_for_bhqkv + + # first input is score + score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,), suffix=captured_buffers_in_dim) + + mask_mod = block_mask[-1] + mask_mod_in_dim_buffers = (None,) * len(mask_mod_other_buffers) + mask_mod = _vmap_for_bhqkv(mask_mod, prefix=(), suffix=mask_mod_in_dim_buffers) + + with TransformGetItemToIndex(): + scores = (scores * scale).to(working_precision) + post_mod_scores = torch.where( + mask_mod(b, h, m, n, *mask_mod_other_buffers), + score_mod(scores, b, h, m, n, *score_mod_other_buffers), + torch.tensor(-float("inf"), dtype=working_precision, device=scores.device), + ) + + return scores, post_mod_scores + + +def math_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor]: + """Eager implementation + + This implementation uses vmap to vectorize the score_mod function over the batch, head, m, and n dimensions. + We then apply the vectorized score_mod function to the scores matrix. Each wrap of vmap applies one of the + batch, head, m, or n dimensions. We need to apply vmap 4 times to vectorized over all 4 dimensions. + + Args: + query: The query tensor + key: The key tensor + value: The value tensor + score_mod: The score_mod function + other_buffers: Other buffers that are passed to the score_mod function + """ + # broadcast query & key along head dim for GQA + G = query.size(1) // key.size(1) + value = torch.repeat_interleave(value, G, dim=1) + key = torch.repeat_interleave(key, G, dim=1) + + _, post_mod_scores = _math_attention_inner( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + # Set fully masked rows' sumexp to 0.0 + logsumexp = post_mod_scores.logsumexp(dim=-1) + masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1) + logsumexp = torch.where(masked_rows, -float("inf"), logsumexp) + + post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1) + + return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2) + + +@flex_attention.py_impl(DispatchKey.CompositeExplicitAutograd) +def sdpa_dense( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = math_attention( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + out = _permute_strides(out, query.stride()) + return out, lse + + +def trace_flex_attention( + proxy_mode: ProxyTorchDispatchMode, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor]: + """Traces the flex_attention operator with the given score_mod function and other_buffers. + + Trace SDPA will call make_fx with "fake" example vals and then trace the score_mod function + This will produce a GraphModule that will be stored on the root tracer as "sdpa_score". We + access this graph module in inductor to inline the score_mod function to the triton template. + """ + example_out = flex_attention( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + example_vals = [ + torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad) + ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] + mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)] + mask_mod = block_mask[-1] + with TransformGetItemToIndex(): + score_graph = reenter_make_fx(score_mod)( + *example_vals, *score_mod_other_buffers + ) + mask_graph = reenter_make_fx(mask_mod)( + *mask_example_vals, *mask_mod_other_buffers + ) + assert isinstance(proxy_mode.tracer, torch.fx.Tracer) + block_mask = block_mask[:-1] + (mask_graph,) + qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_score") + proxy_mode.tracer.root.register_module(qualname, score_graph) + mask_qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_mask") + proxy_mode.tracer.root.register_module(mask_qualname, mask_graph) + node_args = ( + query, + key, + value, + score_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", flex_attention, proxy_args, {} + ) + return track_tensor_tree( + example_out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + +@flex_attention.py_impl(ProxyTorchDispatchMode) +def flex_attention_proxy_torch_dispatch_mode( + mode: ProxyTorchDispatchMode, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor]: + assert mode is not None, "Mode should always be enabled for python fallback key" + return trace_flex_attention( + mode, + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + +@flex_attention.py_functionalize_impl +def flex_attention_functionalize( + ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor]: + """Defines the functionalization rules for the flex_attention operator. + + Write now we are unwrapping each tensor and then redispatching to the next, however we want to + guard against any mutations in the score_mod function, to the other_buffers since those + are free variables. + """ + query_unwrapped = ctx.unwrap_tensors(query) + key_unwrapped = ctx.unwrap_tensors(key) + value_unwrapped = ctx.unwrap_tensors(value) + block_mask_unwrapped = ctx.unwrap_tensors(block_mask) + score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers) + mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers) + + # Appease the mypy overlords + assert isinstance(query_unwrapped, torch.Tensor) + assert isinstance(key_unwrapped, torch.Tensor) + assert isinstance(value_unwrapped, torch.Tensor) + assert isinstance(block_mask_unwrapped, tuple) + assert isinstance(score_mod_other_buffers_unwrapped, tuple) + assert isinstance(mask_mod_other_buffers_unwrapped, tuple) + assert all( + isinstance(item, torch.Tensor) + for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped + ) + + example_vals = ( + [torch.zeros((), dtype=query.dtype)] + + [torch.zeros((), dtype=torch.int) for _ in range(4)] + + list(score_mod_other_buffers_unwrapped) + ) + with ctx.redispatch_to_next() as m: + functional_score_mod = ctx.functionalize(score_mod) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + with TransformGetItemToIndex(): + mutates = _has_potential_branch_input_mutation( + functional_score_mod, example_vals, pre_dispatch + ) + # The only care about mutations of existing buffers since we can't replay these. + # However, we can just error if anything is detected + if mutates: + raise UnsupportedAliasMutationException("Mutations detected in score_mod") + + out = flex_attention( + query_unwrapped, + key_unwrapped, + value_unwrapped, + functional_score_mod, + block_mask_unwrapped, + scale, + kernel_options, + score_mod_other_buffers_unwrapped, + mask_mod_other_buffers_unwrapped, + ) + return ctx.wrap_tensors(out) # type: ignore[return-value, arg-type] + + +@flex_attention.py_impl(FakeTensorMode) +def flex_attention_fake_tensor_mode( + mode: FakeTensorMode, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor]: + with mode: + v_head_dim = value.size(-1) + batch_size, num_heads, seq_len_q, q_head_dim = query.shape + logsumexp = query.new_empty( + batch_size, num_heads, seq_len_q, dtype=torch.float32 + ) + out_shape = (batch_size, num_heads, seq_len_q, v_head_dim) + out = query.new_empty(out_shape) + out = _permute_strides(out, query.stride()) + return out, logsumexp + + +# ---------------------------- Autograd Implementation ---------------------------- +def create_fw_bw_graph(score_mod, index_values, other_buffers): + # See Note:[HOP create fw_bw graph] + + # All of these imports need to be here in order to avoid circular dependencies + from torch._dispatch.python import suspend_functionalization + from torch._functorch.aot_autograd import AOTConfig, create_joint + from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + from torch._subclasses.functional_tensor import disable_functional_mode + from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + + def _from_fun(t): + return torch.empty_strided( + t.size(), + t.stride(), + device=t.device, + dtype=t.dtype, + requires_grad=t.requires_grad, + ) + + # If someone runs this hop under the default compiler backend ("eager") + # Then this path will be run with the actual user inputs. We convert them + # to fake tensors in order to not perform any actual compute. + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(index_values) + if fake_mode is None: + fake_mode = FakeTensorMode(allow_non_fake_inputs=True) + + with fake_mode: + unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values) + unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers) + + assert all(isinstance(t, FakeTensor) for t in unwrapped_score_mod_indexes) + assert all(isinstance(t, FakeTensor) for t in unwrapped_other_buffers) + + example_flat_out = pytree.tree_map( + _from_fun, + score_mod(*unwrapped_score_mod_indexes, *unwrapped_other_buffers), + ) + if not isinstance(example_flat_out, torch.Tensor): + raise RuntimeError( + "Expected output of score_mod to be a tensor." + f"Got type {type(example_flat_out)}." + ) + example_grad = _from_fun(example_flat_out) + + def joint_f(score, b, h, m, n, example_grad, *other_buffers): + def fw_with_masks(*args): + fw_out = score_mod(*args) + out_requires_grad = fw_out.requires_grad + return ((fw_out,), (out_requires_grad,)) + + joint = create_joint(fw_with_masks, aot_config=dummy_aot_config) + args = [score, b, h, m, n] + list(other_buffers) + optional_grad = [example_grad] if example_grad.requires_grad else [] + _, grads = joint(args, optional_grad) + + return grads + + joint_graph = make_fx(joint_f)( + *unwrapped_score_mod_indexes, example_grad, *unwrapped_other_buffers + ) + return score_mod, joint_graph + + +class FlexAttentionAutogradOp(torch.autograd.Function): + @staticmethod + def forward( + ctx, + query, + key, + value, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) -> Tuple[torch.Tensor, torch.Tensor]: + any_buffer_requires_grad = any( + buffer.requires_grad + for buffer in score_mod_other_buffers + mask_mod_other_buffers + ) + assert ( + not any_buffer_requires_grad + ), "Captured buffers that require grad are not yet supported." + ctx._fw_graph = fw_graph + ctx._joint_graph = joint_graph + ctx._mask_graph = block_mask[-1] + # KV_BLOCK_SIZE and Q_BLOCK_SIZE are integers, so can't use ctx.save_for_backward + ctx._KV_BLOCK_SIZE = block_mask[8] + ctx._Q_BLOCK_SIZE = block_mask[9] + ctx.scale = scale + ctx.kernel_options = kernel_options + ctx._score_mod_other_buffers_len = len(score_mod_other_buffers) + with torch._C._AutoDispatchBelowAutograd(): + out, logsumexp = flex_attention( + query, + key, + value, + fw_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + ctx.save_for_backward( + query, + key, + value, + out, + logsumexp, + *block_mask[:8], + *score_mod_other_buffers, + *mask_mod_other_buffers, + ) + return out, logsumexp + + @staticmethod + def backward(ctx, grad_out, grad_logsumexp): + fw_args = ctx.saved_tensors + ( + query, + key, + value, + out, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + *other_buffers, + ) = fw_args + fw_graph = ctx._fw_graph + joint_graph = ctx._joint_graph + mask_graph = ctx._mask_graph + KV_BLOCK_SIZE = ctx._KV_BLOCK_SIZE + Q_BLOCK_SIZE = ctx._Q_BLOCK_SIZE + scale = ctx.scale + kernel_options = ctx.kernel_options + score_mod_other_buffers = tuple( + other_buffers[: ctx._score_mod_other_buffers_len] + ) + mask_mod_other_buffers = tuple( + other_buffers[ctx._score_mod_other_buffers_len :] + ) + # We have asserted that other_buffers do not require grad in the forward + none_grads = [None] * 7 + grad_query, grad_key, grad_value = flex_attention_backward( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + ( + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + KV_BLOCK_SIZE, + Q_BLOCK_SIZE, + mask_graph, + ), + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + return grad_query, grad_key, grad_value, *none_grads + + +@flex_attention.py_impl(DispatchKey.Autograd) +def flex_attention_autograd( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor]: + with TransformGetItemToIndex(): + input_requires_grad = any(t.requires_grad for t in (query, key, value)) + if torch.is_grad_enabled() and input_requires_grad: + example_vals = [ + torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad) + ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] + fw_graph, bw_graph = create_fw_bw_graph( + score_mod, example_vals, score_mod_other_buffers + ) + else: + fw_graph, bw_graph = score_mod, None + out, logsumexp = FlexAttentionAutogradOp.apply( + query, + key, + value, + fw_graph, + bw_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + return out, logsumexp + + +# ---------------------------- Backward HOP Implementation ---------------------------- + + +@flex_attention_backward.py_impl(DispatchKey.CompositeExplicitAutograd) +def sdpa_dense_backward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Callable, # GraphModule type hint? + joint_graph: Callable, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple, + mask_mod_other_buffers: Tuple, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Get outputs before calling repeat interleave + actual_grad_query = torch.empty_like(query) + actual_grad_key = torch.empty_like(key) + actual_grad_value = torch.empty_like(value) + + G = query.size(1) // key.size(1) + key = torch.repeat_interleave(key, G, dim=1) + value = torch.repeat_interleave(value, G, dim=1) + + # We're undoing the log -> log2 change of base in the forwards + logsumexp = logsumexp * math.log(2) + # The backwards formula for the log -> log2 change of base in the forwards + grad_logsumexp = grad_logsumexp / math.log(2) + scores, post_mod_scores = _math_attention_inner( + query, + key, + value, + fw_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + masked_out_rows = logsumexp == -float("inf") + softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1)) + softmax_scores = torch.where(masked_out_rows.unsqueeze(-1), 0, softmax_scores) + + grad_value = softmax_scores.to(query.dtype).transpose(-2, -1) @ grad_out + + grad_softmax_scores = grad_out @ value.transpose(-2, -1) + + sum_scores = torch.sum(out * grad_out, -1, keepdim=True) + grad_score_mod = softmax_scores * ( + grad_softmax_scores - sum_scores + grad_logsumexp.unsqueeze(-1) + ) + + b = torch.arange(0, scores.size(0), device=scores.device) + h = torch.arange(0, scores.size(1), device=scores.device) + m = torch.arange(0, scores.size(2), device=scores.device) + n = torch.arange(0, scores.size(3), device=scores.device) + + mask_graph = block_mask[-1] + # Gradient of the inline score_mod function, with respect to the scores + captured_buffers_in_dim = (None,) * len(score_mod_other_buffers) + out_dims = [0, None, None, None, None] + [None] * len(score_mod_other_buffers) + from torch.nn.attention.flex_attention import _vmap_for_bhqkv + + # inputs are [score, b, h, q_idx, kv_idx, gradOut, ...] + # score and gradOut are "fully" batched + joint_score_mod = _vmap_for_bhqkv( + joint_graph, + prefix=(0,), + suffix=(0,) + captured_buffers_in_dim, + out_dims=out_dims, + ) + with TransformGetItemToIndex(): + grad_scores, *_ = joint_score_mod( + scores, b, h, m, n, grad_score_mod, *score_mod_other_buffers + ) + grad_scores = grad_scores * scale + grad_scores = grad_scores.to(query.dtype) + + mask_mod = _vmap_for_bhqkv( + mask_graph, prefix=(), suffix=(None,) * len(mask_mod_other_buffers) + ) + with TransformGetItemToIndex(): + mask_scores = mask_mod(b, h, m, n, *mask_mod_other_buffers) + grad_scores = torch.where( + mask_scores, grad_scores, torch.tensor(0, dtype=query.dtype) + ) + + grad_query = grad_scores @ key + grad_key = grad_scores.transpose(-2, -1) @ query + + # Reduce DK, DV along broadcasted heads. + grad_key = grad_key.view( + grad_key.size(0), -1, G, grad_key.size(-2), grad_key.size(-1) + ) + grad_value = grad_value.view( + grad_value.size(0), -1, G, grad_value.size(-2), grad_value.size(-1) + ) + + grad_key = torch.sum(grad_key, 2, keepdim=False) + grad_value = torch.sum(grad_value, 2, keepdim=False) + + actual_grad_query.copy_(grad_query) + actual_grad_key.copy_(grad_key) + actual_grad_value.copy_(grad_value) + + return actual_grad_query, actual_grad_key, actual_grad_value + + +def trace_flex_attention_backward( + proxy_mode: ProxyTorchDispatchMode, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs""" + example_out = flex_attention_backward( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + fw_example_vals = [ + torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad) + ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] + bw_example_vals = fw_example_vals + [torch.zeros((), dtype=query.dtype)] + mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)] + mask_graph = block_mask[-1] + with TransformGetItemToIndex(): + fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *score_mod_other_buffers) + joint_graph = reenter_make_fx(joint_graph)( + *bw_example_vals, *score_mod_other_buffers + ) + mask_graph = reenter_make_fx(mask_graph)( + *mask_example_vals, *mask_mod_other_buffers + ) + assert isinstance(proxy_mode.tracer, torch.fx.Tracer) + block_mask = block_mask[:-1] + (mask_graph,) + proxy_mode.tracer.root.register_module("fw_graph", fw_graph) # type: ignore[arg-type] + proxy_mode.tracer.root.register_module("joint_graph", joint_graph) + proxy_mode.tracer.root.register_module("mask_graph", mask_graph) + node_args = ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", + flex_attention_backward, + proxy_args, + {}, + name="flex_attention_backward", + ) + return track_tensor_tree( + example_out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + +@flex_attention_backward.py_impl(ProxyTorchDispatchMode) +def flex_attention_backward_proxy_torch_dispatch_mode( + mode: ProxyTorchDispatchMode, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert mode is not None, "Mode should always be enabled for python fallback key" + return trace_flex_attention_backward( + mode, + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + +@flex_attention_backward.py_functionalize_impl +def flex_attention_backward_functionalize( + ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Defines the functionalization rules for the flex_attention operator. + + Write now we are unwrapping each tensor and then redispatching to the next, + since we know that the forward score mod function is assured to be free of mutations + to the other_buffers, we skip that mutate check and go straight to redispatching. + """ + query_unwrapped = ctx.unwrap_tensors(query) + key_unwrapped = ctx.unwrap_tensors(key) + value_unwrapped = ctx.unwrap_tensors(value) + out_unwrapped = ctx.unwrap_tensors(out) + logsumexp_unwrapped = ctx.unwrap_tensors(logsumexp) + grad_out_unwrapped = ctx.unwrap_tensors(grad_out) + grad_logsumexp_unwrapped = ctx.unwrap_tensors(grad_logsumexp) + block_mask_unwrapped = ctx.unwrap_tensors(block_mask) + score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers) + mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers) + + # Appease the mypy overlords + assert isinstance(query_unwrapped, torch.Tensor) + assert isinstance(key_unwrapped, torch.Tensor) + assert isinstance(value_unwrapped, torch.Tensor) + assert isinstance(out_unwrapped, torch.Tensor) + assert isinstance(logsumexp_unwrapped, torch.Tensor) + assert isinstance(grad_out_unwrapped, torch.Tensor) + assert isinstance(grad_logsumexp_unwrapped, torch.Tensor) + assert isinstance(block_mask_unwrapped, tuple) + assert isinstance(score_mod_other_buffers_unwrapped, tuple) + assert isinstance(mask_mod_other_buffers_unwrapped, tuple) + assert all( + isinstance(item, torch.Tensor) + for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped + ) + + with ctx.redispatch_to_next() as m: + functional_fw_graph = ctx.functionalize(fw_graph) + functional_joint_graph = ctx.functionalize(joint_graph) + + grad_query, grad_key, grad_value = flex_attention_backward( + query_unwrapped, + key_unwrapped, + value_unwrapped, + out_unwrapped, + logsumexp_unwrapped, + grad_out_unwrapped, + grad_logsumexp_unwrapped, + functional_fw_graph, # type: ignore[arg-type] + functional_joint_graph, # type: ignore[arg-type] + block_mask_unwrapped, + scale, + kernel_options, + score_mod_other_buffers_unwrapped, + mask_mod_other_buffers_unwrapped, + ) + + return ctx.wrap_tensors((grad_query, grad_key, grad_value)) # type: ignore[return-value,arg-type] + + +@flex_attention_backward.py_impl(FakeTensorMode) +def flex_attention_backward_fake_tensor_mode( + mode: FakeTensorMode, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: Tuple, + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple = (), + mask_mod_other_buffers: Tuple = (), +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + with mode: + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) + return grad_query, grad_key, grad_value + + +flex_attention_backward.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(flex_attention_backward, deferred_error=True) +) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/hints_wrap.py b/lib/python3.10/site-packages/torch/_higher_order_ops/hints_wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..c211d405614636085575369e4a3738d3ced3fe02 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/hints_wrap.py @@ -0,0 +1,151 @@ +# mypy: allow-untyped-defs +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + autograd_not_implemented, + reenter_make_fx, + unique_graph_id, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree + + +# used for wrapping a function/op with context hints +class HintsWrapper(HigherOrderOperator): + def __init__(self): + super().__init__("hints_wrapper") + + def __call__(self, body_fn, args, kwargs, hints): + r""" + Call implementation of hints_wrapper + + Args: + body_fn (Callable): A callable function that is within the scope + that is being traced. + + args (Tuple of torch.Tensor/int/float/bool): A tuple of inputs to + body_fn. + + kwargs (dict): Keyword argument to the body_fn. + + hints (dict): A dict of context hints which could be passed to + backend compiler. + """ + if not isinstance(args, tuple): + raise RuntimeError(f"args must be a tuple, got {type(args)}") + + if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args): + raise RuntimeError( + "args must be a tuple of tensors, ints, floats, or bools, got " + f"{args}" + ) + + if not isinstance(kwargs, dict): + raise RuntimeError(f"kwargs must be a dict, got {type(kwargs)}") + + if len(kwargs) > 0: + raise RuntimeError( + f"kwargs except for hints are not supported, got {kwargs}" + ) + + if not isinstance(hints, dict): + raise RuntimeError(f"hints must be a dict, got {type(hints)}") + + for k, v in hints.items(): + if not isinstance(k, str): + raise RuntimeError(f"hints key must be a str, got {k}.") + + if not isinstance(v, (int, float, bool, str)): + raise RuntimeError( + "hints must be a dict containing int, float, bool or str " + f"value, got value {v} for key {k}." + ) + + return super().__call__(body_fn, args, kwargs, hints) + + +hints_wrapper = HintsWrapper() + + +@hints_wrapper.py_impl(DispatchKey.CompositeExplicitAutograd) +def hints_wrapper_dense(body_fn, args, kwargs, hints): + return body_fn(*args, **kwargs) + + +hints_wrapper.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(hints_wrapper, deferred_error=True) +) + + +@hints_wrapper.py_impl(FakeTensorMode) +def hints_wrapper_fake_tensor_mode(mode, body_func, args, kwargs, hints): + flat_args = pytree.tree_leaves(args) + with mode: + return body_func(*flat_args, **kwargs) + + +@hints_wrapper.py_functionalize_impl +def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints): + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + unwrapped_hints = ctx.unwrap_tensors(hints) + with ctx.redispatch_to_next(): + functional_body_fn = ctx.functionalize(body_fn) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + if _has_potential_branch_input_mutation( + functional_body_fn, unwrapped_args, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "body_fn of hints_wrapper might be modifying the input!" + ) + if _has_potential_branch_input_alias( + functional_body_fn, unwrapped_args, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "body_fn of hints_wrapper might be aliasing the input!" + ) + outputs = hints_wrapper( + functional_body_fn, + unwrapped_args, + unwrapped_kwargs, + unwrapped_hints, + ) + return ctx.wrap_tensors(outputs) + + +def trace_hints_wrapper(proxy_mode, hints_wrapper, body_fn, args, kwargs, hints): + flat_args = tuple(pytree.tree_leaves(args)) + body_graph = reenter_make_fx(body_fn)(*flat_args, **kwargs) + + _, body_graph_name = unique_graph_id(proxy_mode, prefix="hints_wrapper_body_graph") + proxy_mode.tracer.root.register_module(body_graph_name, body_graph) + + new_args: tuple = (body_graph, flat_args, {}) + # merge hints into kwargs + new_kwargs = {} + new_kwargs["hints"] = hints + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_args) + proxy_kwargs = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_kwargs) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", hints_wrapper, proxy_args, proxy_kwargs, name="hints_wrapper" + ) + + out = body_fn(*flat_args, **kwargs) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@hints_wrapper.py_impl(ProxyTorchDispatchMode) +def inner(proxy_mode, body_fn, args, kwargs, hints): + if proxy_mode.enable_tracing: + return trace_hints_wrapper( + proxy_mode, hints_wrapper, body_fn, args, kwargs, hints + ) + else: + return hints_wrapper(body_fn, args, kwargs, hints) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/map.py b/lib/python3.10/site-packages/torch/_higher_order_ops/map.py new file mode 100644 index 0000000000000000000000000000000000000000..d57d68d5e473f722f47e2112ff9a210dd2dc1bc9 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/map.py @@ -0,0 +1,264 @@ +# mypy: allow-untyped-defs +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._dispatch.python import suspend_functionalization +from torch._functorch.aot_autograd import AOTConfig, create_joint +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + _maybe_run_with_interpreter, + reenter_make_fx, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch._subclasses.functional_tensor import disable_functional_mode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + make_fx, + ProxyTorchDispatchMode, + track_tensor_tree, +) + +from .utils import ( + _from_fun, + _stack_pytree, + _unstack_pytree, + clone_outputs_aliasing_inputs, + prepare_fw_with_masks, +) + + +# TODO: We add this to prevent dymamo from tracing into map_wrapper, +# remove the wrapper call when it's ready. +class MapWrapper(HigherOrderOperator): + def __init__(self): + super().__init__("map") + + def __call__(self, xs, *args): + return map_wrapper(xs, *args) + + +class MapImpl(HigherOrderOperator): + def __init__(self): + super().__init__("map_impl") + + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + + +map = MapWrapper() + +map_impl = MapImpl() + +dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, +) + + +def create_fw_bw_graph(f, num_mapped_args, *args): + mapped_xs = args[:num_mapped_args] + pos_args = args[num_mapped_args:] + + # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs) + example_xs = _unstack_pytree(unwrapped_mapped_xs)[0] + + example_pos_args = [ + _from_fun(arg) if isinstance(arg, torch.Tensor) else arg + for arg in pos_args + ] + example_flat_out = pytree.tree_map( + _from_fun, f(*example_xs, *example_pos_args) + ) + if any( + not isinstance(out, torch.Tensor) + for out in example_flat_out + if out is not None + ): + raise RuntimeError( + "Expect outputs of map only contains tensors or None. " + f"Got types {[type(out) for out in example_flat_out]}." + ) + example_grad = [_from_fun(out) for out in example_flat_out] + + fw_graph = make_fx(f)(*example_xs, *example_pos_args) + + def joint_f(*example_args): + joint_mapped_args = example_args[:joint_num_mapped] + args = example_args[joint_num_mapped:] + + mapped_input = joint_mapped_args[:num_mapped_args] + mapped_grads = joint_mapped_args[num_mapped_args:] + + joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config) + _, grads = joint( + list(mapped_input) + list(args), + [ + grad + for grad in mapped_grads + if grad is not None and grad.requires_grad + ], + ) + + # In order to keep map functional for backward graph, + # we clone outputs that are aliasing inputs + maybe_clone = clone_outputs_aliasing_inputs(example_args) + + return pytree.tree_map(maybe_clone, grads) + + joint_num_mapped = len(example_grad) + len(example_xs) + joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args) + return fw_graph, joint_graph + + +def map_wrapper(f, xs, *args): + flat_xs, xs_spec = pytree.tree_flatten(xs) + if not all(isinstance(t, torch.Tensor) for t in flat_xs): + raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.") + + num_mapped_args = len(flat_xs) + shapes = [xs.shape for xs in flat_xs] + leading_dim_size = shapes[0][0] + if leading_dim_size == 0: + raise RuntimeError("Leading dimensions of mapped xs cannot be 0.") + + if any(cur_shape[0] != leading_dim_size for cur_shape in shapes): + raise RuntimeError( + f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}." + ) + + out_spec = None + + def flat_fn(*flat_args): + xs = pytree.tree_unflatten(list(flat_args[:num_mapped_args]), xs_spec) + unflattened_out = f(xs, *flat_args[num_mapped_args:]) + flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out) + + nonlocal out_spec + out_spec = tmp_out_spec + return flat_out + + return pytree.tree_unflatten( + map_impl(flat_fn, flat_xs, args), out_spec # type: ignore[arg-type] + ) + + +class MapAutogradOp(torch.autograd.Function): + @staticmethod + def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args): + ctx.save_for_backward(*flat_args) + ctx._joint_graph = joint_graph + ctx._num_mapped_args = num_mapped_args + with torch._C._AutoDispatchBelowAutograd(): + return ( + *map_impl( + fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:] + ), + ) + + @staticmethod + def backward(ctx, *flat_grads): + fw_args = ctx.saved_tensors + fw_mapped_args = fw_args[: ctx._num_mapped_args] + pos_args = fw_args[ctx._num_mapped_args :] + + grads = map_impl( + ctx._joint_graph, + fw_mapped_args + flat_grads, + pos_args, + ) + return None, None, None, *grads + + +def trace_map(proxy_mode, func_overload, f, xs, pos_args): + leading_dim_size = xs[0].shape[0] + + example_input = _unstack_pytree(xs)[0] + body_graph = f + + body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args) + + next_name = proxy_mode.tracer.get_fresh_qualname("body_graph_") + + proxy_mode.tracer.root.register_module(next_name, body_graph) + + with disable_proxy_modes_tracing(): + example_outs = body_graph(*example_input, *pos_args) + + def expand_tensor(t): + if isinstance(t, torch.Tensor): + return t.expand(leading_dim_size, *t.shape) + return t + + expanded_outs = pytree.tree_map(expand_tensor, example_outs) + + node_args = (body_graph, list(xs), list(pos_args)) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="map_impl" + ) + return track_tensor_tree( + expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + +@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd) +def map_dense(f, xs, pos_args): + pytrees = [] + for inp in _unstack_pytree(xs): + pytrees.append(f(*inp, *pos_args)) + return _stack_pytree(pytrees) + + +@map_impl.py_impl(DispatchKey.Autograd) +def map_autograd(f, xs, pos_args): + num_mapped_args = len(xs) + fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args) + flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args) + return flat_out + + +@map_impl.py_impl(ProxyTorchDispatchMode) +def map_proxy_torch_dispatch_mode(mode, f, xs, args): + return trace_map(mode, map_impl, f, xs, args) + + +@map_impl.py_impl(FakeTensorMode) +def map_fake_tensor_mode(mode, f, xs, args): + with mode: + return map_dense(f, xs, args) + + +@map_impl.py_functionalize_impl +def map_functionalize(ctx, f, xs, pos_args): + unwrapped_xs = ctx.unwrap_tensors(xs) + unwrapped_args = ctx.unwrap_tensors(pos_args) + wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f)) + + with ctx.redispatch_to_next(): + with disable_proxy_modes_tracing(): + example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + if _has_potential_branch_input_mutation( + f, example_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException("torch.map is mutating the input!") + + if _has_potential_branch_input_alias( + f, example_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException("torch.map is aliasing the input!") + + map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args) + return ctx.wrap_tensors(map_return) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/out_dtype.py b/lib/python3.10/site-packages/torch/_higher_order_ops/out_dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..3a2dfbe8ae2fb5f394d507fce632fe9802261136 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/out_dtype.py @@ -0,0 +1,166 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + maybe_handle_decomp, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +# TODO to figure out a more generic approach +ALLOWABLE_OPS = [ + torch.ops.aten.linear.default, + torch.ops.aten.mm.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.convolution.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul.Scalar, + torch.ops.aten.div.Tensor, + torch.ops.aten.div.Scalar, +] + + +class OutDtypeOperator(HigherOrderOperator): + """ + The out_dtype operator takes an existing ATen functional operator, an + `out_dtype` argument, and arguments to the original operator, and executes + the original operator and returns a Tensor with the `out_dtype` precision. + This operator does not mandate a compute precision so it allows the + representation to not be opinionated about the exact implementation. + + The general implementation for all operators will be the following: + 1. Promote inputs dtypes based on default PyTorch dtype promotion rules, + using the dtypes of all input Tensors/Scalars and the `out_dtype` + arugument. + 2. Execute the operator + 3. Cast the output to `out_dtype` + """ + + def __init__(self) -> None: + super().__init__("out_dtype") + + def __call__(self, op, output_dtype, *args): + if not isinstance(op, torch._ops.OpOverload): + raise ValueError("out_dtype's first argument must be an OpOverload") + if op._schema.is_mutable: + raise ValueError( + "out_dtype's first argument needs to be a functional operator" + ) + if not ( + len(op._schema.returns) == 1 + and isinstance(op._schema.returns[0].type, torch.TensorType) + ): + raise ValueError( + "out_dtype's can only apply to ops that return a single tensor" + f"Instead got {[r.type for r in op._schema.returns]}" + ) + + if op not in ALLOWABLE_OPS: + raise ValueError( + f"out_dtype only allows the following operators: {ALLOWABLE_OPS}." + ) + + res = super().__call__(op, output_dtype, *args) + + return res + + +out_dtype = OutDtypeOperator() + + +def trace_out_dtype(proxy_mode, func_overload, op, output_dtype, *args): + # NB: Long-term we should put the decomposition logic into + # ProxyTorchDispatchMode so that people do not need to call maybe_handle_decomp + # in all HigherOrderOp proxy implementations. + r = maybe_handle_decomp(proxy_mode, func_overload, (op, output_dtype, *args), {}) + if r is not NotImplemented: + return r + + with disable_proxy_modes_tracing(): + # This is a simplified implementation of this operator just for tracing. + # Actual implementation may also first promote the arguments + out = op(*args).to(dtype=output_dtype) + + node_args = (op, output_dtype, *args) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="out_dtype" + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@out_dtype.py_impl(DispatchKey.CompositeExplicitAutograd) +def out_dtype_dense(op: torch._ops.OpOverload, output_dtype: torch.dtype, *args): + if is_int_mm(op, output_dtype, args): + return torch._int_mm(*args) + return out_dtype_fallback(op, output_dtype, *args) + + +def is_int_mm(op, output_dtype, args): + return ( + op == torch.ops.aten.mm.default + and output_dtype == torch.int32 + and len(args) == 2 + and args[0].dtype == torch.int8 + and args[1].dtype == torch.int8 + and args[0].is_cuda + and args[1].is_cuda + ) + + +def out_dtype_fallback(op, output_dtype, *args): + flat_inputs = pytree.arg_tree_leaves(*args) + [torch.ones(1, dtype=output_dtype)] + promote_dtype: torch.dtype = elementwise_dtypes( + *flat_inputs, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + )[0] + + casted_args = pytree.tree_map_only( + torch.Tensor, lambda arg: arg.to(dtype=promote_dtype), args + ) + res = op(*casted_args).to(dtype=output_dtype) + return res + + +out_dtype.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(out_dtype, deferred_error=True) +) + + +@out_dtype.py_impl(ProxyTorchDispatchMode) +def out_dtype_proxy( + mode: ProxyTorchDispatchMode, + op: torch._ops.OpOverload, + output_dtype: torch.dtype, + *args, +): + return trace_out_dtype(mode, out_dtype, op, output_dtype, *args) + + +@out_dtype.py_impl(FakeTensorMode) +def out_dtype_fake_tensor_mode( + mode: FakeTensorMode, + op: torch._ops.OpOverload, + output_dtype: torch.dtype, + *args, +): + with mode: + return out_dtype_dense(op, output_dtype, *args) + + +@out_dtype.py_functionalize_impl +def out_dtype_func(ctx, op, output_dtype, *args): + unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args) + + with ctx.redispatch_to_next(): + res = out_dtype(op, output_dtype, *unwrapped_args) + return ctx.wrap_tensors(res) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/run_const_graph.py b/lib/python3.10/site-packages/torch/_higher_order_ops/run_const_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..1f49ee28394a1c464612d2c5ee8a59abd6be1b49 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/run_const_graph.py @@ -0,0 +1,60 @@ +# mypy: allow-untyped-defs +import torch +from torch._C import DispatchKey +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.utils import _pytree as pytree + + +class RunConstGraph(HigherOrderOperator): + def __init__(self): + super().__init__("run_const_graph") + + def __call__(self, *args): + return super().__call__(*args) + + +run_const_graph = RunConstGraph() + + +@run_const_graph.py_impl(ProxyTorchDispatchMode) +def run_const_graph_dispatch_mode(mode, *args): + const_gm, weights = args + p_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + assert isinstance(const_gm, torch.fx.GraphModule) + assert not hasattr(mode.tracer.root, "_const_graph") + mode.tracer.root.register_module("_const_graph", const_gm) + + proxy = mode.tracer.create_proxy("call_function", run_const_graph, p_args, {}) + + out = const_gm(*weights) + return track_tensor_tree(out, proxy, constant=None, tracer=mode.tracer) + + +@run_const_graph.py_functionalize_impl +def run_const_graph_functional(ctx, *args): + unwrapped_args = ctx.unwrap_tensors(args) + + with ctx.redispatch_to_next(): + out = run_const_graph(*unwrapped_args) + return ctx.wrap_tensors(out) + + +run_const_graph.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(run_const_graph, deferred_error=True) +) + + +@run_const_graph.py_impl(FakeTensorMode) +def run_const_graph_fake_tensor_mode(mode, graph, args): + assert isinstance(graph, torch.fx.GraphModule) + with mode: + return graph(*args) + + +@run_const_graph.py_impl(DispatchKey.CPU) +def run_const_graph_cpu(graph, args): + assert isinstance(graph, torch.fx.GraphModule) + return graph(*args) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/strict_mode.py b/lib/python3.10/site-packages/torch/_higher_order_ops/strict_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..7324e20dcd4cd7bb91fe0a94c50028df9a47cc78 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/strict_mode.py @@ -0,0 +1,94 @@ +# mypy: allow-untyped-defs +import torch +import torch._subclasses.functional_tensor +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._functorch.utils import exposed_in +from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + make_fx, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils._python_dispatch import _get_current_dispatch_mode + + +@exposed_in("torch") +def strict_mode(callable, operands): + if torch.compiler.is_dynamo_compiling(): + return strict_mode_op(callable, operands) + + with _set_compilation_env(): + with torch._dynamo.utils.disable_cache_limit(): + return torch.compile(strict_mode_op, backend="eager", fullgraph=True)( + callable, operands + ) + + +class StrictMode(HigherOrderOperator): + def __init__(self): + super().__init__("strict_mode") + + def __call__(self, callable, operands): + return super().__call__(callable, operands) + + +strict_mode_op = StrictMode() + + +@strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def strict_mode_op_dense(callable, operands): + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return callable(*operands) + + +strict_mode_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(strict_mode_op, deferred_error=True) +) + + +@strict_mode_op.py_impl(ProxyTorchDispatchMode) +def inner(mode, callable, operands): + return trace_strict_mode(mode, strict_mode_op, callable, operands) + + +def trace_strict_mode(mode, strict_mode_op, callable, operands): + pre_dispatch = getattr(mode, "pre_dispatch", False) + + with disable_proxy_modes_tracing(): + graph = make_fx(callable, pre_dispatch=pre_dispatch)(*operands) + + graph_name = mode.tracer.get_fresh_qualname("strict_graph_") + mode.tracer.root.register_module(graph_name, graph) + + args = (graph, operands) + + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + + out_proxy = mode.tracer.create_proxy( + "call_function", strict_mode_op, proxy_args, {}, name="strict_mode" + ) + + out = graph(*operands) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + + +@strict_mode_op.py_impl(FakeTensorMode) +def strict_mode_fake_tensor_mode(mode, callable, operands): + with mode: + true_outs = callable(*operands) + return true_outs + + +@strict_mode_op.py_functionalize_impl +def strict_mode_func(ctx, callable, inputs): + unwrapped_inputs = ctx.unwrap_tensors(inputs) + with ctx.redispatch_to_next(): + functional_callable = ctx.functionalize(callable) + + cond_return = strict_mode_op(functional_callable, unwrapped_inputs) + return ctx.wrap_tensors(cond_return) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/torchbind.py b/lib/python3.10/site-packages/torch/_higher_order_ops/torchbind.py new file mode 100644 index 0000000000000000000000000000000000000000..b35b8d5b296d118045347c35940454d88da07133 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/torchbind.py @@ -0,0 +1,142 @@ +# mypy: allow-untyped-defs +import logging +from contextlib import contextmanager + +import torch +from torch._C import DispatchKey # @manual +from torch._functorch._aot_autograd.utils import KNOWN_TYPES +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._library.fake_class_registry import _ns_and_class_name, FakeScriptObject +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.fx.node import has_side_effect +from torch.utils import _pytree as pytree + + +log = logging.getLogger(__name__) + + +# The call_torchbind operator represents a method invocation on a torchbind +# object. The calling convention is: +# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs) +# We do not expect users to write this operator directly. Instead it will be +# emitted by Dynamo when tracing encounters a torchbind object. +class CallTorchBind(HigherOrderOperator): + def __init__(self): + super().__init__("call_torchbind") + + def __call__(self, obj, method, *args, **kwargs): + return super().__call__(obj, method, *args, **kwargs) + + +call_torchbind = CallTorchBind() + +# Register this operator as side-effectful with FX. +# TODO: this is not really sufficient. While passes (hopefully) check +# Node.is_impure() and make good decisions, we also assume we can execute the +# graph as many times as we want without changing behavior, which is NOT true of +# ops that mutate torchbind object state. +has_side_effect(call_torchbind) + +_orig_scriptmethod_call = torch.ScriptMethod.__call__ + + +def torchbind_method_redispatch(self, *args, **kwargs): + if isinstance(self.raw_owner, torch.ScriptObject): + return call_torchbind(self.raw_owner, self.name, *args, **kwargs) + return _orig_scriptmethod_call(self, *args, **kwargs) + + +@contextmanager +def enable_torchbind_tracing(): + """Context manager that acts as a feature flag to enable torchbind tracing + behavior. Once torchbind tracing has been stabilized, we can remove this and + turn it always on. + """ + try: + KNOWN_TYPES.append(torch.ScriptObject) + torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign] + yield + finally: + assert ( + KNOWN_TYPES.pop() is torch.ScriptObject + ), "Someone else messed with KNOWN_TYPES during tracing, exploding." + torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign] + + +@call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd) +def call_torchbind_impl(obj, method, *args, **kwargs): + if isinstance(obj, torch.ScriptObject): + return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs) + elif isinstance(obj, FakeScriptObject): + return getattr(obj.wrapped_obj, method)(*args, **kwargs) + else: + raise RuntimeError(f"Unsupported first arg type {type(obj)} for call_torchbind") + + +@call_torchbind.py_impl(ProxyTorchDispatchMode) +def inner(mode, *args, **kwargs): + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + + out_proxy = mode.tracer.create_proxy( + "call_function", + call_torchbind, + proxy_args, + proxy_kwargs, + ) + out = call_torchbind(*args, **kwargs) + + obj, method, *rest_args = args + if isinstance(obj, torch.ScriptObject): + ns, class_name = _ns_and_class_name( + obj._type().qualified_name() # type: ignore[attr-defined] + ) + log.warning( + "Tracing torchbind method %s.%s with real ScriptObject. This may" + " cause the original object being mutated. If this is not intended," + ' You can register a fake class with torch._library.register_fake_class("%s::%s").', + class_name, + method, + ns, + class_name, + ) + + ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + if "val" not in out_proxy.node.meta: + assert out is None or isinstance( + out, (int, float, bool) + ), "Currently, only these constant dtypes are supported to be returned from torchbind methods." + out_proxy.node.meta["val"] = out + return ret + + +# When tracing with fake script object, the call_torchbind op will return a fake tensor +# When tracing with real script object, the call_torchbind op may return a real tensor, +# we need to convert it to fake tensor mannually. Dynamic shape is surpported. +@call_torchbind.py_impl(FakeTensorMode) +def call_torchbind_fake(mode, *args, **kwargs): + with mode: + out = call_torchbind_impl(*args, **kwargs) + return pytree.tree_map_only( + torch.Tensor, + lambda x: mode.from_tensor(x, static_shapes=True) + if not isinstance(x, torch._subclasses.fake_tensor.FakeTensor) + else x, + out, + ) + + +call_torchbind.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(call_torchbind, deferred_error=True) +) + + +@call_torchbind.py_functionalize_impl +def call_torchbind_func(ctx, *args, **kwargs): + from torch._higher_order_ops.effects import handle_effects + + return handle_effects( + ctx.mode._allow_token_discovery, ctx.mode._tokens, call_torchbind, args, kwargs + ) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py b/lib/python3.10/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..5a1ad4405c5ec3d8d7cfd0967cc9a19a11e3c1ce --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py @@ -0,0 +1,1044 @@ +# mypy: allow-untyped-defs +import collections +import copy +import dataclasses +import inspect +import logging +import threading +from collections import defaultdict +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.fx as fx +import torch.utils._pytree as pytree +from torch import Tensor +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator +from torch._prims_common import clone_preserve_strides +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +log = logging.getLogger("torch._dynamo") + + +############################################################################### +# Kernel Side Table + + +# We cannot put Triton Kernels into the FX graph as the graph nodes +# do not support arbitrary functions. +# Use a side table. +# We use two dicts so that fetching both the kernel and id are O(1) +class KernelSideTable: + id_to_kernel: Dict[int, Any] = {} + kernel_to_id: Dict[Any, int] = {} + constant_args: Dict[int, Any] = {} + lock = threading.Lock() + + # Returns index on the table + def add_kernel(self, kernel) -> int: + with self.lock: + if kernel in self.kernel_to_id: + return self.kernel_to_id[kernel] + + idx = len(self.id_to_kernel) + self.id_to_kernel[idx] = kernel + self.kernel_to_id[kernel] = idx + return idx + + # Returns the triton kernel at the given index + def get_kernel(self, idx: int): + # No need to lock here as fetching from dict is atomic + assert idx in self.id_to_kernel + return self.id_to_kernel[idx] + + # Not every constant arg can be added to the graph. Use this side table + # for constant args. + def add_constant_args(self, args) -> int: + with self.lock: + idx = len(self.constant_args) + self.constant_args[idx] = args + return idx + + # Returns the constant args + def get_constant_args(self, idx: int): + # No need to lock here as fetching from dict is atomic + assert idx in self.constant_args + return self.constant_args[idx] + + # Resets the table (only meant to be used in unit tests) + # This is only safe assuming single threaded execution + def reset_table(self) -> None: + self.id_to_kernel = {} + self.kernel_to_id = {} + self.constant_args = {} + + +kernel_side_table = KernelSideTable() + + +############################################################################### +# Mutation Tracker + + +@dataclasses.dataclass(frozen=True) +class Param: + idx: int + + +@dataclasses.dataclass(frozen=True) +class Intermediate: + idx: int + + def fake(self): + return self.idx < 0 + + +@dataclasses.dataclass(frozen=True) +class Op: + name: str + fn_call_name: Optional[str] + args: List[Union[Param, Intermediate]] + ret: Intermediate = dataclasses.field(repr=False) + + def __post_init__(self): + if self.name == "tt.call": + assert self.fn_call_name is not None + else: + assert self.fn_call_name is None + + +def generate_ttir(kernel, kwargs): + """ + Uses Triton's internal code generation to create TTIR + """ + import sympy + import triton + from triton.compiler.compiler import ASTSource + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + import torch + import torch._inductor.ir + from torch._subclasses.fake_tensor import FakeTensor + + if isinstance(kernel, Autotuner): + if len(kernel.configs) > 0: + # If we are autotuning, then it doesn't matter which version gets + # picked for tracing purposes, so lets pick the first one + kwargs = {**kwargs, **kernel.configs[0].kwargs} + kernel = kernel.fn + + assert isinstance(kernel, JITFunction) + + if len(kwargs) != len(kernel.arg_names): + raise ValueError("Incorrect number of arguments passed to kernel") + + # Replace all SymExprs with a regular value for TTIR generation + # Replace all FakeTensor/TensorBox with real tensors + # These replacements are needed for triton's type, key and config functions + ordered_args: Dict[str, Any] = {} + for name in kernel.arg_names: + a = kwargs[name] + if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool, sympy.Expr)): + ordered_args[name] = 2 + elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)): + with torch._C._DisableTorchDispatch(): + ordered_args[name] = torch.empty(2, dtype=a.dtype) + else: + ordered_args[name] = a + + ordered_tensor_names = [ + name for name, arg in ordered_args.items() if isinstance(arg, Tensor) + ] + specialization = kernel._get_config(*ordered_args.values()) + constants = { + name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor) + } + + # Build kernel signature -- doesn't include constexpr arguments. + signature = { + name: kernel._type_of(kernel._key_of(arg)) + for i, (name, arg) in enumerate(ordered_args.items()) + if i not in kernel.constexprs + } + + context = triton._C.libtriton.ir.context() + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + options = backend.parse_options({}) + triton._C.libtriton.ir.load_dialects(context) + backend.load_dialects(context) + + src = ASTSource(kernel, signature, constants, specialization) + + # Triton changes ASTSource.make_ir to take 3/4 arguments. Handle + # backward compatibility here. + make_ir_sig_params = len(inspect.signature(src.make_ir).parameters) + if make_ir_sig_params == 2: + ttir_module = src.make_ir(options, context) + elif make_ir_sig_params == 3: + codegen_fns = backend.get_codegen_implementation() + ttir_module = src.make_ir(options, codegen_fns, context) + else: + codegen_fns = backend.get_codegen_implementation() + module_map = backend.get_module_map() + ttir_module = src.make_ir(options, codegen_fns, module_map, context) + if not ttir_module.verify(): + raise RuntimeError("Verification for TTIR module has failed") + + return ttir_module, ordered_tensor_names + + +def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]: + """ + Walk the `ttir_module` bottom up to mine the `functions` from + the structured MLIR entities representing the Triton kernel + (mlir::Operation, mlir::Block, mlir::Region). + """ + functions: Dict[str, Dict[Intermediate, List[Op]]] = {} + + # block id --> op result (Intermediate) --> one or more ops + op_stack: Dict[int, Dict[Intermediate, List[Op]]] = defaultdict( + lambda: defaultdict(list) + ) + region_id_to_block_ids: Dict[int, List[int]] = defaultdict(list) + block_id_to_block_arg_ids: Dict[int, List[int]] = {} + replacements: Dict[int, Union[Intermediate, Param]] = {} + reindex_map: Dict[int, int] = {} + next_fake_intermediate = 0 + + def reindex(idx): + if idx not in reindex_map: + reindex_map[idx] = len(reindex_map) + return reindex_map[idx] + + def mlir_to_functions(op) -> None: + name: str = op.get_name() + if name == "builtin.module": + # this wraps all tt.func ops + return + + operand_ids: List[int] = [ + reindex(op.get_operand(i).id()) for i in range(op.get_num_operands()) + ] + result_ids: List[int] = [ + reindex(op.get_result(i).id()) for i in range(op.get_num_results()) + ] + + child_block_ids: List[int] = [] + for i in [op.get_region(i).id() for i in range(op.get_num_regions())]: + # as the walk is bottom-up, the region_id_to_block_ids[i] + # must be populated by the time we process the enclosing op + child_block_ids.extend(region_id_to_block_ids[i]) + + parent_block_id = -1 + parent_block = op.get_block() + if parent_block is not None: + parent_block_id = parent_block.id() + if parent_block_id not in block_id_to_block_arg_ids: + block_id_to_block_arg_ids[parent_block_id] = [] + for i in range(parent_block.get_num_arguments()): + block_id_to_block_arg_ids[parent_block_id].append( + reindex(parent_block.get_argument(i).id()), + ) + # the region info is collected via ops' parent blocks to be + # used later when the region's encloding op is traversed + parent_region = parent_block.get_parent() + if parent_region is not None: + region_id_to_block_ids[parent_region.id()].append(parent_block_id) + + nonlocal next_fake_intermediate + + if name == "tt.func": + # for function ops: gather and inline + # the ops from all child blocks + fn_ops = defaultdict(list) + for child_block_id in child_block_ids: + for result, block_fn_ops in op_stack.pop(child_block_id).items(): + for block_fn_op in block_fn_ops: + fn_ops[result].append(block_fn_op) + + # replace the corresponding Intermediates in the + # child op args with the function args (Params) + for i, idx in enumerate(block_id_to_block_arg_ids[child_block_ids[0]]): + replacements[idx] = Param(i) + + for fn_op_list in fn_ops.values(): + for fn_op in fn_op_list: + for i in range(len(fn_op.args)): + arg = fn_op.args[i] + seen = set() # to break cycles + # there can be transitive replacements, but likely + # no cycles (we keep the `seen` set just in case) + while ( + isinstance(arg, Intermediate) + and arg.idx in replacements + and arg.idx not in seen + ): + seen.add(arg.idx) + arg = fn_op.args[i] = replacements[arg.idx] + + # next function capture starts + # with empty replacements + replacements.clear() + + fn_name = op.get_str_attr("sym_name") + functions[fn_name] = fn_ops + elif child_block_ids: + if name in {"scf.if", "scf.for", "scf.while", "tt.reduce", "tt.scan"}: + # for blocked ops: inline the enclosed ops into + # the parent block + rewire the last op in each + # child block to return the block result + return_ops = [] + for block_id in child_block_ids: + if name == "scf.for": + # example: + # %result = scf.for %iv = %lb to %ub step %step iter_args(%arg = %init) -> (i32) ... + # block args: 2 (%iv, %arg) + # op operands: 4 (%lb, %ub, %step, %init) + # `%arg` is mapping to `%init` + for i, idx in enumerate(block_id_to_block_arg_ids[block_id]): + if i == 0: + next_fake_intermediate -= 1 + replacements[idx] = Intermediate(next_fake_intermediate) + else: + replacements[idx] = Intermediate(operand_ids[i + 2]) + elif name == "scf.while": + # example: + # %3:3 = scf.while (%arg2 = %1, %arg3 = %2, %arg4 = %c0_i32_8) ... + # block args: 3 (%arg2, %arg3, %arg4) + # op operands: 3 (%1, %2, %c0_i32_8) + # `%arg2` is mapping to `%1`, `%arg3` is mapping to `%2`, ... + for i, idx in enumerate(block_id_to_block_arg_ids[block_id]): + replacements[idx] = Intermediate(operand_ids[i]) + elif name == "scf.if": + # the scf block args are ignored by the pass. but, as they + # may be used as operands of the ops inside the block + # (and nested blocks inlined in the current block by now), + # they are replaced by new fake Intermediates to avoid "this + # operand is not returned by any other op in the fn" error + # in the downstream analysis + for idx in block_id_to_block_arg_ids[block_id]: + next_fake_intermediate -= 1 + replacements[idx] = Intermediate(next_fake_intermediate) + else: + assert name in ("tt.reduce", "tt.scan") + # wire the block arguments to the op arguments + num_operands = len(operand_ids) + block_arg_ids = block_id_to_block_arg_ids[block_id] + assert len(block_arg_ids) == 2 * num_operands, ( + f"{name} is expected to have twice as " + "many block arguments as op arguments: " + f"{operand_ids=}, {block_arg_ids=}." + ) + for i, idx in enumerate(block_arg_ids): + # for a tt.reduce/tt.scan op with N arguments, the block + # arguments comprise N reduced values followed by + # N current values corresponding to the N op args + replacements[idx] = Intermediate( + operand_ids[i % num_operands] + ) + + if block_id in op_stack: + block_ops = op_stack.pop(block_id) + if not block_ops: + continue + last_ret, last_ops = block_ops.popitem() + if all( + op.name + in ("scf.yield", "tt.reduce.return", "tt.scan.return") + for op in last_ops + ): + # if last_ops are all return ops, treat them separately + return_ops.extend(last_ops) + else: + # otherwise, return last_ops to the block + block_ops[last_ret] = last_ops + for op_result, child_ops in block_ops.items(): + op_stack[parent_block_id][op_result].extend(child_ops) + + scf_results = [Intermediate(idx) for idx in result_ids] + for scf_result in scf_results: + for return_op in return_ops: + op_stack[parent_block_id][scf_result].append(return_op) + else: + raise RuntimeError( + f"Unknown blocked function: {name}. Can't capture the TTIR." + ) + else: + callee = None + if name == "tt.call": + callee = op.get_flat_symbol_ref_attr("callee") + args: List[Union[Param, Intermediate]] = [ + Intermediate(operand) for operand in operand_ids + ] + block_ops = op_stack[parent_block_id] + if result_ids: + for result_id in result_ids: + res = Intermediate(result_id) + block_ops[res].append(Op(name, callee, args, res)) + else: + next_fake_intermediate -= 1 + fake_res = Intermediate(next_fake_intermediate) + block_ops[fake_res].append(Op(name, callee, args, fake_res)) + + ttir_module.walk(mlir_to_functions) + + return functions + + +class MemoizeWithCycleCheck: + def __init__(self, fn): + self.fn = fn + self.reset() + + def __call__(self, functions, fn_name, num_args): + key = (fn_name, num_args) + if key not in self.cache: + self.cache[key] = None + self.cache[key] = self.fn(functions, fn_name, num_args) + if self.cache[key] is None: + raise RuntimeError("Recursion is not supported") + return self.cache[key] + + def reset(self): + self.cache = {} + + +@MemoizeWithCycleCheck +def analyze_kernel_mutations(functions, fn_name, num_args): + """ + Analyzes the graph to detect all sinks from a predefined list of sinks + by using triton's MemWrite trait list. NOTE: What if triton exposed this? + From each sink, it traverses the CFG backwards to identify all the input + pointers that are mutated. + """ + # Name of mutation op to mutated parameter indices + # List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td + # All the OPs that have MemWrite trait. + # What if Triton exposed this? + MUTATION_OPS = {"tt.store": [0], "tt.atomic_cas": [0], "tt.atomic_rmw": [0]} + # Ops that we want to bail out on + UNKNOWN_OPS = {"tt.elementwise_inline_asm"} + + stack: List[Union[Param, Intermediate]] = [] + visited = set() + ops = functions[fn_name] + for op_list in ops.values(): + for op in op_list: + if op.name in UNKNOWN_OPS: + raise RuntimeError( + f"ttir analysis hit an op we do not know how to analyze: {op.name}" + ) + + if op.name == "tt.call": + assert op.fn_call_name in functions + mutations = analyze_kernel_mutations( + functions, op.fn_call_name, len(op.args) + ) + stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated) + else: + for idx in MUTATION_OPS.get(op.name, []): + stack.append(op.args[idx]) + + # The following is an iterative DFS algorithm + mutated = [False] * num_args + while stack: + arg = stack.pop() + if arg in visited: + continue + + visited.add(arg) + + if isinstance(arg, Param): + if arg.idx >= num_args: + # This is an argument defined in the kernel, not passed in + continue + mutated[arg.idx] = True + elif isinstance(arg, Intermediate) and not arg.fake(): + for op in ops[arg]: + # Skip arguments to load + if op.name != "tt.load": + stack.extend(op.args) + return mutated + + +def identify_mutated_tensors(kernel, kwargs): + """ + Given a triton kernel and the arguments for this kernel, this function + 1) Retrieves the TTIR converted version of the kernel from Triton's API. + 2) Parses the TTIR and creates a control flow graph + 3) Analyzes the graph to detect all input tensor mutations + """ + + ttir_module = None + functions = None + try: + ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs) + + # extract functions from TTIR using MLIR bindings exposed by Triton code + functions = ttir_to_functions(ttir_module) + + assert functions is not None + kernel_name = next(iter(functions.keys())) + # Triton codegen modifies the name + assert kernel.fn.__name__ in kernel_name + # Reset the cache between top level invocations + # The cache for analyze kernel mutations is mainly used for cycle + # detection, so each top level invocation needs a clean cache + analyze_kernel_mutations.reset() + mutations = analyze_kernel_mutations( + functions, kernel_name, len(ordered_tensor_names) + ) + + return [ + ordered_tensor_names[i] for i, mutated in enumerate(mutations) if mutated + ] + except Exception as e: + log.warning( + "Encountered an exception in identify_mutated_tensors, assuming every input is mutated", + exc_info=True, + ) + if ttir_module is not None: + log.debug("TTIR:\n%s", str(ttir_module)) + if functions is not None: + log.debug("functions:") + for name, fn in functions.items(): + log.debug("===\t%s\t===", name) + for ret, ops in fn.items(): + log.debug("%s\t=>\t%s", ret, ops) + return [key for key, value in kwargs.items() if isinstance(value, Tensor)] + + +############################################################################### +# Triton Kernel Wrappers + + +# Used for wrapping a Triton Kernel +class TritonKernelWrapperMutation(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("triton_kernel_wrapper_mutation") + + def __call__(self, kernel_idx, constant_args_idx, grid, kwargs): + return super().__call__( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + kwargs=kwargs, + ) + + +triton_kernel_wrapper_mutation = TritonKernelWrapperMutation() + + +# Used for wrapping a Triton Kernel in a functional manner +class TritonKernelWrapperFunctional(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("triton_kernel_wrapper_functional") + + def __call__(self, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone): + return super().__call__( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + kwargs=kwargs, + tensors_to_clone=tensors_to_clone, + ) + + +triton_kernel_wrapper_functional = TritonKernelWrapperFunctional() + + +@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd) +def triton_kernel_wrapper_mutation_dense( + *, kernel_idx, constant_args_idx, grid, kwargs +): + from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code + + kernel = kernel_side_table.get_kernel(kernel_idx) + constant_args = kernel_side_table.get_constant_args(constant_args_idx) + + if len(grid) == 1: + grid_fn = grid[0] + else: + fn_name, code = user_defined_kernel_grid_fn_code( + kernel.fn.__name__, kernel.configs, grid + ) + namespace: Dict[str, Any] = {} + exec(code, namespace) + grid_fn = namespace[fn_name] + + kernel[grid_fn](**kwargs, **constant_args) + + +@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode) +def triton_kernel_wrapper_mutation_fake_tensor_mode( + mode, *, kernel_idx, constant_args_idx, grid, kwargs +): + with mode: + return None + + +@triton_kernel_wrapper_mutation.py_impl(DispatchKey.Meta) +def _(*, kernel_idx, constant_args_idx, grid, kwargs): + return None + + +def trace_triton_kernel_wrapper(proxy_mode, func_overload, node_args): + with disable_proxy_modes_tracing(): + out = func_overload(**node_args) + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", + func_overload, + (), + proxy_args, + name=func_overload.__name__ + "_proxy", + ) + ret = track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + return ret + + +@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode) +def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode( + mode, *, kernel_idx, constant_args_idx, grid, kwargs +): + trace_triton_kernel_wrapper( + mode, + triton_kernel_wrapper_mutation, + { + "kernel_idx": kernel_idx, + "constant_args_idx": constant_args_idx, + "grid": grid, + "kwargs": kwargs, + }, + ) + + return None + + +@triton_kernel_wrapper_mutation.py_functionalize_impl +def triton_kernel_wrapper_mutation_functionalize( + ctx, kernel_idx, constant_args_idx, grid, kwargs +): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + kernel = kernel_side_table.get_kernel(kernel_idx) + constant_args = kernel_side_table.get_constant_args(constant_args_idx) + # TODO(oulgen): Preexisting bug, if two kernel inputs are views of each + # other, and one gets mutated in kernel, and later another gets mutated, + # they are no longer equal. Fix this by graph breaking on this condition + # earlier in dynamo. + tensors_to_clone = identify_mutated_tensors( + kernel, {**unwrapped_kwargs, **constant_args} + ) + with ctx.redispatch_to_next(): + unwrapped_outputs = triton_kernel_wrapper_functional( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + kwargs=unwrapped_kwargs, + tensors_to_clone=tensors_to_clone, + ) + + assert set(unwrapped_outputs.keys()).issubset(set(kwargs.keys())) + for key, output_arg in unwrapped_outputs.items(): + if not isinstance(output_arg, Tensor): + continue + input_arg = kwargs[key] + assert isinstance(input_arg, Tensor) + + ctx.replace(input_arg, output_arg) + # indicate that above replace is hidden from autograd + ctx.mark_mutation_hidden_from_autograd(input_arg) + ctx.commit_update(input_arg) + ctx.sync(input_arg) + return None + + +@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd) +def triton_kernel_wrapper_functional_dense( + *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone +): + # TODO(oulgen): For performance reasons, we want to ensure that these + # `clone_preserve_strides` calls are never executed at runtime + # (inductor should always optimize them away). + # Requires https://github.com/pytorch/pytorch/issues/109240 + kwargs = { + key: (clone_preserve_strides(val) if key in tensors_to_clone else val) + for key, val in kwargs.items() + } + triton_kernel_wrapper_mutation( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + kwargs=kwargs, + ) + return {key: val for key, val in kwargs.items() if key in tensors_to_clone} + + +@triton_kernel_wrapper_functional.py_impl(FakeTensorMode) +def triton_kernel_wrapper_functional_fake_tensor_mode( + mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone +): + # TODO(oulgen): For performance reasons, we want to ensure that these + # `clone_preserve_strides` calls are never executed at runtime + # (inductor should always optimize them away). + # Requires https://github.com/pytorch/pytorch/issues/109240 + with mode: + return { + key: clone_preserve_strides(val) + for key, val in kwargs.items() + if key in tensors_to_clone + } + + +@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode) +def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode( + mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone +): + return trace_triton_kernel_wrapper( + mode, + triton_kernel_wrapper_functional, + { + "kernel_idx": kernel_idx, + "constant_args_idx": constant_args_idx, + "grid": grid, + "kwargs": kwargs, + "tensors_to_clone": tensors_to_clone, + }, + ) + + +@triton_kernel_wrapper_functional.py_functionalize_impl +def triton_kernel_wrapper_functional_functionalize( + ctx, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone +): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + with ctx.redispatch_to_next(): + outputs = triton_kernel_wrapper_functional( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + kwargs=unwrapped_kwargs, + tensors_to_clone=tensors_to_clone, + ) + return ctx.wrap_tensors(outputs) + + +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView) +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect) +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA) +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCPU) + +triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCPU) + + +############################################################################### +# The "TritonHOPifier": a class that transforms a call to a triton kernel into +# a call to the triton_kernel_wrapper_mutation HOP. + + +class TritonHOPifier: + """Orchestrator for converting a user-defined triton kernel into a call + to the triton_kernel_wrapper_mutation HOP. + + It has two main use cases. + + 1. When Dynamo sees a triton kernel, it wraps it into a TritonKernelVariable + and uses the TritonHOPifier to convert calls to the TritonKernelVariable + into a call to the HOP. + + 2. In order to capture a user-defined triton kernel while performing + tracing (via make_fx or non-strict export), a user must annotate their + triton kernel with the `capture_triton` decorator. The decorator uses + TritonHOPifier to convert calls to the triton kernel into a call + to the HOP (which can then be traced). + + Because Dynamo has its own calling conventions for e.g. invoking a user-defined function + TritonHOPifier is an abstract class that can be overriden by its subclasses. + """ + + def raise_unsupported(self, msg): + raise NotImplementedError("abstract method") + + def is_callable(self, maybe_callable): + raise NotImplementedError("abstract method") + + def get_value(self, val): + raise NotImplementedError("abstract method") + + def call_grid(self, grid, meta, tx): + raise NotImplementedError("abstract method") + + def call_HOP(self, variable, grids, combined_args, tx): + raise NotImplementedError("abstract method") + + def check_grid(self, grid): + raise NotImplementedError("abstract method") + + def init_variable(self, variable, kernel, kernel_idx, grid): + from triton.runtime.autotuner import Autotuner + + assert kernel is not None + + variable.kernel = kernel + variable.kernel_idx = kernel_side_table.add_kernel(kernel) + + assert kernel_idx is None or variable.kernel_idx == kernel_idx + + variable.grid = grid + + if isinstance(kernel, Autotuner): + import torch + import torch._dynamo + + # We only support configs and keys arguments of triton.autotune + # Make sure other arguments are defaulted + defaults = inspect.signature(Autotuner.__init__).parameters + + # Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep. + # The call to get_first_attr is to maintain backward-compatibility. + if ( + not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args + and ( + ( + "warmup" in defaults + and defaults["warmup"].default + != torch._dynamo.utils.get_first_attr( + kernel, "num_warmups", "warmup" + ) + ) + or ( + "rep" in defaults + and defaults["rep"].default + != torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep") + ) + or ( + "prune_configs_by" in defaults + and defaults["prune_configs_by"].default + != kernel.early_config_prune + ) + # Set via reset_to_zero argument + or len(kernel.reset_idx) != 0 + or len(kernel.restore_idx) != 0 + or ( + "use_cuda_graph" in defaults + and defaults["use_cuda_graph"].default != kernel.use_cuda_graph + ) + ) + ): + self.raise_unsupported( + "Only configs and keys are supported for triton.autotune" + ) + + def call_getitem(self, variable, args): + # __getitem__ should only be called if we don't already have a grid + # Only grid needs to be passed + if variable.grid is not None or len(args) != 1: + self.raise_unsupported( + "Triton kernels should be called with only a single grid" + ) + + return type(variable)( + kernel=variable.kernel, + kernel_idx=variable.kernel_idx, + grid=args[0], + ) + + def call_run(self, variable, args, kwargs, tx): + if "grid" not in kwargs: + self.raise_unsupported("Triton kernel requires to be called with a grid") + grid = kwargs.pop("grid") + kwargs.pop("warmup", None) + # rewrite kernel.run(*args, grid=grid) to kernel[grid](*args) + return self.call_triton_kernel( + type(variable)( + kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=grid + ), + args, + kwargs, + tx, + ) + + def call_triton_kernel(self, variable, args, kwargs, tx): + from triton.runtime.autotuner import autotune, Autotuner, Config + + if "num_ctas" in kwargs: + self.raise_unsupported( + "Passing num_ctas directly to the Triton kernel is not supported. " + "Please use a Config in @triton.autotune instead." + ) + + special_kwargs = {} + for name in ("num_warps", "num_stages"): + if name in kwargs: + # remove special kwargs from `kwargs` + val = kwargs.pop(name) + special_kwargs[name] = self.get_value(val) + + if special_kwargs: + if isinstance(variable.kernel, Autotuner): + # if there is Autotuner already, set + # special kwargs to each of its configs + new_configs = copy.deepcopy(variable.kernel.configs) + for config in new_configs: + config.__dict__.update(special_kwargs) + new_kernel = autotune(configs=new_configs, key=[])(variable.kernel.fn) + else: + # if there is no Autotuner, wrap the kernel into a + # new one with a single config with special kwargs + new_config = Config(kwargs={}, **special_kwargs) + new_kernel = autotune(configs=[new_config], key=[])(variable.kernel) + + # create a new variable to contain the new (wrapped) kernel; + # skip kernel_idx to get a new record in the kernel side table + new_var = type(variable)(new_kernel, None, variable.grid) + return self.call_triton_kernel(new_var, args, kwargs, tx) + + if variable.grid is None: + self.raise_unsupported("Triton kernels should always be called with a grid") + + # Both for grid's meta as well as for the kernel, we need combined + # args and kwargs combined and normalized + combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs} + + configs = ( + [config.kwargs for config in variable.kernel.configs] + if isinstance(variable.kernel, Autotuner) + else [{}] + ) + grids = [] + for config_args in configs: + # If the grid is a function, then lets execute it and convert it to + # a list + grid = variable.grid + if self.is_callable(grid): + # Populate the special "meta" argument to call the grid function + meta = {**combined_args_raw, **config_args} + grid = self.call_grid(grid, meta, tx) + grids.append(self.check_grid(grid)) + + for i in range(len(grids)): + if not isinstance(grids[i], tuple): + self.raise_unsupported("Only tuple grids are supported") + # inductor expects all grids to be 3-tuple so lets make it + if len(grids[i]) == 1: + grids[i] = (grids[i][0], 1, 1) + elif len(grids[i]) == 2: + grids[i] = (grids[i][0], grids[i][1], 1) + elif len(grids[i]) > 3: + self.raise_unsupported("Grid can have at most rank 3") + + assert len(grids) != 0 + + def intify(x): + if isinstance(x, torch.SymInt): + return int(x) + else: + return x + + if len(set(pytree.tree_map(intify, grids))) == 1: + # If there's only one unique grid, lets simplify + grids = [grids[0]] + + return self.call_HOP(variable, grids, combined_args_raw, tx) + + +############################################################################### +# Helpers for capture_triton API that makes a user-defined triton kernel traceable into +# a graph via make_fx or non-strict export (coming soon) + + +class TracingTritonHOPifier(TritonHOPifier): + def raise_unsupported(self, msg): + raise RuntimeError(msg) + + def is_callable(self, maybe_callable): + return callable(maybe_callable) + + def get_value(self, val): + return val + + def call_grid(self, grid, meta, tx): + assert tx is None + return grid(meta) + + def check_grid(self, grid): + if not isinstance(grid, collections.abc.Sequence): + raise RuntimeError( + "capture_triton can only handle grids that resolve to Sequence[int]." + ) + # normalize to tuple + return tuple(grid) + + def call_HOP(self, variable, grids, combined_args, tx): + assert tx is None + + def is_graphable(val): + return isinstance(val, fx.node.base_types) + + non_graphable_args = { + k: v for k, v in combined_args.items() if not is_graphable(v) + } + graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)} + + constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args) + return triton_kernel_wrapper_mutation( + kernel_idx=variable.kernel_idx, + constant_args_idx=constant_args_idx, + grid=grids, + kwargs=graphable_args, + ) + + +tracing_triton_hopifier_singleton = TracingTritonHOPifier() + + +class TraceableTritonKernelWrapper: + def __init__(self, kernel, kernel_idx, grid): + self.kernel = None + self.grid = None + tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) + assert self.kernel is not None + + def __getitem__(self, *args): + return tracing_triton_hopifier_singleton.call_getitem(self, args) + + def run(self, *args, **kwargs): + from torch._library.triton import is_capture_triton_enabled + + if is_capture_triton_enabled(): + return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None) + else: + assert self.kernel is not None + return self.kernel.run(*args, **kwargs) + + def __call__(self, *args, **kwargs): + from torch._library.triton import is_capture_triton_enabled + + if is_capture_triton_enabled(): + return tracing_triton_hopifier_singleton.call_triton_kernel( + self, args, kwargs, None + ) + else: + assert self.kernel is not None + return self.kernel[self.grid](*args, **kwargs) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/utils.py b/lib/python3.10/site-packages/torch/_higher_order_ops/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3b88d5f7dbac8ae0c4cc55d162804fb3fcb9113a --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/utils.py @@ -0,0 +1,379 @@ +# mypy: allow-untyped-defs +import functools +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Callable + +import torch +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch._ops import OperatorBase +from torch.fx.experimental.proxy_tensor import make_fx +from torch.multiprocessing.reductions import StorageWeakRef + + +@dataclass +class UnsupportedAliasMutationException(RuntimeError): + reason: str + + +def autograd_not_implemented_inner( + operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any +) -> Any: + """If autograd is enabled and any of the arguments require grad this will either + raise an error or return a DelayedError depending on the value of delayed. + + Args: + operator: The Operator to call with the *args and **kwargs with + op_name: The name of the Operator + delayed_error: If True, return a DelayedError instead of raising an error + args: The flattened operands to the Operator + kwargs: The keyword arguments to the Operator + + Raises: + RuntimeError: If autograd is enabled and any of the arguments to the Operator + """ + with torch._C._AutoDispatchBelowAutograd(): + result = operator(*args, **kwargs) + flat_operands = pytree.arg_tree_leaves(*args) + if torch.is_grad_enabled() and any( + f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) + ): + if delayed_error: + err_fn = torch._C._functions.DelayedError( + f"Autograd not implemented for {str(operator)}", + 1, + ) + + def fake_requires_grad(tensor): + if torch.is_floating_point(tensor) or torch.is_complex(tensor): + tensor = tensor.detach() + tensor.requires_grad = True + return tensor + + return pytree.tree_map_only( + torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result + ) + else: + raise RuntimeError(f"Autograd not implemented for {str(operator)}") + return result + + +def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable: + def inner(*args, **kwargs): + return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs) + + return inner + + +def _maybe_run_with_interpreter(fn): + maybe_interpreted_fn = fn + if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta(): + # Running graph with interpreter is needed for propagating the stack_trace + def graph_with_interpreter(*args): + with fx_traceback.preserve_node_meta(): + return torch.fx.Interpreter(fn).run(*args) + + maybe_interpreted_fn = graph_with_interpreter + return maybe_interpreted_fn + + +def reenter_make_fx(fn): + from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER + + @functools.wraps(fn) + def wrapped(*args): + assert ( + _CURRENT_MAKE_FX_TRACER is not None + ), "Cannot reenter make_fx when we're not under a make_fx tracing session" + return _CURRENT_MAKE_FX_TRACER.trace_subgraph( + _maybe_run_with_interpreter(fn), *args + ) + + return wrapped + + +def _maybe_reenter_make_fx(fn): + from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER + + if _CURRENT_MAKE_FX_TRACER is not None: + return reenter_make_fx(fn) + else: + return make_fx(fn) + + +@contextmanager +def _set_compilation_env(): + _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag + try: + # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo + # once we are confident fx tracing works with dynamo. + torch.fx._symbolic_trace._is_fx_tracing_flag = False + yield + finally: + torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing + + +def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False): + """ + Dispatch-trace the branch with inputs and check if + producing graph has mutable op on the input. This is + bit restrictive as the branch must be traceable. + """ + try: + gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs) + except UnsupportedAliasMutationException: + # this can happen when nested cond_op is + # functionalized + return True + except Exception as e: + raise e + + def _detect_input_mutation(gm): + input_nodes = set() + for node in gm.graph.nodes: + if node.op == "placeholder": + input_nodes.add(node) + if node.op == "call_function": + target = node.target + if ( + isinstance(target, torch._ops.OpOverload) + and target._schema.is_mutable + ): + for arg in node.args: + if arg in input_nodes: + return True + + for _, module in gm.named_children(): + if isinstance(module, torch.fx.GraphModule): + if _detect_input_mutation(module): + return True + + return False + + return _detect_input_mutation(gm) + + +def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False): + """ + Dispatch-trace the branch with inputs and check if + producing graph has output aliasing the branch input. This is + bit restrictive as the branch must be traceable. + """ + try: + gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs) + except UnsupportedAliasMutationException: + # this can happen when nested cond_op is + # functionalized + return True + except Exception as e: + raise e + + def _detect_input_alias(gm): + input_storages = set() + for node in gm.graph.nodes: + # We need to check existence of "val" because we reuse the logic here + # for map operator, where num_mapped_args is a scalar + # and doesn't have a "val" meta. + if node.op == "placeholder" and "val" in node.meta: + input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage())) + if node.op == "output": + + def check_alias(out): + if out is not None and "val" in out.meta: + out_storage = StorageWeakRef(out.meta["val"]._typed_storage()) + return out_storage in input_storages + return False + + if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))): + return True + + for _, module in gm.named_children(): + if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module): + return True + + return False + + return _detect_input_alias(gm) + + +def unique_graph_id(proxy_mode, prefix): + """Returns a unique name and id for a graph to be added to a proxy_mode tracer""" + # There are probably better ways - I know that create_arg has some self incrementing name + # magic to it, but since we explicitly have to get the name for register_module, + # I was not sure how to do that. This kinda simulates it. + next_name = None + i = 0 + while not next_name: + candidate = f"{prefix}_{i}" + if hasattr(proxy_mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + return i, next_name + + +def _from_fun(t): + from torch._functorch.aot_autograd import from_fun + from torch._subclasses.functional_tensor import FunctionalTensor + + if isinstance(t, torch.Tensor): + if t.dtype != torch.bool: + return torch.empty_strided( + t.size(), + t.stride(), + dtype=t.dtype, + requires_grad=t.requires_grad, + ) + else: + # clone of a functional tensor produces a functional tensor + # but we want to avoid it so we clone a non-functional version + maybe_unfunc_t = t + if isinstance(t, FunctionalTensor): + torch._sync(t) + maybe_unfunc_t = from_fun(t) + elif torch._is_functional_tensor(t): + # need to handle both types of functionalization here: + # these are the tensors that came from the user, + # which could be either FunctionalTensorWrapper or FunctionalTensor + torch._sync(t) + maybe_unfunc_t = torch._from_functional_tensor(t) + return maybe_unfunc_t.clone() + return t + + +def clone_outputs_aliasing_inputs(args): + input_storage = { + StorageWeakRef(arg._typed_storage()) + for arg in args + if isinstance(arg, torch.Tensor) + } + + def maybe_clone(t): + if ( + isinstance(t, torch.Tensor) + and StorageWeakRef(t._typed_storage()) in input_storage + ): + return t.clone() + return t + + return maybe_clone + + +def prepare_fw_with_masks(fn): + def fw_with_masks(*args): + fw_out = fn(*args) + return fw_out, [ + True if isinstance(ret, torch.Tensor) and ret.requires_grad else False + for ret in fw_out + ] + + return fw_with_masks + + +# TODO: The parameter use_output_and_grad_bw is required because some operations +# that utilize this function, such as the while_loop, may require (grad, fwd_outputs) +def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs): + from torch._functorch.aot_autograd import AOTConfig, create_joint + + # Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys + # between Autograd and Python key. Currently, we only suspend functionalization but more can be + # added when required. Will encounter two problems if we don't suspend functionalization: + # + # 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper, + # but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching. + # However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to + # fetch the proxy for the inputs and fail to capture any operations on them. + # + # 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further + # wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer + # only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore, + # when creating the output node, it fails to associate the wrapped tensor with its proxy. + # Instead, it will create _tensor_constant as output. + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + + example_grad = [_from_fun(out) for out in fw_outputs] + num_grads = len(example_grad) + fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs) + + def joint_fn(*joint_operands_grads): + if use_output_and_grad_bw: + grads = joint_operands_grads[0] + inputs = joint_operands_grads[1][-1:] + else: + grads = joint_operands_grads[:num_grads] + inputs = joint_operands_grads[num_grads:] + + joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config) + _, grads = joint( + list(inputs), + [grad for grad in grads if grad is not None and grad.requires_grad], + ) + + # In order to keep map functional for backward graph, + # we clone outputs that are aliasing inputs + maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads) + + return pytree.tree_map(maybe_clone, grads) + + if use_output_and_grad_bw: + example_xs_out = list(fw_inputs) + list(fw_outputs) + joint_graph = _maybe_reenter_make_fx(joint_fn)( + (list(example_grad), list(example_xs_out)) + ) + else: + example_xs_out = list(fw_inputs) + joint_graph = _maybe_reenter_make_fx(joint_fn)( + *(list(example_grad) + list(example_xs_out)) + ) + + return fw_graph, joint_graph + + +def _unstack_pytree(xs): + flat_xs, inspec = pytree.tree_flatten(xs) + if not all(isinstance(xs, torch.Tensor) for xs in flat_xs): + raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}") + + if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs): + raise RuntimeError( + f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}" + ) + + a = zip(*flat_xs) + + pytrees = [] + for tuple in a: + pytrees.append(pytree.tree_unflatten(tuple, inspec)) + return pytrees + + +def _stack_pytree(pytrees): + flat_out = [] + out_spec = None + for pt in pytrees: + flat_pt, out_spec = pytree.tree_flatten(pt) + flat_out.append(flat_pt) + assert out_spec is not None + b = zip(*flat_out) + stacked_out = [] + for leaves in b: + if all(isinstance(leaf, torch.Tensor) for leaf in leaves): + stacked_out.append(torch.stack(leaves)) + elif all(leaf is None for leaf in leaves): + # Backward graph can return None output when forward inputs doesn't require grad. + # When we eagerly execute backward graph, we need to call _stack_pytree on its output, + # therefore we need to deal with None output. + stacked_out.append(None) # type: ignore[arg-type] + else: + raise RuntimeError(f"Cannot stack {leaves}.") + return pytree.tree_unflatten(stacked_out, out_spec) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/while_loop.py b/lib/python3.10/site-packages/torch/_higher_order_ops/while_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..d64f512399d9d5101227876477a9f338309204a1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/while_loop.py @@ -0,0 +1,268 @@ +# mypy: allow-untyped-defs +from typing import Callable, Tuple, Union + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + _maybe_run_with_interpreter, + _set_compilation_env, + autograd_not_implemented, + reenter_make_fx, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree + + +class WhileLoopOp(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("while_loop") + + def __call__( + self, + cond_fn: Callable, + body_fn: Callable, + carried_inputs: Tuple[Union[torch.Tensor, int, float, bool]], + additional_inputs: Tuple[Union[torch.Tensor, int, float, bool]], + /, + ): + if not isinstance(carried_inputs, tuple): + raise RuntimeError( + f"carried_inputs must be a tuple, got {type(carried_inputs)}" + ) + if not isinstance(additional_inputs, tuple): + raise RuntimeError( + f"additional_inputs must be a tuple, got {type(additional_inputs)}" + ) + if not all( + isinstance(t, (torch.Tensor, int, float, bool)) for t in carried_inputs + ): + raise RuntimeError( + "carried_inputs must be a tuple of tensors, ints, floats, or bools, got " + f"{carried_inputs}" + ) + + if not all( + isinstance(t, (torch.Tensor, int, float, bool)) for t in additional_inputs + ): + raise RuntimeError( + "additional_inputs must be a tuple of tensors, ints, floats, or bools, got " + f"{additional_inputs}" + ) + return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs) + + +while_loop_op = WhileLoopOp() + + +def while_loop(cond_fn, body_fn, carried_inputs): + r""" + Run body_fn(*carried_inputs) while cond_fn(*carried_inputs) returns a True scalar tensor. Returns the output of body_fn or + initial carried_inputs. + + .. warning:: + `torch.while_loop` is a prototype feature in PyTorch. It has limited support for input and output types and + doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. + Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + `while_loop` is a structured control flow operator. It preserves the loop semantic across the torch.compile and torch.export. + + `while_loop` is equivalent to the following: + + def while_loop(cond_fn, body_fn, carried_inputs): + val = carried_inputs + while cond_fn(*val): + val = body_fn(*val) + return val + + Args: + cond_fn (Callable): A callable function that returns a boolean Scalar tensor. + + body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors + + carried_inputs (Tuple of possibly nested dict/list/tuple of tensors): A tuple of inputs to cond_fn and body_fn. It's also + the initial value of states that are carried across iterations. + + Example: + + def cond_fn(iter, x): + return iter.sum() < 10 + + def body_fn(iter, x): + return iter + 1, x.sin() + + while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4))) + + Restrictions: + + - body_fn must return tensors with the same metadata (e.g.shape, dtype) as inputs. + + - body_fn and cond_fn must not in-place mutate the carried_inputs. A clone before the mutation is required. + + - body_fn and cond_fn must not mutate python varialbles (e.g. list/dict) created outside of the body_fn. + + - body_fn and cond_fn's output cannot aliase any of the inputs. A clone is required. + + .. warning:: + Temporal Limitations: + + - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. + + """ + + # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo. + # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs. + additional_inputs: Tuple = () + if torch.compiler.is_dynamo_compiling(): + return while_loop_op(cond_fn, body_fn, carried_inputs, additional_inputs) + + def _validate_input(cond_fn, body_fn, carried_inputs): + if not callable(cond_fn) or not callable(body_fn): + raise RuntimeError("Expect cond_fn and body_fn to be callbale.") + + if not isinstance(carried_inputs, (tuple, list)) or pytree.tree_any( + lambda t: not isinstance(t, torch.Tensor), carried_inputs + ): + raise RuntimeError( + "Expect carried_inputs to be a tuple of possibly nested dict/list/tuple that only" + f"consists of tensor leaves, but got {carried_inputs}." + ) + + _validate_input(cond_fn, body_fn, carried_inputs) + + # Dynamo is expecting a callable with "__code__" attribute. + # We cannot directly pass cond_op to it. So we wrap it in a dummy function. + def _while_loop_op_wrapper(*args, **kwargs): + return while_loop_op(*args, **kwargs) + + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): + return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)( + cond_fn, body_fn, carried_inputs, additional_inputs + ) + + +@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs): + carried_vals = carried_inputs + + def _is_boolean_scalar_tensor(pred): + return ( + isinstance(pred, torch.Tensor) + and pred.size() == torch.Size([]) + and pred.dtype == torch.bool + ) + + if not isinstance(carried_inputs, tuple): + raise RuntimeError( + f"carried_inputs must be a tuple but got {type(carried_inputs)}" + ) + + while pred := cond_fn(*carried_vals, *additional_inputs): + if not _is_boolean_scalar_tensor(pred): + raise RuntimeError( + f"cond_fn must return a boolean scalar tensor but got {pred}" + ) + out = body_fn(*carried_vals, *additional_inputs) + assert isinstance( + out, tuple + ), f"body_fn should return a tuple but got {type(out)}" + assert len(out) == len( + carried_inputs + ), "body_fn should return the same number of elements as carried_inputs" + carried_vals = out + return carried_vals + + +while_loop_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(while_loop_op, deferred_error=True) +) + + +@while_loop_op.py_impl(ProxyTorchDispatchMode) +def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs): + def _trace_while_loop( + proxy_mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs + ): + cond_graph = reenter_make_fx(cond_fn)(*carried_inputs, *additional_inputs) + body_graph = reenter_make_fx(body_fn)(*carried_inputs, *additional_inputs) + + next_name = None + i = 0 + while not next_name: + candidate = f"while_loop_cond_graph_{i}" + if hasattr(proxy_mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + cond_graph_name = next_name + body_graph_name = f"while_loop_body_graph_{i}" + assert not hasattr(proxy_mode.tracer.root, body_graph_name) + + proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph) + proxy_mode.tracer.root.register_module(body_graph_name, body_graph) + + args = (cond_graph, body_graph, carried_inputs, additional_inputs) + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", while_loop_op, proxy_args, {}, name="while_loop" + ) + + # body_fn return output with the same pytree and tensor meta data as carried_inputs + # so we could just return the output after one iteration. + out = body_fn(*carried_inputs, *additional_inputs) + return track_tensor_tree( + out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + return _trace_while_loop( + mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs + ) + + +@while_loop_op.py_impl(FakeTensorMode) +def while_loop_fake_tensor_mode( + mode, cond_fn, body_fn, carried_inputs, additional_inputs +): + with mode: + return body_fn(*carried_inputs, *additional_inputs) + + +@while_loop_op.py_functionalize_impl +def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs): + unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs) + unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) + unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs + with ctx.redispatch_to_next() as m: + functional_cond_fn = ctx.functionalize(_maybe_run_with_interpreter(cond_fn)) + functional_body_fn = ctx.functionalize(_maybe_run_with_interpreter(body_fn)) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + for fn, fn_name in [ + (functional_cond_fn, "cond_fn"), + (functional_body_fn, "body_fn"), + ]: + if _has_potential_branch_input_mutation( + fn, unwrapped_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + f"torch.while_loop's {fn_name} might be modifying the input!" + ) + + if _has_potential_branch_input_alias( + fn, unwrapped_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + f"torch.while_loop's {fn_name} might be aliasing the input!" + ) + ret = while_loop_op( + functional_cond_fn, + functional_body_fn, + unwrapped_carried_inputs, + unwrapped_additional_inputs, + ) + return ctx.wrap_tensors(ret) diff --git a/lib/python3.10/site-packages/torch/_higher_order_ops/wrap.py b/lib/python3.10/site-packages/torch/_higher_order_ops/wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..7327b3114a1d158418f8f6539bf1a52055ecbd0e --- /dev/null +++ b/lib/python3.10/site-packages/torch/_higher_order_ops/wrap.py @@ -0,0 +1,240 @@ +# mypy: allow-untyped-defs +import inspect +import itertools +import logging +from typing import Optional + +from torch._logging import warning_once +from torch._ops import HigherOrderOperator +from torch.types import _dtype +from torch.utils.checkpoint import checkpoint, CheckpointPolicy + + +log = logging.getLogger(__name__) + +uid = itertools.count(1) + + +# Used for testing the HigherOrderOperator mechanism +class Wrap(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("wrap") + + def __call__(self, func, *args, **kwargs): + # Dynamo already traces the body of HigherOrderOp beforehand when it + # so no need to trace into it. + import torch._dynamo # noqa: F401 + from torch._dynamo import disable + + @disable + def wrapper(): + result = func(*args, **kwargs) + return result + + return wrapper() + + +wrap = Wrap() + + +class WrapWithSetGradEnabled(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("wrap_with_set_grad_enabled") + + def __call__(self, enable_grad, wrapped_func, *args, **kwargs): + # Dynamo already traces the body of HigherOrderOp beforehand when it + # so no need to trace into it. + import torch._dynamo # noqa: F401 + from torch._dynamo import disable + + @disable + def wrapper(): + with torch.set_grad_enabled(enable_grad): + return wrapped_func(*args, **kwargs) + + return wrapper() + + +wrap_with_set_grad_enabled = WrapWithSetGradEnabled() + + +class WrapWithAutocast(HigherOrderOperator): + def __init__(self): + super().__init__("wrap_with_autocast") + + def __call__( + self, + device_type: str, + dtype: Optional[_dtype], + enabled: bool, + cache_enabled: Optional[bool], + wrapped_func, + *args, + **kwargs, + ): + # Dynamo already traces the body of HigherOrderOp beforehand when it + # so no need to trace into it. + import torch._dynamo # noqa: F401 + from torch._dynamo import disable + + @disable + def wrapper(): + with torch.autocast(device_type, dtype, enabled, cache_enabled): + return wrapped_func(*args, **kwargs) + + return wrapper() + + +wrap_with_autocast = WrapWithAutocast() + + +class WrapActivationCheckpoint(HigherOrderOperator): + """ + This operator is used to wrap torch.utils.checkpoint. This avoids + TorchDynamo to look into saved tensor hooks and directly passes the control + to AOT Autograd, which is ok with tracing saved tensor hooks. As a result of + AOT tracing torch.utils.checkpoint code, we have a backward graph with + recomputed forward nodes. + + However, we might deprecate this operator soon. The difficulty arises in the + functionalization of rng ops. Today, there are two different + functionalization of rng ops - one at AOT autograd and other at Inductor. + And they are difficult to map to each other. The rng states also complicate + pattern matching in Inductor. Due to the ease of implementation, we are + currently inclined towards functionalization at Inductor level, which means + that duplication/recomputation is done as a compiler pass in the + partitioners. See TagActivationCheckpoint for more information. + """ + + def __init__(self) -> None: + super().__init__("wrap_activation_checkpoint") + + def __call__(self, function, *args, **kwargs): + # use_reentrant is set to False because this op is going to be traced. + # And we ensure that AOT Autograd traces through the non reentrant + # version of checkpointing. + import torch.fx.traceback as fx_traceback + from torch.fx import Interpreter + + kwargs["use_reentrant"] = False + kwargs["preserve_rng_state"] = False + # Using interpreter allows preservation of metadata through torch.compile stack. + with fx_traceback.preserve_node_meta(): + return checkpoint(Interpreter(function).run, *args, **kwargs) + + +wrap_activation_checkpoint = WrapActivationCheckpoint() + + +class TagActivationCheckpoint(HigherOrderOperator): + """ + This operator is supposed to be used only with torch.compile stack. This + accepts a Fx graph module which needs to be checkpointed. This operator adds + "recomputable" tag to the nodes of the Fx graph that should be recomputed. + + The goal is to: + 1. Avoid using Dynamo to trace through saved tensor hooks. + 2. For selective checkpointing case, let AOTAutograd trace through + saved tensor hooks but has special logic with TorchDispatchMode to override + the usual saved_tensor_hooks fn logic in order to tag the nodes. + 3. Rely on the partitioners to actually duplicate the nodes. + This sits well in the torch.compile stack, because by the time graph + reaches partitioner, inductor has already run its functionalization of rng + ops (by setting fixed seed for each random op, see `replace_random_passes`). + Therefore, the duplication of nodes, by design, respects the rng states in + the forward and recomputed forward in backward. + """ + + def __init__(self) -> None: + super().__init__("tag_activation_checkpoint") + + @staticmethod + def divide_kwargs(kwargs): + """ + checkpoint fn can have mixed kwargs between checkpointed fn and + checkpoint fn itself. For example + >> def gn(x, y, z=None): + >> a = torch.matmul(x, y) + >> if z is not None: + >> return torch.matmul(a, z) + >> return a + >> def fn(x, y, z): + >> return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z)) + In the above case, z belongs to checkpointed function gn, but + use_reentrant belongs to the checkpoint function. This function splits + the kwargs into checkpoint_kwargs and gmod_kwargs (or + checkpointed_fn_kwargs). + We do sorting to ensure same graph from run to run for better + debuggability. It is not required for correctness. + """ + ckpt_signature = inspect.signature(checkpoint) + checkpoint_keys = set() + for name in ckpt_signature.parameters: + if name in ("function", "args", "kwargs"): + continue + checkpoint_keys.add(name) + + # `preserve_rng_state` is not a regular kwarg + checkpoint_keys.add("preserve_rng_state") + + checkpoint_kwargs = { + name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys + } + gmod_kwargs = { + name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys + } + return checkpoint_kwargs, gmod_kwargs + + def tag_nodes(self, gmod, is_sac): + unique_graph_id = next(uid) + for node in gmod.graph.nodes: + if node.op in ("call_function", "call_method", "call_module"): + node.meta["ac_graph_id"] = unique_graph_id + if is_sac: + # For selective checkpointing, we will populate this tag later in _CachingTorchDispatchMode. + node.meta["recompute"] = None + else: + # Under vanilla activation checkpointing, all nodes should be recomputed. + node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE + return gmod + + def __call__(self, gmod, *args, **kwargs): + import torch.fx.traceback as fx_traceback + from torch.fx import Interpreter + + if "_checkpoint_context_fn" in gmod.meta: + warning_once( + log, + """ +Detected that context_fn is passed to torch.utils.checkpoint under torch.compile. +Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_). +""", + ) + # use_reentrant is set to False because this op is going to be traced. + # And we ensure that AOT Autograd traces through the non reentrant + # version of checkpointing. + kwargs["use_reentrant"] = False + # preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through + # `torch.random.fork_rng` op (which is not supported yet under CUDA). + # This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state + # regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor + # instead of in AOTAutograd). + kwargs["preserve_rng_state"] = False + kwargs["context_fn"] = gmod.meta["_checkpoint_context_fn"] + # We first tag all nodes as "recompute" in this graph, and then we undo the "recompute" tag + # for specific nodes in _CachingTorchDispatchMode in torch/utils/checkpoint.py. + gmod = self.tag_nodes(gmod, is_sac=True) + # Using interpreter allows preservation of metadata through torch.compile stack. + with fx_traceback.preserve_node_meta(): + return checkpoint(Interpreter(gmod).run, *args, **kwargs) + else: + gmod = self.tag_nodes(gmod, is_sac=False) + # Using interpreter allows preservation of metadata through torch.compile stack. + # TODO: We want to use the same `checkpoint(Interpreter(gmod).run, *args, **kwargs)` here + # as the `context_fn != None` case, but that depends on in-place op support in TorchDispatchMode + torch.compile. + # (for details on in-place op issue, run `test_compile_selective_checkpoint_inplace_op` unit test) + with fx_traceback.preserve_node_meta(): + return Interpreter(gmod).run(*args) + + +tag_activation_checkpoint = TagActivationCheckpoint() diff --git a/lib/python3.10/site-packages/torch/_inductor/__init__.py b/lib/python3.10/site-packages/torch/_inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f95e7caaf71e95c675d4ea9e467c99d76ebdb842 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/__init__.py @@ -0,0 +1,179 @@ +# mypy: allow-untyped-defs +from typing import Any, Dict, List, Optional, Tuple + +import torch.fx +import torch.utils._pytree as pytree + + +__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"] + + +def compile( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + options: Optional[Dict[str, Any]] = None, +): + """ + Compile a given FX graph with TorchInductor. This allows compiling + FX graphs captured without using TorchDynamo. + + Args: + gm: The FX graph to compile. + example_inputs: List of tensor inputs. + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Callable with same behavior as gm but faster. + """ + from .compile_fx import compile_fx + + return compile_fx(gm, example_inputs, config_patches=options) + + +def aot_compile( + gm: torch.fx.GraphModule, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + *, + options: Optional[Dict[str, Any]] = None, +) -> str: + """ + Ahead-of-time compile a given FX graph with TorchInductor into a shared library. + + Args: + gm: The FX graph to compile. + args: Example arguments + kwargs: Example keyword arguments + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Path to the generated shared library + """ + from .compile_fx import compile_fx_aot, graph_returns_tuple + + assert graph_returns_tuple(gm), ( + "Graph output must be a tuple(). This is so that we can avoid " + "pytree processing of the outputs. Please change the module to " + "have tuple outputs." + ) + + # We will serialize the pytree info into the .so as constant strings + in_spec = None + out_spec = None + if isinstance(gm.graph._codegen, torch.fx.graph._PyTreeCodeGen): + codegen = gm.graph._codegen + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.recompile() + + if codegen.pytree_info.in_spec is not None: + in_spec = codegen.pytree_info.in_spec + if codegen.pytree_info.out_spec is not None: + out_spec = codegen.pytree_info.out_spec + + else: + if hasattr(gm, "_in_spec"): + in_spec = gm._in_spec + if hasattr(gm, "_out_spec"): + out_spec = gm._out_spec + + serialized_in_spec = pytree.treespec_dumps(in_spec) if in_spec is not None else "" + serialized_out_spec = ( + pytree.treespec_dumps(out_spec) if out_spec is not None else "" + ) + + flat_args_with_path, received_spec = pytree.tree_flatten_with_path( + (args, kwargs or {}) + ) + + # Replace non-tensor (constant) inputs with Nones, since these are not being + # used anyways by the graph + flat_example_inputs = [ + x[1] if isinstance(x[1], torch.Tensor) else None for x in flat_args_with_path + ] + + if in_spec is not None and received_spec != in_spec: + raise ValueError( # noqa: B904 + "Trying to flatten user inputs with exported input tree spec: \n" + f"{in_spec}\n" + "but actually got inputs with tree spec of: \n" + f"{received_spec}" + ) + + options = ( + { + "aot_inductor.serialized_in_spec": serialized_in_spec, + "aot_inductor.serialized_out_spec": serialized_out_spec, + } + if options is None + else { + **options, + "aot_inductor.serialized_in_spec": serialized_in_spec, + "aot_inductor.serialized_out_spec": serialized_out_spec, + } + ) + + return compile_fx_aot( + gm, + flat_example_inputs, # type: ignore[arg-type] + config_patches=options, + ) + + +def list_mode_options( + mode: Optional[str] = None, dynamic: Optional[bool] = None +) -> Dict[str, Any]: + r"""Returns a dictionary describing the optimizations that each of the available + modes passed to `torch.compile()` performs. + + Args: + mode (str, optional): The mode to return the optimizations for. + If None, returns optimizations for all modes + dynamic (bool, optional): Whether dynamic shape is enabled. + + Example:: + >>> torch._inductor.list_mode_options() + """ + + mode_options: Dict[str, Dict[str, bool]] = { + "default": {}, + # enable cudagraphs + "reduce-overhead": { + "triton.cudagraphs": True, + }, + # enable max-autotune + "max-autotune-no-cudagraphs": { + "max_autotune": True, + }, + # enable max-autotune + # enable cudagraphs + "max-autotune": { + "max_autotune": True, + "triton.cudagraphs": True, + }, + } + return mode_options[mode] if mode else mode_options # type: ignore[return-value] + + +def list_options() -> List[str]: + r"""Returns a dictionary describing the optimizations and debug configurations + that are available to `torch.compile()`. + + The options are documented in `torch._inductor.config`. + + Example:: + + >>> torch._inductor.list_options() + """ + + from torch._inductor import config + + current_config: Dict[str, Any] = config.shallow_copy_dict() + + return list(current_config.keys()) + + +def cudagraph_mark_step_begin(): + "Indicates that a new iteration of inference or training is about to begin." + from .cudagraph_trees import mark_step_begin + + mark_step_begin() diff --git a/lib/python3.10/site-packages/torch/_inductor/aoti_eager.py b/lib/python3.10/site-packages/torch/_inductor/aoti_eager.py new file mode 100644 index 0000000000000000000000000000000000000000..f733ce4fbd5a1748c663892c10cf349166bc9461 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/aoti_eager.py @@ -0,0 +1,298 @@ +import json +import logging +import os +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple +from unittest import mock + +import torch +import torch._export +from torch._inductor.utils import is_cpu_device + +from .runtime.runtime_utils import cache_dir + + +log = logging.getLogger(__name__) + + +def aoti_eager_cache_dir(namespace: str, device: str) -> Path: + return Path(cache_dir()) / "aoti_eager" / namespace / device + + +def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any: + from filelock import FileLock + + # Avoid circular import + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT + + op_conf_lock_file = f"{op_func_name_with_overload}.lock" + lock_dir = get_lock_dir() + return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT) + + +def load_aoti_eager_cache( + ns: str, op_func_name_with_overload: str, device_type: str +) -> List[Optional[Dict[str, Any]]]: + device_kernel_cache = aoti_eager_cache_dir(ns, device_type) + op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json" + if not op_conf.exists(): + return [] + + try: + with aoti_eager_op_conf_lock(op_func_name_with_overload): + with open(op_conf) as f: + json_data = json.load(f) + for item in json_data: + # Get absolution path for kernel library + kernel_lib_abs_path = device_kernel_cache / item["kernel_path"] + item["kernel_path"] = kernel_lib_abs_path.as_posix() + + # Check if the kernel library exists + if not kernel_lib_abs_path.exists(): + return [] + + for metadata in item["meta_info"]: + if metadata.get("is_dynamic"): + raise NotImplementedError( + "Only support static shape for now" + ) + if ( + "device_type" in metadata + and metadata["device_type"] == "cpu" + ): + metadata["device_index"] = -1 + for dtype_key in ["dtype", "dtype_value"]: + if dtype_key in metadata: + metadata[dtype_key] = getattr( + torch, metadata[dtype_key].split(".")[-1] + ) + if "layout_value" in metadata: + metadata["layout_value"] = getattr( + torch, metadata["layout_value"].split(".")[-1] + ) + if "memory_format_value" in metadata: + metadata["memory_format_value"] = getattr( + torch, metadata["memory_format_value"].split(".")[-1] + ) + + return json_data + except Exception as e: + err_msg = f"Failed to load aoti eager cache: {e}" + log.exception(err_msg) + return [] + + +def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]: + return {int: torch.int32, float: torch.float, bool: torch.bool} + + +def supported_scalar_types() -> Tuple[type, ...]: + type_to_torch_dtype = supported_builtin_dtype_torch_dtype() + return tuple(type_to_torch_dtype.keys()) + + +def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]: + metadata: Dict[str, Any] = {} + metadata["is_dynamic"] = dynamic + + assert isinstance(input, torch.Tensor) + metadata["device_type"] = f"{input.device.type}" + if is_cpu_device([input]): + metadata["device_index"] = -1 + else: + metadata["device_index"] = input.device.index + metadata["dtype"] = f"{input.dtype}" + metadata["sizes"] = list(input.size()) + metadata["strides"] = list(input.stride()) + metadata["requires_grad"] = input.requires_grad + metadata["dispatch_key_set"] = torch._C._dispatch_keys(input).raw_repr() + return metadata + + +def extract_tensor_list_metadata( + dynamic: bool, + input: List[torch.Tensor], +) -> Dict[str, Any]: + metadata_list = [] + for item in input: + assert isinstance(item, torch.Tensor) + metadata_list.append(extract_tensor_metadata(dynamic, item)) + + metadata: Dict[str, Any] = {} + metadata["tensor_list"] = metadata_list + return metadata + + +def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]: + assert isinstance(input, supported_scalar_types()) + metadata: Dict[str, Any] = {} + metadata["is_dynamic"] = False + # Scalar tensor + metadata["device_type"] = device_type + metadata["device_index"] = -1 if device_type == "cpu" else 0 + type_to_torch_dtype = supported_builtin_dtype_torch_dtype() + metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" + metadata["scalar_value"] = input + return metadata + + +def extract_string_metadata(input: str) -> Dict[str, Any]: + assert isinstance(input, str) + metadata: Dict[str, Any] = {} + metadata["string_value"] = input + return metadata + + +def extract_dtype_metadata(input: torch.dtype) -> Dict[str, Any]: + assert isinstance(input, torch.dtype) + metadata: Dict[str, Any] = {} + metadata["dtype_value"] = f"{input}" + return metadata + + +def extract_device_metadata(input: torch.device) -> Dict[str, Any]: + assert isinstance(input, torch.device) + metadata: Dict[str, Any] = {} + metadata["device_type_value"] = f"{input.type}" + metadata["device_index_value"] = input.index + return metadata + + +def extract_layout_metadata(input: torch.layout) -> Dict[str, Any]: + assert isinstance(input, torch.layout) + metadata: Dict[str, Any] = {} + metadata["layout_value"] = f"{input}" + return metadata + + +def aoti_compile_with_persistent_cache( + ns: str, + op_func_name_with_overload: str, + device_type: str, + dynamic: bool, + f: Callable[..., Any], + args: Tuple[Any], + kwargs: Dict[str, Any], + *, + dynamic_shapes: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, + remove_runtime_assertions: bool = False, + disable_constraint_solver: bool = False, +) -> str: + """ + Compile the given function with persistent cache for AOTI eager mode. + """ + assert not dynamic, "Only support static shape for now" + flattened_inputs = list(args) + list(kwargs.values()) + if not all( + isinstance( + input, + ( + supported_scalar_types(), + torch.Tensor, + list, + str, + torch.dtype, + torch.device, + torch.layout, + ), + ) + for input in flattened_inputs + ): + err_msg = f"Unsupported input types: {flattened_inputs}" + log.exception(err_msg) + raise NotImplementedError(err_msg) + + for input in flattened_inputs: + if isinstance(input, list) and not all( + isinstance(item, torch.Tensor) for item in input + ): + err_msg = f"_impl_with_aoti_compile encounters unsupported input types: {flattened_inputs}" + log.exception(err_msg) + raise NotImplementedError(err_msg) + + persistent_cache = aoti_eager_cache_dir(ns, device_type) + if not persistent_cache.exists(): + persistent_cache.mkdir(parents=True) + + persistent_cache_lib = persistent_cache / "lib" + if not persistent_cache_lib.exists(): + persistent_cache_lib.mkdir() + + with mock.patch.dict( + os.environ, + {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()}, + ): + try: + kernel_lib_path = torch._export.aot_compile( + f, + args, + kwargs, + dynamic_shapes=dynamic_shapes, + remove_runtime_assertions=remove_runtime_assertions, + disable_constraint_solver=disable_constraint_solver, + # Some operations may have non-Tensor parameters like int, float, bool. These + # non-Tensor parameters will not be the input of the graph. Therefore, we do + # need to keep the same signature. + same_signature=False, + ) + + kernel_metadata_items = [] + + for idx, input in enumerate(flattened_inputs): + if isinstance(input, torch.Tensor): + metadata = extract_tensor_metadata(dynamic, input) + elif isinstance(input, list): + assert all(isinstance(item, torch.Tensor) for item in input) + metadata = extract_tensor_list_metadata(dynamic, input) + elif isinstance(input, supported_scalar_types()): + metadata = extract_scalar_metadata(device_type, input) + elif isinstance(input, str): + metadata = extract_string_metadata(input) + elif isinstance(input, torch.dtype): + metadata = extract_dtype_metadata(input) + elif isinstance(input, torch.device): + metadata = extract_device_metadata(input) + elif isinstance(input, torch.layout): + metadata = extract_layout_metadata(input) + else: + raise NotImplementedError(f"Unsupported input type: {type(input)}") + + metadata["arg_order"] = idx + kernel_metadata_items.append(metadata) + + kernel_meta_info: Dict[str, Any] = {} + kernel_meta_info["meta_info"] = kernel_metadata_items + kernel_meta_info["kernel_path"] = ( + Path(kernel_lib_path).relative_to(persistent_cache).as_posix() + ) + + json_data = [] + update_json = True + op_conf = persistent_cache / f"{op_func_name_with_overload}.json" + mode = "r" if op_conf.exists() else "w" + with aoti_eager_op_conf_lock(op_func_name_with_overload): + with open(op_conf, mode) as op_conf_file: + try: + json_data = json.load(op_conf_file) + except Exception as e: + json_data = [] + + assert isinstance(json_data, list) + for item in json_data: + assert isinstance(item, dict) + # Same kernel meta info already exists in the json file + if item["meta_info"] == kernel_metadata_items: + update_json = False + break + + if update_json: + json_data.append(kernel_meta_info) + with open(op_conf, "w") as op_conf_file: + json.dump(json_data, op_conf_file, indent=4) + + return kernel_lib_path + except Exception as e: + err_msg = f"Failed to compile {op_func_name_with_overload}: {e}" + log.exception(err_msg) + return "" diff --git a/lib/python3.10/site-packages/torch/_inductor/async_compile.py b/lib/python3.10/site-packages/torch/_inductor/async_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..0794ceb3eed5719b5037de6ceb1c89c30a52e1cd --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/async_compile.py @@ -0,0 +1,297 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +import multiprocessing +import os +import sys +from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from concurrent.futures.process import BrokenProcessPool +from functools import partial +from time import time +from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING + +import torch +from torch._dynamo.device_interface import get_registered_device_interfaces +from torch._inductor import config +from torch._inductor.codecache import ( + CodeCacheFuture, + CppCodeCache, + CppPythonBindingsCodeCache, + CUDACodeCache, + HalideCodeCache, + LambdaFuture, + ROCmCodeCache, + TritonCodeCache, + TritonFuture, +) +from torch._inductor.compile_worker.subproc_pool import ( + _warm_process_pool, + AnyPool, + SubprocPool, +) +from torch._inductor.compile_worker.watchdog import _async_compile_initializer +from torch._inductor.runtime.compile_tasks import ( + _set_triton_ptxas_path, + _worker_compile_triton, +) +from torch.hub import _Faketqdm, tqdm +from torch.utils._triton import has_triton_package + + +if TYPE_CHECKING: + from torch._inductor.runtime.hints import HalideMeta + +# timing metrics for time spent in the compilation +_cumulative_compile_time = 0.0 +_t0: Optional[float] = None + +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + + +def pre_fork_setup(): + """ + Setup that must be done prior to forking with a process pool. + """ + # ensure properties have been calculated before processes + # are forked + caching_device_properties() + + # Computing the triton key can be slow. If we call it before fork, + # it will be cached for the forked subprocesses. + try: + from triton.compiler.compiler import triton_key + + triton_key() + except ImportError: + # Triton might not be installed or might be an old version. + pass + + +def caching_device_properties(): + for _, device_interface in get_registered_device_interfaces(): + if device_interface.is_available(): + device_interface.Worker.get_device_properties() + + +def _compile_start() -> None: + global _t0 + if _t0 is None: + _t0 = time() + + +def _compile_end() -> None: + global _cumulative_compile_time, _t0 + if _t0 is not None: + t1 = time() + _cumulative_compile_time += t1 - _t0 + _t0 = None + # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) + + +_IS_WINDOWS = sys.platform == "win32" + +log = logging.getLogger(__name__) + + +# Used to keep track of all process pools invoked so far. +_pool_set: Set[AnyPool] = set() + + +def shutdown_compile_workers() -> None: + """Shut down all outstanding compile-worker pools.""" + for pool in _pool_set: + pool.shutdown() + after_fork() + + +def after_fork(): + """Reset pools to initial state without shutting them down""" + _pool_set.clear() + AsyncCompile.process_pool.cache_clear() + + +try: + os.register_at_fork(after_in_child=after_fork) +except AttributeError: + pass # register_at_fork does not exists on windows + + +class AsyncCompile: + def __init__(self) -> None: + pass + + @staticmethod + @functools.lru_cache(1) + def pool() -> ThreadPoolExecutor: + assert config.compile_threads > 1 + return ThreadPoolExecutor(config.compile_threads) + + @staticmethod + def _get_ready(): + """No-op function to help mark when the subprocess pool is ready.""" + return "ready" + + @staticmethod + @functools.lru_cache(1) + def process_pool() -> AnyPool: + assert config.compile_threads > 1 + pool: AnyPool + if config.worker_start_method == "subprocess": + # Wrapper around ProcessPoolExecutor forks in a new process we control + pool = SubprocPool(config.compile_threads) + else: + pre_fork_setup() + ctx = multiprocessing.get_context(config.worker_start_method) + pool = ProcessPoolExecutor( + config.compile_threads, + mp_context=ctx, + initializer=partial(_async_compile_initializer, os.getpid()), + ) + # when this pool is created in a subprocess object, the normal exit handler + # doesn't run, and we need to register our own handler. + # exitpriority has to be high, because another one of the finalizers will + # kill the worker thread that sends the shutdown message to the workers... + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + + # Set an attribute we can check to see if the pool is ready. + pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr] + _pool_set.add(pool) + return pool + + @classmethod + def warm_pool(cls) -> None: + if config.compile_threads <= 1: + return + _compile_start() + _warm_process_pool(cls.process_pool(), config.compile_threads) + _compile_end() + + @classmethod + def submit(cls, task: Callable[..., Any]) -> Any: + if config.compile_threads <= 1: + return task() + return cls.pool().submit(task) + + def _use_process_pool(self): + return ( + config.compile_threads > 1 + and self.process_pool().ready_future.done() # type: ignore[union-attr] + ) + + def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): + kernel_code_log.info("Triton Kernel:\n%s", source_code) + _compile_start() + _set_triton_ptxas_path() + + kernel = TritonCodeCache.load(kernel_name, source_code) + if self._use_process_pool(): + # We want to support changing these env vars after (and while) the + # process pool is running, so pass them to the subprocess to reset. + env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] + extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} + return TritonFuture( + kernel, + self.process_pool().submit( + _worker_compile_triton, + kernel._reload_in_subproc, + extra_env, + ), + ) + else: + kernel.precompile() + return kernel + + def multi_kernel(self, *args, **kwargs) -> Any: + from torch._inductor.codegen.multi_kernel import MultiKernelCall + + # no need to call this in parallel since the sub-kernels are already parallel tasks + return MultiKernelCall(*args, **kwargs) + + def cpp(self, source_code: str): + kernel_code_log.info("CPP Kernel:\n%s", source_code) + if config.compile_threads <= 1: + return CppCodeCache.load(source_code).kernel + else: + get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) + return LambdaFuture(lambda: get_result().kernel) + + def cpp_pybinding(self, argtypes: List[str], source_code: str): + kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) + if config.compile_threads <= 1: + return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) + else: + get_result = CppPythonBindingsCodeCache.load_pybinding_async( + argtypes, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def cuda(self, source_code, dst_file_ext): + kernel_code_log.info("CUDA Kernel:\n%s", source_code) + + def task(): + return CUDACodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def rocm(self, source_code, dst_file_ext): + kernel_code_log.info("ROCm Kernel:\n%s", source_code) + + def task(): + return ROCmCodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def halide(self, meta: HalideMeta, source_code: str): + kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) + if config.compile_threads <= 1: + return HalideCodeCache.generate_halide(meta, source_code) + else: + get_result = HalideCodeCache.generate_halide_async( + meta, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def wait(self, scope: Dict[str, Any]) -> None: + num_kernels = len( + [ + value + for key, value in scope.items() + if isinstance(value, (Future, CodeCacheFuture)) + ] + ) + pbar = tqdm( + total=num_kernels, + desc="Inductor Compilation", + disable=config.disable_progress, + delay=0, + ) + if config.compile_threads > 1: + for key, result in scope.items(): + if config.verbose_progress and not isinstance(pbar, _Faketqdm): + pbar.set_postfix_str(key) + if isinstance(result, (Future, CodeCacheFuture)): + try: + scope[key] = result.result() + except BrokenProcessPool as e: + raise RuntimeError( + "A compilation subprocess exited unexpectedly. This " + "is likely due to a crash. To facilitate debugging, " + "you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 " + "to cause compilation to occur in the main process." + ) from e + pbar.update(1) + + _compile_end() + + +if ( + os.environ.get("TORCH_TNT_IN_USE", "0") == "1" + or os.environ.get("TORCH_WARM_POOL", "1") != "1" + # The subprocess pool is only used for the Triton backend + or not has_triton_package() +): + pass +else: + AsyncCompile.warm_pool() diff --git a/lib/python3.10/site-packages/torch/_inductor/autotune_process.py b/lib/python3.10/site-packages/torch/_inductor/autotune_process.py new file mode 100644 index 0000000000000000000000000000000000000000..eea4f8d6573d82c5d629382893920c8256953c6a --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/autotune_process.py @@ -0,0 +1,876 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +import ctypes +import dataclasses +import functools +import logging +import os +import queue +import time +import warnings +from concurrent.futures import ThreadPoolExecutor +from ctypes import byref, c_size_t, c_void_p, CDLL +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + TYPE_CHECKING, + Union, +) + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch import multiprocessing +from torch._dynamo.testing import rand_strided +from torch._inductor import ir +from torch._inductor.codecache import ( + CppCodeCache, + CUDACodeCache, + DLLWrapper, + get_hash, + PyCodeCache, +) + + +if TYPE_CHECKING: + from multiprocessing.process import BaseProcess + from multiprocessing.queues import Queue + from types import ModuleType + + from torch._inductor.select_algorithm import TritonTemplateCaller + +from . import config +from .runtime.benchmarking import benchmarker +from .virtualized import V + + +CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" +EXIT_HANDLER_REGISTERED = False + +log = logging.getLogger(__name__) + + +# Used to synchronize between parent and child processes +class Ping: + pass + + +class Pong: + pass + + +class NonzeroWorkspaceNotSupportedError(Exception): + pass + + +@contextlib.contextmanager +def set_cuda_visible_device(device: Optional[int]): + """ + Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the + specified single device. If device is None, don't manipulate the environment. + """ + if device is None: + yield + return + + current = os.environ.get(CUDA_VISIBLE_DEVICES) + os.environ[CUDA_VISIBLE_DEVICES] = str(device) + try: + yield + finally: + if current is None: + del os.environ[CUDA_VISIBLE_DEVICES] + else: + os.environ[CUDA_VISIBLE_DEVICES] = current + + +@dataclasses.dataclass +class TuningProcess: + """ + Abstraction for launching a helper process to benchmark kernels. Spawns + the parent process and uses multiprocessing queues to send benchmark + requests and return results. + """ + + device: Optional[int] = None + process: Optional[BaseProcess] = None + request_queue: Optional[Queue[Any]] = None + response_queue: Optional[Queue[Any]] = None + + @staticmethod + def process_main( + request_queue: Queue[Any], + response_queue: Queue[Any], + ) -> None: + """ + Entry point for the child process. + """ + log.debug( + "Entering TuningProcess child. Visible devices = %s", + os.environ.get(CUDA_VISIBLE_DEVICES), + ) + try: + TuningProcess.workloop(request_queue, response_queue) + except Exception as ex: + log.exception("Exception in TuningProcess") + + @staticmethod + def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None: + """ + Work loop for the benchmarking subprocess. + """ + while True: + obj = request_queue.get() + + if obj is None: + break # None is a sentinel for the child to terminate + elif isinstance(obj, Ping): + response_queue.put(Pong()) + elif isinstance(obj, BenchmarkRequest): + response_queue.put(obj.benchmark()) + else: + raise RuntimeError(f"Invalid request type {type(obj)}") + + def valid(self) -> bool: + """ + True if the sub-process has been initialized. + """ + return ( + self.process is not None + and self.request_queue is not None + and self.response_queue is not None + ) + + def clear(self) -> None: + """ + Reset to an uninitialized state. + """ + self.process = self.request_queue = self.response_queue = None + + def initialize(self) -> None: + """ + Create child process, request/response queues, and do the warm up. + Set the environment to make only the provided GPU device visible + to the process. + """ + if self.valid(): + return + + # cuda runtime does not work with "fork", use "spawn" to start processes. + ctx = multiprocessing.get_context("spawn") + self.request_queue = ctx.Queue() + self.response_queue = ctx.Queue() + + self.process = ctx.Process( + target=self.process_main, + args=( + self.request_queue, + self.response_queue, + ), + ) + assert self.process is not None + with set_cuda_visible_device(self.device): + self.process.start() + + def put(self, obj: Any) -> None: + """ + Push a work item to the child process. + """ + # In case of a prior crash, ensure the subprocess is running + self.initialize() + assert self.request_queue is not None + self.request_queue.put(obj) + + def get( + self, result_timeout=120.0, graceful_timeout=3.0, terminate_timeout=1.0 + ) -> Any: + """ + Get a response from the child process. Raises queue.Empty on timeout + or if the process dies. + + This method is (so far) only used by TuningProcessPool, where torch._inductor.config entries are being used + to populate the timeouts: + + Arguments: + + @param result_timeout: Timeout in seconds, defaults to 120.0 or to + config.max_autotune_subproc_result_timeout_seconds when called by TuningProcessPool + @param graceful_timeout: Timeout in seconds to allow graceful shutdown (SIGTERM is sent after this time). + Defaults to 3.0 or to config.max_autotune_subproc_graceful_timeout_seconds + @param terminate_timeout: Timeout in seconds after SIGTERM, until we send SIGKILL if the process + remains alive. Defaults to 1.0 or to + config.max_autotune_subproc_terminate_timeout_seconds. + Returns: + A response from the child process (Any type) + """ + assert self.process is not None + assert self.response_queue is not None + while True: + try: + remaining_timeout = result_timeout + res = None + while remaining_timeout is not None and remaining_timeout >= 1.0: + remaining_timeout -= 0.5 + try: + res = self.response_queue.get(timeout=0.5) + break + except queue.Empty: + if not self.process.is_alive(): + raise # is being caught a few lines below + if res is None: + res = self.response_queue.get(timeout=remaining_timeout) + return res + except queue.Empty: + status = self.process.exitcode + if status is None: + self.kill( + graceful_timeout=graceful_timeout, + terminate_timeout=terminate_timeout, + ) + else: + # child process crashed + self.clear() + raise + + def terminate(self) -> None: + """ + Signal the child process to terminate. + """ + if self.valid(): + assert self.process is not None + assert self.request_queue is not None + self.request_queue.put(None) + + def wait(self) -> None: + """ + Wait for the child process to exit. + """ + if self.process is not None: + self.process.join() + self.clear() + + def kill(self, graceful_timeout=5.0, terminate_timeout=1.0) -> None: + # Tries to kill the process, using a graceful_timeout in which the process + # is allowed to exit gracefully. If the process is still alive, + # it will be terminated. If that is not sufficient to end it + # within terminate_timeout seconds, it will be killed. + if self.process is not None: + self.terminate() + self.process.join(timeout=graceful_timeout) + if self.process.is_alive(): + log.warning( + "Sending SIGTERM to process with PID %d", + self.process.pid, + ) + self.process.terminate() + self.process.join(timeout=terminate_timeout) + if self.process.is_alive(): + log.error( + "Sending SIGKILL to process with PID %d", + self.process.pid, + ) + self.process.kill() # This should definitely end the process + self.clear() + + +@dataclasses.dataclass +class TuningProcessPool: + """ + Maintains a pool of TuningProcesses to benchmark kernels in parallel + across devices. By default, we create one TuningProcess per device and + set the sub-process environment to make only that device visible. + """ + + processes: Optional[queue.Queue[TuningProcess]] = None + executor: Optional[ThreadPoolExecutor] = None + + def initialize(self) -> None: + """ + Start the child processes. + """ + assert (self.processes is None) == (self.executor is None) + if self.processes is not None: + return + + devices = self.get_device_list() + log.debug("Sub-process autotune device list: %s", devices) + + # Launch the child processes and push a msg to "warm up" + self.processes = queue.Queue() + for device in devices: + p = TuningProcess(device=device) + p.initialize() + p.put(Ping()) + self.processes.put(p) + + # Wait for the initialization to finish + for p in self.processes.queue: + assert isinstance(p.get(result_timeout=None), Pong) + + # Use a thread pool to manage distributing work to the subprocesses. + # Threads block on an available process, so it makes sense to match + # the number of threads with the number of devices. + self.executor = ThreadPoolExecutor(max_workers=len(devices)) + + # Register the exit handler for the parent process so it will terminate + # the child processes. + global EXIT_HANDLER_REGISTERED + if not EXIT_HANDLER_REGISTERED: + EXIT_HANDLER_REGISTERED = True + import atexit + + atexit.register(self.terminate) + + def get_device_list(self) -> Sequence[Optional[int]]: + """ + Gather the list of devices to be used in the pool. + """ + if not config.autotune_multi_device: + # Don't use multiple devices + return [None] + + count = torch.cuda.device_count() + + # If the user specified the visible devices in the env, use those. + if CUDA_VISIBLE_DEVICES in os.environ: + devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")] + assert len(devices) <= count + return devices + + return list(range(count)) + + def terminate(self) -> None: + """ + Signal all child processes to terminate. + """ + if self.executor is not None: + self.executor.shutdown() + self.executor = None + + if self.processes is not None: + for p in self.processes.queue: + p.terminate() + for p in self.processes.queue: + p.wait() + self.processes = None + + def target(self, choice: TritonTemplateCaller) -> float: + """ + Entry point for the thread-pool helper threads: Wait for an open TuningProcess, + remove it from the queue, execute the benchmark in that subprocess, and return + the TuningProcess to the queue. + """ + assert choice.bmreq is not None + assert self.processes is not None + + process = self.processes.get() + process.put(choice.bmreq) + try: + return process.get( + config.max_autotune_subproc_result_timeout_seconds, + config.max_autotune_subproc_graceful_timeout_seconds, + config.max_autotune_subproc_terminate_timeout_seconds, + ) + except queue.Empty: + warnings.warn( + f"Failed to benchmark choice '{choice}'. It will be ignored. " + "Please debug the root cause in case the choice can bring perf gains." + ) + # set to INF so this choice will be ignored + return float("inf") + finally: + self.processes.put(process) + + def benchmark( + self, + choices: List[TritonTemplateCaller], + ) -> Dict[TritonTemplateCaller, float]: + """ + Benchmark each choice in a separate process. + """ + assert self.processes is not None, "Tuning process pool is not initialized" + assert self.executor is not None + + results = {} + + # Use a ThreadExecutorPool to spread the work across the subprocesses and + # to grab subprocesses as soon as they're free. + for choice, result in zip(choices, self.executor.map(self.target, choices)): + results[choice] = result + + return results + + +tuning_pool = TuningProcessPool() + + +LayoutOrBuffer = Union[ir.Layout, ir.Buffer] + + +@dataclasses.dataclass +class TensorMeta: + device: torch.device + dtype: torch.dtype + sizes: torch._prims_common.ShapeType + strides: torch._prims_common.StrideType + offset: int + name: Optional[str] = None + + @classmethod + def from_irnodes( + cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]] + ) -> Union[TensorMeta, List[TensorMeta]]: + if isinstance(irnodes, Sequence): + result: List[Any] = [cls.from_irnodes(x) for x in irnodes] + assert all(isinstance(x, TensorMeta) for x in result) + return result + + node = irnodes + if isinstance(node, ir.Layout): + node = ir.Buffer("fake", node) + + dtype = node.get_dtype() + assert dtype is not None + + return TensorMeta( + device=node.get_device(), + dtype=dtype, + sizes=V.graph.sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + strides=V.graph.sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + offset=V.graph.sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + name=node.get_name(), + ) + + def to_tensor(self) -> torch.Tensor: + return rand_strided( + self.sizes, + self.strides, + device=self.device, + dtype=self.dtype, + extra_size=self.offset, + ) + + +@dataclasses.dataclass +class BenchmarkRequest: + """ + Only handle triton template benchmark for now. The extern kernel benchmark + can be done inside the same process since they usually don't cause crash. + + Important: Instances of this class and subclasses have to be serializable + across process boundaries. Do not put CUDA Tensors in here! + """ + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + ) -> None: + # the kernel name defined in the module + self.kernel_name = kernel_name + + if isinstance(input_tensor_meta, TensorMeta): + input_tensor_meta = [input_tensor_meta] + self.input_tensor_meta = input_tensor_meta + + if isinstance(output_tensor_meta, (tuple, list)): + assert len(output_tensor_meta) == 1 + output_tensor_meta = output_tensor_meta[0] + self.output_tensor_meta = output_tensor_meta + + self.extra_args = extra_args + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + raise NotImplementedError + + def cleanup_run_fn(self) -> None: + pass + + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + raise NotImplementedError + + def benchmark( + self, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + debug = log.isEnabledFor(logging.DEBUG) + if debug: + start_ts = time.time() + + # create args and out tensor + if output_tensor is None: + assert len(input_tensors) == 0 + input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta) + output_tensor = self.output_tensor_meta.to_tensor() + + if debug: + create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + start_ts = time.time() + try: + fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor) + except NonzeroWorkspaceNotSupportedError: + # Skipping all ops with nonzero workspace requirements + log.info("Skipping op due to nonzero workspace requirement") + return float("inf") + + if debug: + load_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + start_ts = time.time() + + out = self.do_bench(fn, *input_tensors, output_tensor) + + if debug: + bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + log.debug( + "InChildProcess %s: load %f, create tensor %f, bench %f", + str(self), + load_elapse, # type: ignore[possibly-undefined] + create_tensor_elapse, # type: ignore[possibly-undefined] + bench_elapse, + ) + self.cleanup_run_fn() + return out + + +class TestBenchmarkRequest(BenchmarkRequest): + """ + Supports unit testing. Defined in this file so that the TuningProcess + sub-process knows how to unpickle these objects. + """ + + def __init__(self, value: Optional[float] = None) -> None: + self.value = value + + def benchmark( + self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None + ) -> float: + if self.value is None: + raise Exception("Failed to run") # noqa: TRY002 + return self.value + + +class GPUDeviceBenchmarkRequest(BenchmarkRequest): + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + device_idx_set = { + tensor.device.index + for tensor in [*input_tensors, output_tensor] + if isinstance(tensor, torch.Tensor) + and tensor.is_cuda + and tensor.device.index is not None + } + assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}" + if len(device_idx_set) == 1: + device_idx = next(iter(device_idx_set)) + else: + device_idx = torch.cuda.current_device() + + with torch.cuda.device(device_idx): + out = benchmarker.benchmark_gpu(fn) + torch.cuda.synchronize() # shake out any CUDA errors + + return out + + +class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + module_path: str, # the path of the module defining the triton kernel + module_cache_key: str, + grid: List[int], + num_stages: int, + num_warps: int, + matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction. + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.module_path = module_path + self.module_cache_key = module_cache_key + self.grid = grid + self.num_stages = num_stages + self.num_warps = num_warps + self.matrix_instr_nonkdim = matrix_instr_nonkdim + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + log.debug( + "benchmark module key: %s, path: %s", + self.module_cache_key, + self.module_path, + ) + + run_method = getattr(mod, self.kernel_name).run + extra_args = list(self.extra_args) + + # Newer version of triton add warmup argument to JITFunction.run. + # This code handles backward-compatibility. + warmup_arg = {} + import inspect + + if "warmup" in inspect.signature(run_method).parameters: + warmup_arg["warmup"] = False + + from torch._C import _cuda_getCurrentRawStream as get_raw_stream + + if torch.version.hip and self.matrix_instr_nonkdim != 0: + return functools.partial( + run_method, + *input_tensors, + output_tensor, + *self.extra_args, + grid=self.grid, + **warmup_arg, + stream=get_raw_stream(self.output_tensor_meta.device.index), + ) + else: + return functools.partial( + run_method, + *input_tensors, + output_tensor, + *self.extra_args, + grid=self.grid, + **warmup_arg, + stream=get_raw_stream(self.output_tensor_meta.device.index), + ) + + def precompile(self): + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + getattr(mod, self.kernel_name).precompile() + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}" + + +class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.workspace_size: int = 0 + self.workspace: Optional[torch.Tensor] = None + self.DLL: Optional[DLLWrapper] = None + self._workspace_size_updated = False + self.hash_key: str = "" + self.source_file: str = "" + self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so") + + def precompile(self): + # Prepopulate CUDACodeCache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + CUDACodeCache.compile(self.source_code, "so") + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + self.ensure_dll_loaded() + self.update_workspace_size() + args = [ + c_void_p(tensor.data_ptr()) + for tensor in list(input_tensors) + [output_tensor] + ] + log.debug( + "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + run_method = getattr(self.DLL, self.kernel_name) + workspace_ptr = c_void_p(0) + if self.workspace_size > 0: + self.workspace = torch.zeros( + (self.workspace_size + 7) // 8, + dtype=torch.float64, + device=output_tensor.device, + ) + workspace_ptr = c_void_p(self.workspace.data_ptr()) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *self.extra_args, + None, # null workspace size ptr + workspace_ptr, # set workspace ptr, + stream_ptr, + ) + + def update_workspace_size(self) -> None: + if self._workspace_size_updated: + return + self.ensure_dll_loaded() + unique_input_count = len({meta.name for meta in self.input_tensor_meta}) + args = [c_void_p(None) for _ in range(unique_input_count + 1)] + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + + run_method = getattr(self.DLL, self.kernel_name) + # Retrieve workspace_size and initialize workspace. + c_workspace_size = c_size_t() + run_method( + *args, # input ptrs and output ptrs + *self.extra_args, + byref( + c_workspace_size + ), # set workspace size ptr to retrieve workspace size + None, # null workspace ptr + stream_ptr, + ) + torch.cuda.synchronize() # shake out any CUDA errors + self.workspace_size = c_workspace_size.value + log.debug( + "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 + self.workspace_size, + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + self._workspace_size_updated = True + + def ensure_dll_loaded(self): + if self.DLL is None: + self.DLL, self.hash_key, self.source_file = CUDACodeCache.load( + self.source_code, "so" + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + self.workspace = None + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" + + +class CPUDeviceBenchmarkRequest(BenchmarkRequest): + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + return benchmarker.benchmark_cpu(fn) + + +class CppBenchmarkRequest(CPUDeviceBenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.hash_key = get_hash(source_code) + self.DLL: Optional[Union[CDLL, ModuleType]] = None + + def precompile(self): + # Prepopulate CppCodeCache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + CppCodeCache.load(self.source_code, cuda=False) + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf + self.DLL = CppCodeCache.load(self.source_code, cuda=False) + args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]] + log.debug( + "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.DLL, + args, + self.extra_args, + ) + run_method = getattr(self.DLL, self.kernel_name) + # Assume only size with type ctypes.c_ulonglong in extra_args + assert all(isinstance(arg, ctypes.c_ulonglong) for arg in self.extra_args) + run_method.argtypes = [ctypes.c_ulonglong] * ( + len(args) + len(list(self.extra_args)) + ) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *self.extra_args, + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + """ + Check close attr due to it crash on Windows. + """ + if hasattr(self.DLL, "close"): + self.DLL.close() + + def __str__(self) -> str: + return f"{self.kernel_name=}" + + +def benchmark_in_sub_process( + choices: List[TritonTemplateCaller], +) -> Dict[TritonTemplateCaller, float]: + """ + Do benchmarking in a subprocess and return the perf number (latency). + """ + return tuning_pool.benchmark(choices) diff --git a/lib/python3.10/site-packages/torch/_inductor/bounds.py b/lib/python3.10/site-packages/torch/_inductor/bounds.py new file mode 100644 index 0000000000000000000000000000000000000000..7452f2bb1b62b6862a63bcb8d508032d5e33c7e9 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/bounds.py @@ -0,0 +1,140 @@ +# mypy: allow-untyped-defs +import logging +import operator +from functools import partial +from typing import Any, Callable, Dict + +from sympy import Expr + +import torch +from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges + +from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock +from .utils import cache_on_self, dominated_nodes +from .virtualized import V + + +log = logging.getLogger(__name__) + + +class BoundVars: + """ + Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() + It exposes the ranges of the nodes in the `bounds` variable + + Note. A current limitation of this analysis is that it just works on a per-loop basis. + We should be able to propagate the bounds between across the whole graph. This may benefit + the case a bounded variable is returned by a kernel and fed into another. + """ + + def __init__(self, loop_body: LoopBody) -> None: + def upper_bound(v): + return bound_sympy(v).upper if isinstance(v, Expr) else v + + self.loop_body = loop_body + self.replacement_vals = { + k: ValueRanges[Expr](0, upper_bound(v) - 1) + for k, v in loop_body.var_ranges.items() + } + # avoid computing these values, pessimistically assume that they are unbounded + self.unbounded_vars = dominated_nodes( + node + for node in self.loop_body.get_nodes() + if node.target in ["load", "reduction", operator.getitem] + or "masked_subblock" in node.target + ) + # To access this variable call `get_bounds()` + self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {} + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"loop_body={self.loop_body},\n " + f"replacement_vals={self.replacement_vals}, \n" + f"unbounded_vars={self.unbounded_vars}, \n" + f"_bounds={self._bounds})" + ) + + @cache_on_self + def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: + submodules = self.swap_submodules(self.loop_body.submodules) + + # Initialize the environment with the unbounded variables + for node in self.unbounded_vars: + # we need to evaluate masked_subblock to recurse, and we need to set indirect values + if not isinstance(node.target, str) or ( + "masked_subblock" not in node.target + and "set_indirect" not in node.target + ): + self._bounds[node] = ValueRanges[Expr].unknown() + + with V.set_ops_handler(ValueRangeAnalysis()): + interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) + log.debug("get_bounds:\n%s", self.loop_body.root_block.graph) + interpreter.run(V.get_ops_handler(), initial_env=self._bounds) + return self._bounds + + def swap_submodules( + self, submodules: Dict[str, Callable[..., Any]] + ) -> Dict[str, Callable[..., ValueRanges[Expr]]]: + result: Dict[str, Callable[..., ValueRanges[Expr]]] = {} + for key in submodules.keys(): + if key == "get_index": + result[key] = self.get_index + elif "masked_subblock" in key: + subblock = self.loop_body.subblocks[key] + # The result within the lambda will reference to the final + # set of modules at the end of the for-loop as it stores a reference to it + + # bind subblock in a function because python lambdas close over by reference + # moving the lambda out of make_fn would close over the reference to subblock, + # so all lambdas would have the same subblock reference that is the final + # subblock in the loop + def make_fn(subblock): + return lambda mask, value: self.masked_subblock( + subblock, self._bounds, mask, value, result + ) + + result[key] = make_fn(subblock) + elif "set_indirect" in key: + idx = int(key[len("set_indirect") :]) + var = self.loop_body.indirect_vars[idx] + indirect = partial(self.set_indirect, var) + result[key] = indirect + else: + assert "scan" in key + result[key] = submodules[key] + + return result + + def masked_subblock( + self, + subblock: LoopBodyBlock, + env: Dict[torch.fx.Node, ValueRanges[Expr]], + mask: Any, + value: Any, + submodules: Dict[str, Callable[..., Any]], + ) -> ValueRanges[Expr]: + interp = InterpreterShim(subblock.graph, submodules) + interp.run(V.get_ops_handler(), initial_env=env) + output = [node for node in subblock.graph.nodes if node.target == "output"] + assert len(output) == 1 + # dont bother unioning with value since the load from buffer will be + # pessimistically assumed to be inf anyway + return interp.env[output[0]] + + def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]: + assert isinstance(new, ValueRanges) + self.replacement_vals[old] = new + return new + + def get_index(self, name: Expr) -> ValueRanges[Expr]: + expr = self.loop_body.indexing_exprs[name] + bound = self.replacement_vals.get(expr) + if bound is None: + bound = bound_sympy(expr, self.replacement_vals) + # The following assertion is true at the time of this writing + # We don't assert is as to not execute bound_sympy when bound is not None + # assert bound is None or bound == bound_sympy(expr, self.replacement_vals) + self.replacement_vals[name] = bound + return bound diff --git a/lib/python3.10/site-packages/torch/_inductor/codecache.py b/lib/python3.10/site-packages/torch/_inductor/codecache.py new file mode 100644 index 0000000000000000000000000000000000000000..59cc47ac06c94e492bed109e60c28bb6b7e61ba0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/codecache.py @@ -0,0 +1,3353 @@ +from __future__ import annotations + +import base64 +import copyreg +import dataclasses +import functools +import hashlib +import importlib +import io +import json +import logging +import os +import pickle +import pkgutil +import re +import shlex +import shutil +import struct +import subprocess +import sys +import sysconfig +import tempfile +import textwrap +import threading +import warnings +from bisect import bisect_right +from copy import copy +from ctypes import c_void_p, CDLL, cdll +from datetime import timedelta +from functools import partial +from pathlib import Path +from time import time, time_ns +from types import ModuleType +from typing import ( + Any, + Callable, + cast, + Counter, + Dict, + Generator, + List, + NoReturn, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import TypeAlias + +import torch +import torch.distributed as dist +from torch import SymInt, Tensor +from torch._dynamo.utils import counters, dynamo_timed, get_chromium_event_logger +from torch._inductor import config, exc, metrics +from torch._inductor.codegen.cuda import cuda_env +from torch._inductor.codegen.rocm.compile_command import ( + rocm_compile_command, + rocm_compiler, +) +from torch._utils_internal import log_cache_bypass + +from .utils import _align + + +T = TypeVar("T") + + +if TYPE_CHECKING: + from collections.abc import KeysView + + from .remote_cache import JsonDataTy, RemoteCache + + +""" +codecache.py, cpp_builder.py and cpu_vec_isa.py import rule: +https://github.com/pytorch/pytorch/issues/124245#issuecomment-2197778902 +""" +from torch._inductor.cpp_builder import ( + _set_gpu_runtime_env, + _transform_cuda_paths, + CppBuilder, + CppOptions, + CppTorchCudaOptions, + get_compiler_version_info, + get_cpp_compiler, + get_name_and_dir_from_output_file_path, + normalize_path_separator, +) +from torch._inductor.cpu_vec_isa import pick_vec_isa +from torch._inductor.cudagraph_utils import ( + BoxedDeviceIndex, + CudagraphCachedInfo, + log_cudagraph_skip_and_bump_counter, +) +from torch._inductor.runtime.compile_tasks import ( + _module_to_triton_kernel, + _reload_python_module, + _reload_python_module_in_subproc, +) +from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir +from torch._inductor.utils import ( + ALIGN_BYTES, + align_inputs_from_check_idxs, + BoxedBool, + clear_on_fresh_inductor_cache, + is_linux, + is_windows, + set_tracing_context_output_strides, +) +from torch._logging import trace_structured +from torch._subclasses.fake_tensor import ( + extract_tensor_metadata, + FakeTensor, + TensorMetadata, +) +from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv + + +if TYPE_CHECKING: + from concurrent.futures import Future + + from torch._inductor.graph import GraphLowering + from torch._inductor.ir import ChoiceCaller + from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta + + +_HERE = os.path.abspath(__file__) +_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) +_LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld") + +_IS_WINDOWS = sys.platform == "win32" + +if config.is_fbcode(): + from triton.fb import build_paths + from triton.fb.build import _run_build_command + + from torch._inductor.fb.utils import ( + log_global_cache_errors, + log_global_cache_stats, + log_global_cache_vals, + use_global_cache, + ) +else: + + def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + pass + + def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + pass + + def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + pass + + def use_global_cache() -> bool: # type: ignore[misc] + return False + + +output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") + +LOCK_TIMEOUT = 600 + +_IS_WINDOWS = sys.platform == "win32" + + +log = logging.getLogger(__name__) + + +def cpp_wrapper_cache_dir(name: str) -> str: + cu_str = ( + "cpu" + if torch.version.cuda is None + else f'cu{torch.version.cuda.replace(".", "")}' + ) + python_version = f"py{sys.version_info.major}{sys.version_info.minor}" + build_folder = f"{python_version}_{cu_str}" + + cpp_wrapper_dir = os.path.join(cache_dir(), build_folder) + cpp_wrapper_build_directory = os.path.join(cpp_wrapper_dir, name) + os.makedirs(cpp_wrapper_build_directory, exist_ok=True) + return cpp_wrapper_build_directory + + +def get_cpp_wrapper_cubin_path_name() -> str: + return "cubin_path" if torch.version.hip is None else "hsaco_path" + + +class CacheBase: + @staticmethod + @functools.lru_cache(None) + def get_system() -> Dict[str, Any]: + try: + from triton.compiler.compiler import triton_key + + # Use triton_key instead of triton.__version__ as the version + # is not updated with each code change + triton_version = triton_key() + except ModuleNotFoundError: + triton_version = None + + try: + system: Dict[str, Any] = { + "device": {"name": None}, + "version": { + "triton": triton_version, + }, + } + device_properties = torch.cuda.get_device_properties( + torch.cuda.current_device() + ) + if torch.version.cuda is not None: + system["device"]["name"] = device_properties.name + system["version"]["cuda"] = torch.version.cuda + else: + system["device"]["name"] = device_properties.gcnArchName + system["version"]["hip"] = torch.version.hip + except (AssertionError, RuntimeError): + # If cuda is not installed, none of the above config is relevant. + system = {} + + system["hash"] = hashlib.sha256( + json.dumps(system, sort_keys=True).encode("utf-8") + ).hexdigest() + + return system + + @staticmethod + @clear_on_fresh_inductor_cache + @functools.lru_cache(None) + def get_local_cache_path() -> Path: + return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) + + @staticmethod + @functools.lru_cache(None) + def get_global_cache_path() -> Optional[Path]: + return ( + Path(os.path.join(config.global_cache_dir, CacheBase.get_system()["hash"])) + if config.global_cache_dir is not None + else None + ) + + def __init__(self) -> None: + self.system = CacheBase.get_system() + + def get_local_cache(self) -> Dict[str, Any]: + local_cache_path = self.get_local_cache_path() + if not local_cache_path.is_file(): + return {} + with open(local_cache_path) as local_cache_fp: + local_cache = json.load(local_cache_fp) + return local_cache["cache"] + + def update_local_cache(self, local_cache: Dict[str, Any]) -> None: + local_cache_path = self.get_local_cache_path() + write_atomic( + str(local_cache_path), + json.dumps({"system": self.system, "cache": local_cache}, indent=4), + make_dirs=True, + ) + + +class LocalCache(CacheBase): + def lookup(self, *keys: str) -> Optional[Dict[str, Any]]: + cache = self.get_local_cache() + + sub_cache = cache + for key in keys: + if key in cache: + sub_cache = cache[key] + else: + return None + + return sub_cache + + def set_value(self, *keys: str, value: Any) -> None: + cache = self.get_local_cache() + + sub_cache = cache + for key in keys[0:-1]: + sub_cache.setdefault(key, {}) + sub_cache = sub_cache[key] + sub_cache[keys[-1]] = value + + self.update_local_cache(cache) + + +class PersistentCache(CacheBase): + @functools.lru_cache(None) # noqa: B019 + def get_global_cache(self) -> Dict[str, Any]: + global_cache_path = self.get_global_cache_path() + if global_cache_path is None or not global_cache_path.is_file(): + return {} + with open(global_cache_path) as global_cache_fp: + global_cache = json.load(global_cache_fp) + return global_cache["cache"] + + def lookup( + self, + choices: List[ChoiceCaller], + op: str, + inputs: str, + benchmark: Optional[Callable[[Any], Dict[ChoiceCaller, float]]], + ) -> Dict[ChoiceCaller, float]: + """ + Check to see if we have benchmarked the given choice callers. For each + choice caller: + + 1. Check global_cache[op][inputs][choice][precision], return benchmark if cached. + 2. Check local_cache[op][inputs][choice][precision], return benchmark if cached. + 3. If benchmark is not None: + a. `max_autotune_gemm=True`: benchmark the choice, update + local_cache[op][inputs][choice], and return the benchmark. + b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing. + """ + precision = torch.get_float32_matmul_precision() + + log_stats = partial(log_global_cache_stats, self.system, op, inputs, precision) + log_vals = partial(log_global_cache_vals, self.system, op, inputs, precision) + log_errors = partial( + log_global_cache_errors, self.system, op, inputs, precision + ) + timings = {} + + def check_cache(cache: Dict[str, Any], callback: Any = None) -> bool: + """Check if `cache` contains data for all the choices""" + hit = True + for choice in choices: + choice_hash = choice.hash_key() + if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}): + # cache hit + timings[choice] = cache[op][inputs][precision][choice_hash] + else: + # cache miss + hit = False + break + if callback: + callback(cached=hit) + return hit + + if config.max_autotune or config.max_autotune_gemm: + local_cache = self.get_local_cache() if config.autotune_local_cache else {} + # check local cache first since it is data specific to the current machine + if ( + not check_cache(local_cache) + and not ( + use_global_cache() + and check_cache(self.get_global_cache(), callback=log_stats) + ) + and benchmark is not None + ): + try: + # re-benchmark everything to try to get consistent numbers from the same machine + timings = benchmark(choices) + assert all(choice in timings for choice in choices) + local_cache.setdefault(op, {}) + local_cache[op].setdefault(inputs, {}).setdefault(precision, {}) + for choice, timing in timings.items(): + local_cache[op][inputs][precision][choice.hash_key()] = timing + except RuntimeError as e: + # catch and log autotuning failures + log_errors(e) + raise e + + self.update_local_cache(local_cache) + + timings_to_log = { + choice.hash_key(): timings[choice] for choice in choices + } + log_vals(timings_to_log) + elif use_global_cache(): + # only check global cache, not local one + check_cache(self.get_global_cache(), callback=log_stats) + # may have a partial cache hit, where not everything is benchmarked + + return timings + + +def get_lock_dir() -> str: + lock_dir = os.path.join(cache_dir(), "locks") + if not os.path.exists(lock_dir): + os.makedirs(lock_dir, exist_ok=True) + return lock_dir + + +def sha256_hash(data: bytes) -> str: + # [:51] to strip off the "Q====" suffix common to every hash value. + return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower() + + +def code_hash(code: Union[str, bytes], extra: str = "") -> str: + hashing_str = code if isinstance(code, bytes) else code.encode("utf-8") + if extra != "": + hashing_str = hashing_str + b"||" + extra.encode("utf-8") + return "c" + sha256_hash(hashing_str) + + +def get_path( + basename: str, extension: str, specified_dir: str = "" +) -> Tuple[str, str, str]: + if specified_dir: + if os.path.isabs(specified_dir): + subdir = specified_dir + else: + subdir = os.path.join(cache_dir(), specified_dir) + else: + subdir = os.path.join(cache_dir(), basename[1:3]) + path = os.path.join(subdir, f"{basename}.{extension}") + return basename, subdir, path + + +def get_hash( + content: Union[str, bytes], extra: str = "", hash_type: str = "code" +) -> str: + if hash_type == "code": + return code_hash(content, extra) + if hash_type in ["cubin", "hsaco", "spv"]: + return code_hash(repr(content)) + raise AssertionError(f"Unknown hash type {hash_type}") + + +def write( + content: Union[str, bytes], + extension: str, + extra: str = "", + hash_type: str = "code", + specified_dir: str = "", +) -> Tuple[str, str]: + # use striped content to compute hash so we don't end up with different + # hashes just because the content begins/ends with different number of + # spaces. + key: str = get_hash(content.strip(), extra, hash_type) + basename, subdir, path = get_path(key, extension, specified_dir) + encode_utf_8: bool = hash_type == "code" + if not os.path.exists(path): + write_atomic(path, content, make_dirs=True) + return basename, path + + +def write_text(text: str) -> str: + """ + Write the `text` to a file and return the path computed based on the hash. + """ + return write(text, "txt")[1] + + +def write_atomic( + path_: str, + content: Union[str, bytes], + make_dirs: bool = False, + encode_utf_8: bool = False, +) -> None: + # Write into temporary file first to avoid conflicts between threads + # Avoid using a named temporary file, as those have restricted permissions + assert isinstance( + content, (str, bytes) + ), "Only strings and byte arrays can be saved in the cache" + path = Path(path_) + if make_dirs: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" + write_mode = "w" if isinstance(content, str) else "wb" + with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: + f.write(content) + tmp_path.rename(path) + + +@dataclasses.dataclass +class TensorMetadataAndValues: + """ + TensorMetadata plus the elements as a list of raw values. + Used for hashing inlined constants. + """ + + tensor_metadata: TensorMetadata + values: List[Any] + + +def _ident(x: T) -> T: + return x + + +def extract_tensor_metadata_for_cache_key( + device_map: Dict[torch.device, torch.device], t: Tensor +) -> TensorMetadata: + """ + Extracts the tensor metadata and removes fields of the TensorMetadata + that are not needed for caching + """ + meta = extract_tensor_metadata(t) + if not hasattr(t, "_is_inductor_static"): + meta = dataclasses.replace(meta, storage_offset=0, storage_bytes=None) + + # The pickle implementation avoids serializing the same object more than once. + # That behavior means the byte stream we create to hash will vary if, for example, + # we see two tensor objects with the same device, but the torch.device object is + # actually the same object vs. merely equivalent. We want to produce the same hash + # value in either situation, so we memoize the device objects and always reference + # the same object for a given device. It's possible other metadata fields deserve + # the same treatment, but so far we've only observed this issue with the device. + if meta.device not in device_map: + device_map[meta.device] = meta.device + meta = dataclasses.replace(meta, device=device_map[meta.device]) + + return meta + + +def _reduce_fake_tensor( + device_map: Dict[torch.device, torch.device], t: Tensor +) -> Tuple[Callable[[T], T], Tuple[TensorMetadata]]: + """ + See FxGraphCachePickler. Custom reducer to pickle FakeTensors. + """ + metadata = extract_tensor_metadata_for_cache_key(device_map, t) + return (_ident, (metadata,)) + + +def _reduce_tensor( + device_map: Dict[torch.device, torch.device], t: Tensor +) -> Tuple[Callable[[T], T], Tuple[TensorMetadataAndValues]]: + """ + See FxGraphCachePickler. Custom reducer to pickle Tensors. + If we see tensors, we know they're constants stored as attributes on + the GraphModule. Include the values in the key calculation. Small + tensors will be inlined, so we can't serve the same cache entry for + different values anyway. Large constants are treated as parameters, + so we could conceivably reuse a cache entry. To do that, however, + PyCodeCache would need more complexity to create a new module from its + cache, but with the right constants attached as attributes. + """ + if t.is_mkldnn: + # TODO: These tensors don't currently pickle, so we can't cache a + # compiled graph containing them. Just fail now. If mkldnn tensors + # get pickling support, we can remove this. + raise BypassFxGraphCache("mkldnn tensors unpickleable.") + + # Very large tensors could be expensive to copy to cpu and hash. Let's + # at least report if we find slowness. + start = time() + values = t.tolist() + elapsed = time() - start + if elapsed > 1.0: + warnings.warn( + f"FX graph cache handling of a large constant took {elapsed:.1}s. Please file an issue." + ) + + metadata = extract_tensor_metadata_for_cache_key(device_map, t) + return (_ident, (TensorMetadataAndValues(metadata, values),)) + + +def _reduce_symint(s: SymInt) -> Tuple[Callable[[T], T], Tuple[str]]: + """ + See FxGraphCachePickler. Custom reducer to pickle SymInts. + """ + # For hashing purposes, we only care about the name of the symbol and + # not the backed value. We evaluate guards stored with a cached graph + # to ensure a cached entity with SymInt args is safe to reuse. + return (_ident, (str(s),)) + + +def _reduce_unsupported(s: Any) -> NoReturn: + """ + See FxGraphCachePickler. Custom reducer to handle any objects that we don't + support and therefore raise to bypass caching. + """ + raise BypassFxGraphCache("Reduce unsupported.") + + +class FxGraphCachePickler(pickle.Pickler): + """ + Custom pickler to customize the pickling of some objects (Tensors), only for the + purpose of computing a hash for keying into the FxGraphCache. Tensors contain + objects that don't pickle and/or vary between runs, and we want to capture the + data that allow us to compute a stable, but safe hash. + """ + + # See extract_tensor_metadata_for_cache_key. Whenever we extract metadata during + # pickling, we make sure devices always reference the same torch.device object. + _device_map: Dict[torch.device, torch.device] = {} + + dispatch_table = copyreg.dispatch_table.copy() + dispatch_table[FakeTensor] = functools.partial(_reduce_fake_tensor, _device_map) + dispatch_table[torch.Tensor] = functools.partial(_reduce_tensor, _device_map) + dispatch_table[torch.SymInt] = _reduce_symint + dispatch_table[ + torch.fx.experimental._backward_state.BackwardState + ] = _reduce_unsupported + + @classmethod + def dumps(cls, obj: Any) -> bytes: + """ + Pickle an object using the FxGraphCachePickler. + """ + with io.BytesIO() as stream: + pickler = cls(stream) + # TODO: pickler.fast is technically deprecated. Will this work on new python versions? + pickler.fast = True # Run with pickler.fast so it doesn't intern strings, making the hash result more predictable + try: + pickler.dump(obj) + except (TypeError, AttributeError) as e: + # Some configs options are callables, e.g., post_grad_custom_pre_pass, + # and may not pickle. + log.warning("Can't pickle", exc_info=True) + raise BypassFxGraphCache("Config options may be unpickleable.") from e + return stream.getvalue() + + @classmethod + def get_hash(cls, obj: Any) -> str: + """ + Serialize an object using the FxGraphCachePickler and return a hash + of the pickled object. + """ + serialized_data = cls.dumps(obj) + return sha256_hash(serialized_data) + + @classmethod + def debug_lines(cls, inp: FxGraphHashDetails) -> List[str]: + """ + Get a printable string describing in more detail all the attributes + comprising an object. Useful for debugging when one graph hashes + to a different value than another. + """ + + def get_str(obj: Any) -> str: + if isinstance(obj, torch.Tensor): + return str(extract_tensor_metadata_for_cache_key(cls._device_map, obj)) + elif isinstance(obj, bytes): + return "" + elif type(obj) in cls.dispatch_table: + # Run the reducer on the object + return str(cls.dispatch_table[type(obj)](obj)[1]) + else: + return str(obj) + + lines = [] + for attr, obj in vars(inp).items(): + if isinstance(obj, list): + for ii in range(len(obj)): + h = cls.get_hash(obj[ii]) + lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}") + elif isinstance(obj, dict): + for k, v in obj.items(): + h = cls.get_hash(v) + lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}") + else: + h = cls.get_hash(obj) + lines.append(f"[{h}] {attr}: {get_str(obj)}") + return lines + + +def build_code_hash( + roots: List[str] | None, prefix: str, hasher: hashlib._Hash +) -> None: + for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name): + spec = lib.module_finder.find_spec(lib.name, None) + assert spec is not None + module = spec.origin + assert module is not None + with open(module, "rb") as f: + hasher.update(spec.name.encode("utf-8")) + hasher.update(f.read()) + if lib.ispkg: + # need to also hash submodules + build_code_hash(spec.submodule_search_locations, f"{spec.name}.", hasher) + + +@functools.lru_cache(None) +def torch_key() -> bytes: + """ + Compute a key that contains relevant information about torch source files + """ + if not config.is_fbcode(): + + def get_code_hash(root: str) -> bytes: + # This function isn't meant to be used outside of torch_key, just a + # helper for clarity. Instead, use torch_key() directly when you need + # a hash representing the state of the source code. + extra_files = ( + "codegen/aoti_runtime/interface.cpp", + "codegen/aoti_runtime/implementation.cpp", + "codegen/cpp_prefix.h", + "script.ld", + ) + inductor_root = os.path.dirname(__file__) + extra_files = [os.path.join(inductor_root, x) for x in extra_files] + hasher = hashlib.sha256() + hasher.update(torch.__version__.encode("utf-8")) + build_code_hash([root], "", hasher) + for path in extra_files: + if os.path.exists(path): + with open(path, "rb") as f: + hasher.update(f.read()) + return hasher.digest() + + return get_code_hash(_TORCH_PATH) + + from libfb.py import parutil + + return parutil.get_file_contents("torch/src_hash.txt").rstrip().encode("ascii") + + +def get_inductor_root() -> str: + return os.path.dirname(__file__) + + +@dataclasses.dataclass +class OrderedSetHolder: + """ + See FxGraphHashDetails. Holds a sorted list to support stable hashing + of set kwargs. + """ + + items: List[Any] + + +class BypassFxGraphCache(Exception): + """ + Exception to indicate that the FxGraphCache should be bypassed. + """ + + +class FxGraphHashDetails: + """ + Object to capture all the details for a compiled FX graph relevant to computing + a safe and stable cache key. + """ + + # Excluded kwargs param that are not stable between runs + EXCLUDED_KWARGS = ["graph_id"] + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], + ) -> None: + self.gm = gm + self.example_inputs = example_inputs + + # Order kwargs so hashing is stable to changes in kwarg order. + self.fx_kwargs = {} + for k in sorted(fx_kwargs): + if k not in self.EXCLUDED_KWARGS: + if type(fx_kwargs[k]) is set: + # Special case to handle set params. Python sets can't be + # ordered, so sort the elements and store them in a proxy. + self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k])) + else: + self.fx_kwargs[k] = fx_kwargs[k] + + # Alignment checks + self.inputs_to_check = inputs_to_check + + # 'Deterministic algorithms' can affect codegen via lowering to cuda kernels. + self.deterministic_algorithms_settings = ( + torch.are_deterministic_algorithms_enabled(), + torch.is_deterministic_algorithms_warn_only_enabled(), + torch.utils.deterministic.fill_uninitialized_memory, # type: ignore[attr-defined] + ) + + # Global settings affecting matmul codegen. + self.cuda_matmul_settings = ( + torch.backends.cuda.matmul.allow_tf32, + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction, + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction, + ) + + # Also hash on various system info (including the triton compiler version). + self.torch_version = torch_key() + self.system_info = CacheBase.get_system() + self.inductor_config = config.save_config_portable() + + def debug_lines(self) -> List[str]: + """ + Get a printable string describing in more detail all the attributes + comprising this object. Useful for debugging when one graph hashes + to a different value than another. + """ + return FxGraphCachePickler.debug_lines(self) + + +def compiled_fx_graph_hash( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], +) -> Tuple[str, List[str]]: + """ + Generate a unique hash of the FX graph for caching. + """ + details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check) + # The prefix distinguishes among the other kinds of objects we + # cache in this module. + key = "f" + FxGraphCachePickler.get_hash(details) + debug_lines = details.debug_lines() + debug_str = "\n".join(debug_lines) + log.debug(f"FX graph cache hash details for key {key}:\n{debug_str}") # noqa: G004 + return key, debug_lines + + +def cudagraph_post_compile( + example_inputs: List[Any], + compiled_graph: CompiledFxGraph, + cudagraphs: BoxedBool, +) -> None: + """ + Checks for any reasons not to run cudagraphs and then + runs it on compiled_graph. + Mutates the `compiled_graph.current_callable` and `cudagraphs` + """ + assert compiled_graph.current_callable is not None + assert compiled_graph.cudagraph_info is not None + cached_info = compiled_graph.cudagraph_info + cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons + inputs_to_check = compiled_graph.inputs_to_check + boxed_forward_device_index = compiled_graph.boxed_forward_device_index + is_inference = compiled_graph.fx_kwargs["is_inference"] + is_backward = compiled_graph.fx_kwargs["is_backward"] + + if not cudagraph_fail_reasons: + fx_kwargs = compiled_graph.fx_kwargs + static_input_idxs = fx_kwargs["static_input_idxs"] + + placeholders = cached_info.placeholders + stack_traces = cached_info.stack_traces + if not config.triton.cudagraph_trees: + # Force specialize all inputs so that CUDA graphs will work + for t in example_inputs: + if isinstance(t, torch.SymInt): + int(t) # guard + + if ( + boxed_forward_device_index is not None + and not is_inference + and not is_backward + ): + boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs))) + + from .compile_fx import cudagraphify + + current_callable = compiled_graph.current_callable + assert current_callable is not None + compiled_graph.current_callable = cudagraphify( + current_callable, + static_input_idxs=static_input_idxs, + device_index=next(iter(compiled_graph.device_idxs)), + stack_traces=stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=tuple(compiled_graph.constants.values()), + placeholders=placeholders, + mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs), + ) + + else: + BoxedBool.disable(cudagraphs) + + # See [Backward Generation Handling] + # if cudagraph'd the forward and set the device, we need to let the cudagraph manager + # know we are we running the backward even if we will not run it in cudagraphs + if is_backward and config.triton.cudagraph_trees: + assert boxed_forward_device_index is not None + assert boxed_forward_device_index.value is not None + compiled_graph_callable = compiled_graph.current_callable + + manager = torch._inductor.cudagraph_trees.get_manager( + boxed_forward_device_index.value, create_if_none_exists=False + ) + # should already exist from forward + assert manager is not None + + def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]: + manager.set_to_running_backward() # type: ignore[union-attr] + return compiled_graph_callable(new_inputs) + + compiled_graph.current_callable = compiled_artifact + + if "cuda" in compiled_graph.device_types: + # prefer better disable_cudagraphs_reason bc stack trace + # TODO: migrate all disable reasons to stack trace, refactor + if compiled_graph.disabled_cudagraphs_reason: + log_cudagraph_skip_and_bump_counter( + compiled_graph.disabled_cudagraphs_reason + ) + else: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {cudagraph_fail_reasons}" + ) + + +def maybe_realign_inputs( + ran_cudagraphs: BoxedBool, + compiled_graph: CompiledFxGraph, + inputs_to_check: Sequence[int], +) -> None: + """ + Realigns input strides from inputs_to_check if + we didn't end up running cudagraphs. Mutates + `compiled_graph.current_callable` if cudagraphs + was run. Otherwise, does nothing. + """ + if not ran_cudagraphs: + assert compiled_graph.current_callable is not None + new_callable = align_inputs_from_check_idxs( + compiled_graph.current_callable, inputs_to_check + ) + if new_callable is not compiled_graph.current_callable: + compiled_graph.current_callable = new_callable + + +def add_ephemeral_timeout_increase_for_distributed(time_saved_ns: int) -> int: + """ + Ephemerally increases the NCCL timeout when compiling for a distributed job + Returns amount of seconds increased + """ + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return 0 + + increased_timeout_sec = int(time_saved_ns // 1e9) # convert to seconds + + if config.is_fbcode(): + fudge_factor = torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:ephemeral_timeout_fudge_factor_percentage" + ) + log.info( + "Ephemeral NCCL timeout increase fudge factor %d and original increase value %d", + fudge_factor, + increased_timeout_sec, + ) + increased_timeout_sec += int(increased_timeout_sec * fudge_factor / 100) + + log.info("Increasing NCCL timeout by %d", increased_timeout_sec) + dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs( + timedelta(seconds=increased_timeout_sec) + ) + return increased_timeout_sec + + +class FxGraphCache: + """ + Supports caching and reusing compiled Fx graphs. + + The overall strategy is as follows: + - This cache stores entries on disk. When saving an entry, we can't + serialize callables (that could be C++, Triton, etc.), so we serialize + their own disk cache location. We then recreate the compiled artifact + after fetching from disk. + - For indexing the cache, we gather the fields relevant to identifying an + FxGraph (the graph module, graph inputs, system settings etc.) into an + FxGraphCacheDetails object, pickle it, and compute a hash for the key. + See FxGraphCachePickler. + - Among the metadata we store, we also include a guards expression that's + appropriate for validating any symbols for Tensor arguments that have + symbolic bounds. On cache lookup then, we evaluate those guards in the + current context to validate that a cached entry can be served. + - A given graph could have multiple compiled versions, corresponding to + different sets of guards. Therefore, we store cache entries in the form: + // + - On lookup, we compute the key from the graph details, iterate over all + leaf files in the corresponding subdirectory, deserialize the entry, and + evaluate its guards expression. If the evaluation succeeds, we have a + cache hit. If it fails, we compile the graph and store a new entry. + - Finally, on a cache hit, we need to make sure any guards that would + have been created during compilation are added to the current context. + """ + + # TODO(masnesral): Investigate whether it's beneficial to store compiled graphs + # in an in-memory cache after loading from disk. + @staticmethod + def _get_tmp_dir() -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(cache_dir(), "fxgraph") + + @staticmethod + def _get_tmp_dir_for_key(key: str) -> str: + """ + Return the disk location for a given cache key. + """ + return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key) + + @staticmethod + def _filter_backed_symints(inputs: List[Any]) -> List[torch.SymInt]: + """ + Get the backed SymInt objects from the input list. Note that we can never + have guards that depend on unbacked symint. + """ + return [s for s in inputs if isinstance(s, torch.SymInt) and has_hint(s)] + + @staticmethod + def _get_shape_env() -> Optional[ShapeEnv]: + """ + Helper to get the shape env from the tracing context. + """ + ctx = torch._guards.TracingContext.try_get() + if not ctx: + return None + return ctx.fake_mode.shape_env + + @staticmethod + def _lookup_graph( + key: str, + example_inputs: List[torch.Tensor], + local: bool, + remote_cache: Optional[RemoteCache[JsonDataTy]], + ) -> Optional[CompiledFxGraph]: + """ + Lookup a compiled graph in the cache by key. On a hit, return the + deserialized CompiledFxGraph object. On a miss, return None. + """ + shape_env = FxGraphCache._get_shape_env() + assert shape_env is not None + + symints = FxGraphCache._filter_backed_symints(example_inputs) + hints = [hint_int(s) for s in symints] + + def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]: + if local: + subdir = FxGraphCache._get_tmp_dir_for_key(key) + if os.path.exists(subdir): + for path in sorted(os.listdir(subdir)): + try: + with open(os.path.join(subdir, path), "rb") as f: + yield pickle.load(f) + except Exception: + log.warning( + "fx graph cache unable to load compiled graph", + exc_info=True, + ) + + if remote_cache: + try: + if (cache_data := remote_cache.get(key)) is not None: + assert isinstance(cache_data, dict) + data = cache_data["data"] + assert isinstance(data, (str, bytes)) + content = base64.b64decode(data) + yield pickle.loads(content) + except Exception: + log.warning( + "fx graph cache unable to load compiled graph", exc_info=True + ) + + # Iterate over any entries in the subdir for this key and evaluate + # their guards to determine whether there's a hit. + graph = None + + for candidate in iterate_over_candidates(): + if not candidate.guards_expr: + # No guards to evaluate, so this is a hit. + graph = candidate + break + + # Evaluate the guard expression in the current context. + # If there's not a cache hit, we don't want the evaluation to + # affect the current env, e.g., cause the creation of new guards, + # so we evaluate with the hints instead of the symbols. + hit = bool( + shape_env.evaluate_guards_expression(candidate.guards_expr, hints) + ) + log.debug( + "fx graph cache key %s evaluating guards [%s] with values %s => hit=%s", + key, + candidate.guards_expr, + hints, + hit, + ) + if hit: + graph = candidate + break + + if graph is None: + return None + + # See _save_graph(); we don't store the callable in the cache entry so + # recreate it here from the PyCodeCache disk cache. + artifact_path = get_path(graph.cache_key, "py")[2] + code = graph.source_code + if not os.path.exists(artifact_path): + counters["inductor"]["fxgraph_lookup_write_file"] += 1 + Path(os.path.dirname(artifact_path)).mkdir(parents=True, exist_ok=True) + cpp_pp = cpp_prefix_path() + if os.path.basename(cpp_pp) in code: + if cpp_pp in code: + # Great the name is correct + pass + else: + # Old dir name is included, replace it + pattern = rf'#include\s*"[^"]+{os.path.basename(cpp_pp)}"' + code = re.sub(pattern, f'#include "{cpp_pp}"', code) + + write_atomic(artifact_path, code, make_dirs=True) + + try: + graph.current_callable = PyCodeCache.load_by_key_path( + graph.cache_key, + artifact_path, + graph.cache_linemap, + graph.constants, + ).call + except OSError: + # Not expected, but in case the PyCodeCache entry is removed from + # underneath us, treat it as a cache miss and recompile. + log.error("Failed to load cached artifact: %s", artifact_path) + return None + + # Now re-evaluate with the symints to add any guards to the current env. + if graph.guards_expr: + check = bool( + shape_env.evaluate_guards_expression(graph.guards_expr, symints) + ) + assert check is True + log.debug( + "fx graph cache key %s post-load guards: %s", key, shape_env.guards + ) + + # Increment the cached metrics/counters by the amounts recorded when the FX + # graph was compiled for this cache entry. Pretending these counters + # were incremented normally is useful for testing with the cache enabled. + metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas) + counters["inductor"] += graph.counter_deltas + + from .graph import GraphLowering + + GraphLowering.save_output_code(code) + output_code_log.debug("Output code written to: %s", artifact_path) + output_code_log.debug("Output code: \n%s", code) + # On cache hit, use artifact path as filename + trace_structured( + "inductor_output_code", + lambda: {"filename": artifact_path}, + payload_fn=lambda: code, + ) + return graph + + @staticmethod + def post_compile( + compiled_graph: CompiledFxGraph, + example_inputs: List[torch.Tensor], + cudagraphs: BoxedBool, + ) -> CompiledFxGraph: + """ + Run a set of post processing steps after loading from the cache. These involve: + - Setting the tracing context output strides + - Running cudagraphs if enabled + - Realigning inputs + + This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph. + The results of this function are *not* saved in the cache itself. + """ + set_tracing_context_output_strides(example_inputs, compiled_graph) + + if cudagraphs: + # It's possible that cudagraphs is enabled, but was disabled + # during a previous compilation we're loading from the cache. + # If so, we need to disable it on this new process too. + if compiled_graph.disabled_cudagraphs_reason: + if "cuda" in compiled_graph.device_types: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}" + ) + else: + counters["inductor"]["cudagraph_skips"] += 1 + BoxedBool.disable(cudagraphs) + else: + cudagraph_post_compile( + example_inputs, + compiled_graph, + cudagraphs, + ) + inputs_to_check = compiled_graph.inputs_to_check + # cudagraphs could have been disabled from the earlier conditions + # so we still need to realign inputs if that happens + maybe_realign_inputs( + cudagraphs, + compiled_graph, + inputs_to_check, + ) + + return compiled_graph + + @staticmethod + def _save_graph( + key: str, + compiled_graph: CompiledFxGraph, + example_inputs: List[torch.Tensor], + local: bool, + remote_cache: Optional[RemoteCache[JsonDataTy]], + ) -> None: + """ + Store a serialized CompiledFxGraph on disk. + """ + disk_compiled_graph = copy(compiled_graph) + # We can't really serialize callables that may be C++/Triton/etc., + # so we serialize their PyCodeCache disk cache location instead. + # TODO: This could be better if we're ever able to serialize compiled + # models to disk. + disk_compiled_graph.current_callable = None + + # Before serializing, compute the guard expression that will be used to + # ensure that a CompiledFxGraph is valid when loaded from the cache. It's + # sufficient to consider only the SymInt args to the fx graph since the + # Tensor shapes are already captured in the hash for the cache key. Any + # Tensor arg with a symbolic shape will have a SymInt arg for the graph. + shape_env = FxGraphCache._get_shape_env() + assert shape_env is not None + symints = FxGraphCache._filter_backed_symints(example_inputs) + guards = shape_env.get_pruned_guards(symints) + disk_compiled_graph.guards_expr = shape_env.produce_guards_expression( + placeholders=symints, guards=guards + ) + + try: + content = pickle.dumps(disk_compiled_graph) + except Exception: + log.warning( + "fx graph cache unable to serialize compiled graph", exc_info=True + ) + counters["inductor"]["fxgraph_cache_pickle_error"] += 1 + return + + try: + if local: + subdir = FxGraphCache._get_tmp_dir_for_key(key) + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + # Use a hash of the serialized CompiledFxGraph to get a unique file + # name. The specific name doesn't matter since a lookup involves + # iterating over all entries in the parent subdir. + path = os.path.join(subdir, sha256_hash(content)) + write_atomic(path, content, make_dirs=True) + + if remote_cache: + time_taken_ms = int((disk_compiled_graph._time_taken_ns or 0) // 1e6) + cache_data: JsonDataTy = { + "data": base64.b64encode(content).decode("ascii"), + "time_taken_ms": time_taken_ms, + } + remote_cache.put(key, cache_data) + except Exception: + log.warning("fx graph unable to write to cache", exc_info=True) + counters["inductor"]["fxgraph_cache_write_error"] += 1 + + @staticmethod + def _check_can_cache(gm: torch.fx.GraphModule) -> None: + """ + Check some conditions that would preclude caching and raise BypassFxGraphCache + to bypass in case caching is not possible. + """ + # Freezing can embed constants that wouldn't be static across runs. + if config.freezing or config.aot_inductor.use_runtime_constant_folding: + raise BypassFxGraphCache( + "Freezing may introduce constants that aren't static across runs." + ) + + # The treatment of guards in the caching implementation requires that + # we have a shape env. + if FxGraphCache._get_shape_env() is None: + log.debug("fx graph cache no shape env") + raise BypassFxGraphCache("No shape env.") + + # HigherOrderOperators should be handled on a case-by-case basis. + # Currently, we just skip caching if we have any. + # We also skip if there are any torchbind objects. + for node in gm.graph.nodes: + if isinstance(node.target, torch._ops.HigherOrderOperator): + raise BypassFxGraphCache("Can't cache HigherOrderOperators.") + if node.op == "getattr" and isinstance( + getattr(gm, node.target), torch._C.ScriptObject + ): + raise BypassFxGraphCache("Can't cache torchbind objects.") + + @staticmethod + def load( # type: ignore[no-untyped-def] + compile_fx_fn: Callable[..., Any], + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], + local: bool, + remote: bool, + ): + """ + Load a compiled graph from the cache. If a cached entry does not exist, + compile the graph and save it to the cache. + """ + assert local or remote, "at least one of them needs to be enabled" + compiled_graph = None + cache_state = None + cache_event_time = None + cache_info: Dict[str, Any] = {} + try: + FxGraphCache._check_can_cache(gm) + key, debug_lines = compiled_fx_graph_hash( + gm, example_inputs, fx_kwargs, inputs_to_check + ) + cache_info["key"] = key + cache_info["components"] = debug_lines + + remote_cache: Optional[RemoteCache[JsonDataTy]] = None + if remote: + cache_id = "fx-graph-v1" + try: + if config.is_fbcode(): + from torch._inductor.fb.remote_cache import FbRemoteFxGraphCache + + remote_cache = FbRemoteFxGraphCache(cache_id) + else: + from torch._inductor.remote_cache import RemoteFxGraphCache + + remote_cache = RemoteFxGraphCache(cache_id) + except ModuleNotFoundError as e: + # No need for a stack trace on this error + remote_cache = None + log.warning("Unable to create a remote cache: %s", e) + except Exception: + remote_cache = None + log.warning("Unable to create a remote cache", exc_info=True) + + compiled_graph = FxGraphCache._lookup_graph( + key, example_inputs, local, remote_cache + ) + + if compiled_graph is None: + log.debug("fx graph cache miss for key %s", key) + counters["inductor"]["fxgraph_cache_miss"] += 1 + cache_state = "miss" + start_time = time_ns() + cache_event_time = start_time + compiled_graph = compile_fx_fn( + gm, example_inputs, inputs_to_check, fx_kwargs + ) + compiled_graph._time_taken_ns = time_ns() - start_time + cache_info["time_taken_ns"] = compiled_graph._time_taken_ns + FxGraphCache._save_graph( + key, + compiled_graph, + example_inputs, + local, + remote_cache, + ) + else: + log.debug("fx graph cache hit for key %s", key) + counters["inductor"]["fxgraph_cache_hit"] += 1 + cache_state = "hit" + cache_event_time = time_ns() + if (time_saved_ns := compiled_graph._time_taken_ns) is not None: + cache_info["time_saved_ns"] = time_saved_ns + if ( + ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( + time_saved_ns + ) + ) != 0: + cache_info["ephemeral_timeout_increase"] = ephemeral_increase + compiled_graph._fx_graph_cache_key = key + except BypassFxGraphCache as e: + counters["inductor"]["fxgraph_cache_bypass"] += 1 + cache_state = "bypass" + log.info("Bypassing FX Graph Cache because '%s'", e) + cache_info["cache_bypass_reason"] = str(e) + if remote: + log_cache_bypass("bypass_fx_graph", str(e)) + cache_event_time = time_ns() + + if not compiled_graph: + compiled_graph = compile_fx_fn( + gm, example_inputs, inputs_to_check, fx_kwargs + ) + assert compiled_graph is not None + cache_info["cache_state"] = cache_state + chromium_log = get_chromium_event_logger() + chromium_log.log_instant_event( + f"fx_graph_cache_{cache_state}", cache_event_time, metadata=cache_info + ) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_cache_hash", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + # Use the passed in cudagraphs so that we mutate the BoxedBool correctly + FxGraphCache.post_compile( + compiled_graph, example_inputs, fx_kwargs["cudagraphs"] + ) + return compiled_graph + + @staticmethod + def clear() -> None: + """ + Clear out the on-disk cache. + """ + try: + shutil.rmtree(FxGraphCache._get_tmp_dir()) + except FileNotFoundError: + pass + + +_StrideExprStr: TypeAlias = str + + +@dataclasses.dataclass +class CompiledFxGraph: + """ + Class holding a compiled FX graph. This is the object serialized on disk + to support FxGraph caching. + """ + + current_callable: Optional[Callable[..., Any]] + cache_key: str + source_code: str = dataclasses.field(repr=False) # Do not display source_code + cache_linemap: Optional[List[Tuple[int, str]]] + device_types: Set[str] + device_idxs: Set[int] + mutated_inputs: Set[str] + mutated_input_idxs: Set[int] + constants: Dict[str, torch.Tensor] + torchbind_constants: Dict[str, torch._C.ScriptObject] + output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]] + disabled_cudagraphs_reason: Optional[str] + metrics_deltas: metrics.CachedMetricsDeltas + counter_deltas: Counter[str] + # This is a string representation of an expression we serialize + # with the object so the guards can be evaluated in a different + # context in order to verify the validity of serving a cached + # fx graph. The expression must be generated by: + # ShapeEnv.produce_guards_expression() + guards_expr: Optional[str] + + cudagraph_info: Optional[CudagraphCachedInfo] + fx_kwargs: Dict[str, Any] + inputs_to_check: Sequence[int] + boxed_forward_device_index: Optional[BoxedDeviceIndex] + + _time_taken_ns: Optional[int] = None + _boxed_call: Optional[bool] = None + _fx_graph_cache_key: Optional[str] = None + + def __init__( + self, + current_callable: Optional[Callable[..., Any]], + graph: GraphLowering, + output_strides: List[Optional[Tuple[_StrideExprStr, ...]]], + disabled_cudagraphs_reason: Optional[str], + metrics_deltas: metrics.CachedMetricsDeltas, + counter_deltas: Counter[str], + ) -> None: + self.current_callable = current_callable + self.cache_key = graph.cache_key + if graph.cache_path: + with open(graph.cache_path) as f: + self.source_code = f.read() + self.cache_linemap = graph.cache_linemap + # TODO - ordered set + self.device_types = set(graph.device_types) + self.device_idxs = set(graph.device_idxs) + self.mutated_inputs = set(graph.mutated_inputs) + self.mutated_input_idxs = set(graph.mutated_input_idxs) + self.constants = graph.constants + self.torchbind_constants = graph.torchbind_constants + self.output_strides = output_strides + self.disabled_cudagraphs_reason = disabled_cudagraphs_reason + self.metrics_deltas = metrics_deltas + self.counter_deltas = counter_deltas + self.guards_expr = None + self.cudagraph_info = None + self.fx_kwargs = {} + self.inputs_to_check = () + self.boxed_forward_device_index = None + + def __call__(self, inputs: List[Any]) -> Any: + assert self.current_callable is not None + return self.current_callable(inputs) + + +def run_command_and_check(cmd_: str) -> None: + cmd = shlex.split(cmd_) + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError as e: + raise exc.CppCompileError(cmd, e.output) from e + + +@functools.lru_cache(None) +def split_aot_inductor_output_path(path: str) -> Tuple[str, str]: + """Returns the path where the AOT Inductor compiled kernels are stored.""" + if path.endswith(".so"): + return os.path.split(path) + elif path.endswith(".pt2"): + return os.path.split(path) + else: + return path, "" + + +@clear_on_fresh_inductor_cache +class CudaKernelParamCache: + cache: Dict[str, Dict[str, str]] = {} + cache_clear = staticmethod(cache.clear) + + @classmethod + def set(cls, key: str, params: Dict[str, str], cubin: str, bin_type: str) -> None: + _, path = write( + cubin, + bin_type, + hash_type=bin_type, + specified_dir=split_aot_inductor_output_path( + config.aot_inductor.output_path + )[0], + ) + + params[get_cpp_wrapper_cubin_path_name()] = path + + cls.cache[key] = params + + @classmethod + def get(cls, key: str) -> Optional[Dict[str, str]]: + return cls.cache.get(key, None) + + @classmethod + def get_keys(cls) -> KeysView[str]: + return cls.cache.keys() + + +class AotCodeCompiler: + @classmethod + def compile( + cls, + graph: GraphLowering, + source_code: str, + serialized_extern_kernel_nodes: Optional[str], + cuda: bool, + ) -> str: + if sys.platform == "win32": + raise RuntimeError("AotCodeCompiler not yet supported for inductor") + + _set_gpu_runtime_env() # cpp_extension consults the env + + picked_vec_isa = pick_vec_isa() + vec_isa_cmd_gen = CppBuilder( + name="o", + sources="i", + BuildOption=CppTorchCudaOptions( + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + ), + ) + # write function will calc source_code hash, the same source code with different + # ISA level should be generate different hash. + # So we need get a command_line which contains isa related parameter as a part of hash key. + # And then pass the command_line to below write function as extra parameter to + # guarantee the source code hash contains ISA difference. + cpp_command = repr(vec_isa_cmd_gen.get_command_line()) + + fbcode_aot_cpu_re = False + use_absolute_path = False + if config.is_fbcode(): + ld_command = build_paths.ld() + if not cuda and graph.aot_mode: # Meta internal AOTInductor CPU + objcopy_command = build_paths.objcopy_fallback() + fbcode_aot_cpu_re = True + use_absolute_path = True + else: + objcopy_command = build_paths.objcopy() + else: + ld_command = "ld" + objcopy_command = "objcopy" + + ( + specified_output_path, + specified_so_name, + ) = split_aot_inductor_output_path(config.aot_inductor.output_path) + key, input_path = write( + source_code, + "cpp", + extra=cpp_command, + specified_dir=specified_output_path, + ) + output_code_log.info("Output code written to: %s", input_path) + trace_structured( + "graph_dump", + lambda: { + "name": "inductor_aot_code", + "type": "cpp", + "filename": input_path, + }, + payload_fn=lambda: source_code, + ) + + # We use a file lock below to protect FS operations. The lock file + # is scoped to the 'key', so make sure the consts_path is protected + # by the same lock: + consts_specified_dir = os.path.join(os.path.split(input_path)[0], key) + + def _compile_consts_linux(consts: bytes) -> str: + _, consts_path = write( + consts, + "bin", + specified_dir=consts_specified_dir, + ) + + consts_o = os.path.splitext(consts_path)[0] + ".o" + if fbcode_aot_cpu_re: + cmd = f"{ld_command} -r -b binary -o {os.path.basename(consts_o)} {os.path.basename(consts_path)}" + compile_file(consts_path, consts_o, cmd.split()) + os.chmod(consts_o, 0o644) + else: + cmd = f"{ld_command} -r -b binary -o {consts_o} {consts_path}" + run_command_and_check(cmd) + log.debug("aot constant binary command: %s", cmd) + + if graph.mutated_buffers & set(graph.constants.keys()): + # .data section is between .text and .bss. When the size of .data is large, + # during the linking, the relocation of .text against .bss may overflow. + # Rename it to .ldata so that it won't be in between the .text and .bss section + if len(consts) > 2_000_000_000: + raise ValueError( + "Models with buffer mutation included doesn't support constants greater than 2GB!" + ) + rename_data = " .data=.ldata" + else: + # if no buffer mutation is needed, we could instead set the data region + # as read-only (i.e. .lrodata) which could accomodate larger size of data + # to be linked. + rename_data = " .data=.lrodata,alloc,load,readonly,data,contents" + + assert ( + ALIGN_BYTES & (ALIGN_BYTES - 1) + ) == 0 and ALIGN_BYTES >= 64, "must be power of 2 and >= 64" + cmd = ( + f"{objcopy_command} --rename-section" + f"{rename_data}" + f" --set-section-alignment .data={ALIGN_BYTES}" # following the gAlignment of CPU in c10/core/alignment.h + f" {consts_o} {consts_o}" + ) + log.debug("aot constant rename section command: %s", cmd) + run_command_and_check(cmd) + + cmd = f"rm {consts_path}" + log.debug("aot constant bin removal command: %s", cmd) + run_command_and_check(cmd) + + if fbcode_aot_cpu_re: + body = re.sub(r"[\W]", "_", os.path.basename(consts_path)) + else: + body = re.sub(r"[\W]", "_", consts_path) + + symbol_list = [] + symbol_list.append( + f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_start {consts_o}" + ) + symbol_list.append( + f"{objcopy_command} --redefine-sym _binary_{body}_size=_binary_constants_bin_size {consts_o}" + ) + symbol_list.append( + f"{objcopy_command} --redefine-sym _binary_{body}_end=_binary_constants_bin_end {consts_o}" + ) + log.debug("aot constant binary redefine symbol: %s", " ".join(symbol_list)) + for cmd in symbol_list: + run_command_and_check(cmd) + return consts_o + + def _compile_consts_darwin(consts: bytes) -> str: + if config.aot_inductor.debug_dump_consts_bin: + _, _binary_constants_path = write( + consts, + "bin", + specified_dir=consts_specified_dir, + ) + log.debug("binary constants path: %s", _binary_constants_path) + + is_large_consts = len(consts) > 1024 + consts_asm = "\t.section\t__DATA,__data\n" + consts_asm += "\t.globl\t__binary_constants_bin_start\n" + consts_asm += "__binary_constants_bin_start:\n" + if not is_large_consts: + for c in consts: + consts_asm += f"\t.byte {c}\n" + # Add one element even if constants are empty + # Otherwise assembler will not put them in data section + if not consts: + consts_asm += "\t.space 1\n" + else: + consts_asm += "\t.quad 0x1234567899abcdef\n" + consts_asm += f"\t.space {len(consts) - 8}\n" + consts_asm += ".globl\t__binary_constants_bin_end\n" + consts_asm += "__binary_constants_bin_end:\n" + _, consts_path = write( + consts_asm, + "S", + specified_dir=consts_specified_dir, + ) + consts_o = os.path.splitext(consts_path)[0] + ".o" + cmd = f"{get_cpp_compiler()} -c -o {consts_o} {consts_path}" + run_command_and_check(cmd) + if is_large_consts: + with open(consts_o, "r+b") as f: + f.seek(0) + hdr = f.read(1024) + # Search for magic number and write the actual data over it + start_idx = hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12") + assert start_idx != -1 + f.seek(start_idx) + pos = 0 + while pos < len(consts): + rc = f.write(consts[pos:]) + pos += rc + return consts_o + + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + # Currently, this only support serializing extern nodes in fbcode + # Eventually, we should also have a serializer for OSS. + if serialized_extern_kernel_nodes: + output_json = os.path.splitext(input_path)[0] + ".json" + with open(output_json, "w") as f: + f.write(serialized_extern_kernel_nodes) + + output_so = ( + config.aot_inductor.output_path + if specified_so_name + else os.path.splitext(input_path)[0] + ".so" + ) + + output_o = os.path.splitext(input_path)[0] + ".o" + + all_cuda = all( + graph.get_original_value_of_constant(name).is_cuda + for name in graph.constants.keys() + if name not in graph.folded_constants + ) + + def get_nbytes_of_tensor(tensor: torch.Tensor, all_cuda: bool) -> int: + n_bytes = ( + torch.ops.mkldnn._nbytes(tensor) + if tensor.is_mkldnn + else tensor.untyped_storage().nbytes() + ) + return n_bytes if all_cuda else _align(n_bytes) + + consts_size = sum( + get_nbytes_of_tensor(tensor, all_cuda) + for (name, tensor) in graph.constants.items() + if name not in graph.folded_constants + ) + # TODO: Fix mmap weights with cuda + use_mmap_weights = not config.is_fbcode() and consts_size > 2_000_000_000 + if config.aot_inductor.force_mmap_weights: + use_mmap_weights = True + + if config.aot_inductor.package: + ( + object_output_name, + object_output_dir, + ) = get_name_and_dir_from_output_file_path(input_path) + object_build_options = CppTorchCudaOptions( + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + compile_only=True, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + ) + object_builder = CppBuilder( + name=object_output_name, + sources=input_path, + output_dir=object_output_dir, + BuildOption=object_build_options, + ) + compile_cmd = object_builder.get_command_line() + output_o = object_builder.get_target_file_path() + + compile_flags = os.path.splitext(input_path)[0] + "_compile_flags.json" + object_build_options.save_flags_to_file(compile_flags) + + else: + ( + object_output_name, + object_output_dir, + ) = get_name_and_dir_from_output_file_path(input_path) + object_build_options = CppTorchCudaOptions( + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + compile_only=True, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + ) + object_builder = CppBuilder( + name=object_output_name, + sources=input_path, + output_dir=object_output_dir, + BuildOption=object_build_options, + ) + compile_cmd = object_builder.get_command_line() + output_o = object_builder.get_target_file_path() + + log.debug("aot compilation command: %s", compile_cmd) + if fbcode_aot_cpu_re: + output_o = os.path.splitext(input_path)[0] + ".o" + compile_file(input_path, output_o, compile_cmd.split()) + os.chmod(output_o, 0o644) + else: + run_command_and_check(compile_cmd) + + def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes: + def _pad_to_alignment(raw_bytes: bytes) -> bytes: + padded_bytes = raw_bytes.ljust( + (len(raw_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES * ALIGN_BYTES, + b"\x00", + ) + return padded_bytes + + # This serializes the tensor's untyped_storage to bytes by accessing + # the raw data of the underlying structure. + import ctypes + + if t.numel() == 0: + return b"" + + if t.is_mkldnn: + data_ptr = torch.ops.mkldnn.data_ptr(t) + nbytes = torch.ops.mkldnn._nbytes(t) + else: + t_cpu = t.untyped_storage().cpu() + data_ptr = t_cpu.data_ptr() + nbytes = t_cpu.nbytes() + + raw_array = ctypes.cast( + data_ptr, + ctypes.POINTER(ctypes.c_ubyte * nbytes), + ) + raw_bytes = bytes(raw_array.contents) + return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes) + + serialized_weights = b"".join( + _to_bytes(graph.get_original_value_of_constant(name), all_cuda) + for name in graph.constants.keys() + if name not in graph.folded_constants + ) + if not use_mmap_weights: + aot_constants = serialized_weights + magic_number = 0 + else: + magic_number = cast( + int, torch.randint(0, torch.iinfo(torch.int64).max, (1,)).item() + ) + aot_constants = struct.pack("qq", consts_size + 8, magic_number) + + consts_o = { + "linux": _compile_consts_linux, + "darwin": _compile_consts_darwin, + }[sys.platform](aot_constants) + + if config.aot_inductor.package: + output_name, output_dir = get_name_and_dir_from_output_file_path( + output_so + ) + so_build_options = CppTorchCudaOptions( + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + use_absolute_path=use_absolute_path, + ) + so_builder = CppBuilder( + name=output_name, + sources=[output_o, consts_o], + output_dir=output_dir, + BuildOption=so_build_options, + ) + link_cmd = so_builder.get_command_line() + output_so = so_builder.get_target_file_path() + + linker_flags = os.path.splitext(input_path)[0] + "_linker_flags.json" + so_build_options.save_flags_to_file(linker_flags) + + from torch._inductor.package import package_aoti + + if use_mmap_weights: + weight_file = ( + os.path.splitext(input_path)[0] + "_serialized_weights.bin" + ) + with open(weight_file, "wb") as f_weights: + f_weights.write(serialized_weights) + f_weights.write(struct.pack("q", magic_number)) + + archive_path = package_aoti(os.path.split(input_path)[0]) + return archive_path + else: + output_name, output_dir = get_name_and_dir_from_output_file_path( + output_so + ) + so_build_options = CppTorchCudaOptions( + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + use_absolute_path=use_absolute_path, + ) + so_builder = CppBuilder( + name=output_name, + sources=[output_o, consts_o], + output_dir=output_dir, + BuildOption=so_build_options, + ) + link_cmd = so_builder.get_command_line() + output_so = so_builder.get_target_file_path() + + log.debug("aot linkage command: %s", link_cmd) + if fbcode_aot_cpu_re: + output_so = ( + config.aot_inductor.output_path + if specified_so_name + else os.path.splitext(input_path)[0] + ".so" + ) + compile_file([output_o, consts_o], output_so, link_cmd.split()) + os.chmod(output_so, 0o755) + else: + run_command_and_check(link_cmd) + + if use_mmap_weights: + import resource + + page_size_ = resource.getpagesize() + page_size = max(16384, page_size_) + + with open(output_so, "a+b") as f_so: + so_size = f_so.tell() + # Page align the weights + f_so.write(b" " * (page_size - so_size % page_size)) + f_so.write(serialized_weights) + f_so.write(struct.pack("q", magic_number)) + + # Append cmds to the end of codegen-ed wrapper file + with open(input_path, "a") as f: + f.write("\n") + f.write(f"// Compile cmd\n// {compile_cmd}\n") + f.write(f"// Link cmd\n// {link_cmd}\n") + + return output_so + + +# Putting this fn in cpp.py (unfortunately) causes a deadlock, which is why it's in codecache.py. +# Why? importing from cpp.py invokes codecache.pick_vec_isa(), which takes out a lock. +# Cycle goes: +# - CppCodeCache.load() +# - pick_vec_isa() +# - valid_vec_isa_list() +# - VecISA.__bool__() <-- takes out a lock +# - compile_file() <-- imports cpp_prefix_path from cpp, which causes us to try to take out the same lock. +@clear_on_fresh_inductor_cache +@functools.lru_cache +def cpp_prefix_path() -> str: + path = Path(__file__).parent / "codegen/cpp_prefix.h" + with path.open() as f: + content = f.read() + _, filename = write( + content, + "h", + ) + return normalize_path_separator(filename) + + +def cpp_prefix() -> str: + filename = cpp_prefix_path() + if config.is_fbcode(): + # We need relative paths, since we bundle up + # everything that we compile into a folder for remote compilation. + return f'#include "{os.path.basename(filename)}"' + else: + return f'#include "{filename}"' + + +# Given a path to an input cpp file and an output path, +# Attempts to compile the file, storing the output in "output_path" +def compile_file( + input_path: Union[str, List[str]], output_path: str, cmd: List[str] +) -> None: + with dynamo_timed("compile_file"): + return _compile_file(input_path, output_path, cmd) + + +def _compile_file( + input_path: Union[str, List[str]], output_path: str, cmd: List[str] +) -> None: + input_paths = [input_path] if isinstance(input_path, str) else input_path + input_files = [ + os.path.basename(ip) if config.is_fbcode() else ip for ip in input_paths + ] + try: + if config.is_fbcode(): + # Need to copy our header into the same folder as the sourcecode. + header_path = cpp_prefix_path() + header_name = os.path.basename(header_path) + output_name = os.path.basename(output_path) + # When we build remotely, we need to make sure to carefully copy any files + # that are required during the compilation process into our build directly. + # This is where all of the ATen/c10/Torch includes come from. + torch_includes_path = os.path.join(_TORCH_PATH, "include") + with tempfile.TemporaryDirectory() as tmp_dir: + # Copy everything to tmp compilation folder + shutil.copy(header_path, os.path.join(tmp_dir, header_name)) + shutil.copy(_LINKER_SCRIPT, os.path.join(tmp_dir, "script.ld")) + for p, f in zip(input_paths, input_files): + shutil.copy(p, os.path.join(tmp_dir, f)) + dest_include_path = os.path.join(tmp_dir, "include") + shutil.copytree(torch_includes_path, dest_include_path) + # Run the build + output_file_path = _run_build_command(cmd, tmp_dir, output_name) + # Copy output from the build + if os.path.exists(output_path): + os.remove(output_path) + shutil.copy(output_file_path, output_path) + else: + subprocess.check_output(cmd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + output = e.output.decode("utf-8") + openmp_problem = "'omp.h' file not found" in output or "libomp" in output + if openmp_problem and sys.platform == "darwin": + instruction = ( + "\n\nOpenMP support not found. Please try one of the following solutions:\n" + "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ " + "that has builtin OpenMP support;\n" + "(2) install OpenMP via conda: `conda install llvm-openmp`;\n" + "(3) install libomp via brew: `brew install libomp`;\n" + "(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path" + " with `include/omp.h` under it." + ) + output += instruction + raise exc.CppCompileError(cmd, output) from e + + +_libgomp: Optional[CDLL] = None + + +def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p]: + # This function will be called from generated cpp wrapper code in the JIT mode. + # Because tensors will be passed in as AtenTensorHandle, we need to explicitly convert them. + def convert_arg(arg: Any) -> Any: + if str(type(arg)) == "": + # No easy way to do isinstance check on PyCapsule + return torch._C._aoti.alloc_tensor_by_stealing_from_void_ptr(arg) + elif isinstance(arg, (list, tuple)): + return type(arg)(convert_arg(a) for a in arg) + else: + return arg + + converted_args = [convert_arg(arg) for arg in args] + + assert op.startswith("torch.ops."), ( + op + " can not be called through custom_op_wrapper" + ) + func = None + for i, s in enumerate(op.split(".")): + if i == 0: + func = importlib.import_module(s) + func = getattr(func, s) + + assert callable(func), op + " can not be loaded through custom_op_wrapper" + result = func(*converted_args) + if isinstance(result, (list, tuple)): + for r in result: + assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors" + return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type] + else: + assert isinstance(result, torch.Tensor), op + " returns a non-tensor" + return torch._C._aoti.unsafe_alloc_void_ptr_from_tensor(result) + + +@clear_on_fresh_inductor_cache +class CppCodeCache: + cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} + cache_clear = staticmethod(cache.clear) + cpp_compile_command_flags: Dict[str, Any] = {} + + @staticmethod + def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]: + return cdll.LoadLibrary(path) + + @classmethod + def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]: + try: + result = cls._load_library_inner(path, key) + result.key = key # type: ignore[union-attr] + return result + except (ImportError, OSError) as e: + if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"): + # hacky workaround for fbcode/buck + global _libgomp + _libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1") + result = cls._load_library_inner(path, key) + result.key = key # type: ignore[union-attr] + return result + if "failed to map segment from shared object" in str(e): + raise OSError( + f"{e}. The most common reason this may occur is if the {tempfile.gettempdir()} folder " + "is mounted with noexec (e.g., by default Docker mounts tmp file systems " + f"as noexec). Please remount {tempfile.gettempdir()} with exec enabled, or set another " + "temporary directory with TORCHINDUCTOR_CACHE_DIR environment variable." + ) from e + raise + + @classmethod + def load_async( + cls, + source_code: str, + cuda: bool = False, + submit_fn: Any = None, + extra_flags: Sequence[str] = (), + ) -> Any: + compile_command = { + **cls.cpp_compile_command_flags, + "cuda": cuda, + "vec_isa": pick_vec_isa(), + "extra_flags": extra_flags, + } + + _set_gpu_runtime_env() # cpp_extension consults the env + + command_gen = CppBuilder( + name="o", sources="i", BuildOption=CppTorchCudaOptions(**compile_command) + ) + # write function will calc source_code hash, the same source code with different + # ISA level should be generate different hash. + # So we need get a command_line which contains isa related parameter as a part of hash key. + # And then pass the command_line to below write function as extra parameter to + # guarantee the source code hash contains ISA difference. + vec_isa_cmd = repr(command_gen.get_command_line()) + key, input_path = write(source_code, "cpp", extra=vec_isa_cmd) + + if key not in cls.cache: + from filelock import FileLock + + lock_path = os.path.join(get_lock_dir(), key + ".lock") + output_name, output_dir = get_name_and_dir_from_output_file_path(input_path) + """ + If `fb_code` env, it need to be dispatched to original `compile_file` function. + So, we still need to prepare parameters for the function: `input_path` and `fb_output_path`. + """ + fb_output_path = input_path[:-3] + "so" + future: Optional[Future[Any]] = None + lib = None + + cpp_build_option = CppTorchCudaOptions(**compile_command) + cpp_builder = CppBuilder( + name=output_name, + sources=input_path, + output_dir=output_dir, + BuildOption=cpp_build_option, + ) + + worker_fn = functools.partial( + _worker_compile_cpp, + lock_path, + cpp_builder, + input_path, + fb_output_path, + ) + + binary_path = normalize_path_separator( + fb_output_path + if config.is_fbcode() + else cpp_builder.get_target_file_path() + ) + + def load_fn() -> Any: + nonlocal lib + if lib is None: + if future is not None: + future.result() + result = worker_fn() + assert result is None + lib = cls._load_library(binary_path, key) + assert lib is not None + return lib + + if submit_fn is not None: + with FileLock(lock_path, timeout=LOCK_TIMEOUT): + if not os.path.exists(binary_path): + future = submit_fn(worker_fn) + + cls.cache[key] = load_fn + + return cls.cache[key] + + @classmethod + def load(cls, source_code: str, cuda: bool = False) -> Any: + return cls.load_async(source_code, cuda)() + + +def _worker_compile_cpp( + lock_path: str, + cpp_builder: CppBuilder, + fb_input_path: str, + fb_output_path: str, +) -> None: + from filelock import FileLock + + with FileLock(lock_path, timeout=LOCK_TIMEOUT): + binary_path = ( + fb_output_path if config.is_fbcode() else cpp_builder.get_target_file_path() + ) + if not os.path.exists(binary_path): + if config.is_fbcode(): + compile_file( + fb_input_path, + fb_output_path, + shlex.split(cpp_builder.get_command_line()), + ) + else: + cpp_builder.build() + + +# Customized Python binding for cpp kernels +@clear_on_fresh_inductor_cache +class CppPythonBindingsCodeCache(CppCodeCache): + cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} + cache_clear = staticmethod(cache.clear) + cpp_compile_command_flags = { + # kernels have no dependency on libtorch + "include_pytorch": False, + "shared": True, + } + entry_function = "kernel" + call_entry_function = "kernel(%s);Py_RETURN_NONE;" + extra_parse_arg = "" + suffix_template = textwrap.dedent( + """ + // Python bindings to call %s(): + #define PY_SSIZE_T_CLEAN + #include + #include + #include + + #ifndef _MSC_VER + #if __cplusplus < 202002L + // C++20 (earlier) code + // https://en.cppreference.com/w/cpp/language/attributes/likely + #define likely(x) __builtin_expect(!!(x), 1) + #define unlikely(x) __builtin_expect(!!(x), 0) + #endif + #else + #define likely(x) (x) + #define unlikely(x) (x) + #endif + + // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow. + // We manually link it below to workaround issues with fbcode build. + static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj); + + template static inline T parse_arg(PyObject* args, size_t n) { + static_assert(std::is_pointer::value, "arg type must be pointer or long"); + return static_cast(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n))); + } + template <> inline int64_t parse_arg(PyObject* args, size_t n) { + auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n)); + if(unlikely(result == -1 && PyErr_Occurred())) + throw std::runtime_error("expected int arg"); + return result; + } + template <> inline uintptr_t parse_arg(PyObject* args, size_t n) { + auto result = PyLong_AsVoidPtr(PyTuple_GET_ITEM(args, n)); + if(unlikely(result == reinterpret_cast(-1) && PyErr_Occurred())) + throw std::runtime_error("expected int arg"); + return reinterpret_cast(result); + } + + %s + + static PyObject* %s_py(PyObject* self, PyObject* args) { + try { + if(unlikely(!PyTuple_CheckExact(args))) + throw std::runtime_error("tuple args required"); + if(unlikely(PyTuple_GET_SIZE(args) != %s)) + throw std::runtime_error("requires %s args"); + %s + } catch(std::exception const& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + } catch(...) { + PyErr_SetString(PyExc_RuntimeError, "unhandled error"); + return nullptr; + } + } + + static PyMethodDef py_methods[] = { + {"%s", %s_py, METH_VARARGS, ""}, + {NULL, NULL, 0, NULL}}; + + static struct PyModuleDef py_module = + {PyModuleDef_HEAD_INIT, "%s", NULL, -1, py_methods}; + + PyMODINIT_FUNC PyInit_%s(void) { + const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"); + if(!str_addr) { + PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set"); + return nullptr; + } + std::istringstream iss(str_addr); + uintptr_t addr = 0; + iss >> addr; + _torchinductor_pyobject_tensor_data_ptr = + reinterpret_cast(addr); + return PyModule_Create(&py_module); + } + """ + ) + + @classmethod + def _load_library_inner(cls, path: str, key: str) -> ModuleType: + os.environ["_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"] = str( + torch._C._dynamo.guards._torchinductor_pyobject_tensor_data_ptr # type: ignore[attr-defined] + ) + module_name = f"{key}.{cls.entry_function}" + try: + return sys.modules[module_name] + except KeyError: + pass + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore[union-attr] + return module + + @classmethod + def load_pybinding_async( + cls, + argtypes: List[str], + source_code: str, + cuda: bool = False, + num_outputs: int = -1, + submit_fn: Any = None, + extra_flags: Sequence[str] = (), + ) -> Any: + """ + Wrap a C++ function in fast Python bindings. + + Args: + argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"] + source_code: C++ source code containing a ENTRY_FUNCTION() function + + Returns: + A python version of ENTRY_FUNCTION() + """ + parseargs = ", ".join( + f"parse_arg<{argtype.replace('const ', '')}>(args, {n})" + for n, argtype in enumerate(argtypes) + ) + suffix = cls.suffix_template % ( + cls.entry_function, + cls.extra_parse_arg % num_outputs if cls.extra_parse_arg else "", + cls.entry_function, + len(argtypes), + len(argtypes), + cls.call_entry_function % parseargs, + cls.entry_function, + cls.entry_function, + cls.entry_function, + cls.entry_function, + ) + get_result = cls.load_async( + source_code + suffix, cuda, submit_fn=submit_fn, extra_flags=extra_flags + ) + result = None + + def future() -> Any: + nonlocal result + if result is None: + result = get_result() + assert isinstance(result, ModuleType) + return getattr(result, cls.entry_function) + + return future + + @classmethod + def load_pybinding(cls, *args: Any, **kwargs: Any) -> Any: + return cls.load_pybinding_async(*args, **kwargs)() + + +@clear_on_fresh_inductor_cache +class CppWrapperCodeCache(CppPythonBindingsCodeCache): + cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} + cache_clear = staticmethod(cache.clear) + cpp_compile_command_flags = { + "include_pytorch": True, + "shared": True, + } + entry_function = "inductor_entry_cpp" + call_entry_function = "return inductor_entry_cpp(%s);" + extra_parse_arg = textwrap.dedent( + """ + #include + + static inline std::vector unpack_tensor_handle_list(PyObject* pyvec) { + std::vector result; + size_t result_len = PyList_GET_SIZE(pyvec); + result.reserve(result_len); + for (size_t i = 0; i < result_len; i++) { + // AtenTensorHandle is essentially a pointer + void* elem = PyCapsule_GetPointer(PyList_GET_ITEM(pyvec, i), NULL); + result.push_back(reinterpret_cast(elem)); + } + return result; + } + + static inline PyObject* pack_tensor_handle_list(const std::vector& cppvec) { + size_t result_len = cppvec.size(); + PyObject* result = PyList_New(static_cast(result_len)); + for (size_t i = 0; i < result_len; i++) { + PyObject *elem = + cppvec[i] == nullptr + ? Py_None + // Store AtenTensorHandle as PyCapsulate + : PyCapsule_New(reinterpret_cast(cppvec[i]), NULL, NULL); + PyList_SET_ITEM(result, i, elem); + } + return result; + } + + template <> inline std::vector parse_arg>(PyObject* args, size_t n) { + return unpack_tensor_handle_list(PyTuple_GET_ITEM(args, n)); + } + + PyObject* inductor_entry_cpp(std::vector&& input_handles) { + // For outputs, we only allocate a vector to hold returned tensor handles, + // not allocating the actual output tensor storage here + std::vector output_handles(%s); + try { + inductor_entry_impl(input_handles.data(), output_handles.data()); + return pack_tensor_handle_list(output_handles); + } catch(std::exception const& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return {}; + } catch(...) { + PyErr_SetString(PyExc_RuntimeError, "unhandled error"); + return {}; + } + } + """ + ) + + +@clear_on_fresh_inductor_cache +class HalideCodeCache(CppPythonBindingsCodeCache): + cache: Dict[str, Callable[[], Union[ModuleType, CDLL]]] = {} + cache_clear = staticmethod(cache.clear) + _standalone_runtime_path: Optional[str] = None + prefix = textwrap.dedent( + """ + #include "{halideruntime_h}" + #include "{headerfile}" + #include + #include + + namespace c10 {{ + inline long div_floor_integer(long a, long b) {{ + if ((a<0) != (b<0)) {{ + const auto quot = a / b; + const auto rem = a % b; + return rem ? quot - 1 : quot; + }} + return a / b; + }} + }} + """ + ) + glue_template_cpp = prefix + textwrap.dedent( + """ + void kernel({argdefs}) {{ + {buffers} + int err = halide_kernel({buffer_names}); + if(err != 0) throw std::runtime_error("halide_kernel failed"); + }} + """ + ) + glue_template_cuda = prefix + textwrap.dedent( + """ + #include + static const halide_device_interface_t* cuda_interface = halide_cuda_device_interface(); + + void kernel({argdefs}, uintptr_t stream) {{ + {buffers} + int err = halide_kernel(reinterpret_cast(stream), {buffer_names}); + if(err != 0) throw std::runtime_error("halide_kernel failed"); + }} + """ + ) + standalone_runtime_cuda_init = textwrap.dedent( + """ + #include "{}" + #include + + static int acquire_context(void* user_context, + void** cuda_context_out, + bool create) {{ + return cuCtxGetCurrent(reinterpret_cast(cuda_context_out)); + }} + + static int release_context(void* user_context) {{ + return 0; + }} + + static int get_stream(void* user_context, + void* cuda_context, + void** stream_out) {{ + *stream_out = user_context; + return 0; + }} + + static int register_halide_hooks() {{ + halide_set_cuda_acquire_context(&acquire_context); + halide_set_cuda_release_context(&release_context); + halide_set_cuda_get_stream(&get_stream); + return 0; + }} + + int inductor_register_halide_hooks_result = register_halide_hooks(); + """ + ) + + @classmethod + def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> List[str]: + assert arg.shape is not None + assert arg.stride is not None and len(arg.shape) == len(arg.stride) + assert arg.offset is not None + data_ptr = f"{arg.alias_of or arg.name} + {arg.offset}" + if cuda: + device = f"reinterpret_cast({data_ptr})" + device_interface = "cuda_interface" + host = "nullptr" + flags = "halide_buffer_flag_device_dirty" + else: + device = "0" + device_interface = "nullptr" + host = f"reinterpret_cast({data_ptr})" + flags = "halide_buffer_flag_host_dirty" + + dims = [] + for size, stride in zip(arg.shape, arg.stride): + dims.append(f"halide_dimension_t(0, {size}, {stride})") + + return [ + f"halide_buffer_t {name};", + f"halide_dimension_t {name}_dims[] = {{{', '.join(dims)}}};", + f"{name}.device = {device};", + f"{name}.device_interface = {device_interface};", + f"{name}.host = {host};", + f"{name}.flags = {flags};", + f"{name}.type = {arg.halide_type()};", + f"{name}.dimensions = {len(dims)};", + f"{name}.dim = {name}_dims;", + f"{name}.padding = nullptr;", + ] + + @classmethod + def _codegen_glue(cls, meta: HalideMeta, headerfile: object) -> str: + is_cuda = meta.is_cuda() + assert is_cuda is ("user_context" in meta.target) + assert "no_runtime" in meta.target + buffers = [] + buffer_names = [] + for i, arg in enumerate(meta.argtypes): + if arg.is_buffer(): + buffer_names.append(f"&hl_buf_{i}") + buffers.extend(cls._codegen_buffer(f"hl_buf_{i}", arg, is_cuda)) + else: + assert "*" not in arg.ctype + buffer_names.append(arg.name) + buffers = "\n".join([f" {line}" for line in buffers]).lstrip() + + glue_template = cls.glue_template_cuda if is_cuda else cls.glue_template_cpp + glue_code = glue_template.format( + halideruntime_h=cls.find_header( + "HalideRuntimeCuda.h" if is_cuda else "HalideRuntime.h" + ), + headerfile=headerfile, + argdefs=", ".join( + f"{a.bindings_type()} {a.name}" + for a in meta.argtypes + if a.alias_of is None + ), + buffers=buffers, + buffer_names=", ".join(buffer_names), + ) + return glue_code + + @classmethod + @functools.lru_cache(None) + def config_hash(cls) -> str: + command_gen = CppBuilder( + name="O", + sources="I", + BuildOption=CppOptions(), + ) + command_line = command_gen.get_command_line() + return sha256_hash( + "\n".join( + [ + cls.glue_template_cpp, + cls.glue_template_cuda, + cls.standalone_runtime_cuda_init, + command_line, + ] + ).encode("utf-8") + ) + + @staticmethod + def _search_for_file(suffix: str, errmsg: str) -> str: + spec = importlib.machinery.PathFinder.find_spec("halide") + if spec is None or not spec.submodule_search_locations: + raise RuntimeError("halide python bindings not installed") + try: + search = spec.submodule_search_locations[0] + for file in os.listdir(search): + if file.endswith(".so"): + try: + out = subprocess.check_output( + ["ldd", os.path.join(search, file)] + ) + except subprocess.SubprocessError: + continue + m = re.search(r"(/.*)/libHalide.so", out.decode("utf-8")) + if m: + path = os.path.join(os.path.abspath(m.group(1)), suffix) + if os.path.exists(path): + return os.path.abspath(path) + except Exception as e: + raise RuntimeError(errmsg) from e + raise RuntimeError(errmsg) + + @staticmethod + @functools.lru_cache(None) + def find_libautoschedule(name: str) -> str: + sofile = f"libautoschedule_{name.lower()}.so" + if "HALIDE_LIB" in os.environ: + path = os.path.join(os.environ["HALIDE_LIB"], sofile) + if os.path.exists(path): + return path + errmsg = ( + f"Can't find {sofile}, set env HALIDE_LIB to the directory containing it" + ) + return HalideCodeCache._search_for_file(sofile, errmsg) + + @staticmethod + @functools.lru_cache(None) + def find_header(name: str) -> str: + if "HALIDE_INCLUDE" in os.environ: + path = os.path.join(os.environ["HALIDE_INCLUDE"], name) + if os.path.exists(path): + return path + if "HALIDE_LIB" in os.environ: + path = os.path.abspath( + os.path.join(os.environ["HALIDE_LIB"], f"../include/{name}") + ) + if os.path.exists(path): + return path + errmsg = ( + f"Can't find {name}, set env HALIDE_INCLUDE to the directory containing it" + ) + return HalideCodeCache._search_for_file(f"../include/{name}", errmsg) + + @classmethod + def generate_halide_async( + cls, meta: HalideMeta, source_code: str, submit_fn: Any = None + ) -> Callable[[], Any]: + dirpath = Path( + get_path( + code_hash( + source_code, + extra=repr((cls.config_hash(), meta)), + ), + "halide", + )[2] + ) + os.makedirs(dirpath, exist_ok=True) + wait_for_compile = None + genfile = str(dirpath / "generate_kernel.py") + libfile = str(dirpath / "halide_kernel.a") + headerfile = str(dirpath / "halide_kernel.h") + donefile = str(dirpath / "done") + lockfile = str(dirpath / "lock") + need_compile = not os.path.exists(donefile) + jobs = [] + if need_compile: + write_atomic(genfile, source_code) + cmd = [ + sys.executable, + genfile, + "-g", + "kernel", + "-o", + f"{dirpath}", + "-f", + "halide_kernel", + "-e", + "static_library,h,schedule", + ] + if meta.scheduler: + cmd.extend(["-p", cls.find_libautoschedule(meta.scheduler)]) + cmd.extend(meta.args()) + jobs.append(functools.partial(subprocess.check_call, cmd)) + + binding_types = [ + arg.bindings_type() for arg in meta.argtypes if arg.alias_of is None + ] + if meta.is_cuda(): + binding_types.append("uintptr_t") # stream + bindings_future = cls.load_pybinding_async( + binding_types, + cls._codegen_glue(meta, headerfile), + extra_flags=(libfile, cls.build_standalone_runtime()), + submit_fn=jobs.append if need_compile else None, + cuda=meta.is_cuda(), + ) + + if need_compile: + jobs.append(functools.partial(touch, donefile)) + task = functools.partial(_worker_task_halide, lockfile, jobs) + if submit_fn: + wait_for_compile = submit_fn(task).result + else: + task() + + def load() -> Callable[[], Any]: + if wait_for_compile: + wait_for_compile() + return bindings_future() + + return load + + @classmethod + def generate_halide(cls, *args: Any, **kwargs: Any) -> Callable[[], Any]: + return cls.generate_halide_async(*args, **kwargs)() + + @classmethod + def build_standalone_runtime(cls) -> str: + if cls._standalone_runtime_path and os.path.exists( + cls._standalone_runtime_path + ): + return cls._standalone_runtime_path + is_cuda = torch.cuda.is_available() + libname = "libStandaloneHalideRuntime.so" + target = "host-cuda" if is_cuda else "host" + if cls._standalone_runtime_path: + assert not os.path.exists(cls._standalone_runtime_path) + # We hit this case in unittests when we run with fresh_inductor_cache() + # Generating a fresh runtime over and over causes errors because we initialize + # cuda hundreds of times in the same process and run out of file descriptors. + # Workaround by jail breaking the current fresh_inductor_cache(). + base = default_cache_dir() + else: + base = cache_dir() + dirpath = Path(base) / f"halide-runtime-{target}-{cls.config_hash()}" + os.makedirs(dirpath, exist_ok=True) + donefile = str(dirpath / "done") + lockfile = str(dirpath / "lock") + hookfile = str(dirpath / "hooks.cpp") + afile = str(dirpath / "standalone_halide_runtime.a") + sofile = str(dirpath / libname) + if not os.path.exists(donefile): + import filelock + import halide as hl # type: ignore[import-untyped,import-not-found] + + with filelock.FileLock(lockfile, LOCK_TIMEOUT): + if not os.path.exists(donefile): + with open(hookfile, "w") as f: + if is_cuda: + f.write( + cls.standalone_runtime_cuda_init.format( + cls.find_header("HalideRuntimeCuda.h") + ) + ) + hl.compile_standalone_runtime(afile, hl.Target(target)) + + name, output_dir = get_name_and_dir_from_output_file_path(sofile) + halide_cmd_gen = CppBuilder( + name=name, + sources=[hookfile, afile], + output_dir=output_dir, + BuildOption=CppTorchCudaOptions( + cuda=is_cuda, + ), + ) + + subprocess.check_call( + shlex.split(halide_cmd_gen.get_command_line()) + ) + touch(donefile) + assert os.path.exists(sofile) + cls._standalone_runtime_path = sofile + return sofile + + +def _worker_task_halide(lockfile: str, jobs: List[partial[Any]]) -> None: + from filelock import FileLock + + try: + with FileLock(lockfile, LOCK_TIMEOUT): + for job in jobs: + job() + except subprocess.SubprocessError as e: + if os.environ.get("HALIDE_REPRO") == "1": + python, script, *cmd = getattr(e, "cmd", ("", "", "")) + if os.path.basename(python).startswith("python"): + code = open(script).read() + main = " hl.main()" + assert code.count(main) == 1 + + class Out: + def __repr__(self) -> str: + return "out" + + cmd[cmd.index("-o") + 1] = Out() # type: ignore[call-overload] + repl = textwrap.indent( + textwrap.dedent( + f"""\ + import sys, tempfile + with tempfile.TemporaryDirectory() as out: + sys.argv = {["repro.py", *cmd]!r} + hl.main() + """ + ), + " ", + ) + code = code.replace(main, repl) + with open("repro.py", "w") as fd: + fd.write(code.lstrip()) + raise RuntimeError(f"wrote repro.py: {e}") from e + raise + + +def touch(filename: str): # type: ignore[no-untyped-def] + open(filename, "a").close() + + +@clear_on_fresh_inductor_cache +class PyCodeCache: + cache: Dict[str, ModuleType] = {} + linemaps: Dict[str, List[Tuple[Any, ...]]] = {} + cache_clear = staticmethod(cache.clear) + + @classmethod + def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]: + return write(source_code, "py", extra=extra) + + @classmethod + def load( + cls, + source_code: str, + extra: str = "", + linemap: Optional[List[Tuple[int, str]]] = None, + attrs: Optional[Dict[str, Any]] = None, + ) -> ModuleType: + key, path = write(source_code, "py", extra=extra) + return cls.load_by_key_path(key, path, linemap, attrs) + + @classmethod + def load_by_key_path( + cls, + key: str, + path: str, + linemap: Optional[List[Tuple[int, str]]] = None, + attrs: Optional[Dict[str, Any]] = None, + ) -> ModuleType: + if linemap is None: + linemap = [] + if key not in cls.cache: + mod = _reload_python_module(key, path) + + # another thread might set this first + cls.cache.setdefault(key, mod) + # unzip into separate lines/nodes lists + cls.linemaps[path] = list(zip(*linemap)) + + if attrs is not None: + for k, v in attrs.items(): + setattr(mod, k, v) + + if not (linemap or attrs): + mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined] + _reload_python_module_in_subproc, key, path + ) + + return cls.cache[key] + + @classmethod + @functools.lru_cache(None) + def stack_frames_for_code( + cls, path: str, lineno: int + ) -> Optional[List[Dict[str, Any]]]: + if path not in cls.linemaps: + return None + # [(starting_line, ), ...] + lines, nodes = cls.linemaps[path] + p = bisect_right(lines, lineno) + if p == 0: + return None + entry = nodes[p - 1] + if not entry: + return None + + def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]: + # ideally fx stores stack traces as data rather than a string + # but this is not along a performance critical path + regex = r'File "(.+)", line (\d+), in (.+)\n' + matches = re.findall(regex, stack_trace) + return [ + {"filename": f, "line": int(l), "name": n} + for f, l, n in reversed(matches) + ] + + return parse_stack_trace(entry) + + +class TritonCodeCache: + @classmethod + def load(cls, kernel_name: str, source_code: str) -> ModuleType: + return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name) + + +def _cuda_compiler() -> Optional[str]: + if cuda_env.nvcc_exist(config.cuda.cuda_cxx): + return config.cuda.cuda_cxx + if config.is_fbcode(): + return os.path.join(build_paths.cuda(), "bin", "nvcc") + if cuda_env.nvcc_exist(os.getenv("CUDACXX")): + return os.getenv("CUDACXX", "") + if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")): + return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc")) + return "nvcc" + + +def _cutlass_include_paths() -> List[str]: + if config.is_fbcode(): + from libfb.py import parutil + + cutlass_path = parutil.get_dir_path("cutlass-3-headers") + else: + cutlass_path = config.cuda.cutlass_dir + return [ + # Use realpath to get canonical absolute paths, in order not to mess up cache keys + os.path.realpath(os.path.join(cutlass_path, "include")), + os.path.realpath(os.path.join(cutlass_path, "tools/library/include")), + os.path.realpath(os.path.join(cutlass_path, "tools/library/src")), + os.path.realpath(os.path.join(cutlass_path, "tools/util/include")), + ] + + +def _cuda_lib_options() -> List[str]: + _set_gpu_runtime_env() # cpp_extension consults the env + from torch.utils import cpp_extension + + lpaths = cpp_extension.library_paths(cuda=True) + [ + sysconfig.get_config_var("LIBDIR") + ] + extra_ldflags: List[str] = [] + if is_linux(): + _transform_cuda_paths(lpaths) + for path in lpaths: + # -rpath ensures the DLL can find its dependencies when loaded, even + # if the library path is non-standard. + extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"]) + extra_ldflags.append("-lcuda") + extra_ldflags.append("-lcudart") + else: + raise NotImplementedError( + "Unsupported env, failed to find cuda libs! Currently only Linux is supported." + ) + return extra_ldflags + + +def _nvcc_host_compiler_options() -> List[str]: + return [ + "-fPIC", + "-fno-strict-aliasing", + "-fvisibility=hidden", + "-Wconversion", + ] + + +def _nvcc_compiler_options() -> List[str]: + arch = cuda_env.get_cuda_arch() + if arch == "90": + # Required by cutlass compilation. + arch = "90a" + code = [f"sm_{arch}", f"compute_{arch}"] + if config.cuda.enable_cuda_lto: + code += [f"lto_{arch}"] + options = [ + "-t=0", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-w", + f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", + config.cuda.compile_opt_level, + "-std=c++17", + "--expt-relaxed-constexpr", + "-DNDEBUG", + ] + if config.is_fbcode(): + options.extend(["-ccbin", os.path.dirname(build_paths.gcc())]) + if config.cuda.enable_debug_info: + options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]) + if config.cuda.enable_ptxas_info: + options.extend( + [ + "--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.) + "--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels + "--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels + "--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.) + "--source-in-ptx", + ] + ) # Annotate the ptx file with source information + if config.cuda.use_fast_math: + options.extend( + [ + "--use_fast_math", + "-DCUTLASS_USE_TANH_FOR_SIGMOID=1", + ] + ) + return options + + +def cuda_compile_command( + src_files: List[str], + dst_file: str, + dst_file_ext: str, + extra_args: Optional[List[str]] = None, +) -> str: + if extra_args is None: + extra_args = [] + include_paths = _cutlass_include_paths() + cuda_lib_options = _cuda_lib_options() + nvcc_host_compiler_options = _nvcc_host_compiler_options() + nvcc_compiler_options = _nvcc_compiler_options() + options = ( + nvcc_compiler_options + + extra_args + + [ + f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}" + for opt in nvcc_host_compiler_options + ] + + ["-I" + path for path in include_paths] + + cuda_lib_options + ) + src_file = " ".join(src_files) + res = "" + if dst_file_ext == "o": + res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}" + elif dst_file_ext == "so": + options.append("-shared") + res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" + elif dst_file_ext == "exe": + res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" + else: + raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") + log.debug("CUDA command: %s", res) + return res + + +class DLLWrapper: + """A wrapper for a dynamic library.""" + + def __init__( + self, + lib_path: str, + ) -> None: + self.lib_path = lib_path + self.is_open = False + self.DLL = cdll.LoadLibrary(lib_path) + self.is_open = True + + def close(self) -> None: + if self.is_open: + self._dlclose() + self.is_open = False + + def _dlclose(self) -> None: + f_dlclose = None + + if is_linux(): + syms = CDLL(None) + if not hasattr(syms, "dlclose"): + # Apline Linux + syms = CDLL("libc.so") + + if hasattr(syms, "dlclose"): + f_dlclose = syms.dlclose + elif is_windows(): + import ctypes + + kernel32 = ctypes.CDLL("kernel32", use_last_error=True) + + f_dlclose = kernel32.FreeLibrary + else: + raise NotImplementedError("Unsupported env, failed to do dlclose!") + + if f_dlclose is not None: + if is_linux(): + f_dlclose.argtypes = [c_void_p] + f_dlclose(self.DLL._handle) + elif is_windows(): + import ctypes + from ctypes import wintypes + + f_dlclose.argtypes = [wintypes.HMODULE] + f_dlclose(self.DLL._handle) + else: + log.warning( + "dll unloading function was not found, library may not be unloaded properly!" + ) + + def __getattr__(self, name: str) -> Callable[..., None]: + if not self.is_open: + raise RuntimeError(f"Cannot use closed DLL library: {self.lib_path}") + + method = getattr(self.DLL, name) + + def _wrapped_func(*args: Any) -> None: + err = method(*args) + if err: + raise RuntimeError(f"Error in function: {method.__name__}") + + return _wrapped_func + + def __enter__(self) -> DLLWrapper: # noqa: PYI034 + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + def __del__(self) -> None: + self.close() + + +@clear_on_fresh_inductor_cache +class CUDACodeCache: + @dataclasses.dataclass + class CacheEntry: + input_path: str + output_path: str + + cache: Dict[str, CacheEntry] = {} + cache_clear = staticmethod(cache.clear) + _SOURCE_CODE_SUFFIX = "cu" + + @classmethod + def write(cls, source_code: str, dst_file_ext: str) -> Tuple[str, str]: + """ + Writes source code into a file with dst_file_ext as the file extension. + Returns the hash key of source code, and the path to the file. + """ + + cuda_command = repr( + cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext) + ) + key, input_path = write( + source_code, cls._SOURCE_CODE_SUFFIX, extra=cuda_command + ) + return key, input_path + + @classmethod + def compile( + cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None + ) -> Tuple[str, str, str]: + """ + Compiles CUDA source_code into a file with dst_file_ext extension. + Returns a tuple of dst_file_path, hash_key, source_code_path + """ + key, input_path = cls.write(source_code, dst_file_ext) + if key not in cls.cache: + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext + if not os.path.exists(output_path): + cmd = cuda_compile_command( + [input_path], output_path, dst_file_ext, extra_args + ) + start_time = time() + log.debug("CUDA Compilation: %s", cmd) + cmd_parts = cmd.split(" ") + try: + subprocess.check_output( + cmd_parts, stderr=subprocess.STDOUT, env=os.environ + ) + except subprocess.CalledProcessError as error: + raise exc.CUDACompileError(cmd_parts, error.output) from error + end_time = time() + log_duration_msg = f"CUDA Compilation took {end_time - start_time} seconds. Compile command: {cmd}" + log.info(log_duration_msg) + else: + log.debug( + "CUDA Compilation skipped: %s since output already exists", + input_path, + ) + cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path) + + return (cls.cache[key].output_path, key, input_path) + + @classmethod + def load(cls, source_code: str, dst_file_ext: str) -> Tuple[DLLWrapper, str, str]: + """ + Compiles source code and loads the generated .so file. + Returns a tuple of DLLWrapper, hash_key, source_code_path + """ + + if dst_file_ext != "so": + raise RuntimeError( + f"Only support loading a .so file for now. " + f"Requested file extension: {dst_file_ext}. Source code: {source_code}" + ) + dst_file_path, hash_key, source_code_path = cls.compile( + source_code, dst_file_ext + ) + return (DLLWrapper(dst_file_path), hash_key, source_code_path) + + +@clear_on_fresh_inductor_cache +class ROCmCodeCache: + @dataclasses.dataclass + class CacheEntry: + input_path: str + output_path: str + + cache: Dict[str, CacheEntry] = {} + cache_clear = staticmethod(cache.clear) + _SOURCE_CODE_SUFFIX = "cpp" + _logged_compiler_version = False + + @classmethod + def write(cls, source_code: str, dst_file_ext: str) -> Tuple[str, str]: + """ + Writes source code into a file with dst_file_ext as the file extension. + Returns the hash key of source code, and the path to the file. + """ + + cuda_command = repr( + rocm_compile_command(["dummy_input"], "dummy_output", dst_file_ext) + ) + key, input_path = write( + source_code, cls._SOURCE_CODE_SUFFIX, extra=cuda_command + ) + return key, input_path + + @classmethod + def compile( + cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None + ) -> Tuple[str, str, str]: + """ + Compiles source_code into a file with dst_file_ext extension, + using the compile command specific for the ROCm platform. + Returns a tuple of dst_file_path, hash_key, source_code_path + """ + if not cls._logged_compiler_version: + cls._logged_compiler_version = True + log.debug(get_compiler_version_info(str(rocm_compiler()))) + + key, input_path = cls.write(source_code, dst_file_ext) + if key not in cls.cache: + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext + if not os.path.exists(output_path): + cmd = rocm_compile_command( + [input_path], output_path, dst_file_ext, extra_args + ) + start_time = time() + cmd_parts = cmd.split(" ") + try: + output = subprocess.check_output( + cmd_parts, + stderr=subprocess.STDOUT, + text=True, + env=os.environ, + ) + log.debug("Compilation output: %s", output) + except subprocess.CalledProcessError as error: + raise exc.CUDACompileError(cmd_parts, error.output) from error + end_time = time() + log_duration_msg = f"Compilation took {end_time - start_time} seconds. Compile command: {cmd}" + log.info(log_duration_msg) + else: + log.debug( + "Compilation skipped: %s since output already exists", + input_path, + ) + cls.cache[key] = ROCmCodeCache.CacheEntry(input_path, output_path) + + return (cls.cache[key].output_path, key, input_path) + + @classmethod + def load(cls, source_code: str, dst_file_ext: str) -> Tuple[DLLWrapper, str, str]: + """ + Compiles source code and loads the generated .so file. + Returns a tuple of DLLWrapper, hash_key, source_code_path + """ + + if dst_file_ext != "so": + raise RuntimeError( + f"Only support loading a .so file for now. " + f"Requested file extension: {dst_file_ext}. Source code: {source_code}" + ) + dst_file_path, hash_key, source_code_path = cls.compile( + source_code, dst_file_ext + ) + return (DLLWrapper(dst_file_path), hash_key, source_code_path) + + +class CodeCacheFuture: + def result(self) -> None: + raise NotImplementedError + + +class TritonFuture(CodeCacheFuture): + kernel: ModuleType + + def __init__( + self, + kernel: Any, + future: Optional[Future[Any]], + ) -> None: + self.kernel = kernel + self.future = future + + def result(self) -> ModuleType: # type: ignore[override] + if self.future is not None: + # If the worker failed this will throw an exception. + result = self.future.result() + assert result is None + self.future = None + self.kernel.precompile() + return self.kernel + + +class LambdaFuture(CodeCacheFuture): + def __init__(self, result_fn: Callable[..., Any]) -> None: + self.result_fn = result_fn + + def result(self) -> Callable[..., Any]: # type: ignore[override] + return self.result_fn() diff --git a/lib/python3.10/site-packages/torch/_inductor/comm_analysis.py b/lib/python3.10/site-packages/torch/_inductor/comm_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a233a3b9e21eb33d79b6ad60cd3ba4f276ddae --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/comm_analysis.py @@ -0,0 +1,264 @@ +import functools +import math +from enum import IntEnum + +import sympy + +import torch + +from . import ir +from .utils import get_dtype_size, sympy_product +from .virtualized import V + + +class NCCL_COLL(IntEnum): + ALL_REDUCE = 0 + ALL_GATHER = 1 + REDUCE_SCATTER = 2 + + +class NVIDIA_GPU_TYPE(IntEnum): + VOLTA = 0 + AMPERE = 1 + HOPPER = 2 + + +@functools.lru_cache +def get_gpu_type() -> NVIDIA_GPU_TYPE: + gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or "" + if "V100" in gpu_info: + return NVIDIA_GPU_TYPE.VOLTA + elif "A100" in gpu_info: + return NVIDIA_GPU_TYPE.AMPERE + elif "H100" in gpu_info: + return NVIDIA_GPU_TYPE.HOPPER + else: + # for other gpu types, assume Ampere + return NVIDIA_GPU_TYPE.AMPERE + + +def get_collective_type(node: ir.IRNode) -> NCCL_COLL: + if not isinstance(node, ir._CollectiveKernel): + raise ValueError(f"node is not a collective kernel: {node}") + + kernel_name = node.python_kernel_name + assert kernel_name is not None + if "all_reduce" in kernel_name: + return NCCL_COLL.ALL_REDUCE + elif "all_gather" in kernel_name: + return NCCL_COLL.ALL_GATHER + elif "reduce_scatter" in kernel_name: + return NCCL_COLL.REDUCE_SCATTER + else: + raise ValueError(f"Unsupported collective kernel: {kernel_name}") + + +def get_collective_input_size_bytes(node: ir.IRNode) -> int: + sz_bytes = 0 + for inp in node.inputs: # type: ignore[attr-defined] + numel = sympy_product(inp.layout.size) + if isinstance(numel, sympy.Integer): + # For ease of testing + numel = int(numel) + else: + numel = V.graph.sizevars.size_hint(numel, fallback=0) + sz_bytes += numel * get_dtype_size(inp.layout.dtype) + return sz_bytes + + +def get_collective_group_size(node: ir.IRNode) -> int: + if type(node) == ir._CollectiveKernel: + from torch.distributed.distributed_c10d import _get_group_size_by_name + + return _get_group_size_by_name(node.constant_args[-1]) + else: + raise TypeError(f"Unsupported collective type: {node}") + + +#################################################################################################################### +# The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # +#################################################################################################################### + + +class NCCL_HW(IntEnum): + NVLINK = 0 + PCI = 1 + NET = 2 + + +class NCCL_ALGO(IntEnum): + TREE = 0 + RING = 1 + + +class NCCL_PROTO(IntEnum): + # The ordering and enum values here matches original in + # https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28 + # For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990 + LL = 0 # Low-latency + # LL128 = 1 # Low-latency 128-byte + # SIMPLE = 2 + + +# Latencies in us +# len(NCCL_ALGO) x len(NCCL_PROTO) +# NOTE: use array instead of tensor to prevent incompatibility with fake mode +baseLat = [ + # Tree + [ + 6.8, # LL + ], + # Ring + [ + 6.6, # LL + ], +] + +# Latencies in us +# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO) +hwLat = [ + # NVLINK + [ + [0.6], # Tree (LL) + [0.6], # Ring (LL) + ], + # PCI + [ + [1.0], # Tree (LL) + [1.0], # Ring (LL) + ], + # NET + [ + [5.0], # Tree (LL) + [2.7], # Ring (LL) + ], +] + + +# LL128 max BW per channel +llMaxBws = [ + # Volta-N1/Intel-N2/Intel-N4 + [ + 39.0, + 39.0, + 20.4, + ], + # Ampere-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], + # Hopper-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], +] + + +def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: + """ + Returns estimated NCCL collective runtime in nanoseconds (ns). + + The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. + We aim to estimate the runtime as accurately as possible. + + Assumptions: + - only ring algorithm (NCCL_ALGO_RING) is used + - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used + - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + - collective is one of: allreduce, reducescatter, allgather + """ + tensor_storage_size_bytes = get_collective_input_size_bytes(node) + # Convert bytes to GB + tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024 + + # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus. + # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + num_gpus_per_node = 8 + group_size = get_collective_group_size(node) + nNodes = math.ceil(group_size / num_gpus_per_node) + nRanks = group_size # this is total # of gpus globally that participate in this collective op + + if nRanks <= 1: + return 0 + + # Assumes ring algorithm + nccl_algo = NCCL_ALGO.RING + nccl_proto = NCCL_PROTO.LL + coll = get_collective_type(node) + + # =============== bandwidth computation =============== + # First compute bandwidth in GB/s; then at the end, convert it to GB/ns + + bwIntra = torch._inductor.config.intra_node_bw + bwInter = torch._inductor.config.inter_node_bw + + compCapIndex = get_gpu_type() + index2 = nNodes - 1 if nNodes <= 2 else 2 + # LL: for single node, we look at GPU type; for multi-node, we look at CPU type + index1 = compCapIndex if nNodes == 1 else 0 + llMaxBw = llMaxBws[index1][index2] + + # NOTE: each step of ring algorithm is synchronized, + # and is bottlenecked by the slowest link which is the inter-node interconnect. + # hence when nNodes >= 2, bw is inter-node bandwidth. + # NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc + # have this as `if nNodes <= 2` which seems wrong. Corrected it here. + bw = bwIntra if nNodes == 1 else bwInter + nChannels = 2 # Assume # channels is 2 + busBw = nChannels * bw + + # Various model refinements + busBw = min( + llMaxBw, + busBw + * (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0), + ) + + if coll == NCCL_COLL.ALL_REDUCE: + nsteps = 2 * (nRanks - 1) + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): + nsteps = nRanks - 1 + + # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time) + ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined] + bandwidth = busBw * ratio + # Convert GB/s to GB/ns + bandwidth_GB_per_ns = bandwidth / 1e9 + + # =============== latency computation =============== + intraHw = NCCL_HW.NVLINK + + if coll == NCCL_COLL.ALL_REDUCE: + if nNodes > 1: + nInterSteps = 2 * nNodes + else: + nInterSteps = 0 + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): + nInterSteps = nNodes - 1 + + # First compute latency in us; then at the end, convert it to ns + latency = baseLat[nccl_algo][nccl_proto] + intraLat = hwLat[intraHw][nccl_algo][nccl_proto] + interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto] + + # Inter-node rings still have to launch nsteps * net overhead. + netOverhead = 0.0 + if nNodes > 1: + netOverhead = 1.0 # getNetOverhead(comm); + intraLat = max(intraLat, netOverhead) + latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined] + # Convert us to ns + latency_ns = latency * 1e3 + + # =============== final result =============== + transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns + return transport_ns + latency_ns + + +################################################################################################################ +# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # +################################################################################################################ diff --git a/lib/python3.10/site-packages/torch/_inductor/comms.py b/lib/python3.10/site-packages/torch/_inductor/comms.py new file mode 100644 index 0000000000000000000000000000000000000000..dcad1e1bf67c9841937792f06e1eac6cda52ce62 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/comms.py @@ -0,0 +1,640 @@ +# mypy: allow-untyped-defs +# pyre-strict +from __future__ import annotations + +import heapq +import operator +import sys +from collections import defaultdict +from typing import Dict, List, Set, TYPE_CHECKING + +import torch + +from . import config, ir +from .dependencies import WeakDep +from .utils import ( + contains_collective, + contains_wait, + find_recursive_deps_of_node, + find_recursive_users_of_node, + is_collective, + is_fallback_op, + is_wait, +) + + +overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") + +if TYPE_CHECKING: + from .scheduler import BaseSchedulerNode + + +def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: + """ + Greedily schedules waits as late as possible. + """ + return _schedule_for_comm( + snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False + ) + + +def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: + """ + Greedily schedules comms as early as possible. + """ + return _schedule_for_comm( + snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False + ) + + +def reorder_compute_for_overlap( + snodes: List[BaseSchedulerNode], +) -> List[BaseSchedulerNode]: + """ + This achieves the following overall scheduling procedure: + Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes + that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N. + Step 2: If all those compute nodes are sufficient to overlap comm N, we're done. + Otherwise, we now need to look elsewhere to find compute that overlaps with comm N. + We prioritize compute nodes that are needed sooner. + Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1. + Step 4: We schedule comm N + 1. + Repeat this for subsequent comm nodes. + """ + return _schedule_for_comm( + snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True + ) + + +def _schedule_for_comm( + snodes: List[BaseSchedulerNode], + raise_comms: bool, + sink_waits: bool, + reorder_for_overlap: bool, +) -> List[BaseSchedulerNode]: + """ + Schedule `snodes` for various comm optimization objectives. + + Args: + snodes: the nodes to be scheduled. + raise_comms: whether to greedily schedule collectives as early as possible + sink_wait: whether to greedily schedule waits as late as possible + reorder_compute_for_overlap: whether to reorder compute nodes to + optimize for compute/communication overlapping. + + Returns: + The new schedule order. + + Some notes on the synergy between different options: + - `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`. + - When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized. + """ + # We assign each node a tuple of scores (score_0, score_1, score_2), + # decreasing in importance, with a lower value indicating a higher ranking: + # + # - score_0: the lowest comm_idx among the comm nodes that the node blocks. + # If a node doesn't block any comm nodes, its score_0 is set to + # sys.maxsize. This score ensures that comm nodes get scheduled as early as + # possible. + # - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures + # that wait nodes are deferred as late as possible. + # - score_2: the index of the node in the original topological order. This + # score provides stability in case of ties. + # + # When only raise_comms is True, only score_0 and score_2 are considered. + # When only sink_waits is True, only score_1 and score_2 are considered. + # When neither is True, the original order is yielded. + buf_name_to_snode = {} + name_to_fused_node = {} + scores_0, scores_1, scores_2 = {}, {}, {} + for idx, snode in enumerate(snodes): + for buf_name in snode.get_buffer_names(): + buf_name_to_snode[buf_name] = snode + + for op_name in snode.get_operation_names(): + name_to_fused_node[op_name] = snode + name_to_fused_node[snode.get_name()] = snode + + node_name = snode.get_name() + scores_0[node_name] = sys.maxsize + scores_1[node_name] = 0 + scores_2[node_name] = idx + + comm_idx = 0 + for snode in snodes: + if raise_comms and contains_collective(snode): + scores_0[snode.get_name()] = comm_idx + for anc in snode.ancestors: + anc_fused_name = name_to_fused_node[anc].get_name() + scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx) + comm_idx += 1 + elif sink_waits and contains_wait(snode): + scores_1[snode.get_name()] = 1 + + class Runnable: + def __init__(self, snode) -> None: + self.snode = snode + name = next(iter(snode.get_operation_names())) + fused_name = name_to_fused_node[name].get_name() + self.score = ( + scores_0[fused_name], + scores_1[fused_name], + scores_2[fused_name], + ) + + def __lt__(self, other): + return self.score < other.score + + unmet_deps: Dict[BaseSchedulerNode, Set[str]] = { + snode: {dep.name for dep in snode.unmet_dependencies} for snode in snodes + } + + ready: List[Runnable] = [] + buffer_users: Dict[str, Set[BaseSchedulerNode]] = defaultdict(set) + snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes} + + for snode, deps in unmet_deps.items(): + if len(deps) == 0: + heapq.heappush(ready, Runnable(snode)) + for dep in deps: + buffer_users[dep].add(snode) + + scheduled = [] + + def schedule(snode): + """ + Schedules `snode` and put all unblocked nodes onto the ready queue. + """ + scheduled.append(snode) + for buf_name in snode.get_buffer_names(): + for snode in buffer_users[buf_name]: + unmet_deps[snode].remove(buf_name) + if len(unmet_deps[snode]) == 0: + heapq.heappush(ready, Runnable(snode)) + + def get_overlapping_candidate(): + """ + Return the next node in the ready queue that's neither a collective or + a wait. + """ + candidates = [ + x + for x in ready + if not contains_collective(x.snode) and not contains_wait(x.snode) + ] + if len(candidates) == 0: + return None + return min(candidates, key=lambda x: x.score) + + def schedule_collective_for_overlap(snode): + """ + Schedules collective node `snode`, along with one or more compute nodes + to overlap with it. The strategy is described in the comment of + `reorder_compute_for_overlap`. + """ + assert contains_collective(snode) + schedule(snode) + + collective_cost = snode_to_cost[snode] + while ( + collective_cost > 0 + and (candidate := get_overlapping_candidate()) is not None + ): + ready.remove(candidate) + schedule(candidate.snode) + collective_cost -= snode_to_cost[candidate.snode] + heapq.heapify(ready) + + while len(ready): + snode = heapq.heappop(ready).snode + if reorder_for_overlap and contains_collective(snode): + schedule_collective_for_overlap(snode) + else: + schedule(snode) + + for snode, deps in unmet_deps.items(): + assert len(deps) == 0, ( + "Detected unscheduled nodes. " + f"Nodes with unmet dependencies: {unmet_deps}" + ) + return scheduled + + +def decide_global_ordering_of_comms( + nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node +) -> List[BaseSchedulerNode]: + """ + Decide global ordering of comms, by just enforcing the ordering that's in the input graph + (might not be the same ordering as the eager mode program). + TODO: Come up with a better approach + """ + # If FSDP2 is used, we apply FSDP-specific passes. + if any( + is_fallback_op( + x.node, + { + torch.ops.fsdp.all_gather_copy_in.default, + torch.ops.fsdp.chunk_cat.default, + }, + ) + for x in nodes + ): + nodes = enforce_comm_ordering_for_fsdp(nodes, name_to_buf, name_to_fused_node) + + comm_nodes = [n for n in nodes if contains_collective(n)] + + for i in range(1, len(comm_nodes)): + # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm + mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) + for buf in comm_nodes[i - 1].get_buffer_names(): + comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) + + return nodes + + +def estimate_op_runtime(snode: BaseSchedulerNode) -> float: + """ + Returns estimated op runtime in nanoseconds (ns) + """ + if config.estimate_op_runtime == "default": + runtime = snode.get_estimated_runtime() + else: + assert callable(config.estimate_op_runtime) + runtime = config.estimate_op_runtime(snode) + return runtime + + +def node_summary(snode): + detail = "" + if isinstance(snode.node, ir.ExternKernelOut): + detail = f" ({snode.node.python_kernel_name})" + out_tensor_info = "" + if ( + hasattr(snode.node, "layout") + and hasattr(snode.node.layout, "size") + and hasattr(snode.node.layout, "stride") + ): + out_tensor_info = ( + f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})" + ) + node_name = "" + if hasattr(snode.node, "name"): + node_name = snode.node.name + return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})" + + +def visualize_overlap(order): + total_est_runtime: float = 0.0 + cur_comm_node = None + for snode in order: + if cur_comm_node is None: + if contains_collective(snode): + total_est_runtime += estimate_op_runtime(snode) + cur_comm_node = snode.node + elif is_wait(snode.node): + raise AssertionError( + "Wait is not expected when there is no collective running" + ) + else: # exposed compute op + total_est_runtime += estimate_op_runtime(snode) + overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 + else: # cur_comm_node is not None + if contains_collective(snode): + raise AssertionError( + "Found two collectives running at the same time. " + "`visualize_overlap` needs to be updated to handle this case" + ) + elif is_wait(snode.node): # end of this comm op + overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 + cur_comm_node = None + else: # overlapped compute op + overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004 + overlap_log.debug( + f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004 + ) + + +def reorder_compute_and_comm_for_overlap( + snodes: List[BaseSchedulerNode], +) -> List[BaseSchedulerNode]: + order = snodes + + for p in config.reorder_for_compute_comm_overlap_passes: + if isinstance(p, str) and p in globals(): + p = globals()[p] # it is a builtin pass + if torch.distributed.get_rank() == 0: + overlap_log.debug( + f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004 + ) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug(str(e)) + order = p(order) # type: ignore[operator] + if torch.distributed.get_rank() == 0: + overlap_log.debug( + f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004 + ) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug(str(e)) + return order + + +def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None: + try: + import torch.distributed._composable.fsdp._fsdp_collectives + + assert torch.distributed.is_available() + # Assert existence of these ops + assert ( + torch.ops._c10d_functional.all_gather_into_tensor + and torch.ops._c10d_functional.all_gather_into_tensor_out + ) + except (ImportError, AttributeError, AssertionError): + return + + from .pattern_matcher import ( + CallFunction, + KeywordArg, + Match, + PatternMatcherPass, + register_graph_pattern, + ) + + """ + all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); + getitem = all_gather_copy_in[0]; + (getitem_1 = all_gather_copy_in[1];) # optional + + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...); + + -> + + all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); + getitem = all_gather_copy_in[0]; + getitem_1 = all_gather_copy_in[1]; + + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1); + """ + + def remove_unused_getitem(g): + # Remove `getitem_X = all_gather_copy_in[1]` which is never used. + node_list = list(g.nodes) + for n in node_list: + if ( + n.target == operator.getitem + and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default + and n.args[1] == 1 + ): + g.erase_node(n) + + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunction( + torch.ops._c10d_functional.all_gather_into_tensor.default, + CallFunction( + operator.getitem, + CallFunction( + torch.ops.fsdp.all_gather_copy_in.default, + KeywordArg("all_gather_inputs"), + KeywordArg("inp_split_sizes"), + KeywordArg("all_gather_input_numel"), + KeywordArg("world_size"), + KeywordArg("rank"), + KeywordArg("dtype"), + KeywordArg("device"), + ), + KeywordArg("item_idx"), + ), + KeywordArg("group_size"), + KeywordArg("group_name"), + ), + pass_dict=graph_pass, + extra_check=lambda match: match.kwargs["item_idx"] == 0, + ) + def reinplace_all_gather(match: Match, *args, **kwargs): + def repl( + *args, + ): + copy_in_args = args[:-2] + group_size = args[-2] + group_name = args[-1] + all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default( + *copy_in_args + ) + getitem = all_gather_copy_in[0] + getitem_1 = all_gather_copy_in[1] + all_gather_into_tensor = ( + torch.ops._c10d_functional.all_gather_into_tensor_out.default( + getitem, group_size, group_name, out=getitem_1 + ) + ) + return all_gather_into_tensor + + match.replace_by_example( + repl, + [ + kwargs["all_gather_inputs"], + kwargs["inp_split_sizes"], + kwargs["all_gather_input_numel"], + kwargs["world_size"], + kwargs["rank"], + kwargs["dtype"], + kwargs["device"], + kwargs["group_size"], + kwargs["group_name"], + ], + ) + + remove_unused_getitem(graph) + graph_pass.apply(graph) # type: ignore[arg-type] + + +def get_op_idx(snode): + assert not isinstance( + snode, + ( + torch._inductor.scheduler.FusedSchedulerNode, + torch._inductor.scheduler.GroupedSchedulerNode, + ), + ) + return int(snode.get_name()[2:]) + + +def enforce_comm_ordering_for_fsdp( + snodes: List[torch._inductor.scheduler.BaseSchedulerNode], + name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer], + name_to_fused_node: Dict[str, BaseSchedulerNode], +) -> List[torch._inductor.scheduler.BaseSchedulerNode]: + from . import scheduler + + new_order: list[BaseSchedulerNode] = [] + scheduled = set() + ag_exists = False + rs_exists = False + ag_grouped_node_to_wait_grouped_node = {} + rs_grouped_node_to_wait_grouped_node = {} + snode_name_to_final_snode = {} + + def _create_group_node(snodes_to_group): + group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group) + for snode in snodes_to_group: + snode_name_to_final_snode[snode.get_name()] = group_node + snode_name_to_final_snode[group_node.get_name()] = group_node + return group_node + + # Create grouped nodes for specific sets of ops + for snode in snodes: + # Case 1: Handle AllGather + if is_collective( + snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default + ) and any( + is_fallback_op( + name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default + ) + for x in snode.ancestors + ): + ag_exists = True + ag_snode = snode + ag_related_snode_set: set[scheduler.BaseSchedulerNode] = set() + + # Find the "cast + copy_in + getitem + all_gather" code block + find_recursive_deps_of_node( + ag_snode, + ag_related_snode_set, + name_to_buf, + name_to_fused_node, + ) + + # Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block + allowed_ops = { + torch.ops._c10d_functional.all_gather_into_tensor_out.default, + torch.ops._c10d_functional.wait_tensor.default, + torch.ops.fsdp.split_with_sizes_copy.default, + torch.ops.aten.set_.source_Tensor, + } + find_recursive_users_of_node( + ag_snode, + ag_related_snode_set, + name_to_buf, + name_to_fused_node, + criteria_cb=lambda x: not ( + isinstance(x, scheduler.NopKernelSchedulerNode) + or ( + isinstance(x, scheduler.ExternKernelSchedulerNode) + and x.node.op_overload in allowed_ops # type: ignore[union-attr] + ) + ), + ) + + # sort nodes by original operation order + ag_related_snodes = sorted( + ag_related_snode_set, key=lambda x: get_op_idx(x) + ) + + # In the "reuse layer" case, some ops in the 2nd all-gather code block could also + # depend on ops in the 1st all-gather code block, and we don't want to group them together. + end_idx_of_current_ag_block = len(ag_related_snodes) + copy_out_count = 0 + for i in range(len(ag_related_snodes)): + cur_snode = ag_related_snodes[i] + if is_fallback_op( + cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default + ): + copy_out_count += 1 + if copy_out_count > 1: + end_idx_of_current_ag_block = i + break + + ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block] + + # Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode + wait_node_idx = None + for i in range(len(ag_related_snodes) - 1): + if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel): + wait_node_idx = i + 1 + break + assert wait_node_idx is not None + ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx]) + + # Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode + ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:]) + + ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node + + # Case 2: Handle ReduceScatter + elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default): + rs_exists = True + rs_snode = snode + + # Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block + rs_related_snode_set: set[scheduler.BaseSchedulerNode] = set() + find_recursive_users_of_node( + rs_snode, + rs_related_snode_set, + name_to_buf, + name_to_fused_node, + ) + + # sort nodes by original operation order + rs_related_snodes = sorted( + rs_related_snode_set, key=lambda x: get_op_idx(x) + ) + + # Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode + wait_node_idx = None + for i in range(len(rs_related_snodes) - 1): + if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel): + wait_node_idx = i + 1 + break + assert wait_node_idx is not None + rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx]) + + # Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode + rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:]) + + rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node + + assert len(snode_name_to_final_snode) > 0 + if ag_exists: + assert len(ag_grouped_node_to_wait_grouped_node) > 0 + if rs_exists: + assert len(rs_grouped_node_to_wait_grouped_node) > 0 + + # Build the new node schedule, taking GroupedSchedulerNode into account + for snode in snodes: + if snode.get_name() in snode_name_to_final_snode: + snode = snode_name_to_final_snode[snode.get_name()] + if snode in scheduled: + continue + new_order.append(snode) + scheduled.add(snode) + + # Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run + # before next AllGather's "copy_in then AG" group node + prev_ag_wait = None + for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items(): + if prev_ag_wait is not None: + mutating_buf = next(iter(ag_group_node.get_buffer_names())) + for o in prev_ag_wait.get_outputs(): + ag_group_node.add_fake_dep( + WeakDep(o.get_name(), mutating_buf=mutating_buf) + ) + prev_ag_wait = wait_group_node + + # Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run + # before next ReduceScatter's "copy_in then RS" group node + prev_rs_wait = None + for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items(): + if prev_rs_wait is not None: + mutating_buf = next(iter(rs_group_node.get_buffer_names())) + for o in prev_rs_wait.get_outputs(): + rs_group_node.add_fake_dep( + WeakDep(o.get_name(), mutating_buf=mutating_buf) + ) + prev_rs_wait = wait_group_node + + return new_order # type: ignore[return-value] diff --git a/lib/python3.10/site-packages/torch/_inductor/compile_fx.py b/lib/python3.10/site-packages/torch/_inductor/compile_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..db4c1e0eddfeceb3b65a7551c9e3615eabebc257 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/compile_fx.py @@ -0,0 +1,1629 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import contextlib +import functools +import io +import itertools +import logging +import os +import sys +import time +import warnings +from itertools import count +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from unittest import mock + +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch.fx +import torch.utils._pytree as pytree +from functorch.compile import min_cut_rematerialization_partition +from torch._dynamo import ( + compiled_autograd, + config as dynamo_config, + logging as dynamo_logging, + utils as dynamo_utils, +) +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.repro.after_aot import wrap_compiler_debug +from torch._dynamo.utils import ( + counters, + detect_fake_mode, + flatten_graph_inputs, + lazy_format_graph_code, +) +from torch._functorch import config as functorch_config +from torch._functorch.aot_autograd import aot_export_module, make_boxed_func +from torch._inductor.codecache import ( + _StrideExprStr, + code_hash, + CompiledFxGraph, + FxGraphCache, +) +from torch._inductor.cudagraph_utils import ( + BoxedDeviceIndex, + CudagraphCachedInfo, + get_placeholder_info, + log_cudagraph_skip_and_bump_counter, + PlaceholderInfo, +) +from torch._inductor.debug import save_args_for_compile_fx_inner +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._inductor.utils import ( + BoxedBool, + count_tangents, + fresh_inductor_cache, + InputType, + is_gpu, + should_assume_input_aligned, + tensor_is_aligned, +) +from torch._logging import trace_structured +from torch._ops import OpOverload +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.monitor import _WaitCounter + +from .._dynamo.backends.common import aot_autograd +from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined] +from ..fx.graph import _PyTreeCodeGen +from . import config, metrics +from .debug import DebugContext +from .decomposition import select_decomp_table +from .fx_passes.joint_graph import joint_graph_passes +from .fx_passes.post_grad import post_grad_passes, view_to_reshape +from .fx_passes.pre_grad import pre_grad_passes +from .graph import GraphLowering +from .ir import ExternKernelNode +from .utils import ( + align_inputs_from_check_idxs, + clone_preserve_strides, + copy_misaligned_inputs, + get_cloned_parameter_buffer_name, + has_incompatible_cudagraph_ops, + maybe_get_suppress_shape_guards_ctx, + output_node, + remove_unaligned_input_idxs, + shape_env_from_inputs, +) +from .virtualized import V + + +if config.is_fbcode(): + from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log +else: + # no-op decorator + def time_and_log(attr: str): + return dynamo_utils.identity + + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs") +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +# copy_ fails when trying to write to tensors with memory overlap, +# for expanded dimensions (a dimension which used to have size 1 -> ?) +# we can select one element from that dimension and write to it +# to achieve writing to all values of that dimension of the input tensor +def get_expanded_dims(t): + if not isinstance(t, torch.Tensor): + return None + return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1] + + +def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor: + for expanded_dim in expanded_dims: + t = torch.ops.aten.slice(t, expanded_dim, 0, 1) + return t + + +def complex_memory_overlap(t: torch.Tensor) -> bool: + # if torch._debug_has_internal_overlap thinks this tensor potentially has + # memory overlap internally, let's dig deeper to find out whether it's true. + # + # Call squeeze() so that dimension with size 1 does not cause false positive. + t = index_expanded_dims(t, get_expanded_dims(t)).squeeze() + if torch._debug_has_internal_overlap(t) != 0: + strides = t.stride() + sizes = t.shape + indices = list(range(len(strides))) + indices = [x for _, x in sorted(zip(strides, indices))] + for i in range(len(strides)): + prev_stride = 1 if i == 0 else strides[indices[i - 1]] + prev_size = 1 if i == 0 else sizes[indices[i - 1]] + if strides[indices[i]] < prev_stride * prev_size: + return True + return False + + +def get_static_input_idxs(num_fixed): + # If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes + # of cudagraphs. Rather than copying these into cudagraph-owned memory + # like we do for normal inputs on each run, we will re-record a cudagraph if these + # parameter locations change. + context = torch._guards.TracingContext.try_get() + fixed = list(range(num_fixed)) + if not context or not context.fw_metadata: + return fixed + + return fixed + context.fw_metadata.static_input_indices + + +@functools.lru_cache(None) +def _step_logger(): + return dynamo_logging.get_step_logger(log) + + +@functools.lru_cache(None) +def _warn_tf32_disabled(): + if ( + torch.cuda.is_available() + and not torch.backends.cuda.matmul.allow_tf32 + and torch.cuda.get_device_capability() >= (8, 0) + ): + warnings.warn( + "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. " + "Consider setting `torch.set_float32_matmul_precision('high')` for better performance." + ) + + +def _unlift_graph(mod, gm, graph_signature): + from torch.export.unflatten import _assign_attr, _AttrKind + + state_dict = {} + for name, param in mod.named_parameters(remove_duplicate=False): + state_dict[name] = param + _assign_attr( + param, + gm, + name, + attr_kind=_AttrKind.PARAMETER, + ) + for name, buffer in mod.named_buffers(remove_duplicate=False): + state_dict[name] = buffer + _assign_attr( + buffer, + gm, + name, + attr_kind=_AttrKind.BUFFER, + ) + + placeholder_nodes = gm.graph.find_nodes(op="placeholder") + lifted_inputs = [] + + # In AOTI, module parameters and buffers are not lifted as graph inputs. + # As a result, mutation to buffers has side effect which makes their initial + # values different from Eager. So we clone them here as a copy. + # We are not cloning for parameters, although it will be needed if we want to + # support training. + for node in placeholder_nodes: + node_name = node.name + if node_name in graph_signature.inputs_to_parameters: + parameter_name = graph_signature.inputs_to_parameters[node_name] + lifted_inputs.append(parameter_name) + elif node_name in graph_signature.inputs_to_buffers: + buffer_name = graph_signature.inputs_to_buffers[node_name] + lifted_inputs.append(buffer_name) + gm.meta[ + get_cloned_parameter_buffer_name(buffer_name) + ] = clone_preserve_strides(state_dict[buffer_name]) + else: + assert node_name in graph_signature.user_inputs + lifted_inputs.append(None) + + from torch.export._unlift import _unlift + + outputs = list(gm.graph.nodes)[-1].args[0] + mutated_outputs = [] + buffer_mutations = graph_signature.buffers_to_mutate + user_input_mutations = graph_signature.user_inputs_to_mutate + output_tokens = graph_signature.output_tokens + for idx, out in enumerate(outputs): + value = None + + if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): + if out.name in buffer_mutations: + value = buffer_mutations[out.name] + elif out.name in user_input_mutations: + value = user_input_mutations[out.name] + + mutated_outputs.append(value) + + unlifted_gm = _unlift( + gm, + lifted_inputs, + mutated_outputs, + pytree.LeafSpec(), + None, + state_dict, + {}, + ) + return unlifted_gm + + +def _get_subgraph_names(gm): + for node in sorted( + itertools.chain( + gm.graph.find_nodes(op="call_function", target=torch.ops.higher_order.cond), + gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.while_loop + ), + ) + ): + if node.target == torch.ops.higher_order.cond: + true_subgraph_name = node.args[1].name + false_subgraph_name = node.args[2].name + yield true_subgraph_name + yield false_subgraph_name + elif node.target == torch.ops.higher_order.while_loop: + cond_subgraph_name = node.args[0].name + body_subgraph_name = node.args[1].name + yield cond_subgraph_name + yield body_subgraph_name + + +def _recursive_pre_grad_passes(gm, example_inputs): + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + # as we don't have recursive example inputs, passing None here + new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None) + setattr(gm, subgraph_name, new_subgraph) + return pre_grad_passes(gm, example_inputs) + + +def _recursive_joint_graph_passes(gm): + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + _recursive_joint_graph_passes(subgraph) + joint_graph_passes(gm) + + +def _recursive_post_grad_passes(gm, is_inference: bool = False): + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + _recursive_post_grad_passes(subgraph, is_inference) + post_grad_passes(gm, is_inference) + + +def split_const_gm( + gm: torch.fx.GraphModule, + lifted_constants: Optional[Dict[str, Any]] = None, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> Tuple[torch.fx.GraphModule, Dict[str, int]]: + """ + This function takes an GraphModule input "gm". + The gm will be split into 2 components, + 1) const_gm, which consists the subgraph of gm that can be constant folded. + 2) gm (being inplace modified,) which returns the graph after constant folding. + + If an additional "lifted_constants" argument is passed in, we will assume the gm has + been lifted and run the transformation accordingly. + + When a "skip_folding_node_fn" callback is passed, we will skip constant folding on + the nodes for which the callback returns True. + + const_output_index is a mapping of corresponding node name from gm to the + output index of const_gm. + Returns (const_gm, const_output_index) + """ + from torch._inductor.constant_folding import ( + CONST_MODULE_TAG, + META_TAG, + MODULE_TAG, + replace_node_with_constant, + run_and_get_constant_graph, + ) + + const_gm, const_result = run_and_get_constant_graph( + gm, lifted_constants, skip_folding_node_fn + ) + + const_outputs = { + x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0]) + } + + to_erase_node = [] + to_replace_node = [] + const_output_index = {} + for node in gm.graph.nodes: + if node.name in const_outputs: + to_replace_node.append(node) + elif node.meta[META_TAG] == CONST_MODULE_TAG and node.op != "placeholder": + to_erase_node.append(node) + + for node in to_replace_node: + new_const_name = "_FOLDED_CONST_" + node.name + replace_node_with_constant( + gm, + node, + const_result[const_outputs[node.name]], + new_const_name, + ) + const_output_index[new_const_name] = const_outputs[node.name] + for node in to_erase_node[::-1]: + if node.users: + for n in node.users: + assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty." + else: + gm.graph.erase_node(node) + gm.recompile() + + return const_gm, const_output_index + + +def is_tf32_warning_applicable(gm: torch.fx.GraphModule): + aten = torch.ops.aten + tf32_ops = { + aten.mm.default, + aten.addmm.default, + aten.bmm.default, + aten.baddbmm.default, + } + for target in tf32_ops: + for node in gm.graph.find_nodes(op="call_function", target=target): + if ( + isinstance(node.meta.get("val", None), torch.Tensor) + and node.meta["val"].dtype == torch.float32 + and node.meta["val"].device.type == "cuda" + ): + return True + return False + + +def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]): + """ + For CPU backend, enable comprehensive padding causes some unit tests + fail due to changing number of generated kernels. Skip for now. + """ + has_gpu = any( + is_gpu(t.device.type) for t in example_inputs if isinstance(t, torch.Tensor) + ) + + if config.disable_padding_cpu and config.comprehensive_padding and not has_gpu: + perf_hint_log.info("Skip comprehensive padding on CPU") + return config.patch(comprehensive_padding=False) + else: + return contextlib.nullcontext() + + +def fake_tensor_prop( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + force_allow_non_fake_inputs: bool = False, +): + """ + If we can not detect fake mode from the context of inputs, create one. + + The created fake mode will be returned. + """ + fake_mode = detect_fake_mode(example_inputs) + if not fake_mode: + fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs) + else: + ctx = ( + contextlib.nullcontext() + if not force_allow_non_fake_inputs + else mock.patch.object(fake_mode, "allow_non_fake_inputs", True) + ) + with ctx: # type: ignore[attr-defined] + FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs( + *example_inputs + ) + + return fake_mode + + +def should_use_remote_fx_graph_cache(): + if config.fx_graph_remote_cache is not None: + return config.fx_graph_remote_cache + if not config.is_fbcode(): + return False + + if torch._utils_internal.is_fb_unit_test(): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + jk_name = "pytorch/remote_cache:fx_graph_memcache_version" + if torch.version.hip is not None: + jk_name = "pytorch/remote_cache:fx_graph_memcache_version_amd" + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(jk_name) + + +# pass config dict back to user +def get_patched_config_dict(config_patches=None) -> Dict[str, Any]: + with config.patch(config_patches): + return config.get_config_copy() + + +@contextlib.contextmanager +def with_fresh_cache_if_config(): + if config.force_disable_caches: + # Don't delete the cache dir because it has to survive beyond the + # compile_fx call. Let's put the temp dirs under the default cache + # dir so they're easier to locate. + with fresh_inductor_cache(dir=cache_dir(), delete=False): + yield + else: + yield + + +def compile_fx_inner(*args, **kwargs): + # Need with_fresh_cache_if_config for compile_fx_inner even if we already have one for + # compile_fx. The reason is the compilation for backward graph may happen after + # compile_fx return and we may want to use the _LazyGraphModule for compiling + # the backward graph as well. + with contextlib.ExitStack() as stack: + stack.enter_context(torch.utils._python_dispatch._disable_current_modes()) + stack.enter_context(_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)) + stack.enter_context( + dynamo_utils.dynamo_timed( + "compile_fx_inner", phase_name="inductor_compile", fwd_only=False + ) + ) + stack.enter_context(with_fresh_cache_if_config()) + stack.enter_context(DebugContext()) + + return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( + *args, **kwargs + ) + + +@time_and_log(attr="compilation time (in seconds)") +def _compile_fx_inner( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + cudagraphs: Optional[BoxedBool] = None, + static_input_idxs: Optional[List[int]] = None, + is_backward: bool = False, + graph_id: Optional[int] = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + is_inference: bool = False, + boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, + user_visible_outputs: Optional[Dict[str, None]] = None, + layout_opt: Optional[bool] = None, + extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None, +) -> Union[CompiledFxGraph, str]: + """ + Inductor API that compiles a single graph. + + If you change the argument list for this function, make sure you + also update the call to save_args_for_compile_fx_inner below accordingly. + """ + if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode: + # trigger the real recompilation for _LazyGraphModule before returning + # the forward method. + from torch.fx._lazy_graph_module import _LazyGraphModule + + _LazyGraphModule.force_recompile(gm) + return make_boxed_func(gm.forward) + + if static_input_idxs is None: + static_input_idxs = [] + + static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs) + + assert isinstance( + next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list) + ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" + + if config.save_args: + save_args_for_compile_fx_inner( + gm, + example_inputs, + cudagraphs=cudagraphs, + static_input_idxs=static_input_idxs, + is_backward=is_backward, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + is_inference=is_inference, + boxed_forward_device_index=boxed_forward_device_index, + user_visible_outputs=user_visible_outputs, + layout_opt=layout_opt, + ) + + if cudagraphs is None: + cudagraphs = BoxedBool(config.triton.cudagraphs) + + # Inputs to fx_codegen_and_compile + # Anything that affects codegen should go here, so if the signature + # of fx_codegen_and_compile changes, the dict should be updated accordingly + graph_kwargs = { + "cudagraphs": cudagraphs, + "static_input_idxs": static_input_idxs, + "is_backward": is_backward, + "graph_id": graph_id, + "cpp_wrapper": cpp_wrapper, + "aot_mode": aot_mode, + "is_inference": is_inference, + "user_visible_outputs": user_visible_outputs, + "layout_opt": layout_opt, + "extern_node_serializer": extern_node_serializer, + } + + start = time.time() + + fx_graph_remote_cache = should_use_remote_fx_graph_cache() + + inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) # type: ignore[arg-type] + + def codegen_and_compile( + gm, + example_inputs, + inputs_to_check, + fx_kwargs, + ): + """ + This function calls fx_codegen_and_compile and also adds some extra metadata to the resulting + compiled fx graph. The metadata is saved to FXGraphCache. + """ + compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs) + if isinstance(compiled_graph, str): + # We only return a string in aot mode, in which case we don't + # need to do any post-compilation steps: we just return the string, + # which is the filename of the compiled code. + return compiled_graph + cudagraph_info = None + if cudagraphs: + # check cudagraph disabling reasons from inductor lowering + if compiled_graph.disabled_cudagraphs_reason: + if "cuda" in compiled_graph.device_types: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}" + ) + else: + counters["inductor"]["cudagraph_skips"] += 1 + BoxedBool.disable(cudagraphs) + else: + complex_memory_overlap_inputs = any( + complex_memory_overlap(t) + for t in example_inputs + if isinstance(t, torch.Tensor) + ) + + if not config.triton.cudagraph_support_input_mutation: + # Skip supports for cudagraph-managed tensors + from torch._inductor.cudagraph_utils import ( + check_for_mutation_ignore_cuda_graph_managed_tensor, + ) + + has_mutation_str = ( + check_for_mutation_ignore_cuda_graph_managed_tensor( + gm, + compiled_graph, + static_input_idxs, # type:ignore[arg-type] + ) + ) + has_mutation = has_mutation_str is not None + + if has_mutation: + compiled_graph.disabled_cudagraphs_reason = has_mutation_str + else: + # Check mutation later to support cudagraph-managed tensors + has_mutation = None + + cudagraph_tests = [ + (not has_mutation, "mutated inputs"), + (not has_incompatible_cudagraph_ops(gm), "incompatible ops"), + (not complex_memory_overlap_inputs, "complex memory overlap"), + ( + all( + isinstance(t, (torch.Tensor, torch.SymInt)) + for t in example_inputs + ), + "non-Tensor inputs", + ), + ] + output = output_node(gm) + # output args are tuple of first argument + assert len(output.args) == 1 + stack_traces = [ + (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) + for arg in output.args[0] + ] + cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b] + placeholders = tuple(get_placeholder_info(gm.graph)) + cudagraph_info = CudagraphCachedInfo( + placeholders, stack_traces, cudagraph_fail_reasons + ) + + compiled_graph.cudagraph_info = cudagraph_info + compiled_graph.inputs_to_check = inputs_to_check + compiled_graph.fx_kwargs = fx_kwargs + # TODO: should this be part of fx_kwargs + compiled_graph.boxed_forward_device_index = boxed_forward_device_index + return compiled_graph + + with _WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _: + if ( + not config.force_disable_caches + and (config.fx_graph_cache or fx_graph_remote_cache) + and not aot_mode + ): + for i, input in enumerate(example_inputs): + if ( + isinstance(input, torch.Tensor) + and input.device.type == "cuda" + and i in static_input_idxs + ): + input._is_inductor_static = True # type: ignore[attr-defined] + + compiled_graph = FxGraphCache.load( + codegen_and_compile, + gm, + example_inputs, + graph_kwargs, + inputs_to_check, + local=config.fx_graph_cache, + remote=fx_graph_remote_cache, + ) + else: + compiled_graph = codegen_and_compile( + gm, example_inputs, inputs_to_check, graph_kwargs # type: ignore[arg-type] + ) + if aot_mode: + # AOT mode is special because codegen_and_compile returns a string. + # In that case, we don't need to run all post compilation steps, we just need + # to return the string directly. + return compiled_graph + compiled_graph = FxGraphCache.post_compile( + compiled_graph, example_inputs, cudagraphs + ) + + log.debug("FX codegen and compilation took %.3fs", time.time() - start) + + _step_logger()( + logging.INFO, + "torchinductor done compiling " + f"{'BACKWARDS' if is_backward else 'FORWARDS'} " + f"graph {graph_id}", + ) + # aot autograd needs to know to pass in inputs as a list + compiled_graph._boxed_call = True + return compiled_graph + + +def fx_codegen_and_compile( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + cudagraphs: Optional[BoxedBool] = None, + static_input_idxs: Optional[List[int]] = None, + is_backward: bool = False, + graph_id: Optional[int] = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + is_inference: bool = False, + # Use a dict with None value rather than a set for deterministic + # iteration order just in case. + user_visible_outputs: Optional[Dict[str, None]] = None, + layout_opt: Optional[bool] = None, + extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None, +) -> Union[CompiledFxGraph, str]: + if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None: + import time + + log.warning("Sleeping for %s since sleep_sec_TESTING_ONLY is set", sleep_sec) + time.sleep(sleep_sec) + + with dynamo_utils.preserve_rng_state(): + if is_tf32_warning_applicable(gm): + _warn_tf32_disabled() + + inductor_counters = counters["inductor"].copy() + + # lift the maximum depth of the Python interpreter stack + # to adapt large/deep models + sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000)) + + _step_logger()( + logging.INFO, + "torchinductor compiling " + f"{'BACKWARDS' if is_backward else 'FORWARDS'} " + f"graph {graph_id}", + ) + + def log_graph_runnable(): + fd = io.StringIO() + torch._dynamo.repro.after_aot.save_graph_repro( + fd, gm, example_inputs, "inductor", save_dir=None + ) + return fd.getvalue() + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_runnable", + "encoding": "string", + }, + payload_fn=lambda: log_graph_runnable(), + ) + + V.debug.fx_graph(gm, example_inputs) + # TODO: Should we actually dump this? It should be redundant with the aot + # structured logs... + # trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False)) + + shape_env = shape_env_from_inputs(example_inputs) + + # Convert view to reshape in the graph. This is necessary primarily for + # layout optimization. Do it unconditionally for uniformity. + # + # It's needed because when we do layout optimization, an contiguous tensor + # in eager mode may becomes a channels last tensor. A view op previously + # can be applied to the contiguous tensor may not be able to be applied + # on the channels tensor any more. An error like + # RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + # will be printed. + # + # Replace view op to reshape op in this case. + # As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this. + # + # Also this has to be done before FakeTensorProp below to avoid the failed + # .view() call. + view_to_reshape(gm) + + # It is safe to run FakeTensorProp under no_grad because by the time + # we're in inductor, we assume that AOTAutograd has already "taken care" + # of autograd, so there should be no more autograd-related API's in the + # graph. + with torch.no_grad(): + fake_mode = fake_tensor_prop(gm, example_inputs) + + # pattern matcher passes might not preserve striding information + # on node.meta["val"]. if in the future we rely on these being + # correct we will need to fix. + + with V.set_fake_mode(fake_mode): + # has some issues with memory in training + _recursive_post_grad_passes(gm, is_inference=is_inference) + V.debug.fx_graph_transformed(gm, example_inputs) + post_grad_graphs_log.debug( + "%s", + lazy_format_graph_code( + "AFTER POST GRAD", + gm, + include_stride=True, + include_device=True, + colored=True, + ), + ) + trace_structured( + "inductor_post_grad_graph", + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + if config.is_fbcode(): + log_optimus_to_scuba( + extra_logging={"pt2_configs": str(get_patched_config_dict())} + ) + + with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding( + example_inputs + ): + const_output_index = None + const_graph = None + const_code = None + + if aot_mode and config.aot_inductor.use_runtime_constant_folding: + const_gm, const_output_index = split_const_gm(gm) + + const_graph = GraphLowering( + const_gm, + example_inputs=[], + shape_env=shape_env, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + user_visible_outputs=user_visible_outputs, + extern_node_serializer=extern_node_serializer, + is_inference=is_inference, + is_const_graph=True, + ) + with V.set_graph_handler(const_graph): + assert cpp_wrapper, "AOT mode only supports C++ wrapper" + const_graph.run() + + const_code, _ = const_graph.codegen_with_cpp_wrapper() + + graph = GraphLowering( + gm, + # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning. + # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass, + # we currently use fake tensors and defake them later. + example_inputs=example_inputs, + shape_env=shape_env, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + user_visible_outputs=user_visible_outputs, + extern_node_serializer=extern_node_serializer, + is_inference=is_inference, + const_output_index=const_output_index, + const_code=const_code, + const_module=const_graph, + ) + metrics_helper = metrics.CachedMetricsHelper() + with V.set_graph_handler(graph): + graph.run(*example_inputs) + output_strides: List[Optional[Tuple[_StrideExprStr, ...]]] = [] + if graph.graph_outputs is not None: + # We'll put the output strides in the compiled graph so we + # can later return them to the caller via TracingContext + p = SymExprPrinter() + for out in graph.graph_outputs: + if ( + hasattr(out, "layout") + and len(free_unbacked_symbols(out.layout.stride)) == 0 + ): + # Convert to string for eval on the load path + output_strides.append( + tuple(p.doprint(s) for s in out.layout.stride) + ) + else: + output_strides.append(None) + + _check_triton_bf16_support(graph) + compiled_fn = graph.compile_to_fn() + num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() + metrics.num_bytes_accessed += num_bytes + metrics.node_runtimes += node_runtimes + metrics.nodes_num_elem += nodes_num_elem + + if ( + cudagraphs + and config.triton.cudagraph_skip_dynamic_graphs + and not V.graph.disable_cudagraphs_reason + and torch._inductor.utils.any_is_symbolic(*example_inputs) + ): + stack_trace = None + for node in gm.graph.nodes: + meta_val = node.meta.get("val", None) + if ( + node.op == "placeholder" + or not isinstance(meta_val, torch.Tensor) + or not torch._inductor.utils.any_is_symbolic(meta_val) + ): + continue + + if stack_trace := node.meta.get("stack_trace", None): + break + disable = "graph with symbolic shapes inputs and config.triton.cudagraph_skip_dynamic_graphs=True." + if stack_trace: + disable = f"{disable} Found from {stack_trace}\n" + else: + disable = f"{disable}\n" + V.graph.disable_cudagraphs_reason = disable + + if V.aot_compilation is True: + return compiled_fn + + if cudagraphs and not V.graph.disable_cudagraphs_reason: + from torch._inductor.cudagraph_utils import ( + check_lowering_disable_cudagraph, + ) + + V.graph.disable_cudagraphs_reason = ( + check_lowering_disable_cudagraph(V.graph.device_node_mapping) + ) + + compiled_graph = CompiledFxGraph( + compiled_fn, + graph, + output_strides, + V.graph.disable_cudagraphs_reason, + metrics_helper.get_deltas(), + counters["inductor"] - inductor_counters, + ) + + return compiled_graph + + +def get_input_idxs_to_check( + inputs: List[InputType], + static_input_idxs: Sequence[int], +) -> Sequence[int]: + """ + This function runs at compile time, and generates a list of indices for which we + might need to do a copy to preserve alignment requirements. + """ + ids_to_check = [] + + for i, input in enumerate(inputs): + if not isinstance(input, torch.Tensor): + # non-tensors don't need alignment + continue + if not is_gpu(input.device.type): + # right now we only care for gpu tensors + continue + with maybe_get_suppress_shape_guards_ctx(): + # suppress guards so that tensor_is_aligned and should_assume_input_aligned + # do not add guards on input's storage offset + if i in static_input_idxs and tensor_is_aligned(input): + continue + if not should_assume_input_aligned(input): + continue + + # if we get here, then + # (a) our triton code assumes that the input is aligned + # (b) we can't be sure ahead of time that the input will actually be aligned. + # therefore, at runtime, we'll need to check that the input is aligned + # (and if not, clone it to make it aligned.) + ids_to_check.append(i) + + return ids_to_check + + +def cudagraphify( + model: Callable[..., Any], + static_input_idxs: Sequence[int] = (), + *, + device_index: int, + stack_traces: List[Optional[str]], + is_backward: bool, + is_inference: bool, + constants: Tuple[torch.Tensor, ...] = (), + placeholders: Sequence[PlaceholderInfo] = (), + mutated_input_idxs: Tuple[int, ...] = (), +) -> Callable[..., Any]: + from torch._inductor.cudagraph_trees import ( + cudagraphify_impl as new_cudagraphify_impl, + ) + + cudagraphify_fn: Callable[..., Any] + if config.triton.cudagraph_trees: + cudagraphify_fn = functools.partial( + new_cudagraphify_impl, + device_index=device_index, + stack_traces=stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=constants, + placeholders=placeholders, + mutated_input_idxs=mutated_input_idxs, + ) + else: + cudagraphify_fn = cudagraphify_impl + + compiled_fn = None + + def run(new_inputs): + nonlocal compiled_fn + if compiled_fn is None: + with dynamo_utils.dynamo_timed( + "cudagraphify" + ), dynamo_utils.preserve_rng_state(): + compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs) + return compiled_fn(new_inputs) + + return run + + +def static_input(x: torch.Tensor) -> torch.Tensor: + """ + Copy and input while preserving strides + """ + return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device) + + +def index_expanded_dims_and_copy_( + dst: torch.Tensor, + src: torch.Tensor, + expanded_dims: List[int], +): + "Index into expanded dimensions of both dst and src then copy_" + dst = index_expanded_dims(dst, expanded_dims) + src = index_expanded_dims(src, expanded_dims) + dst.copy_(src) + + +def cudagraphify_impl( + model: Callable[..., Any], + inputs: List[torch.Tensor], + static_input_idxs: Sequence[int] = (), +): + """ + Assumes inputs[static_input_idxs[i]] are always the same memory address + """ + check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type] + static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] + copy_misaligned_inputs(inputs, check_input_idxs) # type: ignore[arg-type] + + assert isinstance(inputs, list) + + inps_expanded_dims = [ + get_expanded_dims(x) if idx not in static_input_idxs else [] + for idx, x in enumerate(inputs) + ] + + # allocate static tensor inputs + static_inputs = [ + x + if not isinstance(x, torch.Tensor) + else static_input(x) + if idx not in static_input_idxs + else x.detach() + for idx, x in enumerate(inputs) + ] + + # copy over input values for fresh allocations + for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)): + if isinstance(x, torch.Tensor) and idx not in static_input_idxs: + index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims) + + # warmup + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + # copy static_inputs because it will be cleared in model + with torch.cuda.stream(stream): + model(list(static_inputs)) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + # record + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"): + static_outputs = model(list(static_inputs)) + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + if config.size_asserts: + + def run(new_inputs): + assert len(static_inputs) == len(new_inputs) + for idx, (dst, src, expanded_dims) in enumerate( + zip(static_inputs, new_inputs, inps_expanded_dims) + ): + if not isinstance(dst, torch.Tensor): + pass + elif idx in static_input_idxs: + assert dst.data_ptr() == src.data_ptr() + else: + # TODO - could make one single op of multiple slices + # and avoid dispatch. + # Could also pre-index the `dst` tensors + index_expanded_dims_and_copy_(dst, src, expanded_dims) + new_inputs.clear() + graph.replay() + return static_outputs + + else: + copy_indices = [ + idx for idx in range(len(static_inputs)) if idx not in static_input_idxs + ] + + def run(new_inputs): + for idx in copy_indices: + expanded_dims = inps_expanded_dims[idx] + index_expanded_dims_and_copy_( + static_inputs[idx], new_inputs[idx], expanded_dims + ) + new_inputs.clear() + graph.replay() + return static_outputs + + return align_inputs_from_check_idxs(run, check_input_idxs) + + +def compile_fx_aot( + model_: torch.fx.GraphModule, + example_inputs_: List[torch.Tensor], + inner_compile: Callable[..., Any] = compile_fx_inner, + config_patches: Optional[Dict[str, Any]] = None, +): + config_patches: Dict[str, Any] = ( + {"cpp_wrapper": True} + if config_patches is None + else {**config_patches, "cpp_wrapper": True} + ) + if ( + "aot_inductor.output_path" not in config_patches + and not config.aot_inductor.output_path + ): + config_patches = { + **config_patches, + "aot_inductor.output_path": code_hash(model_.code), + } + + extern_node_serializer = config_patches.pop("extern_node_serializer", None) + with V.set_aot_compilation(True): + compiled_lib_path = compile_fx( + model_, + example_inputs_, + inner_compile=functools.partial( + inner_compile, + aot_mode=True, + extern_node_serializer=extern_node_serializer, + ), + config_patches=config_patches, + ) + assert os.path.exists( + compiled_lib_path + ), f"AOTInductor compiled library does not exist at {compiled_lib_path}" + return compiled_lib_path + + +_graph_counter = count(0) + + +def fw_compiler_freezing( + aot_autograd_model: torch.fx.GraphModule, + aot_example_inputs: List[torch.Tensor], + dynamo_model: torch.fx.GraphModule, + num_example_inputs: int, + inner_compile: Callable[..., Any], + cudagraphs: BoxedBool, + graph_id: int, + forward_device: BoxedDeviceIndex, +): + from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze + + # partition_fn won't be called + _recursive_joint_graph_passes(aot_autograd_model) + + layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True) + if layout_opt: + # make sure meta['val'] is properly setup + fake_tensor_prop(aot_autograd_model, aot_example_inputs, True) + convert_conv_weights_to_channels_last(aot_autograd_model) + + opt_model, preserved_arg_indices = freeze( + dynamo_model, + aot_autograd_model, + aot_example_inputs, # type: ignore[arg-type] + ) + + aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices] + num_fixed = len(preserved_arg_indices) - num_example_inputs + + fake_mode = detect_fake_mode(aot_example_inputs) + + # for freezing, all graph outputs should be user visible + *_, model_outputs_node = opt_model.graph.nodes + model_outputs = model_outputs_node.args[0] + user_visible_outputs = dict.fromkeys( + n.name for n in model_outputs if isinstance(n, torch.fx.Node) + ) + + static_input_idxs = list(range(num_fixed)) + # constant params will be real tensors, not fake + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context is not None: + params_flat = tracing_context.params_flat + assert params_flat is not None + for i in range(len(params_flat)): + if i not in preserved_arg_indices: + params_flat[i] = None + + if tracing_context.fw_metadata: + static_input_idxs += tracing_context.fw_metadata.static_input_indices + + with mock.patch.object(fake_mode, "allow_non_fake_inputs", True): + optimized_function = inner_compile( + opt_model, + aot_example_inputs, + static_input_idxs=static_input_idxs, + cudagraphs=cudagraphs, + graph_id=graph_id, + is_inference=True, + boxed_forward_device_index=forward_device, + layout_opt=layout_opt, + user_visible_outputs=user_visible_outputs, + ) + + # aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper + # that drops constant-ified params + if V.aot_compilation is True: + return optimized_function + + def wrapper(args): + args_new = [args[i] for i in preserved_arg_indices] + args.clear() + return optimized_function(args_new) + + wrapper._boxed_call = True # type: ignore[attr-defined] + + return wrapper + + +def compile_fx( + model_: torch.fx.GraphModule, + example_inputs_: List[torch.Tensor], + inner_compile: Callable[..., Any] = compile_fx_inner, + config_patches: Optional[Dict[str, Any]] = None, + decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None, +): + with _use_lazy_graph_module(dynamo_config.use_lazy_graph_module): + """Main entrypoint to a compile given FX graph""" + if config_patches: + with config.patch(config_patches): + return compile_fx( + model_, + example_inputs_, + # need extra layer of patching as backwards is compiled out of scope + inner_compile=config.patch(config_patches)(inner_compile), + decompositions=decompositions, + ) + + if config.cpp_wrapper: + with config.patch( + { + "cpp_wrapper": False, + # For triton.autotune_at_compile_time, disable by default for + # FBCode, but enabled by default for OSS. + "triton.autotune_at_compile_time": config.triton.autotune_at_compile_time + if config.is_fbcode() + else os.environ.get( + "TORCHINDUCTOR_TRITON_AUTOTUNE_AT_COMPILE_TIME", "1" + ) + == "1", + "triton.autotune_cublasLt": False, + "triton.cudagraphs": False, + "triton.store_cubin": True, + } + ), V.set_real_inputs(example_inputs_): + inputs_ = example_inputs_ + if isinstance(model_, torch.fx.GraphModule): + fake_inputs = [ + node.meta.get("val") + for node in model_.graph.nodes + if node.op == "placeholder" + ] + if all(v is not None for v in fake_inputs): + # Validate devices before switching to fake tensors. + for idx, fi, i in zip(count(), fake_inputs, inputs_): + if fi.device != i.device: + raise ValueError( + f"Device mismatch between fake input and example input at position #{idx}: " + f"{fi.device} vs {i.device}. If the model was exported via torch.export(), " + "make sure torch.export() and torch.aot_compile() run on the same device." + ) + inputs_ = fake_inputs + return compile_fx( + model_, + inputs_, + inner_compile=functools.partial(inner_compile, cpp_wrapper=True), + decompositions=decompositions, + ) + + recursive_compile_fx = functools.partial( + compile_fx, + inner_compile=inner_compile, + decompositions=decompositions, + ) + + if not graph_returns_tuple(model_): + return make_graph_return_tuple( + model_, + example_inputs_, + recursive_compile_fx, + ) + + if isinstance(model_, torch.fx.GraphModule): + if isinstance(model_.graph._codegen, _PyTreeCodeGen): + # this graph is the result of dynamo.export() + return handle_dynamo_export_graph( + model_, + example_inputs_, + recursive_compile_fx, + ) + + model_ = _recursive_pre_grad_passes(model_, example_inputs_) + + if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_): + return flatten_graph_inputs( + model_, + example_inputs_, + recursive_compile_fx, + ) + + assert not config._raise_error_for_testing + num_example_inputs = len(example_inputs_) + cudagraphs = BoxedBool(config.triton.cudagraphs) + forward_device = BoxedDeviceIndex(None) + + graph_id = next(_graph_counter) + + decompositions = ( + decompositions if decompositions is not None else select_decomp_table() + ) + + def fw_compiler_base( + model: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + is_inference: bool, + ): + with dynamo_utils.dynamo_timed("compile_fx..fw_compiler_base"): + return _fw_compiler_base(model, example_inputs, is_inference) + + def _fw_compiler_base( + model: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + is_inference: bool, + ): + if is_inference: + # partition_fn won't be called + _recursive_joint_graph_passes(model) + + fixed = torch._inductor.utils.num_fw_fixed_arguments( + num_example_inputs, len(example_inputs) + ) + + user_visible_outputs = {} + + if config.keep_output_stride: + model_outputs_node = output_node(model) + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + num_model_outputs = len(model_outputs) + + context = torch._guards.TracingContext.try_get() + # See Note [User Outputs in the inductor graph] + if context is not None and context.fw_metadata and not is_inference: + original_output_start_index = ( + context.fw_metadata.num_mutated_inp_runtime_indices + ) + else: + original_output_start_index = 0 + + if isinstance(model_, torch.fx.GraphModule): + *_, orig_model_outputs_node = model_.graph.nodes + assert orig_model_outputs_node.op == "output" + orig_model_outputs, _ = pytree.tree_flatten( + orig_model_outputs_node.args + ) + num_orig_model_outputs = len(orig_model_outputs) + else: + num_orig_model_outputs = num_model_outputs + + assert num_orig_model_outputs <= num_model_outputs + + # Note [User Outputs in the inductor graph] + # We makes the following assumption + # For inference + # len(orig_model_outputs) == len(model_outputs) + # For training + # len(orig_model_outputs) <= len(model_outputs) + # During training, most of the time the model_outputs starts with + # original module's outputs followed by saved activations. + # But this can be not true if the model have inplace updated tensors. + # AOTAutograd will make those tensors being returned before the original + # module's output. + # To make things safe, we'll use original_output_start_index field + # set by AOTAutograd to decide where the original module outputs start. + orig_output_end_idx = ( + original_output_start_index + num_orig_model_outputs + ) + # Sanity chec: we are about to splice out the "user" outputs from the full set + # of "graph" outputs. Make sure we're within bounds. + assert orig_output_end_idx <= num_model_outputs + + user_visible_outputs = dict.fromkeys( + n.name + for n in model_outputs[ + original_output_start_index:orig_output_end_idx + ] + if isinstance(n, torch.fx.Node) + ) + + return inner_compile( + model, + example_inputs, + static_input_idxs=get_static_input_idxs(fixed), + cudagraphs=cudagraphs, + graph_id=graph_id, + is_inference=is_inference, + boxed_forward_device_index=forward_device, + user_visible_outputs=user_visible_outputs, + ) + + fw_compiler = functools.partial(fw_compiler_base, is_inference=False) + + if config.freezing and not torch.is_grad_enabled(): + inference_compiler = functools.partial( + fw_compiler_freezing, + dynamo_model=model_, + num_example_inputs=num_example_inputs, + inner_compile=inner_compile, + cudagraphs=cudagraphs, + graph_id=graph_id, + forward_device=forward_device, + ) + else: + inference_compiler = functools.partial(fw_compiler_base, is_inference=True) + + def partition_fn(graph, joint_inputs, **kwargs): + _recursive_joint_graph_passes(graph) + return min_cut_rematerialization_partition( + graph, joint_inputs, **kwargs, compiler="inductor" + ) + + def bw_compiler( + model: torch.fx.GraphModule, example_inputs: List[torch.Tensor] + ): + with dynamo_utils.dynamo_timed("compile_fx..bw_compiler"): + user_visible_outputs = {} + + if config.bw_outputs_user_visible: + model_outputs_node = output_node(model) + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + user_visible_outputs = dict.fromkeys( + n.name for n in model_outputs if isinstance(n, torch.fx.Node) + ) + fixed = count_tangents(model) + return inner_compile( + model, + example_inputs, + static_input_idxs=list(range(fixed)), + cudagraphs=cudagraphs, + is_backward=True, + graph_id=graph_id, + boxed_forward_device_index=forward_device, + user_visible_outputs=user_visible_outputs, + ) + + # TODO: can add logging before/after the call to create_aot_dispatcher_function + # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func + # once torchdynamo is merged into pytorch + + fake_mode = detect_fake_mode( + example_inputs_ + ) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + tracing_context = ( + torch._guards.TracingContext.try_get() + or torch._guards.TracingContext(fake_mode) + ) + + if V.aot_compilation is True: + with functorch_config.patch(unlift_effect_tokens=True): + gm, graph_signature = aot_export_module( + model_, + example_inputs_, + trace_joint=False, + decompositions=decompositions, + ) + unlifted_gm = _unlift_graph(model_, gm, graph_signature) + if "dynamo_flat_name_to_original_fqn" in model_.meta: + unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[ + "dynamo_flat_name_to_original_fqn" + ] + + # Disable amp as in aot_dispatch_autograd (https://github.com/pytorch/pytorch/pull/86515) + # In inference_compiler (fw_compiler_base), _recursive_joint_graph_passes will call into + # _sfdp_init() to register patterns. + # When fallback_random is set to True, the sdpa patterns will be traced during runtime. + # If amp is turned on, the traced FP32 patterns will have prims.convert_element_type which + # will be the same as the generated FP16 patterns. + disable_amp = torch._C._is_any_autocast_enabled() + context = ( + torch._C._DisableAutocast if disable_amp else contextlib.nullcontext + ) + with V.set_fake_mode(fake_mode), compiled_autograd.disable(), context(): + return inference_compiler(unlifted_gm, example_inputs_) + + with V.set_fake_mode(fake_mode), torch._guards.tracing( + tracing_context + ), compiled_autograd.disable(), functorch_config.patch( + unlift_effect_tokens=True + ): + return aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + inference_compiler=inference_compiler, + decompositions=decompositions, + partition_fn=partition_fn, + keep_inference_input_mutations=True, + cudagraphs=cudagraphs, + )(model_, example_inputs_) + + +def graph_returns_tuple(gm: torch.fx.GraphModule): + """True if a FX graph returns a tuple""" + if not isinstance(gm, torch.fx.GraphModule): + return True # can't check this, assume true + (rv,) = output_node(gm).args + if isinstance(rv, (list, tuple)): + return True + if ( + isinstance(rv, torch.fx.node.Node) + and hasattr(rv.target, "_schema") + and len(rv.target._schema.returns) > 1 + and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns) + ): + # for graphs whose result is one node with multiple outputs + return True + return False + + +def make_graph_return_tuple( + gm: torch.fx.GraphModule, + inputs: List[torch.Tensor], + compile_gm: Callable[..., Any], +): + """ + Mutate gm so it returns a tuple. This is only needed for graphs + not created by torchdynamo that return non-tuples. + """ + node = output_node(gm) + (rv,) = node.args + rv, spec = pytree.tree_flatten(rv) + with gm.graph.inserting_before(node): + gm.graph.output(rv) + gm.graph.erase_node(node) + assert graph_returns_tuple(gm) + + compiled_fn = compile_gm(gm, inputs) + + @functools.wraps(compiled_fn) + def wrapper(*args, **kwargs): + return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec) + + return wrapper + + +def handle_dynamo_export_graph( + gm: torch.fx.GraphModule, + inputs: List[torch.Tensor], + compile_gm: Callable[..., Any], +): + """ + `torch._dynamo.export` embeds pytrees in the FX graph codegen object, + convert that to a normal FX graph so inductor can compile it. + """ + codegen = gm.graph._codegen + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.recompile() + + compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs)) + + @functools.wraps(compiled_fn) + def wrapper(*args): + return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args))) + + return wrapper + + +def _check_triton_bf16_support(graph: GraphLowering) -> None: + def warn_and_skip(device) -> None: + from torch._dynamo.exc import SkipFrame + + device_interface = get_interface_for_device(device.type) + device_props = device_interface.get_device_properties(device) + warnings.warn( + f"{device_props.name} does not support bfloat16 compilation natively, skipping" + ) + raise SkipFrame("BF16 is not supported") + + for inp in graph.graph_inputs.values(): + device = getattr(inp, "get_device", lambda: torch.device("meta"))() + if (not is_gpu(device.type)) or inp.get_dtype() != torch.bfloat16: + continue + # Print warning and skip frame if attempting to compile for bfloat16 + # on device without hardware support for dtype + device_interface = get_interface_for_device(device.type) + if device_interface.is_bf16_supported(including_emulation=False): + return + warn_and_skip(device) + + for out in graph.graph_outputs: + device = getattr(out, "get_device", lambda: torch.device("meta"))() + if (not is_gpu(device.type)) or out.get_dtype() != torch.bfloat16: + continue + # Print warning and skip frame if attempting to compile for bfloat16 + # on device without hardware support for dtype + device_interface = get_interface_for_device(device.type) + if device_interface.is_bf16_supported(including_emulation=False): + return + warn_and_skip(device) diff --git a/lib/python3.10/site-packages/torch/_inductor/config.py b/lib/python3.10/site-packages/torch/_inductor/config.py new file mode 100644 index 0000000000000000000000000000000000000000..80b65fbc31af9173e89c7005c67718caa3f51c9c --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/config.py @@ -0,0 +1,1241 @@ +import os # noqa: C101 +import sys +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union + +import torch + + +def is_fbcode() -> bool: + return not hasattr(torch.version, "git_version") + + +def fx_graph_remote_cache_default() -> Optional[bool]: + if os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "1": + return True + if os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "0": + return False + return None + + +def autotune_remote_cache_default() -> Optional[bool]: + if os.environ.get("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") == "1": + return True + if os.environ.get("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") == "0": + return False + return None + + +# Enable auto_functionalized_v2 (enabled by default) +enable_auto_functionalized_v2 = ( + os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "0") == "1" +) + +# add some debug printouts +debug = False + +# Whether to disable a progress bar for autotuning +disable_progress = True + +# Whether to enable printing the source code for each future +verbose_progress = False + +# use fx aot graph codegen cache +fx_graph_cache = ( + os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE", "0" if is_fbcode() else "1") == "1" +) + +# use remote fx aot graph codegen cache +# False: Disables the cache +# True: Enables the cache +# None: Not set -- Off for OSS, JustKnobs based for internal +fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default() + +# enable autotune local cache +autotune_local_cache = True + +# enable autotune remote cache +# False: Disables the cache +# True: Enables the cache +# None: Not set -- Off for OSS, JustKnobs based for internal +autotune_remote_cache: Optional[bool] = autotune_remote_cache_default() + +# Force disabled all inductor level caching -- This will override any other caching flag +force_disable_caches = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "1" + +# sleep in inductor for testing +sleep_sec_TESTING_ONLY: Optional[int] = None + +# The default layout constraint for custom operators. +# This must be the name of one of the layout constraint tags +# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}), +# If the custom op does not have a layout constraint tag already +# then we assume the following applies. +custom_op_default_layout_constraint = "flexible_layout" + +# use cpp wrapper instead of python wrapper +cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" + +# codegen cpp wrapper code in an ABI compatible mode +abi_compatible = ( + os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE", "1" if is_fbcode() else "0") == "1" +) + +c_shim_version = os.environ.get("TORCHINDUCTOR_C_SHIM_VERSION", "2") + +# dead code elimination +dce = False + +# assume weight tensors are fixed size +static_weight_shapes = True + +# put correctness assertions in generated code +size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1" +nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1" + +# enable loop reordering based on input orders +pick_loop_orders = True + +# reuse a kernel input as the output +inplace_buffers = True + +# reuse a buffer for an unrelated purpose +allow_buffer_reuse = True + +# Enable pooled allocations for non-output tensors +memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1" + +# How to organize memory under memory_planning=True: +# - "none": do not try to pool storage, just reuse +# - "intermediates": all non-outputs share storage, outputs each get unique storage +# - "outputs": two pools, one for intermediates (freed on return) and one for outputs +# - "combined": a single pool for both intermediates and outputs +memory_pool = os.environ.get("TORCHINDUCTOR_MEMORY_POOL", "intermediates") + +# codegen benchmark harness +benchmark_harness = True + +# fuse pointwise into templates +epilogue_fusion = True + +# do epilogue fusions before other fusions +epilogue_fusion_first = False + +# enable pattern match+replace optimizations +pattern_matcher = True + +# set to True to enable the back-to-back GEMM pass +b2b_gemm_pass = False + +# register custom graph optimization pass hook. so far, pre/post passes are +# only applied before/after pattern_matcher in post_grad_passes. +# +# def my_custom_pre_pass(graph: torch.fx.graph.Graph): +# # my custom graph optimization pass +# ... +# +# def my_custom_post_pass(graph: torch.fx.graph.Graph): +# # my custom graph optimization pass +# ... +# +# torch._inductor.config.post_grad_custom_pre_pass = my_custom_pre_pass +# torch._inductor.config.post_grad_custom_post_pass = my_custom_post_pass +post_grad_custom_pre_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None +post_grad_custom_post_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None + +# Registers a custom joint graph pass. +joint_custom_pre_pass: Optional[Callable[[torch.fx.Graph], None]] = None +joint_custom_post_pass: Optional[Callable[[torch.fx.Graph], None]] = None + +# Registers a custom pregrad pass. Note that the pre-grad IR is 1. +# non-functional, 2. non-normalized, and 3. prone to change. Ideally we should +# use post-grad passes. +pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None + +# Registers a custom pass to be run right before fusion in Inductor scheduler. +# WARNING: Inductor scheduler IR is at prototype stage and subject to change, +# hence custom IR passes built on top of it might break in the future. +_pre_fusion_custom_pass: Optional[ + Callable[ + [List["torch._inductor.scheduler.BaseSchedulerNode"]], + List["torch._inductor.scheduler.BaseSchedulerNode"], + ] +] = None + +# Deprecated +split_cat_fx_passes = True + +# Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability. +efficient_conv_bn_eval_fx_passes = False + +# Enable predispatch aten IR for export +is_predispatch = False + +# Deprecated +group_fusion = False + +# Deprecated +batch_fusion = True + +# Pre grad fusion and options in order, set to empty dict to disable fusion. +# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions. +# batch fusion options: +# batch_linear +# batch_linear_lhs +# batch_layernorm +# batch_tanh +# batch_relu +# batch_sigmoid + +# split cat fusion options: +# normalization_pass +# remove_split_with_size_one_pass +# merge_getitem_cat_pass +# merge_stack_tahn_unbind +# merge_splits_pass +# mutate_cat_pass +# split_cat_pass +pre_grad_fusion_options: Dict[str, Dict[str, Any]] = { + "batch_linear": {}, + "batch_linear_lhs": {}, + "batch_layernorm": {}, + "batch_tanh": {}, + "batch_relu": {}, + "batch_sigmoid": {}, +} + +# Post grad fusion and options, set to empty dict to disable fusion. +# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions. +post_grad_fusion_options: Dict[str, Dict[str, Any]] = {} + +# enable reordering pass for improving memory locality +reorder_for_locality = True + +# Scale down RBLOCK for better occupancy +dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1" + +# this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32 +# but the mul gets fused with other pointwise ops instead. +force_fuse_int_mm_with_mul = False + +# for pattern torch.mm(a, b.to(dtype)) with cuda tensors, +# enable torch._inductor.kernel.mm.tuned_mixed_mm fused kernel. +# Autotune will compare perf with normal cast->then->mm option +use_mixed_mm = True + +# enable runtime numeric check for pre/post grad fx passes +# floating point provides limited accuracy (about 7 decimal digits for single precision +# floating point numbers,about 16 decimal digits for double precision floating point numbers) +# according to PyTorch documentation. +# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations +fx_passes_numeric_check: Dict[str, Any] = { + "pre_grad": False, + "precision": 1e-4, + "num_iterations": 1, + "requires_optimizer": True, +} + +# mixed_mm_choice can be used to control the behaviour for pattern torch.mm(a, b.to(dtype)) with cuda tensors. +# The fallback aten implementation is normal cast->then->mm option. +# If mixed_mm_choice is "default": this flag will be ignored. +# If mixed_mm_choice is "triton": +# - Always use torch._inductor.kernel.mm.tuned_mixed_mm's fused kernel. +# - Autotune will not compare with fallback. +# If mixed_mm_choice is "aten": always use the fallback aten implementation. +# If mixed_mm_choice is "heuristic": +# - Enables the heuristic. +# - If the heuristic decides to add a config, it will add the config as the first choice. +# - If autotune is disabled, this config will always be chosen. +# - If autotune is enabled, it will also compare with fallback aten implementation and fused kernel. +# The use_mixed_mm flag will be ignored if mixed_mm_choice != "default". +mixed_mm_choice = "heuristic" + +# enable reordering pass for increasing overlap between compute and communication +reorder_for_compute_comm_overlap = False + +# passes (in execution order) for increasing overlap between compute and communication +# for built-in passes, use string name; for user-defined passes, pass in the function handle +# WARNING: Inductor scheduler IR is at prototype stage and subject to change, +# hence custom IR passes built on top of it might break in the future. +reorder_for_compute_comm_overlap_passes = [ + "reorder_compute_for_overlap", + "sink_waits", + "raise_comms", +] + +# runtime estimation function for ops +# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle +estimate_op_runtime = "default" + +# unit: GB/s, uni-directional P2P bandwidth per card +# default value is NVLink +intra_node_bw = 300 + +# unit: GB/s, uni-directional P2P bandwidth per node +# default value is InfiniBand +inter_node_bw = 25 + +# enable slow autotuning passes to select algorithms +max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1" + +# enable slow autotuning passes to select pointwise/reductions algorithms +max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1" + +# enable slow autotuning passes to select gemm algorithms +max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1" + +# force cublas and triton to use the same precision; cublas supports TF32 for matmul operations +# when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations +# for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure +# that triton does not use TF32 wherever cublas would not use TF32 +force_same_precision = ( + True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1" +) + +# Specify candidate backends for gemm autotune. +# Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CPP. +# ATen: default Pytorch ATen kernels. +# Triton: Triton templates defined in torch inductor (AMD and NVidia GPUs). +# CUTLASS: Cutlass templates and kernels (NVidia GPUs only). +# CK: Composable Kernel templates and kernels (AMD Instinct GPUs only). +# CPP: CPP templates and kernels for CPU. +max_autotune_gemm_backends = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP" +).upper() + +# As above, specify candidate backends for conv autotune. +# NB: in some cases for 1x1 convs we emit as matmul, +# which will use the backends of `max_autotune_gemm_backends` +max_autotune_conv_backends = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_CONV_BACKENDS", "ATEN,TRITON" +).upper() + + +# Specify the size of the search space for GEMM autotuning. +# DEFAULT - balance between compile time overhead and performance +# EXHAUSTIVE - maximize performance +max_autotune_gemm_search_space = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT" +).upper() + +# Whether we fall back to ATen or hard error when no matches are found during autotuning +autotune_fallback_to_aten = ( + os.environ.get("TORCHINDUCTOR_AUTOTUNE_FALLBACK_TO_ATEN", "1") == "1" +) + +# the value used as a fallback for the unbacked SymInts +# that can appear in the input shapes (e.g., in autotuning) +unbacked_symint_fallback = 8192 + +# DEPRECATED, DO NOT USE +search_autotune_cache = False + +save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1" + +# We will disable creating subprocess for autotuning if this is False +autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1" + +# The following three timeouts are applicable if autotune_in_subproc is True: + +# Max time that a a valid benchmark result may take during autotuning +max_autotune_subproc_result_timeout_seconds = 60.0 +# Additional time we allow subprocesses to terminate gracefully after the timeout until we send a SIGTERM +max_autotune_subproc_graceful_timeout_seconds = 1.0 +# Additional time that we grant after a SIGTERM until we do a hard SIGKILL of subprocesses +max_autotune_subproc_terminate_timeout_seconds = 2.0 + +# If autotuning in subprocess, whether to use multiple devices +autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1" + +coordinate_descent_tuning = ( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1" +) +coordinate_descent_check_all_directions = ( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1" +) +coordinate_descent_search_radius = int( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1") +) + +# AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and +# generate the learned heursitic to code which is shipped with the compiler +# Specify a list of comma separated optimizations to collect data for +autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "") +# Specify a list of comma separated optimizations to use learned heuristics for +autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "mixed_mm") + + +def run_autoheuristic(name: str) -> bool: + return collect_autoheuristic(name) or use_autoheuristic(name) + + +def collect_autoheuristic(name: str) -> bool: + return name in torch._inductor.config.autoheuristic_collect.split(",") + + +def use_autoheuristic(name: str) -> bool: + return name in torch._inductor.config.autoheuristic_use.split(",") + + +# If set to "DEFAULT", this will use the default log path specified in autoheuristic.py. +# If set to another path, autoheuristic will instead log results to the given path. +autoheuristic_log_path = os.environ.get( + "TORCHINDUCTOR_AUTOHEURISTIC_LOG_PATH", "DEFAULT" +) + +# Disabled by default on ROCm, opt-in if model utilises NHWC convolutions +layout_opt_default = "1" if not torch.version.hip else "0" +layout_optimization = ( + os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1" +) + +force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1" + + +# Whether to keep the output strides the same as eager after layout optimization. +keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1" + +# Enabling this will let compiler print warning messages if a generated triton +# kernel has inputs with mixed layouts. This is helpful for perf debugging +# since kernel with mixed layout inputs may run much slower then one whose inputs +# have uniform layouts. +warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1" + +# control store vs recompute heuristic +# For fanouts, rematerialization can lead to exponential blowup. So, have +# smaller threshold +realize_reads_threshold = 4 +realize_opcount_threshold = 30 + +# Threshold to prevent excessive accumulation of ops in one buffer during lowering +realize_acc_reads_threshold = 8 + +# fallback to eager for random/dropout, this is slow but useful for debugging +fallback_random = False + +# automatically create fallbacks when encountering an unhandled op +implicit_fallbacks = True + +# fuse even in cases without common reads +aggressive_fusion = False + +# For each fused kernel in the wrapper, comment with the nodes that get fused. +# Useful for debugging fusion. +debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" +benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1" +enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "") +loop_ordering_after_fusion = ( + os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1" +) + +# For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel +benchmark_epilogue_fusion = ( + os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1" +) + +# Take how many of the top triton kernels to benchmark epilogue +max_epilogue_benchmarked_choices = 1 + +# how many nodes to allow into a single fusion +max_fusion_size = 64 + +# max number of inputs to generate cat as a pointwise op with masked laods +max_pointwise_cat_inputs = 8 + +# replace small reductions with pointwise, disable with `= 1` +unroll_reductions_threshold = 8 + +# Add extra comments to output code (causes compile cache misses) +comment_origin = False + +# Convert 1x1 convs into matmuls +conv_1x1_as_mm = False + +# Enable split reductions for better utilization when the dimension +# being reduced over is large (by splitting it) +split_reductions = True + +benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1" + +# Enable constant and index_expr folding +constant_and_index_propagation = True + +# we always add constants into graph.constants without +# performing any constant-inlining optimization +always_keep_tensor_constants = False + +# assert that indirect indexing does not read / write out of bounds +assert_indirect_indexing = True + +# compute CSE bounds on variables that do not appear in the FX graph +compute_all_bounds = False + +# enable the combo kernel that combines data-independent kernels (additional +# to foreach kernels) into a single one (Experimental) +combo_kernels = False +# benchmark combo kernels and only allow ones with perf gains +benchmark_combo_kernel = False +# combo_kernel autotuning options: 0 - disable, 1 - enable except for foreach, +# 2 - enable for all +combo_kernels_autotune = 1 +# Enable masking for combining kernels of mixed sizes: 0 - disable, 1 - enable +# for all except for foreach, 2 - enable for all +combo_kernel_allow_mixed_sizes = 1 +# Enable dynamic shapes for foreach kernels +combo_kernel_foreach_dynamic_shapes = False + +# constant folding on the joint graph +joint_graph_constant_folding = True + +# Enable indirect_indexing asserts for decompositions and lowerings +debug_index_asserts = False + +# Mode to emulate pytorch eager numerics for lower precision (fp16, bf16) +# Pytorch eager computes bf16/fp16 by upcasting inputs to fp32 and downcasting after +# For multiple, fused pointwise nodes, inductor will elide the intermediary upcasts and downcasts +# Typically this should be closer to fp64 ref numerics. However, it can be useful for debugging +# to emulate the eager numerics. +emulate_precision_casts = False + +# warnings intended for PyTorch developers, disable for point releases +is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__ +developer_warnings = is_fbcode() or is_nightly_or_source + +# This pattern matches a special usage of scatter +# 1. It's applied to a constant tensor +# 2. The index tensor has size 1 in the scatter dimension +# Such pattern generates a sparse matrix when the const tensor is all-zero. +# We can lower this pattern to a pointwise kernel for more fusion opportunities +# and saving memory footprint. +optimize_scatter_upon_const_tensor = ( + os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1" +) + + +# The multiprocessing start method to use for inductor workers in the codecache. +# Can be "subprocess" or "fork". +def decide_worker_start_method() -> str: + start_method = os.environ.get( + "TORCHINDUCTOR_WORKER_START", "fork" if is_fbcode() else "subprocess" + ) + assert start_method in ( + "subprocess", + "fork", + ), f"Invalid start method: {start_method}" + return start_method + + +worker_start_method = decide_worker_start_method() + +# Flags to turn on all_reduce fusion. These 2 flags should be automaticaly turned +# on by DDP and should not be set by the users. +_fuse_ddp_communication = False +_fuse_ddp_bucket_size = 25 + +# Flag to control which fusion passes to apply. Functions in the list will +# be applied in order. There are two different different fusion passes +# --"fuse_ddp_with_concat_op" and "fuse_ddp_with_coalesced_op". The default +# one is "fuse_ddp_with_concat_op". Users can also change this to a customized +# fusion function. +# +# The fusion currently does not support multiple DDP with different PG or +# data type. This feature will be added in the future PRs. +# +# "schedule_comm_wait" is used to delay the wait ops to maximize comm/comp +# overlapping. At this moment, this pass performs better than +# reorder_for_compute_comm_overlap_passes but we will add the logic of +# "schedule_comm_wait" in the future and remove the one here. +_fuse_ddp_communication_passes: List[Union[Callable[..., None], str]] = [ + "fuse_ddp_with_concat_op", + "schedule_comm_wait", +] + +_micro_pipeline_tp: bool = False + + +def decide_compile_threads() -> int: + """ + Here are the precedence to decide compile_threads + 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by + setting this to 1 to make pdb happy. + 2. Set to 1 if it's win32 platform + 3. decide by the number of CPU cores + """ + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + elif sys.platform == "win32": + return 1 + elif is_fbcode(): + return 1 + else: + cpu_count = ( + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else os.cpu_count() + ) + assert cpu_count + return min(32, cpu_count) + + +compile_threads = decide_compile_threads() + +# gemm autotuning global cache dir +if is_fbcode(): + try: + from libfb.py import parutil + + if __package__: + global_cache_dir = parutil.get_dir_path( + os.path.join(__package__.replace(".", os.sep), "fb/cache") + ) + else: + global_cache_dir = parutil.get_dir_path("fb/cache") + except (ValueError, ModuleNotFoundError): + global_cache_dir = None + +else: + global_cache_dir = None + +# If kernel is fused, the name is generated from the origin node op names +# for larger kernels limit this +kernel_name_max_ops = 10 + +# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs +shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1" + +# Control if we will do padding for pointwise/reductions +comprehensive_padding = ( + os.environ.get("TORCHINDUCTOR_COMPREHENSIVE_PADDING", "1") == "1" +) +pad_channels_last = False + +# Disable comprehensive padding on the CPU +disable_padding_cpu = True + +# The width of comprehensive padding, in bytes. +# CUDA max memory transaction size is 128 bytes for a warp. +padding_alignment_bytes = 128 + +# Threshold on the minimum stride that will be padded. +# +# Don't align a too small stride since that causes too much memory increase. +# Pad too small stride may also cause perf loss. We may result in many tiny data blocks +# with gaps in between. That causes less coalesced GPU memory access! +# +# Initially we pick 320 as the threshold since for alignement=16, +# that results in at most 5% memory cost. +# +# But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. +# Let's say an inner reduction has a row size 513. Inductor will generate +# persistent reduction code. +# If we do padding, the strides are not contiguous any more. Inductor +# uses a much smaller threshold for persistent reduction in this case and +# generates potentially worse non-persistent reduction code. +# +# This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. +# (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) +padding_stride_threshold = 1024 + +# Enable padding outputs, even if they would not be padded in eager mode. +# By default, we use the same strides as eager mode. +pad_outputs = False + +# Whether to treat output of the backward graph as user visible. +# For user visible outputs, inductor will make sure the stride matches with eager. +bw_outputs_user_visible = True + +# Whether to always use shape padding if it is enabled and possible +force_shape_pad: bool = False + +# Fx-based linear/matmul/bmm + permute/transpose vertical fusion +permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" + +# Mark the wrapper call in PyTorch profiler +profiler_mark_wrapper_call = False + +# Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for +# every intermediate for which we can correlate it with an intermediate +# from the original FX graph +generate_intermediate_hooks = False + +# Populate traceback field on IRNode; good for debugging why origin_node is +# not populated, or finding out where an IRNode was constructed +debug_ir_traceback = False + +# used for debugging to make sure config is properly set +_raise_error_for_testing = False + +_profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "") +profile_bandwidth = _profile_var != "" +profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var +# Specify a file where we print out the profiling results. +# None means we do not dump results to a file. +profile_bandwidth_output = os.environ.get("TORCHINDUCTOR_PROFILE_OUTPUT", None) +# Switch to do_bench_using_profiling to exclude the CPU overheads +profile_bandwidth_with_do_bench_using_profiling = ( + os.environ.get("TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING") == "1" +) + + +# TODO: remove later +disable_cpp_codegen = False + + +# Freezing will attempt to inline weights as constants in optimization +# and run constant folding and other optimizations on them. After freezing, weights +# can no longer be updated. +freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1" + +# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead +# of potentially keeping multiple copies of weights. +freezing_discard_parameters: bool = False + +# Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests +# should be run with this flag both on and off to make sure we have coverage. +allow_stack_allocation: bool = ( + os.environ.get("TORCHINDUCTOR_STACK_ALLOCATION", "1" if is_fbcode() else "0") == "1" +) + +# Enables an alternate DSO interface (the "minimal ArrayRef interface") intended +# to maximize performance for use cases that it can accommodate at the expense of +# generality. In brief: +# - inputs and outputs are ArrayRefTensor (note that strides are required, but the +# tensor must be contiguous) +# - constant handling is unchanged because it is not a per-inference-iteration bottleneck +# +# When the DSO is generated in this mode, the usual interface will also be supported, +# but performance for that interface may be degraded. +use_minimal_arrayref_interface: bool = False + +# decompose some memory bound matmul/bmm to mul +decompose_mem_bound_mm: bool = False + +# assume_aligned_inputs means that we assume that inputs will be aligned; we generate +# code using this assumption, and clone tensors before use if they aren't aligned. +# In the common case, most inputs will be aligned. +assume_aligned_inputs: bool = False + +# For the user-written Triton kernels compiled with the model, ignore the unsupported +# arguments passed to the @triton.autotune in the user's code; this is unsafe, as +# ignoring the unsupported args may lead to unexpected autotuning behavior: don't +# set unless you know what you're doing. +unsafe_ignore_unsupported_triton_autotune_args: bool = False + +# When True, we will check in scheduler.py _codegen that there are no "loops" +# in the call stack; that is to say, the same frame multiple times. This +# ensures that a cProfile trace to this frame will be a straight line without +# any cycles. +check_stack_no_cycles_TESTING_ONLY: bool = False + + +# config specific to codegen/cpp.py +class cpp: + # set to torch.get_num_threads() + threads = -1 + + # Do not generate loops when the condition doesn't hold, like: + # for(long i0=4096; i0<4096; i0+=1) + no_redundant_loops = ( + os.environ.get("TORCHINDUCTOR_CPP_NO_REDUNDANT_LOOPS", "1") == "1" + ) + + # Assume number of threads is dynamic, don't specialize thread number. + # Kernels don't recompile on thread number changes with this flag on. + # For single-threaded workload, turning it on would incur a slight + # performance degradation. + dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1" + + simdlen: Optional[int] = None + min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096")) + cxx = ( + None, # download gcc12 from conda-forge if conda is installed + # "g++-12", + # "g++-11", + # "g++-10", + # "clang++", + os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"), + # "g++.par", + ) + # Allow kernel performance profiling via PyTorch profiler + enable_kernel_profile = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_KERNEL_PROFILE", "0") == "1" + ) + + # enable weight prepacking to get a better performance; may lead to large memory footprint + weight_prepack = os.environ.get("TORCHINDUCTOR_CPP_WEIGHT_PREPACK", "1") == "1" + + # Inject a bug into our relu implementation; useful for testing our repro + # extraction and minification functionality. + # Valid values: "compile_error", "runtime_error", "accuracy" + inject_relu_bug_TESTING_ONLY: Optional[str] = None + inject_log1p_bug_TESTING_ONLY: Optional[str] = None + + # If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise, + # force usage as specified, without testing. + vec_isa_ok: Optional[bool] = None + + # similar to config.triton.descriptive_names + descriptive_names = "original_aten" + + # how many nodes to allow into a single horizontal fusion + max_horizontal_fusion_size = int( + os.environ.get("TORCHINDUCTOR_CPP_MAX_HORIZONTAL_FUSION_SIZE", "16") + ) + + # Make scatter_reduce fallback when reduce is sum to avoid performance regression + # using atomic_add. + fallback_scatter_reduce_sum = ( + os.environ.get("TORCHINDUCTOR_CPP_FALLBACK_SCATTER_REDUCE_SUM", "1") == "1" + ) + + # Use funsafe-math-optimizations when compiling + enable_unsafe_math_opt_flag = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_UNSAFE_MATH_OPT_FLAG", "0") == "1" + ) + + # Use ffp-contract when compiling + enable_floating_point_contract_flag = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_FLOATING_POINT_CONTRACT_FLAG", "0") + == "1" + ) + + # Disable the tiling select heuristic + enable_tiling_heuristics = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1" + ) + + # Maximal allowed number of slices on K-dim for a GEMM kernel. This controls + # the maximal parallelism of K-slicing. Since K-slicing requires extra thread + # synchronization and buffers, the maximal number of slices is limited to + # mitigate the sync overhead and memory usage. + # When set to 0, the number of slices is unlimited. + gemm_max_k_slices = int(os.environ.get("TORCHINDUCTOR_CPP_GEMM_MAX_K_SLICES", "1")) + + # For perf tuning and debugging purpose, configure the pre-defined cache blocking for + # MxNxK dims respectively. The blockings are separated by comma and the unit is + # the number of register blocks. + # For example, "4,1,10" means 4 register blocks on M, 1 on N and 10 on K respectively. + gemm_cache_blocking = os.environ.get("TORCHINDUCTOR_CPP_GEMM_CACHE_BLOCKING", None) + + # For perf tuning and debugging purpose, configure the pre-defined thread blocking factors for + # MxNxK dims respectively. The factors are separated by comma and their product + # should be the same as the total number of threads. + # For example, if the total number of threads is 56, "7,4,2" means the work is + # decomposed into 7x4x2 thread blocks along MxNxK of a GEMM. + gemm_thread_factors = os.environ.get("TORCHINDUCTOR_CPP_GEMM_THREAD_FACTORS", None) + + # Whether to enable masked vectorization for the tail_loop. + enable_loop_tail_vec = True + + +# config specific to codegen/triton.py +class triton: + # Use cudagraphs on output code + cudagraphs = os.environ.get("TORCHINDUCTOR_CUDAGRAPHS") == "1" + + # Use cudagraph trees for memory pooling if `cudagraphs` is True + cudagraph_trees = True + + # Should we skip cudagraphing graphs with dynamic shape inputs + # If False, we will re-record a graph for each unique set of shape inputs + cudagraph_skip_dynamic_graphs = False + + # assertions not on the fast path, steady state + slow_path_cudagraph_asserts = True + + # TODO - need to debug why this prevents cleanup + cudagraph_trees_history_recording = False + + # Enable cudagraph support for mutated inputs from prior cudagraph pool + cudagraph_support_input_mutation = False if is_fbcode() else True + + # Maximal number of allowed cudagraph re-record for a function and + # a cudagraph node due to static input tensor address changes or + # cudagraph managed tensor data pointer changed. + # i.e., allow num_recording <= cudagraph_unexpected_rerecord_limit + # note: we are conservative here and choose a large limit. + cudagraph_unexpected_rerecord_limit = 128 + + # Warn loudly when the number of cudagraphs due to dynamic shape + # exceeds this limit + cudagraph_dynamic_shape_warn_limit: Optional[int] = 50 + + # synchronize after cudagraph invocation + force_cudagraph_sync = False + + # always run cudagraphs in the eager warmup stage + # instead of recording and executing cudagraphs + force_cudagraphs_warmup = False + + # assertions on the fast path + fast_path_cudagraph_asserts = False + + # skip warmup for cudagraph trees + skip_cudagraph_warmup = False + + # Synchronize before and after every compiled graph. + debug_sync_graph = False + + # Synchronize after every kernel launch, to help pinpoint bugs + debug_sync_kernel = False + + # Always load full blocks (rather than broadcasting inside the block) + dense_indexing = False + + # limit tiling dimensions + max_tiles = 2 + + # Prefer higher dimensional tilings. This simplifies indexing expressions, making + # it easier to identify block pointers. + prefer_nd_tiling: bool = False + + # use triton.autotune for pointwise ops with complex layouts + # this should only be disabled for debugging/testing + autotune_pointwise = True + + # max autotune gemm with cublasLt + autotune_cublasLt = True + + # Tune the generated Triton kernels at compile time instead of first time they run + autotune_at_compile_time = False + + # should we stop a fusion to allow better tiling? + tiling_prevents_pointwise_fusion = True + tiling_prevents_reduction_fusion = True + + # should we give different names to kernels + # Note: This is orthogonal to descriptive_names - this is deciding whether + # our triton kernel names should all be `triton_` (to maximize caching) or + # whether they should be unique. + unique_kernel_names = os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1" + + # should we put op names in kernel names + # False: No special names (just triton__1, triton__2, etc.) + # "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.) + # "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions) + # "inductor_node": Maps to the node name in the FX graph passed to Inductor + descriptive_names = "original_aten" + + # use alternate codegen for smaller reductions + persistent_reductions = ( + os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1" + ) + + # 0/False: disable + # 1/True: enable, use tuning to pick between different subkernels + # 2: enable, force using persistent reduction (for debugging) + # 3: enable, force using non-persistent reduction (for debugging) + multi_kernel = int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0")) + + # hint to Triton when arguments are divisible by 16 + divisible_by_16 = True + + # Minimum RBLOCK to be used for a TritonSplitScanKernel + # NOTE: This also indirectly controls the size of workspace buffer required + min_split_scan_rblock = 256 + + # Store the generated cubin files for cpp wrapper code to load + store_cubin = False + + # the max number of spills we allow for the configs we benchmark. + # Setting this to 0 means we skip a config if it spills even a single + # register. + # Setting it to a larger value allows a config spilling a small amount + # of registers being benchmarked. + # + # NOTE: triton will always report >0 register spills for kernels using sin/cos. + # (check this issue https://github.com/openai/triton/issues/1756 ) + # So far we see a fixed 8 spilled registers for kernels using sin/cos. + # Raise the threshold to 16 to be safe. + # We should revisit this once we understand more of the source of register spills. + spill_threshold: int = 16 + + # Generate code containing the newer tl.make_block_ptr() API for loads/store + use_block_ptr = False + + # Inject a bug into our relu implementation; useful for testing our repro + # extraction and minification functionality. + # Valid values: "compile_error", "runtime_error", "accuracy" + inject_relu_bug_TESTING_ONLY: Optional[str] = None + + # Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental) + codegen_upcast_to_fp32 = True + + +class aot_inductor: + # AOTInductor output path + # If an absolute path is specified, the generated lib files will be stored under the directory; + # If a relative path is specified, it will be used as a subdirectory under the default caching path; + # If not specified, a temp directory will be created under the default caching path. + # If the specified path contains something like "model.so", the sub-string will be used + # to name the generated library. + output_path = "" + + debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1" + + debug_dump_consts_bin: bool = ( + os.environ.get("AOT_INDUCTOR_DEBUG_DUMP_CONSTS_BIN", "0") == "1" + ) + + # option for debug printing/saving for intermediate tensor values for aot inductor + # 0: disable debug dumping + # 1: enable saving intermediate tensor values + # 2: enable printing intermediate tensor values + debug_intermediate_value_printer = os.environ.get( + "AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0" + ) + + # filtered nodes to be printed for debug values. Specify this option when debug_intermediate_value_printer is set to 2 + filtered_kernel_names = os.environ.get( + "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", None + ) + + # Serialized tree spec for flattening inputs + serialized_in_spec = "" + + # Serialized tree spec for flattening outputs + serialized_out_spec = "" + + # flag to decide whether to create a submodule for constant graph. + use_runtime_constant_folding: bool = False + + # flag to force weight to be appened to the shared library and mmaped by the runtime + # rather than embedded into the data section. Needed to support 1B+ parameter models + force_mmap_weights: bool = False + + package: bool = False + + +class cuda: + # CUDA arch to use for CUDA template kernel compilation. + # e.g. "70", "75", "80", "90", etc. + # When arch is None, Inductor uses torch.cuda.get_device_capability(0). + arch: Optional[str] = None + + # CUDA version to use for CUDA template kernel compilation. + # e.g. "11.4", "12.1", etc. + # When version is None, Inductor uses torch.version.cuda. + version: Optional[str] = None + + # Optimization level for the host compiler. + compile_opt_level = "-O1" + + # Whether to enable device LTO (link-time-optimization). + enable_cuda_lto = False + + # Whether to keep intermediate files dring compilation. + enable_ptxas_info = False + + # Whether to enable debug info, e.g. line number, cutlass debug info. + enable_debug_info = False + + # Whether to use fast math. + use_fast_math = False + + # Path to the CUTLASS repo root directory. + # The default path only works under PyTorch local development environment. + cutlass_dir = os.environ.get( + "TORCHINDUCTOR_CUTLASS_DIR", + os.path.abspath( + os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/") + ), + ) + + # Configures the maximum number of CUTLASS configs to profile in max_autotune. + # By default it's None, so that all CUTLASS configs are tuned. + # This is mainly used to reduce test time in CI. + cutlass_max_profiling_configs: Optional[int] = None + + # Path to CUDA NVCC. + # NVCC search order: + # 1) cuda_cxx set in this config + # 2) CUDACXX environment variable + # 3) CUDA_HOME environment variable + # 4) default system search PATH. + cuda_cxx: Optional[str] = None + + # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops. + cutlass_backend_min_gemm_size: int = 1 + + # enable generation of inline standalone runner in CUDA CPP generated code + # which allows to compile the generated code into a standalone executable. + generate_test_runner: bool = ( + os.environ.get("INDUCTOR_CUDA_BACKEND_GENERATE_TEST_RUNNER_CODE", "1") == "1" + ) + + # Keep only Cutlass op configs which contain this regular expression pattern + # Set this to "warpspecialized_cooperative_epi_tma" to enable only SM90 TMA Cutlass Kernels for large GEMMs + cutlass_op_allowlist_regex: Optional[str] = None + + # Note: Names of Cutlass ops names can be obtained by calling + # op.configuration_name() on a Cutlass op instance, for example those + # returned from cutlass_utils.gen_ops() or the op argument passed to + # CUTLASSGemmTemplate.render(...) + + # Filter Cutlass configs which contain this regular expression pattern + # Set this to "pingpong" to avoid numerical issues + # caused by the op ordering of the "pingpong" memory access + # pattern used by some Cutlass Kernels. + cutlass_op_denylist_regex: Optional[str] = "pingpong" + + +class rocm: + # Offload arch list for device code compilation, e.g. ["gfx941", "gfx942"]. + # If empty, the `native` arch is used + arch: List[str] = [] + + # Enable the CK backend for CDNA2 and CDNA3 only (for now) + # Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors + ck_supported_arch: List[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"] + + # Optimization level, use to balance compilation speed and runtime performance + compile_opt_level = "-O2" + + # Flag to keep debug information in compiled objects + is_debug = False + + # Flag to keep intermediate files (assembly listings, preprocessed sources, etc.) + save_temps = False + + # Flag to add `-ffast-math`` to compile flags + use_fast_math = True + + # Flag to add `-fgpu-flush-denormals-to-zero` to compile flags + flush_denormals = True + + # Flag to print register and LDS usage during compilation + print_kernel_resource_usage = False + + # Path to ROCm installation, if None, use env variable ROCM_HOME + rocm_home: Optional[str] = None + + # Path to Composable Kernel library. + # Install with `pip install git+https://github.com/rocm/composable_kernel@develop`. + ck_dir = os.environ.get("TORCHINDUCTOR_CK_DIR") + + # Number of op instance choices to trade off between runtime perf and compilation time + n_max_profiling_configs: Optional[int] = None + + # Flag to use a short list of CK instances which perform well across a variety of shapes. + # Currently RCR and F16 only + use_preselected_instances: bool = False + + +# Backend to use for CPU codegen either "cpp" or "halide" (experimental) +cpu_backend = "cpp" + +# Backend to use for CUDA codegen either "triton" or "halide" (experimental) +cuda_backend = "triton" + + +class halide: + # Base halide target to use for CPU devices + cpu_target = "host" + + # Base halide target to use for CUDA devices + gpu_target = "host-cuda" + + # Halide autoscheduler to use, choices are: + # "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only) + scheduler_cuda = "Anderson2021" + scheduler_cpu = "Adams2019" + + # Controls `no_asserts` flag passed to Halide target (warning: can false positive) + asserts = False + + # Controls `debug` flag passed to Halide target + debug = False + + # Enable (or fallback on) scan kernels such as cumsum + # Halide autoschedulers struggle with these kernels + scan_kernels = False + + +# create a directory containing lots of debug information +class trace: + # master switch for all debugging flags below + enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + # Save debug information to a temporary directory + # If not specified, a temp directory will be created by system + debug_dir: Optional[str] = None + + # Save python logger call >=logging.DEBUG + debug_log = False + + # Save python logger call >=logging.INFO + info_log = False + + # Save input FX graph (post decomps, pre optimization) + fx_graph = True + + # Save FX graph after transformations + fx_graph_transformed = True + + # Save TorchInductor IR before fusion pass + ir_pre_fusion = True + + # Save TorchInductor IR after fusion pass + ir_post_fusion = True + + # Copy generated code to trace dir + output_code = True + + # SVG figure showing post-fusion graph + graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1" + + # SVG figure showing fx with fusion + draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1" + + # We draw our fx graphs with the "record" shape attribute by default. + # Sometimes, when the graph is very complex, we may hit dot errors like below: + # "flat edge between adjacent nodes one of which has a record shape - + # replace records with HTML-like labels" + # and thus fail to generate a graph. So, let's give the user an option + # to specify the shape attribute for the dot graph. For example, passing + # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like lables + # to workaround the above failure. + dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None) + + # If not None, this is the URL that saves the SVG files of the input/output + # graph of each pass that changed the graph + # The nodes that are being transformed in each pass will be colored in yellow + # URL only supports local directory for now + log_url_for_graph_xform = os.environ.get("INDUCTOR_LOG_URL_FOR_GRAPH_XFORM", None) + + # Store cProfile (see snakeviz to view) + compile_profile = False + + # Upload the .tar.gz file + # Needs to be overriden based on specific environment needs + upload_tar: Optional[Callable[[str], None]] = None + + log_autotuning_results: bool = False + + +_save_config_ignore = [ + # workaround: "Can't pickle " + "trace.upload_tar", + "post_grad_custom_post_pass", + "post_grad_custom_pre_pass", + "joint_custom_pre_pass", + "joint_custom_post_pass", + "pre_grad_custom_pass", +] + +_cache_config_ignore_prefix = [ + # trace functions are not relevant to config caching + "trace", + # uses absolute path + "cuda.cutlass_dir", + # not relevant + "compile_threads", +] + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + +from torch.utils._config_module import install_config_module + + +# adds patch, save_config, etc +install_config_module(sys.modules[__name__]) diff --git a/lib/python3.10/site-packages/torch/_inductor/constant_folding.py b/lib/python3.10/site-packages/torch/_inductor/constant_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..72f34d32475f39c94a30308455edbcfcb8e03d08 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/constant_folding.py @@ -0,0 +1,348 @@ +import collections +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.utils._pytree as pytree + + +aten = torch.ops.aten + +# We would like to split modules into two subgraphs for runtime weight updates to work correctly. +# The use case and more information could be found at: +# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing +META_TAG = "MODULE_TYPE" +MODULE_TAG = "_MAIN_MODULE" +CONST_MODULE_TAG = "_CONST_MODULE" + + +def replace_node_with_constant( + gm: torch.fx.GraphModule, + node: torch.fx.Node, + constant: torch.Tensor, + name: Optional[str] = None, +) -> None: + g = gm.graph + + if name: + qualname = name + else: + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 # type: ignore[assignment] + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 + + gm._frozen_param_count = i + 1 + + with g.inserting_before(node): + new_input_node = g.create_node("get_attr", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + + # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_buffer(qualname, constant) + setattr(gm, qualname, constant) + + +def is_const_source( + node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]] +) -> bool: + return node.op == "get_attr" or ( + node.op == "placeholder" + and lifted_constants is not None + and node.name in lifted_constants + ) + + +class ConstantFolder(torch.fx.Interpreter): + def __init__( + self, + gm: torch.fx.GraphModule, + skip_constructors: bool = False, + lifted_constants: Optional[Dict[str, torch.Tensor]] = None, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, + ) -> None: + super().__init__(gm) + self.node_replacements: Dict[torch.fx.Node, Any] = {} + self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter() + self.unknown_value = object() + self.skip_constructors: bool = skip_constructors + + # overwrite this to deallocate env values if their only remaining use + # is the output + self.user_to_last_uses = self.node_to_last_non_output_use() + self.lifted_constants = lifted_constants + + def _support_dynamic_shape(self) -> bool: + # ConstantFolder not support dynamic shape now + return False + + def _deduce_value(self, node: torch.fx.Node) -> Any: + return super().run_node(node) + + def is_impure(self, node: torch.fx.node.Node) -> bool: + if ( + node.target == torch.ops.prims.convert_element_type.default + and is_const_source(node.args[0], self.lifted_constants) # type: ignore[arg-type] + and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] + and node.args[1] == torch.bfloat16 + ): + # For int8_weight -> dq -> bf16_weight + return True + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused + return True + return False + + def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]: + last_non_output_use = collections.defaultdict(list) + seen_uses = set() + output_node = next(iter(reversed(self.module.graph.nodes))) + + for node in reversed(self.module.graph.nodes): + if node.target == "output": + continue + + def add_use(inp: torch.fx.Node) -> None: + if inp in seen_uses: + return + + seen_uses.add(inp) + last_non_output_use[node].append(inp) + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs)) + + # if this node is only used in output, we want to gc it right away + if len(node.users) == 1 and output_node in node.users: + last_non_output_use[node].append(node) + + return last_non_output_use + + def run_node(self, node: torch.fx.Node) -> Any: + if node.target == "output": + # because we remove nodes from env on last non output use, + # re-define them now or we'll get error in interpreter + def set_env(arg: torch.fx.Node) -> None: + self.env[arg] = self.unknown_value + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, set_env, node.args) + return super().run_node(node) + + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + # We need to do this weird thing because in cases where flattened_inputs + # contains a ScriptObject, equality checking results in a type error if + # the types are different. + if any( + type(self.unknown_value) == type(input_) and self.unknown_value == input_ + for input_ in flattened_inputs + ): + return self.unknown_value + + # TODO - fix errors with this + if ( + node.op == "call_function" + and node.target == aten._efficientzerotensor.default + ): + return self.unknown_value + + # TODO - constant folding triton kernel returns the inputs -- fix this + if ( + node.op == "call_function" + and node.name == "triton_kernel_wrapper_functional_proxy" + ): + return self.unknown_value + + # skip constructors, since inductor generates optimal code for them already + # and turning into tensor would result in an additional global memory read + # TODO - more complicated strategy + if ( + self.skip_constructors + and not is_const_source(node, self.lifted_constants) + and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) + ): + return self.unknown_value + + # All mutations should either be removed or on inputs which we did not make constant + if ( + isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return self.unknown_value + + out = self._deduce_value(node) + if out == self.unknown_value: + return self.unknown_value + + if not is_const_source(node, self.lifted_constants) and isinstance( + out, torch.Tensor + ): + if out.device.type == "meta": + return out + + if not self.insertable_tensor_check(out): + return out + + if self.is_impure(node): + return self.unknown_value + + self.add_node_replacement(node, out) + + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + + for n in flattened_node_inps: + if not isinstance(n, torch.fx.Node): + continue + + self.replaced_uses[n] += 1 + + for to_delete in self.user_to_last_uses.get(node, []): + if self.replaced_uses[to_delete] == len(to_delete.users): + self.node_replacements.pop(to_delete, None) + + return out + + def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor + + def run(self) -> Any: # type: ignore[override] + env: Dict[torch.fx.Node, Any] = {} + self.insert_placerholder_values(env) + return super().run(initial_env=env) + + def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None: + for n in self.module.graph.find_nodes(op="placeholder"): + if self.lifted_constants is not None and n.name in self.lifted_constants: + env[n] = self.lifted_constants[n.name] + else: + env[n] = self.unknown_value # type: ignore[assignment] + + +def constant_fold( + gm: torch.fx.GraphModule, + constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> None: + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + for node in gm.graph.find_nodes(op="get_attr"): + if len(node.users) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def constant_graph_tag( + gm: torch.fx.GraphModule, + lifted_constants: Optional[Dict[str, Any]], + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]], +) -> None: + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder( + gm, skip_constructors=True, lifted_constants=lifted_constants + ) + cf.run() + + for node in gm.graph.nodes: + if skip_folding_node_fn is not None and skip_folding_node_fn(node): + node.meta[META_TAG] = MODULE_TAG + continue + if ( + is_const_source(node, lifted_constants) + or node in cf.node_replacements + or node in cf.replaced_uses + ): + node.meta[META_TAG] = CONST_MODULE_TAG + else: + node.meta[META_TAG] = MODULE_TAG + + +def run_and_get_constant_graph( + gm: torch.fx.GraphModule, + lifted_constants: Optional[Dict[str, Any]], + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]], +) -> Tuple[torch.fx.GraphModule, Tuple[torch.Tensor, ...]]: + """ + Construct a GraphModule which corresponds to the part which could be + constant folded in provided gm. + """ + + constant_graph_tag(gm, lifted_constants, skip_folding_node_fn) + + def untag(node: torch.fx.Node) -> bool: + used_to_fold = False + for u in node.users: + if u.meta[META_TAG] == CONST_MODULE_TAG: + used_to_fold = True + break + if not used_to_fold: + node.meta[META_TAG] = MODULE_TAG + return used_to_fold + + const_args = [] + if lifted_constants is not None: + placeholders = list(gm.graph.find_nodes(op="placeholder")) + for node in placeholders: + if node.meta[META_TAG] == MODULE_TAG: + continue + if untag(node): + const_args.append(lifted_constants[node.name]) + + # We rewrite the tags, if it's a constant being directly consumed, without + # any folding opportunity, we keep it in main gm. + for node in gm.graph.find_nodes(op="get_attr"): + untag(node) + + new_graph = torch.fx.Graph() + + node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + output_nodes = [] + for node in gm.graph.nodes: + if node.meta[META_TAG] == MODULE_TAG: + continue + + new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) + node_remapping[node] = new_node + + for user in node.users: + if user.meta[META_TAG] == MODULE_TAG: + output_nodes.append(new_node) + break + + new_graph.output(tuple(output_nodes)) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + + const_result = new_gm(*const_args) + return new_gm, const_result diff --git a/lib/python3.10/site-packages/torch/_inductor/cpp_builder.py b/lib/python3.10/site-packages/torch/_inductor/cpp_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..95a0bff86fd8a4f1963d279d096b82dd33934ab2 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/cpp_builder.py @@ -0,0 +1,1511 @@ +# This CPP builder is designed to support both Windows and Linux OS. +# The design document please check this RFC: https://github.com/pytorch/pytorch/issues/124245 + +import copy +import errno +import functools +import json +import logging +import os +import platform +import re +import shlex +import shutil +import subprocess +import sys +import sysconfig +import warnings +from ctypes import cdll +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple, Union + +import torch +from torch._dynamo.utils import dynamo_timed +from torch._inductor import config, exc +from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA +from torch._inductor.runtime.runtime_utils import cache_dir +from torch.torch_version import TorchVersion + + +if config.is_fbcode(): + from triton.fb import build_paths # noqa: F401 + + from torch._inductor.fb.utils import ( + log_global_cache_errors, + log_global_cache_stats, + log_global_cache_vals, + use_global_cache, + ) +else: + + def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: + pass + + def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: + pass + + def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: + pass + + def use_global_cache() -> bool: + return False + + +# Windows need setup a temp dir to store .obj files. +_BUILD_TEMP_DIR = "CxxBuild" + +# initialize variables for compilation +_IS_LINUX = sys.platform.startswith("linux") +_IS_MACOS = sys.platform.startswith("darwin") +_IS_WINDOWS = sys.platform == "win32" + +SUBPROCESS_DECODE_ARGS = ("utf-8",) if _IS_WINDOWS else () + +log = logging.getLogger(__name__) + + +# =============================== toolchain =============================== +@functools.lru_cache(1) +def cpp_compiler_search(search: str) -> str: + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT + + for cxx in search: + try: + if cxx is None: + # gxx package is only available for Linux + # according to https://anaconda.org/conda-forge/gxx/ + if sys.platform != "linux": + continue + # Do not install GXX by default + if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"): + continue + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock( + os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT + ) + with lock: + cxx = install_gcc_via_conda() + subprocess.check_output([cxx, "--version"]) + return cxx + except (subprocess.SubprocessError, FileNotFoundError, ImportError): + continue + raise exc.InvalidCxxCompiler + + +def install_gcc_via_conda() -> str: + """On older systems, this is a quick way to get a modern compiler""" + prefix = os.path.join(cache_dir(), "gcc") + cxx_path = os.path.join(prefix, "bin", "g++") + if not os.path.exists(cxx_path): + log.info("Downloading GCC via conda") + conda = os.environ.get("CONDA_EXE", "conda") + if conda is None: + conda = shutil.which("conda") + if conda is not None: + subprocess.check_call( + [ + conda, + "create", + f"--prefix={prefix}", + "--channel=conda-forge", + "--quiet", + "-y", + "python=3.8", + "gxx", + ], + stdout=subprocess.PIPE, + ) + return cxx_path + + +@functools.lru_cache(None) +def check_compiler_exist_windows(compiler: str) -> None: + """ + Check if compiler is ready, in case end user not activate MSVC environment. + """ + try: + output_msg = ( + subprocess.check_output([compiler, "/help"], stderr=subprocess.STDOUT) + .strip() + .decode(*SUBPROCESS_DECODE_ARGS) + ) + except FileNotFoundError as exc: + raise RuntimeError(f"Compiler: {compiler} is not found.") from exc + except subprocess.SubprocessError: + # Expected that some compiler(clang, clang++) is exist, but they not support `/help` args. + pass + + +def get_cpp_compiler() -> str: + if _IS_WINDOWS: + compiler = os.environ.get("CXX", "cl") + check_compiler_exist_windows(compiler) + else: + if config.is_fbcode(): + return ( + build_paths.cc() if torch.version.hip is None else build_paths.clang() + ) + if isinstance(config.cpp.cxx, (list, tuple)): + search = tuple(config.cpp.cxx) + else: + search = (config.cpp.cxx,) + compiler = cpp_compiler_search(search) + return compiler + + +@functools.lru_cache(None) +def _is_apple_clang(cpp_compiler: str) -> bool: + version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8") + return "Apple" in version_string.splitlines()[0] + + +def _is_clang(cpp_compiler: str) -> bool: + # Mac OS apple clang maybe named as gcc, need check compiler info. + if sys.platform == "darwin": + return _is_apple_clang(cpp_compiler) + elif _IS_WINDOWS: + # clang suite have many compilers, and only clang-cl is supported. + if re.search(r"((clang$)|(clang\+\+$))", cpp_compiler): + raise RuntimeError( + "Please use clang-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." + ) + return bool(re.search(r"(clang-cl)", cpp_compiler)) + return bool(re.search(r"(clang|clang\+\+)", cpp_compiler)) + + +def _is_gcc(cpp_compiler: str) -> bool: + if sys.platform == "darwin" and _is_apple_clang(cpp_compiler): + return False + return bool(re.search(r"(gcc|g\+\+)", cpp_compiler)) + + +@functools.lru_cache(None) +def _is_msvc_cl(cpp_compiler: str) -> bool: + if not _IS_WINDOWS: + return False + + try: + output_msg = ( + subprocess.check_output([cpp_compiler, "/help"], stderr=subprocess.STDOUT) + .strip() + .decode(*SUBPROCESS_DECODE_ARGS) + ) + return "Microsoft" in output_msg.splitlines()[0] + except FileNotFoundError as exc: + return False + + return False + + +@functools.lru_cache(None) +def _is_intel_compiler(cpp_compiler: str) -> bool: + def _check_minimal_version(compiler_version: TorchVersion) -> None: + """ + On Windows: early version icx has `-print-file-name` issue, and can't preload correctly for inductor. + """ + min_version = "2024.2.1" if _IS_WINDOWS else "0.0.0" + if compiler_version < TorchVersion(min_version): + raise RuntimeError( + f"Intel Compiler error: less than minimal version {min_version}." + ) + + try: + output_msg = ( + subprocess.check_output( + [cpp_compiler, "--version"], stderr=subprocess.DEVNULL + ) + .strip() + .decode(*SUBPROCESS_DECODE_ARGS) + ) + is_intel_compiler = "Intel" in output_msg.splitlines()[0] + if is_intel_compiler: + if _IS_WINDOWS: + if re.search(r"((icx$)|(icx-cc$))", cpp_compiler): + raise RuntimeError( + "Please use icx-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." + ) + + # Version check + icx_ver_search = re.search(r"(\d+[.]\d+[.]\d+[.]\d+)", output_msg) + if icx_ver_search is not None: + icx_ver = icx_ver_search.group(1) + _check_minimal_version(TorchVersion(icx_ver)) + + return is_intel_compiler + except FileNotFoundError as exc: + return False + except subprocess.SubprocessError: + # --version args not support. + return False + + return False + + +@functools.lru_cache(None) +def is_gcc() -> bool: + return _is_gcc(get_cpp_compiler()) + + +@functools.lru_cache(None) +def is_clang() -> bool: + return _is_clang(get_cpp_compiler()) + + +@functools.lru_cache(None) +def is_intel_compiler() -> bool: + return _is_intel_compiler(get_cpp_compiler()) + + +@functools.lru_cache(None) +def is_apple_clang() -> bool: + return _is_apple_clang(get_cpp_compiler()) + + +@functools.lru_cache(None) +def is_msvc_cl() -> bool: + return _is_msvc_cl(get_cpp_compiler()) + + +def get_compiler_version_info(compiler: str) -> str: + env = os.environ.copy() + env["LC_ALL"] = "C" # Don't localize output + try: + version_string = subprocess.check_output( + [compiler, "-v"], stderr=subprocess.STDOUT, env=env + ).decode(*SUBPROCESS_DECODE_ARGS) + except Exception as e: + try: + version_string = subprocess.check_output( + [compiler, "--version"], stderr=subprocess.STDOUT, env=env + ).decode(*SUBPROCESS_DECODE_ARGS) + except Exception as e: + return "" + # Mutiple lines to one line string. + version_string = version_string.replace("\r", "_") + version_string = version_string.replace("\n", "_") + return version_string + + +# =============================== cpp builder =============================== +def _append_list(dest_list: List[str], src_list: List[str]) -> None: + for item in src_list: + dest_list.append(copy.deepcopy(item)) + + +def _remove_duplication_in_list(orig_list: List[str]) -> List[str]: + new_list: List[str] = [] + for item in orig_list: + if item not in new_list: + new_list.append(item) + return new_list + + +def _create_if_dir_not_exist(path_dir: str) -> None: + if not os.path.exists(path_dir): + try: + Path(path_dir).mkdir(parents=True, exist_ok=True) + except OSError as exc: # Guard against race condition + if exc.errno != errno.EEXIST: + raise RuntimeError( # noqa: TRY200 (Use `raise from`) + f"Fail to create path {path_dir}" + ) + + +def _remove_dir(path_dir: str) -> None: + if os.path.exists(path_dir): + for root, dirs, files in os.walk(path_dir, topdown=False): + for name in files: + file_path = os.path.join(root, name) + os.remove(file_path) + for name in dirs: + dir_path = os.path.join(root, name) + os.rmdir(dir_path) + os.rmdir(path_dir) + + +def _run_compile_cmd(cmd_line: str, cwd: str) -> bytes: + cmd = shlex.split(cmd_line) + try: + status = subprocess.check_output(args=cmd, cwd=cwd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + output = e.output.decode("utf-8") + openmp_problem = "'omp.h' file not found" in output or "libomp" in output + if openmp_problem and sys.platform == "darwin": + instruction = ( + "\n\nOpenMP support not found. Please try one of the following solutions:\n" + "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ " + "that has builtin OpenMP support;\n" + "(2) install OpenMP via conda: `conda install llvm-openmp`;\n" + "(3) install libomp via brew: `brew install libomp`;\n" + "(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path" + " with `include/omp.h` under it." + ) + output += instruction + raise exc.CppCompileError(cmd, output) from e + return status + + +def run_compile_cmd(cmd_line: str, cwd: str) -> bytes: + with dynamo_timed("compile_file"): + return _run_compile_cmd(cmd_line, cwd) + + +def normalize_path_separator(orig_path: str) -> str: + if _IS_WINDOWS: + return orig_path.replace(os.sep, "/") + return orig_path + + +class BuildOptionsBase: + """ + This is the Base class for store cxx build options, as a template. + Acturally, to build a cxx shared library. We just need to select a compiler + and maintains the suitable args. + """ + + def __init__( + self, + compiler: str = "", + definitions: Optional[List[str]] = None, + include_dirs: Optional[List[str]] = None, + cflags: Optional[List[str]] = None, + ldflags: Optional[List[str]] = None, + libraries_dirs: Optional[List[str]] = None, + libraries: Optional[List[str]] = None, + passthrough_args: Optional[List[str]] = None, + aot_mode: bool = False, + use_absolute_path: bool = False, + compile_only: bool = False, + ) -> None: + self._compiler = compiler + self._definations: List[str] = definitions or [] + self._include_dirs: List[str] = include_dirs or [] + self._cflags: List[str] = cflags or [] + self._ldflags: List[str] = ldflags or [] + self._libraries_dirs: List[str] = libraries_dirs or [] + self._libraries: List[str] = libraries or [] + # Some args is hard to abstract to OS compatable, passthough it directly. + self._passthough_args: List[str] = passthrough_args or [] + + self._aot_mode: bool = aot_mode + self._use_absolute_path: bool = use_absolute_path + self._compile_only: bool = compile_only + + def _process_compile_only_options(self) -> None: + if self._compile_only: + self._libraries_dirs = [] + self._libraries = [] + + def _remove_duplicate_options(self) -> None: + self._definations = _remove_duplication_in_list(self._definations) + self._include_dirs = _remove_duplication_in_list(self._include_dirs) + self._cflags = _remove_duplication_in_list(self._cflags) + self._ldflags = _remove_duplication_in_list(self._ldflags) + self._libraries_dirs = _remove_duplication_in_list(self._libraries_dirs) + self._libraries = _remove_duplication_in_list(self._libraries) + self._passthough_args = _remove_duplication_in_list(self._passthough_args) + + def _finalize_options(self) -> None: + self._process_compile_only_options + self._remove_duplicate_options + + def get_compiler(self) -> str: + return self._compiler + + def get_definations(self) -> List[str]: + return self._definations + + def get_include_dirs(self) -> List[str]: + return self._include_dirs + + def get_cflags(self) -> List[str]: + return self._cflags + + def get_ldflags(self) -> List[str]: + return self._ldflags + + def get_libraries_dirs(self) -> List[str]: + return self._libraries_dirs + + def get_libraries(self) -> List[str]: + return self._libraries + + def get_passthough_args(self) -> List[str]: + return self._passthough_args + + def get_aot_mode(self) -> bool: + return self._aot_mode + + def get_use_absolute_path(self) -> bool: + return self._use_absolute_path + + def get_compile_only(self) -> bool: + return self._compile_only + + def save_flags_to_file(self, file: str) -> None: + attrs = { + "compiler": self.get_compiler(), + "definitions": self.get_definations(), + "include_dirs": self.get_include_dirs(), + "cflags": self.get_cflags(), + "ldflags": self.get_ldflags(), + "libraries_dirs": self.get_libraries_dirs(), + "libraries": self.get_libraries(), + "passthrough_args": self.get_passthough_args(), + "aot_mode": self.get_aot_mode(), + "use_absolute_path": self.get_use_absolute_path(), + "compile_only": self.get_compile_only(), + } + + with open(file, "w") as f: + json.dump(attrs, f) + + +def _get_warning_all_cflag(warning_all: bool = True) -> List[str]: + if not _IS_WINDOWS: + return ["Wall"] if warning_all else [] + else: + return [] + + +def _get_cpp_std_cflag(std_num: str = "c++17") -> List[str]: + if _IS_WINDOWS: + """ + On Windows, only c++20 can support `std::enable_if_t`. + Ref: https://learn.microsoft.com/en-us/cpp/overview/cpp-conformance-improvements-2019?view=msvc-170#checking-for-abstract-class-types # noqa: B950 + Note: + Only setup c++20 for Windows inductor. I tried to upgrade all project to c++20, but it is failed: + https://github.com/pytorch/pytorch/pull/131504 + """ + std_num = "c++20" + return [f"std:{std_num}"] + else: + return [f"std={std_num}"] + + +def _get_os_related_cpp_cflags(cpp_compiler: str) -> List[str]: + if _IS_WINDOWS: + cflags = [ + "wd4819", + "wd4251", + "wd4244", + "wd4267", + "wd4275", + "wd4018", + "wd4190", + "wd4624", + "wd4067", + "wd4068", + "EHsc", + ] + else: + cflags = ["Wno-unused-variable", "Wno-unknown-pragmas"] + if _is_clang(cpp_compiler): + cflags.append("Werror=ignored-optimization-argument") + return cflags + + +def _get_optimization_cflags() -> List[str]: + if _IS_WINDOWS: + return ["O2"] + else: + cflags = ["O0", "g"] if config.aot_inductor.debug_compile else ["O3", "DNDEBUG"] + cflags.append("ffast-math") + cflags.append("fno-finite-math-only") + + if not config.cpp.enable_unsafe_math_opt_flag: + cflags.append("fno-unsafe-math-optimizations") + if not config.cpp.enable_floating_point_contract_flag: + cflags.append("ffp-contract=off") + + if sys.platform != "darwin": + # https://stackoverflow.com/questions/65966969/why-does-march-native-not-work-on-apple-m1 + # `-march=native` is unrecognized option on M1 + if not config.is_fbcode(): + if platform.machine() == "ppc64le": + cflags.append("mcpu=native") + else: + cflags.append("march=native") + + return cflags + + +def _get_shared_cflag(compile_only: bool) -> List[str]: + if _IS_WINDOWS: + """ + MSVC `/MD` using python `ucrtbase.dll` lib as runtime. + https://learn.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=msvc-170 + """ + SHARED_FLAG = ["DLL", "MD"] + else: + if compile_only: + return ["fPIC"] + if platform.system() == "Darwin" and "clang" in get_cpp_compiler(): + # This causes undefined symbols to behave the same as linux + return ["shared", "fPIC", "undefined dynamic_lookup"] + else: + return ["shared", "fPIC"] + + return SHARED_FLAG + + +def get_cpp_options( + cpp_compiler: str, + compile_only: bool, + warning_all: bool = True, + extra_flags: Sequence[str] = (), +) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]: + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + + cflags = ( + _get_shared_cflag(compile_only) + + _get_optimization_cflags() + + _get_warning_all_cflag(warning_all) + + _get_cpp_std_cflag() + + _get_os_related_cpp_cflags(cpp_compiler) + ) + + passthough_args.append(" ".join(extra_flags)) + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +class CppOptions(BuildOptionsBase): + """ + This class is inherited from BuildOptionsBase, and as cxx build options. + This option need contains basic cxx build option, which contains: + 1. OS related args. + 2. Toolchains related args. + 3. Cxx standard related args. + Note: + 1. This Options is good for assist modules build, such as x86_isa_help. + """ + + def __init__( + self, + compile_only: bool = False, + warning_all: bool = True, + extra_flags: Sequence[str] = (), + use_absolute_path: bool = False, + ) -> None: + super().__init__() + self._compiler = get_cpp_compiler() + self._use_absolute_path = use_absolute_path + self._compile_only = compile_only + + ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) = get_cpp_options( + cpp_compiler=self._compiler, + compile_only=compile_only, + extra_flags=extra_flags, + warning_all=warning_all, + ) + + _append_list(self._definations, definations) + _append_list(self._include_dirs, include_dirs) + _append_list(self._cflags, cflags) + _append_list(self._ldflags, ldflags) + _append_list(self._libraries_dirs, libraries_dirs) + _append_list(self._libraries, libraries) + _append_list(self._passthough_args, passthough_args) + self._finalize_options() + + +def _get_glibcxx_abi_build_flags() -> List[str]: + if not _IS_WINDOWS: + return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] + else: + return [] + + +def _get_torch_cpp_wrapper_defination() -> List[str]: + return ["TORCH_INDUCTOR_CPP_WRAPPER"] + + +def _use_custom_generated_macros() -> List[str]: + return [" C10_USING_CUSTOM_GENERATED_MACROS"] + + +def _use_fb_internal_macros() -> List[str]: + if not _IS_WINDOWS: + if config.is_fbcode(): + fb_internal_macros = [ + "C10_USE_GLOG", + "C10_USE_MINIMAL_GLOG", + "C10_DISABLE_TENSORIMPL_EXTENSIBILITY", + ] + # TODO: this is to avoid FC breakage for fbcode. When using newly + # generated model.so on an older verion of PyTorch, need to use + # the v1 version for aoti_torch_create_tensor_from_blob + create_tensor_from_blob_v1 = "AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1" + + fb_internal_macros.append(create_tensor_from_blob_v1) + return fb_internal_macros + else: + return [] + else: + return [] + + +def _setup_standard_sys_libs( + cpp_compiler: str, + aot_mode: bool, + use_absolute_path: bool, +) -> Tuple[List[str], List[str], List[str]]: + from torch._inductor.codecache import _LINKER_SCRIPT + + cflags: List[str] = [] + include_dirs: List[str] = [] + passthough_args: List[str] = [] + if _IS_WINDOWS: + return cflags, include_dirs, passthough_args + + if config.is_fbcode(): + cflags.append("nostdinc") + # Note that the order of include paths do matter, as a result + # we need to have several branches interleaved here + if torch.version.hip is None: + include_dirs.append(build_paths.sleef()) + include_dirs.append(build_paths.openmp()) + include_dirs.append(build_paths.python()) + if torch.version.hip is not None: + include_dirs.append(build_paths.clang_include()) + include_dirs.append(build_paths.gcc_include()) + include_dirs.append(build_paths.gcc_install_tools_include()) + else: + include_dirs.append(build_paths.cc_include()) + include_dirs.append(build_paths.libgcc()) + include_dirs.append(build_paths.libgcc_arch()) + include_dirs.append(build_paths.libgcc_backward()) + include_dirs.append(build_paths.glibc()) + include_dirs.append(build_paths.linux_kernel()) + include_dirs.append("include") + + if aot_mode and not use_absolute_path: + linker_script = _LINKER_SCRIPT + else: + linker_script = os.path.basename(_LINKER_SCRIPT) + + if _is_clang(cpp_compiler): + passthough_args.append(" --rtlib=compiler-rt") + passthough_args.append(" -fuse-ld=lld") + passthough_args.append(f" -Wl,--script={linker_script}") + passthough_args.append(" -B" + build_paths.glibc_lib()) + passthough_args.append(" -L" + build_paths.glibc_lib()) + + return cflags, include_dirs, passthough_args + + +def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> Tuple[List[str], List[str]]: + macros = [] + build_flags = [] + if vec_isa != invalid_vec_isa: + # Add Windows support later. + for x in vec_isa.build_macro(): + macros.append(copy.deepcopy(x)) + + build_flags = [vec_isa.build_arch_flags()] + + if config.is_fbcode(): + cap = str(vec_isa).upper() + macros = [ + f"CPU_CAPABILITY={cap}", + f"CPU_CAPABILITY_{cap}", + f"HAVE_{cap}_CPU_DEFINITION", + ] + + return macros, build_flags + + +def _get_torch_related_args( + include_pytorch: bool, aot_mode: bool +) -> Tuple[List[str], List[str], List[str]]: + from torch.utils.cpp_extension import _TORCH_PATH, TORCH_LIB_PATH + + include_dirs = [ + os.path.join(_TORCH_PATH, "include"), + os.path.join(_TORCH_PATH, "include", "torch", "csrc", "api", "include"), + # Some internal (old) Torch headers don't properly prefix their includes, + # so we need to pass -Itorch/lib/include/TH as well. + os.path.join(_TORCH_PATH, "include", "TH"), + os.path.join(_TORCH_PATH, "include", "THC"), + ] + libraries_dirs = [TORCH_LIB_PATH] + libraries = [] + if sys.platform != "darwin" and not config.is_fbcode(): + libraries = ["torch", "torch_cpu"] + if not aot_mode: + libraries.append("torch_python") + + if _IS_WINDOWS: + libraries.append("sleef") + + # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690 + if not config.abi_compatible: + libraries.append("c10") + libraries_dirs.append(TORCH_LIB_PATH) + + return include_dirs, libraries_dirs, libraries + + +def _get_python_include_dirs() -> List[str]: + include_dir = Path(sysconfig.get_path("include")) + # On Darwin Python executable from a framework can return + # non-existing /Library/Python/... include path, in which case + # one should use Headers folder from the framework + if not include_dir.exists() and platform.system() == "Darwin": + std_lib = Path(sysconfig.get_path("stdlib")) + include_dir = (std_lib.parent.parent / "Headers").absolute() + if not (include_dir / "Python.h").exists(): + warnings.warn(f"Can't find Python.h in {str(include_dir)}") + return [str(include_dir)] + + +def _get_python_related_args() -> Tuple[List[str], List[str]]: + python_include_dirs = _get_python_include_dirs() + python_include_path = sysconfig.get_path( + "include", scheme="nt" if _IS_WINDOWS else "posix_prefix" + ) + if python_include_path is not None: + python_include_dirs.append(python_include_path) + + if _IS_WINDOWS: + python_path = os.path.dirname(sys.executable) + python_lib_path = [os.path.join(python_path, "libs")] + else: + python_lib_path = [sysconfig.get_config_var("LIBDIR")] + + if config.is_fbcode(): + python_include_dirs.append(build_paths.python()) + + return python_include_dirs, python_lib_path + + +@functools.lru_cache(None) +def is_conda_llvm_openmp_installed() -> bool: + try: + command = "conda list llvm-openmp --json" + output = subprocess.check_output(command.split()).decode("utf8") + return len(json.loads(output)) > 0 + except subprocess.SubprocessError: + return False + + +@functools.lru_cache(None) +def homebrew_libomp() -> Tuple[bool, str]: + try: + # check if `brew` is installed + subprocess.check_output(["which", "brew"]) + # get the location of `libomp` if it is installed + # this is the location that `libomp` **would** be installed + # see https://github.com/Homebrew/brew/issues/10261#issuecomment-756563567 for details + libomp_path = ( + subprocess.check_output(["brew", "--prefix", "libomp"]) + .decode("utf8") + .strip() + ) + # check if `libomp` is installed + omp_available = os.path.exists(libomp_path) + return omp_available, libomp_path + except subprocess.SubprocessError: + return False, "" + + +@functools.lru_cache(None) +def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: + try: + output = subprocess.check_output([cpp_compiler, "-print-file-name=bin"]).decode( + "utf8" + ) + omp_path = os.path.join(output.rstrip(), omp_name) + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + omp_module = cdll.LoadLibrary(omp_path) + except subprocess.SubprocessError: + pass + + +@functools.lru_cache(None) +def perload_icx_libomp_win(cpp_compiler: str) -> None: + def _load_icx_built_in_lib_by_name(cpp_compiler: str, lib_name: str) -> bool: + try: + output = subprocess.check_output( + [cpp_compiler, f"-print-file-name={lib_name}"], + stderr=subprocess.DEVNULL, + ).decode(*SUBPROCESS_DECODE_ARGS) + omp_path = output.rstrip() + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + omp_module = cdll.LoadLibrary(omp_path) + return True + except subprocess.SubprocessError: + pass + return False + + """ + Intel Compiler implenmented more math libraries than clang, for performance proposal. + We need preload them like openmp library. + """ + preload_list = [ + "libiomp5md.dll", # openmp + "svml_dispmd.dll", # svml library + "libmmd.dll", # libm + ] + + for lib_name in preload_list: + _load_icx_built_in_lib_by_name(cpp_compiler, lib_name) + + +def _get_openmp_args( + cpp_compiler: str, +) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str]]: + cflags: List[str] = [] + ldflags: List[str] = [] + include_dir_paths: List[str] = [] + lib_dir_paths: List[str] = [] + libs: List[str] = [] + passthough_args: List[str] = [] + if _IS_MACOS: + # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang` + cflags.append("Xclang") + cflags.append("fopenmp") + + # only Apple builtin compilers (Apple Clang++) require openmp + omp_available = not _is_apple_clang(cpp_compiler) + + # check the `OMP_PREFIX` environment first + omp_prefix = os.getenv("OMP_PREFIX") + if omp_prefix is not None: + header_path = os.path.join(omp_prefix, "include", "omp.h") + valid_env = os.path.exists(header_path) + if valid_env: + include_dir_paths.append(os.path.join(omp_prefix, "include")) + lib_dir_paths.append(os.path.join(omp_prefix, "lib")) + else: + warnings.warn("environment variable `OMP_PREFIX` is invalid.") + omp_available = omp_available or valid_env + + if not omp_available: + libs.append("omp") + + # prefer to use openmp from `conda install llvm-openmp` + conda_prefix = os.getenv("CONDA_PREFIX") + if not omp_available and conda_prefix is not None: + omp_available = is_conda_llvm_openmp_installed() + if omp_available: + conda_lib_path = os.path.join(conda_prefix, "lib") + include_dir_paths.append(os.path.join(conda_prefix, "include")) + lib_dir_paths.append(conda_lib_path) + # Prefer Intel OpenMP on x86 machine + if os.uname().machine == "x86_64" and os.path.exists( + os.path.join(conda_lib_path, "libiomp5.dylib") + ): + libs.append("iomp5") + + # next, try to use openmp from `brew install libomp` + if not omp_available: + omp_available, libomp_path = homebrew_libomp() + if omp_available: + include_dir_paths.append(os.path.join(libomp_path, "include")) + lib_dir_paths.append(os.path.join(libomp_path, "lib")) + + # if openmp is still not available, we let the compiler to have a try, + # and raise error together with instructions at compilation error later + elif _IS_WINDOWS: + """ + On Windows, `clang` and `icx` have their specific openmp implenmention. + And the openmp lib is in compiler's some sub-directory. + For dynamic library(DLL) load, the Windows native APIs are `LoadLibraryA` and `LoadLibraryExA`, and their search + dependencies have some rules: + https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-loadlibraryexa#searching-for-dlls-and-dependencies + In some case, the rules may not include compiler's sub-directories. + So, it can't search and load compiler's openmp library correctly. + And then, the whole application would be broken. + + To avoid the openmp load failed, we can automatic locate the openmp binary and preload it. + 1. For clang, the function is `perload_clang_libomp_win`. + 2. For icx, the function is `perload_icx_libomp_win`. + """ + if _is_clang(cpp_compiler): + cflags.append("openmp") + libs.append("libomp") + perload_clang_libomp_win(cpp_compiler, "libomp.dll") + elif _is_intel_compiler(cpp_compiler): + cflags.append("Qiopenmp") + libs.append("libiomp5md") + perload_icx_libomp_win(cpp_compiler) + else: + # /openmp, /openmp:llvm + # llvm on Windows, new openmp: https://devblogs.microsoft.com/cppblog/msvc-openmp-update/ + # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 + cflags.append("openmp") + cflags.append("openmp:experimental") # MSVC CL + else: + if config.is_fbcode(): + include_dir_paths.append(build_paths.openmp()) + + openmp_lib = build_paths.openmp_lib() + fb_openmp_extra_flags = f"-Wp,-fopenmp {openmp_lib}" + passthough_args.append(fb_openmp_extra_flags) + + libs.append("omp") + else: + if _is_clang(cpp_compiler): + # TODO: fix issue, can't find omp.h + cflags.append("fopenmp") + libs.append("gomp") + elif _is_intel_compiler(cpp_compiler): + cflags.append("fiopenmp") + else: + cflags.append("fopenmp") + libs.append("gomp") + + return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthough_args + + +def get_mmap_self_macro(use_mmap_weights: bool) -> List[str]: + macros = [] + if use_mmap_weights: + macros.append(" USE_MMAP_SELF") + return macros + + +def get_cpp_torch_options( + cpp_compiler: str, + vec_isa: VecISA, + include_pytorch: bool, + aot_mode: bool, + compile_only: bool, + use_absolute_path: bool, + use_mmap_weights: bool, +) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]: + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + + torch_cpp_wrapper_definations = _get_torch_cpp_wrapper_defination() + use_custom_generated_macros_definations = _use_custom_generated_macros() + + ( + sys_libs_cflags, + sys_libs_include_dirs, + sys_libs_passthough_args, + ) = _setup_standard_sys_libs(cpp_compiler, aot_mode, use_absolute_path) + + isa_macros, isa_ps_args_build_flags = _get_build_args_of_chosen_isa(vec_isa) + + ( + torch_include_dirs, + torch_libraries_dirs, + torch_libraries, + ) = _get_torch_related_args(include_pytorch=include_pytorch, aot_mode=aot_mode) + + python_include_dirs, python_libraries_dirs = _get_python_related_args() + + ( + omp_cflags, + omp_ldflags, + omp_include_dir_paths, + omp_lib_dir_paths, + omp_lib, + omp_passthough_args, + ) = _get_openmp_args(cpp_compiler) + + cxx_abi_passthough_args = _get_glibcxx_abi_build_flags() + fb_macro_passthough_args = _use_fb_internal_macros() + + mmap_self_macros = get_mmap_self_macro(use_mmap_weights) + + definations = ( + torch_cpp_wrapper_definations + + use_custom_generated_macros_definations + + isa_macros + + fb_macro_passthough_args + + mmap_self_macros + ) + include_dirs = ( + sys_libs_include_dirs + + python_include_dirs + + torch_include_dirs + + omp_include_dir_paths + ) + cflags = sys_libs_cflags + omp_cflags + ldflags = omp_ldflags + libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths + libraries = torch_libraries + omp_lib + passthough_args = ( + sys_libs_passthough_args + + isa_ps_args_build_flags + + cxx_abi_passthough_args + + omp_passthough_args + ) + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +class CppTorchOptions(CppOptions): + """ + This class is inherited from CppTorchOptions, which automatic contains + base cxx build options. And then it will maintains torch related build + args. + 1. Torch include_directories, libraries, libraries_directories. + 2. Python include_directories, libraries, libraries_directories. + 3. OpenMP related. + 4. Torch MACROs. + 5. MISC + """ + + def __init__( + self, + vec_isa: VecISA = invalid_vec_isa, + include_pytorch: bool = False, + warning_all: bool = True, + aot_mode: bool = False, + compile_only: bool = False, + use_absolute_path: bool = False, + use_mmap_weights: bool = False, + shared: bool = True, + extra_flags: Sequence[str] = (), + ) -> None: + super().__init__( + compile_only=compile_only, + warning_all=warning_all, + extra_flags=extra_flags, + use_absolute_path=use_absolute_path, + ) + + self._aot_mode = aot_mode + + ( + torch_definations, + torch_include_dirs, + torch_cflags, + torch_ldflags, + torch_libraries_dirs, + torch_libraries, + torch_passthough_args, + ) = get_cpp_torch_options( + cpp_compiler=self._compiler, + vec_isa=vec_isa, + include_pytorch=include_pytorch, + aot_mode=aot_mode, + compile_only=compile_only, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + ) + + _append_list(self._definations, torch_definations) + _append_list(self._include_dirs, torch_include_dirs) + _append_list(self._cflags, torch_cflags) + _append_list(self._ldflags, torch_ldflags) + _append_list(self._libraries_dirs, torch_libraries_dirs) + _append_list(self._libraries, torch_libraries) + _append_list(self._passthough_args, torch_passthough_args) + self._finalize_options() + + +def _set_gpu_runtime_env() -> None: + if ( + config.is_fbcode() + and torch.version.hip is None + and "CUDA_HOME" not in os.environ + and "CUDA_PATH" not in os.environ + ): + os.environ["CUDA_HOME"] = build_paths.cuda() + + +def _transform_cuda_paths(lpaths: List[str]) -> None: + # This handles two cases: + # 1. Meta internal cuda-12 where libs are in lib/cuda-12 and lib/cuda-12/stubs + # 2. Linux machines may have CUDA installed under either lib64/ or lib/ + for i, path in enumerate(lpaths): + if ( + "CUDA_HOME" in os.environ + and path.startswith(os.environ["CUDA_HOME"]) + and not os.path.exists(f"{path}/libcudart_static.a") + ): + for root, dirs, files in os.walk(path): + if "libcudart_static.a" in files: + lpaths[i] = os.path.join(path, root) + lpaths.append(os.path.join(lpaths[i], "stubs")) + break + + +def get_cpp_torch_cuda_options( + cuda: bool, + aot_mode: bool = False, + compile_only: bool = False, +) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]: + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + if ( + config.is_fbcode() + and "CUDA_HOME" not in os.environ + and "CUDA_PATH" not in os.environ + ): + os.environ["CUDA_HOME"] = ( + build_paths.rocm() if torch.version.hip else build_paths.cuda() + ) + + _set_gpu_runtime_env() + from torch.utils import cpp_extension + + include_dirs = cpp_extension.include_paths(cuda) + libraries_dirs = cpp_extension.library_paths(cuda) + + if cuda: + definations.append(" USE_ROCM" if torch.version.hip else " USE_CUDA") + + if torch.version.hip is not None: + if config.is_fbcode(): + libraries += ["amdhip64"] + else: + libraries += ["c10_hip", "torch_hip"] + definations.append(" __HIP_PLATFORM_AMD__") + else: + if config.is_fbcode(): + libraries += ["cuda"] + else: + libraries += ["c10_cuda", "cuda", "torch_cuda"] + + if aot_mode: + if config.is_fbcode(): + from torch._inductor.codecache import cpp_prefix_path + + cpp_prefix_include_dir = [f"{os.path.dirname(cpp_prefix_path())}"] + include_dirs += cpp_prefix_include_dir + + if cuda and torch.version.hip is None: + _transform_cuda_paths(libraries_dirs) + + if config.is_fbcode(): + if torch.version.hip is not None: + include_dirs.append(os.path.join(build_paths.rocm(), "include")) + else: + include_dirs.append(os.path.join(build_paths.cuda(), "include")) + + if aot_mode and cuda: + if torch.version.hip is None: + if not compile_only: + # Only add link args, when compile_only is false. + passthough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"] + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +class CppTorchCudaOptions(CppTorchOptions): + """ + This class is inherited from CppTorchOptions, which automatic contains + base cxx build options and torch common build options. And then it will + maintains cuda device related build args. + """ + + def __init__( + self, + vec_isa: VecISA = invalid_vec_isa, + include_pytorch: bool = False, + cuda: bool = True, + aot_mode: bool = False, + compile_only: bool = False, + use_absolute_path: bool = False, + use_mmap_weights: bool = False, + shared: bool = True, + extra_flags: Sequence[str] = (), + ) -> None: + super().__init__( + vec_isa=vec_isa, + include_pytorch=include_pytorch, + aot_mode=aot_mode, + compile_only=compile_only, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + extra_flags=extra_flags, + ) + + cuda_definations: List[str] = [] + cuda_include_dirs: List[str] = [] + cuda_cflags: List[str] = [] + cuda_ldflags: List[str] = [] + cuda_libraries_dirs: List[str] = [] + cuda_libraries: List[str] = [] + cuda_passthough_args: List[str] = [] + + ( + cuda_definations, + cuda_include_dirs, + cuda_cflags, + cuda_ldflags, + cuda_libraries_dirs, + cuda_libraries, + cuda_passthough_args, + ) = get_cpp_torch_cuda_options( + cuda=cuda, aot_mode=aot_mode, compile_only=compile_only + ) + _append_list(self._definations, cuda_definations) + _append_list(self._include_dirs, cuda_include_dirs) + _append_list(self._cflags, cuda_cflags) + _append_list(self._ldflags, cuda_ldflags) + _append_list(self._libraries_dirs, cuda_libraries_dirs) + _append_list(self._libraries, cuda_libraries) + _append_list(self._passthough_args, cuda_passthough_args) + self._finalize_options() + + +def get_name_and_dir_from_output_file_path( + file_path: str, +) -> Tuple[str, str]: + """ + This function help prepare parameters to new cpp_builder. + Example: + input_code: /tmp/tmpof1n5g7t/5c/c5crkkcdvhdxpktrmjxbqkqyq5hmxpqsfza4pxcf3mwk42lphygc.cpp + name, dir = get_name_and_dir_from_output_file_path(input_code) + Run result: + name = c5crkkcdvhdxpktrmjxbqkqyq5hmxpqsfza4pxcf3mwk42lphygc + dir = /tmp/tmpof1n5g7t/5c/ + + put 'name' and 'dir' to CppBuilder's 'name' and 'output_dir'. + CppBuilder --> get_target_file_path will format output path accoding OS: + Linux: /tmp/tmppu87g3mm/zh/czhwiz4z7ca7ep3qkxenxerfjxy42kehw6h5cjk6ven4qu4hql4i.so + Windows: [Windows temp path]/tmppu87g3mm/zh/czhwiz4z7ca7ep3qkxenxerfjxy42kehw6h5cjk6ven4qu4hql4i.dll + """ + name_and_ext = os.path.basename(file_path) + name, ext = os.path.splitext(name_and_ext) + dir = os.path.dirname(file_path) + + return name, dir + + +class CppBuilder: + """ + CppBuilder is a cpp jit builder, and it supports both Windows, Linux and MacOS. + Args: + name: + 1. Build target name, the final target file will append extension type automatically. + 2. Due to the CppBuilder is supports mutliple OS, it will maintains ext for OS difference. + sources: + Source code file list to be built. + BuildOption: + Build options to the builder. + output_dir: + 1. The output_dir the taget file will output to. + 2. The default value is empty string, and then the use current dir as output dir. + 3. Final target file: output_dir/name.ext + """ + + def __get_python_module_ext(self) -> str: + SHARED_LIB_EXT = ".pyd" if _IS_WINDOWS else ".so" + return SHARED_LIB_EXT + + def __get_object_ext(self) -> str: + EXT = ".obj" if _IS_WINDOWS else ".o" + return EXT + + def __init__( + self, + name: str, + sources: Union[str, List[str]], + BuildOption: BuildOptionsBase, + output_dir: str = "", + ) -> None: + self._compiler = "" + self._cflags_args = "" + self._definations_args = "" + self._include_dirs_args = "" + self._ldflags_args = "" + self._libraries_dirs_args = "" + self._libraries_args = "" + self._passthough_parameters_args = "" + + self._output_dir = "" + self._target_file = "" + + self._use_absolute_path: bool = False + self._aot_mode: bool = False + + self._name = name + + # Code start here, initial self internal veriables firstly. + self._compiler = BuildOption.get_compiler() + self._use_absolute_path = BuildOption.get_use_absolute_path() + self._aot_mode = BuildOption.get_aot_mode() + + self._output_dir = output_dir + + self._compile_only = BuildOption.get_compile_only() + file_ext = ( + self.__get_object_ext() + if self._compile_only + else self.__get_python_module_ext() + ) + self._target_file = os.path.join(self._output_dir, f"{self._name}{file_ext}") + + if isinstance(sources, str): + sources = [sources] + + if config.is_fbcode(): + if self._aot_mode and not self._use_absolute_path: + inp_name = sources + # output process @ get_name_and_dir_from_output_file_path + else: + # We need to copy any absolute-path torch includes + inp_name = [os.path.basename(i) for i in sources] + self._target_file = os.path.basename(self._target_file) + + self._sources_args = " ".join(inp_name) + else: + self._sources_args = " ".join(sources) + + for cflag in BuildOption.get_cflags(): + if _IS_WINDOWS: + self._cflags_args += f"/{cflag} " + else: + self._cflags_args += f"-{cflag} " + + for defination in BuildOption.get_definations(): + if _IS_WINDOWS: + self._definations_args += f"/D {defination} " + else: + self._definations_args += f"-D {defination} " + + for inc_dir in BuildOption.get_include_dirs(): + if _IS_WINDOWS: + self._include_dirs_args += f"/I {inc_dir} " + else: + self._include_dirs_args += f"-I{inc_dir} " + + for ldflag in BuildOption.get_ldflags(): + if _IS_WINDOWS: + self._ldflags_args += f"/{ldflag} " + else: + self._ldflags_args += f"-{ldflag} " + + for lib_dir in BuildOption.get_libraries_dirs(): + if _IS_WINDOWS: + self._libraries_dirs_args += f'/LIBPATH:"{lib_dir}" ' + else: + self._libraries_dirs_args += f"-L{lib_dir} " + + for lib in BuildOption.get_libraries(): + if _IS_WINDOWS: + self._libraries_args += f'"{lib}.lib" ' + else: + self._libraries_args += f"-l{lib} " + + for passthough_arg in BuildOption.get_passthough_args(): + self._passthough_parameters_args += f"{passthough_arg} " + + def get_command_line(self) -> str: + def format_build_command( + compiler: str, + sources: str, + include_dirs_args: str, + definations_args: str, + cflags_args: str, + ldflags_args: str, + libraries_args: str, + libraries_dirs_args: str, + passthougn_args: str, + target_file: str, + ) -> str: + if _IS_WINDOWS: + # https://learn.microsoft.com/en-us/cpp/build/walkthrough-compile-a-c-program-on-the-command-line?view=msvc-1704 + # https://stackoverflow.com/a/31566153 + cmd = ( + f"{compiler} {include_dirs_args} {definations_args} {cflags_args} {sources} " + f"{passthougn_args} /LD /Fe{target_file} /link {libraries_dirs_args} {libraries_args} {ldflags_args} " + ) + cmd = normalize_path_separator(cmd) + else: + compile_only_arg = "-c" if self._compile_only else "" + cmd = re.sub( + r"[ \n]+", + " ", + f""" + {compiler} {sources} {definations_args} {cflags_args} {include_dirs_args} + {passthougn_args} {ldflags_args} {libraries_args} {libraries_dirs_args} {compile_only_arg} -o {target_file} + """, + ).strip() + return cmd + + command_line = format_build_command( + compiler=self._compiler, + sources=self._sources_args, + include_dirs_args=self._include_dirs_args, + definations_args=self._definations_args, + cflags_args=self._cflags_args, + ldflags_args=self._ldflags_args, + libraries_args=self._libraries_args, + libraries_dirs_args=self._libraries_dirs_args, + passthougn_args=self._passthough_parameters_args, + target_file=self._target_file, + ) + return command_line + + def get_target_file_path(self) -> str: + return normalize_path_separator(self._target_file) + + def build(self) -> Tuple[bytes, str]: + """ + It is must need a temperary directory to store object files in Windows. + After build completed, delete the temperary directory to save disk space. + """ + _create_if_dir_not_exist(self._output_dir) + _build_tmp_dir = os.path.join( + self._output_dir, f"{self._name}_{_BUILD_TEMP_DIR}" + ) + _create_if_dir_not_exist(_build_tmp_dir) + + build_cmd = self.get_command_line() + + status = run_compile_cmd(build_cmd, cwd=_build_tmp_dir) + + _remove_dir(_build_tmp_dir) + return status, self._target_file diff --git a/lib/python3.10/site-packages/torch/_inductor/cpu_vec_isa.py b/lib/python3.10/site-packages/torch/_inductor/cpu_vec_isa.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4838e5f168550656e4e0024d8fb2003ccf06a9 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/cpu_vec_isa.py @@ -0,0 +1,373 @@ +# mypy: allow-untyped-defs +import dataclasses +import functools +import os +import platform +import re +import subprocess +import sys +from typing import Any, Callable, Dict, List + +import torch +from torch._inductor import config + + +_IS_WINDOWS = sys.platform == "win32" + + +def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str: + # ISA dry compile will cost about 1 sec time each startup time. + # Please check the issue: https://github.com/pytorch/pytorch/issues/100378 + # Actually, dry compile is checking compile capability for ISA. + # We just record the compiler version, isa options and pytorch version info, + # and generated them to output binary hash path. + # It would optimize and skip compile existing binary. + from torch._inductor.cpp_builder import get_compiler_version_info, get_cpp_compiler + + compiler_info = get_compiler_version_info(get_cpp_compiler()) + torch_version = torch.__version__ + fingerprint = f"{compiler_info}={isa_flags}={torch_version}" + return fingerprint + + +class VecISA: + _bit_width: int + _macro: List[str] + _arch_flags: str + _dtype_nelements: Dict[torch.dtype, int] + + # Note [Checking for Vectorized Support in Inductor] + # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions + # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions + # like exp, pow, sin, cos and etc. + # But PyTorch and TorchInductor might use different compilers to build code. If + # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so + # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass + # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest + # gcc/g++ compiler by default while it could support the AVX512 compilation. + # Therefore, there would be a conflict sleef version between PyTorch and + # TorchInductor. Hence, we dry-compile the following code to check whether current + # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM + # also needs the logic + # In fbcode however, we are using the same compiler for pytorch and for inductor codegen, + # making the runtime check unnecessary. + _avx_code = """ +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) +#include +#include +#endif + +alignas(64) float in_out_ptr0[16] = {0.0}; + +extern "C" void __avx_chk_kernel() { + auto tmp0 = at::vec::Vectorized(1); + auto tmp1 = tmp0.exp(); + tmp1.store(in_out_ptr0); +} +""" # noqa: B950 + + _avx_py_load = """ +import torch +from ctypes import cdll +cdll.LoadLibrary("__lib_path__") +""" + + def bit_width(self) -> int: + return self._bit_width + + def nelements(self, dtype: torch.dtype = torch.float) -> int: + return self._dtype_nelements[dtype] + + def build_macro(self) -> List[str]: + return self._macro + + def build_arch_flags(self) -> str: + return self._arch_flags + + def __hash__(self) -> int: + return hash(str(self)) + + def check_build(self, code: str) -> bool: + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT, write + from torch._inductor.cpp_builder import ( + CppBuilder, + CppTorchOptions, + normalize_path_separator, + ) + + key, input_path = write( + code, + "cpp", + extra=_get_isa_dry_compile_fingerprint(self._arch_flags), + ) + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_dir = os.path.dirname(input_path) + buid_options = CppTorchOptions(vec_isa=self, warning_all=False) + x86_isa_help_builder = CppBuilder( + key, + [input_path], + buid_options, + output_dir, + ) + try: + # Check if the output file exist, and compile when not. + output_path = normalize_path_separator( + x86_isa_help_builder.get_target_file_path() + ) + if not os.path.isfile(output_path): + status, target_file = x86_isa_help_builder.build() + + # Check build result + subprocess.check_call( + [ + sys.executable, + "-c", + VecISA._avx_py_load.replace("__lib_path__", output_path), + ], + cwd=output_dir, + stderr=subprocess.DEVNULL, + env={**os.environ, "PYTHONPATH": ":".join(sys.path)}, + ) + except Exception as e: + return False + + return True + + @functools.lru_cache(None) # noqa: B019 + def __bool__(self) -> bool: + if config.cpp.vec_isa_ok is not None: + return config.cpp.vec_isa_ok + + if config.is_fbcode(): + return True + + return self.check_build(VecISA._avx_code) + + +@dataclasses.dataclass +class VecNEON(VecISA): + _bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h + _macro = ["CPU_CAPABILITY_NEON"] + if sys.platform == "darwin" and platform.processor() == "arm": + _macro.append("AT_BUILD_ARM_VEC256_WITH_SLEEF") + _arch_flags = "" # Unused + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "asimd" # detects the presence of advanced SIMD on armv8-a kernels + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +@dataclasses.dataclass +class VecAVX512(VecISA): + _bit_width = 512 + _macro = ["CPU_CAPABILITY_AVX512"] + _arch_flags = ( + "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" + if not _IS_WINDOWS + else "/arch:AVX512" + ) # TODO: use cflags + _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32} + + def __str__(self) -> str: + return "avx512" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +@dataclasses.dataclass +class VecAMX(VecAVX512): + _arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8" + + def __str__(self) -> str: + return super().__str__() + " amx_tile" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + _amx_code = """ +#include +#include + +struct amx_tilecfg { + uint8_t palette_id; + uint8_t start_row; + uint8_t reserved_0[14]; + uint16_t colsb[16]; + uint8_t rows[16]; +}; + +extern "C" void __amx_chk_kernel() { + amx_tilecfg cfg = {0}; + _tile_loadconfig(&cfg); + _tile_zero(0); + _tile_dpbf16ps(0, 1, 2); + _tile_dpbusd(0, 1, 2); +} +""" + + @functools.lru_cache(None) # noqa: B019 + def __bool__(self) -> bool: + if super().__bool__(): + if config.is_fbcode(): + return False + if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx(): + return True + return False + + +@dataclasses.dataclass +class VecAVX2(VecISA): + _bit_width = 256 + _macro = ["CPU_CAPABILITY_AVX2"] + _arch_flags = ( + "-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2" + ) # TODO: use cflags + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "avx2" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +@dataclasses.dataclass +class VecZVECTOR(VecISA): + _bit_width = 256 + _macro = [ + "CPU_CAPABILITY_ZVECTOR", + "CPU_CAPABILITY=ZVECTOR", + "HAVE_ZVECTOR_CPU_DEFINITION", + ] + _arch_flags = "-mvx -mzvector" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "zvector" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +@dataclasses.dataclass +class VecVSX(VecISA): + _bit_width = 256 # VSX simd supports 128 bit_width, but aten is emulating it as 256 + _macro = ["CPU_CAPABILITY_VSX"] + _arch_flags = "-mvsx" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "vsx" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +class InvalidVecISA(VecISA): + _bit_width = 0 + _macro = [""] + _arch_flags = "" + _dtype_nelements = {} + + def __str__(self) -> str: + return "INVALID_VEC_ISA" + + def __bool__(self) -> bool: # type: ignore[override] + return False + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +def x86_isa_checker() -> List[str]: + supported_isa: List[str] = [] + + def _check_and_append_supported_isa( + dest: List[str], isa_supported: bool, isa_name: str + ) -> None: + if isa_supported: + dest.append(isa_name) + + Arch = platform.machine() + """ + Arch value is x86_64 on Linux, and the value is AMD64 on Windows. + """ + if Arch != "x86_64" and Arch != "AMD64": + return supported_isa + + avx2 = torch.cpu._is_avx2_supported() + avx512 = torch.cpu._is_avx512_supported() + amx_tile = torch.cpu._is_amx_tile_supported() + + _check_and_append_supported_isa(supported_isa, avx2, "avx2") + _check_and_append_supported_isa(supported_isa, avx512, "avx512") + _check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile") + + return supported_isa + + +invalid_vec_isa = InvalidVecISA() +supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()] + + +# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content +# might have too much redundant content that is useless for ISA check. Hence, +# we only cache some key isa information. +@functools.lru_cache(None) +def valid_vec_isa_list() -> List[VecISA]: + isa_list: List[VecISA] = [] + if sys.platform == "darwin" and platform.processor() == "arm": + isa_list.append(VecNEON()) + + if sys.platform not in ["linux", "win32"]: + return isa_list + + arch = platform.machine() + if arch == "s390x": + with open("/proc/cpuinfo") as _cpu_info: + while True: + line = _cpu_info.readline() + if not line: + break + # process line + featuresmatch = re.match(r"^features\s*:\s*(.*)$", line) + if featuresmatch: + for group in featuresmatch.groups(): + if re.search(r"[\^ ]+vxe[\$ ]+", group): + isa_list.append(VecZVECTOR()) + break + elif arch == "ppc64le": + isa_list.append(VecVSX()) + elif arch == "aarch64": + isa_list.append(VecNEON()) + elif arch in ["x86_64", "AMD64"]: + """ + arch value is x86_64 on Linux, and the value is AMD64 on Windows. + """ + _cpu_supported_x86_isa = x86_isa_checker() + for isa in supported_vec_isa_list: + if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa: + isa_list.append(isa) + + return isa_list + + +def pick_vec_isa() -> VecISA: + if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]): + return VecAVX2() + + _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list() + if not _valid_vec_isa_list: + return invalid_vec_isa + + # If the simdlen is None, it indicates determine the vectorization length automatically + if config.cpp.simdlen is None: + assert _valid_vec_isa_list + return _valid_vec_isa_list[0] + + for isa in _valid_vec_isa_list: + if config.cpp.simdlen == isa.bit_width(): + return isa + + return invalid_vec_isa diff --git a/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py b/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py new file mode 100644 index 0000000000000000000000000000000000000000..5a33de0e3668998ac6f6d71ab16c696938ab1265 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py @@ -0,0 +1,2441 @@ +""" +CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables, +which share the same memory pool. Sharing a memory pool is an extremely +important optimization when chaining multiple CUDA graphs together, as it +prevents you from needing to copy intermediate tensors from one graph to the +next, and reduces overall memory usage by allowing dead memory from the first +pool to be reused in the second. + +The standard graph/make_graph_callables support sharing memory pool, but +with a lot of caveats. CUDA graph trees remove these restrictions: + +* Previously, if you recorded graphs A, B, you had to replay A, B in that + order. With CUDA graph trees, after replaying A, you can change your + mind and record/replay a different graph B'; we will support efficient + execution of both A, B and A, B', using only max(mem(A, B), mem(A, B')). In + other words: we support arbitrary trees of CUDA graph operations, not just + sequences (this is why this feature is called CUDA graph trees.) + +* Previously, if you executed graph A, some non-CUDA graph code, and then + graph B, after executing graph B, it was not safe to retain any references + to intermediates produced by A. With CUDA graph trees, we track if any +outputs of graph A are still live by the time graph B is run, and make + sure graph B doesn't clobber there memory when reusing the CUDA graphs + pool. You'll get a separate recording of B depending on what tensors + stay live or dead. + +CUDA graph trees are flexible enough to be used in Dynamo across graph breaks, +which is their primary use case. + +The ability to switch from replay to record is fairly nontrivial: remember that +when you replay a CUDA graph, you only replay CUDA operations; no CPU side state +is updated. In particular, the CPU-side book-keeping for the allocator is not +reconstructed. However, to record a new child CUDA graph, we must restore this +book-keeping. This is what checkpoint pool state is used for. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import gc +import itertools +import operator +import sys +import threading +import traceback +import warnings +import weakref +from collections import defaultdict +from enum import auto, Enum +from typing import ( + Any, + Callable, + cast, + ContextManager, + Dict, + Generator, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + TYPE_CHECKING, + TypeVar, + Union, +) + +import torch.fx +from torch import Tensor +from torch._dynamo.mutation_guard import GenerationTracker +from torch._dynamo.utils import counters, preserve_rng_state +from torch._inductor.compile_fx import ( + align_inputs_from_check_idxs, + copy_misaligned_inputs, + get_expanded_dims, + get_input_idxs_to_check, + index_expanded_dims, + remove_unaligned_input_idxs, + static_input, +) +from torch._inductor.cudagraph_utils import ( + check_for_mutation, + CheckInvariantStatus, + FunctionID, + log_cudagraph_skip_and_bump_counter, + log_data_ptr_mismatch, + maybe_warning_due_to_dynamic_shape, + ModelType, + OutputType, + PlaceholderInfo, + WrappedFunction, +) +from torch.multiprocessing.reductions import StorageWeakRef +from torch.storage import UntypedStorage +from torch.utils import _pytree as pytree +from torch.utils.weak import TensorWeakRef + + +if TYPE_CHECKING: + from torch._inductor.utils import InputType + from torch.types import _bool + +StorageWeakRefPointer = int +StorageDataPtr = int +NBytes = int +S = TypeVar("S", bound="StorageWeakRefWrapper") + + +if torch.backends.cuda.is_built(): + from torch._C import ( + _cuda_CUDAAllocator_AllocatorState as AllocatorState, + _set_cached_tensors_enabled as _set_cached_tensors_enabled, + ) +else: + + class AllocatorState: # type: ignore[no-redef] + pass + + def _set_cached_tensors_enabled(enabled: _bool) -> None: + pass + + +log = torch._logging.getArtifactLogger(__name__, "cudagraphs") + + +from . import config + + +@dataclasses.dataclass(frozen=True) +class GraphID: + "Unique counter of a cuda graph recording" + id: int + + +def clear_cublass_cache() -> None: + """ + Cublas keeps a persistent workspace allocation for running matmuls. This poses a problem for + doing warmup within a CUDAGraph private pool because we do not want persistent allocations from + one one run to the next. When we begin a new run of a cudagraphs path (generation), all tensors + from the previous generation are freed. This frees them the memory pool, but not elsewhere. + A tensor in the cublas workspace would continue to be in use the workspace but would also get allocated + in the next run. The memory would be in use in two places. + + To solve this, we clear cublas caches before and after warming up or recording. If a workspace is required + it will be allocated to the cudagraph private pool and accounted for in the allocator for the duration of the + program. There is no overhead to this on replay since cudagraphs removes allocation overhead. + """ + torch._C._cuda_clearCublasWorkspaces() + + +@contextlib.contextmanager +def clear_cublas_manager() -> Generator[None, None, None]: + "Context manager around clearing cublas caches that will clear on enter and exit" + clear_cublass_cache() + try: + yield + finally: + clear_cublass_cache() + + +@contextlib.contextmanager +def disable_conv_cache_emptying() -> Generator[None, None, None]: + prev = torch._C._cuda_get_conv_benchmark_empty_cache() + torch._C._cudnn_set_conv_benchmark_empty_cache(False) + try: + yield + finally: + torch._C._cudnn_set_conv_benchmark_empty_cache(prev) + + +@contextlib.contextmanager +def enable_history_recording() -> Generator[None, None, None]: + "Turns on history recording in the CUDA Caching Allocator" + enabled = torch._C._cuda_isHistoryEnabled() + try: + if not enabled: + torch.cuda.memory._record_memory_history() + yield + finally: + if not enabled: + torch.cuda.memory._record_memory_history(None) + + +def get_history_recording() -> ContextManager[None]: + # TODO - remove, prevents cleanup + if not config.triton.cudagraph_trees_history_recording: + return contextlib.nullcontext() + return enable_history_recording() + + +class TreeManagerContainer: + """ + Manages the lifetime of the tree manager. Like `PrivatePool` in cuda caching allocator, + the tree and its corresponding memory pool should be kept alive as long as any outstanding + graph or tensor which is an output of a graph remains alive. + + There is a single tree manager container per device. + + The lifecycle of a tree_manager is: + - Is constructed, no graph, no fns, no tensors + - Tree manager is fetched, resulting in tree manager being allocated + - We generate a bunch of functions, calling add_strong_reference + - These functions die, calling finalize_reference + - When all the functions die, we finalize_tree_manager. + + TODO: in the future, we would like to do the following once storage weak refs land + - We look for all the live storages and add references to THOSE + - We count as storages die + - All the storages are dead, we deallocate the tree manager + """ + + def __init__(self, device_index: int) -> None: + # This class keeps a strong reference to tree_manager, + # but upon all other strong references to the tree_manager will reset it to None. + # We need a strong reference so that we can still access its attributes upon cleanup. + self.tree_manager: Optional[CUDAGraphTreeManager] = None + + # Number of outstanding references to the current tree manager + self.live_cudagraphify_fns = 0 + + self.device_index = device_index + + # Following two objects are only set in the case that Tensor outputs outlive + # the cudagraphify_fns. Reference to the Graph is needed to keep the private pool from + # deallocation. + self.live_storages_count = 0 + self.graph: Optional[torch.cuda.CUDAGraph] = None + + self.lock = threading.Lock() + + def _finalize_tensor(self) -> None: + with self.lock: + self.live_storages_count -= 1 + if self.live_storages_count == 0: + self.graph = None + + # manager was used again after existing cleanup, + # we shouldnt set it to None + if self.live_cudagraphify_fns == 0: + self.tree_manager = None + + def finalize_cudagraphify_fn(self) -> None: + with self.lock: + self.live_cudagraphify_fns -= 1 + if self.live_cudagraphify_fns == 0: + self._finalize_tree_manager() + + def _finalize_tree_manager(self) -> None: + assert self.lock.locked() + self.tree_manager = None + + # TODO - when issue #91395 is landed, we can set a weakref on + # storages and trigger a deallocation when all outputs of the + # cudagraph are dead. + + # live_storages = list( + # tree_manager.live_cudagraph_pool_storages_in_curr_execution() + # ) + + # # Maintain reference to graph to keep tensors alive + # assert len(tree_manager.roots) > 0, "expected at least one use" + # root = next(tree_manager.get_roots()) + # self.graph = root.graph + # seen_storages = set() + # for stor in live_storages: + # if stor in seen_storages: + # continue + # seen_storages.add(stor) + # self.live_storages_count += 1 + # . weakref.finalize(stor, self._finalize_tensor) + + def add_strong_reference(self, fn: Callable[..., Any]) -> None: + with self.lock: + self.live_cudagraphify_fns += 1 + + weakref.finalize(fn, self.finalize_cudagraphify_fn) + + def get_tree_manager(self) -> CUDAGraphTreeManager: + with self.lock: + if self.tree_manager is None: + self.tree_manager = CUDAGraphTreeManager(self.device_index) + return self.tree_manager + + +local = threading.local() + +# one tree manager per device +local.tree_manager_containers = {} +local.tree_manager_locks = defaultdict(threading.Lock) + + +# only incremented by user call of mark_step_begin +class MarkStepBox: + mark_step_counter = 0 + + +# We need to register this as an object that will be copied over as TLS when new +# threads are created in autograd +torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_containers) +torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks) + + +def mark_step_begin() -> None: + "Indicates that a new iteration of inference or training is about to begin." + + # iterate down to distinguish from GenerationTracking counter + MarkStepBox.mark_step_counter -= 1 + + +def reset_cudagraph_trees() -> None: + "Clear all cudagraph trees" + # see shutdown below for why this is necessary + container_dict = get_obj(local, "tree_manager_containers") + locks_dict = get_obj(local, "tree_manager_locks") + for device, lock in locks_dict.items(): + with lock: + container = container_dict.get(device) + if not container or not container.tree_manager: + continue + + container.tree_manager.shutdown() + + _set_cached_tensors_enabled(False) + container_dict.clear() + + MarkStepBox.mark_step_counter = 0 + + +def get_obj(local: Any, attr_name: str) -> Any: + if hasattr(local, attr_name): + return getattr(local, attr_name) + else: + assert torch._C._is_key_in_tls(attr_name) + return torch._C._get_obj_in_tls(attr_name) + + +def get_container(device_index: int) -> TreeManagerContainer: + container_dict = get_obj(local, "tree_manager_containers") + lock = get_obj(local, "tree_manager_locks")[device_index] + + with lock: + if device_index not in container_dict: + container_dict[device_index] = TreeManagerContainer(device_index) + + return container_dict[device_index] + + +def get_manager( + device_index: int, create_if_none_exists: bool = True +) -> Optional[CUDAGraphTreeManager]: + if create_if_none_exists: + return get_container(device_index).get_tree_manager() + return get_container(device_index).tree_manager + + +def cudagraphify_impl( + model: ModelType, + inputs: List[InputType], + static_input_idxs: Sequence[int], + *args: Any, + **kwargs: Any, +) -> ModelType: + fn_cache: Dict[Tuple[int, ...], Callable[..., Any]] = {} + + # Detect int inputs: we need to index on these + int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)] + get_ints: Any = operator.itemgetter(*int_key) if int_key else lambda _: None + + has_warn = False + + del inputs + + def deferred_cudagraphify(inputs: List[InputType]) -> OutputType: + nonlocal has_warn + + int_key = get_ints(inputs) + fn = fn_cache.get(int_key) + if fn is not None: + return fn(inputs) + + if int_key is None: + log.info("recording cudagraph tree for graph without symints") + else: + log.info("recording cudagraph tree for symint key %s", int_key) + + if not has_warn: + has_warn = maybe_warning_due_to_dynamic_shape(fn_cache, int_key) + + # first get indices we need to check to align, then update our static inputs, + # and finally copy + check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) + new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) + copy_misaligned_inputs(inputs, check_input_idxs) + + fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs) + fn = align_inputs_from_check_idxs(fn, inputs_to_check=check_input_idxs) + fn_cache[int_key] = fn + + return out + + return deferred_cudagraphify + + +def cudagraphify( + model: ModelType, + inputs: List[InputType], + static_input_idxs: Sequence[int] = (), + *, + device_index: int, + is_backward: bool, + is_inference: bool, + stack_traces: Optional[StackTraces] = None, + constants: Tuple[torch.Tensor, ...] = (), + placeholders: Tuple[PlaceholderInfo, ...] = (), + mutated_input_idxs: Tuple[int, ...] = (), +) -> Tuple[ModelType, OutputType]: + manager = get_container(device_index).get_tree_manager() + assert not (is_backward and is_inference) + mode = ( + CompilationMode.BACKWARD + if is_backward + else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD) + ) + + return manager.add_function( + model, + inputs, + static_input_idxs, + stack_traces, + mode, + constants, + placeholders, + mutated_input_idxs, + ) + + +class StorageWeakRefWrapper: + """ + Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked. + """ + + __slots__ = ["ref", "_data_ptr", "extra_ref_check"] + + storage_ref: Optional[StorageWeakRef] + + def __init__( + self, + inp: Union[Tensor, UntypedStorage], + extra_ref_check: Optional[Callable[[], bool]] = None, + ) -> None: + """ + extra_ref_check is an additional check we need to run to check if the + weak ref has expired. in checking storage use count we assume extra_ref_check + will hold an additional reference to the storage. + """ + if isinstance(inp, Tensor): + stor = inp.untyped_storage() + else: + assert isinstance(inp, UntypedStorage) + stor = inp + self.ref = StorageWeakRef(stor) + self._data_ptr = stor.data_ptr() + self.extra_ref_check = extra_ref_check + + @classmethod + def from_weakref_and_data_ptr( + cls: Type[S], + cdata: Any, + data_ptr: int, + extra_ref_check: Optional[Callable[[], bool]] = None, + ) -> StorageWeakRefWrapper: + instance = cls.__new__(cls) + instance._data_ptr = data_ptr + instance.ref = StorageWeakRef.from_weakref(cdata) + instance.extra_ref_check = extra_ref_check + return instance + + def __call__(self) -> Optional[StorageWeakRefPointer]: + if self.expired(): + return None + + return self.ref.cdata + + def swap_weakref(self, cdata: Any) -> None: + self.ref.__del__() + self.ref.cdata = cdata + + def data_ptr(self) -> int: + "NB: returns the data ptr even if the storage has expired" + return self._data_ptr + + def remove_extra_reference(self) -> None: + self.extra_ref_check = None + + def expired(self) -> bool: + if self.extra_ref_check is not None and not self.extra_ref_check(): + return False + + # if extra_ref_check is not None we expect an additional reference + stor_count = torch._C._storage_Use_Count(self.ref.cdata) + return (stor_count - (self.extra_ref_check is not None)) == 0 + + def __repr__(self) -> str: + if self.ref is None or self.ref.expired(): + return f"StorageWeakRefWrapper to {self.data_ptr()}; dead" + else: + return f"StorageWeakRefWrapper to {self.data_ptr()}; alive" + + +def is_live(weak_ref: Optional[StorageWeakRefWrapper]) -> bool: + return maybe_deref(weak_ref) is not None + + +def maybe_deref( + weak_ref: Optional[StorageWeakRefWrapper], +) -> Optional[Tuple[StorageWeakRefPointer, int]]: + if weak_ref is None: + return None + r = weak_ref() + if r is None: + return None + # NB: r.data_ptr() does not necessarily equal weak_ref.data_ptr() + return r, weak_ref.data_ptr() + + +@contextlib.contextmanager +def _use_cuda_memory_pool_manager( + device: int, mem_pool: Tuple[int, int], stream: torch.cuda.Stream +) -> Generator[None, None, None]: + """ + Context manager to use cuda graph pool for new allocations. If you use this manager + all cudagraph tensors in use should be reflected in the allocator or they will be overwritten. + existing_graph should already have been used in a capture, and the mem_pool must already exist, + because this manager will not preserve a reference to the pool which keeps it alive. + """ + torch.cuda.synchronize() + stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(stream), torch.device(device): + torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool) + try: + yield + finally: + torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool) + torch._C._cuda_releasePool(device, mem_pool) + + torch.cuda.current_stream().wait_stream(stream) + + +def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]: + if not isinstance(t, torch.Tensor): + assert t is None + return None + return StorageWeakRefWrapper(t) + + +# A path index of (depth, offset) indices into a graph that is `depth`` number of nodes from the root +# at graph output offset +PathOutputIndex = Tuple[int, int] + +# For each node in the path, for each output, is the output alive +PathLiveness = List[List[bool]] + +StackTraces = List[Optional[str]] + + +class CUDAWarmupNode: + """ + Simplified Wrapper around A CUDA Model that wraps outputs in storage refs and exposes + apis to get the live storages in the current chain of warmup. + + A CUDAWarmupNode may have either CUDAGraphNode or CUDAWarmupNode as a parent, but may only have + CUDAWarmupNode as children, because we cannot record or execute with tensors which do not have stable + memory addresses. + + CUDAWarmupNode and CUDAGraphNode have a number of differences that make it easier to use separate classes. + - Much of the CUDAGraphNode logic & initialization is based on the tensor properties of first recording. In the + first instance of warmup, these are not finalized yet. + - All Inputs to the RecordedFunction must be copied over to the cuda graph memory pool, this is unnecessary in warmup. + - CUDAWarmup is only used once and so does not need to optimize as much bookkeeping. It is much simpler. + + NB: this class and CUDAGraphNode need to expose `path_live_weakrefs`, `all_outputs_are_dead`, and + `self.outputs_weakrefs`, `stack_traces`, and `tensor_weakrefs` for compatibility. + """ + + def __init__( + self, + wrapped_function: WrappedFunction, + parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]], + cuda_graphs_pool: Tuple[int, int], + existing_cuda_graph: Optional[torch.cuda.CUDAGraph], + device_index: int, + stack_traces: Optional[StackTraces], + stream: torch.cuda.Stream, + already_warm: bool, + id: GraphID, + ) -> None: + self.wrapped_function = wrapped_function + self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent + self.cuda_graphs_pool = cuda_graphs_pool + self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = [] + self.tensor_weakrefs: List[Optional[TensorWeakRef]] = [] + self.existing_cuda_graph = existing_cuda_graph + self.has_run = False + self.device_index = device_index + self.stack_traces = stack_traces + self.stream = stream + self.already_warm = already_warm + self.id = id + + def run(self, new_inputs: Any) -> OutputType: + assert not self.has_run, "Wrapped function should never be run twice" + + # See: output_is_alias_of_persistent_static_inputs below. We should only be returning freshly created + # storages in path_live_weakrefs. + existing_path_data_ptrs = { + t.data_ptr() for t in self.path_live_weakrefs() if t() + } + + def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]: + non_cudagraph_inps = [] + for t in itertools.chain(new_inputs, self.wrapped_function.constants): + if ( + isinstance(t, torch.Tensor) + and t.untyped_storage().data_ptr() not in existing_path_data_ptrs + ): + non_cudagraph_inps.append(weakref.ref(t.untyped_storage())) + return non_cudagraph_inps + + non_cudagraph_inps_storages = get_non_cudagraph_inps() + + if config.triton.slow_path_cudagraph_asserts and not self.already_warm: + refs = list(self.path_live_weakrefs()) + check_memory_pool(self.device_index, self.cuda_graphs_pool, refs) + + with torch.cuda.device( + self.device_index + ), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager( + self.device_index, self.cuda_graphs_pool, self.stream + ), get_history_recording(): + out = self.wrapped_function.model(new_inputs) + + # We need to know which outputs are allocated within the cudagraph pool + # so that we can deallocate them at the beginning of the next cudagraph step, + # and set their access to error. + # We use a weakref to the inputs storage, in case a block which was previously + # allocated to the general caching allocator pool gets reallocated to a private pool. + + non_cudagraph_inps_storage_ptrs = set() + for storage in non_cudagraph_inps_storages: + s = storage() + if s is not None: + non_cudagraph_inps_storage_ptrs.add(s._cdata) + + assert len(new_inputs) == 0 + + # sdpa returns cpu tensors when not recording cuda graph + def add_ref(o: Any) -> bool: + return ( + isinstance(o, torch.Tensor) + and o.is_cuda + and o.untyped_storage()._cdata not in non_cudagraph_inps_storage_ptrs + and o.untyped_storage().data_ptr() != 0 + ) + + self.outputs_weakrefs.extend( + [map_to_ref(o) if add_ref(o) else None for o in out] + ) + self.tensor_weakrefs.extend( + [TensorWeakRef(o) if add_ref(o) else None for o in out] + ) + + if config.triton.slow_path_cudagraph_asserts and not self.already_warm: + out_refs = list(self.path_live_weakrefs()) + check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs) + + return out + + @property + def _path_from_root( + self, + ) -> Generator[Union[CUDAGraphNode, CUDAWarmupNode], None, None]: + nodes = [] + node: Union[CUDAGraphNode, CUDAWarmupNode] = self + while node: + nodes.append(node) + node = node.parent # type: ignore[assignment] + + yield from reversed(nodes) + + def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: + "Returns all live storages weakrefs that created by nodes in this path" + for node in self._path_from_root: + for output in node.outputs_weakrefs: + if is_live(output): + yield output # type: ignore[misc] + + def all_outputs_are_dead(self) -> bool: + return not list(self.path_live_weakrefs()) + + def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool: + for storage_weak_ref in self.path_live_weakrefs(): + if t.untyped_storage().data_ptr() == storage_weak_ref.data_ptr(): + return True + return False + + +# Aliases for List that say what the indices denote +InputList = List # input indexes +OutputList = List # output indexes +LevelList = List # levels (distance from root of tree) + + +class OutputAliasInfo: + pass + + +class _UnaliasedStorage(OutputAliasInfo): + "Singleton to mark that the graph output constructs a new alias or is None" + + +UnaliasedStorage = _UnaliasedStorage() + + +class AliasesPriorGraphOutput(OutputAliasInfo): + "Marks that the graph output aliases an output of a prior graph" + __slots__ = ["index"] + + index: PathOutputIndex + + def __init__(self, index: PathOutputIndex) -> None: + assert isinstance(index, tuple) + self.index = index + + +class AliasesNewOutput(OutputAliasInfo): + "Marks that the graph output aliases an index in the new, returned outputs" + + __slots__ = ["index"] + + index: int + + def __init__(self, index: int) -> None: + assert isinstance(index, int) + self.index = index + + +class CUDAGraphNode: + """ + A single recording of a function into a CUDA Graph. Recordings of CUDA Graphs share a single memory pool + and are structured into a tree, where there is a single recording that can precede it (parent) and multiple + subsequent recordings that may follow (children). A node will have no parent if it is the first recording + in a tree; i.e., when it is first recorded, there are no live tensors from a previous recording which + would force a dependency. + + On first recording, all of the live tensors in the current CUDA Graph Node path will be + reflected in the corresponding private pool. On subsequent executions, the caching allocator + is unaffected when the graph is replayed. + + In order to support recording a subsequent cuda graph recording after execution of this graph, + we checkpoint the state of the memory pool so that it may later be resumed. + + WrappedFunction should have already been warmed up prior to invocation. + + See [setCheckpointPoolState] for further explanation, as well as + https://user-images.githubusercontent.com/13564/222815509-374f3400-f83d-4f7d-8fa6-4a092b3250bb.png + """ + + def __init__( + self, + wrapped_function: WrappedFunction, + id: GraphID, + parent: Optional[CUDAGraphNode], + inputs: List[InputType], + cuda_graphs_pool: Tuple[int, int], + device_index: int, + stack_traces: Optional[StackTraces], + stream: torch.cuda.Stream, + ) -> None: + assert isinstance(inputs, (list, tuple)) + + self.wrapped_function = wrapped_function + self.id = id + self.device = device_index + self.stack_traces = stack_traces + self.stream = stream + + # Enable re-record a cudagraph when static tensor address changed. + # if not we should error when it changed. + self.rerecord_if_static_inputs_change = ( + torch._dynamo.config.inline_inbuilt_nn_modules + or torch._inductor.config.triton.cudagraph_support_input_mutation + ) + + # if this is a root parent will be None. use weakref to prevent reference cycle + self._parent = weakref.ref(parent) if parent is not None else None + # reference to the shared memory pool for the entire cuda graphs tree + self.cuda_graphs_pool = cuda_graphs_pool + + # A single wrapped function may be recorded multiple times if memory patterns or + # invariants change from one execution to the next + self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list) + + # StorageWeakRef maintains whether the Storage C++ object remains allocated, + # not whether the corresponding memory has been deallocated. In order + # to use them to track memory deallocations we must maintain a single StorageWeakRef + # for all Storages that reference that memory (even if we are constructing Storages + # that do not have a deallocator function). We maintain one single storage_cache + # as we execute any tree path. When we retrieve a storage from the cache we + # check that it is still alive, and we hash based on observed recording data ptr + # and storage cdata. + + # we preserve a single reference to executed outputs that is then referenced + # in children to avoid children having to chase parent pointers in the hot path + # DO NOT reassign output_weakrefs, only call `clear()` + # Path is a series of nodes from root to the current node + self.outputs_weakrefs: OutputList[Optional[StorageWeakRefWrapper]] = [] + self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [ + node.outputs_weakrefs for node in self._path_from_root + ] + self.path_stacktraces: LevelList[Optional[StackTraces]] = [ + node.stack_traces for node in self._path_from_root + ] + self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = [] + + # tensors which are outputs of previous graphs in the tree + self.cudagraph_managed_idxs: List[int] = [ + idx + for idx, t in enumerate(inputs) + if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t) + ] + + self.static_input_idxs: List[int] = list( + set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs) + ) + + self.non_static_input_idx: LevelList[int] = [ + i for i in range(len(inputs)) if i not in self.static_input_idxs + ] + + counters["inductor"]["cudagraph_recorded_non_static_inputs"] += len( + self.non_static_input_idx + ) + + self.non_managed_static_input_idxs: LevelList[int] = [ + i + for i in wrapped_function.static_input_idxs + if i not in self.cudagraph_managed_idxs + ] + + def maybe_get_static_data_ptr( + idx: int, + inputs: List[Union[torch.Tensor, int]], + static_input_idxs: List[int], + ) -> Optional[int]: + inp = inputs[idx] + if isinstance(inp, torch.Tensor) and idx in static_input_idxs: + return inp.data_ptr() + return None + + self.static_input_data_ptrs: InputList[Optional[int]] = [ + maybe_get_static_data_ptr(i, inputs, self.static_input_idxs) + for i in range(len(inputs)) + ] + + # When we checkpoint, and free generations, we will be manually freeing the outputs + # of CUDAGraphNodes. We should not be freeing parameters, not do we need to account for + # their liveness (they are static), so we need to compute which outputs are aliases of + # parameters. Some static inputs are saved tensors from the forward that die in the backward. + # Their locations are static but lifetimes are not. We only include the persistent static + # data ptrs below because the non persistent data ptrs may be outputs of this record and + # fresh allocations. + + # precompute expanded dims to avoid computing in the hot path + self.expanded_dims: List[List[int]] = [ + get_expanded_dims(x) + if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs + else [] + for idx, x in enumerate(inputs) + ] + + # For each node in path, which outputs were observed to be live + # before invoking graph recording, and after graph recording + self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = [] + self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = [] + + # List of Tuples of (depth, output_index) that index into node at depth + # number of nodes from root and output_index of outputs. Will index into + # path_weakrefs. + self.expected_dead_indices_before_graph: List[PathOutputIndex] = [] + self.expected_dead_indices_after_graph: List[PathOutputIndex] = [] + + # all live indices after graph recording + self.live_indices_after_graph: List[PathOutputIndex] = [] + + if self.parent is not None: + previous_liveness = self.parent.recorded_liveness_after_graph + curr_liveness = self._get_liveness(self.path_weakrefs) + + different_indices = self._get_different_indices( + previous_liveness, curr_liveness + ) + + self.recorded_liveness_before_graph = curr_liveness + self.expected_dead_indices_before_graph = different_indices + + recording_inputs = self._allocate_and_copy_recording_inputs(inputs) + # recording inputs will copy over memory, so we can free non recording inputs + inputs.clear() + del inputs + + # graph used for recording model invocation + self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + + # we allocate non-static inputs within the same memory pool as the CUDAGraph + # which we will record the model with. For memory efficiency, it is important + # to reclaim the input memory when the inputs are no longer live. To accomplish this, + # we reconstruct tensors at the correct data pointers of our inputs which are + # non owning and do not prevent deallocation. On subsequent executions, input values + # will be copied over to these tensors. + self.reconstructed_inputs: List[InputType] = [ + self._reconstruct_from_tensor_metadata(self._tensor_metadata(x)) + if isinstance(x, torch.Tensor) + else x + for x in recording_inputs + ] + + # DO THE RECORDING!!! + # We record the CUDA graph in the constructor of CUDAGraphNode, which + # gives you what the CPU side compute of the function would do. We + # don't throw the recording outputs away: their memory is + # correctly accounted for in the CUDAGraphs caching allocator. This + # means on the very FIRST run of the CUDA graph node, we can directly + # do more recording, because we have a valid caching allocator state. + # NB: This relies on run() being called immediately after the + # constructor, otherwise this optimization would not be valid. + + # initialized below in _record + + self.checkpointed_caching_state: Optional[AllocatorState] = None + + # Output Storage Alias information, can be: + # - A new, unaliased storage, or the output is None + # - An alias of an output of a prior graph + # - An alias of an output already created in the reconstructed outputs + # This is None if the output in question is an int + self.output_storage_alias: OutputList[Optional[OutputAliasInfo]] = [] + + # is the output Storage unaliased in subsequent outputs, of all subsequent paths + # if it is, we cached the output tensor and adjust storage liveness tracking to also + # check if the output tensor does not have an additional python reference. + # If a descendent node discovers it has an alias of a prior output, then the output + # will no longer be cached in the ancestor. + # The large majority of tensors are unaliased, and preserving aliased output tensors would add + # significant additional complexity with marginal gains + # The cached tensor outputs are added on the first execution, and cleared whenever we need + # to do subsequent recording + self.unaliased_in_all_paths: OutputList[bool] = [] + self.cached_tensor_outputs: OutputList[Optional[Tensor]] = [] + + # if an output aliases a static, persistent input then the corresponding Tensor will + # be set here. These are different than cached tensors, because they are tensors that + # are aliases of parameters that are always live. + self.static_output_tensors: OutputList[Optional[Tensor]] = [] + + # Cleared after recording + self.recording_outputs: Optional[OutputType] = self._record( + wrapped_function.model, recording_inputs + ) + self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = [] + + # As with inputs, we do not want to keep the outputs permanently alive because that would prevent + # their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata + # needed to reconstruct instead. + assert self.recording_outputs is not None + for out in self.recording_outputs: + if isinstance(out, torch.Tensor): + self.outputs_metadata.append( + self._tensor_metadata(out, ignore_storage_offset=False) + ) + else: + assert isinstance(out, (int, type(None))), type(out) + self.outputs_metadata.append(out) + + self.graph.replay() + + def _copy_inputs_and_remove_from_src( + self, dsts: List[InputType], srcs: List[InputType] + ) -> None: + dst_tensors = [] + src_tensors = [] + for idx in self.non_static_input_idx: + if not isinstance(srcs[idx], torch.Tensor): + continue + expanded_dims = self.expanded_dims[idx] + dst_tensors.append(index_expanded_dims(dsts[idx], expanded_dims)) # type: ignore[arg-type] + src_tensors.append(index_expanded_dims(srcs[idx], expanded_dims)) # type: ignore[arg-type] + srcs[idx] = None # type: ignore[call-overload] + # Fails on empty lists + if dst_tensors: + torch._foreach_copy_(dst_tensors, src_tensors) + + def check_static_inputs_are_stable(self, new_inputs: List[InputType]) -> None: + # avoid checking managed tensor static points since we already checked those in check_invariants + if ( + not self.rerecord_if_static_inputs_change + and not torch._C._tensors_data_ptrs_at_indices_equal( + new_inputs, # type: ignore[arg-type] + self.static_input_data_ptrs, + self.non_managed_static_input_idxs, + ) + ): + # this should error + error_msg = log_data_ptr_mismatch( + self.wrapped_function.placeholders, + new_inputs, + self.static_input_data_ptrs, + self.non_managed_static_input_idxs, + CheckInvariantStatus.StaticInputIdxMismatch, + ) + torch._check(False, lambda: error_msg) + + def run_first_inputs(self, new_inputs: List[InputType]) -> OutputType: + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_before_invocation() + + # graph is already invoked in the __init__ + # inputs are copied over in _allocate_recording_inputs and subsequently cleared + assert len(new_inputs) == 0 + outputs = self.recording_outputs + self.recording_outputs = None + assert outputs is not None + return outputs + + def run(self, new_inputs: List[InputType]) -> OutputType: + self.check_static_inputs_are_stable(new_inputs) + + self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs) + new_inputs.clear() + + self.run_graph() + + outputs = self.reconstruct_outputs() + + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_after_invocation() + + if config.triton.force_cudagraph_sync: + torch.cuda.synchronize() + + # Reset this to run the check in the future + self.static_inputs_stable = False + + return outputs + + def reconstruct_outputs(self) -> OutputType: + "Reconstruct output tensors according to their saved metadata and alias information" + + # Cached tensors will not yet be set on the first execution + # They are also cleared in checkpointing, so if we checkpoint this node + # and then execute it again we will need to repopulate cached tensors + if not self.cached_tensor_outputs: + self._initialize_cached_tensors() + + outputs: OutputType = [] + + for i, (storage_info, metadata) in enumerate( + zip(self.output_storage_alias, self.outputs_metadata) + ): + if not isinstance(metadata, dict): # tensor metadata + assert isinstance(metadata, (int, type(None))) + outputs.append(metadata) + continue + + cached_t = self.cached_tensor_outputs[i] + if cached_t is not None: + # this output represents a fresh allocated tensor. + # We return the same TensorImpl from run to run to avoid overhead. + # autograd.Function will reset the Autograd meta of output tensors + # as part of aot_autograd, but _backward_hooks are stored on tensors separately, + # so we need to manually reset hooks. + if cached_t._backward_hooks is not None: + cached_t._backward_hooks = None + + # No need to update weakrefs, already correctly initialized + outputs.append(cached_t) + continue + + static_t = self.static_output_tensors[i] + if static_t is not None: + assert self.outputs_weakrefs[i] is None + outputs.append(static_t) + continue + + storage = self.prepare_alias_info_for_tensor_construction( + storage_info, metadata + ) + + if isinstance(storage, UntypedStorage) or storage is None: + out = self._reconstruct_from_tensor_metadata(metadata, storage) + else: + assert isinstance(storage, int) + out = self._reconstruct_from_tensor_metadata( + metadata, cast(torch.Tensor, outputs[storage]).untyped_storage() + ) + + outputs.append(out) + w = self.outputs_weakrefs[i] + assert w is not None + w.swap_weakref(out.untyped_storage()._weak_ref()) + + return outputs + + def prepare_alias_info_for_tensor_construction( + self, + out_alias_info: Optional[OutputAliasInfo], + metadata: Union[Dict[str, Any], int, None], + ) -> Union[UntypedStorage, None, int]: + if ( + isinstance(metadata, (int, type(None))) + or out_alias_info is UnaliasedStorage + ): + return None + + if isinstance(out_alias_info, AliasesPriorGraphOutput): + depth, existing_output_index = out_alias_info.index + ref = self.path_weakrefs[depth][existing_output_index] + assert ref is not None + return torch.UntypedStorage._new_with_weak_ptr(ref()) + + assert isinstance(out_alias_info, AliasesNewOutput) + return out_alias_info.index + + def prepare_storages_for_construction( + self, + ) -> List[Union[UntypedStorage, None, int]]: + output_storages = [] + for output_storage_alias, metadata in zip( + self.output_storage_alias, self.outputs_metadata + ): + output_storages.append( + self.prepare_alias_info_for_tensor_construction( + output_storage_alias, metadata + ) + ) + + return output_storages + + def run_graph(self) -> None: + assert self.graph is not None + self.graph.replay() + + def all_outputs_are_dead(self) -> bool: + "All outputs of the path from this node to its root are dead" + for depth, output_index in self.live_indices_after_graph: + if is_live(self.path_weakrefs[depth][output_index]): + return False + return True + + def _record(self, model: ModelType, inputs: List[InputType]) -> OutputType: + "Record the model" + + def static_input_iter() -> Generator[torch.Tensor, None, None]: + for i in self.wrapped_function.static_input_idxs: + _inp = inputs[i] + if isinstance( + _inp, torch.Tensor + ) and not self._is_cuda_graph_recorded_tensor(_inp): + yield _inp + + # see: output_is_alias_of_persistent_static_inputs above + static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = { + inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp) + for inp in itertools.chain( + static_input_iter(), self.wrapped_function.constants + ) + } + + if config.triton.slow_path_cudagraph_asserts: + # need to use parent live weakrefs because live_indices isnt set yet + memory = ( + [] if self.parent is None else list(self.parent.path_live_weakrefs()) + ) + memory += [ + StorageWeakRefWrapper(elem) + for i, elem in enumerate(inputs) + if isinstance(elem, torch.Tensor) + and i not in self.wrapped_function.static_input_idxs + and elem.untyped_storage().data_ptr() != 0 + ] + check_memory_pool(self.device, self.cuda_graphs_pool, memory) + + with preserve_rng_state(), torch.cuda.device( + self.device + ), clear_cublas_manager(), torch.cuda.graph( + self.graph, + stream=self.stream, + pool=self.cuda_graphs_pool, + capture_error_mode="thread_local", + ), get_history_recording(): + static_outputs = model(inputs) + + # running model should reclaim memory + assert len(inputs) == 0 + + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + self._add_first_outputs(static_outputs, static_input_persistent_storage_ptrs) + + return static_outputs + + def _add_first_outputs( + self, + outputs: OutputType, + static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper], + ) -> None: + "Add the outputs from the first invocation of the node and set up metadata" + + # getting liveness before we have added the outputs to path, so the length + # of the two lists is equal + prev_liveness = self.recorded_liveness_before_graph + curr_liveness = self._get_liveness(self.path_weakrefs) + + delta = self._get_different_indices(prev_liveness, curr_liveness) + self.expected_dead_indices_after_graph = delta + + assert len(self.outputs_weakrefs) == 0 + # index from data pointer to index in outputs + output_new_storages_index: Dict[StorageDataPtr, int] = {} + + self.unaliased_in_all_paths = [False for _ in range(len(outputs))] + self.static_output_tensors = [None for _ in range(len(outputs))] + + for i, o in enumerate(outputs): + if o is None or not isinstance(o, torch.Tensor): + self.output_storage_alias.append(UnaliasedStorage) + continue + + torch._check( + o.is_cuda or o.untyped_storage().data_ptr() == 0, + lambda: ( + "Expected all cuda outputs in cuda graph recording. Non cuda output " + f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" + ), + ), + + ref = static_input_persistent_storage_ptrs.get( + o.untyped_storage().data_ptr(), None + ) + # also treat empty storages as static outputs because we do not need to manage their lifetime + # and they should not participate in checkpointing + is_empty_storage = o.untyped_storage().data_ptr() == 0 + if (ref and ref() is not None) or is_empty_storage: + self.output_storage_alias.append(None) + self.static_output_tensors[i] = o + continue + + path_ref = self._is_alias_of_live_recorded_tensor(o) + if path_ref is not None: + self._mark_prior_graph_output_as_aliased(path_ref) + self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref)) + continue + + if o.untyped_storage().data_ptr() in output_new_storages_index: + index = output_new_storages_index[o.untyped_storage().data_ptr()] + self.unaliased_in_all_paths[index] = False + self.output_storage_alias.append(AliasesNewOutput(index)) + continue + + output_new_storages_index[o.untyped_storage().data_ptr()] = i + self.output_storage_alias.append(UnaliasedStorage) + self.unaliased_in_all_paths[i] = True + + if self.stack_traces is None: + self.stack_traces = [None for _ in range(len(outputs))] + else: + assert len(self.stack_traces) == len( + outputs + ), "Wrong number of stack traces passed in" + + assert not self.outputs_weakrefs + for out, static_output_tensor in zip(outputs, self.static_output_tensors): + if not isinstance(out, torch.Tensor) or static_output_tensor is not None: + self.outputs_weakrefs.append(None) + self.tensor_weakrefs.append(None) + else: + self.outputs_weakrefs.append(StorageWeakRefWrapper(out)) + self.tensor_weakrefs.append(TensorWeakRef(out)) + + self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs) + self.checkpointed_caching_state = torch._C._cuda_getCheckpointState( + self.device, self.cuda_graphs_pool + ) + + # now, get liveness with outputs added + for depth in range(len(self.path_weakrefs)): + for output_index in range(len(self.path_weakrefs[depth])): + if is_live(self.path_weakrefs[depth][output_index]): + self.live_indices_after_graph.append((depth, output_index)) + + self.debug_check_invariants_after_invocation() + if config.triton.slow_path_cudagraph_asserts: + check_memory_pool( + self.device, self.cuda_graphs_pool, list(self.path_live_weakrefs()) + ) + + def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex) -> None: + "Remove a graph output from the unaliased, cached tensors in an ancestor node" + depth, output_index = index + node = list(self._path_from_root)[depth] + node.unaliased_in_all_paths[output_index] = False + x = self.path_weakrefs[depth][output_index] + assert x is not None + x.remove_extra_reference() + + def _initialize_cached_tensors(self) -> None: + # we should not be clearing output_weakrefs, and they should be set in the first + # record run + assert len(self.outputs_weakrefs) == len(self.outputs_metadata) + + for i, (storage_info, metadata, make_cached) in enumerate( + zip( + self.output_storage_alias, + self.outputs_metadata, + self.unaliased_in_all_paths, + ) + ): + if not make_cached: + self.cached_tensor_outputs.append(None) + continue + + assert storage_info is UnaliasedStorage + assert isinstance(metadata, dict) + s = self.create_storage(metadata) + out = self._reconstruct_from_tensor_metadata(metadata, storage=s) # type: ignore[arg-type] + + # XXX: let autograd know that there will be an additional reference to the tensor + # that can be ignored when deciding whether to do gradient buffer inplacing. + # Otherwise, inplacing could differ between tracing and subsequent execution. + # For some models we tested this led to inputs no longer being in cudagraph pools, + # leading to spurious re-recordings. + # It also tells AMP cache that even though the tensor impls cannot be cached + # in dtype conversions. + + torch._C._add_cached_tensor(out) + + self_ref = weakref.ref(self) + + # one reference in our array, and calling sys.getrefcount bumps the refcount by one + def check_refcount(i: int) -> bool: + self_loc = self_ref() + if self_loc is None: + return False + return self_loc.get_output_refcount(i) == 2 + + check = functools.partial(check_refcount, i=i) + + self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check) + self.cached_tensor_outputs.append(out) + + def get_output_refcount(self, index: int) -> int: + return sys.getrefcount(self.cached_tensor_outputs[index]) + + @property + def parent(self) -> Optional[CUDAGraphNode]: + "unwraps the weakref to _parent" + return self._parent() if self._parent is not None else None + + @property + def _path_to_root(self) -> Generator[CUDAGraphNode, None, None]: + "Returns all nodes in the path starting at self and ending at root" + node = self + while node: + yield node + node = node.parent # type: ignore[assignment] + + @property + def _path_from_root(self) -> Generator[CUDAGraphNode, None, None]: + "Returns all nodes in the path starting at the root and ending at self" + nodes = reversed(list(self._path_to_root)) + yield from nodes + + def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool: + "Is this tensor an output of a node in this path" + for output_refs in self.path_weakrefs: + for storage_weak_ref in output_refs: + if storage_weak_ref is None: + continue + # don't need to check liveness of storage since the cuda graph managed + # memory is never released. + data_ptr = storage_weak_ref.data_ptr() + if t.untyped_storage().data_ptr() == data_ptr: + return True + + return False + + def _is_alias_of_live_recorded_tensor( + self, t: torch.Tensor + ) -> Optional[PathOutputIndex]: + for depth, output_refs in enumerate(self.path_weakrefs): + for output_index, storage_ref in enumerate(output_refs): + if (storage_and_ptr := maybe_deref(storage_ref)) is not None: + storage, ptr = storage_and_ptr + if ptr == t.untyped_storage().data_ptr(): + return (depth, output_index) + + return None + + @staticmethod + def _check_liveness( + indices: List[PathOutputIndex], + output_refs: List[List[Optional[StorageWeakRefWrapper]]], + ) -> bool: + "Check that all of the indices specified are dead references" + for depth, output_index in indices: + w = output_refs[depth][output_index] + assert w is not None + if w() is not None: + return False + return True + + def add_child(self, function_id: FunctionID, node: CUDAGraphNode) -> None: + "Adds node as a a child of self" + self.children[function_id].append(node) + + @staticmethod + def _get_different_indices( + prev: List[List[bool]], curr: List[List[bool]] + ) -> List[PathOutputIndex]: + "Find indices where the two lists differ." + dead_indices = [] + assert len(prev) <= len(curr) + for i, (outputs1, outputs2) in enumerate(zip(prev, curr)): + assert len(outputs1) == len(outputs2) + for j, (output1, output2) in enumerate(zip(outputs1, outputs2)): + if output1 != output2: + dead_indices.append((i, j)) + + return dead_indices + + @staticmethod + def _get_liveness( + weakrefs: List[List[Optional[StorageWeakRefWrapper]]], + ) -> List[List[bool]]: + "Maps weakrefs to true if the reference is alive and false otherwise" + if len(weakrefs) == 0: + return [] + + return [pytree.tree_map(is_live, outputs) for outputs in weakrefs] + + def debug_assert_invariants( + self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex] + ) -> None: + if not config.triton.fast_path_cudagraph_asserts: + return + + for i, node in enumerate(self._path_from_root): + assert self.path_weakrefs[i] is node.outputs_weakrefs + + nodes = list(self._path_from_root) + + live_blocks = get_block_addrs(self.cuda_graphs_pool) + + live_storage_data_ptrs = set() + live_storage_weak_ptrs = set() + + for depth, outputs_liveness in enumerate(expected_liveness): + for output_idx, output_liveness in enumerate(outputs_liveness): + # tensor can die early, but it can't be alive when it should be dead + w = self.path_weakrefs[depth][output_idx] + if (stor_weak_ptr_and_data_ptr := maybe_deref(w)) is not None: + assert output_liveness + stor_weak_ptr, stor_data_ptr = stor_weak_ptr_and_data_ptr + assert (stor_data_ptr in live_storage_data_ptrs) == ( + stor_weak_ptr in live_storage_weak_ptrs + ) + live_storage_data_ptrs.add(stor_data_ptr) + live_storage_weak_ptrs.add(stor_weak_ptr) + + is_persistent_alias = ( + nodes[depth].static_output_tensors[output_idx] is not None + ) + + if is_persistent_alias: + assert stor_data_ptr not in live_blocks + + for depth, output_index in newly_dead: + assert not is_live(self.path_weakrefs[depth][output_index]) + + def debug_check_invariants_before_invocation(self) -> None: + self.debug_assert_invariants( + self.recorded_liveness_before_graph, self.expected_dead_indices_before_graph + ) + + def debug_check_invariants_after_invocation(self) -> None: + self.debug_assert_invariants( + self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph + ) + + def data_ptrs_dead_since_invocation(self) -> List[int]: + """ + Since this node was invoked, return data ptrs of all tensor outputs that have died + in the current executing tree path. + """ + curr_liveness = self._get_liveness(self.path_weakrefs) + _get_different_indices = self._get_different_indices( + self.recorded_liveness_after_graph, curr_liveness + ) + + path = list(self._path_from_root) + ptrs_to_deallocate = [] + for depth, output_index in _get_different_indices: + ptrs_to_deallocate.append( + path[depth].outputs_metadata[output_index]["data_ptr"] # type: ignore[index] + ) + + return ptrs_to_deallocate + + def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: + for i, j in self.live_indices_after_graph: + out = self.path_weakrefs[i][j] + if out is not None and is_live(out): + yield out + + def remove_node_cached_tensors(self) -> None: + for t in self.cached_tensor_outputs: + if t is not None: + torch._C._remove_cached_tensor(t) + self.cached_tensor_outputs.clear() + + for i, unaliased in enumerate(self.unaliased_in_all_paths): + if unaliased: + n = self.outputs_weakrefs[i] + assert n is not None + n.remove_extra_reference() + + def remove_path_cached_tensors(self) -> None: + for node in self._path_from_root: + node.remove_node_cached_tensors() + + def clear_path_state(self) -> None: + "Clear the path state in this current executing node" + # this doesnt actually do anything right now, leaving it as placeholder + + @staticmethod + def _tensor_metadata( + x: torch.Tensor, ignore_storage_offset: bool = True + ) -> Dict[str, Any]: + assert isinstance(x, torch.Tensor) + # We ignore the storage offset for inputs, but not for outputs + # TODO: - should we make the storage resizable ? + return { + "nbytes": x.untyped_storage().nbytes(), + "data_ptr": x.untyped_storage().data_ptr(), + "size": x.shape, + "stride": x.stride(), + "dtype": x.dtype, + "device": x.device, + "storage_offset": x.storage_offset() if not ignore_storage_offset else 0, + } + + def _reconstruct_from_tensor_metadata( + self, metadata: Dict[str, Any], storage: Optional[UntypedStorage] = None + ) -> Tensor: + s = self.create_storage(metadata) if storage is None else storage + return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) # type: ignore[arg-type] + + def create_storage(self, metadata: Dict[str, Any]) -> torch.types.Storage: + return torch._C._construct_storage_from_data_pointer( + metadata["data_ptr"], metadata["device"], metadata["nbytes"] + ) + + def _allocate_and_copy_recording_inputs( + self, inputs: List[InputType] + ) -> List[Union[torch.Tensor, int]]: + """ + Allocate inputs for non static, non cudagraph managed tensors in the memory pool + and copy over the tensor values. + """ + + torch.cuda.synchronize() + self.stream.wait_stream(torch.cuda.current_stream()) + recording_inputs: List[InputType] = [] + + with warnings.catch_warnings(record=True), torch.cuda.device( + self.device + ), _use_cuda_memory_pool_manager( + self.device, + mem_pool=self.cuda_graphs_pool, + stream=self.stream, + ): + for i, inp in enumerate(inputs): + if not isinstance(inp, torch.Tensor): + assert isinstance(inp, int) + recording_inputs.append(inp) + elif i not in self.static_input_idxs: + # static_input does an allocation! + recording_inputs.append(static_input(inp)) + else: + recording_inputs.append(inp) + + self._copy_inputs_and_remove_from_src(recording_inputs, inputs) + + return recording_inputs + + def check_invariants( + self, inputs: List[InputType] + ) -> Tuple[CheckInvariantStatus, Callable[..., str]]: + """ + Checks if this node can be run. The same pattern of tensor liveness, static inputs, + and tensors managed in the cudagraph private pool must remain stable. + """ + + _logger = functools.partial( + log_data_ptr_mismatch, + self.wrapped_function.placeholders, + inputs, + self.static_input_data_ptrs, + ) + + # previously managed data pointers remain stable + # this is on the hot path so moved to C++. equivalent to: + # return all(t.data_ptr() == data_ptr for (t, data_ptr) in zip(tensors, data_ptrs)) + if not torch._C._tensors_data_ptrs_at_indices_equal( + inputs, # type: ignore[arg-type] + self.static_input_data_ptrs, + self.cudagraph_managed_idxs, + ): + status = CheckInvariantStatus.CudagraphManagedIdxMismatch + _logger = functools.partial( + _logger, + self.cudagraph_managed_idxs, + status, + ) + return status, _logger + + if not self._check_liveness( + self.expected_dead_indices_before_graph, self.path_weakrefs + ): + status = CheckInvariantStatus.ExpectedDeadIndicesBeforeGraphMismatch + return status, lambda: f"{status}" + + # static input data pointers should remain stable + # if we are inlining builtin nn modules we re-record in this case + # if we are not inlining builtin nn modules, we check this in check_static_inputs_are_stable + # and error if they are not stable + if ( + self.rerecord_if_static_inputs_change + and not torch._C._tensors_data_ptrs_at_indices_equal( + inputs, # type: ignore[arg-type] + self.static_input_data_ptrs, + self.static_input_idxs, + ) + ): + status = CheckInvariantStatus.StaticInputIdxMismatch + _logger = functools.partial( + _logger, + self.static_input_idxs, + status, + ) + return status, _logger + + # the cudagraph managed tensors which died upon recording must also die upon + # this invocation. it is too late to check after we've replayed the graph, + # because we would have already written over their memory. + for idx in self.cudagraph_managed_idxs: + inputs[idx] = None # type: ignore[call-overload] + + torch._check( + self._check_liveness( + self.expected_dead_indices_after_graph, self.path_weakrefs + ), + lambda: "TODO: graph recording observed an input tensor deallocate during graph " + " recording that did not occur during replay. Please file an issue.", + ) + return CheckInvariantStatus.SUCCESS, lambda: f"{CheckInvariantStatus.SUCCESS}" + + def num_descendants(self) -> int: + "Total number of descendents of this node" + num_desc = 0 + for children in self.children.values(): + for child in children: + num_desc += 1 + num_desc += child.num_descendants() + return num_desc + + +def get_cudagraph_segments(pool_id: Tuple[int, int]) -> Any: + segments = torch.cuda.memory_snapshot() + return [segment for segment in segments if segment["segment_pool_id"] == pool_id] + + +def get_block_addrs(pool_id: Tuple[int, int], live_only: bool = True) -> List[int]: + blocks = [] + + for segment in get_cudagraph_segments(pool_id): + addr = segment["address"] + for block in segment["blocks"]: + if block["state"] == "active_allocated" or not live_only: + blocks.append(addr) + + addr += block["size"] + + return blocks + + +def format_tb(frames: List[Any]) -> str: + formatted_traceback = [] + + for entry in frames: + formatted_traceback.append( + traceback.FrameSummary(entry["filename"], entry["line"], entry["name"]) + ) + + return "".join(traceback.format_list(formatted_traceback)) + + +def check_memory_pool( + device: int, + pool_id: Tuple[int, int], + live_storages_ptrs: List[StorageWeakRefWrapper], +) -> None: + assert all( + isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs + ) # noqa: C419 + unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} + + # check if there is a divergence first, then do the expensive snapshot call after + # we know it will error + if torch._C._cuda_checkPoolLiveAllocations(device, pool_id, unique_storages): + return + + # at this point we are past the fast-path. we have seen rare cases where a dead tensor is dead, + # but hasn't been gc'd yet, and gives false positive for allocated_not_in_live_storages + gc.collect() + + segments = get_cudagraph_segments(pool_id) + + allocated_not_in_live_storages = {} + + for segment in segments: + addr = segment["address"] + for block in segment["blocks"]: + if block["state"] == "active_allocated": + if addr not in unique_storages: + allocated_not_in_live_storages[addr] = block + else: + unique_storages.remove(addr) + + addr += block["size"] + + torch._check( + len(unique_storages) == 0, + lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}", + ) + + if len(allocated_not_in_live_storages) != 0: + formatted = [] + for dp, block in allocated_not_in_live_storages.items(): + trace = format_tb(block.get("frames", [])) + formatted.append(f"Data Pointer: {dp}, history: \n{trace}") + formatted_s = "\n".join(formatted) + msg = ( + f"These live storage data ptrs are in the cudagraph pool but not " + f"accounted for as an output of cudagraph trees: \n\n{formatted_s}" + ) + raise RuntimeError(msg) + + +class ExecutionState(Enum): + """ + Represents the state of the CUDAGraph Tree. Will be None if there is no live current memory allocated + in the cuda graph pool. Otherwise will reflect the state of the most recently executed node. + """ + + NONE = auto() + WARMUP = auto() + RECORDING = auto() + EXECUTION = auto() + + +class CompilationMode(Enum): + FORWARD = auto() + BACKWARD = auto() + INFERENCE = auto() + + +class CUDAGraphTreeManager: + """ + Groups individual recordings or executions of cuda graphs into a tree of recordings, + and checks required invariants, and manages warmups of graphs. + + When graphs are recorded in the same tree, it enforces subsequent execution + to follow the same order and have the same output tensor livespans. To remove + unnecessary coupling of cuda graphs (and additional imposed invariants), + the tree manager will end a currently recording tree whenever it is valid - when + the memory pool no longer has any live allocations. + + We ignore outputs from a previous generation that correspond to prior model outputs. + Currently this is hardcoded `GenerationTracker.generation` tracked in torch dynamo. + # TODO: make generation increment configurable, warn on overwrite. + + We run graph warmups in the cudagraph memory pool and return the result on the first invocation + of a function. For many models it is important to reclaim activations as you run the backward. + If we were to warm up the model and keep an extra copy of the inputs around to subsequently + use for recording, we would incur a memory penalty. Additionally, if we are part way through training + your model and need to recompile, memory will be allocated to the cuda graph pool, so we run this + warmup run in the cuda graph memory pool. As for recording, warm up needs the state of live tensors + to be accurately reflected so we checkpoint the allocator state if we need to warm up following graph + replay. + """ + + def __init__(self, device_index: int) -> None: + # roots are functions which have no dependencies on an other node. I.e., + # when they are first invoked, none of their inputs are outputs are outputs + # of another node, nor are there any live outputs of another node whose + # liveness would create a dependency. + self.roots: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list) + + # mapping from function id to wrapped function + self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {} + + self.ids_to_stack_traces: Dict[FunctionID, Optional[StackTraces]] = {} + + self.warmed_up_functions: Set[FunctionID] = set() + # if we fail to increment generation, and are stuck warming up, + # only warn on each function once + self.warned_functions: Set[FunctionID] = set() + torch._C._set_cached_tensors_enabled(True) + + # warn only once if a function mutates inputs + self.warned_mutation: Set[FunctionID] = set() + + # NB: cuda caching allocator will remember the stream a segment is allocated to + # and only allocate that segment to the same stream. we need to use a single stream + # for all allocations to the memory pool, otherwise the allocations to separate streams + # will not be reused; separate recordings would have use the same memory pool, but not + # the same memory. + + with torch.cuda.device(device_index): + torch.cuda.synchronize() + self.stream = torch.cuda.Stream() + self.stream.wait_stream(torch.cuda.current_stream()) + + # Keeps Memory Pool Alive + self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle() + + with warnings.catch_warnings(record=True), torch.cuda.graph( + self.graph, + pool=self.cuda_graphs_thread_pool, + stream=self.stream, + capture_error_mode="thread_local", + ): + pass + + self.graph_counter = itertools.count(0) + self.func_counter = itertools.count(0) + + # mapping from graph_id to (function id to mutation type hint) since we are + # specializing on a particular combination of Parent Node -> Function ID. + self.non_cudagraph_managed_mutation_hint: Dict[ + Optional[GraphID], Dict[FunctionID, bool] + ] = defaultdict(dict) + self.warmup_node_counter = itertools.count(start=-1, step=-1) + + # mapping from graph_id to (function id to re-record count). We fall back to + # eager function if a function is re-recorded frequently on a node. + self.num_rerecord: Dict[Optional[GraphID], Dict[FunctionID, int]] = defaultdict( + lambda: defaultdict(lambda: 0) + ) + + # whether we the current node is in a state of warmup, recording, execution. If + # there is no current node the state will be ExecutionState.None. + self.path_state = ExecutionState.NONE + self.device_index = device_index + + # the most recently invoked cudagraph wrapping of a function. Will be None + # when there is no output from a previous recording or execution whose memory + # we need to respect in the cuda caching allocation. If you incremented generation, + # this will also be none, as ignore those allocations. + self.current_node: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = None + + # current generation of cudagraph invocations. when torch.compile is run + # we increment the current generation. are willing to ignore live outputs + # of a previous generation in checking liveness. + self.current_gen: int = -1 + + # number of instances we are in execution and failed to match to an + # existing child + self.debug_fail_counter = 0 + # number of instances we had to checkpoint the function + self.debug_checkpointing_counter = 0 + + self.id_to_mode: Dict[FunctionID, CompilationMode] = {} + + # Note: [Backward Generation Handling] + # We generally perform a sequence of forward executions followed by backward executions. + # If multiple torch.compile wrapped forwards are executed with their backwards pending, + # we should not disregard the outputs from a prior torch.compile since the entire training + # loop hasn't completed. Occasionally, a backward pass corresponding to a forward pass may + # not be executed, so we cannot wait for all pending forward pass backward completions, so + # we cannot wait for all backwards to have been invoked. Instead we wait for a single backward + # invocation. Triggering a backward pass typically doesn't lead to another torch.compile + # invocation, making it less likely for the generation to increase between multiple + # backward calls. The following use case is covered by this approach: + # mod1 = torch.compile(...) + # mod2 = torch.compile(...) + # mod2(mod1(x)).sum().backward() + + self.running_forwards_with_pending_backwards = False + + def run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType: + assert self.graph is not None, "Running CUDAGraph after shutdown" + out = self._run(new_inputs, function_id) + + # The forwards are only pending following invocation, not before + mode = self.id_to_mode[function_id] + if mode == CompilationMode.FORWARD: + self.running_forwards_with_pending_backwards = True + elif mode == CompilationMode.BACKWARD: + self.running_forwards_with_pending_backwards = False + + return out + + def set_to_running_backward(self) -> None: + self.running_forwards_with_pending_backwards = False + + def _get_cuda_graph_recorded_tensor_checker(self) -> Callable[[Tensor], bool]: + return ( + self.current_node._is_cuda_graph_recorded_tensor + if isinstance(self.current_node, (CUDAGraphNode, CUDAWarmupNode)) + else lambda _: False + ) + + def new_warmup_node_id(self) -> GraphID: + return GraphID(next(self.warmup_node_counter)) + + def _update_non_cudagraph_managed_mutation( + self, function_id: FunctionID, inputs: List[InputType] + ) -> None: + node_id = self._get_node_id() + if maybe_mutation_str := check_for_mutation( + self.ids_to_funcs[function_id], + inputs, + self._get_cuda_graph_recorded_tensor_checker(), + ): + self.non_cudagraph_managed_mutation_hint[node_id][function_id] = True + # warn once per function_id + if function_id in self.warned_mutation: + return + self.warned_mutation.add(function_id) + log_cudagraph_skip_and_bump_counter(maybe_mutation_str) + else: + self.non_cudagraph_managed_mutation_hint[node_id][function_id] = False + + def _get_node_id(self) -> Optional[GraphID]: + if self.current_node is None: + return None + elif isinstance(self.current_node, (CUDAGraphNode, CUDAWarmupNode)): + return self.current_node.id + else: + raise RuntimeError(f"Unknown node type {type(self.current_node)}") + + def exceed_rerecord_limit( + self, node_id: Optional[GraphID], function_id: FunctionID + ) -> bool: + if torch._dynamo.config.inline_inbuilt_nn_modules: + return False + + return ( + self.num_rerecord[node_id][function_id] + > torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit + ) + + def _run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType: + # we will try to end the current execution lazily, since + # we dont want to do unnecessary checking of the existing outputs + # on the hot path, but both recording and warmup only happen once + # so we check up front + if self.in_recording: + self.try_end_curr_recording(function_id) + + if self.in_warmup: + self.try_end_curr_warmup(function_id) + + node_id = self._get_node_id() + if function_id not in self.non_cudagraph_managed_mutation_hint[node_id]: + self._update_non_cudagraph_managed_mutation(function_id, new_inputs) + + # Early exit if the function mutates inputs which are neither parameters/buffers nor + # cudagraph recorded tensors. This check should happen after `try_end_curr_recording` + # and `try_end_curr_warmup` which may change self.current_node. + if self.non_cudagraph_managed_mutation_hint[node_id][ + function_id + ] or self.exceed_rerecord_limit(node_id, function_id): + return self.ids_to_funcs[function_id].model(new_inputs) + + # warming up a function and subsequentally recording may use different memory addresses + # because both depend on the state of the caching allocator. if we warm up graph A, + # then warm up graph B and make more allocations, the subsequent recording of A will not + # necessarily use the same addresses as in the warm up. Thus any warm up of a node can only + # be followed by warm up runs. + if ( + ( + not ( + function_id in self.warmed_up_functions + or config.triton.skip_cudagraph_warmup + ) + ) + or self.in_warmup + or config.triton.force_cudagraphs_warmup + ): + # If we are in the middle of executing cuda graphs, then we need to checkpoint memory state. + # Both Recording and Warmup will be reflected in the allocator and dont need changes + if self.path_state == ExecutionState.EXECUTION: + self.apply_checkpoint_execution_state_in_allocator() + + return self.run_eager(new_inputs, function_id) + + assert not isinstance(self.current_node, CUDAWarmupNode) + child_nodes = ( + self.roots if self.current_node is None else self.current_node.children + ) + + if not self.in_recording: + unexpected_rerecord, unexpected_rerecord_reason = False, lambda: "" + for child in child_nodes[function_id]: + # here we are checking memory consistency between recording and execution, + # as well as things like stability of tensor locations, etc + # and other + status, status_logger = child.check_invariants(new_inputs) + if status == CheckInvariantStatus.SUCCESS: + return self.execute_node(child, new_inputs) + + if ( + status == CheckInvariantStatus.StaticInputIdxMismatch + or status == CheckInvariantStatus.CudagraphManagedIdxMismatch + ): + unexpected_rerecord = True + unexpected_rerecord_reason = status_logger + + # now that we know the new function can't be run as a child of the + # current node, if it is a root, try to end the current execution. + # as noted above, we want to do this lazily to avoid having to + # check all existing outputs + if self.current_node is not None and function_id in self.roots: + self.try_end_curr_execution() + + # run again to hit the root matching case which must succeed + if self.current_node is None: + return self.run(new_inputs, function_id) + + if len(self.ids_to_funcs[function_id].mutated_input_idxs) > 0: + self._update_non_cudagraph_managed_mutation(function_id, new_inputs) + if self.non_cudagraph_managed_mutation_hint[self._get_node_id()][ + function_id + ]: + return self.ids_to_funcs[function_id].model(new_inputs) + + # nb: run before checkpointing because checkpointing is slow, and we will + # be using the eager caching allocator pool which does not require live + # accounting of tensors in cudagraph allocator + if unexpected_rerecord: + curr_node_id = self._get_node_id() + self.num_rerecord[curr_node_id][function_id] += 1 + if self.exceed_rerecord_limit(curr_node_id, function_id): + _id = curr_node_id.id if curr_node_id else None + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraph due to function {function_id.id} exceeding max " + f"re-recording limit " + f"(={torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit}) " + f"on cudagraph node {_id} due to {unexpected_rerecord_reason()}." + ) + return self.ids_to_funcs[function_id].model(new_inputs) + + # at this point, we necessarily will do a new recording + self.debug_fail_counter += 1 + + self.try_end_curr_execution() + if self.current_node is not None: + self.apply_checkpoint_execution_state_in_allocator() + + # now, we are in a recording state ! + return self.record_function(new_inputs, function_id) + + def shutdown(self) -> None: + """ + Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn + might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown + to avoid a reference cycle. + """ + nodes = [] + for roots in self.roots.values(): + nodes.extend(roots) + + while nodes: + node = nodes.pop() + for children in node.children.values(): + nodes.extend(children) + node.remove_node_cached_tensors() + node.graph = None + + self.graph = None + self.roots = None # type: ignore[assignment] + self.current_node = None + + def record_function( + self, new_inputs: List[InputType], function_id: FunctionID + ) -> OutputType: + assert not isinstance(self.current_node, CUDAWarmupNode) + graph_id = self.new_graph_id() + log.debug( + "Recording function %d of graph recording id %d", + function_id.id, + graph_id.id, + ) + torch.cuda.synchronize() + node = CUDAGraphNode( + self.ids_to_funcs[function_id], + graph_id, + self.current_node, + new_inputs, + self.cuda_graphs_thread_pool, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + ) + if self.current_node is None: + self.roots[function_id].append(node) + else: + self.current_node.add_child(function_id, node) + self.current_node = node + self.path_state = ExecutionState.RECORDING + self.update_generation() + torch.cuda.synchronize() + return node.run_first_inputs(new_inputs) + + def execute_node( + self, node: CUDAGraphNode, new_inputs: List[InputType] + ) -> OutputType: + self.current_node = node + self.path_state = ExecutionState.EXECUTION + self.update_generation() + return node.run(new_inputs) + + def run_eager( + self, new_inputs: List[InputType], function_id: FunctionID + ) -> OutputType: + # this is only stored on current node, because when we start a new path, + # we will deallocate it + already_warm = function_id in self.warmed_up_functions + if not already_warm: + log.debug("Running warmup of function %d", function_id.id) + else: + log.debug( + "Running eager of function %d because ancestor needed to warm up", + function_id.id, + ) + self.warmed_up_functions.add(function_id) + node = CUDAWarmupNode( + self.ids_to_funcs[function_id], + self.current_node, + self.cuda_graphs_thread_pool, + self.graph, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + already_warm, + self.new_warmup_node_id(), + ) + self.current_node = node + self.path_state = ExecutionState.WARMUP + self.update_generation() + return node.run(new_inputs) + + def new_graph_id(self) -> GraphID: + return GraphID(next(self.graph_counter)) + + def new_func_id(self) -> FunctionID: + return FunctionID(next(self.func_counter)) + + def add_function( + self, + model: ModelType, + inputs: List[InputType], + static_input_idxs: Sequence[int], + stack_traces: Optional[StackTraces], + mode: CompilationMode, + constants: Tuple[torch.Tensor, ...], + placeholders: Tuple[PlaceholderInfo, ...], + mutated_input_idxs: Tuple[int, ...], + ) -> Tuple[ModelType, OutputType,]: + id = self.new_func_id() + self.ids_to_stack_traces[id] = stack_traces + self.ids_to_funcs[id] = WrappedFunction( + model, + list(static_input_idxs), + id, + tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda), + placeholders, + mutated_input_idxs, + ) + self.id_to_mode[id] = mode + fn = functools.partial(self.run, function_id=id) + + # container needs to set clean up when fn dies + get_container(self.device_index).add_strong_reference(fn) + return fn, fn(inputs) + + @property + def in_recording(self) -> bool: + return self.path_state == ExecutionState.RECORDING + + @property + def in_warmup(self) -> bool: + return self.path_state == ExecutionState.WARMUP + + def get_roots(self) -> Iterator[CUDAGraphNode]: + for nodes in self.roots.values(): + yield from nodes + + @property + def current_node(self) -> Optional[Union[CUDAGraphNode, CUDAWarmupNode]]: + return self._current_node + + @current_node.setter + def current_node( + self, value: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] + ) -> None: + self._current_node = value + if value is None: + self.path_state = ExecutionState.NONE + + def update_generation(self) -> None: + self.current_gen = self.get_curr_generation() + + @staticmethod + def get_curr_generation() -> int: + if MarkStepBox.mark_step_counter != 0: + return MarkStepBox.mark_step_counter + + return GenerationTracker.generation + + @staticmethod + def user_invoked_mark_step() -> bool: + return MarkStepBox.mark_step_counter != 0 + + def can_start_new_generation(self) -> bool: + if not self.in_new_torch_compile_invocation(): + return False + + if self.user_invoked_mark_step(): + return True + + return not self.running_forwards_with_pending_backwards + + def in_new_torch_compile_invocation(self) -> bool: + return self.current_gen != self.get_curr_generation() + + def try_end_curr_recording(self, function_id: FunctionID) -> None: + """ + Check if the current recording can be terminated, either because all outputs of the + previously recorded node are dead or because it was executed in a different + generation. Will set current_node to None and in_recording to False if successful. + """ + assert self.in_recording + assert self.current_node is not None + + # multiple invocations, allow overwriting the previous generation + if self.can_start_new_generation(): + self.dealloc_current_path_weakrefs() + self.clear_current_path_state_and_set_to_none() + return + + if self.current_node.all_outputs_are_dead(): + self.clear_current_path_state_and_set_to_none() + return + + self.check_warn_on_unable_to_start_executing(function_id) + + def try_end_curr_execution(self) -> None: + """ + Check if the current executing node can be terminated, either because all outputs of the + previously executed node are dead or because it was executed in a different generation. + Will set current_node to None if successful. + """ + + assert not self.in_recording + if self.current_node is None: + return + + if self.can_start_new_generation(): + self.clear_current_path_state_and_set_to_none() + return + + if self.current_node.all_outputs_are_dead(): + self.clear_current_path_state_and_set_to_none() + + def try_end_curr_warmup(self, function_id: FunctionID) -> None: + if self.can_start_new_generation(): + self.dealloc_current_path_weakrefs() + self.current_node = None + return + + assert self.current_node is not None + if self.current_node.all_outputs_are_dead(): + self.current_node = None + return + + self.check_warn_on_unable_to_start_executing(function_id) + + def check_warn_on_unable_to_start_executing(self, function_id: FunctionID) -> None: + "Warn if we in a potential loop where we are unable to hit fast path" + if ( + function_id in self.warned_functions + or not self.in_new_torch_compile_invocation() + ): + return + + assert self.current_node is not None + existing_nodes = [ + node + for node in self.current_node._path_from_root + if node.wrapped_function.id == function_id + ] + + if len(existing_nodes) <= 1: + return + + # repeated same pattern + parents = { + n.parent.wrapped_function.id + for n in itertools.chain(existing_nodes, (self.current_node,)) + if n.parent is not None + } + if len(parents) == len(existing_nodes): + return + + self.warned_functions.add(function_id) + warnings.warn( + "Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. " + "Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() " + "before each model invocation" + ) + + def dealloc_current_path_weakrefs(self) -> None: + assert self.current_node is not None + # TODO: we could also allow the these weak refs to continue to be allocated, + # but that adds some complications. + for node in self.current_node._path_from_root: + assert node.stack_traces is not None + assert len(node.tensor_weakrefs) == len(node.stack_traces) + for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces): + ten = None if t is None else t() + if ten is None: + continue + + stack_trace = ( + stack_trace.strip() + if stack_trace + else "[Could not find stack trace]" + ) + msg = ( + "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. " + f"Stack trace: {stack_trace}. " + "To prevent overwriting, clone the tensor outside of torch.compile() " + "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation." + ) + torch._C._set_storage_access_error_msg(ten, msg) + + deleted = set() + for storage_ref in self.current_node.path_live_weakrefs(): + _storage_deref = storage_ref() + if _storage_deref and storage_ref.data_ptr() not in deleted: + deleted.add(storage_ref.data_ptr()) + torch._C._free_And_Remove_DeleterFn(_storage_deref) + + def clear_current_path_state_and_set_to_none(self) -> None: + assert isinstance(self.current_node, CUDAGraphNode) + self.current_node.clear_path_state() + self.current_node = None + + def apply_checkpoint_execution_state_in_allocator(self) -> None: + """ + Checkpoint the current execution state in the caching allocator so that + additional cudagraph recordings can be made respecting existent live storages. + """ + assert isinstance(self.current_node, CUDAGraphNode) + self.debug_checkpointing_counter += 1 + log.debug( + "Checkpointing cuda caching allocator state. Number of checkpoints %d", + self.debug_checkpointing_counter, + ) + + state = self.current_node.checkpointed_caching_state + device = self.current_node.device + assert state is not None and device is not None + + # currently we deallocate on instead of allowing stale recordings + stale_storages: List[int] = [] + + # remove cached tensors, otherwise they would prevent memory from being + # reclaimed in subsequent recordings + self.current_node.remove_path_cached_tensors() + live_storages_wrappers = list(self.current_node.path_live_weakrefs()) + + # path_live_weakrefs guarantees that t() will not be None + live_storages_weak_refs: list[int] = [t() for t in live_storages_wrappers] # type: ignore[misc] + ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation() + torch._C._cuda_setCheckpointPoolState( + device, state, stale_storages, live_storages_weak_refs + ) + + # NB: deduplicate aliased outputs + for ptr in set(ptrs_to_deallocate): + torch._C._cuda_cudaCachingAllocator_raw_delete(ptr) + + # Now the live blocks should be exactly equal to the live storages in private pool + if config.triton.slow_path_cudagraph_asserts: + check_memory_pool( + self.device_index, self.cuda_graphs_thread_pool, live_storages_wrappers + ) + for wrapper in live_storages_wrappers: + storage_ptr = wrapper() + assert storage_ptr is not None + assert torch._C._has_Standard_Deleter(storage_ptr) + assert wrapper.data_ptr() not in ptrs_to_deallocate + + def live_cudagraph_pool_storages_in_curr_execution( + self, + ) -> List[StorageWeakRefPointer]: + if self.current_node is None: + return [] + # explicitly ignoring previous recorded outputs from past path + # path_live_weakrefs() guarantees that t() will not be None + return [t() for t in self.current_node.path_live_weakrefs()] # type: ignore[misc] diff --git a/lib/python3.10/site-packages/torch/_inductor/cudagraph_utils.py b/lib/python3.10/site-packages/torch/_inductor/cudagraph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4f97e1daf60f623f3baea38f8b60f0466af81385 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/cudagraph_utils.py @@ -0,0 +1,330 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import torch +from torch._dynamo.utils import counters +from torch._inductor.utils import InputType + + +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +OutputType = List[Optional[Union[int, torch.Tensor]]] +ModelType = Callable[[List[InputType]], OutputType] + + +@dataclasses.dataclass(frozen=True) +class FunctionID: + "Unique counter of a function wrapped in cudagraphify_impl" + id: int + + +@dataclasses.dataclass(frozen=True) +class PlaceholderInfo: + """ + A serializable version of torch.fx.Node that contains information + pertinent to placeholder stack traces. We use these in logging and error messages + related to cudagraphs, and will cache these results. + """ + + name: str + stack_trace: Optional[str] + # This field is recursive, but never cyclic (since a node never uses itself) + users: List[PlaceholderInfo] + mutating_use_stack_trace: Optional[str] + + +@dataclasses.dataclass(frozen=True) +class WrappedFunction: + """ + Represents a function that you want to record for CUDA graph replay, + with a little more metadata so we can identify if we have an applicable + CUDA graph in our CUDA graph tree for it. + """ + + model: Callable[..., Any] + static_input_idxs: Sequence[int] + id: FunctionID + constants: Tuple[torch.Tensor, ...] + placeholders: Sequence[PlaceholderInfo] + mutated_input_idxs: Sequence[int] + + +def get_mutating_use_stack_trace_from_node( + placeholder_node: torch.fx.Node, +) -> Optional[str]: + # reinplaced uses might have a single, non-copy_ use + if len(placeholder_node.users) == 1: + return next(iter(placeholder_node.users)).meta.get("stack_trace", None) + + for use in placeholder_node.users: + if use.target == torch.ops.aten.copy_.default: + if stack_trace := use.meta.get("stack_trace", None): + return stack_trace + + return None + + +def get_mutating_use_stack_trace(placeholder_info: PlaceholderInfo) -> Optional[str]: + return placeholder_info.mutating_use_stack_trace + + +def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo: + name = placeholder_node.name + stack_trace = placeholder_node.meta.get("stack_trace", None) + users = [] + mutating_use_stack_trace = None + # Only recurse to users once, since we only care about user's stack traces + if placeholder_node.op == "placeholder": + users = [to_placeholder_info(i) for i in placeholder_node.users] + mutating_use_stack_trace = get_mutating_use_stack_trace_from_node( + placeholder_node + ) + + return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace) + + +def get_placeholder_info(graph: torch.fx.Graph) -> List[PlaceholderInfo]: + return [ + to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder" + ] + + +def format_default_skip_message(reason: str) -> str: + return f"skipping cudagraphs due to {reason}" + + +def get_mutation_stack_trace( + placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int] +) -> str: + stack_trace: Optional[str] = "" + + for idx in mutation_indices: + placeholder = placeholders[idx] + if stack_trace := get_mutating_use_stack_trace(placeholder): + break + + msg = format_default_skip_message( + f"mutated inputs ({len(mutation_indices)} instances)" + ) + if stack_trace: + return f"{msg}. Found from : \n {stack_trace}" + + return msg + + +def check_for_mutation( + func: WrappedFunction, + inputs: List[InputType], + is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool], +) -> Optional[str]: + # doesnt work for non-trees because the warmup run would apply mutation twice + if torch._inductor.config.triton.cudagraph_trees: + # checking if mutation is only on parameters/static inputs + mutation_indices: Sequence[int] = [ + idx + for idx in func.mutated_input_idxs + if not ( + idx in func.static_input_idxs + or is_cuda_graph_recorded_tensor(inputs[idx]) # type: ignore[arg-type] + ) + ] + else: + mutation_indices = func.mutated_input_idxs + + static_inputs_log.debug( + "check mutation static input indices: %s", func.static_input_idxs + ) + static_inputs_log.debug("check mutation mutation indices: %s", mutation_indices) + + return ( + get_mutation_stack_trace(func.placeholders, mutation_indices) + if mutation_indices + else None + ) + + +def _get_use_stack_trace(node) -> Optional[str]: + for use in node.users: + if stack_trace := use.meta.get("stack_trace", None): + return stack_trace + return None + + +def check_multiple_devices_or_any_cpu_nodes( + device_node_mapping: Dict[torch.device, torch.fx.Node] +) -> Optional[str]: + if cpu_node := device_node_mapping.get(torch.device("cpu")): + msg = f"cpu device ({cpu_node.name})" + if stack_trace := _get_use_stack_trace(cpu_node): + return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}") + + return format_default_skip_message(msg) + + if ( + len(device_node_mapping) == 1 + and next(iter(device_node_mapping.keys())).type == "cuda" + ): + return None + + keys_repr = (repr(key) for key in device_node_mapping.keys()) + return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}") + + +def check_lowering_disable_cudagraph( + device_node_mapping: Dict[torch.device, torch.fx.Node] +): + return check_multiple_devices_or_any_cpu_nodes(device_node_mapping) + + +def log_cudagraph_skip_and_bump_counter(msg): + perf_hint_log.warning(msg) + counters["inductor"]["cudagraph_skips"] += 1 + + +@dataclasses.dataclass +class BoxedDeviceIndex: + value: Optional[int] + + def set(self, device_idx: Optional[int]): + assert device_idx is None or isinstance(device_idx, int) + self.value = device_idx + + +def check_for_mutation_ignore_cuda_graph_managed_tensor( + gm: torch.fx.GraphModule, compiled_graph, static_input_idxs: Sequence[int] +) -> Optional[str]: + default_msg = format_default_skip_message("mutated inputs") + + # doesnt work for non-trees because the warmup run would apply mutation twice + if torch._inductor.config.triton.cudagraph_trees: + unique_idxs = set(static_input_idxs) + # checking if mutation is only on parameters/static inputs + mutation_indices = [ + idx for idx in compiled_graph.mutated_input_idxs if idx not in unique_idxs + ] + has_mutation = len(mutation_indices) != 0 + if not has_mutation: + return None + placeholders = get_placeholder_info(gm.graph) + return get_mutation_stack_trace(placeholders, mutation_indices) + + else: + has_mutation = len(compiled_graph.mutated_inputs) != 0 + return None if not has_mutation else default_msg + + +def get_placeholder_stack_trace(placeholder: PlaceholderInfo) -> Optional[str]: + """ + Gets the first non-empty stack trace of a placeholder or its users. + """ + if placeholder.stack_trace: + return placeholder.stack_trace + + for user in placeholder.users: + if user.stack_trace: + return user.stack_trace + + return None + + +class CheckInvariantStatus(Enum): + # Check invariant succeeded + SUCCESS = 1 + + # Previously managed data pointers are not stable + CudagraphManagedIdxMismatch = 2 + + # Static tensor input addresses are not stable + StaticInputIdxMismatch = 3 + + # Expected dead indices before graph are live + ExpectedDeadIndicesBeforeGraphMismatch = 4 + + def __str__(self) -> str: + if self.name == "CudagraphManagedIdxMismatch": + return "cudagraph managed tensor data pointer changed" + elif self.name == "StaticInputIdxMismatch": + return "static input data pointer changed" + elif self.name == "ExpectedDeadIndicesBeforeGraphMismatch": + return "expected dead indices before graph are live" + else: + return f"{self.name}: {self.value}" + + +def log_data_ptr_mismatch( + placeholders: Sequence[PlaceholderInfo], + inputs: List[InputType], + recorded_data_ptr: Sequence[Optional[int]], + target_idxs: Sequence[int], + mismatch: CheckInvariantStatus, +) -> str: + """ + Logs the mismatch between input data pointers and recorded data pointers. + This checks only idxs in target_idxs. + """ + assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len( + placeholders + ), "length mismatch between inputs, recorded_data_ptr, and placeholders" + + t_tensors = [inputs[i] for i in target_idxs] + t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs] + error_msg = f"{mismatch}.\n" + for i, (tensor, data_ptr) in enumerate(zip(t_tensors, t_data_ptrs)): + assert isinstance(tensor, torch.Tensor) + index = target_idxs[i] + if tensor.data_ptr() != data_ptr: + placeholder = placeholders[index] + error_msg = ( + f"{error_msg}input name: {placeholder.name}. " + f"data pointer changed from {data_ptr} to {tensor.data_ptr()}. " + f"input stack trace: {get_placeholder_stack_trace(placeholder)}\n" + ) + return error_msg + + +def maybe_warning_due_to_dynamic_shape( + fn_cache: Dict[Tuple[int, ...], Callable[..., Any]], + new_int_key: Any, +) -> bool: + num_cudagraphs = len(fn_cache.keys()) + 1 + + def warn_msg(): + return ( + "CUDAGraph supports dynamic shapes by recording a new graph for each " + "distinct input size. Recording too many CUDAGraphs may lead to " + f"extra overhead. We have observed {num_cudagraphs} distinct sizes. " + "Please consider the following options for better performance: " + "a) padding inputs to a few fixed number of shapes; or b) set " + "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. " + "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None " + "to silence this warning." + ) + + if ( + torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit + and num_cudagraphs + > torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit + ): + perf_hint_log.warning(warn_msg()) + return True + + return False + + +@dataclasses.dataclass(frozen=True) +class CudagraphCachedInfo: + """ + Info needed to realign inputs + """ + + placeholders: Sequence[PlaceholderInfo] + stack_traces: List[Optional[str]] + cudagraph_fail_reasons: List[str] diff --git a/lib/python3.10/site-packages/torch/_inductor/debug.py b/lib/python3.10/site-packages/torch/_inductor/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..868833a425be4b13d76dbca5318498ee3b7e5a46 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/debug.py @@ -0,0 +1,693 @@ +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import os +import os.path +import pickle +import pstats +import shutil +import subprocess +from typing import Any, Callable, Dict, IO, Iterator, List, Optional, Type, Union +from unittest.mock import patch + +import torch +from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled +from torch import fx as fx +from torch._dynamo.repro.after_aot import save_graph_repro +from torch._dynamo.utils import get_debug_dir +from torch.fx.graph_module import GraphModule +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.fx.passes.tools_common import legalize_graph +from torch.utils._pytree import tree_map + +from . import config, ir # noqa: F811, this is needed +from .scheduler import ( + BaseSchedulerNode, + FusedSchedulerNode, + NopKernelSchedulerNode, + OutputNode, + SchedulerNode, +) +from .virtualized import V + + +log = logging.getLogger(__name__) + +SchedulerNodeList = List[Any] +BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"]) +GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"] + + +@functools.lru_cache(None) +def has_dot() -> bool: + try: + subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE) + return True + except subprocess.SubprocessError: + return False + + +def draw_buffers( + nodes: List[BaseSchedulerNode], + print_graph: bool = False, + fname: Optional[str] = None, +) -> None: + """ + Draw a graph in fname.svg. + """ + if not has_dot(): + log.warning("draw_buffers() requires `graphviz` package") + return + + if fname is None: + fname = get_graph_being_compiled() + + graph = create_fx_from_snodes(nodes) + + for node in graph.nodes: + if "fusion_meta" not in node.meta: + continue + group = node.meta["fusion_meta"].group + if isinstance(group, tuple): + if isinstance(group[1], int): + group = (group[1],) + else: + group = group[1] + + # gather meta data + dtype = None + if isinstance(node, ir.ComputedBuffer): + dtype = node.data.dtype + + metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type] + node.meta["tensor_meta"] = metadata + + if print_graph: + print(graph) + + gm = GraphModule({}, graph) + legalize_graph(gm) + gm.graph.lint() + draw_graph( + gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape + ) + + +def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph: + """ + Creates a FX Graph from a list of SchedulerNode objects. + """ + + def get_fake_func(name: str) -> Callable[..., int]: + def func1(*args: Any) -> int: + return 0 + + func1.__name__ = name + return func1 + + FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"]) + + buf_to_fx_node = {} + node_to_fx_node = {} + graph = torch.fx.Graph() + first_node = None + + outputs = [] + group: Any = None + # create call_function node for each Buffer and Kernel + for snode in snodes: + if snode.is_extern(): + node_type = "extern" + group = node_type + elif snode.is_template(): + node_type = "template" + group = node_type + elif isinstance(snode, NopKernelSchedulerNode): + node_type = "nop" + group = node_type + elif isinstance(snode, SchedulerNode): + node_type = "compute" + group = snode.group + elif isinstance(snode, FusedSchedulerNode): + node_type = "fused" + group = snode.group + else: + raise RuntimeError("Unknown node type") + + fused_name = torch._inductor.utils.get_fused_kernel_name( + snode.get_nodes(), "original_aten" + ) + func_name = f"{node_type}: {fused_name}" + node_func = get_fake_func(func_name) + kwargs = {} + if hasattr(snode, "get_device"): + kwargs = {"device": snode.get_device()} + fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type] + + def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool: + if isinstance(snode, FusedSchedulerNode): + return any(in_output(x) for x in snode.snodes) + return any( + isinstance(user.node, OutputNode) + for buf in snode.get_outputs() + for user in buf.users + ) + + if in_output(snode): + outputs.append(fx_node) + name = snode.get_name() + fx_node.name = name + + fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type) + + node_to_fx_node[name] = fx_node + for buf in snode.get_outputs(): + buf_to_fx_node[buf.get_name()] = fx_node + + if first_node is None: + first_node = fx_node + + # create edges between nodes + for snode in snodes: + name = snode.get_name() + deps = snode.read_writes.reads + + fx_node = node_to_fx_node[name] + new_args = [] + for dep in deps: + if dep.name in buf_to_fx_node: + dep_node = buf_to_fx_node[dep.name] + else: + with graph.inserting_before(first_node): + dep_node = graph.placeholder(dep.name) + buf_to_fx_node[dep.name] = dep_node + if dep_node == fx_node: # to avoid cycles + continue + new_args.append(dep_node) + + fx_node.args = tuple(new_args) + + graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) + return graph + + +def update_orig_fx_node_name_to_buf_name( + nodes: Optional[SchedulerNodeList], + node_name_to_buf_name: Dict[str, str], + parent_buf_name: Optional[str] = None, + n_origins: int = 0, +) -> None: + if nodes is None: + return + for node in nodes: + # for FusedSchedulerNode, traverse recursively into get_nodes() + buf_name = node.get_name() + children_nodes = node.get_nodes() + if children_nodes is not None and len(children_nodes) > 1: + update_orig_fx_node_name_to_buf_name( + children_nodes, + node_name_to_buf_name, + buf_name if parent_buf_name is None else parent_buf_name, + ) + continue + else: + assert len(children_nodes) == 1 and children_nodes[0] == node + + ir_node = node.node + if ir_node is None or ir_node.origins is None: + continue + for origin in ir_node.origins: + node_name = origin.name + # when buf1 and buf2 both have origin=node1 + # we draw node1 according to buf1 + if node_name not in node_name_to_buf_name: + node_name_to_buf_name[node_name] = ( + buf_name if parent_buf_name is None else parent_buf_name + ) + + +def get_node_name_to_buf_meta( + node_name_to_buf_name: Dict[str, str] +) -> Dict[str, BufMeta]: + buf_name_to_n_node = {} + for node_name, buf_name in node_name_to_buf_name.items(): + if buf_name not in buf_name_to_n_node: + buf_name_to_n_node[buf_name] = {node_name} + else: + buf_name_to_n_node[buf_name].add(node_name) + + node_name_to_buf_meta = {} + for node_name, buf_name in node_name_to_buf_name.items(): + n_node = len(buf_name_to_n_node[buf_name]) + node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node) + return node_name_to_buf_meta + + +def annotate_orig_fx_with_snodes( + gm: torch.fx.GraphModule, + snodes: SchedulerNodeList, +) -> None: + """ + Creates a FX Graph from a list of SchedulerNode objects. + """ + node_name_to_buf_name: Dict[str, str] = {} + update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name) + if node_name_to_buf_name is None: + return + node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name) + for node in gm.graph.nodes: + if node.name in node_name_to_buf_meta: + node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name) + + +@contextlib.contextmanager +def enable_aot_logging() -> Iterator[None]: + compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + import torch._functorch.aot_autograd + + log = logging.getLogger(torch._functorch.aot_autograd.__name__) + + stack = contextlib.ExitStack() + if not compile_debug: + try: + yield + finally: + stack.close() + return + + # Enable all graphs to be logged to a file by setting the flags to True + # and the log level of the file logger to DEBUG + stack.enter_context(patch("functorch.compile.config.debug_partitioner", True)) + + path = os.path.join(get_debug_dir(), "torchinductor") + os.makedirs(path, exist_ok=True) + + fh = logging.FileHandler( + os.path.join( + path, + f"aot_{get_aot_graph_name()}_debug.log", + ) + ) + fh.setLevel(logging.DEBUG) + fh.setFormatter( + logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") + ) + log.addHandler(fh) + try: + yield + finally: + log.removeHandler(fh) + stack.close() + + +class DebugContext: + _counter = itertools.count() + + @staticmethod + def create_debug_dir(folder_name: str) -> Optional[str]: + debug_dir = config.trace.debug_dir or get_debug_dir() + for n in DebugContext._counter: + dirname = os.path.join( + debug_dir, + "torchinductor", + f"{folder_name}.{n}", + ) + if not os.path.exists(dirname): + os.makedirs(dirname) + return dirname + return None + + def __init__(self) -> None: + self._prof = None + self._path = None + self._stack = contextlib.ExitStack() + + def copy(self, new_path: str) -> None: + if not self._path: + return + assert new_path.endswith(".debug"), new_path + from filelock import FileLock + + try: + with FileLock(f"{new_path}.lock"): + if os.path.exists(new_path): + shutil.rmtree(new_path) + shutil.copytree(self._path, new_path) + except OSError: + log.warning( + "Failed to copy debug files from %s to %s", self._path, new_path + ) + + def fopen( + self, + filename: str, + write_mode: str = "w", + *args: Any, + **kwargs: Any, + ) -> IO[Any]: + assert self._path + return open(os.path.join(self._path, filename), write_mode, *args, **kwargs) + + @contextlib.contextmanager + def fopen_context( + self, + filename: str, + write_mode: str = "w", + *args: Any, + **kwargs: Any, + ) -> Iterator[IO[Any]]: + assert self._path + with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f: + yield f + + def filename(self, suffix: str) -> str: + assert self._path + return os.path.join(self._path, suffix) + + def upload_tar(self) -> None: + if config.trace.upload_tar is not None: + import tarfile + + assert self._path + tar_file = os.path.join( + self._path, f"{os.path.basename(self._path)}.tar.gz" + ) + with tarfile.open(tar_file, "w:gz") as tar: + tar.add(self._path, arcname=os.path.basename(self._path)) + config.trace.upload_tar(tar_file) + + def __enter__(self) -> None: + if config.debug: + log = logging.getLogger("torch._dynamo") + prev_level = log.level + log.setLevel(logging.DEBUG) + + def reset_log_level(level: Any) -> None: + log.setLevel(level) + + self._stack.callback(reset_log_level, prev_level) + + self._stack.enter_context(V.set_debug_handler(self)) + + if not config.trace.enabled: + return + + self._path = self.create_debug_dir(get_aot_graph_name()) # type: ignore[assignment] + + if config.trace.debug_log: + self._setup_log_capture("debug.log", logging.DEBUG) + if config.trace.info_log: + self._setup_log_capture("info.log", logging.INFO) + + def _setup_log_capture( + self, + filename: str, + level: int, + ) -> None: + log = logging.getLogger("torch._inductor") + fd = self._stack.enter_context(self.fopen(filename)) + ch = logging.StreamHandler(fd) + ch.setLevel(level) + ch.setFormatter( + logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") + ) + log.addHandler(ch) + log.setLevel(min(log.level, level)) + self._stack.callback(log.removeHandler, ch) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + if self._prof: + self._prof.disable() + self._save_profile_data() + + if self._path: + self.upload_tar() + log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path) + self._stack.close() + + def _save_profile_data(self) -> None: + assert self._prof + self._prof.dump_stats(self.filename("compile.prof")) + with self.fopen("compile.stats") as fd: + stats = pstats.Stats(self._prof, stream=fd) + stats.strip_dirs() + stats.sort_stats("cumtime") + stats.print_stats(100) + stats.sort_stats("tottime") + stats.print_stats(100) + + def __getattr__(self, name: str) -> Optional[Callable[..., None]]: + if config.trace.enabled and getattr(config.trace, name): + try: + return getattr(DebugFormatter(self), name) + except Exception: + log.warning("Ignoring exception in debug code", exc_info=True) + return None + else: + + def ignored(*args: Any, **kwargs: Any) -> None: + pass + + return ignored + + +class DebugFormatter: + def __init__(self, handler: DebugContext) -> None: + self.fopen = handler.fopen + self.fopen_context = handler.fopen_context + self.filename = handler.filename + self.handler = handler + + def fx_graph( + self, + gm: torch.fx.GraphModule, + inputs: List[torch.Tensor], + ) -> None: + with self.fopen("fx_graph_runnable.py") as fd: + save_graph_repro(fd, gm, inputs, "inductor") + + with self.fopen("fx_graph_readable.py") as fd: + fd.write(gm.print_readable(print_output=False)) + + def fx_graph_transformed( + self, + gm: torch.fx.GraphModule, + inputs: List[torch.Tensor], + ) -> None: + with self.fopen("fx_graph_transformed.py") as fd: + fd.write(gm.print_readable(print_output=False)) + + def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None: + self._write_ir("ir_pre_fusion.txt", nodes) + + def ir_post_fusion(self, nodes: SchedulerNodeList) -> None: + self._write_ir("ir_post_fusion.txt", nodes) + + def _write_ir( + self, + filename: str, + nodes: SchedulerNodeList, + ) -> None: + with self.fopen(filename) as fd: + log.info("Writing debug ir to %s", fd.name) + for node in nodes: + fd.write(node.debug_str()) + fd.write("\n\n\n") + + def graph_diagram(self, nodes: SchedulerNodeList) -> None: + draw_buffers(nodes, fname=self.filename("graph_diagram.svg")) + + def draw_orig_fx_graph( + self, + gm: torch.fx.GraphModule, + nodes: SchedulerNodeList, + ) -> None: + annotate_orig_fx_with_snodes(gm, nodes) + draw_graph( + gm, + fname=self.filename("orig_fx_graph_diagram.svg"), + clear_meta=False, + prog=GRAPHVIZ_COMMAND_SCALABLE, + parse_stack_trace=True, + dot_graph_shape=config.trace.dot_graph_shape, + ) + + def output_code(self, filename: str) -> None: + shutil.copy(filename, self.filename("output_code.py")) + + def log_autotuning_results( + self, + name: str, + input_nodes: List[ir.IRNode], + timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821 + elapse: float, + precompile_elapse: float, + ) -> None: + import json + + from .ir import FixedLayout + + def build_node_info(node: ir.IRNode) -> Dict[str, str]: + if hasattr(node, "name"): + node_name = node.name + else: + node_name = "" + node_info = { + "name": node_name, + "type": type(node).__name__, + } + try: + layout = node.get_layout() + if isinstance(layout, FixedLayout): + offset = 0 + try: + offset = int(layout.offset) + except Exception: + try: + offset = V.graph.sizevars.size_hint( + layout.offset, fallback=0 + ) + except Exception: + pass + static_layout = FixedLayout( + layout.device, + dtype=layout.dtype, + size=list(V.graph.sizevars.size_hints(layout.size)), + stride=list(V.graph.sizevars.size_hints(layout.stride)), + offset=offset, + ) + node_info["layout"] = str(static_layout) + else: + node_info["layout"] = str(node.get_layout()) + except Exception as e: + pass + try: + node_info["dtype"] = str(node.get_dtype()) + except Exception as e: + pass + try: + node_info["device"] = str(node.get_device()) + except Exception as e: + pass + try: + node_info["stride"] = str( + V.graph.sizevars.size_hints(node.get_stride()) + ) + except Exception as e: + pass + try: + node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size())) + except Exception as e: + pass + try: + node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel())) + except Exception as e: + pass + if hasattr(node, "data") and isinstance(node.data, ir.IRNode): + node_info["data"] = build_node_info(node.data) + return node_info + + general_properties = { + "op_name": name, + "cuda_device_name": torch.cuda.get_device_name(), + "cuda_device_count": torch.cuda.device_count(), + "input_nodes": [build_node_info(node) for node in input_nodes], + "autotuning_time": elapse, + "precompile_time": precompile_elapse, + } + with self.fopen_context( + "autotuning_result_json_list.txt", "at", encoding="utf-8" + ) as fd: + for caller, time in timings.items(): + info_dict = dict(caller.info_dict()) + info_dict.update(general_properties) + info_dict["benchmark_result"] = time + json.dump(info_dict, fd) + fd.write("\n") + + +@dataclasses.dataclass +class TensorMetadataHolder: + tensor_metadata: TensorMetadata + device: torch.device + + +save_args_cnt = itertools.count() + + +def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: + """ + This function is used to save arguments for a compile_fx_inner function call + to the file system. Later on one can replay the compile_fx_inner call + with the saved arguments using load_args_and_run_compile_fx_inner. + """ + + folder = "/tmp/inductor_saved_args" + if not os.path.exists(folder): + os.mkdir(folder) + + def handle_tensor(x: Any) -> Any: + """ + Pickle FakeTensor will result in error: + AttributeError: Can't pickle local object 'WeakValueDictionary.__init__..remove' + + Convert all Tensor to metadata. This may also makes pickle faster. + """ + if isinstance(x, torch.Tensor): + return TensorMetadataHolder(_extract_tensor_metadata(x), x.device) + else: + return x + + args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs)) + + fn_name = "compile_fx_inner" + path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl" + with open(path, "wb") as f: + pickle.dump((args_to_save, kwargs_to_save), f) + + if log.isEnabledFor(logging.DEBUG): + message = f""" +Arguments for a compile_fx_inner call is saved to {path}. To replay the call, +run the following: + +from torch._inductor.debug import load_args_and_run_compile_fx_inner +load_args_and_run_compile_fx_inner({path!r}) + """ + # call print rather than log.debug. log.debug will print message + # prefix for each line which makes the code snippet harder to be + # copied. + # Not a big deal since the code is already been guarded by checking + # the log level. + print(message) + + +def load_args_and_run_compile_fx_inner(path: str) -> Any: + from torch._inductor.compile_fx import compile_fx_inner + + with open(path, "rb") as f: + args, kwargs = pickle.load(f) + + def handle_tensor(x: Any) -> Any: + if isinstance(x, TensorMetadataHolder): + return torch._dynamo.testing.rand_strided( + x.tensor_metadata.shape, + x.tensor_metadata.stride, + x.tensor_metadata.dtype, + x.device, + ) + else: + return x + + fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + with fake_mode, config.patch("save_args", False): + args, kwargs = tree_map(handle_tensor, (args, kwargs)) + return compile_fx_inner(*args, **kwargs) diff --git a/lib/python3.10/site-packages/torch/_inductor/decomposition.py b/lib/python3.10/site-packages/torch/_inductor/decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..0e067395f807121789f2a51fe4919f122f03f170 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/decomposition.py @@ -0,0 +1,980 @@ +# mypy: allow-untyped-decorators +import functools +import logging +import math +import sys +import typing +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch._decomp as decomp +import torch._prims_common as utils +import torch.ao.quantization.fx._decomposed +from torch._decomp import ( + core_aten_decompositions, + get_decompositions, + remove_decompositions, +) +from torch._decomp.decompositions import ( + _grid_sampler_2d as decomp_grid_sampler_2d, + pw_cast_for_opmath, +) +from torch._decomp.decompositions_for_rng import extra_random_decomps +from torch._dynamo.utils import counters +from torch._higher_order_ops.out_dtype import out_dtype +from torch._inductor.utils import pad_listlike +from torch._prims_common import ( + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + type_to_dtype, +) +from torch.fx.experimental.symbolic_shapes import definitely_true, guard_size_oblivious + +from . import config, inductor_prims +from .utils import ( + is_gpu, + needs_fallback_due_to_atomic_add_limitations, + use_scatter_fallback, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims +quantized = torch.ops.quantized +_quantized = torch.ops._quantized +quantized_decomposed = torch.ops.quantized_decomposed + +inductor_decompositions = get_decompositions( + [ + aten._adaptive_avg_pool2d_backward, + aten.addmv, + aten.arange, + aten.bitwise_and_, + aten.bitwise_or_, + aten.clamp_min_, + aten.dist, + aten.empty_like, + aten.flip, + aten.gelu, + aten.hardtanh, + aten.index_select, + aten.lcm, + aten.leaky_relu, + aten.linalg_vector_norm, + aten._log_softmax, + aten.max_pool2d_with_indices_backward, + aten._native_batch_norm_legit, + aten._native_batch_norm_legit_functional, + aten._native_batch_norm_legit_no_training, + aten._batch_norm_with_update, + aten._batch_norm_with_update_functional, + aten._batch_norm_no_update, + aten.batch_norm_backward, + aten.native_batch_norm, + aten.native_group_norm, + aten.native_layer_norm, + aten.nll_loss2d_backward, + aten._softmax, + aten.sin_, + aten.sqrt_, + out_dtype, + aten._to_copy, + aten.tril_indices, + aten.triu_indices, + aten.upsample_bilinear2d.vec, + quantized.linear_dynamic_fp16_unpacked_weight, + _quantized.wrapped_quantized_linear, + ] +) +decompositions = {**core_aten_decompositions(), **inductor_decompositions} + +# Remove unwanted decompositions included via the core ATen decompositions from +# the Inductor decomp table. +decomps_to_exclude = [ + aten._unsafe_index, + aten._unsafe_masked_index, + aten._unsafe_masked_index_put_accumulate, + aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py + aten._softmax_backward_data, + aten.clamp_max, + aten.clamp_min, + aten.glu, # inductor lowers this directly + aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass + aten.slice_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass + aten.split.Tensor, # inductor lowers this directly + aten.squeeze, # inductor lowers this directly + aten.sum, # inductor lowers this directly + aten.unbind, # inductor lowers this directly +] + +remove_decompositions(decompositions, decomps_to_exclude) + + +def register_decomposition( + ops: List[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]] +) -> Callable[..., Any]: + for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined] + if op in decompositions: + log.warning("duplicate decomp: %s", ops) + return decomp.register_decomposition(ops, decompositions) + + +# TODO: for now, inductor doesn't handle asserts +# because the condition is symbol -> tensor in the graph. +@register_decomposition([aten._assert_async.msg]) +def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None: + return + + +# Following `assert_async_msg_decomp` and implement as non-op. +@register_decomposition([aten._functional_assert_async.msg]) +def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None: + return + + +@register_decomposition([aten.sym_constrain_range_for_size.default]) +def sym_constrain_range_for_size( + symbol: torch.SymInt, + *, + min: Optional[torch.types.Number] = None, + max: Optional[torch.types.Number] = None, +) -> None: + return + + +@register_decomposition([aten.clamp]) +@pw_cast_for_opmath +def clamp( + x: torch.Tensor, + min: Optional[torch.types.Number] = None, + max: Optional[torch.types.Number] = None, +) -> torch.Tensor: + if min is not None: + x = x.clamp_min(min) + if max is not None: + x = x.clamp_max(max) + return x + + +@register_decomposition([aten.full]) +def full( + size: List[Union[int, torch.SymInt]], + fill_value: torch.types.Number, + **kwargs: Any, +) -> torch.Tensor: + dtype = kwargs.get("dtype") + if dtype is None: + kwargs["dtype"] = type_to_dtype(type(fill_value)) + return torch.full(size, fill_value, **kwargs) + return NotImplemented + + +# Not really sure how to put this into the main library. PrimTorch wants +# empty_permuted to go to the prim, and typically users don't really want +# to decompose to empty_strided (but inductor is OK with it, because we are +# cool with strides and everything goes to empty_strided) +@register_decomposition([aten.empty_permuted.default]) +def empty_permuted( + size: List[Union[int, torch.SymInt]], + physical_layout: List[int], + **kwargs: Any, +) -> torch.Tensor: + perm = [0] * len(size) + for p, l in enumerate(physical_layout): + perm[l] = p + return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm) + + +@register_decomposition([aten.convolution_backward]) +def convolution_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + bias_sizes: List[int], + stride: Union[int, List[int]], + padding: Union[int, List[int]], + dilation: Union[int, List[int]], + transposed: bool, + output_padding: List[int], + groups: int, + output_mask: List[bool], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not output_mask[2] or not is_gpu(grad_output.device.type): + return NotImplemented + grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim()))) + grad_inp, grad_weight, _ = aten.convolution_backward( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + [output_mask[0], output_mask[1], False], + ) + return (grad_inp, grad_weight, grad_bias) + + +@register_decomposition([aten.round.decimals]) +def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor: + ten_pow_decimals = 10.0**decimals + return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals) + + +@register_decomposition([aten.bmm]) +@pw_cast_for_opmath +def bmm( + self: torch.Tensor, + batch2: torch.Tensor, +) -> torch.Tensor: + if config.coordinate_descent_tuning: + if guard_size_oblivious(self.shape[1] == 1) or guard_size_oblivious( + batch2.shape[2] == 1 + ): + out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2) + return out + if self.device.type == "cpu": + if guard_size_oblivious(self.size(1) == 1) and guard_size_oblivious( + batch2.size(-1) == 1 + ): + counters["inductor"]["decompose_bmm"] += 1 + return torch.sum( + self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True + ).unsqueeze(1) + return NotImplemented + + +@register_decomposition([aten.addmm]) +@pw_cast_for_opmath +def addmm( + self: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, + beta: torch.types.Number = 1, + alpha: torch.types.Number = 1, +) -> torch.Tensor: + if self.device.type == "cpu": + if guard_size_oblivious(mat1.size(0) == 1) and guard_size_oblivious( + mat2.size(-1) == 1 + ): + counters["inductor"]["decompose_addmm"] += 1 + out = torch.sum( + mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True + ).unsqueeze(0) + return alpha * out + beta * self + if ( + guard_size_oblivious(mat1.size(0) == 1) + and definitely_true(mat2.size(0) <= 16) + and definitely_true(mat2.size(1) <= 16) + ): + counters["inductor"]["decompose_addmm"] += 1 + out = (mat1.T * mat2).sum(dim=0, keepdim=True) + return alpha * out + beta * self + return NotImplemented + + +@register_decomposition([aten.mm]) +@pw_cast_for_opmath +def mm( + self: torch.Tensor, + input2: torch.Tensor, +) -> torch.Tensor: + # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. + # todo: Look into why and fix it (hopefully) + if config.coordinate_descent_tuning: + if guard_size_oblivious(self.shape[0] == 1) or guard_size_oblivious( + input2.shape[1] == 1 + ): + return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1) + if self.device.type == "cpu": + if ( + guard_size_oblivious(self.size(-1) == 1) + and guard_size_oblivious(self.size(0) > 0) + and guard_size_oblivious(input2.size(0) == 1) + and (self.dtype == input2.dtype) + and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32) + ): + counters["inductor"]["decompose_mm"] += 1 + return torch.cat([self[i, :] * input2 for i in range(self.size(0))]) + if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious( + input2.size(-1) == 1 + ): + counters["inductor"]["decompose_mm"] += 1 + return torch.sum( + self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True + ).unsqueeze(0) + return NotImplemented + + +# This pass does two things: +# - Eliminate cat when there is only one tensor input +# - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we +# don't remove ALL empty tensors, only the naughty ones) +@register_decomposition([aten.cat.default]) +def cat( + tensors: List[torch.Tensor], + dim: int = 0, +) -> torch.Tensor: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + def non_empty_tensor(x: torch.Tensor) -> bool: + # For better or worse, this is a valid cat: + # + # torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)]) + # + # We'd like to eliminate naughtiness like this for downstream passes + # like split_cat. The easiest way is to just drop such inputs + # (guarding that they are non-zero). + # + # Is it permissible for this filtering to be size-oblivious? A case + # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0 + # happened to be zero, we would have liked to have filtered it out. + # But actually, the ONLY way this could have passed is if u0 == 0, + # so by the time we get here we have already installed a deferred + # runtime assert forcing u0 to be zero. So if this hasn't happened, + # we know that the unbacked SymInt has appropriate size and there are + # no problems. + if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0): + return False + + if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0): + return False + + return True + + filtered_tensors = list(filter(non_empty_tensor, tensors)) + + if len(filtered_tensors) == 1: + return filtered_tensors[0].clone() + elif 1 < len(filtered_tensors) < len(tensors): + # on the first call, when we remove empty tensors, we redispatch recursively + return aten.cat.default(filtered_tensors, dim) + + # optimization, avoid concat for single, repeated input + if len(filtered_tensors) > 1 and all( + t is filtered_tensors[0] for t in filtered_tensors + ): + inp = filtered_tensors[0] + shape = list(inp.shape) + dim = dim + len(inp.shape) if dim < 0 else dim + shape.insert(dim, len(filtered_tensors)) + return inp.unsqueeze(dim).expand(*shape).flatten(dim, dim + 1).clone() + + # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed) + return NotImplemented + + +@register_decomposition([aten.angle]) +def angle(x: torch.Tensor) -> torch.Tensor: + if x.is_complex(): + return torch.where( + torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real) + ) + + # when x is real number + # if x >= 0, return 0 + # if x < 0, return pi + # if x is nan, return nan + _, dtype = elementwise_dtypes( + x, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device) + ret = torch.where(x < 0, pi, 0.0) + return torch.where(torch.isnan(x), float("nan"), ret) + + +@register_decomposition([aten.add]) +def add( + x: torch.Tensor, + y: torch.Tensor, + *, + alpha: Optional[torch.types.Number] = None, +) -> torch.Tensor: + # Require both x and y to be complex tensors. + x_is_complex_tensor = torch.is_tensor(x) and x.is_complex() + y_is_complex_tensor = torch.is_tensor(y) and y.is_complex() + if not x_is_complex_tensor or not y_is_complex_tensor: + return NotImplemented + z = y + if alpha is not None: + z = alpha * y + complex_type = torch.promote_types(x.dtype, y.dtype) + + # For complex typed `x`, `x.view(x.real.dtype)` doubles the last dimension and can cause problem + # when broadcasting the add. + def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor: + """Reshape tensor from [*initial_dims, last_dim] to *initial_dims, last_dim/2, 2]""" + # Get the current shape of the tensor + *initial_dims, last_dim = tensor.shape + + # Check if the last dimension is even. We should never reach here since `x.view(x.real.dtype)` + # doubles the last dimension for complex numbers. + if last_dim % 2 != 0: + raise AssertionError( + "The size of the last dimension must be even to reshape it to [..., last_dim/2, 2]" + ) + + # Reshape the tensor + new_shape = (*initial_dims, last_dim // 2, 2) + reshaped_tensor = tensor.view(new_shape) + return reshaped_tensor + + x_reshaped = reshape_tensor_complex(x.view(x.real.dtype)) + z_reshaped = reshape_tensor_complex(z.view(y.real.dtype)) + result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type) + return result + + +@register_decomposition([aten.conj_physical]) +def conj_physical(self: torch.Tensor) -> torch.Tensor: + assert not self.is_complex(), "TODO: implement this" + return self + + +@register_decomposition([aten.lift, aten.detach_]) +def lift(self: torch.Tensor) -> torch.Tensor: + return self + + +@register_decomposition([aten.bernoulli.default]) +def bernoulli( + self: torch.Tensor, + *, + generator: Optional[torch.Generator] = None, +) -> torch.Tensor: + assert generator is None + return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) + + +@register_decomposition([aten.fmin, prims.fmin]) +def fmin(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.where(torch.isnan(other) | (other > self), self, other) + + +@register_decomposition([aten.fmax, prims.fmax]) +def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.where(torch.isnan(other) | (other < self), self, other) + + +@register_decomposition(aten.amax) +def amax( + self: torch.Tensor, + dim: Optional[int] = None, + keepdim: bool = False, +) -> torch.Tensor: + if self.dtype == torch.bool: + return torch.any(self, dim=dim, keepdim=keepdim) + return NotImplemented + + +@register_decomposition(aten.amin) +def amin( + self: torch.Tensor, + dim: Optional[int] = None, + keepdim: bool = False, +) -> torch.Tensor: + if self.dtype == torch.bool: + return torch.all(self, dim=dim, keepdim=keepdim) + return NotImplemented + + +@register_decomposition([aten.narrow_copy]) +def narrow_copy( + self: torch.Tensor, + dim: int, + start: int, + length: int, +) -> torch.Tensor: + return torch.narrow(self, dim, start, length).clone() + + +@register_decomposition([aten.view_copy.default]) +def view_copy_default( + self: torch.Tensor, + size: List[Union[int, torch.SymInt]], +) -> torch.Tensor: + return aten.view(self, size).clone() + + +@register_decomposition([aten.view_copy.dtype]) +def view_copy_dtype( + self: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + return self.to(dtype).clone() + + +def get_like_layout( + tensor: torch.Tensor, + memory_format: Optional[torch.memory_format] = None, +) -> torch.memory_format: + # TODO: _to_copy tensor to stride permutation + if memory_format is torch.preserve_format or memory_format is None: + return utils.suggest_memory_format(tensor) + else: + return memory_format + + +@register_decomposition(aten.rand_like) +def rand_like( + self: torch.Tensor, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return torch.rand( + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randn_like) +def randn_like( + self: torch.Tensor, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return torch.randn( + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.full_like) +def full_like( + self: torch.Tensor, + fill_value: Union[int, float], + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[torch.device] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> torch.Tensor: + return torch.full( + [*self.size()], + fill_value, + dtype=dtype or self.dtype, + layout=layout or self.layout, + device=device or self.device, + requires_grad=requires_grad, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint_like.default) +def randint_like( + self: torch.Tensor, + high: int, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return aten.randint.low( + 0, + high, + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint_like.low_dtype) +def randint_like_low( + self: torch.Tensor, + low: int, + high: int, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return aten.randint.low( + low, + high, + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint.default) +def randint( + high: int, + size: List[Union[int, torch.SymInt]], + **kwargs: Any, +) -> torch.Tensor: + return aten.randint.low(0, high, size, **kwargs) + + +@register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default) +def linear_dynamic_fp16_unpacked_weight( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight) + return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight( + input, packed_weight, bias, weight.size()[0] + ) + + +@register_decomposition(_quantized.wrapped_quantized_linear.default) +def wrapped_quantized_linear( + input: torch.Tensor, + input_scale: torch.Tensor, + input_zero_point: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + bias: torch.Tensor, + out_scale: torch.Tensor, + out_zero_point: torch.Tensor, + out_channel: int, +) -> torch.Tensor: + packed_weight = torch.ops._quantized._wrapped_linear_prepack( + weight, weight_scale, weight_zero_point, bias + ) + return torch.ops._quantized._wrapped_quantized_linear_prepacked( + input, + input_scale, + input_zero_point, + packed_weight, + out_scale, + out_zero_point, + out_channel, + ) + + +@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack) +def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor: + def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor: + x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3)) + if sys.byteorder == "little": + return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None] + else: + return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None] + + scales = bitcast_u8_to_f32(packed[..., -8:-4]) + offsets = bitcast_u8_to_f32(packed[..., -4:]) + return packed[..., :-8].to(torch.float32) * scales + offsets + + +@register_decomposition([aten.grid_sampler_2d]) +@pw_cast_for_opmath +def grid_sampler_2d( + a: torch.Tensor, + grid: torch.Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, +) -> torch.Tensor: + # We do not expand the grid (_expand_grid=False) on cpu for performance reasons + # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x + # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) + # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. + # Thus we apply this hack to not expand the grid for this case. + _expand_grid = not ( + a.device == torch.device("cpu") + and interpolation_mode == 0 + and a.is_contiguous(memory_format=torch.contiguous_format) + ) + + output = decomp_grid_sampler_2d( + a, + grid=grid, + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + _expand_grid=_expand_grid, + ) + return output + + +@register_decomposition(aten._foreach_addcmul.Scalar) +def _foreach_addcmul_scalar( + self: List[torch.Tensor], + left_tensors: List[torch.Tensor], + right_tensors: List[torch.Tensor], + scalar: float = 1, +) -> List[torch.Tensor]: + return aten._foreach_add.List( + self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar + ) + + +@register_decomposition(aten._foreach_addcdiv.Scalar) +def _foreach_addcdiv_scalar( + self: List[torch.Tensor], + left_tensors: List[torch.Tensor], + right_tensors: List[torch.Tensor], + scalar: float = 1, +) -> List[torch.Tensor]: + return aten._foreach_add.List( + self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar + ) + + +@register_decomposition(aten._foreach_lerp.Scalar) +def _foreach_lerp_scalar( + start_tensors: List[torch.Tensor], + end_tensors: List[torch.Tensor], + weight: torch.types.Number, +) -> List[torch.Tensor]: + return aten._foreach_add.List( + start_tensors, + aten._foreach_mul.Scalar( + aten._foreach_sub.List(end_tensors, start_tensors), weight + ), + ) + + +@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd) +@register_decomposition(aten.miopen_batch_norm) +def miopen_batch_norm( + input: torch.Tensor, + weight: torch.Tensor, + bias: typing.Optional[torch.Tensor], + running_mean: typing.Optional[torch.Tensor], + running_var: typing.Optional[torch.Tensor], + training: bool, + exponential_average_factor: float, + epsilon: float, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + a, b, c = aten.native_batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + exponential_average_factor, + epsilon, + ) + + if training: + return (a, b, c) + return ( + a, + weight.new_zeros((0,)), + weight.new_zeros((0,)), + ) + + +@functools.lru_cache(None) +def fast_random_decomps() -> Dict[Any, Callable[..., Any]]: + return {**decompositions, **extra_random_decomps} + + +# TODO(aakhundov): replace this (and the above) Any by more +# specific type and fix all the cascading mypy errors +def select_decomp_table() -> Dict[Any, Callable[..., Any]]: + """decomps can change based on config""" + if config.fallback_random: + return decompositions + return fast_random_decomps() + + +@register_decomposition(aten.masked_scatter) +def masked_scatter( + self: torch.Tensor, + mask: torch.Tensor, + source: torch.Tensor, +) -> torch.Tensor: + from .codegen.common import BackendFeature, has_backend_feature + + if has_backend_feature(self.device, BackendFeature.MASKED_SCATTER_WITH_INDEX): + # This two-step algorithm is the same as eager CUDA, for eager CPU we + # use a 1-shot serial iteration. + self, mask = aten.broadcast_tensors([self, mask]) + source_idx = mask.reshape(-1).cumsum(0) - 1 + self_flat, mask_flat, source_flat = (x.flatten() for x in (self, mask, source)) + result = aten._unsafe_masked_index(source_flat, mask_flat, [source_idx], 0) + return torch.where(mask_flat, result, self_flat).view(self.shape) + return NotImplemented + + +@register_decomposition(quantized_decomposed.choose_qparams.tensor) +def choose_qparams_tensor( + input: torch.Tensor, + quant_min: int, + quant_max: int, + eps: float, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + min_val, max_val = torch.aminmax(input) + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.max(scale, torch.Tensor([eps])) + zero_point = quant_min - torch.round(min_val / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + return scale.to(torch.float64), zero_point.to(torch.int64) + + +@register_decomposition(aten.put) +def put( + self: torch.Tensor, + index: torch.Tensor, + source: torch.Tensor, + accumulate: bool = False, +) -> torch.Tensor: + flattened = self.flatten() + flattened = torch.index_put( + flattened, [index], source.reshape(index.shape), accumulate + ) + return flattened.reshape(self.shape) + + +@register_decomposition(aten.put_) +def put_( + self: torch.Tensor, + index: torch.Tensor, + source: torch.Tensor, + accumulate: bool = False, +) -> torch.Tensor: + out = aten.put(self, index, source, accumulate=accumulate) + return self.copy_(out) + + +@register_decomposition(aten._softmax_backward_data.default) +@pw_cast_for_opmath +def _softmax_backward_data( + grad_output: torch.Tensor, + output: torch.Tensor, + dim: int, + input_dtype: torch.dtype, +) -> torch.Tensor: + new_grad_output = grad_output * output + sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True) + # grad_input = new_grad_output - output * sum_new_grad + grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output) + + # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor + # if grad_output.device == torch.device("cpu"): + # return grad_input.contiguous() + + if grad_output.dtype != input_dtype: + grad_input = grad_input.to(input_dtype) + return grad_input.contiguous() + + +@register_decomposition(aten.index_reduce) +def index_reduce( + self: torch.Tensor, + dim: int, + index: torch.Tensor, + src: torch.Tensor, + reduction_type: str, + *, + include_self: bool = True, +) -> torch.Tensor: + if reduction_type == "mean" and not needs_fallback_due_to_atomic_add_limitations( + self.dtype + ): + true_division = self.dtype.is_floating_point or self.dtype.is_complex + ones = torch.ones_like(src) + if include_self: + out = self + counts = torch.ones_like(self).index_add(dim, index, ones) + else: + out = self.index_fill(dim, index, 0) + counts = torch.zeros_like(self).index_add(dim, index, ones) + counts = counts.masked_fill(counts < 1, 1) + out = out.index_add(dim, index, src) + return out / counts if true_division else out // counts + + if use_scatter_fallback( + aten.scatter_reduce_.two, + reduction_type, + self.dtype, + src.dtype, + src.device.type, + True, + ): + return NotImplemented + + repeats = self.shape[dim + 1 :].numel() * self.shape[:dim].numel() + index_shape = (index.numel(), *self.shape[dim + 1 :], *self.shape[:dim]) + perm = (*range(self.ndim - dim, self.ndim), 0, *range(1, self.ndim - dim)) + scatter_index = ( + index.to(torch.int64) + .repeat_interleave(repeats) + .reshape(index_shape) + .permute(perm) + ) + return self.scatter_reduce( + dim, + scatter_index, + src, + reduction_type, + include_self=include_self, + ) + + +@register_decomposition(aten.max_pool2d_with_indices) +def max_pool2d_with_indices( + x: torch.Tensor, + kernel_size: List[int], + stride: Optional[Union[int, List[int]]] = None, + padding: Union[int, List[int]] = 0, + dilation: Union[int, List[int]] = 1, + ceil_mode: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if dilation == 1: + dilation = [1, 1] + + if padding == 0: + padding = [0, 0] + + if not stride: + stride = kernel_size + + kernel_size = pad_listlike(kernel_size, 2) + dilation = pad_listlike(dilation, 2) + padding = pad_listlike(padding, 2) + stride = pad_listlike(stride, 2) + + window_size = kernel_size[0] * kernel_size[1] + # We fallback when using non-default dilation or when the window size is too large + if ( + torch._inductor.lowering.should_fallback_max_pool2d_with_indices( + kernel_size, dilation + ) + or window_size > torch.iinfo(torch.int8).max + ): + return NotImplemented + + vals, offsets = prims._low_memory_max_pool2d_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ) + indices = prims._low_memory_max_pool2d_offsets_to_indices( + offsets, + kernel_size[1], + x.size(-1), + stride, + padding, + ) + return vals, indices diff --git a/lib/python3.10/site-packages/torch/_inductor/dependencies.py b/lib/python3.10/site-packages/torch/_inductor/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..95a48281977c87002e9ee486ae80e0b58b706d33 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/dependencies.py @@ -0,0 +1,745 @@ +# mypy: allow-untyped-defs +import abc +import dataclasses +import itertools +import logging +import re +import typing +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from unittest.mock import patch + +import sympy + +import torch +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.utils._ordered_set import OrderedSet + +from .codegen.common import index_prevent_reordering +from .utils import ( + get_dtype_size, + reduction_num_outputs, + sympy_index_symbol, + sympy_str, + sympy_subs, + VarRanges, +) +from .virtualized import OpsHandler, ReductionType, V + + +log = logging.getLogger(__name__) +is_indirect = re.compile(r"indirect|tmp").search + + +class Dep(abc.ABC): + name: str + index: sympy.Expr + + @abc.abstractmethod + def rename(self, renames: Dict[str, str]) -> "Dep": + pass + + @abc.abstractmethod + def get_numel(self) -> sympy.Expr: + pass + + @abc.abstractmethod + def numbytes_hint(self): + pass + + @abc.abstractmethod + def has_unbacked_symbols(self) -> bool: + pass + + @abc.abstractmethod + def is_contiguous(self) -> bool: + pass + + def normalize_with_stride_order(self, prefix="t"): + return self + + +@dataclasses.dataclass(frozen=True) +class MemoryDep(Dep): + name: str + index: sympy.Expr + var_names: Tuple[sympy.Symbol, ...] + size: Tuple[sympy.Expr, ...] + mode: Optional[str] = None + + def __repr__(self) -> str: + return f"MemoryDep({self.name!r}, {self.index}, {self.ranges}, {self.mode})" + + @property + def num_vars(self): + return len(self.var_names) + + def decide_loop_order_to_match(self, other): + """ + Can return None if not able to decide loop orders. + """ + assert self.num_vars == other.num_vars + + # ignore broadcast for now since broadcast causes extra 0 strides + # which makes it hard to decide the correct loop orders. + if self.num_vars != len(self.index.free_symbols): + return None + if other.num_vars != len(other.index.free_symbols): + return None + + # bail out if any size is 0 or 1 + # For size == 0, it's an empty tensor, any strides for that dimension + # are equivalent. Skip for simplicity and it may not matter that much. + # + # For size == 1, it cause cause tie for strides of different dimensions. + # Also when we first time create LoopBody in ComputedBuffer.simplify_and_reorder + # we can dependencies.index_vars_squeeze which should already sqeeuze + # the size == 1 dimensions. + if any(s == 0 or s == 1 for s in itertools.chain(self.size, other.size)): + return None + + # Extract strides for both expression + self_strides = V.graph.sizevars.stride_hints(self.index, self.var_names) + other_strides = V.graph.sizevars.stride_hints(other.index, other.var_names) + + # Even if the shape contains no 0/1, some complex index expression may + # still have duplicate stride values. Here is an example: + # https://gist.github.com/shunting314/511a7e1ec88aa2e1a8ec85d8445ab129 + # We don't reorder the loop for these cases for now, but in theory + # we could improve the algorithm to detect the correct loop orders. + if len(set(self_strides)) != len(self_strides) or len( + set(other_strides) + ) != len(other_strides): + log.debug( + "unable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%s", + self, + other, + self_strides, + other_strides, + ) + return None + + # May hanppen if self and other are as follows + # MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None) + # MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None) + if set(self_strides) != set(other_strides): + return None + + stride_to_index = {s: i for i, s in enumerate(self_strides)} + order = [] + for s in other_strides: + order.append(stride_to_index[s]) + + assert set(order) == set(range(0, self.num_vars)) + return order + + def get_offset(self): + """ + Return the offset by setting every variable to be 0. + """ + return sympy_subs(self.index, dict.fromkeys(self.var_names, 0)) + + def normalize(self) -> "MemoryDep": + """ + Normalize by merging loops. The different to normalize_with_stride_order is, + this method does not reorder loops while normalize_with_stride_order reorder + loops based on stride order. + """ + return MemoryDep( + self.name, + *_RecordLoadStoreInner._normalize(self.index, self.ranges), # type: ignore[arg-type] + self.mode, + ) + + def normalize_with_stride_order(self, prefix="t"): + r""" + Used to decide if two MemoryDep does not equal due to different loop orders. + More specifically, when dep1 and dep2 are not equal, we can normalize + both and check if they are equal after that. If yes, then the mismatch is + caused by different loop orders. + """ + # import here to avoid circular import + from torch._inductor import ir + + strides = V.graph.sizevars.stride_hints(self.index, self.var_names) + + # pick a loop order with stride ordered decreasingly + order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) + stride_reorder = ir.same_reorder(order) + sizes = self.size + var_names = self.var_names + + new_reordered_sizes = stride_reorder(sizes) + new_reordered_var_names = stride_reorder(var_names) + + new_simplified_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + new_reordered_var_names, + new_reordered_sizes, + index_prevent_reordering( + [self.index], new_reordered_var_names, new_reordered_sizes + ), + ) + + # now let's create new symbols with the passed in prefix + var_ranges, add_var = var_builder(prefix) + replacement = dict( + zip( + new_reordered_var_names, + reindex([add_var(x) for x in new_simplified_sizes]), + ) + ) + new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR + + out = MemoryDep(self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())) # type: ignore[arg-type] + return out + + @property + def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]: + """{c0: 128, c1: 512, ...}""" + return dict(zip(self.var_names, self.size)) + + def get_numel(self) -> sympy.Expr: + if self.is_indirect(): + numel = V.graph.get_numel(self.name) + else: + vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols) + numel = sympy.Integer(1) + for var, size in zip(self.var_names, self.size): + if var in vars: + numel = numel * size + return numel # type: ignore[return-value] + + def rename(self, renames: Dict[str, str]) -> "MemoryDep": + if self.name in renames: + return MemoryDep( + renames[self.name], + self.index, + var_names=self.var_names, + size=self.size, + mode=self.mode, + ) + return self + + def numbytes_hint(self): + return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( + V.graph.get_dtype(self.name) + ) + + def has_unbacked_symbols(self): + return len(free_unbacked_symbols(self.get_numel())) > 0 + + def is_contiguous(self) -> bool: + return isinstance(self.index, sympy.Symbol) and self.index in self.var_names + + def stride1_for_last_dim(self, result_for_complex_expression=True) -> bool: + """ + Whether the stride for the last dimension is 1. + """ + # python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_masked_scatter_cuda_float16 + # will exercise thru this corner case. + if len(self.var_names) == 0: + return True + + terms = self.index.args if isinstance(self.index, sympy.Add) else [self.index] + + last_sym = self.var_names[-1] + for term in terms: + if term is last_sym: + return True + + # Having a >1 stride for the last dimension is bad for perf + # return False. + if ( + isinstance(term, sympy.Mul) + and len(term.args) == 2 + and term.args[1] is last_sym + and isinstance(term.args[0], (int, sympy.Integer)) + and term.args[0] > 1 + ): + return False + + return result_for_complex_expression + + def is_scalar(self) -> bool: + if isinstance(self.index, sympy.Symbol): + return self.index not in self.var_names and not self.is_indirect() + return isinstance(self.index, (int, sympy.Integer)) + + def is_indirect(self) -> bool: + return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined] + + +@dataclasses.dataclass(frozen=True) +class StarDep(Dep): + name: str + mode: Optional[str] = None + + # depends on the entire buffer + @property + def index(self): + raise NotImplementedError("StarDep does not have an index") + + def get_numel(self) -> sympy.Expr: + return V.graph.get_numel(self.name) # type: ignore[return-value] + + def rename(self, renames: Dict[str, str]) -> "StarDep": + if self.name in renames: + return StarDep(renames[self.name], self.mode) + return self + + def numbytes_hint(self): + return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( + V.graph.get_dtype(self.name) + ) + + def has_unbacked_symbols(self): + return len(free_unbacked_symbols(self.get_numel())) > 0 + + def is_contiguous(self) -> bool: + return False + + def is_scalar(self) -> bool: + return False + + def is_indirect(self) -> bool: + return False + + +# Used for tracking mutation ordering +# if A reads a buffer and B mutates it +# B must be ordered after A +# +# This is useful for a variety of reasons. +# For example, if A's read is never actually used, we can eliminate it. +# Another case is if A's buffer ends up being fused away, we never need to +# materialize that buffer +@dataclasses.dataclass(frozen=True) +class WeakDep(Dep): + # Fake dependency on unused buffer + name: str + # Buffer that is doing the mutation + mutating_buf: str + + @property + def index(self): + raise NotImplementedError("WeakDep does not have an index") + + def get_numel(self) -> sympy.Expr: + return sympy.Integer(1) + + def rename(self, renames: Dict[str, str]) -> "WeakDep": + if self.name in renames: + return WeakDep(renames[self.name], self.mutating_buf) + return self + + def numbytes_hint(self): + return 1 # Purely inserted for ordering, not an actual dep + + def has_unbacked_symbols(self): + return False + + def is_contiguous(self) -> bool: + return False + + +@dataclasses.dataclass(frozen=True) +class IndexExprDep: + index: sympy.Expr # type: ignore[assignment] + var_names: Tuple[sympy.Symbol, ...] + size: Tuple[sympy.Expr, ...] + + +@dataclasses.dataclass +class ReadWrites: + reads: OrderedSet[Dep] + writes: OrderedSet[Dep] + index_exprs: OrderedSet[IndexExprDep] + range_vars: Optional[List[sympy.Expr]] = None + var_ranges: Optional[VarRanges] = None + + def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites": + return ReadWrites( + OrderedSet(dep.rename(renames) for dep in self.reads), + OrderedSet(dep.rename(renames) for dep in self.writes), + self.index_exprs, + self.range_vars, + self.var_ranges, + ) + + def with_read(self, dep: Union[Dep, Set[Dep]]) -> "ReadWrites": + assert isinstance(dep, (WeakDep, StarDep, set)) + if not isinstance(dep, set): + dep = {dep} + return ReadWrites( + OrderedSet.union(self.reads, dep), + self.writes, + self.index_exprs, + self.range_vars, + self.var_ranges, + ) + + def merge(self, other: "ReadWrites"): + reads = OrderedSet.union(self.reads, other.reads) + writes = OrderedSet.union(self.writes, other.writes) + index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs) + return ReadWrites(reads - writes, writes, index_exprs) + + @staticmethod + def merge_list(read_writes: List["ReadWrites"]): + all_writes = OrderedSet.union(*[rw.writes for rw in read_writes]) + all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes + all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes]) + return ReadWrites(all_reads, all_writes, all_index_exprs) + + def remove_reads(self, rem_reads): + return ReadWrites( + self.reads - rem_reads, + self.writes, + self.index_exprs, + self.range_vars, + self.var_ranges, + ) + + def reads_and_writes(self): + return itertools.chain(self.reads, self.writes) + + def buffer_names(self, ignore_integer_index=True): + """ + Integer index is used for load_seed. + """ + names: OrderedSet[str] = OrderedSet() + for dep in self.reads_and_writes(): + if not isinstance(dep, MemoryDep): + continue + if not ignore_integer_index or not isinstance( + dep.index, (int, sympy.Integer) + ): + names.add(dep.name) + return names + + +class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] + def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: + super().__init__() + self._reads: OrderedSet[Dep] = OrderedSet() + self._writes: OrderedSet[MemoryDep] = OrderedSet() + self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet() + self._var_ranges: VarRanges = var_ranges + self._should_normalize: bool = normalize + + @staticmethod + def drop_unused_symbols(index, var_names, sizes): + """ + Reduction has last (reduced) dim in its sizes, but + downstream users won't. Normalize this away. + """ + if not isinstance(index, sympy.Expr): + # index can be an int + return + free_symbols = index.free_symbols + while var_names and var_names[-1] not in free_symbols: + var_names.pop() + sizes.pop() + + @classmethod + def _normalize( + cls, index: sympy.Expr, var_ranges: VarRanges + ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: + # Try to further simplify the indexes even if simplify_loops didn't + # convert it to the simplest form because of the interference from + # different indexing formulas. + index_vars = [*var_ranges.keys()] + sizes = tuple(var_ranges.values()) # type: ignore[assignment] + new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + index_vars, + sizes, + index_prevent_reordering([index], index_vars, sizes), + ) + + # assign new variables each dimension to deal with numbering mismatches + # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 + new_vars, add_var = var_builder(canonicalization_prefix()) + replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) + index = sympy_subs(sympy.expand(index), replacement) + + new_vars = [*new_vars.keys()] + new_sizes = [*new_sizes] + cls.drop_unused_symbols(index, new_vars, new_sizes) + return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type] + + def canonicalize( + self, index: sympy.Expr + ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: + if not self._should_normalize: + sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()] + var_names = [k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1] + sizes = [v for v in sizes if v != 1] + + self.drop_unused_symbols(index, var_names, sizes) + + return index, tuple(var_names), tuple(sizes) # type: ignore[return-value, arg-type] + var_ranges = { + k: V.graph.sizevars.simplify(v) + for k, v in self._var_ranges.items() + # TODO(jansel): explore this further normalization + # if k in free_symbols + } + return self._normalize(index, var_ranges) + + def load(self, name: str, index: sympy.Expr) -> str: + self._reads.add(MemoryDep(name, *self.canonicalize(index))) + return f"load({name}, {sympy_str(index)})" + + def load_seed(self, name: str, index: int): + assert isinstance(index, int) + return self.load(name, sympy.Integer(index)) + + def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str: + self._writes.add(MemoryDep(name, *self.canonicalize(index), mode=mode)) + return f"store({name}, {sympy_str(index)}, {value}, {mode})" + + def store_reduction(self, name: str, index, value) -> str: + return self.store(name, index, f"store_reduction({value})") + + def index_expr(self, index: sympy.Expr, dtype) -> str: + self._index_exprs.add(IndexExprDep(*self.canonicalize(index))) + return f"index_expr({sympy_str(index)}, {dtype})" + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + self._reads.add(StarDep(offsets_name)) + return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})" + + +class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined] + def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: + parent_handler = _RecordLoadStoreInner( + var_ranges=var_ranges, normalize=normalize + ) + super().__init__(parent_handler=parent_handler) + + +# TODO: check call sites +def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]: + cnt = itertools.count() + var_ranges: VarRanges = {} + + def add_var(length: sympy.Expr) -> sympy.Symbol: + v = sympy_index_symbol(f"{prefix}{next(cnt)}") + var_ranges[v] = length + return v + + return var_ranges, add_var + + +def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str): + var_ranges, add_var = var_builder(prefix) + args: List[List[sympy.Symbol]] = [] + for size in argsizes: + args.append(list(map(add_var, size))) + return args, var_ranges + + +def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"): + from .ir import SqueezeView + + var_ranges, add_var = var_builder(prefix) + args: List[List[sympy.Expr]] = [] + new_sizes: List[List[sympy.Expr]] = [] + for size in argsizes: + new_size, reindex = SqueezeView.squeezer(size) + new_sizes.append(new_size) + args.append(reindex(list(map(add_var, new_size)))) + return args, var_ranges + + +def extract_read_writes( + fn: Callable[..., Any], + *argsizes: Tuple[sympy.Expr, ...], + normalize: bool = False, + prefix: str = "d", + hidden_args=(), +): + args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix) + + from .loop_body import LoopBody, MemoryUsageType + + if isinstance(fn, LoopBody): + # Fast path to avoid tracing when we already have a LoopBody + inner = _RecordLoadStoreInner(var_ranges=var_ranges, normalize=normalize) + name_to_index = fn.indexing_from_args([*args, *hidden_args]) + if fn.indirect_vars: + # mimic the `tmpX` naming tracing gives us + repl = {v: sympy.Symbol(f"tmp{i}") for i, v in enumerate(fn.indirect_vars)} + name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()} + for entry in fn.memory_usage[MemoryUsageType.LOAD]: + inner.load(entry.buffer_name, name_to_index[entry.index_name]) + for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]: + inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) + for entry in fn.memory_usage[MemoryUsageType.STORE]: + inner.store( + entry.buffer_name, name_to_index[entry.index_name], None, entry.mode + ) + for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]: + inner.store_reduction( + entry.buffer_name, name_to_index[entry.index_name], None + ) + for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]: + inner.index_expr(name_to_index[entry.index_name], None) + for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]: + inner.bucketize( + None, entry.buffer_name, name_to_index[entry.index_name], None, None + ) + # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped + else: + # Slow path tracing the function + rw = RecordLoadStore(var_ranges, normalize=normalize) + with V.set_ops_handler(rw): + fn(*args, *hidden_args) + inner = rw.parent_handler + + if normalize: + range_vars = [] # Number of vars could differ due to normalization + else: + range_vars = [*itertools.chain.from_iterable(args)] + + return ReadWrites( + OrderedSet(inner._reads), + OrderedSet(inner._writes), + inner._index_exprs, + range_vars, + var_ranges, + ) + + +def extract_input_node_reduction_ranges( + input_node: "torch._inductor.ir.TensorBox", +) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]: + """ + Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same. + It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes. + In this case, reduction_sizes of the Reduction nodes need to be the same. + Otherwise returns (None, None). + """ + + from .ir import ComputedBuffer, Loops + + if isinstance(input_node.data, ComputedBuffer): + # Input node has already been realized. Return its size and reduction_size. + size = input_node.get_size() + reduction_size = input_node.get_reduction_size() + if len(reduction_size) > 0: + return (size, reduction_size) + else: + return (None, None) + + if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined] + # Other IRNodes do not have reduction_ranges. + return (None, None) + + # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes? + # The current method still uses reduction ranges from the dependent realized node, which is not ideal. + # Is there a way to check whether there are permutations inbetween? + reads = input_node.get_reads() + reduction_size = None + size = None + while reduction_size is None and len(reads) > 0: + seen: OrderedSet[str] = OrderedSet() + new_reads = [] + for read in reads: + if not isinstance(read, MemoryDep): + continue + if read.name in seen: + continue + seen.add(read.name) + buffer = V.graph.try_get_buffer(read.name) + if buffer is None: + continue + op = buffer.get_defining_op() + if op is None: + continue + + if isinstance(op, ComputedBuffer) and len(op.get_reduction_size()) > 0: + if reduction_size is None: + reduction_size = op.get_reduction_size() + size = op.get_size() + elif reduction_size != op.get_reduction_size() or size != op.get_size(): + return (None, None) + else: + new_reads.extend(op.get_reads()) + if reads == new_reads: + return (size, reduction_size) + else: + reads = new_reads + return (size, reduction_size) + + +def canonicalization_prefix(): + return "c" + + +# ops handler which computes all the free unbacked symbols for an IR +class FreeUnbackedSymbolsOpsHandler: + symbols: OrderedSet[sympy.Symbol] + + def __init__(self) -> None: + self.symbols = OrderedSet() + + def __getattr__(self, name: str) -> Callable[..., Any]: + def inner(*args, **kwargs): + for a in itertools.chain(args, kwargs.values()): + if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): + self.symbols |= free_unbacked_symbols(a) + + return inner + + def indirect_indexing( + self, index_var, size, check=True, wrap_neg=True + ) -> sympy.Symbol: + assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean)) + self.symbols |= free_unbacked_symbols(size) + return sympy_index_symbol(f"({str(index_var)})") + + def frexp(self, x): + return (None,) * 2 + + def scan(self, dtypes, combine_fn, values): + return (None,) * len(values) + + def sort(self, dtypes, values, stable, descending): + return (None,) * len(values) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[None, Tuple[None, ...]], + ) -> Union[None, Tuple[None, ...]]: + num_values = reduction_num_outputs(reduction_type) + return (None,) * num_values if num_values > 1 else None + + +def _typecheck_FreeUnbackedSymbolsOpsHandler( + h: FreeUnbackedSymbolsOpsHandler, +) -> OpsHandler[None]: + return h + + +def extract_free_unbacked_symbols(fn: Callable[..., Any], index, rindex=None): + from .ir import FlexibleLayout + + args = [index, rindex] if rindex is not None else [index] + handler = FreeUnbackedSymbolsOpsHandler() + # NB: I cargo culted the allow_indexing patch here, I don't understand why + # people do this all over + with V.set_ops_handler(handler), patch.object( + FlexibleLayout, "allow_indexing", True + ): + fn(*args) + return handler.symbols diff --git a/lib/python3.10/site-packages/torch/_inductor/exc.py b/lib/python3.10/site-packages/torch/_inductor/exc.py new file mode 100644 index 0000000000000000000000000000000000000000..728f652032d52affff0bfe216e920a4a97cd51d3 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/exc.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import os +import tempfile +import textwrap +from functools import lru_cache + + +if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1": + + @lru_cache(None) + def _record_missing_op(target): + with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd: + fd.write(str(target) + "\n") + +else: + + def _record_missing_op(target): # type: ignore[misc] + pass + + +class OperatorIssue(RuntimeError): + @staticmethod + def operator_str(target, args, kwargs): + lines = [f"target: {target}"] + [ + f"args[{i}]: {arg}" for i, arg in enumerate(args) + ] + if kwargs: + lines.append(f"kwargs: {kwargs}") + return textwrap.indent("\n".join(lines), " ") + + +class MissingOperatorWithoutDecomp(OperatorIssue): + def __init__(self, target, args, kwargs) -> None: + _record_missing_op(target) + super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}") + + +class MissingOperatorWithDecomp(OperatorIssue): + def __init__(self, target, args, kwargs) -> None: + _record_missing_op(target) + super().__init__( + f"missing decomposition\n{self.operator_str(target, args, kwargs)}" + + textwrap.dedent( + f""" + + There is a decomposition available for {target} in + torch._decomp.get_decompositions(). Please add this operator to the + `decompositions` list in torch._inductor.decomposition + """ + ) + ) + + +class LoweringException(OperatorIssue): + def __init__(self, exc: Exception, target, args, kwargs) -> None: + super().__init__( + f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}" + ) + + +class SubgraphLoweringException(RuntimeError): + pass + + +class InvalidCxxCompiler(RuntimeError): + def __init__(self) -> None: + from . import config + + super().__init__( + f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}" + ) + + +class CppWrapperCodeGenError(RuntimeError): + def __init__(self, msg: str) -> None: + super().__init__(f"C++ wrapper codegen error: {msg}") + + +class CppCompileError(RuntimeError): + def __init__(self, cmd: list[str], output: str) -> None: + if isinstance(output, bytes): + output = output.decode("utf-8") + + super().__init__( + textwrap.dedent( + """ + C++ compile error + + Command: + {cmd} + + Output: + {output} + """ + ) + .strip() + .format(cmd=" ".join(cmd), output=output) + ) + + +class CUDACompileError(CppCompileError): + pass diff --git a/lib/python3.10/site-packages/torch/_inductor/extern_node_serializer.py b/lib/python3.10/site-packages/torch/_inductor/extern_node_serializer.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed505d3e60ee9896aed5953ab114779ecd54dcd --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/extern_node_serializer.py @@ -0,0 +1,25 @@ +import json +from typing import List + +from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node +from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder +from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode + + +def serialize_extern_kernel_node( + extern_kernel_node: inductor_ExternKernelNode, +) -> ExternKernelNode: + assert isinstance(extern_kernel_node.node, Node) + return ExternKernelNode( + name=extern_kernel_node.name, + node=extern_kernel_node.node, + ) + + +def extern_node_json_serializer( + extern_kernel_nodes: List[inductor_ExternKernelNode], +) -> str: + serialized_nodes = ExternKernelNodes( + nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes] + ) + return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder) diff --git a/lib/python3.10/site-packages/torch/_inductor/freezing.py b/lib/python3.10/site-packages/torch/_inductor/freezing.py new file mode 100644 index 0000000000000000000000000000000000000000..9b936bf7bb1b1b4ef00195874a8d51fc737d0b16 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/freezing.py @@ -0,0 +1,269 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import itertools +import logging +import weakref +from typing import Any, List, Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code +from torch._functorch.aot_autograd import MutationType +from torch._functorch.compile_utils import fx_graph_cse +from torch._inductor.constant_folding import constant_fold, replace_node_with_constant +from torch._inductor.fx_passes.freezing_patterns import freezing_passes +from torch._inductor.fx_passes.post_grad import view_to_reshape + +from . import config + + +aten = torch.ops.aten +prims = torch.ops.prims + +log = logging.getLogger(__name__) + + +def replace_params_with_constants( + gm: torch.fx.GraphModule, + flat_params: list[Any], + fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta, +) -> List[int]: + """ + Replaces the parameters of a PyTorch GraphModule with constants wherever possible. + Returns a list of indices representing the input parameters that were not converted to constants. + """ + params = gm.graph.find_nodes(op="placeholder") + fake_inp_nodes = params[: len(params)] + preserved_arg_indices = [] + aliased_input_args = [ + out_info.base_idx + for out_info in fw_metadata.output_info + if out_info.base_idx is not None + ] + + # TODO (tmanlaibaatar) figure out why this is different + # from mutated_inp_runtime_indices + mutated_inps = [ + i + for i, m in enumerate(fw_metadata.input_info) + if m.mutation_type + in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH) + ] + + for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)): + if i in mutated_inps or i in aliased_input_args: + preserved_arg_indices.append(i) + continue + replace_node_with_constant(gm, node, real_input) + # add on non param inputs + preserved_arg_indices.extend(range(len(flat_params), len(params))) + # is this necessary ? + gm.recompile() + return preserved_arg_indices + + +def freeze( + dynamo_gm: torch.fx.GraphModule, + aot_autograd_gm: torch.fx.GraphModule, + example_inputs: List[torch._subclasses.FakeTensor], +) -> Tuple[torch.fx.GraphModule, List[int]]: + """ + Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation + and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency. + + Assumes that this function is run in dynamo tracing post aot_autograd. + + Args: + dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule. + aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen. + example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process. + + Returns: + Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices + of the inputs that were preserved (not turned into constants). + """ + # We have convert conv's weight to channels last which may meet error for .view + # when doing fake_tensor_prop. So we need to convert view to reshape first. + # See the details in fx_codegen_and_compile of compile_fx.py. + view_to_reshape(aot_autograd_gm) + + if tracing_context := torch._guards.TracingContext.try_get(): + fw_metadata = tracing_context.fw_metadata + params_flat = tracing_context.params_flat + assert fw_metadata is not None and params_flat is not None + + preserved_arg_indices = replace_params_with_constants( + aot_autograd_gm, params_flat, fw_metadata + ) + else: + inputs = aot_autograd_gm.graph.find_nodes(op="placeholder") + preserved_arg_indices = list(range(len(inputs))) + + # TODO - further restrict cse ? right now needed to dedup aliasing ops + cse_graph = fx_graph_cse(aot_autograd_gm.graph) + aot_autograd_gm.graph = cse_graph + aot_autograd_gm.recompile() + + aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices] + freezing_passes(aot_autograd_gm, aot_example_inputs) + + constant_fold(aot_autograd_gm) + # invalidate nn Modules + if config.freezing_discard_parameters: + invalidate_eager_modules() + discard_traced_gm_params(dynamo_gm) + + log.debug( + "%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm, colored=True) + ) + + return aot_autograd_gm, preserved_arg_indices + + +class ErasedTensor(torch.Tensor): + @staticmethod + def __new__(cls, elem, name, owning_mod): + return super().__new__(cls, elem.to(device="meta")) + + def __init__(self, elem, name: Optional[str], mod) -> None: + self.erased_name = name + self.owning_mod_ref = weakref.ref(mod) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + erased_tensors = [ + e + for e in pytree.arg_tree_leaves(*args, **kwargs) + if isinstance(e, ErasedTensor) + ] + assert len(erased_tensors) > 0 + e = erased_tensors[0] + + raise RuntimeError( + f"Trying to run Pytorch Eager Module after Dynamo Freezing. " + "The original parameters have been discarded for memory efficiency. " + f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}" + ) + + +def invalidate_eager_modules(): + with torch.utils._python_dispatch._disable_current_modes(): + for ( + mod + ) in torch._guards.TracingContext.get().module_context.nn_modules.values(): + if not isinstance(mod, torch.nn.Module): + continue + + for attr_name, tensor in list( + itertools.chain( + mod.named_parameters(recurse=False), + mod.named_buffers(recurse=False), + ) + ): + with torch._dispatch.python.no_python_dispatcher(): + e_t = ErasedTensor(tensor, attr_name, mod) + if isinstance(tensor, torch.nn.Parameter): + e_t.requires_grad_(True) + e_t._is_param = True # type: ignore[attr-defined] + setattr(mod, attr_name, e_t) + + +def discard_traced_gm_params(mod: torch.fx.GraphModule): + with torch.utils._python_dispatch._disable_current_modes(): + for attr_name, tensor in list( + itertools.chain( + mod.named_parameters(recurse=False), mod.named_buffers(recurse=False) + ) + ): + with torch._dispatch.python.no_python_dispatcher(): + e_t = ErasedTensor(tensor, attr_name, mod) + if isinstance(tensor, torch.nn.Parameter): + e_t.requires_grad_(True) + e_t._is_param = True # type: ignore[attr-defined] + setattr(mod, attr_name, e_t) + + +def enforce_output_layout(gm: torch.fx.GraphModule): + """ + Make sure the output node's layout does not change due to compiler optimizations + by adding aten.as_strided nodes with the expected strides. + + Only used for inference so we can assume all graph outputs are model outputs. + """ + *_, output_node = gm.graph.nodes + out_list = output_node.args[0] + with gm.graph.inserting_before(output_node): + for n in out_list: + if not isinstance( + n.meta["val"], torch.Tensor + ) or not torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]): + continue + + # add a node to enforce eager layout + ft = n.meta["val"] + new_node = gm.graph.call_function( + prims.inductor_force_stride_order.default, (n, ft.stride()) + ) + + # can not call + # n.replace_all_uses_with(new_node) + # since it will replace the usage of n in new_node itself. + output_node.replace_input_with(n, new_node) + + gm.graph.lint() + gm.recompile() + + +def enforce_as_strided_input_layout(gm: torch.fx.GraphModule): + """ + Make sure the as_strided node's input's layout does not change due to compiler + optimizations, because the as_strided strides info depends on input tensor stride info. + """ + + as_strided_ops = [ + torch.ops.aten.as_strided.default, + torch.ops.aten.as_strided_.default, + torch.ops.aten.as_strided_scatter.default, + ] + strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops] + for n in strided_nodes: + with gm.graph.inserting_before(n): + # add a node to enforce eager layout + ft = n.args[0].meta["val"] + new_node = gm.graph.call_function( + prims.inductor_force_stride_order.default, (n.args[0], ft.stride()) + ) + n.replace_input_with(n.args[0], new_node) + + gm.graph.lint() + gm.recompile() + + +def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule): + """ + Convert 4d convolution weight tensor to channels last format. + + This pass is performed before freezing so the added nodes can be constant + folded by freezing. + """ + with dynamo_timed("convert_conv_weights_to_channels_last"): + convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default] + for conv in convs: + weight_node = conv.args[1] + if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[ + "val" + ].is_contiguous(memory_format=torch.channels_last): + # not a 4d tensor or already channels last, skip + continue + + with gm.graph.inserting_before(conv): + new_node = gm.graph.call_function( + aten.clone.default, + (weight_node,), + {"memory_format": torch.channels_last}, + ) + conv.replace_input_with(weight_node, new_node) + + enforce_as_strided_input_layout(gm) + enforce_output_layout(gm) diff --git a/lib/python3.10/site-packages/torch/_inductor/fx_utils.py b/lib/python3.10/site-packages/torch/_inductor/fx_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f1791004996204a0e2cb201c4a706f2e99408086 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/fx_utils.py @@ -0,0 +1,251 @@ +# mypy: allow-untyped-defs +import operator +from collections import defaultdict +from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type + +import sympy + +import torch +import torch.fx +from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + rebind_unbacked, + statically_known_true, + sym_eq, +) +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map + +from .virtualized import V + + +# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched. +# Works for length 2 patterns with 1 module and 1 function/method. +def matches_module_function_pattern( + pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]], + node: torch.fx.node.Node, + modules: Dict[str, torch.nn.modules.Module], +) -> bool: + if len(node.args) == 0: + return False + if not isinstance(node.args[0], torch.fx.Node) or not isinstance( + node, torch.fx.Node + ): + return False + # the first node is call_module + if node.args[0].op != "call_module": + return False + if not isinstance(node.args[0].target, str): + return False + if node.args[0].target not in modules: + return False + if type(modules[node.args[0].target]) is not pattern[0]: + return False + # the second node is call_function or call_method + if node.op != "call_function" and node.op != "call_method": + return False + if node.target != pattern[1]: + return False + # make sure node.args[0] output is only used by current node. + if len(node.args[0].users) > 1: + return False + return True + + +class FakeTensorUpdater: + """ + The main idea here is that it's difficult to maintain accurate fake + tensors (our primary form of metadata) for each node in our graph as we + transform it. + + The most reliable way to obtain this information is by rerunning + faketensor propagation. However, in general, faketensor propagation is + fairly expensive. So, instead we'd like to only rerun faketensor + propagation on nodes that have changed. + + In order to detect which nodes have changed, we first hash its node, + target, and argument lists (which are immutable in FX). + + Then, whenever we call incremental_update, we check which FX nodes have a + new hash, and recompute the faketensor metadata for that node. Then, we + continue to recursively compute the faketensors for all users until the + fake tensors stop changing. + """ + + def __init__(self, graph: torch.fx.Graph) -> None: + self.processed_hashes = set() + self.graph = graph + + for node in self.graph.nodes: + self.processed_hashes.add(self.hash_node(node)) + + def hash_node(self, node: torch.fx.Node): + # todo(chilli): Not a great hash function + return (node, node.target, id(node.args), id(node.kwargs)) + + def incremental_update(self): + processed = set() + existing_storages: DefaultDict[Optional[int], int] = defaultdict(int) + for node in self.graph.nodes: + existing_storages[get_node_storage(node)] += 1 + + def is_intlist_same(new, old): + return statically_known_true(sym_eq(new, old)) + + def is_fake_tensor_same(new, old): + if type(new) != type(old): + return False + if isinstance(new, (list, tuple)): + if len(new) != len(old): + return False + return all( + is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old) + ) + if new is None: + return old is None + if not isinstance(new, torch.Tensor): + assert isinstance( + new, (torch.SymInt, torch.SymBool, torch.SymFloat) + ), f"Unknown type {type(new)} in {self.graph}" + return ( + new.node.shape_env._maybe_evaluate_static( + sympy.Eq(new.node.expr, old.node.expr) + ) + == sympy.true + ) + if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout: + return False + if new.layout == torch.strided and ( + not is_intlist_same(new.stride(), old.stride()) + or not statically_known_true( + new.storage_offset() == old.storage_offset() + ) + ): + return False + + if new.device != old.device: + return False + + if get_storage(new) == get_storage(old): + return True + + # This is the case where it returns a completely fresh storage that's used nowhere else. + if ( + existing_storages[get_storage(old)] == 1 + and get_storage(new) not in existing_storages + ): + return True + return False + + def should_process_node(node): + # node.target for nodes returning true from this function + # are called under fake mode and does not work for inductor + # lowerings. We check if the node.target is an aten operator + # or operator.getitem which is used when returning multiple + # tensors from an op. + return node.op == "call_function" and ( + isinstance(node.target, torch._ops.OpOverload) + or node.target == operator.getitem + ) + + to_process = set() + for node in self.graph.nodes: + if ( + self.hash_node(node) in self.processed_hashes + and id(node) not in to_process + ): + continue + + if not should_process_node(node): + continue + + is_valid, args, kwargs = get_fake_args_kwargs(node) + if not is_valid: + continue + with V.fake_mode: + new_fake_tensor = node.target(*args, **kwargs) + if "val" in node.meta and is_fake_tensor_same( + new_fake_tensor, node.meta["val"] + ): + continue + + rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor) + + node.meta["val"] = new_fake_tensor + if (shape_env := V.fake_mode.shape_env) and ( + symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor) + ): + # Refresh the bindings to the new symbols + node.meta["unbacked_bindings"] = symbol_to_path + + existing_storages[get_node_storage(node)] += 1 + + to_process.update([id(user) for user in node.users]) + + self.processed_hashes.add(self.hash_node(node)) + + +def get_storage(t: torch.Tensor) -> int: + return t.untyped_storage()._cdata + + +def get_node_storage(node: torch.fx.Node) -> Optional[int]: + if "val" not in node.meta: + return None + if not isinstance(node.meta["val"], torch.Tensor): + return None + if not torch._C._has_storage(node.meta["val"]): + return None + return get_storage(node.meta["val"]) + + +def get_fake(x): + if isinstance(x, torch.fx.Node): + if "val" not in x.meta: + return x + return x.meta["val"] + return x + + +def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]: + """ + First value returns a boolean if any of the input nodes don't have a faketensor. + """ + args, kwargs = tree_map(get_fake, (x.args, x.kwargs)) + if any( + isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs) + ): + return False, args, kwargs + return True, args, kwargs + + +def is_node_realized(node: torch.fx.Node) -> bool: + """Returns true if a node is always realized when lowered to inductor IR. + + NOTE: This may return some false negatives. e.g. it doesn't + handle buffers realized heuristically during lowering, or + buffers realized indirectly through view ops. + """ + from torch._inductor.lowering import fallbacks, needs_realized_inputs + + def is_buffer(node: torch.fx.Node) -> bool: + if node.op == "call_function" and node.target is operator.getitem: + # For nodes with multiple outputs, we get the fx graph: + # foo = torch.ops.aten.foo(...) + # getitem = foo[0] + # getitem_1 = foo[1] + # where we need to check if foo is a fallback kernel + return is_buffer(node.args[0]) # type: ignore[arg-type] + return node.op in ("placeholder", "output") or node.target in fallbacks + + if is_buffer(node): + return True + + def realizes_inputs(node: torch.fx.Node) -> bool: + return node.op == "output" or node.target in needs_realized_inputs + + if any(realizes_inputs(user) for user in node.users): + return True + + # Otherwise, assume node isn't realized + return False diff --git a/lib/python3.10/site-packages/torch/_inductor/graph.py b/lib/python3.10/site-packages/torch/_inductor/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..3448ea6eb9620ad0094a7c86aa35244574f1ff2b --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/graph.py @@ -0,0 +1,1930 @@ +import functools +import itertools +import logging +import operator +import os +import re +import sys +import time +from collections import defaultdict +from contextlib import contextmanager +from types import ModuleType +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Iterable, + List, + NoReturn, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) + +import sympy +from sympy import Expr + +import torch +import torch._logging +import torch.fx +from torch import device, Tensor +from torch._decomp import get_decompositions +from torch._dynamo.utils import defake, dynamo_timed +from torch._logging import LazyString, trace_structured +from torch._prims_common import make_channels_last_strides_for +from torch._subclasses.fake_tensor import FakeTensor +from torch.fx import GraphModule +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import ( + free_unbacked_symbols, + has_free_symbols, + resolve_unbacked_bindings, + RuntimeAssert, + ShapeEnv, + SymTypes, +) +from torch.fx.graph import Graph +from torch.fx.node import Node +from torch.utils._mode_utils import no_dispatch +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.numbers import int_oo + +from . import config, ir +from .codegen.common import ( + BackendFeature, + DeviceOpOverrides, + get_backend_features, + get_device_op_overrides, + get_wrapper_codegen_for_device, + init_backend_registration, +) +from .exc import ( + CppWrapperCodeGenError, + LoweringException, + MissingOperatorWithDecomp, + MissingOperatorWithoutDecomp, +) +from .ir import ( + Constant, + FixedLayout, + get_device_type, + InputBuffer, + Pointwise, + Reduction, + StorageBox, + TensorBox, + TorchBindObject, +) +from .lowering import ( + FALLBACK_ALLOW_LIST, + fallback_handler, + fallback_node_due_to_unsupported_type, + lowerings, + make_fallback, + maybe_layout_constraints, + needs_realized_inputs, + unsupported_output_tensor, +) +from .scheduler import BaseSchedulerNode +from .sizevars import SizeVarAllocator +from .utils import ( + convert_shape_to_inductor, + gather_origins, + get_cloned_parameter_buffer_name, + get_sympy_Expr_dtype, + maybe_get_suppress_shape_guards_ctx, + should_assume_input_aligned, +) +from .virtualized import NullHandler, V + + +if TYPE_CHECKING: + from torch._higher_order_ops.effects import _EffectType + from .codegen.wrapper import WrapperCodeGen + +from torch._inductor.codecache import output_code_log + + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") + +aten = torch.ops.aten + +_post_grad_graph_counter = itertools.count() + +if config.is_fbcode(): + from torch._inductor.fb.utils import log_module_code +else: + + def log_module_code(*args: Any, **kwargs: Any) -> None: + pass + + +def supported_dtype_of_cpp_wrapper(dtype: torch.device, cuda: bool) -> bool: + supported_dtype = { + torch.float32, + torch.float64, + torch.int64, + torch.int32, + torch.int16, + torch.int8, + torch.uint8, + torch.bool, + torch.bfloat16, + torch.complex32, + torch.complex64, + torch.complex128, + torch.float16, + } + if cuda: + supported_dtype.add(torch.float8_e4m3fn) + supported_dtype.add(torch.float8_e5m2) + supported_dtype.add(torch.float8_e4m3fnuz) + supported_dtype.add(torch.float8_e5m2fnuz) + + return dtype in supported_dtype + + +def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]: + assert isinstance( + constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) + ), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer" + if isinstance(constant_buffer, sympy.core.numbers.Integer): + return torch.int64 + + if isinstance(constant_buffer, sympy.Expr): + return get_sympy_Expr_dtype(constant_buffer) + + if constant_buffer.is_integer: + return torch.int64 + elif constant_buffer.is_float: + return torch.float32 + else: + return None + + +def is_magic_method(op: Any) -> bool: + magic_ops = {method_to_operator(m) for m in magic_methods} + return op in magic_ops + + +def getattr_recursive( + obj: GraphModule, target: str +) -> Union[Tensor, torch._C.ScriptObject, GraphModule]: + target_atoms = target.split(".") + attr_itr = obj + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def mark_nodes_dislike_padding( + g: Graph, user_visible_outputs: Optional[Dict[str, None]] +) -> None: + """ + Nodes like convolution/convolution_backward want its input to be dense. + If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction. + + The pass finds nodes that dislike padding. These are nodes that can be reached + from a convolution/convolution_backward in the backward direction without + going thru a reduction. + """ + if not config.comprehensive_padding: + return + ops_dislike_padding = { + aten.convolution, + aten.convolution_backward, + } + # what's a better way to collect the reduction ops? + ops_like_padding = { + aten.var_mean, + aten.sum, + aten.mean, + aten.prod, + aten.any, + aten.amin, + aten.amax, + aten.min, + aten.max, + aten.argmin, + aten.argmax, + aten.scatter_reduce, + } + + def _get_overload_packet( + node: torch.fx.Node, + ) -> Optional[torch._ops.OpOverloadPacket]: + return ( + node.target._overloadpacket + if node.op == "call_function" + # hasattr on OpOverloadPacket is slow, do isinstance first + and isinstance(node.target, torch._ops.OpOverload) + and hasattr(node.target, "_overloadpacket") + else None + ) + + for cur in reversed(g.nodes): + op = _get_overload_packet(cur) + if not op: + continue + if op in ops_dislike_padding: + cur.meta["dislike_padding"] = True + + if cur.meta.get("dislike_padding", False): + # propagate + for prior in cur.all_input_nodes: + prior_op = _get_overload_packet(prior) + if not prior_op: + continue + if prior_op not in ops_like_padding: + prior.meta["dislike_padding"] = True + # We only want to mark output nodes. So, move it after the above prior nodes process. + if ( + not config.pad_outputs + and user_visible_outputs + and cur.name in user_visible_outputs + ): + cur.meta["dislike_padding"] = True + + +class GraphLowering(torch.fx.Interpreter): + graph_outputs: List[ir.IRNode] + + def symbolic_sizes_strides( + self, ex: torch.Tensor + ) -> Tuple[Union[List[int], List[Expr]], Union[List[int], List[Expr]]]: + """ + Support dynamic shapes and dynamic strides by assigning variables + to each dimension. We duck-shape tensors, so if two tensors + have the same size they get assigned the same symbolic variable. + """ + if self.reuse_shape_env: + return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor( + ex.stride() + ) + else: + from torch._dynamo.source import ConstantSource + + # TODO: this should not be needed once #93059 lands + # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816 + # TODO: make a dedicated UnknownSource for this? + # NB: This is using the legacy default behavior from + # create_symbolic_sizes_strides_storage_offset but we hope we can + # just delete this entirely + source = ConstantSource( + f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}" + ) + ( + size, + stride, + _, + ) = self._shape_env.create_symbolic_sizes_strides_storage_offset( + ex, + source, + ) + + size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size] + stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride] + return size, stride + + def static_sizes_strides( + self, ex: torch.Tensor + ) -> Tuple[List[sympy.Expr], List[sympy.Expr]]: + """ + Primarily used to weights + """ + size = [sympy.Integer(i) for i in ex.size()] + stride = [sympy.Integer(i) for i in ex.stride()] + return size, stride + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs: Optional[List[torch.Tensor]] = None, + shape_env: Optional[ShapeEnv] = None, + graph_id: Optional[int] = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + user_visible_outputs: Optional[Dict[str, None]] = None, + layout_opt: Optional[bool] = None, + extern_node_serializer: Optional[ + Callable[[List[ir.ExternKernelNode]], Any] + ] = None, + is_inference: bool = False, + is_const_graph: bool = False, + const_output_index: Optional[Dict[str, int]] = None, + const_code: Optional[str] = None, + const_module: Optional["GraphLowering"] = None, + name: Optional[str] = None, + ) -> None: + super().__init__(gm) + self.example_inputs = example_inputs + self.layout_opt = ( + layout_opt + if layout_opt is not None + else self.decide_layout_opt(gm, is_inference=is_inference) + ) + self.num_channels_last_conv = 0 + self.is_inference = is_inference + self.is_const_graph = is_const_graph + self.const_code = const_code + self.const_module = const_module + + self.extra_traceback = False # we do our own error wrapping + if shape_env is None: + shape_env = ShapeEnv() + self.reuse_shape_env = False + else: + self._shape_env = shape_env + self.reuse_shape_env = True + self._shape_env = shape_env + # We are going to start code generating runtime asserts, so make sure + # you don't start adding new ones in the lowering process + shape_env.freeze_runtime_asserts() + # We're going to mutate ras_by_symbol as we finish generating them + self.ras_by_symbol: Dict[ + sympy.Symbol, List[RuntimeAssert] + ] = shape_env.deferred_runtime_asserts.copy() + self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet() + self.sizevars = SizeVarAllocator(shape_env) + self.graph_input_names: List[str] = [] + self.graph_inputs: Dict[str, TensorBox] = {} + self.graph_inputs_original: Dict[str, InputBuffer] = {} + self.zero_dim_cpu_tensor_list: OrderedSet[str] = OrderedSet() + self.device_types: OrderedSet[str] = ( + const_module.device_types if const_module else OrderedSet() + ) + self.device_idxs: OrderedSet[int] = ( + const_module.device_idxs if const_module else OrderedSet() + ) + self.cuda = False + self.buffers: List[ir.Buffer] = [] + self.operations: List[ir.Operation] = [] + self.const_output_index: Dict[str, int] = ( + const_output_index if const_output_index else {} + ) + self.folded_constants: OrderedSet[str] = ( + OrderedSet(const_output_index.keys()) + if const_output_index + else OrderedSet() + ) + self.constants: Dict[str, torch.Tensor] = ( + const_module.constants if const_module else {} + ) + self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {} + self.constant_reprs: Dict[str, str] = {} + self.removed_operations: OrderedSet[str] = OrderedSet() + self.removed_buffers: OrderedSet[str] = OrderedSet() + self.removed_inplace_buffers: OrderedSet[str] = OrderedSet() + self.mutated_buffers: OrderedSet[str] = OrderedSet() + self.never_reuse_buffers: OrderedSet[str] = OrderedSet() + self.inplaced_to_remove: OrderedSet[str] = OrderedSet() + self.device_ops: DeviceOpOverrides = None # type: ignore[assignment] + self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment] + # See `ProxyExecutor Design Note` in ir.py for more details + self.extern_kernel_nodes: List[ir.ExternKernelNode] = [] + + from torch._inductor.extern_node_serializer import extern_node_json_serializer + + self.extern_node_serializer: Callable[[List[ir.ExternKernelNode]], Any] = ( + extern_node_serializer + if config.is_fbcode() and extern_node_serializer + else extern_node_json_serializer + ) + + self.current_node: torch.fx.Node = None # type: ignore[assignment] + self.lists: Dict[str, List[str]] = {} + self.mutated_inputs: OrderedSet[str] = OrderedSet() + self.mutated_input_idxs: List[int] = [] + self.name_to_buffer: Dict[str, ir.Buffer] = {} + self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list) + self.name_to_op: Dict[str, ir.Operation] = {} + self.creation_time = time.time() + self.name = name # type: ignore[assignment] + self.cpp_wrapper = cpp_wrapper + + # record multi_kernel choice for cpp_wrapper so the second pass knows + # which sub-kernel is picked. Copy cpp_wrapper to another variable + # since cpp_wrapper flag is OrderedSet to false for the first pass of codegen. + self.record_multi_kernel_choice = cpp_wrapper + self.multi_kernel_to_choice: Dict[str, int] = {} + + self.aot_mode = aot_mode + self.graph_id = graph_id + self.post_grad_graph_id = next(_post_grad_graph_counter) + self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment] + self.nodes_prefer_channels_last = ( + self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet() + ) + self._warned_fallback = {"aten.convolution_backward"} + self.user_visible_outputs = ( + user_visible_outputs if user_visible_outputs is not None else {} + ) + mark_nodes_dislike_padding(gm.graph, user_visible_outputs) + self.cache_key: str = "" # This is the cache key for the compiled artifact + self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored + self.cache_linemap: List[ + Tuple[int, str] + ] = ( + [] + ) # This is the linemap used by the profiler to mark custom compiled kernels getting run + # Used if lowering encounters cases where cudagraphs are not supported + self.disable_cudagraphs_reason: Optional[str] = None + + # only keeping one node per device for stack trace purposes + self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {} + self.orig_gm: torch.fx.GraphModule = gm.__copy__() + self.dynamo_flat_name_to_original_fqn = self.module.meta.get( + "dynamo_flat_name_to_original_fqn", {} + ) + self.allocated_constant_name: Dict[str, str] = ( + const_module.allocated_constant_name if const_module is not None else {} + ) + init_backend_registration() + self.get_backend_features = functools.lru_cache(None)(get_backend_features) + + self.effectful_ops: Dict[_EffectType, ir.Buffer] = {} + self.aligned_inputs: OrderedSet[str] = OrderedSet() + self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet() + + # Below field is related to printing debug intermediate tensor values info for debugging + self.all_codegen_kernel_names: OrderedSet[str] = OrderedSet() + + def has_feature( + self, device: Union[torch._inductor.ir.IRNode, device], feature: BackendFeature + ) -> bool: + assert isinstance(feature, BackendFeature), feature + return feature in self.get_backend_features(get_device_type(device)) + + @staticmethod + def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool: + """ + Decide if we should enable layout optimization for this graph based on + heuristics. + """ + if not config.layout_optimization: + return False + + if config.force_layout_optimization: + return True + + conv_nodes = [ + n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default + ] + nconv = len(conv_nodes) + + if nconv == 0: + return False + + # For cpu backend and mkldnn enabled, we always use channels_last for better performance. + if ( + torch.backends.mkldnn.enabled + and torch.backends.mkldnn.is_available() + and all( + n.args[idx].meta["val"].device == torch.device("cpu") + for n in conv_nodes + for idx in [0, 1] + ) + ): + return True + + # Following models are skipped due to this: + # jx_nest_base + # volo_d1_224 + if len(list(gm.graph.nodes)) >= 300 * nconv: + log.debug("Skipped layout opt because only a few conv") + return False + + if any( + has_free_symbols(n.args[idx].meta["val"]) + for n in conv_nodes + for idx in [0, 1] + ): + log.debug( + "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670" + ) + return False + + def is_grouped(n: Any) -> bool: + meta_val = n.args[1].meta["val"] # type: ignore[union-attr, operator] + assert isinstance(meta_val, torch.Tensor) + return n.args[-1] > 1 and meta_val.size(1) > 1 # type: ignore[union-attr, operator] + + def is_in_out_channel(n: torch.fx.Node) -> bool: + return ( + n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) # type: ignore[union-attr, operator] + and n.args[1].meta["val"].size(2) > 1 # type: ignore[union-attr, operator] + ) + + def is_small_channel(n: torch.fx.Node) -> bool: + return ( + n.args[1].meta["val"].size(0) <= 64 # type: ignore[union-attr, operator] + and n.args[1].meta["val"].size(1) <= 64 # type: ignore[union-attr, operator] + ) + + # only grouped convolutions benchmarked as slower in conv samples for inference only + if is_inference: + from torch.utils.flop_counter import FlopCounterMode + + flop_counts: Dict[str, float] = defaultdict(float) + for node in conv_nodes: + success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs( + node + ) + + if success: + with FlopCounterMode(display=False) as flop_counter_mode: + with V.fake_mode: + node.target(*args, **kwargs) + + counted_flops = flop_counter_mode.get_total_flops() + if is_grouped(node): + node_type = "grouped" + elif is_small_channel(node): + node_type = "small" + elif is_in_out_channel(node): + node_type = "in_out" + else: + node_type = "default" + + flop_counts[node_type] += counted_flops + else: + log.debug("Conv inputs meta not found") + + # average benchmarked channels last speedup / slowdown, < 1 is speedup. + # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/ + # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb + GROUPED_MULTIPLIER = 1.358 + DEFAULT_MULTIPLIER = 0.823 + IN_OUT_MULTIPLIER = 0.725 + SMALL_MULTIPLIER = 0.783 + + total_flops = sum(flop_counts.values()) + # TODO - get different values per hardware + weighted_flops = ( + flop_counts["grouped"] * GROUPED_MULTIPLIER + + flop_counts["small"] * SMALL_MULTIPLIER + + flop_counts["in_out"] * IN_OUT_MULTIPLIER + + flop_counts["default"] * DEFAULT_MULTIPLIER + ) + do_layout_opt = weighted_flops <= total_flops + if not do_layout_opt: + log.debug( + "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d", + total_flops, + weighted_flops, + ) + return do_layout_opt + + # Channels last layout can dramatically hurt grouped conv perf. E.g. + # Conv with arguments like + # {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3], + # "stride": [2, 2], "padding": [1, 1], "groups": 2} + # slows down 31x using channels last.. + + # But a lot of timm models use depthwise separable convolution which will + # result in grouped convolution with in-channel size == 1. + # For those grouped convolution, channels last still helps a lot. + # E.g. + # Conv with arguments + # {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3], + # "stride": [2, 2], "padding": [1, 1], "groups": 58} + # get 1.86x speedup with channels last layout. + # + # The following heuristics skip using channels-last if the model contains + # grouped convolution with in-channels > 1. + if any(map(is_grouped, conv_nodes)): + log.debug( + "Skip layout opt because found grouped convolution with >1 in_channels!" + ) + return False + + # For some models that contain convolution with larger in-channel than out-channel, applying + # channels last hurts performance. + # Following models are skipped due to this: + # - pytorch_unet + # - phlippe_densenet (slightly worse) + # - Background_Matting (1.22x -> 0.821x) + # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x) + if any(map(is_in_out_channel, conv_nodes)): + log.debug( + "Skip layout opt because some convolutions have smaller out_channel" + ) + return False + + # Following models are skipped due to this: + # - functorch_maml_omniglot + if all(map(is_small_channel, conv_nodes)): + log.debug("Skip layout opt because all convolution channels are too small") + return False + + return True + + def qualify_name(self, name: str) -> str: + """Prepend the given name with the graph name if any.""" + if self.name is not None: + return f"{self.name}_{name}" + return name + + def make_subgraph( + self, + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + subgraph_name: str, + ) -> "GraphLowering": + """ + Make a subgraph of the current graph with all inherited + parts, except the graph module (`gm`) and `example_inputs`. + The subgraphs are lowered separately, but intended to be + inlined in the parent graph's codegening. Hence the need + for maintaining the same `shape_env` and other properties. + The subgraph name is qualified by the parent graph's name. + """ + return GraphLowering( + gm=gm, + example_inputs=example_inputs, + shape_env=self._shape_env, + cpp_wrapper=self.cpp_wrapper, + aot_mode=self.aot_mode, + extern_node_serializer=self.extern_node_serializer, + is_inference=self.is_inference, + name=self.qualify_name(subgraph_name), + ) + + def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]: + """ + The rule to decide if an node prefer channels last is simple. + 1. if it's input/output of a convolution + 2. if one of its user prefers channels last + + We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs; + Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers + channels last. + + Consider the scenario: conv -> batch-norm -> relu -> conv + Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies: + 1. the output of batch-norm should be channels last initially since its input is a conv's output. + Forcing the batch-norm's output to be contiguous results in the first copy + 2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output. + We need convert it to channels last layout which results in the second copy. + With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies + can be saved. + """ + output_set: OrderedSet[Node] = OrderedSet() + for n in reversed(self.module.graph.nodes): + if n.target == torch.ops.aten.convolution.default: + output_set.add(n) + continue + + for user in n.users: + if user in output_set: + output_set.add(n) + break + + # need a second pass to add downstream nodes of those channel last nodes to the sets. + # This pass is especially needed to avoid mix-layout kernel inputs in backward pass. + # + # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned + # from the fwd graph. Without this second pass, we will force relu's output to be contiguous. + # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last + # tensors and passed to a kernel. + # + # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x. + # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x . + # This also helps the following models: + # - res2net101_26w_4s + # - res2net50_14w_8s + # - sebotnet33ts_256 + for n in self.module.graph.nodes: + if n in output_set: + output_set.update(n.users) + + return output_set + + def warn_fallback(self, name: str) -> None: + if name not in self._warned_fallback: + self._warned_fallback.add(name) + perf_hint_log.info("Using FallbackKernel: %s", name) + + def add_device_info(self, device: torch.device) -> None: + self.device_types.add(device.type) + if device.index is not None: + self.device_idxs.add(device.index) + if V.graph.current_node and device not in self.device_node_mapping: + self.device_node_mapping[device] = V.graph.current_node + + @property + def fake_mode(self) -> torch._subclasses.fake_tensor.FakeTensorMode: + return V.fake_mode + + def try_get_buffer( + self, buffer_name: str + ) -> Optional[Union[ir.TensorBox, ir.Buffer]]: + if buffer_name in self.name_to_buffer: + return self.name_to_buffer[buffer_name] + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name] + if buffer_name in self.constants: + data = V.graph.constants[buffer_name] + return ir.ConstantBuffer( + buffer_name, + ir.FixedLayout( + data.device, data.dtype, *V.graph.static_sizes_strides(data) + ), + ) + + return None + + def get_buffer(self, buffer_name: str) -> Union[ir.TensorBox, ir.Buffer]: + buf = self.try_get_buffer(buffer_name) + if buf is not None: + return buf + raise RuntimeError(f"Failed to find buffer matching name {buffer_name}") + + def get_dtype(self, buffer_name: str) -> torch.dtype: + if buffer_name in self.constants: + return self.constants[buffer_name].dtype + if buffer_name in self.name_to_buffer: + return self.name_to_buffer[buffer_name].get_dtype() + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name].get_dtype() + m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name) + if m: + return self.get_dtype(m.group(1)) + raise KeyError(f"could not find {buffer_name}") + + def get_numel(self, buffer_name: str) -> Union[int, Expr]: + from .ir import MultiOutputLayout + + if buffer_name in self.constants: + return self.constants[buffer_name].numel() + if buffer_name in self.name_to_buffer: + buf = self.name_to_buffer[buffer_name] + if isinstance(getattr(buf, "layout", None), MultiOutputLayout): + return 1 + return buf.get_numel() + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name].get_numel() + raise KeyError(f"could not find {buffer_name}") + + def run(self, *args: Any) -> Any: # type: ignore[override] + with dynamo_timed("GraphLowering.run"): + return super().run(*args) + + def register_operation(self, op: ir.Operation) -> str: + assert op.operation_name is None, f"Operation registered twice: {op}" + assert isinstance(op, ir.Operation) + name = self.qualify_name(f"op{len(self.operations)}") + self.operations.append(op) + self.name_to_op[name] = op + op.operation_name = name + return name + + def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str: + name = self.qualify_name(f"buf{len(self.buffers)}") + self.buffers.append(buffer) + self.name_to_buffer[name] = buffer + if ( + # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144 + not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements()) + and buffer.get_device() is not None + ): + self.add_device_info(buffer.get_device()) + + if set_name: + buffer.name = name + return name + + def register_operation_list(self, operation_names: List[str]) -> str: + name = self.qualify_name("list_" + "_".join(operation_names)) + self.lists[name] = operation_names + return name + + def register_users_of( + self, node_output: Union[Iterable[ir.IRNode], ir.IRNode] + ) -> None: + def register(value: Union[Iterable[ir.IRNode], ir.IRNode]) -> None: + if isinstance(value, (list, tuple)): + for x in value: + register(x) + if isinstance(value, ir.TensorBox): + for read_name in value.get_read_names(): + self.name_to_users[read_name].append(value) + + register(node_output) + + def mark_buffer_mutated(self, name: str) -> None: + """ + When a buffer is mutated we need to make sure all the reads to + the old version are realized before the mutation happens. + """ + assert isinstance(name, str) + self.mutated_buffers.add(name) + + if name not in self.name_to_users: + return + + for user in self.name_to_users[name]: + user.realize() + + def get_original_value_of_constant(self, name: str) -> torch.Tensor: + """ + In AOTI, module buffers may have been mutated during the tracing and compilation. + Thus we need to read from previously stored original buffers, to make sure the + generated model.so uses correct initial values. + """ + assert name in self.allocated_constant_name and name in self.constants, ( + "Can not find the original value for " + name + ) + orig_name = get_cloned_parameter_buffer_name(self.allocated_constant_name[name]) + return ( + self.module.meta[orig_name] + if orig_name in self.module.meta + else self.constants[name] + ) + + def allocate_non_dup_const_name( + self, name: Optional[str], data: Union[Tensor] + ) -> str: + orig_name = name + if not config.aot_inductor.use_runtime_constant_folding: + for constant_name, value in self.constants.items(): + if ( + not data.is_mkldnn + and data.size() == value.size() + and data.stride() == value.stride() + and data.dtype == value.dtype + and data.device == value.device + and data.untyped_storage().data_ptr() + == value.untyped_storage().data_ptr() + and data.storage_offset() == value.storage_offset() + ): + return constant_name + + if name is None: + name = f"constant{len(self.constants)}" + assert name is not None + if name[0].isdigit(): + name = f"constant_{name}" + name = self.qualify_name(name) + # We may generate a var name for each constant in the codegen. + # Let's only keep sane characters. + prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name) + name = prefix + cnt = 0 + while name in self.constants: + name = f"{prefix}_{cnt}" + cnt += 1 + self.constants[name] = data + self.constant_reprs[name] = ( + f"{data.device!r} {data.dtype!r} " + f"{tuple(data.size())!r} {tuple(data.stride())!r} " + f"{hash(data):x}" + ) + self.allocated_constant_name[name] = orig_name # type: ignore[assignment] + return name + + def add_tensor_constant( + self, data: Tensor, name: Optional[str] = None + ) -> TensorBox: + new_name = self.allocate_non_dup_const_name(name, data) + return TensorBox.create( + ir.ConstantBuffer( + new_name, + FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)), + ) + ) + + def constant_name(self, name: str, device_override: Optional[torch.device]) -> str: + """ + We AOT copy constants to the devices they are needed on. + If device_override doesn't match the constant's device, then + copy it and return a different name. + """ + if self.constants[name].device == device_override or device_override is None: + return name + with torch.utils._python_dispatch._disable_current_modes(): + # caller might have OrderedSet fake tensor mode which will create a fake tensor + # when calling .to, so unset modes here + return self.allocate_non_dup_const_name( + f"{name}_{device_override.type}{device_override.index or 0}", + self.constants[name].to(device_override), + ) + + def placeholder( + self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override] + ) -> Union[Expr, TensorBox, None]: + example = super().placeholder(target, args, kwargs) # type: ignore[arg-type] + self.graph_input_names.append(target) + if isinstance(example, SymTypes): + expr = example.node.expr + self.graph_inputs[target] = expr + return expr + elif isinstance(example, (int, bool, float)): + expr = sympy.sympify(example) + self.graph_inputs[target] = expr + return expr + elif example is None: + return None + if isinstance(example, BackwardState): + # Ignored arg, must be unused + # Alternately we could filter this out in AotAutograd + return None + assert isinstance(example, torch.Tensor), example + # todo(chilli): We can remove the last check once we turn buffers into + # static shape tensors. That's a hack to workaround Inductor believing + # the buffer should be static but us passing in a fake tensor with + # symbolic shapes. + if not example._has_symbolic_sizes_strides: + # the first N inputs are weights + sizes, strides = self.static_sizes_strides(example) + else: + sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment] + # TODO(jansel): handle input aliasing + target = self.qualify_name(target) + tensor = TensorBox.create( + InputBuffer( + target, + FixedLayout(example.device, example.dtype, sizes, strides), + ) + ) + self.graph_inputs[target] = tensor + self.graph_inputs_original[target] = tensor.data.data + if self.current_node.users: # cudagraphs should work with an unused CPU input + self.add_device_info(example.device) + + # Note: [Input Alignment handling in Inductor] + # Alignment matters for generating efficient code. Some operations, + # e.g. vectorized loads, can only be performed on aligned inputs. + # + # But if we codegen assuming aligned inputs and then get unaligned + # inputs at runtime, then we are forced to clone - which is bad for + # both perf and memory usage. + # + # One option would be to guard on storage_offset%ALIGNMENT, and then + # codegen based on this. But storage_offset guards turned out to be + # expensive and cause recompiles; Instead, we're generating code + # based on the alignment of the example input without guarding. + with maybe_get_suppress_shape_guards_ctx(): + if should_assume_input_aligned(example): + self.aligned_inputs.add(target) + return tensor + + def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override] + if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): + return super().call_function(target, args, kwargs) + + # hasattr on OpOverloadPacket is slow, check isinstance first + if not isinstance(target, torch._ops.OpOverloadPacket) and hasattr( + target, "_inductor_lowering_function" + ): + # passthrough lowerings from .pattern_matcher + return target(*args, **kwargs) + + if target not in lowerings: + assert isinstance( + target, torch._ops.OpOverload + ), f"{target} is not an OpOverload" + base_name = target.name().split(".")[0] + if base_name in FALLBACK_ALLOW_LIST: + make_fallback(target) + elif config.implicit_fallbacks: + error = ( + MissingOperatorWithDecomp + if get_decompositions([target]) + else MissingOperatorWithoutDecomp + ) + log.info( + "Creating implicit fallback for:\n%s", + error.operator_str(target, args, kwargs), + ) + make_fallback(target) + + elif get_decompositions([target]): + # There isn't a good way to dynamically patch this in + # since AOT Autograd already ran. The error message tells + # the user how to fix it. + raise MissingOperatorWithDecomp(target, args, kwargs) + else: + raise MissingOperatorWithoutDecomp(target, args, kwargs) + + try: + log.debug(" via %s", lowerings[target]) # type: ignore[index] + out = lowerings[target](*args, **kwargs) # type: ignore[index] + return out + except Exception as e: + raise LoweringException(e, target, args, kwargs).with_traceback( + e.__traceback__ + ) from None + + @staticmethod + def can_inline_constant(t: torch.Tensor) -> bool: + """ + True if this is a small constant attr that will be inlined. + """ + return len(t.shape) == 1 and t.shape[0] <= 8 + + def get_attr( + self, target: str, args: Tuple[()], kwargs: Dict[str, object] # type: ignore[override] + ) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]: + # this is a constant + value = getattr_recursive(self.module, target) # type: ignore[arg-type] + + if isinstance(value, torch.fx.GraphModule): + return ir.Subgraph(name=target, graph_module=value) + + if isinstance(value, torch._C.ScriptObject): + self.torchbind_constants[target] = value + self.constant_reprs[target] = "" + return TorchBindObject(target, value) + + assert isinstance(value, torch.Tensor) + if ( + config.aot_inductor.use_runtime_constant_folding + or config.always_keep_tensor_constants + or unsupported_output_tensor(value) + ): + return self.add_tensor_constant(value, target) + + with no_dispatch(): + if value.shape == (): + return Constant(value.item(), value.dtype, value.device) + if self.can_inline_constant(value): + log.debug("Inlining constant: %s ", str(target)) + # tensor lowering has constant inlining logic + from .lowering import tensor + + return tensor(value.tolist(), dtype=value.dtype, device=value.device) + + return self.add_tensor_constant(value, target) + + def call_module(self, target: Any, args: Any, kwargs: Any) -> NoReturn: + raise AssertionError + + def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn: + raise AssertionError + + def output( + self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override] + ) -> None: + result = super().output(target, args, kwargs) # type: ignore[arg-type] + if not isinstance(result, (tuple, list)): + # nested subgraphs can have singleton outputs + result = (result,) + assert isinstance(result, (tuple, list)), type(result) + assert all( + isinstance( + x, + ( + TensorBox, + ir.Constant, + type(None), + ir.ConstantBuffer, + sympy.Expr, + sympy.logic.boolalg.Boolean, + int, + ir.EffectfulKernel, + ), + ) + for x in result + ), result + + fx_node_args = V.graph.current_node.args[0] # type: ignore[arg-type] + if not isinstance(fx_node_args, (tuple, list)): + # nested subgraphs can have singleton outputs + fx_node_args = (fx_node_args,) + result = [ir.ExternKernel.realize_input(x) for x in result] + result_correct_strides = [] + + assert len(fx_node_args) == len(result) + for r, fx_node in zip(result, fx_node_args): + if not isinstance(r, (ir.TensorBox, ir.BaseView)): + result_correct_strides.append(r) + else: + # AOT Autograd tries to detect stride divergence of inductor from output metadata. + # Here, we try to avoid spurious divergence by matching insignificant strides such as + result_correct_strides.append( + self.try_match_insignificant_strides( + r, fx_node.meta["val"].stride() + ) + ) + + self.graph_outputs = result_correct_strides + value: ir.IRNode + for name, value in self.graph_inputs.items(): + assert isinstance( + value, (TensorBox, sympy.Expr) + ), f"Unsupported inductor graph input type: {type(value)}" + if not isinstance(value, TensorBox): + continue + value.realize() + assert isinstance(value, TensorBox) + value = value.data + assert isinstance(value, ir.StorageBox) + value_storage_box = value + value = value.data + if not isinstance(value, InputBuffer) or value.get_name() != name: + # one of our inputs was mutated, need to turn that into a copy + ir.MutationLayoutSHOULDREMOVE.realize_into( + value, self.graph_inputs_original[name] + ) + # replace output with mutated input + try: + ind = self.graph_outputs.index(value_storage_box) + self.graph_outputs[ind] = self.graph_inputs_original[name] + except ValueError: + pass + + self.finalize() + log.debug( + "Force channels last inputs for %d conv for the current graph with id %d", + self.num_channels_last_conv, + self.graph_id if self.graph_id is not None else -1, + ) + + def finalize(self) -> None: + for buf in self.buffers: + buf.decide_layout() + + @contextmanager + def set_current_node(self, node: torch.fx.Node): # type: ignore[no-untyped-def] + old = self.current_node + try: + self.current_node = node + yield + finally: + self.current_node = old + + def try_match_insignificant_strides( + self, + tensor: Union[ir.TensorBox, ir.BaseView], + meta_strides_inp: Tuple[Union[int, torch.SymInt], ...], + ) -> Union[ir.TensorBox, ir.BaseView]: + """ + Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant + dimensions - size 0 or 1 - will be updated. + + If there are real stride differences (NHWC vs NCHW) then the input will be returned. + """ + + # should have already been realized + assert torch._inductor.ir.is_storage_and_layout(tensor) + + meta_strides = [ + s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_strides_inp + ] + + if all( + self.sizevars.statically_known_equals(s1, s2) + for s1, s2 in zip(meta_strides, tensor.get_stride()) + ): + return tensor # type: ignore[arg-type] + + def significant_strides_equal( + shape: Sequence[Union[Expr, int]], + meta_strides: Sequence[Union[Expr, int]], + tensor_strides: Sequence[Union[Expr, int]], + ) -> bool: + for dim, s1, s2 in zip(shape, meta_strides, tensor_strides): + if self.sizevars.statically_known_leq(dim, 1): # type: ignore[arg-type] + continue + + if not self.sizevars.statically_known_equals(s1, s2): + return False + + return True + + if not significant_strides_equal( + tensor.get_size(), meta_strides, tensor.get_stride() + ): + return tensor + + storage, old_layout = torch._inductor.ir.as_storage_and_layout(tensor) + new_stride = list(old_layout.stride) + for i, s in enumerate(tensor.get_size()): + if self.sizevars.statically_known_leq(s, 1): # type: ignore[arg-type] + new_stride[i] = meta_strides[i] + + new_layout = torch._inductor.ir.FixedLayout( + old_layout.device, + old_layout.dtype, + old_layout.size, + new_stride, + old_layout.offset, + ) + return ir.TensorBox(torch._inductor.ir.ReinterpretView(storage, new_layout)) + + def propagate_mutation( + self, + fx_node: torch.fx.Node, + old_args: Tuple[Any], + old_kwargs: Dict[str, Any], + new_args: Tuple[Any], + new_kwargs: Dict[str, Any], + ) -> None: + """Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs. + + Assumes we may have cloned old_args/old_kwargs into new_args/new_kwargs + and then called fx_node(*new_args, **new_kwargs). + + If fx_node mutates any of new_args/new_kwargs, and they are different from + old_args/old_kwargs, then we need to update the original tensor. + """ + assert isinstance(fx_node.target, torch._ops.OpOverload) + assert len(old_args) == len(new_args) + assert len(old_kwargs) == len(new_kwargs) + + def maybe_propagate( + schema_arg: torch._C.Argument, old_arg: ir.IRNode, new_arg: ir.IRNode + ) -> None: + if old_arg is new_arg: + return + if schema_arg.alias_info is not None and schema_arg.alias_info.is_write: + # The lowering for copy_ is smart enough to "replace" old_arg with + # new_arg in all future uses so a copy_ kernel never gets emitted. + self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {}) + + schema = fx_node.target._schema + for idx, (old_arg, new_arg) in enumerate(zip(old_args, new_args)): + schema_arg = schema.arguments[idx] + maybe_propagate(schema_arg, old_arg, new_arg) + + schema_kwargs = {arg.name: arg for arg in schema.arguments} + + for key in old_kwargs.keys(): + old_arg = old_kwargs[key] + new_arg = new_kwargs[key] + schema_arg = schema_kwargs[key] + maybe_propagate(schema_arg, old_arg, new_arg) + + def run_node(self, n: torch.fx.Node) -> object: + def debug(msg: str) -> None: + log.debug("lowering %s %s", LazyString(n.format_node), msg) + + buffer_watermark = len(self.buffers) + operation_watermark = len(self.operations) + + origins = {n} + is_call_function = n.op == "call_function" + if is_call_function: + args, kwargs = self.fetch_args_kwargs_from_env(n) + origins |= gather_origins(args, kwargs) + with ir.IRNode.current_origins(origins), self.set_current_node( # type: ignore[arg-type] + n + ), V.set_current_node( + n + ): + if ( + n.op == "call_function" + and n.target is not operator.getitem + and fallback_node_due_to_unsupported_type(n) + ): + debug("fallback_handler") + result = fallback_handler(n.target, add_to_fallback_set=False)( + *args, **kwargs # type: ignore[possibly-undefined] + ) + elif n.op == "call_function" and ( + layout_constraints := maybe_layout_constraints(n.target) # type: ignore[arg-type] + ): + debug("layout_constraints") + old_args = args # type: ignore[possibly-undefined] + old_kwargs = kwargs # type: ignore[possibly-undefined] + args, kwargs = layout_constraints(n, *args, **kwargs) # type: ignore[index] + result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type] + # layout_constraints are allowed to make new copies of the inputs. + # if they do, and if the target is mutable, then we need to + # write the new values back into the original inputs. + self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined] + elif is_magic_method(n.target): + # TODO: this is sus, it probably should be handled in the + # lowerings themselves similarly to sym_size/sym-stride + # https://github.com/pytorch/pytorch/issues/127789 + debug("is_magic_method") + if isinstance( + n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) + ): + result = n.meta["val"].node.expr + else: + result = super().run_node(n) + else: + debug("") + result = super().run_node(n) + + # require the same stride order for dense outputs, + # 1. user-land view() will not throw because inductor + # output different strides than eager + # long term the solution is to make view() always succeed + # with infallible strides. + # 2: as_strided ops, we need make sure its input has same size/stride with + # eager model to align with eager behavior. + as_strided_ops = [ + torch.ops.aten.as_strided.default, + torch.ops.aten.as_strided_.default, + torch.ops.aten.as_strided_scatter.default, + torch.ops.aten.resize.default, + torch.ops.aten.resize_as.default, + ] + is_output = any(user.op == "output" for user in n.users) + is_input_for_as_strided = any( + user.target in as_strided_ops for user in n.users + ) + + if n.meta.get("inductor_realize_to_strides", False) and isinstance( + result, TensorBox + ): + result.realize() + strides = n.meta["val"].stride() + sym_strides = torch._inductor.utils.any_is_symbolic(*strides) + if ( + not hasattr(result, "get_stride") + or result.get_stride() != strides + and not sym_strides + ): + stride_order = ir.get_stride_order(strides) + result = ir.ExternKernel.require_stride_order(result, stride_order) + if ( + is_output + and isinstance(result, TensorBox) + and isinstance(result.data, ir.BaseView) + ): + # Realize so that outputs are correctly aliased + result.realize() + + if (is_output or is_input_for_as_strided) and isinstance( + n.meta["val"], torch.Tensor + ): + strides = n.meta["val"].stride() + if len(strides): + allow_padding = ( + config.pad_outputs or n.name not in self.user_visible_outputs + ) and not is_input_for_as_strided + dense = torch._prims_common.is_non_overlapping_and_dense( + n.meta["val"] + ) + unbacked_symbols_in_strides = ( + len(free_unbacked_symbols(strides)) > 0 + ) + if ( + not unbacked_symbols_in_strides + and dense + and len(result.get_size()) == 4 + and n in self.nodes_prefer_channels_last + and n.name not in self.user_visible_outputs + and not is_input_for_as_strided + ): + strides = ir.FlexibleLayout.stride_ordered_for_memory_format( + result.get_size(), torch.channels_last + ) + if not unbacked_symbols_in_strides and len(strides): + # To avoid converting possible view ops to a copy kernel, we use the previous + # require_exact_strides to handle views. But ultimately it's better to require + # the right strides at the tensor definition. + if n.meta["val"]._is_view() or isinstance( + result.data, ir.BaseView + ): + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order(strides), + allow_padding=allow_padding, + ) + else: + strides = [ + s.node.expr if isinstance(s, torch.SymInt) else s + for s in strides + ] + result = ir.ExternKernel.require_exact_strides( + result, strides, allow_padding=allow_padding + ) + + # Realize if (1) any user need inputs realized, or (2) there is + # already too many reads and rematerializing can be bad. + num_users = len(OrderedSet(n.users)) + if num_users > 1 and isinstance(result, TensorBox): + for user in n.users: + if user.target in needs_realized_inputs: + result.realize_hint() + # This inclusion is somewhat controversial (from + # discussion between Horace, Natalia, and Elias). + # Currently, it's not very clear why this is helpful. + # The general idea here is that even though a node may + # have FlexibleLayout, we still often *treat* it as if + # it was contiguous. This appears to sometimes result in + # suboptimal behavior. + # + # When we do a better job selecting layout, we should + # revisit this. + need_fixed_layout = [ + torch.ops.aten.convolution_backward.default, + torch.ops.aten.mm.default, + torch.ops.aten._int_mm.default, + ] + need_fixed_channels_last_layout = [] + if not self.layout_opt: + need_fixed_layout.append(torch.ops.aten.convolution.default) + if torch._C._has_mkldnn: + need_fixed_layout += [ + torch.ops.mkldnn._linear_pointwise.default, + torch.ops.mkldnn._linear_pointwise.binary, + torch.ops.aten.mkldnn_rnn_layer.default, + torch.ops.onednn.qlinear_pointwise.default, + torch.ops.onednn.qlinear_pointwise.tensor, + torch.ops.onednn.qlinear_pointwise.binary, + torch.ops.onednn.qlinear_pointwise.binary_tensor, + ] + need_fixed_channels_last_layout += [ + torch.ops.mkldnn._convolution_pointwise.default, + torch.ops.mkldnn._convolution_pointwise.binary, + torch.ops.mkldnn._convolution_pointwise_.binary, + torch.ops.mkldnn._convolution_transpose_pointwise.default, + torch.ops.onednn.qconv2d_pointwise.default, + torch.ops.onednn.qconv2d_pointwise.binary, + ] + if torch._C.has_mkl: + need_fixed_layout += [torch.ops.mkl._mkl_linear.default] + if user.target in need_fixed_layout: + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order(n.meta["val"].stride()), + allow_padding=True, + ) + if ( + user.target in need_fixed_channels_last_layout + and n is user.args[0] + ): + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order( + make_channels_last_strides_for(n.meta["val"].shape) + ), + ) + if user.op == "output": + if isinstance(result.data.data, (Pointwise, Reduction)): + result.realize() + + # TODO(jansel): introduce a store vs inline choice + result.mark_reuse(len(n.users)) + + # Realize if the IRNode already has accumulated lots of reads + if isinstance(result, TensorBox) and result.has_exceeded_max_reads(): + # Prevent excessive accumulation in a computed buffer, when + # there are multiple branches each with small number of memory + # reads, but they converge to a user. + result.realize_hint() + + # Realize if a Pointwise has too much stuff to be inlined. + # As this may cause RecursionError during Inductor's evaluation. + if isinstance(result, TensorBox) and isinstance(result.data, StorageBox): + curr = result.data.data + if isinstance(curr, Pointwise): + # Use inner fn as a rough proxy. Good enough. + if curr.has_large_inner_fn(): + result.realize() + + # This is not complete, but it doesn't have to be: origin_node + # tracking is best effort. The logic here critically relies on direct + # TensorBox -> StorageBox denoting a non-view; we don't bother trying + # to get views to work. Feel free to add any extra cases as needed. + # + # Note: we can't YOLO tree_map over this result, because if there are + # buffers or a view involved, we might not be able to validly assign + # the origin_node here. + if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox): + if isinstance(result.data.data, ir.Loops): + result.data.data.origin_node = n + elif isinstance(result.data.data, ir.Buffer): + result.data.data.origin_node = n + if isinstance(result.data.data, ir.ComputedBuffer) and isinstance( + result.data.data.data, ir.Loops + ): + result.data.data.data.origin_node = n + # Not really multi-output, can straightforwardly recurse in + elif ( + isinstance(result.data.data, ir.MultiOutput) + and not result.data.data.indices + ): + if isinstance(result.data.data.inputs[0], ir.Buffer): + result.data.data.inputs[0].origin_node = n + + self.register_users_of(result) + + new_unbacked_defs: OrderedSet[sympy.Symbol] = OrderedSet() + for buf in self.buffers[buffer_watermark:]: + new_unbacked_defs |= buf.get_unbacked_symbol_defs() + for op in self.operations[operation_watermark:]: + new_unbacked_defs |= op.get_unbacked_symbol_defs() + + def format_new_defs() -> str: + r = [] + for buf in self.buffers[buffer_watermark:]: + r.append( + f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n" + ) + for op in self.operations[operation_watermark:]: + r.append( + f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n" + ) + return "***\n".join(r) + + if n.op != "placeholder": + # Note [Backwards runtime asserts] + # Backwards poses an interesting problem for deferred runtime + # asserts. In the easy case, we may solely close over data + # dependent sized tensors, and there are no binding sites for + # unbacked SymInts. In this case, we can just drop all the + # runtime asserts on the floor: no non-placeholder bindings, no + # problem. + # + # However, it is *possible* for a fresh runtime assert to show up + # between forwards and backwards. Right now, the freezing process + # that happens when we lower forwards means that we will freeze + # runtime asserts, and then the moment the backwards lowering + # process attempts to add a new deferred runtime assert, we will + # fail. Let's say you remove that assert. Now when we get here, + # we need to make sure we actually emit these asserts (because we + # can't emit them in forwards, we already compiled it). So we + # have to do something here. But we don't want to reemit ALL + # deferred runtime asserts, we only want to emit the NEW ones. + # Therefore needing some sort of stratification in the ShapeEnv. + # This is all doable, it just hasn't been done yet. + shape_env = V.graph.sizevars.shape_env + + def make_assert(expr: Expr, msg: str) -> None: + assert_op = ir.AssertScalar(expr, msg) + self.register_buffer(assert_op, set_name=True) + self.register_operation(assert_op) + + for i0 in new_unbacked_defs: + ras = self.ras_by_symbol.pop(i0, []) + # NB: size-like not needed, we won't retrace + vr = shape_env.var_to_range[i0] + if not shape_env._default_unspecified_value_range().issubset(vr): + + def is_convertible(s: Expr) -> bool: + if s in (int_oo, -int_oo): + return False + try: + int(s) + return True + except TypeError: + return False + + if is_convertible(vr.lower): + make_assert(i0 >= vr.lower, f"{i0} >= {vr.lower}") + if is_convertible(vr.upper): + make_assert(i0 <= vr.upper, f"{i0} <= {vr.upper}") + + for ra in ras: + fvs = free_unbacked_symbols(ra.expr) + missing = fvs - self.bound_unbacked_symbols + if missing: + i1 = min(missing, key=str) + self.ras_by_symbol.setdefault(i1, []).append(ra) + else: + make_assert(ra.expr, f"{ra.expr}") + + self.bound_unbacked_symbols |= new_unbacked_defs + + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {}) + ) + # When we do lowering, it is possible we reallocate unbacked SymInts. + # So we need to line up the unbacked SymInts when performing the test + # here + # + # In principle, we could permit lowering to introduce MORE unbacked + # SymInts: as long as all the old unbacked ones are accounted for, + # it's fine for inductor to introduce extra calls to item()/unbacked() + # whatever. This actually happens in practice when an unbacked SymInt + # gets memoized away; naively, when Inductor reprocesses a kernel, it + # doesn't know that the memo still applies, and ends up allocating a + # new symbol. However, this is generally a bad thing: we may still + # end up needing to test equalities on the symbols, and a fresh + # symbol is likely to hit lots of GuardOnDataDependent errors that + # we already know facts for. + renamed_unbacked_bindings = OrderedSet( + V.fake_mode.shape_env.unbacked_renamings.get(s, s) + for s in unbacked_bindings.keys() + ) + assert new_unbacked_defs >= renamed_unbacked_bindings, ( + f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n" + f"fx node is: {n.format_node()}\n" + f"new operations are:\n\n{format_new_defs()}" + ) + + return result + + def validate_can_generate_cpp_wrapper(self) -> None: + if config.disable_cpp_codegen: + raise CppWrapperCodeGenError("C++ codegen is disabled") + + if sys.platform not in ["linux", "darwin", "win32"]: + raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}") + + for value in self.graph_inputs.values(): + dtype = None + if isinstance(value, TensorBox): + dtype = value.get_dtype() + elif isinstance( + value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) + ): + dtype = may_get_constant_buffer_dtype(value) + + if not supported_dtype_of_cpp_wrapper(dtype, self.cuda): + raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}") + + def init_wrapper_code(self) -> None: + self.cuda = "cuda" in self.device_types + if self.cpp_wrapper: + self.validate_can_generate_cpp_wrapper() + + device_types = self.device_types.copy() + device_types.discard("cpu") + device_types.discard("meta") + # TODO(Eikan): Only support mixing cpu and other device now. + assert len(device_types) <= 1, "Does not support mixing {}".format( + "+".join(device_types) + ) + only_cpu = len(device_types) == 0 + device_type = "cpu" if only_cpu else device_types.pop() + + self.device_ops = get_device_op_overrides(device_type) + wrapper_code_gen_cls = get_wrapper_codegen_for_device( + device_type, self.cpp_wrapper + ) + assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported" + self.wrapper_code = wrapper_code_gen_cls() + + if self.const_module: + # If we have const module, we could reuse the kernels + # This could avoid duplication and save time on doing recompilation (if Triton.) + self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter + self.wrapper_code.src_to_kernel = ( + self.const_module.wrapper_code.src_to_kernel + ) + + def codegen_with_cpp_wrapper(self) -> Tuple[str, List[Tuple[int, Node]]]: + """ + For CPU, the cpp wrapper codegen is done in one pass. + For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python + wrapper code and run it to generate autotuned kernel binaries in the first pass; and then + generate cpp wrapper code and compile it to a dynamic library in the second pass. + """ + if "cuda" in self.device_types: + # first pass + self.cpp_wrapper = False + # Although triton.store_cubin was OrderedSet in compile_fx, the backward pass didn't pick + # that up. In theory it should work by only setting triton.store_cubin to True here, + # but that will cause a problem when use_runtime_constant_folding is OrderedSet. + with config.patch({"triton.store_cubin": True}): + compiled = self.compile_to_module().call + + if not config.triton.autotune_at_compile_time: + + def materialize( + x: Union[torch.SymInt, torch.SymFloat, torch.Tensor] + ) -> Union[int, float, torch.Tensor]: + if x is None: + return None + elif isinstance(x, (torch.SymInt, torch.SymFloat)): + # Need concrete value to run dynamic shapes and tune the result + return x.node.hint + elif isinstance(x, FakeTensor): + return defake(x) + else: + assert isinstance( + x, torch.Tensor + ), "Unknown type when creating real inputs" + str(type(x)) + return x + + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context is not None and not isinstance( + V.real_inputs, NullHandler + ): + if tracing_context.output_strides: + tracing_context.output_strides.clear() + + params_flat = [ + param + for param in tracing_context.params_flat # type: ignore[union-attr] + if param is not None + ] + real_inputs = [ + materialize(x) + for x in itertools.chain(params_flat, V.real_inputs) + ] + else: + # In the backward pass, V.real_inputs is not OrderedSet. + # Generating random inputs based on self.example_inputs sometimes can be problematic, + # e.g. illegal memory access. A comprehensive fix is to autotune in a separate process. + real_inputs = [ + materialize(x) + for x in ( + self.example_inputs + if isinstance(V.real_inputs, NullHandler) + else V.real_inputs + ) + ] + + if self.mutated_inputs: + from .compile_fx import clone_preserve_strides + + mutated_input_idxs = [ + idx + for idx, name in enumerate(self.graph_inputs) + if name in self.mutated_inputs + and isinstance(real_inputs[idx], torch.Tensor) + ] + for idx in mutated_input_idxs: + # clone mutated Tensor inputs to avoid mutating them in + # the first pass of the CPP wrapper-based compilation, as + # this will lead to a side effect on the example inputs: + # e.g. if torch.compile(f)(x) if called on input-mutating + # f, the inputs x will be mutated twice in the process: + # once here, and again when running the compiled model; + # this will also lead to a numerically incorrect output + mutated_inp = real_inputs[idx] + assert isinstance(mutated_inp, torch.Tensor) + real_inputs[idx] = clone_preserve_strides(mutated_inp) + del mutated_inp + + with torch.utils._python_dispatch._disable_current_modes(): + compiled(real_inputs) + del real_inputs + + # second pass + self.cpp_wrapper = True + self.removed_buffers.clear() + self.removed_operations.clear() + self.inplaced_to_remove.clear() + V.graph.sizevars.precomputed_replacements.clear() + V.graph.sizevars.inv_precomputed_replacements.clear() + with config.patch({"triton.autotune_at_compile_time": False}): + return self.codegen() + else: + # cpu + return self.codegen() + + def codegen(self) -> Tuple[str, List[Tuple[int, Node]]]: + from .scheduler import Scheduler + + self.init_wrapper_code() + + self.scheduler = Scheduler(self.operations) + V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes) + + self.wrapper_code.push_codegened_graph(self) + self.scheduler.codegen() + + log.debug( + "Finished codegen for all nodes. The list of kernel names available: %s", + V.graph.all_codegen_kernel_names, + ) + + result = self.wrapper_code.generate(self.is_inference) + self.wrapper_code.pop_codegened_graph() + return result + + def codegen_subgraph(self, parent_graph: "GraphLowering") -> None: + """ + This is a more compact version of the `codegen()` above + where we codegen this graph as a subgraph of some parent + graph. The parent graph is passed as an argument: the + intention is to inline codegening of the subgraph in + the parent graph's wrapper code (including the generated + kerenls). The wrapper code is not finalized (via `.generate()` + call), as this will be done in the parent graph's `codegen()`. + """ + from .scheduler import Scheduler + + self.wrapper_code = parent_graph.wrapper_code + self.device_ops = parent_graph.device_ops + self.cpp_wrapper = parent_graph.cpp_wrapper + + self.scheduler = Scheduler(self.operations) + self.scheduler.codegen() + + def count_bytes( + self, + ) -> Tuple[ + int, List[Tuple[BaseSchedulerNode, int]], List[Tuple[BaseSchedulerNode, float]] + ]: + total_bytes = 0 + node_counts = [] + node_runtimes = [] + for node in self.scheduler.nodes: + num_bytes = node.get_read_write_buffers_sizes() + total_bytes += num_bytes + node_counts.append((node, num_bytes // 4)) + node_runtimes.append((node, node.get_estimated_runtime())) + + return total_bytes, node_counts, node_runtimes + + @staticmethod + def save_output_code(code: str) -> None: + # No-op to be patched for unit tests + pass + + def compile_to_module(self) -> ModuleType: + with dynamo_timed( + "GraphLowering.compile_to_module", phase_name="code_gen", fwd_only=False + ): + return self._compile_to_module() + + def _compile_to_module(self) -> ModuleType: + from .codecache import PyCodeCache + + code, linemap = ( + self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() + ) + + GraphLowering.save_output_code(code) + output_code_log.debug("Output code: \n%s", code) + try: + linemap = [(line_no, node.stack_trace) for line_no, node in linemap] # type: ignore[misc] + key, path = PyCodeCache.write(code) + except Exception: + trace_structured( + "inductor_output_code", + # Just omit the filename, I still want the code though! + payload_fn=lambda: code, + ) + raise + else: + trace_structured( + "inductor_output_code", + lambda: {"filename": path}, + payload_fn=lambda: code, + ) + + mod = PyCodeCache.load_by_key_path( + key, + path, + linemap=linemap, # type: ignore[arg-type] + attrs={**self.constants, **self.torchbind_constants}, + ) + self.cache_key = key + self.cache_path = path + self.cache_linemap = linemap # type: ignore[assignment] + + # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029 + # TODO. Revisit this once the logging API is more mature + assert mod.__file__ is not None + + log_module_code(mod.__file__) + log.debug("Output code written to: %s", mod.__file__) + output_code_log.info("Output code written to: %s", mod.__file__) + if config.benchmark_kernel: + print(f"Compiled module path: {mod.__file__}", file=sys.stderr) + V.debug.output_code(mod.__file__) + V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug") + return mod + + def compile_to_fn(self) -> Any: + if self.aot_mode: + from .codecache import AotCodeCompiler + + assert self.cpp_wrapper, "AOT mode only supports C++ wrapper" + code, linemap = self.codegen_with_cpp_wrapper() + output_code_log.debug("Output code: \n%s", code) + + serialized_extern_kernel_nodes = None + if self.extern_kernel_nodes: + serialized_extern_kernel_nodes = self.extern_node_serializer( + self.extern_kernel_nodes + ) + output_code_log.debug( + "Serialized Extern Kernel Nodes: \n%s", + serialized_extern_kernel_nodes, + ) + + # Directly return the file path with the compiled code + return AotCodeCompiler.compile( + self, code, serialized_extern_kernel_nodes, cuda=self.cuda + ) + else: + return self.compile_to_module().call + + def get_output_names(self) -> List[str]: + return [ + node.get_name() + for node in self.graph_outputs + if not isinstance(node, ir.NoneAsConstantBuffer) + and not isinstance(node, ir.ShapeAsConstantBuffer) + ] + + def is_unspec_arg(self, name: str) -> bool: + # dynamo wraps unspec variable as 0d CPU tensor, + # need to convert to scalar during codegen (triton only) + return ( + name in self.graph_inputs.keys() + and self.graph_inputs[name].get_numel() == 1 + and self.graph_inputs[name].get_device().type == "cpu" + ) or name in self.zero_dim_cpu_tensor_list diff --git a/lib/python3.10/site-packages/torch/_inductor/hooks.py b/lib/python3.10/site-packages/torch/_inductor/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..9d8aeecd283185608260f0ffd16a19270f9e96b1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/hooks.py @@ -0,0 +1,30 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Callable, List, TYPE_CHECKING + + +if TYPE_CHECKING: + import torch + +# Executed in the order they're registered +INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = [] + + +@contextlib.contextmanager +def intermediate_hook(fn): + INTERMEDIATE_HOOKS.append(fn) + try: + yield + finally: + INTERMEDIATE_HOOKS.pop() + + +def run_intermediate_hooks(name, val): + global INTERMEDIATE_HOOKS + hooks = INTERMEDIATE_HOOKS + INTERMEDIATE_HOOKS = [] + try: + for hook in hooks: + hook(name, val) + finally: + INTERMEDIATE_HOOKS = hooks diff --git a/lib/python3.10/site-packages/torch/_inductor/index_propagation.py b/lib/python3.10/site-packages/torch/_inductor/index_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..f4384d51b7d4adea0d55804090d684f9b7573b05 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/index_propagation.py @@ -0,0 +1,373 @@ +# mypy: allow-untyped-defs +"""This file implements the IndexPropagation ops handler, which wraps an +underlying handler to add a limited form of constant propagation, as well as +propagation of sympy expressions downstream of ops.index_expr calls. + +For example, say we have the IR: + + tmp0 = ops.index_expr(x, torch.int32) + tmp1 = ops.constant(2, torch.int32) + tmp2 = ops.mul(tmp0, tmp1) + tmp3 = ops.indirect_indexing(tmp2, x_size) + tmp4 = ops.load("buf0", tmp3) + +The underlying handler would just see: + + ops.load("buf0", x * 2) + +This is limited by the set of operators handled in the sympy expression +printers. So simple operations like minimum and maximum cannot be translated to +SymPy expressions yet, despite sympy.Min and sympy.Max existing. + +""" +import itertools +from dataclasses import dataclass +from typing import Any, Callable, Dict, Literal, Optional, overload, Tuple, Union +from typing_extensions import TypeAlias + +import sympy + +import torch +from torch._prims_common import dtype_to_type, is_integer_dtype +from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from .sizevars import evaluate_expr +from .utils import generate_assert +from .virtualized import V + + +_ExprType = Union[sympy.Expr, float, int, bool] + + +def _is_constant(val: _ExprType): + if isinstance(val, sympy.Basic): + return val.is_number + return isinstance(val, (int, float, bool)) + + +def upper_bound(val: _ExprType): + return bound_sympy(val).upper if isinstance(val, sympy.Expr) else val + + +@dataclass +class TypedExpr: + """A SymPy expression with associated type""" + + expr: _ExprType + dtype: torch.dtype + + def is_constant(self): + return _is_constant(self.expr) + + def __post_init__(self): + if _is_constant(self.expr): + self.expr = dtype_to_type(self.dtype)(self.expr) + + +class SymPyOps: + """An ops handler where all IR values are SymPy expressions + + When a value cannot be represented as a SymPy expression, the method is + either not defined, or returns NotImplemented + + """ + + @staticmethod + def identity(value: Any) -> Any: + return value + + @staticmethod + def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr: + return TypedExpr(value, dtype) + + @staticmethod + def index_expr(value: Union[sympy.Expr, int], dtype: torch.dtype) -> TypedExpr: + return TypedExpr(value, dtype) + + @staticmethod + def to_dtype( + value: TypedExpr, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = False, + ) -> TypedExpr: + return TypedExpr(value.expr, dtype) + + @staticmethod + def abs(x: TypedExpr) -> TypedExpr: + return TypedExpr(abs(x.expr), x.dtype) # type: ignore[arg-type] + + @staticmethod + def square(x: TypedExpr) -> TypedExpr: + return TypedExpr(x.expr * x.expr, x.dtype) + + @staticmethod + def add(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr + y.expr, result_type) + + @staticmethod + def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr - y.expr, result_type) + + @staticmethod + def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr * y.expr, result_type) + + @staticmethod + def neg(x: TypedExpr) -> TypedExpr: + return TypedExpr(-x.expr, x.dtype) + + @staticmethod + def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + return TypedExpr(FloorDiv(x.expr, y.expr), result_type) + + @staticmethod + def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr) + return TypedExpr(result_expr, result_type) + + @staticmethod + def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + x_expr = sympy.sympify(x.expr) + y_expr = sympy.sympify(y.expr) + # In these cases, remainder in Python == remainder in C++, so this transformation + # is sound + if ( + x_expr.is_nonnegative is not None + and x_expr.is_nonnegative == y_expr.is_positive + ): + result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr) + return TypedExpr(result_expr, result_type) + return NotImplemented + + @staticmethod + def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(sympy.Min(x.expr, y.expr), result_type) + + @staticmethod + def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(sympy.Max(x.expr, y.expr), result_type) + + +@dataclass +class IndexPropVar: + value: Any # Either an IR value, or TypedExpr if is_symbolic is true + is_symbolic: bool = False + + @staticmethod + def new_symbolic(expr: TypedExpr) -> "IndexPropVar": + return IndexPropVar(expr, is_symbolic=True) + + def __post_init__(self): + assert not self.is_symbolic or isinstance( + self.value, TypedExpr + ), "Symbolic IndexPropVar must contain a TypedExpr" + + +IndexPropResult: TypeAlias = Union[IndexPropVar, Tuple["IndexPropResult", ...]] + + +class IndexPropagation: + """Ops wrapper that tries to propagate constant and index_expr values through the computation. + + This aims to maximize the compile time simplification possible, and convert + indirect indexing from arange into normal static indexing. + + """ + + def __init__( + self, + inner: Any, + iter_ranges: Dict[sympy.Symbol, sympy.Expr], + indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr], + ) -> None: + self._inner = inner + self.shape_env = V.graph.sizevars.shape_env + + var_to_range = { + k: ValueRanges(0, upper_bound(v) - 1) for k, v in iter_ranges.items() + } + self.var_to_range = tuple( + itertools.chain(self.shape_env.var_to_range.items(), var_to_range.items()) + ) + # NOTE: this is intentionally kept as a reference so the caller can + # update it in-place + self.indirect_var_ranges = indirect_var_ranges + + axioms = [] + for x, s in iter_ranges.items(): + axioms.append(0 <= x) + axioms.append(x < s) + self.axioms = tuple(axioms) + self.shape_env.get_axioms() + + def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any: + # Construct a new constant/index_expr from the SymPy expression + if _is_constant(expr): + val = dtype_to_type(dtype)(expr) + return self._inner.constant(val, dtype) + return self._inner.index_expr(expr, dtype) + + def unwrap(self, a: Union[Any, IndexPropVar]) -> Any: + if isinstance(a, (list, tuple)): + return tuple(self.unwrap(v) for v in a) + + if not isinstance(a, IndexPropVar): + return a + + # Prefer the sympy representation if possible + if a.is_symbolic: + return self.materialize_expr(a.value.expr, a.value.dtype) + + return a.value + + def wrap(self, a) -> IndexPropResult: + if isinstance(a, (list, tuple)): + return tuple(self.wrap(v) for v in a) + return IndexPropVar(a) + + @overload + def fallback( + self, + name: Literal["indirect_indexing"], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> IndexPropVar: + ... + + @overload + def fallback( + self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> IndexPropResult: + ... + + def fallback( + self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> IndexPropResult: + # Fallback to the wrapped handler + new_args = [self.unwrap(a) for a in args] + new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()} + return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs)) + + def propagate_sympy( + self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> IndexPropResult: + # Build a new SymPy expression from this ops call + def unwrap(a: Union[Any, IndexPropVar]) -> Any: + if not isinstance(a, IndexPropVar): + return a + return a.value + + new_args = [unwrap(a) for a in args] + new_kwargs = {k: unwrap(v) for k, v in kwargs.items()} + new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs) + is_valid_expr = new_expr is not NotImplemented and ( + # Inductor doesn't expect floating point in sympy expressions, but + # allow floating point constants to be propagated + new_expr.is_constant() + or new_expr.expr.is_integer + ) + if not is_valid_expr: + return self.fallback(name, args, kwargs) + return IndexPropVar.new_symbolic(new_expr) + + def __getattr__(self, name: str) -> Callable[..., IndexPropResult]: + def inner(*args: Any, **kwargs: Any) -> IndexPropResult: + if not hasattr(SymPyOps, name): + return self.fallback(name, args, kwargs) + + var_arguments = [ + a + for a in itertools.chain(args, kwargs.values()) + if isinstance(a, IndexPropVar) + ] + if not all(v.is_symbolic for v in var_arguments): + return self.fallback(name, args, kwargs) + + return self.propagate_sympy(name, args, kwargs) + + return inner + + def statically_true(self, e): + """ + Given some iter_ranges, return a function that given an expression, returns whether + it is true or false using value ranges, guard knowledge and runtime_asserts. + + FIXME I think this may not be entirely right, as we may not be able to use all runtime_asserts + If this is an issue, just use guards in `self.axioms`. + + The proper way of handling this would be to have a global shape_env that adds + runtime_asserts as they happen in the code. Then, it shuld be used in SimplifyIndexing + to perform wrap_expr and in CSEProxy.check_bounds to elide upper / lower bounds also + for indirect_indexing + """ + var_to_range = ( + *self.var_to_range, + *( + (k, ValueRanges(0, upper_bound(v) - 1)) + for k, v in self.indirect_var_ranges.items() + ), + ) + return evaluate_expr(self.shape_env, e, self.axioms, var_to_range) + + def indirect_indexing( + self, + index: Union[Any, IndexPropVar], + size: Any, + check: bool = True, + wrap_neg=True, + ) -> Any: + if isinstance(index, IndexPropVar) and index.is_symbolic: + # If we find something we can convert into a direct indexing we do so + # We still need to (perhaps) wrap the expression and add bound checks + # We want to do this "constant folding", as we don't allow to fuse + # kernels into indirect indexing + + expr = sympy.sympify(index.value.expr) + + # TODO Perhaps move this logic to the simplify indexing pass + def wrap_expr(expr): + # Positive, negative, mixed + if self.statically_true(0 <= expr): + return expr + elif self.statically_true(expr < 0): + return expr + size + else: + return Where(expr < 0, expr + size, expr) + + # Sometimes it's easier to prove 0 <= expr than the weaker -size <= expr + can_prove_lower = self.statically_true(0 <= expr) or self.statically_true( + -size <= expr + ) + can_prove_upper = self.statically_true(expr < size) + if wrap_neg: + expr = wrap_expr(expr) + if generate_assert(check): + self.fallback( + "check_bounds", + (expr, size), + dict(lower=not can_prove_lower, upper=not can_prove_upper), + ) + return expr + + indirect_var = self.fallback( + "indirect_indexing", (index, size, check, wrap_neg), {} + ).value + return indirect_var diff --git a/lib/python3.10/site-packages/torch/_inductor/inductor_prims.py b/lib/python3.10/site-packages/torch/_inductor/inductor_prims.py new file mode 100644 index 0000000000000000000000000000000000000000..82a23d3e60cf10eac46fc5f30b2a8d811ebe7360 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/inductor_prims.py @@ -0,0 +1,179 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +from typing import Optional, Sequence + +import torch +from torch import _prims, Tensor + + +log = logging.getLogger(__name__) + + +def make_prim( + schema: str, + impl_aten, + return_type=_prims.RETURN_TYPE.NEW, + doc: str = "", + tags: Optional[Sequence[torch.Tag]] = None, +): + if isinstance(return_type, tuple): + + def meta(*args, **kwargs): + return tuple(_prims.TensorMeta(o) for o in impl_aten(*args, **kwargs)) + + else: + + def meta(*args, **kwargs): + return _prims.TensorMeta(impl_aten(*args, **kwargs)) + + return _prims._make_prim( + schema=schema, + return_type=return_type, + meta=meta, + impl_aten=impl_aten, + doc=doc, + tags=tags, + ) + + +def eager_force_stride(input_tensor: Tensor, stride) -> Tensor: + if input_tensor.stride() == stride: + return input_tensor + new_tensor = input_tensor.clone().as_strided( + input_tensor.shape, + stride, + ) + new_tensor.copy_(input_tensor) + return new_tensor + + +# Custom prims used for handling randomness +seed = make_prim( + "inductor_seed(Device device) -> Tensor", + lambda device: torch.randint(2**63 - 1, [], device=device), + doc="create a fresh seed (one per call) for use with inductor_rand", + tags=(torch.Tag.nondeterministic_seeded,), +) +seeds = make_prim( + "inductor_seeds(int count, Device device) -> Tensor", + lambda count, device: torch.randint(2**63 - 1, [count], device=device), + doc="Horizontal fusion of many inductor_seed() calls", + tags=(torch.Tag.nondeterministic_seeded,), +) +lookup_seed = make_prim( + # if inductor_lookup_seed changes, update partitioners.py + "inductor_lookup_seed(Tensor seeds, int index) -> Tensor", + lambda seeds, index: seeds[index], + doc="Extract a single seed from the result of inductor_seeds()", +) +random = make_prim( + "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor", + lambda size, seed, mode: getattr(torch, mode)(size, device=seed.device), + doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused", +) +randint = make_prim( + "inductor_randint(SymInt low, SymInt high, SymInt[] size, Tensor seed) -> Tensor", + lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device), + doc="torch.randint() using backend-specific RNG that can be fused", +) +force_stride_order = make_prim( + "inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor", + eager_force_stride, + doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise", +) +_unsafe_index_put_ = make_prim( + "_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", + lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_( + self, indices, values, accumulate + ), + doc="Unsafe index_put_ (doesn't issue device asserts)", +) +fma = make_prim( + "fma(Tensor a, Tensor b, Tensor c) -> Tensor", + lambda a, b, c: (a * b) + c, + doc="Fused multiply add: fma(a, b, c) -> (a * b) + c without rounding after the multiplication", +) + + +def _low_memory_max_pool2d_with_offsets_aten( + self, + kernel_size, + stride, + padding, + dilation, + ceil_mode, +): + vals, indices = torch.ops.aten.max_pool2d_with_indices( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + + input_width = self.shape[-1] + kernel_width = kernel_size[1] + + bh_shape = [1] * self.ndim + bh_shape[-2] = -1 + bh = torch.arange(indices.shape[-2], dtype=torch.int64, device=self.device).view( + bh_shape + ) + + bw_shape = [1] * self.ndim + bw_shape[-1] = -1 + bw = torch.arange(indices.shape[-1], dtype=torch.int64, device=self.device).view( + bw_shape + ) + + hbase = bh * stride[0] - padding[0] + wbase = bw * stride[1] - padding[1] + + ih = indices // input_width + iw = indices - (ih * input_width) + + h_inc = ih - hbase + w_inc = iw - wbase + + offsets = h_inc * kernel_width + w_inc + + return vals, offsets.to(torch.int8) + + +def _low_memory_max_pool2d_offsets_to_indices_aten( + offsets, kernel_width, input_width, stride, padding +): + offsets = offsets.to(torch.int64) + h_inc = offsets // kernel_width + w_inc = offsets - (h_inc * kernel_width) + + bh_shape = [1] * offsets.ndim + bh_shape[-2] = -1 + bh = torch.arange(offsets.shape[-2], dtype=torch.int64, device=offsets.device).view( + bh_shape + ) + + bw_shape = [1] * offsets.ndim + bw_shape[-1] = -1 + bw = torch.arange(offsets.shape[-1], dtype=torch.int64, device=offsets.device).view( + bw_shape + ) + + hbase = bh * stride[0] - padding[0] + wbase = bw * stride[1] - padding[1] + + ih = hbase + h_inc + iw = wbase + w_inc + return ih * input_width + iw + + +_low_memory_max_pool2d_with_offsets = make_prim( + "_low_memory_max_pool2d_with_offsets(Tensor self, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950 + _low_memory_max_pool2d_with_offsets_aten, + return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW), + doc="Instead of returning indices, returns indices offsets.", +) + +_low_memory_max_pool2d_offsets_to_indices = make_prim( + "_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding) -> Tensor", # noqa: B950 + _low_memory_max_pool2d_offsets_to_indices_aten, + doc="Convert small int offsets to regular indices.", +) diff --git a/lib/python3.10/site-packages/torch/_inductor/ir.py b/lib/python3.10/site-packages/torch/_inductor/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..0929996fe174523fe0b3aad44e2d20130ea7ba86 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/ir.py @@ -0,0 +1,6953 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import itertools +import logging +import textwrap +import traceback +from contextlib import nullcontext +from functools import partial +from typing import ( + Any, + Callable, + ClassVar, + ContextManager, + Dict, + Iterable, + List, + Literal, + Optional, + overload, + Sequence, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import TypeAlias +from unittest.mock import patch + +import sympy +from sympy import Expr, Integer, Symbol + +import torch._export.serde.schema as export_schema +import torch._logging +import torch.fx +import torch.utils._pytree as pytree +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import identity +from torch._export.serde.serialize import GraphModuleSerializer +from torch._higher_order_ops.auto_functionalize import can_auto_functionalize +from torch._inductor import metrics +from torch._prims_common import ( + compute_required_storage_length, + is_boolean_dtype, + is_float_dtype, + make_channels_last_strides_for, + StrideType, +) +from torch._subclasses.fake_tensor import get_schema_info +from torch.fx.experimental.symbolic_shapes import ( + CallMethodKey, + compute_unbacked_bindings, + DivideByKey, + free_unbacked_symbols, + rebind_unbacked, + resolve_unbacked_bindings, + SymTypes, +) +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import SymT + +from . import config, dependencies +from .codegen.common import BackendFeature, index_prevent_reordering +from .dependencies import ( + extract_free_unbacked_symbols, + extract_input_node_reduction_ranges, + extract_read_writes, + var_builder, +) +from .loop_body import LoopBody +from .ops_handler import OpCounterCSE, OpCountResult +from .runtime.benchmarking import benchmarker +from .runtime.hints import ReductionHint +from .utils import ( + argsort, + cache_on_self, + ceildiv, + convert_shape_to_inductor, + convert_shape_to_symint, + developer_warning, + get_kernel_metadata, + is_dynamic, + is_gpu, + sympy_dot, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_product, + sympy_subs, +) +from .virtualized import ops, OpsValue, V + + +if TYPE_CHECKING: + from .graph import GraphLowering + +_T = TypeVar("_T") +_U = TypeVar("_U") +_V = TypeVar("_V") + +_IntLike: TypeAlias = Union[int, Expr] + +log = logging.getLogger(__name__) +indent = functools.partial(textwrap.indent, prefix=" ") +aten = torch.ops.aten + +""" [Note: Inductor IR] + +Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each +lowering is registered to a particular aten operator, and expects inputs that +correspond to the aten schema. However, in place of torch Tensor inputs, lowerings +expect Inductor TensorBox inputs. + +TensorBox IR represents torch tensors. Tensors are sometimes single objects owning +storage, and sometimes views of another Tensor's storage. Mutating tensor operations +(such as add_()) affect the underlying storage and any associated views. Other operations +(such as .t_()) update metadata about the current view but don't modify the underlying storage. + +To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer. + +TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor +output from an operation. But just as torch.Tensors take different forms, TensorBox IR can +reference View IR or directly reference StorageBox IRs. + +Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops) +may take an existing TensorBox and point it to a new underlying View IR. + +Tensors that directly own storage are represented as a chain of: +TensorBox -> StorageBox -> Buffer +where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout. + +If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer +(leaving the old buffer unmodified and functionalizing the operation). + +Tensors backed by views add one more indirection to the IR. +TensorBox -> View -> StorageBox -> Buffer +In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox. + +Computation is represented by Operation nodes, with each operation producing 1 +or more output Buffers. In the case of mutations, these will be new Buffers that have the +mutated buffer listed in its get_mutation_names(). + +It is also possible to have an InputBuffer for which there is no corresponding Operation, +e.g. it may be a graph input or compile time constant. + +""" + + +_NodeOrNodes: TypeAlias = Union[ + int, + "TensorBox", + Dict[str, "TensorBox"], + "Symbol", + "IRNode", + Sequence[ + Optional[Union[int, Dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]] + ], +] + + +def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None: + def _check_tensorbox(nodes: Optional[_NodeOrNodes]) -> None: + # Could expand this to check deeper properties + # (e.g. TensorBox points to View or StorageBox) + if nodes is None: + pass + elif isinstance(nodes, (list, tuple)): + for node in nodes: + _check_tensorbox(node) + elif isinstance(nodes, dict): + for node in nodes.values(): + _check_tensorbox(node) + else: + assert isinstance( + nodes, + ( + torch._inductor.ir.ExpandView, + DynamicScalar, + AssertScalar, + TensorBox, + sympy.logic.boolalg.Boolean, + Expr, + int, + EffectfulKernel, + ), + ), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" + + # Be picky about the accepted data structure (don't use pytree here) + _check_tensorbox(node_or_nodes) + + +def ops_wrapper(name: str) -> Callable[..., OpsValue]: + assert isinstance(name, str) + + def fn(*args: object, **kwargs: object) -> OpsValue: + return getattr(ops, name)(*args, **kwargs) + + return fn + + +def inverse_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]: + inv_order = dict(zip(order, range(len(order)))) + + def reindex(index: Sequence[_T]) -> Sequence[_T]: + assert len(index) == len(inv_order) + return [index[inv_order[i]] for i in range(len(index))] + + return reindex + + +def same_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]: + def reindex(index: Sequence[_T]) -> Sequence[_T]: + assert len(index) == len(order) + return [index[order[i]] for i in range(len(index))] + + return reindex + + +def fuse_reindexing( + reindex1: Callable[[Sequence[_U]], Sequence[_V]], + reindex2: Callable[[Sequence[_T]], Sequence[_U]], +) -> Callable[[Sequence[_T]], Sequence[_V]]: + def reindex(index: Sequence[_T]) -> Sequence[_V]: + return reindex1(reindex2(index)) + + return reindex + + +NHWC_STRIDE_ORDER = [3, 0, 2, 1] +NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1] + + +def stride_order2fill_order( + order: Sequence[Union[int, Integer]] +) -> Sequence[Union[int, Integer]]: + """ + Convert stride order to fill order + For channel last format, + + stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0] + """ + lookup = {pos: idx for idx, pos in enumerate(order)} + fill_order = [lookup[i] for i in range(len(order))] + return fill_order + + +def get_stride_order(seq: Sequence[Union[int, torch.SymInt, Expr]]) -> Sequence[int]: + """ + Convert strides to stride order + """ + sorted_idx: List[int] = argsort(seq) + out = [0 for _ in range(len(seq))] + for i, elem in enumerate(sorted_idx): + out[elem] = i + return out + + +@overload +def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: + ... + + +@overload +def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: + ... + + +def ir_node_to_tensor( + x: Optional[IRNode], guard_shape: bool = True +) -> Optional[torch.Tensor]: + if x is None: + return None + + shape_fn: Callable[[Union[int, Expr]], Union[int, Expr]] + if not guard_shape: + shape_fn = V.graph.sizevars.size_hint + else: + shape_fn = identity + size = [shape_fn(s) for s in x.get_size()] + stride: StrideType + if is_storage_and_layout(x): + stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc, union-attr] + else: + stride = FlexibleLayout.contiguous_strides(size) # type: ignore[assignment] + dtype = x.get_dtype() + device = x.get_device() + size = convert_shape_to_symint(size) + stride = convert_shape_to_symint(stride) + with V.graph.sizevars.shape_env.suppress_guards(): + t = torch.empty_strided( + size=size, stride=stride, dtype=dtype, device=device + ).zero_() + return t + + +def may_convert_to_optional( + value: Optional[Sequence[_T]], +) -> Optional[Sequence[Optional[_T]]]: + if isinstance(value, list) and not value: + # [None] makes sure the cpp wrapper codegen will generate something like + # {std::nullopt} instead of {} + return [None] + return value + + +def get_device_type(x: object) -> Optional[str]: + if get_device := getattr(x, "get_device", None): + return get_device_type(get_device()) + if isinstance(x, torch.device): + return x.type + return None + + +def is_triton(x: object) -> bool: + dtype = get_device_type(x) + return bool(dtype and is_gpu(dtype)) + + +def is_cpu(x: object) -> bool: + return get_device_type(x) == "cpu" + + +class IRNode: + _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet() + + @staticmethod + @contextlib.contextmanager + def current_origins(origins: OrderedSet[torch.fx.Node]): + old = IRNode._current_origins + IRNode._current_origins = old | origins + try: + yield + finally: + IRNode._current_origins = old + + def __post_init__(self): + self.origins = OrderedSet(self._current_origins) + self.traceback = traceback.format_stack() if config.debug_ir_traceback else None + + def get_read_names(self) -> OrderedSet[str]: + raise NotImplementedError(f"NYI on {type(self)}") + + def get_traceback(self): + return self.traceback + + def get_defining_op(self): + raise NotImplementedError + + def common_repr(self, shorten=True): + origins = f"origins={getattr(self, 'origins', '')}" + if shorten and len(origins) > 64: + # this can get *very* long + origins = f"{origins[:61]}..." + return [origins] + + def str_helper(self, lines, shorten=True, multiline=True): + lines = lines + self.common_repr(shorten) + lines = list(map(str, lines)) + if multiline: + new_lines = indent(",\n".join(lines)) + return f"{type(self).__name__}(\n{new_lines}\n)" + else: + return f"{type(self).__name__}({lines})" + + def get_dtype(self): + return self.dtype + + def get_layout(self): + raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!") + + def get_size(self): + raise NotImplementedError(f"get_size() is not implemented by {type(self)}!") + + @property + def shape(self): + return self.get_size() + + def get_numel(self): + return sympy_product(self.get_size()) + + def is_zero_elements(self): + return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] + + def realize(self): + """ + If the IRNode refers to data which has not been materialized (e.g., + it is a Pointwise/Reduction that could potentially have more + compute fused into it), realize the IRNode into physical memory, + ending the possibility of fusing into it, but allowing, e.g., multiple + users to access the data without having to recompute. + + Check StorageBox.realize for a particularly notable implementation. + + TODO(ezyang): I think, in principle, every IRNode should have an + implementation of this, and most of the time no-op is OK, but you + really do have to audit each IRNode for this, so for now, raise + an error if it's not implemented. Note that some code in graph.py + will catch this thrown error and suppress it with a warning. + """ + raise NotImplementedError(f"realize NYI on {type(self)}") + + def codegen_reference(self, writer=None): + raise NotImplementedError(f"codegen_reference NYI on {type(self)}") + + # The abstract method declarations below serve to convince mypy that all IRNode instances have these functions + # defined, while having no effect at runtime. We cannot create stub implementations here because other parts of + # the code dynamically check for defined attributes. + get_device: Callable[[], torch.device] + dtype: torch.dtype + get_name: Callable[[], str] + get_reads: Callable[[], Any] + num_reads: Callable[[], int] + get_stride: Callable[[], Any] + get_storage_numel: Callable[[], Any] + has_exceeded_max_reads: Callable[[], bool] + make_loader: Callable[[], Callable[[Any], Any]] + make_indexer: Callable[[], Callable[[Any], Any]] + mark_reuse: Callable[[int], None] + realize_hint: Callable[[], None] + get_unbacked_symbol_uses: Callable[[], OrderedSet[sympy.Symbol]] + + +@dataclasses.dataclass +class Operation: + def __post_init__(self): + self.operation_name: Optional[str] = None + + def get_device(self): + raise NotImplementedError + + def get_origin_node(self): + assert hasattr(self, "origin_node") + return self.origin_node + + def get_origins(self): + assert hasattr(self, "origins") + return self.origins + + def get_operation_name(self) -> str: + assert self.operation_name is not None + return self.operation_name + + def is_extern(self): + return False + + def is_no_op(self): + return False + + def get_read_writes(self): + raise NotImplementedError + + def is_user_of(self, name): + return name in self.get_read_names() + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet(dep.name for dep in self.get_reads()) + + def get_reads(self): + return self.get_read_writes().reads + + def get_outputs(self) -> List[Buffer]: + raise NotImplementedError + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + """ + Returns the unbacked symbols which are required to be in scope in + order to successfully perform codegen for this buffer. For example, + a buffer that corresponds to an extern kernel call that takes i0 as + an argument would return {i0} here. This is used to generate necessary + dependencies that ensure we actually bind i0 in codegen before you + try to use it. + + Note that this is NOT transitive; in particular, if this buffer takes + in as input another buffer with dynamic shape (e.g., (i0,)), we will + not report it here, because you will already have a dependency + on that buffer, which will eventually have a dependency on i0 if + necessary. + """ + return OrderedSet() + + def get_workspace_size(self): + """ + Gets extra global memory size needed by this buffer. + Some algorithms (e.g. group gemm) may require extra global memory in the generated code. + """ + return 0 + + +@dataclasses.dataclass +class Loops(IRNode): + device: torch.device + dtype: torch.dtype + inner_fn: Callable[..., Any] + ranges: List[Expr] + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet().union( + *(free_unbacked_symbols(e) for e in self.ranges), + self.inner_fn_free_unbacked_symbols(), + ) + + def __str__(self, names=("ranges",)): + return self.str_helper( + [ + f"'{self.device.type}'", + str(self.dtype), + self.inner_fn_str(), + ] + + [f"{name}={getattr(self, name)}" for name in names] + + [f"origin_node={self.origin_node!r}"] + ) + + def __post_init__(self): + super().__post_init__() + self.origin_node = None + + __repr__ = __str__ + + def get_device(self): + return self.device + + def get_origin_node(self): + return self.origin_node + + def get_size(self): + return self.ranges + + def get_pointwise_size(self): + return self.ranges + + def is_extern(self): + return False + + @classmethod + def create(cls, *args, **kwargs): + origin_node = kwargs.pop("origin_node", None) + tb = kwargs.pop("traceback", None) + r = cls(*args, **kwargs) + r.origin_node = origin_node + r.traceback = ( + tb or traceback.format_stack() if config.debug_ir_traceback else None + ) + return TensorBox.create(r) + + @staticmethod + def _index(ranges, prefix=SymT.INDEX): + return [ + sympy.Integer(0) if s == 1 else sympy_index_symbol_with_prefix(prefix, n) + for n, s in enumerate(ranges) + ] + + @cache_on_self + def inner_fn_opcount(self) -> OpCountResult: + opcounter = OpCounterCSE(V.MockHandler()) + with V.set_ops_handler(opcounter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + self.inner_fn(*self.inner_fn_args()) + return opcounter.getvalue() + + def inner_fn_args(self): + return (self._index(self.ranges),) + + @cache_on_self + def inner_fn_str(self): + return V.KernelFormatterHandler.ir_to_string( + self.inner_fn, *self.inner_fn_args() + ) + + def has_large_inner_fn(self): + return self.inner_fn_opcount().num_ops > config.realize_opcount_threshold + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + return extract_free_unbacked_symbols(self.inner_fn, index) + + def get_reads(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + if self.get_reduction_type(): + return extract_read_writes( + self.make_loader(), + self.get_size(), + self.get_reduction_size(), + ).reads + else: + return extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet(self.inner_fn_opcount().read_buffers) + + def num_reads(self): + return len(self.inner_fn_opcount().read_buffers) + + def get_reduction_size(self): + raise NotImplementedError( + f"get_reduction_size() is not implemented by {type(self)}!" + ) + + def get_reduction_type(self): + raise NotImplementedError( + f"get_reduction_type() is not implemented by {type(self)}!" + ) + + def constant_to_device(self, device): + raise NotImplementedError( + f"constant_to_device() is not implemented by {type(self)}!" + ) + + +def nop_loader_fn(idx: Union[Expr, Sequence[Expr]], *, dtype: torch.dtype) -> OpsValue: + if dtype.is_floating_point: + return ops.constant(float("nan"), dtype) + else: + return ops.constant(0, dtype) + + +class Pointwise(Loops): + def make_loader(self): + # Make zero-element loops into a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.dtype) + + return self.inner_fn + + def get_reduction_size(self): + return [] + + def get_reduction_type(self): + return None + + def store_output(self, output_name, indexer, vars): + loader = self.make_loader() + return ops.store(output_name, indexer(vars), loader(vars)) + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Pointwise(device, self.dtype, loader, self.ranges) + + +@dataclasses.dataclass +class Scatter(Pointwise): + output_indexer: Callable[[List[Expr]], Expr] + scatter_mode: Optional[str] = None + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Scatter( + device, + self.dtype, + loader, + self.ranges, + self.output_indexer, + self.scatter_mode, + ) + + def store_output(self, output_name, indexer, vars): + loader = self.make_loader() + return ops.store( + output_name, + indexer(self.output_indexer(vars)), + loader(vars), + mode=self.scatter_mode, + ) + + +REDUCTION_COMBINE_FN: Dict[str, Callable[..., OpsValue]] = { + "any": ops_wrapper("logical_or"), + "max": ops_wrapper("maximum"), + "min": ops_wrapper("minimum"), + "prod": ops_wrapper("mul"), + "sum": ops_wrapper("add"), + "xor_sum": ops_wrapper("bitwise_xor"), +} + + +def get_reduction_combine_fn( + reduction_type: str, dtype: torch.dtype, arg_break_ties_left: bool = True +) -> Callable[..., object]: + if reduction_type in REDUCTION_COMBINE_FN: + return REDUCTION_COMBINE_FN[reduction_type] + + elif reduction_type in ("argmax", "argmin"): + + def argmax_combine_fn( + a: Tuple[object, object], b: Tuple[object, object] + ) -> Tuple[OpsValue, OpsValue]: + a_value, a_index = a + b_value, b_index = b + + if reduction_type == "argmin": + mask = ops.lt(a_value, b_value) + else: + mask = ops.gt(a_value, b_value) + + equal = ops.eq(a_value, b_value) + if is_float_dtype(dtype): + a_isnan = ops.ne(a_value, a_value) + b_isnan = ops.ne(b_value, b_value) + mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan)) + equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan)) + + tie = ( + ops.lt(a_index, b_index) + if arg_break_ties_left + else ops.gt(a_index, b_index) + ) + mask = ops.logical_or(mask, ops.logical_and(equal, tie)) + return ( + ops.where(mask, a_value, b_value), + ops.where(mask, a_index, b_index), + ) + + return argmax_combine_fn + + elif reduction_type == "welford_combine": + + def welford_combine_fn( + a: Tuple[OpsValue, OpsValue, OpsValue], + b: Tuple[OpsValue, OpsValue, OpsValue], + ) -> Tuple[OpsValue, OpsValue, OpsValue]: + a_mean, a_m2, a_weight = a + b_mean, b_m2, b_weight = b + + delta = b_mean - a_mean + new_weight = a_weight + b_weight + w2_over_w = b_weight / new_weight + return ( + a_mean + delta * w2_over_w, + a_m2 + b_m2 + delta * delta * a_weight * w2_over_w, + new_weight, + ) + + return welford_combine_fn + + else: + raise NotImplementedError(f"unknown reduction_type={reduction_type}") + + +def significant_strides_equal( + strides1: Sequence[_IntLike], strides2: Sequence[_IntLike], size: Sequence[_IntLike] +) -> bool: + """ + Returns true if the strides are equal, ignoring dimensions of size 1 . + """ + non_1_indices = [ + i + for i, dim in enumerate(size) + if V.graph.sizevars.size_hint(dim, fallback=2) != 1 + ] + strides1 = [V.graph.sizevars.size_hint(strides1[i]) for i in non_1_indices] + strides2 = [V.graph.sizevars.size_hint(strides2[i]) for i in non_1_indices] + return strides1 == strides2 + + +@dataclasses.dataclass +class Reduction(Loops): + reduction_ranges: List[Expr] + reduction_type: str + # self.dtype represents the dst dtype + src_dtype: torch.dtype + reduction_hint: ReductionHint + + def __str__(self) -> str: # type: ignore[override] + return Loops.__str__( # type: ignore[call-arg] + self, names=("ranges", "reduction_ranges", "reduction_type") + ) + + def __repr__(self) -> str: # type: ignore[override] + return self.__str__() + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return super().get_unbacked_symbol_uses() | OrderedSet().union( + *(free_unbacked_symbols(e) for e in self.reduction_ranges) + ) + + def get_reduction_size(self): + return self.reduction_ranges + + def get_reduction_type(self): + return self.reduction_type + + def store_reduction(self, output_name, indexer, vars, reduction_vars): + value = ops.reduction( + self.dtype, + self.src_dtype, + self.reduction_type, + self.inner_fn(vars, reduction_vars), + ) + return ops.store_reduction(output_name, indexer(vars), value) + + def index_length(self): + return len(self.ranges) + len(self.reduction_ranges) + + def inner_fn_args(self): + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, SymT.RINDEX) + return (index, rindex) + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, SymT.RINDEX) + return extract_free_unbacked_symbols(self.inner_fn, index, rindex) + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Reduction( + device, + self.dtype, + loader, + self.ranges, + self.reduction_ranges, + self.reduction_type, + self.src_dtype, + ReductionHint.DEFAULT, + ) + + @staticmethod + def num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node: Optional[IRNode] = None, + ): + def _is_static(x): + return isinstance(x, (int, sympy.Integer)) + + reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel) + numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges)) + + should_split = ( + not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT) + and reduction_type + not in ( + "argmax", + "argmin", + ) + and config.split_reductions + # We don't support unbacked symints + and _is_static(reduction_numel_hint) + and _is_static(numel_hint) + ) + if not should_split: + return ReductionHint.DEFAULT, 1 + + device_interface = get_interface_for_device(get_device_type(device)) # type: ignore[arg-type] # next PR + device_properties = device_interface.Worker.get_device_properties(device) + if get_device_type(device) == "xpu": + num_sm = device_properties.gpu_subslice_count + else: + # default is cuda behavior + num_sm = device_properties.multi_processor_count + + min_elements_per_thread = 32 + max_elements_per_thread = 512 + threads_per_sm = 2048 + min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm + max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm + + def inner_reduction_splits(reduction_numel_hint, numel_hint): + # do heuristics that's close to eager mode for split inner reduction + # we leak reduction autotune configs here, and will need to refactor to avoid this later + num_warps = 8 + num_threads = 32 * num_warps + if numel_hint >= 2 * num_sm: # don't split if there are enough outputs + return 1 + if reduction_numel_hint <= 8192: + return 1 + if reduction_numel_hint * numel_hint <= min_elements_per_device: + split_size = min_elements_per_thread + elif reduction_numel_hint * numel_hint < max_elements_per_device: + target_blocks = num_sm * threads_per_sm // (2 * num_threads) + blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint + tmp_split_size = ( + reduction_numel_hint + num_threads * blocks_per_output - 1 + ) // (num_threads * blocks_per_output) + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) + if abs(closest - tmp_split_size) < 30: + # prefer even splits, but never smalle than min_elements_per_thread + split_size = max(closest, min_elements_per_thread) + else: + split_size = tmp_split_size + else: + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) + if abs(closest - max_elements_per_thread) < 50: + # prefer even splits + split_size = closest + else: + split_size = max_elements_per_thread + return (reduction_numel_hint + split_size * num_threads - 1) // ( + split_size * num_threads + ) + + def outer_reduction_splits(reduction_numel_hint, numel_hint): + # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128 + # extend to even smaller number of outputs + num_warps = 8 + num_threads = num_warps * 32 + rvals_per_thread = 4 # comes from heuristics, refactor to not leak here + xvals_per_block = 128 + xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block + if reduction_numel_hint * numel_hint < min_elements_per_device: + split_size = min_elements_per_thread + elif reduction_numel_hint * numel_hint < max_elements_per_device: + target_blocks = num_sm * threads_per_sm // (num_threads) + target_blocks = (target_blocks + xblocks - 1) // xblocks + tmp_split_size = ( + reduction_numel_hint + rvals_per_thread * target_blocks - 1 + ) // (rvals_per_thread * target_blocks) + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) + if abs(tmp_split_size - closest) < 20: + split_size = max(closest, min_elements_per_thread) + else: + split_size = tmp_split_size + else: + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) + if abs(closest - max_elements_per_thread) < 50: + # prefer even splits + split_size = closest + else: + split_size = max_elements_per_thread + + return (reduction_numel_hint + rvals_per_thread * split_size - 1) // ( + rvals_per_thread * split_size + ) + + # easy cases + if numel_hint == 1: + split = inner_reduction_splits(reduction_numel_hint, numel_hint) + if split == 1: + # No need to split. + return ReductionHint.INNER, split + if input_node is not None and isinstance(input_node, TensorBox): + new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( + input_node + ) + if new_ranges is not None and new_reduction_ranges is not None: + extracted_numel_hint = V.graph.sizevars.symbolic_hint( + sympy_product(new_ranges + new_reduction_ranges) + ) + if reduction_numel_hint == extracted_numel_hint: + log.debug( + "Use previous IRNode's range and reduction_ranges instead of split. " + "current ranges: %s, current reduction ranges: %s, current split: %d, " + "new ranges: %s, new reduction ranges: %s", + ranges, + reduction_ranges, + split, + new_ranges, + new_reduction_ranges, + ) + # If the input_node or its dependent nodes are also Reduction nodes, + # use reduction_sizes of this node or its dependent nodes directly. + return ReductionHint.INNER, -1 + return ReductionHint.INNER, split + if ( + reduction_numel_hint <= min_elements_per_thread + or numel_hint >= num_sm * 2 * 32 + ): + return ReductionHint.DEFAULT, 1 + + r = Reduction( + device, + dst_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + src_dtype, + ReductionHint.DEFAULT, + ) + + def get_read_indices(r): + cb = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=r.get_device(), + dtype=r.get_dtype(), + size=r.get_size(), + ), + data=r, + ) + read_writes = cb.get_read_writes() + # try finding the full size producer + # TODO this will fail for something like ((1, N) * (N, 1)).sum() + # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare + range_vars = [ + r + for r in read_writes.range_vars + if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number) + ] + indices = [] + changed = False + for md in sorted(read_writes.reads, key=lambda x: x.name): + if all(r in md.index.free_symbols for r in range_vars): + indices.append(md.index) + if md.name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[md.name] + original_stride = buf.layout.stride + buf.decide_layout() + if buf.layout.stride != original_stride: + changed = True + return indices, changed + + indices, changed = get_read_indices(r) + if changed: + indices, _ = get_read_indices(r) + + if len(indices) == 0: + # TODO determine splits when all inputs are broadcast + return ReductionHint.DEFAULT, 1 + + (_, reduction_vars), ranges = dependencies.index_vars_squeeze( + r.get_size(), r.get_reduction_size() + ) + num_outer = 0 + num_inner = 0 + for i in indices: + i = V.graph.sizevars.simplify_with_ranges(i, ranges) + strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys()) + outer = all(s > 1 for s in strides) + if outer: + num_outer += 1 + else: + num_inner += 1 + if num_inner > num_outer: + return ReductionHint.INNER, inner_reduction_splits( + reduction_numel_hint, numel_hint + ) + else: + return ReductionHint.OUTER, outer_reduction_splits( + reduction_numel_hint, numel_hint + ) + + @staticmethod + def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type, src_dtype): + """Convert inner_fn from a reduction to an pointwise""" + reduction_ranges = [ + V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges + ] + + combine_fn = get_reduction_combine_fn(reduction_type, src_dtype) + + def fn(index): + return functools.reduce( + combine_fn, + ( + value_fn(index, rindex) + for rindex in itertools.product( + *[range(x) for x in reduction_ranges] + ) + ), + ) + + if reduction_type in ("argmin", "argmax"): + flatten_index = FixedLayout( + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + reduction_ranges, + FlexibleLayout.contiguous_strides(reduction_ranges), + ).make_indexer() + + def value_fn(index, rindex): + rindex = [sympy.expand(i) for i in rindex] + return ( + inner_fn(index, rindex), + ops.index_expr(flatten_index(rindex), torch.int64), + ) + + return lambda index: fn(index)[1] + else: + value_fn = inner_fn + return fn + + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + input_node: Optional[IRNode] = None, + ): + reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + if reduction_numel == 0: + # N.B. This is a hack to generate the literal of the given type + # Ideally, we should be fixing `def constant` in triton.py + # but it breaks due to hardcoded dtypes in other places + def py_cnst(val): + return ( + bool(val) + if dst_dtype == torch.bool + else float(val) + if dst_dtype.is_floating_point + else int(val) + ) + + rtypes_to_inits = { + "sum": py_cnst(0), + "xor_sum": py_cnst(0), + "prod": py_cnst(1), + "any": py_cnst(0), + # "all" is desugared to `!any(!val)` + } + + assert ( + reduction_type in rtypes_to_inits.keys() + ), f"{reduction_type} not supported for zero-dimension tensors!" + + def const_fn(index): + return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) + + return Pointwise.create( + device=device, + dtype=src_dtype, + inner_fn=const_fn, + ranges=list(ranges), + ) + + if reduction_numel == 1: + # this reduction is actually a pointwise op + if reduction_type in ("argmin", "argmax"): + + def fn(index): + return ops.constant(0, dst_dtype) + + else: + + def fn(index): + reduction_index = [sympy.Integer(0) for _ in reduction_ranges] + return inner_fn(index, reduction_index) + + return Pointwise.create(device, dst_dtype, fn, ranges) + + if ( + isinstance(reduction_numel, sympy.Integer) + and V.graph.sizevars.size_hint(reduction_numel) + < config.unroll_reductions_threshold + and sympy_product(ranges) != 1 + ): + return Pointwise.create( + device, + dst_dtype, + cls._unroll_reduction_fn( + inner_fn, reduction_ranges, reduction_type, src_dtype + ), + ranges, + ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = cls.num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node, + ) + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ReductionHint.DEFAULT: + reduction_hint = hint + if split == -1: + assert input_node is not None + new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( + input_node # type: ignore[arg-type] + ) + assert new_ranges is not None + assert new_reduction_ranges is not None + return cls.create_multilayer_existing_ranges( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + elif split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + ) + + return TensorBox.create( + Reduction( + device, + dst_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + src_dtype, + reduction_hint, + ) + ) + + @staticmethod + def default_accumulator(reduction_type, dtype): + if reduction_type in ("max", "argmax"): + if is_float_dtype(dtype): + return float("-inf") + elif is_boolean_dtype(dtype): + return 0 + else: + return torch.iinfo(dtype).min + if reduction_type in ("min", "argmin"): + if is_float_dtype(dtype): + return float("inf") + elif is_boolean_dtype(dtype): + return 1 + else: + return torch.iinfo(dtype).max + + return { + "sum": 0, + "prod": 1, + "xor_sum": 0, + "any": 0, + "welford_reduce": (0, 0, 0), + "welford_combine": (0, 0, 0), + }[reduction_type] + + @staticmethod + def default_value(reduction_type, dtype): + if reduction_type == "welford_reduce": + return 0 + return Reduction.default_accumulator(reduction_type, dtype) + + @staticmethod + def _multilayer_second_step_hint( + split: int, numel_hint: int, reduction_hint: ReductionHint + ) -> ReductionHint: + if split == -1: + return reduction_hint + if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER: + return ReductionHint.OUTER_TINY + if ( + split <= 1024 + and numel_hint <= 256 + and reduction_hint == ReductionHint.OUTER + ): + return ReductionHint.OUTER_TINY + + return reduction_hint + + @classmethod + def _multilayer_wrap_loader( + cls, + loader, + reduction_ranges, + reduction_numel, + split, + block_size, + default, + ): + reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel]) + need_mask = not V.graph.sizevars.is_expr_static_and_true( + sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] + ) + + def wrapper_fn(index, reduction_index): + (reduction_index,) = reduction_index + *new_index, reduction_block = index + indices = block_size * reduction_block + reduction_index + + def body(): + return loader(new_index, reindex([indices])) + + if need_mask: + mask = ops.lt( + ops.index_expr(indices, torch.int32), + ops.index_expr(reduction_numel, torch.int32), + ) + return ops.masked(mask, body, default) + else: + return body() + + return wrapper_fn + + @classmethod + def _multilayer_wrap_loader_existing_ranges( + cls, + loader, + original_ranges, + original_reduction_ranges, + new_ranges, + new_reduction_ranges, + default, + ): + assert all( + r == 1 for r in original_ranges + ), f"Only enabled for numel_hint == 1, found {original_ranges=}" + reindex = View.dynamic_reshape_indexer( + original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges) + ) + + def wrapper_fn(merged_index, new_reduction_index): + original_idx = merged_index[: len(original_ranges)] + new_index = merged_index[len(original_ranges) :] + return loader( + original_idx, + reindex(tuple(new_index) + tuple(new_reduction_index)), + ) + + return wrapper_fn + + @classmethod + def create_multilayer_helper( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + wrapper_fn: Callable[..., Any], + original_ranges: List[Expr], + original_reduction_ranges: List[Expr], + new_ranges: List[Expr], + new_reduction_ranges: List[Expr], + reduction_type: str, + split: int, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + # triton will automatically compute reductions in fp32 if reducing over fp16/bf16 + # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction + # in fp32 and not reduce precision by breaking up the kernel into multiple layers + intermediate_dtype = ( + dst_dtype + if dst_dtype not in (torch.float16, torch.bfloat16) + else torch.float + ) + intermediate = Reduction.create( + device, + intermediate_dtype, + src_dtype, + wrapper_fn, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + intermediate.realize() + intermediate_loader = intermediate.make_loader() + + def intermediate_fn(index, reduction_index): + return intermediate_loader([*index, *reduction_index]) + + numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges)) + reduction_hint = cls._multilayer_second_step_hint( + split, numel_hint, reduction_hint + ) + + assert original_ranges == new_ranges[: len(original_ranges)] + return TensorBox.create( + Reduction( + device, + dst_dtype, + intermediate_fn, + original_ranges, + new_ranges[len(original_ranges) :], + reduction_type, + src_dtype, + reduction_hint, + ) + ) + + @classmethod + def create_multilayer( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + split: int, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + # TODO(jansel): realize the reduction so we can do dynamic indexing + reduction_numel = sympy_product(reduction_ranges) + block_size = FloorDiv(reduction_numel + (split - 1), split) + default = cls.default_value(reduction_type, dst_dtype) + wrapper_fn = cls._multilayer_wrap_loader( + inner_fn, reduction_ranges, reduction_numel, split, block_size, default + ) + + return cls.create_multilayer_helper( + device, + dst_dtype, + src_dtype, + wrapper_fn, + ranges, + reduction_ranges, + [*ranges, split], # type: ignore[list-item] + [block_size], + reduction_type, + split, + reduction_hint, + ) + + @classmethod + def create_multilayer_existing_ranges( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + original_ranges: List[Expr], + original_reduction_ranges: List[Expr], + new_ranges: List[Expr], + new_reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + default = cls.default_value(reduction_type, dst_dtype) + wrapper_fn = cls._multilayer_wrap_loader_existing_ranges( + inner_fn, + original_ranges, + original_reduction_ranges, + new_ranges, + new_reduction_ranges, + default, + ) + return cls.create_multilayer_helper( + device, + dst_dtype, + src_dtype, + wrapper_fn, + original_ranges, + original_reduction_ranges, + [*original_ranges, *new_ranges], + new_reduction_ranges, + reduction_type, + -1, + reduction_hint, + ) + + +class WelfordReduction(Reduction): + output_index: int + + def __init__( + self, + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + reduction_hint, + output_index, + ): + if len(inner_fns) == 1: + loader = inner_fns[0] + else: + + def loader(idx, reduction_idx): + return tuple(fn(idx, reduction_idx) for fn in inner_fns) + + super().__init__( + device, + dtype, + loader, + ranges, + reduction_ranges, + reduction_type, + dtype, + reduction_hint, + ) + self.output_index = output_index + + def store_reduction(self, output_name, indexer, vars, reduction_vars): + values = ops.reduction( + self.dtype, + self.src_dtype, + self.reduction_type, + self.inner_fn(vars, reduction_vars), + ) + value = values[self.output_index] + return ops.store_reduction(output_name, indexer(vars), value) + + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dtype: torch.dtype, + inner_fns: Sequence[Callable[..., Any]], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + ): + assert reduction_type in ("welford_reduce", "welford_combine") + + reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + def const(val): + def inner_fn(idx): + return ops.constant( + val, + dtype, + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(ranges), + ) + + if reduction_numel == 0: + mean = const(0) + m2 = const(0) + weight = const(0) + return mean, m2, weight + + if reduction_numel == 1: + + def copy(loader): + def inner_fn(idx): + reduction_index = [sympy.Integer(0) for _ in reduction_ranges] + return loader(idx, reduction_index) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(ranges), + ) + + if reduction_type == "welford_reduce": + return copy(inner_fns[0]), const(0), const(1) + else: + return tuple(copy(fn) for fn in inner_fns) + + # TODO: Unrolled reduction + # if ( + # isinstance(reduction_numel, sympy.Integer) + # and V.graph.sizevars.size_hint(reduction_numel) + # < config.unroll_reductions_threshold + # and sympy_product(ranges) != 1 + # ): + # return Pointwise.create( + # device, + # dst_dtype, + # cls._unroll_reduction_fn( + # inner_fn, reduction_ranges, reduction_type, src_dtype + # ), + # ranges, + # ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = Reduction.num_splits( + device, + dtype, + dtype, + inner_fns[0], + ranges, + reduction_ranges, + reduction_type=reduction_type, + reduction_numel=reduction_numel, + ) + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ReductionHint.DEFAULT: + reduction_hint = hint + if split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + ) + + results = [ + TensorBox.create( + WelfordReduction( + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + reduction_hint, + output_idx, + ) + ) + for output_idx in range(3) + ] + for t in results: + t.realize() + return results + + @staticmethod + def default_value(reduction_type, dtype): + return (0, 0, 0) + + @classmethod + def create_multilayer( # type: ignore[override] + cls, + device: torch.device, + dtype: torch.dtype, + inner_fns: Sequence[Callable[..., Any]], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + split: int, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + reduction_numel = sympy_product(reduction_ranges) + need_mask = not V.graph.sizevars.is_expr_static_and_true( + sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] + ) + + if need_mask and reduction_type != "welford_combine": + # If we need mask, then "welford_reduce" doesn't work because + # masked inputs shouldn't count towards the welford weight + + def constant(idx, reduction_idx, value): + return ops.constant(value, dtype) + + return cls.create_multilayer( + device=device, + dtype=dtype, + inner_fns=( + inner_fns[0], + partial(constant, value=0), + partial(constant, value=1), + ), + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type="welford_combine", + split=split, + reduction_hint=reduction_hint, + ) + + block_size = FloorDiv(reduction_numel + (split - 1), split) + intermediates = WelfordReduction.create( + device, + dtype, + tuple( + cls._multilayer_wrap_loader( + loader, + reduction_ranges, + reduction_numel, + split, + block_size, + default=0, + ) + for loader in inner_fns + ), + [*ranges, split], # type: ignore[list-item] + [block_size], + reduction_type, + reduction_hint, + ) + for i in intermediates: + i.realize() + + i_loaders = [i.make_loader() for i in intermediates] + + def intermediate_loader_fn(index, reduction_index, loader): + return loader([*index, *reduction_index]) + + numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges)) + reduction_hint = cls._multilayer_second_step_hint( + split, numel_hint, reduction_hint + ) + return WelfordReduction.create( + device, + dtype, + tuple( + partial(intermediate_loader_fn, loader=i.make_loader()) + for i in intermediates + ), + ranges, + [split], # type: ignore[list-item] + # welford_reduce turns one input into three outputs, which are combined with welford_combine + "welford_combine", + reduction_hint, + ) + + +@dataclasses.dataclass +class Scan(Loops): + scan_ranges: List[Expr] + size: List[Expr] + combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]] + reindex: Callable[[List[Expr], List[Expr]], List[Expr]] + reduction_hint: ReductionHint + output_index: int + # output_index indexes the following tuples + dtypes: Tuple[torch.dtype, ...] + inner_fns: Tuple[Callable[..., Any], ...] + + # HACK we mimick reduction + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we + # need to explicitly represent the closure so we can pull out unbacked + # symbols here + return ( + super().get_unbacked_symbol_uses() + | OrderedSet().union(*(free_unbacked_symbols(e) for e in self.scan_ranges)) + | OrderedSet().union(*(free_unbacked_symbols(e) for e in self.size)) + ) + + def __post_init__(self): + assert len(self.ranges) + len(self.scan_ranges) == len(self.size) + super().__post_init__() + + def store_reduction(self, output_name, indexer, vars, scan_vars): + idx = self.reindex(vars, scan_vars) + values = [inner_fn(idx) for inner_fn in self.inner_fns] + result = ops.scan(self.dtypes, self.combine_fn, values) + return ops.store(output_name, indexer(idx), result[self.output_index]) + + def get_reduction_type(self): + # return self.scan_op + return "custom" + + def get_reduction_size(self): + return self.scan_ranges + + def get_size(self): + return self.size + + def get_pointwise_size(self): + return self.ranges + + def index_length(self): + return len(self.ranges) + len(self.scan_ranges) + + def inner_fn_args(self): + index = self._index(self.ranges) + rindex = self._index(self.scan_ranges, SymT.RINDEX) + idx = self.reindex(index, rindex) + return (idx,) + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + rindex = self._index(self.scan_ranges, SymT.RINDEX) + idx = self.reindex(index, rindex) + return extract_free_unbacked_symbols(self.inner_fn, idx) + + @classmethod + def create( + cls, + device: torch.device, + dtypes: Tuple[torch.dtype, ...], + inner_fns: Tuple[Callable[[List[Expr]], Any], ...], + size: List[Expr], + axis: int, + combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]], + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + *, + # Whether we have the option to fallback to aten + can_fallback_to_aten: bool = True, + **kwargs, + ) -> List[Optional[TensorBox]]: + pointwise_ranges = [*size[:axis], *size[axis + 1 :]] + scan_ranges = [size[axis]] + + if not V.graph.has_feature(device, BackendFeature.SCAN): + return [None] * len(dtypes) + + if len(dtypes) > 1 and not V.graph.has_feature( + device, BackendFeature.TUPLE_REDUCTION + ): + return [None] * len(dtypes) + + sizevars = V.graph.sizevars + scan_numel = sizevars.simplify(sympy_product(scan_ranges)) + + assert len(dtypes) == len(inner_fns) + + # Scan with a single element is just a copy + if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): # type: ignore[arg-type] + return [ + Pointwise.create( + device=device, + dtype=dtypes[output_index], + inner_fn=inner_fns[output_index], + ranges=size, + ) + for output_index in range(len(dtypes)) + ] + + reduction_hint, num_splits = cls.num_splits( + device=device, + dtype=dtypes[0], + inner_fn=inner_fns[0], + axis=axis, + pointwise_ranges=pointwise_ranges, + scan_ranges=scan_ranges, + combine_fn=combine_fn, + scan_numel=scan_numel, + ) + scan_type = Scan + if num_splits > 1: + supports_split = torch.version.hip is None and len(dtypes) == 1 + if not supports_split: + if can_fallback_to_aten: + # Fallback to ATen + return [None] * len(dtypes) + else: + num_splits = 1 + else: + scan_type = SplitScan + + def reindex(index, scan_index): + assert len(scan_index) == len(scan_ranges) + assert len(index) == len(pointwise_ranges) + return [*index[:axis], *scan_index, *index[axis:]] + + results = [ + TensorBox.create( + scan_type( + device=device, + dtype=dtypes[output_index], + dtypes=dtypes, + inner_fn=inner_fns[output_index], + inner_fns=inner_fns, + size=size, + ranges=pointwise_ranges, + scan_ranges=scan_ranges, + combine_fn=combine_fn, + reindex=reindex, + reduction_hint=reduction_hint, + output_index=output_index, + **kwargs, + ) + ) + for output_index in range(len(dtypes)) + ] + + for result in results: + result.realize() + + return results + + @classmethod + def num_splits( + cls, + device: torch.device, + dtype: torch.dtype, + inner_fn: Callable[[List[Expr]], Any], + axis: int, + pointwise_ranges: List[Expr], + scan_ranges: List[Expr], + combine_fn: Callable[[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]], + scan_numel: Expr, + ): + # TODO: custom splitting heuristic for scan + def wrapper_fn(idx, reduction_idx): + return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]]) + + return Reduction.num_splits( + device=device, + dst_dtype=dtype, + src_dtype=dtype, + inner_fn=wrapper_fn, + ranges=pointwise_ranges, + reduction_ranges=scan_ranges, + reduction_type="sum", + reduction_numel=scan_numel, + ) + + +# This signifies a scan op that should go through TritonSplitScanKernel codegen on CUDA. +@dataclasses.dataclass +class SplitScan(Scan): + pass + + +@dataclasses.dataclass +class Sort(Loops): + # Sorts a tuple of key, value pairs + sort_ranges: List[Expr] + size: List[Expr] + reindex: Callable[[List[Expr], List[Expr]], List[Expr]] + reduction_hint: ReductionHint + output_index: int + # output_index indexes the following tuples + dtypes: Tuple[torch.dtype, ...] + inner_fns: Tuple[Callable[..., Any], ...] + + stable: bool + descending: bool + + # HACK we mimick reduction + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return ( + super().get_unbacked_symbol_uses() + | OrderedSet().union(*(free_unbacked_symbols(e) for e in self.sort_ranges)) + | OrderedSet().union(*(free_unbacked_symbols(e) for e in self.size)) + ) + + def __post_init__(self): + assert len(self.ranges) + len(self.sort_ranges) == len(self.size) + super().__post_init__() + + def store_reduction(self, output_name, indexer, vars, sort_vars): + idx = self.reindex(vars, sort_vars) + values = [inner_fn(idx) for inner_fn in self.inner_fns] + result = ops.sort(self.dtypes, values, self.stable, self.descending) + return ops.store(output_name, indexer(idx), result[self.output_index]) + + def get_reduction_type(self): + return "sort" + + def get_reduction_size(self): + return self.sort_ranges + + def get_size(self): + return self.size + + def get_pointwise_size(self): + return self.ranges + + def index_length(self): + return len(self.ranges) + len(self.sort_ranges) + + def inner_fn_args(self): + index = self._index(self.ranges) + rindex = self._index(self.sort_ranges, SymT.RINDEX) + idx = self.reindex(index, rindex) + return (idx,) + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + rindex = self._index(self.sort_ranges, SymT.RINDEX) + idx = self.reindex(index, rindex) + return extract_free_unbacked_symbols(self.inner_fn, idx) + + @classmethod + def create( + cls, + device: torch.device, + dtypes: Tuple[torch.dtype, ...], + inner_fns: Tuple[Callable[[List[Expr]], Any], ...], + size: List[Expr], + axis: int, + stable: bool, + descending: bool, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + **kwargs, + ) -> List[Optional[TensorBox]]: + pointwise_ranges = [*size[:axis], *size[axis + 1 :]] + sort_ranges = [size[axis]] + + if not V.graph.has_feature(device, BackendFeature.SORT): + return [None] * len(dtypes) + + sizevars = V.graph.sizevars + sort_numel = sizevars.simplify(sympy_product(sort_ranges)) + + # Heuristic, smallest rblock where triton usually outperforms aten.sort + # It also isn't bandwidth bound so fusion is unlikely to help. + max_rblock = 512 + is_persistent_kernel = ( + config.triton.persistent_reductions + and sizevars.is_expr_static_and_true(sympy.Le(sort_numel, max_rblock)) + ) + if not is_persistent_kernel: + # We only support persistent triton kernels + return [None] * len(dtypes) + + assert len(dtypes) == len(inner_fns) + + # Sort with a single element is just a copy + if sizevars.is_expr_static_and_true(sympy.Le(sort_numel, 1)): # type: ignore[arg-type] + return [ + Pointwise.create( + device=device, + dtype=dtypes[output_index], + inner_fn=inner_fns[output_index], + ranges=size, + ) + for output_index in range(len(dtypes)) + ] + + def reindex(index, sort_index): + assert len(sort_index) == len(sort_ranges) + assert len(index) == len(pointwise_ranges) + return [*index[:axis], *sort_index, *index[axis:]] + + results = [ + TensorBox.create( + Sort( + device=device, + dtype=dtypes[output_index], + dtypes=dtypes, + inner_fn=inner_fns[output_index], + inner_fns=inner_fns, + size=size, + ranges=pointwise_ranges, + sort_ranges=sort_ranges, + reindex=reindex, + reduction_hint=reduction_hint, + output_index=output_index, + stable=stable, + descending=descending, + **kwargs, + ) + ) + for output_index in range(len(dtypes)) + ] + + for result in results: + result.realize() + + return results + + +def is_storage_and_layout(x: IRNode) -> bool: + try: + as_storage_and_layout(x, freeze=False) + return True + except NotImplementedError: + return False + + +def is_contiguous_storage_and_layout(x: IRNode) -> bool: + try: + buffer, layout = as_storage_and_layout(x, freeze=False) + # pad the stride here so we will NOT claim an tensor as contiguous + # if a padding is gonna happen. + if layout.should_pad_strides(): + layout.pad_strides() + return layout.is_contiguous() + except NotImplementedError: + return False + + +def as_storage_and_layout( + x: IRNode, + freeze: bool = True, + want_contiguous: bool = False, + stride_order: Optional[Sequence[Union[int, Integer]]] = None, + allow_padding: bool = False, + exact_strides: Optional[Sequence[Union[int, Integer]]] = None, +) -> Tuple[StorageBox, Layout]: + """ + Try to simplify x into a StorageBox and a Layout. + + allow_padding only affect how we apply stride_order. When allow_padding + is True, we have the freedom to add padding when applying the stride_order. + """ + if isinstance(x, TensorBox): + return as_storage_and_layout( + x.data, + freeze=freeze, + want_contiguous=want_contiguous, + stride_order=stride_order, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + if isinstance(x, StorageBox) and isinstance(x.data, Buffer): + if freeze: + if want_contiguous: + x.data.freeze_layout() + assert x.data.layout.is_contiguous() + elif stride_order is not None: + x.data.freeze_layout_with_stride_order( + stride_order, allow_padding=allow_padding + ) + elif exact_strides is not None: + x.data.freeze_layout_with_exact_strides( + exact_strides, allow_padding=allow_padding + ) + else: + x.data.decide_layout() + return x, x.data.layout + if isinstance(x, ReinterpretView): + # making the base of x contiguous or stride_ordered will not necessarily make + # the ReinterpretView either, so don't pass along those arguments + buffer, _ = as_storage_and_layout( + x.data, + freeze=freeze, + ) + return buffer, x.layout + raise NotImplementedError + + +as_contiguous_storage_and_layout = functools.partial( + as_storage_and_layout, want_contiguous=True +) + + +def is_stride_order_storage_and_layout( + x: IRNode, stride_order: Sequence[Union[int, Integer]] +) -> bool: + try: + buffer, layout = as_storage_and_layout(x, freeze=False) + return layout.is_stride_ordered(stride_order) + except NotImplementedError: + return False + + +@dataclasses.dataclass +class BaseView(IRNode): + data: IRNode + + def get_unbacked_symbol_uses(self): + return self.data.get_unbacked_symbol_uses() + + def make_reindexer(self): + raise NotImplementedError(f"make_reindexer NYI on {self}") + + def make_indexer(self): + inner = self.data.make_indexer() + reindex = self.make_reindexer() + + def indexer(idx): + return inner(reindex(idx)) + + return indexer + + def make_loader(self): + inner = self.data.make_loader() + reindex = self.make_reindexer() + + def loader(idx): + return inner(reindex(idx)) + + return loader + + @property + def dtype(self): + return self.data.dtype + + def get_layout(self): + return self.data.get_layout() + + def get_device(self): + return self.data.get_device() + + def get_origin_node(self): + return None + + def get_name(self): + return self.data.get_name() + + def get_pointwise_size(self): + return self.get_size() + + def mark_reuse(self, users): + return self.data.mark_reuse(users) + + def has_exceeded_max_reads(self): + return self.data.has_exceeded_max_reads() + + def realize(self): + return self.data.realize() + + def realize_hint(self): + return self.data.realize_hint() + + def get_storage_numel(self): + return self.data.get_storage_numel() + + def is_extern(self): + return self.data.is_extern() # type: ignore[attr-defined] + + def is_module_buffer(self): + return self.data.is_module_buffer() # type: ignore[attr-defined] + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() + + def get_reads(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + return extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + + def unwrap_view(self): + x: IRNode = self + while isinstance(x, BaseView): + x = x.data + return x + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Pointwise(device, self.get_dtype(), loader, self.get_size()) + + +@dataclasses.dataclass +class ExpandView(BaseView): + size: List[Expr] + + @staticmethod + def _normalize_size(x, new_size): + """Replace `-1` with correct sizes""" + sizevars = V.graph.sizevars + new_size = list(map(sympy.expand, new_size)) + old_size = x.get_size() + old_size = [None] * (len(new_size) - len(old_size)) + list(old_size) + assert len(new_size) == len(old_size) + for i in range(len(new_size)): + if new_size[i] == -1: + assert old_size[i] is not None + new_size[i] = old_size[i] + elif old_size[i] is None or old_size[i] == 1: + pass + else: + # Sanity check: Expect broadcast compatibility + # + # NB: new_size[i] == old_size[i] is expected to already be + # guarded because the meta formula was expected to have taught + # us this equality. + assert ( + sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0 + ), "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}" + return new_size + + @classmethod + def create(cls, x, new_size): + new_size = cls._normalize_size(x, new_size) + + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + skip = len(new_size) - len(old_layout.size) + assert skip >= 0 + new_stride = [sympy.Integer(0)] * skip + for stride, size in zip(old_layout.stride, old_layout.size): + new_stride.append(stride if size != 1 else sympy.Integer(0)) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(new_size), + new_stride, + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + return ExpandView(x, new_size) + + def get_size(self): + return self.size + + def make_reindexer(self): + target = self.get_size() + actual = self.data.get_size() + skip = len(target) - len(actual) + + def reindex(index): + index = list(index[skip:]) + assert len(index) == len(actual) + for i in range(len(actual)): + if actual[i] == 1: + # zero out broadcast dimension + index[i] = sympy.Integer(0) + return index + + return reindex + + +@dataclasses.dataclass +class PermuteView(BaseView): + dims: List[Expr] + + @classmethod + def create(cls, x, dims): + dims = cls._map_neg_dims(dims) + assert OrderedSet(dims) == OrderedSet(range(len(dims))) + + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + [old_layout.size[i] for i in dims], + [old_layout.stride[i] for i in dims], + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + return PermuteView(x, dims) + + @classmethod + def _map_neg_dims(cls, dims): + return [dim if dim >= 0 else len(dims) + dim for dim in dims] + + def get_size(self): + assert OrderedSet(self._map_neg_dims(self.dims)) == OrderedSet( + range(len(self.dims)) + ) + size = self.data.get_size() + return [size[i] for i in self.dims] + + def make_reindexer(self): + inv = {j: i for i, j in enumerate(self.dims)} + inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index] + assert OrderedSet(inv) == OrderedSet(range(len(self.dims))) + + def reindex(index): + return [index[i] for i in inv] + + return reindex + + +class SqueezeView(BaseView): + @classmethod + def create(cls, x, *, dim=None): + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_size = [] + new_stride = [] + if dim is not None: + assert isinstance(dim, int), "expected integer dim argument" + assert 0 <= dim and dim < len(old_layout.size) + + for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)): + if dim is None: + if size != 1: + new_size.append(size) + new_stride.append(stride) + else: + if i != dim: + new_size.append(size) + new_stride.append(stride) + else: + assert size == 1, "expected squeezed size to be 1" + + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + if dim is None: + # redirect to a generic view + return View.create(x, [s for s in x.get_size() if s != 1]) + else: + assert x.get_size()[dim] == 1 + return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim]) + + @staticmethod + def squeezer(size: Tuple[sympy.Expr, ...]): + new_size = [s for s in size if s != 1] + not_one = [i for i, s in enumerate(size) if s != 1] + length = len(size) + + def reindex(index: List[sympy.Expr]) -> Tuple[sympy.Expr, ...]: + assert len(index) == len(not_one), f"{index} {not_one}" + new_index = [sympy.Integer(0)] * length + for idx, s in zip(not_one, index): + new_index[idx] = s + return tuple(new_index) + + return new_size, reindex + + def __init__(self, data): + raise AssertionError("use SqueezeView.create()") + + +@dataclasses.dataclass +class GenericView(BaseView): + size: List[Expr] + reindex: Callable[..., Any] + + def make_reindexer(self): + return self.reindex + + def reindex_str(self): + index_old = [ + sympy_index_symbol_with_prefix(SymT.INDEX, n) for n in range(len(self.size)) + ] + index_new = list(self.reindex(index_old)) + return f"lambda {', '.join(map(str, index_old))}: {index_new}" + + def __str__(self) -> str: + return self.str_helper( + [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"] + ) + + __repr__ = __str__ + + @classmethod + def create(cls, x, new_size, reindex): + return cls(x, list(new_size), reindex) + + def get_size(self): + return self.size + + +@dataclasses.dataclass +class View(GenericView): + @staticmethod + def handle_negative_index(idx, size): + idx = sympy.expand(idx) + size = sympy.expand(size) + evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr + if evaluate_expr(sympy.Lt(idx, 0)): + idx = idx + size + return idx + + @classmethod + def create(cls, x, new_size): + assert isinstance(new_size, (tuple, list)) + old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) + + # Skip pointless views + if V.graph.sizevars.statically_known_list_equals(old_size, new_size): + return x + + unbacked_symbols_in_sizes = False + if ( + len(free_unbacked_symbols(old_size)) > 0 + or len(free_unbacked_symbols(new_size)) > 0 + ): + unbacked_symbols_in_sizes = True + + if 0 in new_size: + + def fake_reindex(index): + return tuple([0] * len(old_size)) + + return cls(x, list(new_size), fake_reindex) + # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout + elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes: + if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)): + # realize x; otherwise, the dynamic_reshape_indexer below will fail + # due to the size_hint's inability to process unbacked SymInts + x = ExternKernel.realize_input(x) + + storage, old_layout = as_contiguous_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + FlexibleLayout.contiguous_strides(new_size), + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + reindex = cls.dynamic_reshape_indexer(old_size, new_size) + return cls(x, list(new_size), reindex) + + @staticmethod + def resolve_negative_size(old_size, new_size): + new_size = [V.graph.sizevars.simplify(x) for x in new_size] + old_size = [V.graph.sizevars.simplify(x) for x in old_size] + + new_size = list(new_size) + for i in range(len(new_size)): + if new_size[i] == -1: + new_size[i] = sympy.Integer(1) + new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size)) + break + + V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size)) + return old_size, new_size + + @classmethod + def dynamic_reshape_indexer(cls, old_size, new_size): + try: + reindex = cls._dynamic_reshape_indexer(old_size, new_size) + except (AssertionError, IndexError): + # optimistic algorithm failed, lets do a fallback + flat = [sympy_product(old_size)] + reindex1 = cls._dynamic_reshape_indexer(old_size, flat) + reindex2 = cls._dynamic_reshape_indexer(flat, new_size) + reindex = fuse_reindexing(reindex1, reindex2) + return reindex + + @staticmethod + def _dynamic_reshape_indexer(old_size, new_size): + """ + Perform a reshape entirely by modifying indexing math + """ + size_hint = V.graph.sizevars.size_hint + # TODO: These symbols may not escape, if they don't assert so and + # treat them as temporary + vars = [ + sympy_index_symbol_with_prefix(SymT.VIEW, i) for i in range(len(new_size)) + ] + + stack_new = list(zip(vars, new_size)) + stack_old = list(old_size) + + view_expr = [] + while stack_new and stack_old: + size_old = stack_old.pop() + var, size_new = stack_new.pop() + if size_old == 1: + view_expr.append(sympy.Integer(0)) + stack_new.append((var, size_new)) # re-add + elif size_new == 1: + stack_old.append(size_old) # re-add + elif size_hint(size_new) == size_hint(size_old): + view_expr.append(var) + V.graph.sizevars.guard_equals(size_new, size_old) + elif size_hint(size_new) < size_hint(size_old): + while size_hint(size_new) < size_hint(size_old): + var2, size_new2 = stack_new.pop() + var = var2 * size_new + var + size_new = size_new * size_new2 + view_expr.append(var) + V.graph.sizevars.guard_equals(size_new, size_old) + elif size_hint(size_new) > size_hint(size_old): + divisor = sympy.Integer(1) + modulus = size_old + view_expr.append(ModularIndexing(var, divisor, modulus)) + divisor = divisor * modulus + while size_hint(size_new) > size_hint(size_old): + modulus = stack_old.pop() + view_expr.append(ModularIndexing(var, divisor, modulus)) + divisor = divisor * modulus + size_old = size_old * modulus + V.graph.sizevars.guard_equals(size_new, size_old) + else: + raise AssertionError + + while stack_old: + size_old = stack_old.pop() + V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type] + view_expr.append(sympy.Integer(0)) + + while stack_new: + var, size_new = stack_new.pop() + V.graph.sizevars.guard_equals(size_new, 1) # type: ignore[arg-type] + + view_expr.reverse() + assert len(view_expr) == len(old_size) + + def reindex(index): + assert len(index) == len(vars), (len(index), len(vars)) + replacements = dict(zip(vars, index)) + return tuple(sympy_subs(x, replacements) for x in view_expr) # type: ignore[arg-type] + + return reindex + + +@dataclasses.dataclass +class ReinterpretView(BaseView): + """Pretend our storage has a different layout""" + + layout: Layout + + def __post_init__(self): + super().__post_init__() + if isinstance(self.data, BaseView): + self.data = self.data.unwrap_view() + + def __str__(self) -> str: + return self.str_helper( + [ + self.data, + self.layout, + ] + ) + + __repr__ = __str__ + + def get_name(self): + return self.data.get_name() + + def get_device(self): + return self.layout.device + + def get_origin_node(self): + return None + + @property + def dtype(self): + return self.layout.dtype + + def get_size(self): + return list(self.layout.size) + + def get_stride(self): + return list(self.layout.stride) + + def make_loader(self): + def loader(index): + indexer = self.layout.make_indexer() + tmp_loader = ops.load(self.get_name(), indexer(index)) + if self.layout.dtype != self.data.dtype: + return ops.to_dtype_bitcast(tmp_loader, self.dtype, self.data.dtype) + else: + return tmp_loader + + return loader + + def make_indexer(self): + return self.layout.make_indexer() + + def get_layout(self): + return self.layout + + def freeze_layout(self): + pass + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return ( + free_unbacked_symbols(self.layout.size) + | free_unbacked_symbols(self.layout.stride) + | free_unbacked_symbols(self.layout.offset) + ) + + def codegen_reference(self, writer=None): + # reinterpret_tensor is similar to as_strided except: + # - offset is added to the existing offset (rather than replacing it) + # - view tracking is disabled similar to unsafe_view + return V.graph.wrapper_code.codegen_reinterpret_view( + self.data, + self.layout.size, + self.layout.stride, + self.layout.offset, + writer, + dtype=self.layout.dtype, + ) + + def num_reads(self): + return 1 + + +@dataclasses.dataclass +class DtypeView(BaseView): + """Pretend our storage has a different type""" + + target_dtype: torch.dtype + + @classmethod + def create(cls, x, new_dtype): + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + new_dtype, + old_layout.size, + old_layout.stride, + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + return DtypeView(x, new_dtype) + + def __str__(self) -> str: + return self.str_helper([self.data, self.target_dtype]) + + __repr__ = __str__ + + @property + def dtype(self): + return self.target_dtype + + def get_size(self): + return self.data.get_size() + + def make_loader(self): + inner = self.data.make_loader() + + def loader(idx): + return ops.to_dtype_bitcast(inner(idx), self.target_dtype, self.data.dtype) + + return loader + + +class SliceView(View): + @classmethod + def normalize_start_end(cls, x, dim, start, end): + """ + Normalize start and end such that both are in the range + [0, x.get_size()[dim]] and start <= end. + """ + sizevars = V.graph.sizevars + dim_size = x.get_size()[dim] + + if any(free_unbacked_symbols(x) for x in (start, end, dim_size)): + + def clamp(x, lower, upper): + return sympy.Min(sympy.Max(x, lower), upper) + + else: + + def clamp(x, lower, upper): + return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper) + + def clamp_wrap(val, lower, upper, default): + if val is None: + return default + val = cls.handle_negative_index(val, dim_size) + return clamp(val, lower, upper) + + start = clamp_wrap(start, 0, dim_size, 0) + end = clamp_wrap(end, start, dim_size, dim_size) + return start, end + + @classmethod + def create(cls, x, dim, start, end, step=1, clamp=True): + step = sympy.expand(step) + assert isinstance(step, sympy.Expr) or step > 0 + try: + if start == 0 and end >= 2**63 - 1 and step == 1: + return x + except TypeError: + pass + + sizevars = V.graph.sizevars + new_size = list(x.get_size()) + + # NB: Ordinarily we default to clamping. + # We only don't clamp for split_with_sizes. For split_with_sizes, sizes should be already valid + # failing in this situation is ok, since invalid sizes could trigger silent errors. + if clamp: + start, end = cls.normalize_start_end(x, dim, start, end) + + new_size[dim] = FloorDiv(end - start + (step - 1), step) + + if is_storage_and_layout(x): + # Fast path + storage, old_layout = as_storage_and_layout(x) + new_stride = list(old_layout.stride) + new_stride[dim] = new_stride[dim] * step + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset + old_layout.stride[dim] * start, + ) + return ReinterpretView(storage, new_layout) + + def reindex(index): + assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" + index = list(index) + index[dim] = index[dim] * step + start + return index + + # redirect to a generic view + return SliceView(x, size=new_size, reindex=reindex) + + +class BaseConstant(IRNode): + dtype: torch.dtype + device: torch.device + + def get_size(self): + return () + + def get_device(self): + return self.device + + def get_origin_node(self): + return None + + def mark_reuse(self, users): + pass + + def has_exceeded_max_reads(self): + return False + + def get_reads(self): + return () + + def is_extern(self): + return False + + +@dataclasses.dataclass +class Constant(BaseConstant): + value: Any + dtype: torch.dtype + device: torch.device + + def make_loader(self): + def loader(index): + return ops.constant(self.value, self.dtype) + + return loader + + def realize(self): + pass + + def constant_to_device(self, device): + return Constant(self.value, self.dtype, device) + + +@dataclasses.dataclass +class IndexingConstant(BaseConstant): + index: Any + dtype: torch.dtype + device: torch.device + + def make_loader(self): + def loader(index): + return ops.index_expr(self.index, self.dtype) + + return loader + + def constant_to_device(self, device): + return IndexingConstant(self.index, self.dtype, device) + + +def is_contiguous_strides_for_shape( + stride: Sequence[_IntLike], shape: Sequence[_IntLike] +) -> bool: + return all( + size == 1 or left == right + for left, right, size in zip( + stride, FlexibleLayout.contiguous_strides(shape), shape + ) + ) + + +def get_align_for_dtype(dtype: torch.dtype) -> int: + return config.padding_alignment_bytes // dtype.itemsize + + +@dataclasses.dataclass +class Layout(IRNode): + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: List[Expr], + stride: Optional[Sequence[Union[Expr, int]]], + offset: Expr = Integer(0), + ): + assert stride is None or len(size) == len( + stride + ), f"size={size}, stride={stride}" + self.device = device + self.dtype = dtype + assert all(isinstance(s, (Expr, int)) for s in size) + self.size = size + self._stride = stride + self.offset = offset + + @property + def stride(self): + return self._stride + + def __str__(self) -> str: + offset = "" + if self.offset != 0: + offset = f", offset={self.offset}" + return ( + f"{type(self).__name__}('{self.device.type}', {self.dtype}, " + f"size={self.size}, stride={self.stride}{offset})" + ) + + __repr__ = __str__ + + def is_contiguous(self): + return is_contiguous_strides_for_shape(self.stride, self.size) + + @staticmethod + def is_channels_last_contiguous(shape, strides): + ndim = len(shape) + if ndim not in [4, 5] or shape[1] == 1: + return False + for left, right, size in zip( + strides, make_channels_last_strides_for(shape), shape # type: ignore[arg-type] + ): + if size != 1 and left != right: + return False + return True + + def is_transposed(self): + for left, right, size in zip( + self.stride, + reversed(FlexibleLayout.contiguous_strides(list(reversed(self.size)))), + self.size, + ): + if size != 1 and left != right: + return False + return True + + def is_stride_ordered(self, order): + assert len(self.stride) == len(order) + + # ignore dimensions of size 1, they dont affect layout + non_1_indices = [ + i + for i, dim in enumerate(self.size) + if V.graph.sizevars.size_hint(dim, fallback=2) != 1 + ] + + stride = [self.stride[i] for i in non_1_indices] + order = [order[i] for i in non_1_indices] + + def sorted_indices(arr): + sorted_arr = sorted(arr) + return [sorted_arr.index(element) for element in arr] + + # since we may have removed dimensions, need to re-sort & re-index order + order = sorted_indices(order) + + # reorder the stride given order + stride_ordered = [-1] * len(order) + for i in range(len(order)): + stride_ordered[order[i]] = V.graph.sizevars.size_hint(stride[i]) + # check if it is in ascending order + for i in range(len(order) - 1): + if stride_ordered[i] > stride_ordered[i + 1]: + return False + return True + + def is_channels_last_stride_ordered(self): + # create channels_last order(NCHW, NCDHW, the C is the first order). + order = [0] + list(reversed(range(1, len(self.stride) - 1))) + order = [len(order)] + order + return self.is_stride_ordered(order) + + @staticmethod + def _pad_strides(in_strides, size, dtype): + """ + The padding does not change stride order but makes sure all strides larger + than the threshold are multiple of align. + """ + align = get_align_for_dtype(dtype) + if len(in_strides) == 0: + return in_strides + + if not config.pad_channels_last and Layout.is_channels_last_contiguous( + size, in_strides + ): + return in_strides + + current_fx_node = V.get_current_node() + if hasattr(current_fx_node, "meta") and current_fx_node.meta.get( + "dislike_padding", False + ): + return in_strides + + # get_stride_order does not work with dynamic shape. Also we can not + # statically decide if a padding is needed or how much padding we should + # do for dynamic shape. + # + # Skip padding the strides for dynamic shape for now. + if not all( + isinstance(s, (int, sympy.Integer)) + for s in itertools.chain(in_strides, size) + ): + return in_strides + + stride_order = get_stride_order(in_strides) + fill_order = stride_order2fill_order(stride_order) + + new_strides = [0 for _ in range(len(in_strides))] + # since we pad when the layout is flexible, we can decide the + # smallest stride to be 1. + new_strides[fill_order[0]] = 1 + + padded = False + for rank, idx in enumerate(fill_order[1:], start=1): + prev_idx = fill_order[rank - 1] + stride = new_strides[prev_idx] * size[prev_idx] + + if stride > config.padding_stride_threshold and stride % align != 0: + stride = ceildiv(stride, align) * align + padded = True + new_strides[idx] = stride + + if not padded: + # Consider a tensor with shape [256, 1, 5, 5] + # Avoid strides like [25, 5, 5, 1] being padded to equivalent strides + # [25, 25, 5, 1]. + return in_strides + + metrics.num_comprehensive_padding += 1 + return new_strides + + def pad_strides(self): + assert isinstance(self, FlexibleLayout) + assert self._stride is not None + self._stride = self._pad_strides(self._stride, self.size, self.dtype) + + def should_pad_strides(self): + return config.comprehensive_padding and isinstance(self, FlexibleLayout) + + def as_fixed(self): + if isinstance(self, FixedLayout): + return self + + if self.should_pad_strides(): + self.pad_strides() + return FixedLayout( + self.device, + self.dtype, + self.size, + self.stride, + self.offset, + ) + + def make_indexer(self): + assert ( + FlexibleLayout.allow_indexing + ), f"convert {type(self).__name__} to FixedLayout first" + return self.as_fixed().make_indexer() + + def __eq__(self, other) -> bool: + return ( + self.device == other.device + and self.dtype == other.dtype + and self.size == other.size + and self.stride == other.stride + and self.offset == other.offset + ) + + def storage_size(self) -> sympy.Expr: + return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type, return-value] + + +class FixedLayout(Layout): + """A Tensor layout we cannot change""" + + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: Union[List[Expr], List[int]], + stride: Optional[Sequence[Union[Expr, int]]] = None, + offset: Union[Expr, int] = Integer(0), + ): + if stride is None: + stride = FlexibleLayout.contiguous_strides(size) + super().__init__( + device, + dtype, + size, # type: ignore[arg-type] + stride, + offset, # type: ignore[arg-type] + ) + + def make_indexer(self): + """A closure containing math to read a given element""" + + def indexer(index): + assert len(index) == len(self.stride) + assert len(index) == len(self.size) + result = self.offset + for idx, stride, sz in zip(index, self.stride, self.size): + if sz != 1: + result = result + idx * stride + return result + + return indexer + + +class FlexibleLayout(Layout): + """A Tensor layout we are allowed to change""" + + allow_indexing = False + + # WARNING! This doesn't handle zero size tensors correctly + @staticmethod + def contiguous_strides(sizes): + if len(sizes) == 0: + return [] + reversed_strides = [sympy.Integer(1)] + for size in reversed(sizes[1:]): + reversed_strides.append(size * reversed_strides[-1]) + return list(reversed(reversed_strides)) + + @staticmethod + def fill_ordered(sizes, order): + """ + Create a stride based on the order the dimensions should be filled in. + + In this format, channels last would be: + [1, 3, 2, 0] + """ + assert OrderedSet(range(len(sizes))) == OrderedSet(order), (sizes, order) + next_stride = sympy.Integer(1) + strides = [None] * len(order) + + for i in order: + strides[i] = next_stride + next_stride = next_stride * sizes[i] + return strides + + @staticmethod + def stride_ordered(sizes, order): + """ + Create a stride based on the sorted order of a permuted range. + + In this format, channels last would be: + [3, 0, 2, 1] + """ + assert OrderedSet(range(len(sizes))) == OrderedSet(order) + fill_order = stride_order2fill_order(order) + return FlexibleLayout.fill_ordered(sizes, fill_order) + + @staticmethod + def stride_ordered_for_memory_format(sizes, memory_format): + """ + Create a stride based on a memory format. + + Memory format is translasted into a stride order, + so channels_last is the same as: + FlexibleLayout.stride_ordered(sizes, [3, 0, 2, 1]) + + This interface does not support memory_format `torch.preserve_format` + which should be used to deduce a format from another source + """ + if memory_format == torch.channels_last: + return FlexibleLayout.stride_ordered(sizes, NHWC_STRIDE_ORDER) + elif memory_format == torch.channels_last_3d: + return FlexibleLayout.stride_ordered(sizes, NHWDC_STRIDE_ORDER) + elif memory_format == torch.contiguous_format: + return FlexibleLayout.contiguous_strides(sizes) + else: + log.debug( + "stride_ordered_for_memory_format, unsuppored memory_format: %s", + memory_format, + ) + raise NotImplementedError + + @staticmethod + def same_ordered(sizes, stride): + """ + Create a stride that has the same stride order as given stride + + For example, if given stride is [1000, 1, 100, 10], + the fill order should be [1, 3, 2, 0] + """ + assert len(sizes) == len(stride) + stride = [V.graph.sizevars.size_hint(x) for x in stride] + fill_order = sorted(range(len(stride)), key=stride.__getitem__) + return FlexibleLayout.fill_ordered(sizes, fill_order) + + def as_stride_order(self, order, allow_padding=False): + new_stride = self.stride_ordered(self.size, order) + if self.should_pad_strides() and allow_padding: + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + + def as_exact_strides(self, exact_strides, allow_padding=False): + new_stride = exact_strides + if self.should_pad_strides() and allow_padding: + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + + def as_fill_order(self, order): + new_stride = self.fill_ordered(self.size, order) + if self.should_pad_strides(): + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + + def as_same_order(self, stride): + new_stride = self.same_ordered(self.size, stride) + if self.should_pad_strides(): + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + + def __init__(self, device, dtype, size, stride_order=None): + if stride_order: + strides = FlexibleLayout.fill_ordered(size, stride_order) + else: + strides = FlexibleLayout.contiguous_strides(size) + super().__init__(device, dtype, size, strides) + + +class NonOwningLayout(Layout): + """Is a view into the storage of another tensor""" + + def __init__(self, view: Union[BaseView, TensorBox]): + layout = view.get_layout() + super().__init__( + layout.device, + layout.dtype, + layout.size, + layout.stride, + ) + self.view = view + + def make_indexer(self): + return self.as_fixed().make_indexer() + + def maybe_guard_aligned(self): + offset = self.view.get_layout().offset + if offset == 0: + return True + from .utils import ALIGNMENT + + return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type] + + +class NoneLayout(IRNode): + # This is janky, I figured out what fields to populate by just running + # the model I was interested in and adding properties/methods as needed. + # This doesn't inherit from Layout because Layout assumes you have stuff + # like sizes, but I don't really have anything here. + # + # If you have an ir.Node with NoneLayout, you probably need to setup + # dependencies manually in scheduler + + def __init__(self, device): + self.device = device + self.size = [0] + self.stride = [0] + + def storage_size(self): + return 0 + + def as_fixed(self): + return self + + +class MutationLayoutSHOULDREMOVE(Layout): + def __init__(self, target: IRNode): + super().__init__( + target.get_device(), + target.get_dtype(), + target.get_size(), + None, + ) + self.target = target + name = self.get_buffer().get_name() + V.graph.mark_buffer_mutated(name) + + @Layout.stride.getter # type: ignore[attr-defined] + def stride(self): + return self.real_layout().stride + + def storage_size(self) -> sympy.Expr: + return self.real_layout().storage_size() + + def get_buffer(self) -> Buffer: + def unwrap_views(target): + if isinstance(target, MutationLayoutSHOULDREMOVE): + return unwrap_views(target.target) + if isinstance(target, BaseView): + return unwrap_views(target.unwrap_view()) + if isinstance(target, MutableBox): + return unwrap_views(target.data) + return target + + result = unwrap_views(self.target) + assert isinstance( + result, Buffer + ), "MutationLayoutSHOULDREMOVE must refer to a buffer" + return result + + def real_layout(self): + return self.get_buffer().layout + + @classmethod + def realize_into(cls, src, dst, unsafe_alias=False): + dst.realize() + # NOTE: We must realize users of `dst` before we realize `src`, since + # realization order determines scheduling order. Otherwise, src's + # mutation would be scheduled before the existing users of dst! + V.graph.mark_buffer_mutated(dst.get_name()) + + if isinstance(src, TensorBox): + src = src.data + + # We copy the contents of src into dst. In most cases this should + # be fused into a single kernel by the scheduler. + # NOTE: We cannot change src's layout to mutate dst directly as this + # would alias src to dst, which is not correct as further mutations to + # dst would effect users of src. However if there are no more users of + # dst, we can alias src to dst. + src.realize_hint() + + if not unsafe_alias: + src = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ).data + + src.realize() + assert isinstance(src.data.layout, FlexibleLayout) + src.data.layout = MutationLayoutSHOULDREMOVE(dst) + return src.data + + def as_fixed(self): + return self + + def make_indexer(self): + return self.target.make_indexer() + + +@dataclasses.dataclass +class Buffer(IRNode): + # Name is sometimes None; e.g., ForceInPlace, where there isn't + # a meaningful name + name: Optional[str] + layout: Layout + + # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly, + # MultiOutput does NOT define this! + + def __post_init__(self): + super().__post_init__() + self.origin_node = None + + def make_indexer(self): + return self.layout.make_indexer() + + def get_name(self) -> str: + assert self.name, self + return self.name + + def get_device(self): + return self.layout.device + + def get_origin_node(self): + return self.origin_node + + def get_defining_op(self) -> Optional[Operation]: + return None + + @property + def dtype(self): + return getattr(self.layout, "dtype", None) + + def get_size(self): + return list(self.layout.size) + + def get_stride(self): + return list(self.layout.stride) + + def get_offset(self): + return self.layout.offset + + def get_layout(self): + return self.layout + + def get_storage_numel(self): + return self.get_numel() + + def is_extern(self): + return False + + def freeze_layout(self): + if not isinstance(self.layout, (MultiOutputLayout, NonOwningLayout)): + self.layout = self.layout.as_fixed() + + def freeze_layout_with_stride_order(self, order, allow_padding=False): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_stride_order(order, allow_padding=allow_padding) + + def freeze_layout_with_fill_order(self, order): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_fill_order(order) + + def freeze_layout_with_same_order(self, stride): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_same_order(stride) + + def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_exact_strides( + exact_strides, allow_padding=allow_padding + ) + + def is_zero_elements(self): + return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] + + def make_loader(self): + # Loading from a zero-element buffer is a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.get_dtype()) + + def loader(index): + indexer = self.layout.make_indexer() + return ops.load(self.name, indexer(index)) + + return loader + + def codegen_reference(self, writer=None): + return self.get_name() + + def decide_layout(self): + pass + + def get_inputs_that_alias_output(self): + if isinstance(self.layout, NonOwningLayout): + return [self.layout.view.get_name()] + return () + + def get_mutation_names(self): + if isinstance(self.layout, MutationLayoutSHOULDREMOVE): + return [self.layout.target.get_name()] + return () + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet([self.get_name()]) + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def realize(self): + pass + + def should_allocate(self): + # Returns False by default. + return False + + +@dataclasses.dataclass +class OperationBuffer(Buffer, Operation): + # An operation that produces a single output buffer + def get_outputs(self) -> List[Buffer]: + return [self] + + def get_defining_op(self) -> Operation: + return self + + def __post_init__(self): + Buffer.__post_init__(self) + Operation.__post_init__(self) + + +class InputBuffer(Buffer): + def num_reads(self): + return 1 + + +class ConstantBuffer(InputBuffer): + override_device: Optional[torch.device] = None + + def make_loader(self): + def loader(index): + indexer = self.layout.make_indexer() + return ops.load( + V.graph.constant_name(self.get_name(), self.override_device), + indexer(index), + ) + + return loader + + def constant_to_device(self, device): + return ConstantBuffer( + V.graph.constant_name(self.get_name(), device), self.layout + ) + + +class NoneAsConstantBuffer(IRNode): + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def codegen_reference(self, writer=None): + return V.graph.wrapper_code.none_str + + +class ShapeAsConstantBuffer(IRNode): + def __init__(self, shape): + super().__init__() + self._shape = shape + + @property + def shape(self): + return self._shape + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return free_unbacked_symbols(self.shape) + + def codegen_reference(self, writer=None): + return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.shape)) + + +@dataclasses.dataclass +class ComputedBuffer(OperationBuffer): + data: Loops + + def get_computed_buffer_name(self): + """ + Returns self.name if it exists, otherwise returns the name of the data node if that exists. + If neither exist, returns None. + """ + if self.name is not None: + return self.name + if hasattr(self.data, "name"): + return self.data.name + return None + + def num_reads(self): + return self.data.num_reads() + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() + + def get_read_writes(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + if self.data.get_reduction_type(): + return extract_read_writes( + self.get_store_function(), + self.data.get_pointwise_size(), + self.data.get_reduction_size(), + ) + else: + return extract_read_writes( + self.get_store_function(), + self.data.get_size(), + ) + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + # Ordinarily, we'd like to just peek at the arguments list, + # but ComputedBuffers have no argument list. + # + # Morally, this logic needs to be synchronized with the + # KernelArgs.size calls, which are responsible for making symbols make + # there way as kernel arguments (and it is precisely passing in one of + # those symbols that establishes a dependency). However, we haven't + # started codegen yet so we can't directly reuse that logic. + # + # For now, I'm just yoloing with the size of the buffer. Not sure if + # it is enough. + # + # One thing you might wonder is if this is enough for a ComputedBuffer + # denoting a reduction over i0. Empirically, it is enough, but for an + # unusual reason: we only need accurate dependencies for item() call, + # but it's impossible to end up with a reduction over i0 from an + # item() call without a regular non-reduction buffer first. + return ( + free_unbacked_symbols(self.get_size()) + | free_unbacked_symbols(self.get_stride()) + | free_unbacked_symbols(self.get_offset()) + | self.data.get_unbacked_symbol_uses() + ) + + def make_loader(self): + # Inline constants and index_expressions + if ( + hasattr(self.data, "make_loader") + and self.name not in V.graph.mutated_buffers + and self.num_reads() == 0 + ): + # can be inlined + return self.data.make_loader() + return super().make_loader() + + def get_store_function(self): + indexer = self.layout.as_fixed().make_indexer() + if isinstance(self.data, (Reduction, Scan, Sort)): + return partial(self.data.store_reduction, self.name, indexer) + else: + assert isinstance(self.data, Pointwise) + return partial(self.data.store_output, self.name, indexer) + + def get_fill_order(self): + """ + If our layout is still flexible, try to determine the stride order based on stride orders of reads. + + TODO(jansel): A better algorithm here would look at downstream consumers of this + value and try to do global graph-level layout optimization. + This is also something just begging to be autotuned. + """ + if isinstance(self.layout, FlexibleLayout): + (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze( + self.data.get_pointwise_size(), self.data.get_reduction_size() + ) + reads = self.get_read_writes().reads + # only consider reads to buffer of same size + # ignore StarDeps because they don't contribute stride information + assert all( + isinstance(r, (dependencies.StarDep, dependencies.MemoryDep)) + for r in reads + ) + reads = [ + sympy_subs( + r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0} + ) + for r in reads + if isinstance(r, dependencies.MemoryDep) + ] + + if reads: + if isinstance(self.data, (Scan, Sort)): + indices = self.data.reindex(index_vars, reduction_vars) + else: + indices = index_vars + stride_lengths = [ + V.graph.sizevars.stride_hints(expr, indices) for expr in reads # type: ignore[arg-type] + ] + from .scheduler import pick_loop_order + + return pick_loop_order(stride_lengths, self.get_size()) + + return None + + def decide_layout(self): + if isinstance(self.layout, FlexibleLayout): + order = self.get_fill_order() + if order: + self.freeze_layout_with_fill_order(order) + else: + self.freeze_layout() + + @cache_on_self + def get_default_sizes_body(self): + args, var_ranges = dependencies.index_vars_squeeze( + self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q" + ) + with patch.object(ConstantBuffer, "override_device", self.get_device()): + body = LoopBody( + self.get_store_function(), + (args if self.get_reduction_type() else args[:1]), + var_ranges, + *args, + ) + index_vars = [] + reduce_vars: List[Any] = [] + index_size = [] + reduce_size = [] + for v, s in var_ranges.items(): + if v in args[0]: + assert not reduce_vars + index_vars.append(v) + index_size.append(s) + else: + assert v in args[1] + reduce_vars.append(v) + reduce_size.append(s) + return (index_size, reduce_size), body, (index_vars, reduce_vars) + + def simplify_and_reorder( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ): + """ + This is a main place where we do loop transformations in a + backend-agnostic way. + + Here we: + 1) Remove any 1 dimensions + 2) Fuse contiguous dimensions together + 3) Reorder dimensions based on stride orders + + Optional argument extra_indexing_constraints can be used to append additional + indexing expressions to existing ones derived from buffer's body. This can be useful + to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...) + on CPU by preventing indexing simplifications and obtaining index/reduce ranges for + the scheduler node compatible with other nodes. + Optional argument recompute_sizes_body_func can be used to recompute sizes and body + on the default body. This can be useful to append additional loop transformations. + """ + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = self.get_default_sizes_body() + + if recompute_sizes_body_func: + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = recompute_sizes_body_func( + (index_size, reduce_size), body, (index_vars, reduce_vars) + ) + + index_formulas = [*body.indexing_exprs.values()] + if extra_indexing_constraints is not None: + assert ( + isinstance(extra_indexing_constraints, tuple) + and len(extra_indexing_constraints) == 2 + ) + extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints + assert isinstance(extra_indexing_ranges, dict) + assert isinstance(extra_indexing_expr, list) + assert all(isinstance(f, Expr) for f in extra_indexing_expr) + + expected_var_ranges = body.var_ranges + assert expected_var_ranges == extra_indexing_ranges, ( + expected_var_ranges, + extra_indexing_ranges, + ) + # remove already existing expressions + extra_indexing_expr = [ + e for e in extra_indexing_expr if e not in index_formulas + ] + index_formulas += extra_indexing_expr + + memory_addrs = [*body.get_write_exprs()] + if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER): + memory_addrs.extend(body.get_read_exprs()) + + def simplify_and_reorder(x_vars, support_vars, sizes, simplify_loops): + sizes, reindex0, reindex1 = self._apply_loop_reordering( + x_vars, support_vars, sizes, memory_addrs + ) + # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] + x_vars = reindex0(x_vars) + + if simplify_loops: + sizes, reindex2, prune = V.graph.sizevars._simplify_loops( + x_vars, + sizes, + index_prevent_reordering(index_formulas, x_vars, sizes), + ) + reindex = fuse_reindexing(reindex1, reindex2) + else: + reindex = reindex1 + return sizes, reindex, reindex1 + + support_vars = index_vars + reduce_vars + should_merge_loops = ( + self.get_device().type != "cuda" or not config.loop_ordering_after_fusion + ) + iter_ranges, iter_reindex, _ = simplify_and_reorder( + index_vars, + support_vars, + index_size, + should_merge_loops, + ) + + # Like iteration dimensions, we may also want to delay merging reduction dimensions. + # E.g., if we reduce a tensor [M, N, K] for its M and N dimensions followed by a pointwise + # kernel, merging M and N dimension too early makes it hard to decide what loop order + # we should pick for the piontwise kernel so that it is fusible with the reduction. + reduce_ranges, reduce_reindex, _ = simplify_and_reorder( + reduce_vars, support_vars, reduce_size, should_merge_loops + ) + + # retrace the loop body with simplification and reordering applied + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + iter_ranges, + reduce_ranges, + prefix="z", + ) + body = LoopBody( + body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, + ) + return (iter_ranges, reduce_ranges), body + + @staticmethod + def _apply_loop_reordering( + index_vars, + support_vars, + sizes, + memory_addrs, + priority_idx=None, + ): + """ + Shuffle the order of loops around to hopefully improve performance. + """ + from .scheduler import pick_loop_order + + if priority_idx is None: + priority_idx = [] + + try: + strides = [ + V.graph.sizevars.stride_hints(expr, index_vars, support_vars) + for expr in memory_addrs + ] + assert len(strides) == len(memory_addrs) and len(strides[0]) == len( + index_vars + ) + order = list(reversed(pick_loop_order(strides, sizes, priority_idx))) + except Exception: + if config.debug: + log.warning( + "Did not simplify complex index:\n%s\n%s", + dict(zip(index_vars, sizes)), + memory_addrs, + ) + order = list(range(len(sizes))) + sizes = [sizes[i] for i in order] + return sizes, same_reorder(order), inverse_reorder(order) + + def get_reduction_size(self): + return self.data.get_reduction_size() + + def get_reduction_type(self): + return self.data.get_reduction_type() + + def is_no_op(self): + return self.data.is_zero_elements() + + def should_allocate(self): + return True + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + return self.data.constant_to_device(device) + + +class TemplateBuffer(OperationBuffer): + """ + Represents a Triton (in the future other type) of template operator + that we can fuse an epilogue onto. + """ + + def __init__(self, layout, inputs, make_kernel_render): + super().__init__(name=None, layout=layout) + self.inputs = InputsKernel.unwrap_storage(inputs) + self.make_kernel_render = make_kernel_render + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def get_read_writes(self): + return self.extract_read_writes(normalize=True) + + def extract_read_writes(self, normalize): + name = self.get_name() + indexer = self.layout.make_indexer() + + def dummy(index, rindex): + assert len(rindex) == 0 + return ops.store(name, indexer(index), "fake") + + deps = dependencies.extract_read_writes( + dummy, self.get_size(), (), normalize=normalize + ) + deps.reads = OrderedSet(dependencies.StarDep(x.get_name()) for x in self.inputs) + return deps + + def get_reduction_size(self): + return 1 + + def get_reduction_type(self): + return None + + def is_no_op(self): + return False + + def should_allocate(self): + return True + + def simplify_and_reorder( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ): + return ( + ( + self.get_size(), + (), + ), + None, + ) + + +class TritonTemplateBuffer(TemplateBuffer): + def __init__( + self, + layout, + inputs, + make_kernel_render, + debug_extra=None, + mutated_inputs: Optional[Iterable[IRNode]] = None, + ): + """ + NOTE:[TritonTemplates with multiple outputs] + We want the ability for TritonTemplates to output multiple tensors. Triton + kernels have no notion of outputs and this is done by creating tensors that + are then mutated by the kernel. Currenlty our STORE_OUTPUT codegen doesn't + support creating multinode outputs for triton templates. + We work around this by creating an extra input buffer during the lowering + and we mark them as mutated inputs. + """ + super().__init__(layout, inputs, make_kernel_render) + self.debug_extra = debug_extra + self.mutated_inputs = mutated_inputs + self.outputs: List[Buffer] = [self] + if mutated_inputs is not None: + # Ensure that the mutated inputs are only allowed for certain nodes + allowed_set = ( + torch.ops.higher_order.flex_attention, + torch.ops.higher_order.flex_attention_backward, + ) + current_node = V.graph.current_node.target + assert ( + current_node in allowed_set + ), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}" + device = self.inputs[0].get_device() + self.outputs += [ + MutationOutput(NoneLayout(device), buf, self) for buf in mutated_inputs + ] + + def get_outputs(self) -> List[Buffer]: + return self.outputs + + def __str__(self) -> str: + out = f"TritonTemplateBuffer(layout={self.layout}, {self.debug_extra})" + return out + + +PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]] + + +class ChoiceCaller: + """ + Represents a possible choice used in autotune_process.py. + During autotuning, self.benchmark() is first called to get benchmark result, + and if this choice is selected, self.output_node() is called to get the output_node. + + Children classes: TritonTemplateCaller, CUDATemplateCaller. + """ + + def __init__(self, name, input_nodes, layout): + super().__init__() + self.name = name + self.layout = layout + self.input_nodes = input_nodes + + def benchmark(self, *args, out) -> float: + algo = self.to_callable() + return benchmarker.benchmark(algo, args, {"out": out}) + + def call_name(self) -> str: + raise NotImplementedError + + def to_callable(self): + raise NotImplementedError + + def hash_key(self) -> str: + raise NotImplementedError + + def output_node(self) -> TensorBox: + raise NotImplementedError + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return {} + + def autoheuristic_id(self) -> str: + return "unsupported_choice" + + +class TritonTemplateCallerBase(ChoiceCaller): + def get_make_kernel_render(self) -> Any: + raise NotImplementedError + + +class MultiTemplateBuffer(TritonTemplateBuffer): + """ + Represents a Buffer with multiple backing implementation choices. + + Choices can be TritonTemplates or ExternKernels. During scheduling if there is a potential + epilogue we will benchmark each of the choices with the epilogue to determine an implementation. + Otherwise, the fastest base choice will be chosen. + """ + + def __init__( + self, + layout: Layout, + inputs: List[IRNode], + choice_timings: Callable[[], Dict[ChoiceCaller, float]], + ): + super().__init__(layout=layout, inputs=inputs, make_kernel_render=None) + self._choice_timings_fn = choice_timings + self._choice_timings: Optional[Dict[ChoiceCaller, float]] = None + self.original_inputs = inputs + + @property + def choice_timings(self) -> Dict[ChoiceCaller, float]: + if self._choice_timings is None: + self._choice_timings = self._choice_timings_fn() + return self._choice_timings + + @contextlib.contextmanager + def swap_as_triton_caller(self, caller: TritonTemplateCallerBase): + assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller) + assert self.layout == caller.layout + + render = self.make_kernel_render + self.make_kernel_render = caller.get_make_kernel_render() + try: + yield + finally: + self.make_kernel_render = render + + def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase): + assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller) + assert self.layout.size == caller.layout.size + assert self.layout.stride == caller.layout.stride + self.make_kernel_render = caller.get_make_kernel_render() + + def get_min_choice(self) -> Tuple[ChoiceCaller, float]: + min_choice = min(self.choice_timings, key=self.choice_timings.get) # type: ignore[arg-type] + return (min_choice, self.choice_timings[min_choice]) + + +class CUDATemplateBuffer(TemplateBuffer): + def __init__( + self, + layout, + inputs, + make_kernel_render, + workspace_size: int, + template: CUDATemplate, # type: ignore[name-defined] # noqa: F821 + ): + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + + def get_workspace_size(self): + return self.workspace_size if self.workspace_size is not None else 0 + + +class CppTemplateBuffer(TemplateBuffer): + def __init__(self, layout, inputs, make_kernel_render, template, choice): + super().__init__(layout, inputs, make_kernel_render) + self.template = template + self.choice = choice + + +@dataclasses.dataclass +class InputsKernel(OperationBuffer): + inputs: List[Buffer] + + def get_read_writes(self): + reads: OrderedSet[dependencies.Dep] = OrderedSet() + StarDep = dependencies.StarDep + for input in self.inputs: + if isinstance(input, list): + reads.update(StarDep(x.get_name()) for x in input) + else: + reads.add(StarDep(input.get_name())) + + writes: OrderedSet[dependencies.Dep] = OrderedSet( + StarDep(buf.get_name()) for buf in self.get_outputs() + ) + + return dependencies.ReadWrites( + reads=reads, + writes=writes, + index_exprs=OrderedSet(), + ) + + @classmethod + def unwrap_storage_for_input(cls, x): + if isinstance(x, TensorBox): + x = x.data + if isinstance(x, StorageBox): + x = x.data + if isinstance(x, BaseView) and not isinstance(x, ReinterpretView): + x = ExternKernel.realize_input(x) + if isinstance(x, TensorBox): + # when converting to ReinterpretView fails in the + # realize_input call above, the result will be wrapped + # into TensorBox / StorageBox pair as a result of the + # cls.copy_input call; so we should unwrap recursively + return cls.unwrap_storage_for_input(x) + if isinstance(x, TorchBindObject): + return x + assert isinstance(x, (Buffer, ReinterpretView)), x + return x + + @staticmethod + def unwrap_storage(inputs): + inputs_new = [] + for x in inputs: + if isinstance(x, list): + x = [InputsKernel.unwrap_storage_for_input(i) for i in x] + else: + x = InputsKernel.unwrap_storage_for_input(x) + inputs_new.append(x) + return inputs_new + + def is_extern(self): + return True + + def num_reads(self): + return 1 + + +class NopKernel(InputsKernel): + def is_no_op(self): + return True + + +class ConcatKernel(NopKernel): + """ + There isn't actually a real kernel for concat, we just change the + storage for the upstream data. + """ + + @classmethod + def create(cls, inputs, dim): + device = inputs[0].get_device() + dtype = inputs[0].get_dtype() + new_size = list(inputs[0].get_size()) + offsets_start = [0] + offsets_end = [new_size[dim]] + assert 0 <= dim < len(new_size) + for i in range(1, len(inputs)): + input_size = inputs[i].get_size() + offsets_start.append(new_size[dim]) + assert len(input_size) == len(new_size) + assert inputs[i].get_dtype() == dtype + assert inputs[i].get_device() == device + for j in range(len(new_size)): + if j == dim: + new_size[j] = new_size[j] + input_size[j] + else: + new_size[j] = V.graph.sizevars.guard_equals( + new_size[j], input_size[j] + ) + offsets_end.append(new_size[dim]) + + output_stride = FlexibleLayout.contiguous_strides(new_size) + # If any of the inputs is in CL format, use CL format for the output + for i in range(len(inputs)): + x = inputs[i] + if is_storage_and_layout(x): + layout = x.get_layout() + if isinstance( + layout, FixedLayout + ) and Layout.is_channels_last_contiguous(layout.size, layout.stride): + # use CL stride for the output + output_stride = make_channels_last_strides_for(new_size) + break + any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs) + fx_node_args = V.graph.current_node.args[0] + assert isinstance(fx_node_args, list) + # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output + if any_input_is_storage_and_layout is False and any( + "val" in arg.meta + and ( + arg.meta["val"].is_contiguous(memory_format=torch.channels_last) + or arg.meta["val"].is_contiguous(memory_format=torch.channels_last_3d) + ) + for arg in fx_node_args + ): + output_stride = make_channels_last_strides_for(new_size) + + concat_kernel = ConcatKernel( + name=None, + layout=FixedLayout( + device=device, + dtype=dtype, + size=new_size, + stride=output_stride, + ), + inputs=[], + ) + kernel = StorageBox(concat_kernel) + op_names = [] + for i in range(len(inputs)): + input_buffer = cls.realize_into( + inputs[i], + SliceView.create( + kernel, dim, offsets_start[i], offsets_end[i], clamp=False + ), + ) + concat_kernel.inputs.append(input_buffer) + + if isinstance(inputs[i].data, BaseView): + input_unwrapped = inputs[i].data.unwrap_view() + else: + input_unwrapped = inputs[i].data + + if ( + input_unwrapped.is_input_buffer() + and is_gpu(inputs[i].get_device().type) + and not is_dynamic(input_buffer) + ): + op_names.append(input_buffer.get_operation_name()) + + if len(op_names) > 1 and V.graph.has_feature(device, BackendFeature.FOREACH): + V.graph.register_operation_list(op_names) + + concat_kernel.name = V.graph.register_buffer(concat_kernel) + concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs) + V.graph.register_operation(concat_kernel) + + return kernel + + @classmethod + def can_realize_into_without_copy(cls, src): + if isinstance(src, TensorBox): + # unwrap a TensorBox + return cls.can_realize_into_without_copy(src.data) + + return isinstance(src.data.layout, FlexibleLayout) and not isinstance( + src.data, ExternKernelAlloc + ) + + @classmethod + def realize_into(cls, src, dst): + # Attempt to turn this into a ReinterpretView rather than assert. + # This has concessions around layout, as as_storage_and_layout + # can cause us to go from flexible to fixed layout. + if not isinstance(dst, ReinterpretView): + if is_storage_and_layout(dst): + storage, layout = as_storage_and_layout(dst) + dst = ReinterpretView(storage, layout) + assert isinstance(dst, ReinterpretView), dst + if isinstance(src, TensorBox): + # unwrap a TensorBox + return cls.realize_into(src.data, dst) + if isinstance(src, StorageBox): + src.realize() + # ExternKernelAlloc has specific requirements for output layout, should create a copy + assert hasattr(src.data, "layout") + if cls.can_realize_into_without_copy(src): + src.data.layout = NonOwningLayout(dst) + return src.data + # introduce a copy + pw = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ) + return cls.realize_into(pw, dst) + + def should_allocate(self): + return True + + +@dataclasses.dataclass +class ExternKernel(InputsKernel): + constant_args: Tuple[Any, ...] = () + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + output_view: Optional[ReinterpretView] = None + python_kernel_name: Optional[str] = None + cpp_kernel_name: Optional[str] = None + # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel + # We shouldn't need to do this since the information can be retrieved from op_overload._schema. + ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field( + default_factory=list + ) + op_overload: Optional[ + Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] + ] = None + arg_properties: Optional[List[Dict[str, Any]]] = None + kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None + unbacked_bindings: Dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field( + default_factory=dict + ) + mutation_outputs: List[MutationOutput] = dataclasses.field(default_factory=list) + + def __init__( + self, + name, + layout, + inputs, + constant_args=(), + kwargs=None, + output_view=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ): + super().__init__( + name, + layout, + inputs, + ) + self.constant_args = constant_args + self.kwargs = kwargs if kwargs else {} + self.output_view = output_view + self.op_overload = op_overload + self.set_cpp_kernel_name(cpp_kernel_name) + self.set_python_kernel_name(python_kernel_name) + self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel + self.collect_arg_kwarg_properties() + self.unbacked_bindings = {} + self.mutation_outputs = [] + self.fx_node = V.graph.current_node + + def get_outputs(self) -> List[Buffer]: + return [self, *self.mutation_outputs] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def collect_arg_kwarg_properties(self): + # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional + # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen + self.arg_properties = ( + [ + { + "name": x.name, + "type": x.real_type, + "default_value": x.default_value, + } + for x in self.op_overload._schema.arguments + if not x.kwarg_only + ] + if isinstance(self.op_overload, torch._ops.OpOverload) + else [{} for i in range(len(self.inputs))] + ) + self.allarg_properties = ( + { + x.name: {"type": x.real_type, "default_value": x.default_value} + for x in self.op_overload._schema.arguments + } + if isinstance(self.op_overload, torch._ops.OpOverload) + else {} + ) + # FIXME: self.kwargs does not always match kwargs defined in schema, so sometimes + # ordered_kwargs_for_cpp_kernel is explicilty passed in. + if ( + isinstance(self.op_overload, torch._ops.OpOverload) + and not self.ordered_kwargs_for_cpp_kernel + ): + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in self.op_overload._schema.arguments if x.kwarg_only + ] + + def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False): + # Previously, we want to maintain forward-compatibility by skipping + # default args in the serialized artifacts in fbcode. However, + # some of our shim interfaces require default values being OrderedSet. + # Discussed with Sherlock offline and we decided to allow serializing + # default args into the C++ wrapper code for now. We will refine this + # part if we see real FC requirement. More details related to FC + # can be found at: + # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing + assert isinstance(args, (list, tuple)) + if isinstance(args, tuple): + args = list(args) + assert self.arg_properties, "ExternKernel.arg_properties should not be empty" + + n_args = len(args) + n_pos_args = len(self.arg_properties) + # For cpp wrapper, if some positional args are not provided, we need to check + # if they're in the kwargs or use their default value + if n_args < n_pos_args: + log.debug( + "%s has %d unprovided positional arguments. " + "Will check if they are in the keyword arguments or will use default values.", + self.op_overload, + n_pos_args - n_args, + ) + for i in range(n_args, n_pos_args): + arg_name = self.arg_properties[i]["name"] + args.append( + kwargs[arg_name] + if arg_name in kwargs + else self.arg_properties[i]["default_value"] + ) + return args + + def decide_layout(self): + if isinstance(self.layout, FlexibleLayout): + self.apply_constraint() + self.freeze_layout() + + def codegen_comment(self, wrapper): + origin_str, detailed_origin_str = get_kernel_metadata(self, wrapper) + if origin_str: + wrapper.writeline(origin_str) + + def codegen(self, wrapper): + raise NotImplementedError + + def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None): + self.cpp_kernel_name = cpp_kernel_name + self.cpp_kernel_overload_name = None + self.cpp_kernel_key = None + self.cpp_op_schema = None + if not V.graph.cpp_wrapper or not isinstance( + self.op_overload, torch._ops.OpOverload + ): + return + + kernel = self.op_overload + if self.cpp_kernel_name is None: + # Try to construct cpp_kernel_name from op_overload + if kernel.namespace == "aten": + # Calling with the default kernel name can lead to ambiguous behavior like the following example. + # repeat_interleave(const at::Tensor & repeats, c10::optional output_size=std::nullopt) + # repeat_interleave(const at::Tensor & self, int64_t repeats, + # c10::optional dim=std::nullopt, c10::optional output_size=std::nullopt) + opname = ( + kernel.__name__.split(".")[0] + if kernel._overloadname == "default" + else kernel.__name__.replace(".", "_") + ) + self.cpp_kernel_name = f"at::_ops::{opname}::call" + else: + self.cpp_kernel_name = kernel._schema.name + + # Set up info for runtime schema lookup + # TODO: The logics here may be further simplified. + from .codegen.wrapper import get_cpp_op_schema + + self.cpp_kernel_overload_name = kernel._schema.overload_name + self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] + try: + self.cpp_op_schema = get_cpp_op_schema(kernel) + except Exception: + self.cpp_op_schema = "" + + def set_python_kernel_name(self, python_kernel_name: Optional[str]): + self.python_kernel_name = python_kernel_name + if python_kernel_name is not None: + return + + kernel = self.op_overload + if kernel is None: + pass + elif isinstance(kernel, torch._ops.HigherOrderOperator): + self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}" + else: + self.python_kernel_name = ( + f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" + ) + + def get_kernel_name(self): + return ( + ( + V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name) # type: ignore[attr-defined] + if config.abi_compatible + else self.cpp_kernel_name + ) + if V.graph.cpp_wrapper + else self.python_kernel_name + ) + + @staticmethod + def copy_input(x): + pw = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=x.get_size(), + origin_node=x.get_origin_node(), + traceback=x.get_traceback(), + ) + pw.realize() + return pw + + @classmethod + def process_kernel( + cls, kernel, *args, **kwargs + ) -> Tuple[ + Any, + List[Any], + List[Any], + Callable[[Any, Any], Any], + Optional[Dict[sympy.Symbol, pytree.KeyPath]], + ]: + binded_args = {"args": args, "kwargs": kwargs} + + args_flat, args_spec = pytree.tree_flatten(binded_args) + + is_arg_tensor = [] + tensor_args = [] + non_tensor_args: List[Any] = [] + for arg in args_flat: + is_arg_tensor.append(isinstance(arg, IRNode)) + if is_arg_tensor[-1]: + tensor_args.append(arg) + else: + if isinstance(arg, sympy.Expr): + arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) + non_tensor_args.append(arg) + + def unflatten_args(new_tensor_args, new_non_tensor_args): + result = [] + it_tensors = iter(new_tensor_args) + it_non_tensors = iter(new_non_tensor_args) + for is_tensor in is_arg_tensor: + if is_tensor: + result.append(next(it_tensors)) + else: + result.append(next(it_non_tensors)) + r = pytree.tree_unflatten(result, args_spec) + return r.get("args", []), r.get("kwargs", {}) + + tensor_args = [cls.realize_input(x) for x in tensor_args] + + # freeze layout otherwise our output stride calculation might + # become incorrect + for x in tensor_args: + if is_storage_and_layout(x): + as_storage_and_layout(x, freeze=True) + + # Rerun fake tensor propagation, because Inductor may have changed the + # strides of inputs and we need to determine accurately what the + # output stride will be. + example_args: List[Union[torch.Tensor, torch._C.ScriptObject]] = [] + + # We need to retain the constant values of fake tensors that we originally + # propagated the graph with, because for some operators running without a + # constant would trigger an error / DataDependentException + for x in tensor_args: + # if x is a view of a constant, we need to realize the view + # (we can't pass the constant into the kernel directly) + if not isinstance(x, BaseView) and x.get_name() in V.graph.constants: + example_args.append(V.graph.constants[x.get_name()]) + elif ( + not isinstance(x, BaseView) + and x.get_name() in V.graph.torchbind_constants + ): + example_args.append(V.graph.torchbind_constants[x.get_name()]) + else: + example_args.append(ir_node_to_tensor(x, guard_shape=True)) + + new_args, new_kwargs = unflatten_args(example_args, non_tensor_args) + example_output = kernel(*new_args, **new_kwargs) + + unbacked_bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]] = None + if shape_env := V.fake_mode.shape_env: + rebind_unbacked(shape_env, V.current_node, example_output) + unbacked_bindings = compute_unbacked_bindings( + shape_env, example_output, V.current_node.meta.get("val") + ) + + example_out_li = ( + [example_output] + if not isinstance(example_output, (list, tuple)) + else example_output + ) + for t in example_out_li: + if isinstance(t, torch.Tensor) and t.is_sparse: + msg = "sparsity not handled. Please file issue for sparse inference weights." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + return ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) + + @classmethod + def convert_to_reinterpret_view(cls, x): + """ + In order to pass this to an extern kernel we need a + ReinterpretView not a View. This allows us to avoid some + unneeded copies. + """ + assert isinstance(x, BaseView) + if isinstance(x, ReinterpretView): + return x + + # NOTE: Don't use extract_read_writes here as it fails when + # make_loader() inlines the computation + x_unwrap_view = x.unwrap_view() + buf = V.graph.get_buffer(x_unwrap_view.get_name()) + assert buf is not None + x_unwrap_view_fx_node = buf.get_origin_node() + # Prefer channels last format according to how the format is set from eager. + if ( + x_unwrap_view_fx_node is not None + and "val" in x_unwrap_view_fx_node.meta + and isinstance(x_unwrap_view.layout, FlexibleLayout) + and ( + x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last + ) + or x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last_3d + ) + ) + ): + x_unwrap_view.freeze_layout_with_same_order( + make_channels_last_strides_for(x_unwrap_view.get_size()) + ) + else: + x_unwrap_view.freeze_layout() + + index_args, var_ranges = dependencies.index_vars_squeeze( + x.get_size(), prefix="r" + ) + range_vars = index_args[0] + index = x.make_indexer()(range_vars) + + index = V.graph.sizevars.simplify_with_ranges(index, var_ranges) + strides = V.graph.sizevars.stride_vars(index, range_vars) + offset = V.graph.sizevars.offset_var(index, range_vars) + expected = sympy_dot(range_vars, strides) + offset + + if index != expected: + log.debug( + "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s", + strides, + offset, + index, + ) + raise NotImplementedError + + return ReinterpretView( + data=x.data, + layout=FixedLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=x.get_size(), + stride=strides, + offset=offset, + ), + ) + + @classmethod + def realize_input(cls, x): + if x is None: + return NoneAsConstantBuffer() + if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)): + return ShapeAsConstantBuffer(x) + if isinstance(x, Constant): + return V.graph.add_tensor_constant( + torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device()) + ) + if isinstance(x, ConstantBuffer): + return x + if isinstance(x, TensorBox): + return cls.realize_input(x.data) + if isinstance(x, ReinterpretView): + return ReinterpretView(cls.realize_input(x.data), x.get_layout()) + if isinstance(x, BaseView): + x.realize() + if is_storage_and_layout(x.unwrap_view()): + try: + return cls.convert_to_reinterpret_view(x) + except NotImplementedError: + pass + if isinstance(x, StorageBox): + # TODO(jansel): impose layout preference on realized buffer + x.realize() + return x + if isinstance(x, TorchBindObject): + return x + return cls.copy_input(x) + + @classmethod + def require_stride1(cls, x): + if is_storage_and_layout(x): + if len(x.get_stride()) == 0: + return x + for stride in x.get_stride(): + if stride == 1: + return x + return cls.copy_input(x) + + @classmethod + def require_strides( + cls, + x, + order: Optional[Sequence[int]] = None, + exact_strides: Optional[Sequence[_IntLike]] = None, + allow_padding=False, + ): + assert order is not None or exact_strides is not None + if x.get_numel() == 0: # Layout doesn't matter + return x + # require x to have the layout + if is_storage_and_layout(x): + while isinstance(x.get_layout(), NonOwningLayout): + x = x.get_layout().view + if isinstance(x.get_layout(), FlexibleLayout): + if order: + # If the the FlexibleLayout already has the size and stride in the required order, + # freeze it to a FixedLayout by using its current size and stride. + # The behavior of using its current size and stride or the given order can be different + # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1: + # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last), + # the current size and stride already satisfies this order. + # However by freezing it to the required order, the layout will be changed to: + # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary. + + # fix flexiblelayout to be FixedLayout with stride_order + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=get_stride_order( + V.graph.sizevars.size_hints(x.get_layout().stride) + ) + if is_stride_order_storage_and_layout(x, order) + else order, + allow_padding=allow_padding, + ) + return x + else: + # If the exact_strides is given, freeze the FlexibleLayout to a FixedLayout with the exact_strides. + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=None, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + return x + elif isinstance(x.get_layout(), FixedLayout) and ( + (order and x.get_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, x.get_layout().stride, x.get_size() + ) + ) + ): + return x + elif isinstance(x.get_layout(), MutationLayoutSHOULDREMOVE): + if isinstance(x.get_layout().real_layout(), FlexibleLayout): + raise AssertionError( + "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout" + ) + elif isinstance(x.get_layout().real_layout(), FixedLayout) and ( + (order and x.get_layout().real_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, + x.get_layout().real_layout().stride, + x.get_size(), + ) + ) + ): + return x + + # TODO - Storage to InputBuffer + if isinstance(x, InputBuffer) and ( + (order and x.get_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, x.get_layout().stride, x.get_size() + ) + ) + ): + return x + if ( + isinstance(x, TensorBox) + and isinstance(x.data, BaseView) + and not isinstance(x.data, ReinterpretView) + and is_storage_and_layout(x.unwrap_view()) + and not isinstance(x.unwrap_view().data, ExternKernelAlloc) + ): + try: + x.data = cls.convert_to_reinterpret_view(x.data) + if order: + return cls.require_stride_order( + x, order, allow_padding=allow_padding + ) + elif exact_strides: + return cls.require_exact_strides( + x, exact_strides, allow_padding=allow_padding + ) + except NotImplementedError: + pass + # Although this is a clone, inductor is good about fusing clones into previous + # operations if they weren't realized and their layouts were flexible. + x = cls.copy_input(x) + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=order, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + if order: + assert is_stride_order_storage_and_layout(x, order) + return x + + @classmethod + def require_exact_strides(cls, x, exact_strides, allow_padding=False): + return cls.require_strides( + x, exact_strides=exact_strides, allow_padding=allow_padding + ) + + @classmethod + def require_stride_order(cls, x, order, allow_padding=False): + return cls.require_strides(x, order=order, allow_padding=allow_padding) + + @classmethod + def require_channels_last(cls, x): + return cls.require_stride_order(x, NHWC_STRIDE_ORDER) + + @classmethod + def require_channels_last_3d(cls, x): + return cls.require_stride_order(x, NHWDC_STRIDE_ORDER) + + @classmethod + def require_contiguous(cls, x): + return cls.require_stride_order(x, list(reversed(range(len(x.get_size()))))) + + def apply_constraint(self): + pass + + def codegen_const_args(self, names: Optional[List[str]] = None): + if V.graph.cpp_wrapper: + result = [] + # Aten ops follow the convention that tensor args are before non-tensor args, + # in which case the following 'len(self.inputs) + i' logic works. But this + # may not be true for other ops, and if that is the case, caller needs to + # pass in a list of const arg names for arg_properties lookup. + name_to_arg_properties = None + if names and self.arg_properties: + assert len(self.constant_args) == len( + names + ), "names passed to codegen_const_args does not match self.constant_args" + name_to_arg_properties = { + arg.get("name"): arg for arg in self.arg_properties + } + + for i, x in enumerate(self.constant_args): + if name_to_arg_properties is not None: + prop = name_to_arg_properties.get(names[i]) # type: ignore[index] + type_ = prop.get("type") if prop else None + else: + idx = len(self.inputs) + i + type_ = ( + self.arg_properties[idx].get("type") + if self.arg_properties and idx < len(self.arg_properties) + else None + ) + result.append( + V.graph.wrapper_code.val_to_arg_str(x, type_) # type: ignore[arg-type] + ) + return result + else: + return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) + + def codegen_args(self): + args = [] + for i, x in enumerate(self.inputs): + if isinstance(x, list): + names = [i.codegen_reference() for i in x] + codegen_reference = f'[{", ".join(names)}]' + args.append(codegen_reference) + else: + if V.graph.cpp_wrapper: + assert self.arg_properties and i < len( + self.arg_properties + ), "Invalid access to ExternKernel.arg_properties" + type_ = self.arg_properties[i].get("type") + args.append( + V.graph.wrapper_code.val_to_arg_str( # type: ignore[arg-type] + x, type_ + ) + ) + else: + args.append(x.codegen_reference()) + args.extend(self.codegen_const_args()) + return args + + def get_kwargs_value(self, arg_name): + if arg_name in self.kwargs: + return self.kwargs.get(arg_name) + if self.allarg_properties and self.allarg_properties.get(arg_name): + return self.allarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] + else: + raise AssertionError(f"{arg_name} not in self.allarg_properties") + + def codegen_kwargs(self, skip_out=False): + if V.graph.cpp_wrapper: + kwargs = [] + for arg_name in self.ordered_kwargs_for_cpp_kernel: + if skip_out and arg_name == "out": + # ExternKernelOut has its own logic for inserting the out parameter + continue + + v = self.get_kwargs_value(arg_name) + if isinstance(v, sympy.Expr): + kwargs.append(v) + else: + type_ = ( + self.allarg_properties.get(arg_name).get("type") # type: ignore[union-attr] + if self.allarg_properties and arg_name in self.allarg_properties + else None + ) + kwargs.append( + V.graph.wrapper_code.val_to_arg_str( # type: ignore[arg-type] + v, type_ + ) + ) + else: + kwargs = [ + f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" # type: ignore[misc] + for k, v in self.kwargs.items() + ] + return kwargs + + def codegen_size_asserts(self, wrapper): + if config.size_asserts and not V.graph.cpp_wrapper: + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(self.get_size()) == 0: + return + size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size()) + stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride()) + wrapper.writeline( + f"assert_size_stride({self.get_name()}, {size}, {stride})" + ) + + def get_group_stride(self): + """ + get output sizes and strides, for template_codegen + """ + _size = self.get_size() + _stride = self.get_stride() + # iter_ranges = _size of output tensor, reduce_range = [] because no reduction + return [_size, []], _stride + + def canonicalize(self): + """ + Manually get canonicalization of the output index + """ + # manually generate index formula for conv + sizevars = V.graph.sizevars + sizes = self.get_size() + strides = self.get_stride() + strides = [sizevars.size_hint(x) for x in strides] + # TODO: I can't tell if the symbols here are temporary + index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))] + # reorder index vars according to stride + index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) + lookup = {pos: idx for idx, pos in enumerate(index_order)} + order = [lookup[i] for i in range(len(lookup))] + index_vars = [index_vars[i] for i in order] + indexer = self.make_indexer() + index = indexer(index_vars) + + new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, [index] + ) + + # assign new variables each dimension to deal with numbering mismatches + # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 + _, add_var = var_builder("c") + replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) + + index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type] + return index, tuple(new_sizes) + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + # NB: It's not necessary to check regular inputs as we automatically + # have dependencies on them + r: OrderedSet[sympy.Symbol] = OrderedSet() + for arg in self.constant_args: + r |= maybe_free_unbacked_symbols(arg) + for arg in self.kwargs.values(): + r |= maybe_free_unbacked_symbols(arg) + return r + + def __str__(self) -> str: + kernel_name = getattr(self, "python_kernel_name", None) + lines = [ + f"python_kernel_name={kernel_name!r}", + ] + lines += [ + f"{field.name}={getattr(self, field.name)}" + for field in dataclasses.fields(self) + ] + lines.append(f"origin_node={self.origin_node!r}") + return self.str_helper(lines) + + __repr__ = __str__ + + +@dataclasses.dataclass +class ExternKernelOut(ExternKernel): + def codegen(self, wrapper): + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs(skip_out=True)] + kernel_name = self.get_kernel_name() + if ( + V.graph.cpp_wrapper + and self.cpp_kernel_name == "torch::inductor::_mm_plus_mm" + ): + # For https://github.com/pytorch/pytorch/issues/128474 + kernel_name = ( + "aoti_torch__mm_plus_mm_out" + if config.abi_compatible + else "torch::inductor::_mm_plus_mm_out" + ) + else: + kernel_name = self.get_kernel_name() + wrapper.generate_extern_kernel_out( + kernel_name, + self.codegen_reference(), + self.output_view.codegen_reference() if self.output_view else None, + args, + ) + + def __init__( + self, + layout, + inputs, + constant_args=(), + kwargs=None, + output_view=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ): + super().__init__( + None, + layout, + self.unwrap_storage(inputs), + constant_args, + kwargs or {}, + None, + python_kernel_name, + cpp_kernel_name, + ordered_kwargs_for_cpp_kernel, + op_overload, + ) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def should_allocate(self): + return True + + +class RandomSeeds(ExternKernelOut): + def __init__(self, count: int, device: torch.device): + limits = torch.iinfo(torch.int64) + super().__init__( + layout=FixedLayout( + device=device, + dtype=torch.int64, + size=[count], + ), + inputs=[], + constant_args=[limits.min, limits.max, [count]], + python_kernel_name="aten.randint.low_out", + # FIXME: Ideally we should only use at::_ops::randint_low_out::call here, + # but the signature is different from is at::randint_out. Again, + # we can simplify the code when only keeping an ABI-compatible version. + cpp_kernel_name="at::_ops::randint_low_out::call" + if config.abi_compatible + else "at::randint_out", + op_overload=aten.randint.low_out, + ) + + +class ExternKernelAlloc(ExternKernel): + def codegen(self, wrapper): + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs()] + V.graph.wrapper_code.generate_extern_kernel_alloc(self, args) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def __init__( + self, + layout, + inputs, + constant_args=(), + kwargs=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ): + super().__init__( + None, + layout, + self.unwrap_storage(inputs), + constant_args, + kwargs or {}, + None, + python_kernel_name, + cpp_kernel_name, + ordered_kwargs_for_cpp_kernel, + op_overload, + ) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def should_allocate(self): + return False + + def apply_constraint(self): + raise NotImplementedError + + +class MutationOutput(Buffer): + """ + An output buffer that represents the mutation of a pre-existing buffer + """ + + def __init__(self, layout, mutated_node, mutating_node: Operation): + super().__init__(name=None, layout=layout) + mutated_node_name = mutated_node.get_name() + V.graph.mark_buffer_mutated(mutated_node_name) + self.mutation_names = [mutated_node_name] + self.mutating_node: Operation = mutating_node + self.name = V.graph.register_buffer(self) + + def get_defining_op(self) -> Operation: + return self.mutating_node + + def get_mutation_names(self): + return self.mutation_names + + def should_allocate(self): + return False + + +class UserDefinedTritonKernel(ExternKernel): + def get_kernel_and_configs(self): + from triton.runtime.autotuner import Autotuner + + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + + kernel = kernel_side_table.get_kernel(self.kernel_idx) + configs = [] + if isinstance(kernel, Autotuner): + configs = kernel.configs + kernel = kernel.fn + return kernel, configs + + def codegen(self, wrapper): + kernel, configs = self.get_kernel_and_configs() + + # Definition of kernel + new_name, triton_meta = wrapper.define_user_defined_triton_kernel( + kernel, configs, self.kwargs + ) + raw_args = [ + self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel + ] + + # NOTE: raw_args doesn't include autotuned args. + # But, kernel.constexprs includes indices of autotuned args. + # So, let's recalculate constexpr indices wrt to raw_args. + constexpr_indices = [] + for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel): + if kernel.arg_names.index(kwarg) in kernel.constexprs: + constexpr_indices.append(idx) + + # Call to kernel + self.codegen_comment(wrapper) + wrapper.generate_user_defined_triton_kernel( + new_name, raw_args, self.grid, configs, triton_meta, constexpr_indices + ) + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + # add unbacked symbols used in the grid to the ones used + # in the kwargs (the latter is generated by ExternKernel) + return super().get_unbacked_symbol_uses() | free_unbacked_symbols(self.grid) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__(self, *, kernel_idx, grid, kernel_args): + inputs = [] + kwargs = {} + constant_args = [] + for k, v in kernel_args.items(): + if isinstance(v, TensorBox): + t = InputsKernel.unwrap_storage_for_input(self.realize_input(v)) + inputs.append(t) + kwargs[k] = t + else: + constant_args.append(v) + kwargs[k] = v + + assert len(inputs) != 0 + self.device = inputs[0].get_device() + + super().__init__( + None, + NoneLayout(self.device), # type: ignore[arg-type] + inputs, + tuple(constant_args), + kwargs, + ) + self.kernel_idx = kernel_idx + self.grid = grid + + kernel, configs = self.get_kernel_and_configs() + # If we are autotuning, not all arguments will be passed + self.ordered_kwargs_for_cpp_kernel = [ + arg for arg in kernel.arg_names if arg in kernel_args + ] + + from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors + + autotuned_kwargs = configs[0].kwargs if len(configs) > 0 else {} + self.mutable_args = [ + kernel_args[key] + for key in identify_mutated_tensors( + kernel, {**kernel_args, **autotuned_kwargs} + ) + ] + + self.mutation_outputs = [ + MutationOutput(NoneLayout(self.device), buf, self) + for buf in self.mutable_args + ] + V.graph.register_operation(self) + + def get_outputs(self) -> List[Buffer]: + return list(self.mutation_outputs) + + def get_device(self) -> torch.device: + return self.device + + +class InplaceBernoulliFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper): + (x,) = (t.codegen_reference() for t in self.inputs) + + if V.graph.cpp_wrapper and config.abi_compatible: + # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here, + # which needs to be explicitly generated for cpp wrapper + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}" + ) + else: + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__(self, op_overload, x, *constant_args): + super().__init__( + None, + NoneLayout(x.get_device()), # type: ignore[arg-type] + self.unwrap_storage([x]), + constant_args, + op_overload=op_overload, + ) + V.graph.mark_buffer_mutated(x.get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + if not config.abi_compatible: + # TODO: this should be simplified once we switch to ABI-compatible only + self.cpp_kernel_name = "at::native::bernoulli_" + + +# Used to deal with torch.complex types +class InplaceCopyFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper): + (dst, src, non_blocking) = self.codegen_args() + wrapper.codegen_device_copy(src, dst) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__( + self, + layout, + inputs, + constant_args, + ): + super().__init__( + None, + layout, + inputs, + constant_args, + python_kernel_name="aten.copy_", + cpp_kernel_name=( + "aoti_torch_copy_" if config.abi_compatible else "at::_ops::copy_::call" + ), + ) + V.graph.mark_buffer_mutated(inputs[0].get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create(cls, dst, src, non_blocking: bool = False): + inputs = [cls.realize_input(t) for t in [dst, src]] + constant_args = (non_blocking,) + result = InplaceCopyFallback( + NoneLayout(dst.get_device()), # type: ignore[arg-type] + inputs, + constant_args, + ) + return result + + +class MutatingFirstArgExternKernel(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper): + argrefs = [ + *(t.codegen_reference() for t in self.inputs), + *map(repr, self.constant_args), + ] + wrapper.writeline( + f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}" + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def has_side_effects(self): + return True + + +class ResizeStorageBytes(MutatingFirstArgExternKernel): + def __init__(self, variable, new_size): + assert isinstance(new_size, int), "TODO: dynamic shapes" + super().__init__( + None, + NoneLayout(variable.get_device()), # type: ignore[arg-type] + self.unwrap_storage([variable]), + constant_args=(new_size,), + ) + V.graph.mark_buffer_mutated(variable.get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + self.python_kernel_name = "inductor_ops.resize_storage_bytes_" + self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_" + V.graph.never_reuse_buffers.add(variable.data.get_name()) + + +class SetSourceTensorKernel(ExternKernelAlloc): + def __init__(self, self_tensor, storage_tensor): + self_tensor.freeze_layout() + super().__init__( + self_tensor.get_layout(), + [self_tensor, storage_tensor], + python_kernel_name="torch.ops.aten.set_.source_Tensor", + op_overload=torch.ops.aten.set_.source_Tensor, + ) + V.graph.never_reuse_buffers.add(self_tensor.data.get_name()) + V.graph.never_reuse_buffers.add(storage_tensor.get_name()) + V.graph.never_reuse_buffers.add(self.get_name()) + device = storage_tensor.get_device() + self.mutation_outputs = [ + MutationOutput(NoneLayout(device), self_tensor, self), + MutationOutput(NoneLayout(device), storage_tensor, self), + ] + + def get_inputs_that_alias_output(self): + return [self.inputs[0].get_name(), self.inputs[1].get_name()] + + +class ScatterFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly. + This class handles both aten.scatter_ and aten.scatter_reduce_. + It also handle the case `src` being a scalar properly. + """ + + def codegen(self, wrapper): + reduce = self.kwargs["reduce"] + if V.graph.cpp_wrapper: + # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum + get_operator_enum = {"add": "sum", "multiply": "prod"} + if reduce in get_operator_enum: + reduce = get_operator_enum[reduce] + + if self.src_is_tensor: + (x, index, src) = (t.codegen_reference() for t in self.inputs) + else: + (x, index) = (t.codegen_reference() for t in self.inputs) + src = self.constant_args[1] + wrapper.generate_scatter_fallback( + x, + [x, self.constant_args[0], index, src], + self.cpp_kernel_name, + self.python_kernel_name, + self.src_is_tensor, + reduce, + self.codegen_kwargs(), + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__( + self, + op_overload, + x, + dim: int, + index, + src, + *, + reduce: Optional[str] = None, + include_self: bool = True, + ): + self.src_is_tensor = isinstance(src, TensorBox) + + constant_args: Tuple[Any, ...] + if self.src_is_tensor: + tensors = [self.realize_input(t) for t in [x, index, src]] + constant_args = (dim,) + else: + tensors = [self.realize_input(t) for t in [x, index]] + constant_args = (dim, src) + + super().__init__( + None, + NoneLayout(x.get_device()), # type: ignore[arg-type] + self.unwrap_storage(tensors), + constant_args, + {"reduce": reduce, "include_self": include_self}, + python_kernel_name=str(op_overload), + ordered_kwargs_for_cpp_kernel=["reduce", "include_self"], + op_overload=op_overload, + ) + V.graph.mark_buffer_mutated(x.get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + +class IndexPutFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation and indices properly + """ + + def codegen(self, wrapper): + (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs) + indices = [] + iter_valid_indices = iter(valid_indices) + for i, _ in enumerate(self.indices): + if self.indices[i] is not None: + indices.append(next(iter_valid_indices)) + else: + indices.append(V.graph.wrapper_code.none_str) + + wrapper.generate_index_put_fallback( + self.get_kernel_name(), x, indices, values, *self.codegen_const_args() + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__(self, op_overload, x, indices, values, accumulate): + self.indices = indices + valid_indices = [i for i in indices if i is not None] + tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] + cpp_kernel_name = ( + "aoti_torch_index_put_out" if config.abi_compatible else "at::index_put_out" + ) + super().__init__( + None, + NoneLayout(x.get_device()), # type: ignore[arg-type] + self.unwrap_storage(tensors), + (accumulate,), + python_kernel_name="aten.index_put_", + cpp_kernel_name=cpp_kernel_name, + op_overload=op_overload, + ) + V.graph.mark_buffer_mutated(self.inputs[0].get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + +class DeviceCopy(ExternKernelOut): + @classmethod + def create(cls, x, device): + if ( + not x.is_extern() + and all(r in V.graph.constants for r in x.get_read_names()) + and not config.aot_inductor.use_runtime_constant_folding + ): + return x.constant_to_device(device) + + V.graph.add_device_info(device) + V.graph.add_device_info(x.get_device()) + + developer_warning("DeviceCopy in input program") + return DeviceCopy( + FlexibleLayout( + device=device, + dtype=x.get_dtype(), + size=x.get_size(), + ), + [cls.realize_input(x)], + ) + + def codegen(self, wrapper): + args = self.codegen_args() + assert len(args) == 1 + if self.output_view: + wrapper.codegen_device_copy(args[0], self.output_view.codegen_reference()) + else: + wrapper.codegen_device_copy(args[0], self.codegen_reference()) + + +class DynamicScalar(ExternKernel): + """ + The result of a call to aten._local_scalar_dense. + """ + + def get_reads(self): + return () + + def should_allocate(self): + return False + + def __init__(self, sym, keypath, data): + data.realize() + super().__init__(None, NoneLayout(torch.device("cpu")), self.unwrap_storage([data])) # type: ignore[arg-type] + self.sym = sym + self.keypath = keypath + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet([self.sym]) + + def codegen(self, wrapper): + wrapper.codegen_dynamic_scalar(self) + + +class AssertScalar(ExternKernel): + """ + The result of a call to aten._assert_scalar + """ + + def get_reads(self): + return () + + def should_allocate(self): + return False + + def __init__(self, scalar, msg): + super().__init__( + # Buffer(name, layotu) + None, + NoneLayout(torch.device("cpu")), # type: ignore[arg-type] + # InputsKernel(inputs) + [], + ) # type: ignore[arg-type] + self.scalar = scalar + self.msg = msg + + def has_side_effects(self): + return True + + def get_unbacked_symbol_uses(self): + return free_unbacked_symbols(self.scalar) + + def codegen(self, wrapper): + if V.graph.cpp_wrapper: + pass + else: + # NB: It is EXTREMELY important not to simplify the scalar under + # assertion here, because simplify is done with respect to + # runtime asserts. So if you have "u0 == 0" in the runtime + # asserts, if you subsequently try to simplify(u0 == 0), you will + # get True (because we've already runtime assert'ed that it's + # true). But we're code generating the actual runtime assert + # here!! + wrapper.writeline( + f"if not {V.graph.wrapper_code.codegen_python_sizevar(self.scalar, simplify=False)}:" + ) + wrapper.writeline(f" raise RuntimeError({repr(self.msg)})") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + wrapper.writeline(f"{self.get_name()} = None") + + +@dataclasses.dataclass +class ExternKernelNode: + name: str + node: export_schema.Node + + +has_c_shim = OrderedSet( + [ + aten._embedding_bag.default, + aten._fft_c2c.default, + aten._scaled_dot_product_efficient_attention.default, + aten._scaled_dot_product_flash_attention.default, + aten._scaled_dot_product_cudnn_attention.default, + aten._scaled_mm.default, + aten.addmm.out, + aten.bmm.out, + aten.copy_.default, + aten.mm.out, + aten.repeat_interleave.Tensor, + aten.nonzero.default, + aten.view.dtype, + aten.view_as_real.default, + ] +) + + +class FallbackKernel(ExternKernelAlloc): + def __init__( + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + *, + unbacked_bindings=None, + ): + if ( + kernel == aten.mul.Tensor + and len(tensor_args) == 1 + and len(nontensor_args) == 1 + ): + # When aten.mul.Tensor's second arg is constant, cpp wrapper expects + # to call mul_Scalar. A more proper fix is to do it in decomposition. + # See https://github.com/pytorch/pytorch/issues/123478 + kernel = aten.mul.Scalar + + super().__init__( + layout, + tuple(tensor_args), + tuple(nontensor_args), + op_overload=kernel, + ) + + # We need output buffers for generating kernel arguments in the + # abi-compatible mode, where we retrieve outputs by pass each individual + # output through the abi-compatible interface. + self.outputs: Sequence[Any] = [] + self.use_runtime_dispatch = False + self.unbacked_bindings = unbacked_bindings + + assert isinstance( + kernel, + ( + torch._ops.OpOverload, + torch._ops.HigherOrderOperator, + ), + ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported" + self.op_overload = kernel + self.unflatten_args = unflatten_args + self.kwargs = {} if kwargs is None else kwargs + V.graph.warn_fallback(self.python_kernel_name) # type: ignore[arg-type] + + # args that are aliased + self.alias_names: List[str] = [] + # args that are mutated AND returned from the op + self.mutation_names: List[str] = [] + + if isinstance(self.op_overload, torch._ops.HigherOrderOperator): + # We assume here that HOPs with FallbackKernel are functional. + # This may not always be true! HOPs must individually opt-in to + # FallbackKernel, so please check this if you opt-in. + return + + if "_c10d_functional" in self.op_overload.name(): + # _c10d_functional kernels are lowered into _CollectiveKernel which + # derives from FallbackKernel for the cpp codegen. The kernels + # don't pass the can_auto_functionalize check, but their mutation + # is handled properly by _CollectiveKernel. + return + + schema = self.op_overload._schema + + # NOTE: [FallbackKernel supported operators] + # We only support three types of operators: + # - functional ops + # - view ops + # - inplace aten ops + # - mutating ops that are auto-functionalizable. That is, + # the operator may mutate any number of inputs, but its outputs + # may not alias any of the inputs. + # + # The unsupported cases usually do not show up here (because + # AOTAutograd functionalized them away); the only way for an in-place + # op to show up here is if a lowering or pass introduced it. + if torch._library.utils.mutates_and_returns_first_arg(self.op_overload): + self.mutation_names.append(tensor_args[0].get_name()) + return + + if schema.is_mutable and not can_auto_functionalize(kernel): + raise NotImplementedError( + f"NYI: Can't generate FallbackKernel for {kernel}" + ) + + schema_args = schema.arguments + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + + def handle_aliasing_and_mutation(info, arg): + # Assertions to make sure we didn't mismatch args + if isinstance(info.type, torch.ListType): + assert isinstance(arg, (list, tuple)) + is_optional_tensor = isinstance( + info.type, torch.OptionalType + ) and isinstance(info.type.getElementType(), torch.TensorType) + is_list_tensor = isinstance(info.type, torch.ListType) and isinstance( + info.type.getElementType(), torch.TensorType + ) + if is_optional_tensor or isinstance(info.type, torch.TensorType): + # PyTorch also accepts None and scalar types for args marked as "Tensor". + # We're not going to check all of them here. + assert not isinstance(arg, (tuple, list)) + + if arg is None: + return + if info.alias_info is None: + return + + def add_alias(t): + self.alias_names.append(t.get_name()) + if info.alias_info.is_write: + self.mutation_outputs.append( + MutationOutput(NoneLayout(t.get_device()), t, self) + ) + + if is_list_tensor: + for tensor_arg in arg: + add_alias(tensor_arg) + else: + assert isinstance(info.type, torch.TensorType) or is_optional_tensor + add_alias(arg) + + for info, arg in torch._library.utils.zip_schema(schema, args, kwargs): + handle_aliasing_and_mutation(info, arg) + + def codegen_unbacked_symbol_defs(self, wrapper): + if not hasattr(self, "unbacked_bindings"): + return + + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, self.unbacked_bindings + ) + + if not unbacked_bindings: + return + + for s, keypath in unbacked_bindings.items(): + + def go(expr, keypath): + if keypath == (): + return expr + + if ( + len(keypath) >= 2 + and isinstance(keypath[0], CallMethodKey) + and isinstance(keypath[1], pytree.SequenceKey) + ): + return go( + f"{expr}.{keypath[0].name}({keypath[1].idx})", keypath[2:] + ) + elif isinstance(keypath[0], CallMethodKey): + return go(f"{expr}.{keypath[0].name}()", keypath[1:]) + elif isinstance(keypath[0], pytree.SequenceKey): + return ( + go(f"std::get<{keypath[0].idx}>({expr})", keypath[1:]) + if V.graph.cpp_wrapper + else go(f"{expr}[{keypath[0].idx}]", keypath[1:]) + ) + elif isinstance(keypath[0], DivideByKey): + # TODO: need to assert divisibility + # TODO: this is invalid C++ codegen + return go(f"{expr}.__floordiv__({keypath[0].divisor})", keypath[1:]) + else: + raise AssertionError(f"unrecognized keypath {keypath}") + + def go_outer(): + if V.graph.cpp_wrapper and config.abi_compatible: + # Special handling for the top level buffer access, + # because self.get_name() is actually never bound; the + # individual output arguments are bound by + # generate_c_shim_fallback_kernel + if len(self.outputs) == 1: + return go(self.outputs[0].get_name(), keypath) + else: + assert isinstance(keypath[0], pytree.SequenceKey) + return go(self.outputs[keypath[0].idx].get_name(), keypath[1:]) + else: + return go(self.get_name(), keypath) + + wrapper.writeline( + f"{wrapper.codegen_unbacked_symbol_decl(s)} = {go_outer()}{wrapper.ending}" + ) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + if unbacked_bindings := getattr(self, "unbacked_bindings", None): + return resolve_unbacked_bindings( + V.graph.sizevars.shape_env, unbacked_bindings + ).keys() + else: + return OrderedSet() + + def codegen_args(self): + @dataclasses.dataclass + class Shim: + ref: Any + + def __repr__(self) -> str: + return self.ref + + tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] + args, kwargs = self.unflatten_args(tensor_args, self.constant_args) + if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): + args = self.fill_non_provided_args(args, kwargs) + args = [ + V.graph.wrapper_code.val_to_arg_str(x, param.real_type) + for param, x in zip(self.op_overload._schema.arguments, args) + ] + else: + args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args] + + # let self.codegen_kwargs handle kwargs + self.kwargs.update(kwargs) + return args + + @staticmethod + def find_device(tensor_args, example_output): + if tensor_args: + devices = [arg.get_device() for arg in tensor_args if arg.get_device()] + return devices[0] + if isinstance(example_output, torch.Tensor): + return example_output.device + if isinstance(example_output, (list, tuple)): + device_set = OrderedSet( + FallbackKernel.find_device(None, x) for x in example_output + ) + # Remove None + devices = [device for device in device_set if device] + if len(devices) == 1: + return devices[0] + for device in devices: + if is_gpu(device.type): + return device + return devices[0] + return None + + def has_side_effects(self): + if isinstance(self.op_overload, torch._ops.HigherOrderOperator): + return False + return get_schema_info(self.op_overload).is_mutable() + + def get_inputs_that_alias_output(self): + return self.alias_names + + def get_mutation_names(self): + assert len(self.mutation_names) <= 1 + return self.mutation_names + + # ProxyExecutor Design Note + # We export the ExternFallbackNodes (for custom ops) into a serialized file + # and run it with a host side proxy executor to address the ABI problem + # This is currently only implemented for fbcode. Eventually, we will also make this work for OSS. + # Detailed design doc can be found at + # https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing + def export_extern_kernel_node(self): + assert isinstance(self, FallbackKernel) + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + args = self.fill_non_provided_args(args, kwargs) + ordered_kwargs = [ + kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel + ] + if not V.graph.aot_mode: + # No need to serialize in the cpp wrapper JIT mode + return [*args, *ordered_kwargs] + + serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type] + named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs) # type: ignore[arg-type] + + # serialize_outputs + def handle_single_output(return_type, output): + if isinstance(return_type, torch.TensorType): + # For single Tensor + out = output + if isinstance(output, (list, tuple)): + assert len(output) == 1 + out = output[0] + return export_schema.Argument.create( + as_tensor=export_schema.TensorArgument(name=out.get_name()) + ) + elif isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.TensorType + ): + # For single TensorList + return export_schema.Argument.create( + as_tensors=[ + export_schema.TensorArgument(name=out.get_name()) + for out in output + ] + ) + else: + raise RuntimeError(f"Unsupported return type {type(return_type)}") + + target = self.op_overload + returns = target._schema.returns # type: ignore[union-attr] + if len(returns) == 1: + return_type = returns[0].real_type + output_arguments = [handle_single_output(return_type, self.outputs)] + else: + # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])" + assert isinstance(self.outputs, tuple) + assert len(returns) == len(self.outputs) + output_arguments = [ + handle_single_output(return_schema.real_type, output) + for return_schema, output in zip(returns, self.outputs) + ] + + node = ExternKernelNode( + name=self.get_name(), + node=export_schema.Node( + target=self.op_overload.name(), # type: ignore[union-attr] + inputs=named_arguments, + outputs=output_arguments, + metadata={}, + ), + ) + + V.graph.extern_kernel_nodes.append(node) + + return [*args, *ordered_kwargs] + + def codegen(self, wrapper): + kernel = self.op_overload + if kernel.namespace == "aten": # type: ignore[union-attr] + # Aten Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload) + if V.graph.cpp_wrapper: + from torchgen.aoti.fallback_ops import inductor_fallback_ops + + if config.abi_compatible and str(kernel) not in inductor_fallback_ops: + # C shim v2 is torchgen-ed, which should cover all aten ops. + # If you do hit a missed op, please update fallback_ops.py. + log.warning( + "%s is missing a c-shim implementation, using proxy executor as fallback", + kernel, + ) + self.use_runtime_dispatch = True + elif kernel.namespace == "_quantized": # type: ignore[union-attr] + # Internal Quantized Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload) + if V.graph.cpp_wrapper: + if not config.abi_compatible: + self.use_runtime_dispatch = True + else: + # For non-aten OpOverload, i.e. custom ops + if V.graph.cpp_wrapper: + self.use_runtime_dispatch = True + + if self.use_runtime_dispatch: + self.codegen_comment(wrapper) + + exported_args = None + args = None + if config.abi_compatible: + exported_args = self.export_extern_kernel_node() + else: + args = [*self.codegen_args(), *self.codegen_kwargs()] + + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + self.op_overload, + exported_args, + self.outputs, + ) + else: + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs()] + V.graph.wrapper_code.generate_fallback_kernel(self, args) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + self.codegen_unbacked_symbol_defs(wrapper) + + @staticmethod + def tensor_to_layout(output: torch.Tensor): + return FixedLayout( + output.device, + output.dtype, + convert_shape_to_inductor(output.size()), + convert_shape_to_inductor(output.stride()), + ) + + @classmethod + def create(cls, kernel, *args, **kwargs): + fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,) + context: ContextManager[None] = ( + V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() # type: ignore[assignment] + ) + with context: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, *args, **kwargs) + + device = cls.find_device(tensor_args, example_output) + if example_output is None: + packed = cls( + NoneLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings=unbacked_bindings, + ) + + else: + assert device, "Not sure where to find device info" + packed = cls( + MultiOutputLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings=unbacked_bindings, + ) + + def generate_output(output, indices): + if isinstance(output, (list, tuple)): + return type(output)( + generate_output(output[i], indices + [(type(output), i)]) + for i in range(len(output)) + ) + elif isinstance(output, dict): + return { + key: generate_output(val, indices + [(type(output), key)]) + for key, val in output.items() + } + elif isinstance(output, torch.Tensor): + return MultiOutput( + cls.tensor_to_layout(output), + packed, + indices, + ) + elif isinstance(output, int): + return output + elif isinstance(output, torch.SymInt): + return output.node.expr + else: + assert ( + output is None + ), f"FallbackKernel output type {type(output)} is not supported" + return None + + outputs = generate_output(example_output, []) + if isinstance(outputs, (list, tuple, dict)): + packed.outputs = outputs # type: ignore[assignment] + else: + packed.outputs = [outputs] + return outputs + + def apply_constraint(self): + return super().apply_constraint() + + +@dataclasses.dataclass +class ComplexView(FallbackKernel): + """View a complex number as two dtyped numbers or vice versa""" + + def should_allocate(self): + return False + + def get_inputs_that_alias_output(self): + # Signal to codegen that our output buffer isn't safe to reuse + return [self.inputs[0].get_name()] + + def __init__( + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + *, + unbacked_bindings=None, + ): + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + unbacked_bindings=unbacked_bindings, + ) + + +@dataclasses.dataclass +class MultiOutputLayout(IRNode): + device: torch.device + + +class MultiOutput(ExternKernel): + # Given an input MultiOutputLayout buffer, indexes out an actual buffer + # from that result. This doesn't actually produce multiple outputs, + # that's MultiOutputLayout! + def codegen_list_tuple_access(self, basename, indices): + if len(indices) > 0: + itype, i = indices[0] + if issubclass(itype, list): + return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:]) + elif issubclass(itype, tuple): + # cpp wrapper code needs to use std::get<> to access a tuple + tuple_access = V.graph.wrapper_code.codegen_tuple_access( + basename, self.get_name(), str(i) + ) + return self.codegen_list_tuple_access(tuple_access, indices[1:]) + elif issubclass(itype, dict): + return self.codegen_list_tuple_access(f"{basename}['{i}']", indices[1:]) + else: + raise AssertionError("non supported index type: ", itype) + else: + return basename + + def codegen(self, wrapper): + wrapper.codegen_multi_output( + self.get_name(), + self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices), + ) + + def __init__(self, layout, input, indices: List[Tuple[Any, ...]]): + super().__init__(None, layout, [input], ()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + self.indices = indices + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return self.inputs[0].get_unbacked_symbol_uses() + + def should_allocate(self): + return False + + def get_inputs_that_alias_output(self): + return [ + inp.get_name() + for inp in self.inputs + if isinstance(inp, FallbackKernel) + and len(inp.get_inputs_that_alias_output()) > 0 + ] + + +@dataclasses.dataclass +class MutableBox(IRNode): + """ + TensorBox / StorageBox allow in-place mutation of Tensors + """ + + data: IRNode + + def __getattr__(self, name): + fn = getattr(self.data, name) + if callable(fn): + return fn + raise AttributeError(f"{type(self.data).__name__}.{name} not callable") + + def realize(self): + return self.data.realize() + + def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: + return self.data.get_unbacked_symbol_uses() + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() + + def get_defining_op(self): + return self.data.get_defining_op() + + def codegen_reference(self, writer=None): + return self.data.codegen_reference(writer) + + @property + def layout(self): + return self.data.get_layout() + + def get_layout(self): + return self.layout + + def get_size(self): + return self.data.get_size() + + @property + def dtype(self): + return self.data.dtype + + def __str__(self) -> str: + if isinstance(self.data, MutableBox): + line0 = f"{type(self).__name__}({type(self.data).__name__}(" + endl = "))" + inner = self.data.data + else: + line0 = f"{type(self).__name__}(" + inner = self.data + endl = ")" + + lines = [ + line0, + indent(str(inner)), + endl, + ] + return "\n".join(lines) + + __repr__ = __str__ + + +class TensorBox(MutableBox): + @staticmethod + def create(data): + return TensorBox(StorageBox(data)) + + +class StorageBox(MutableBox): + def is_input_buffer(self): + if isinstance(self.data, (InputBuffer, ReinterpretView)): + return self.data.get_name() in V.graph.graph_inputs + return False + + def is_module_buffer(self): + return ( + isinstance(self.data, (ConstantBuffer)) + and self.data.get_name() in V.graph.constants + ) + + def realize(self): + if isinstance( + self.data, + ( + ComputedBuffer, + InputsKernel, + InputBuffer, + ReinterpretView, + TemplateBuffer, + ), + ): + return self.data.get_name() + assert isinstance(self.data, (Pointwise, Reduction, Scan, Sort)), type( + self.data + ) + origin_node = self.data.get_origin_node() + traceback = self.data.get_traceback() + self.data = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=self.data.get_device(), + dtype=self.data.get_dtype(), + size=self.data.get_size(), + ), + data=self.data, + ) + self.data.name = V.graph.register_buffer(self.data) + V.graph.register_operation(self.data) + self.data.origins = self.origins + self.data.origin_node = origin_node + self.data.traceback = traceback + return self.data.name + + def realize_hint(self): + """ + Called on buffers we expect to be forced to realize later. + """ + if ( + isinstance(self.data, (Pointwise, Reduction)) + and self.data.inner_fn_opcount().nontrivial_read_count > 1 + ): + self.realize() + + def has_exceeded_max_reads(self): + return isinstance(self.data, Pointwise) and ( + self.num_reads() > config.realize_acc_reads_threshold + or self.has_large_inner_fn() + ) + + def should_realize_on_reuse(self, users): + """ + A heuristic to decide if we should realize a tensor + that is used multiple times. + """ + if users > 1 and isinstance(self.data, (Pointwise, Reduction)): + if is_cpu(self.data): + # Heuristic for realizing reused result of heavy ops on cpu + opcount = self.data.inner_fn_opcount() + heavy_ops = ["exp", "sigmoid"] # a list of heavy ops + if any(x in opcount.used_ops for x in heavy_ops): + return True + return ( + self.num_reads() > config.realize_reads_threshold + or self.has_large_inner_fn() + ) + return False + + def mark_reuse(self, users): + if self.should_realize_on_reuse(users): + self.realize() + + def num_reads(self): + return self.data.num_reads() + + +@dataclasses.dataclass +class Subgraph(IRNode): + name: str + graph_module: torch.fx.GraphModule + graph: Optional[GraphLowering] = None + + +def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool: + buffers = [ + buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer + for buffer in buffers + ] + # assuming the same buffer is represented by the same IRNode object + return len(OrderedSet(id(buffer) for buffer in buffers)) < len(buffers) + + +@dataclasses.dataclass +class Conditional(ExternKernel): + predicate: Optional[IRNode] = None + operands: Optional[List[TensorBox]] = None + true_subgraph: Optional[Subgraph] = None + false_subgraph: Optional[Subgraph] = None + outputs: Optional[List[MultiOutput]] = None + + def __init__( + self, + predicate: IRNode, + operands: List[TensorBox], + true_subgraph: Subgraph, + false_subgraph: Subgraph, + layout: MultiOutputLayout, + ): + self.predicate = predicate + self.operands = operands + self.true_subgraph = true_subgraph + self.false_subgraph = false_subgraph + + inputs = [] + if not isinstance(predicate, ShapeAsConstantBuffer): + inputs.append(predicate) + inputs.extend(operands) + + super().__init__( + name=None, + layout=layout, # type: ignore[arg-type] + inputs=inputs, # type: ignore[list-item] + ) + + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create( + cls, + predicate: TensorBox, + true_fn: Subgraph, + false_fn: Subgraph, + operands: List[TensorBox], + ): + predicate = cls.realize_input(predicate) + operands = [cls.realize_input(x) for x in operands] + + fx_operands = V.graph.current_node.args[-1] + fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] + + for subgraph in (true_fn, false_fn): + if subgraph.graph is None: + # create and lower subgraphs + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fake_operands, + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_operands) + + true_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] + false_outputs = false_fn.graph.graph_outputs # type: ignore[union-attr] + + for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)): + if _has_aliased_buffers(true_outputs): + raise AssertionError( + "Output aliasing is currently not supported in compiled torch.cond. " + f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}" + ) + + # make sure true and false outputs are structurally equivalent + assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs) + for i, (to, fo) in enumerate(zip(true_outputs, false_outputs)): + assert to.get_size() == fo.get_size(), (i, to, fo) + assert to.get_stride() == fo.get_stride(), (i, to, fo) + assert to.get_device() == fo.get_device(), (i, to, fo) + assert to.get_dtype() == fo.get_dtype(), (i, to, fo) + assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo) + + if not isinstance(predicate, ShapeAsConstantBuffer): + # use predicate device for consistent codegen-ing + device = predicate.get_device() + else: + # predicate is not a Tensor: use first operand's device + assert ( + len(operands) > 0 + ), "When predicate is not a Tensor, there must be at least one operand in torch.cond." + device = operands[0].get_device() + + conditional = Conditional( + predicate=predicate, + operands=operands, + true_subgraph=true_fn, + false_subgraph=false_fn, + layout=MultiOutputLayout(device), + ) + + outputs = [ + MultiOutput( + FixedLayout( + device=output.get_device(), + dtype=output.get_dtype(), + size=output.get_size(), + stride=output.get_stride(), + offset=output.get_layout().offset, + ), + conditional, + [(list, i)], + ) + # as the true and false outputs are equivalent, + # we can use either of them here as a "template" + for i, output in enumerate(true_outputs) + ] + + conditional.outputs = outputs + return outputs + + def codegen(self, wrapper): + wrapper.codegen_conditional(self) + + +@dataclasses.dataclass +class WhileLoop(ExternKernel): + carried_inputs: Optional[List[TensorBox]] = None + additional_inputs: Optional[List[TensorBox]] = None + cond_subgraph: Optional[Subgraph] = None + body_subgraph: Optional[Subgraph] = None + outputs: Optional[List[MultiOutput]] = None + + def __init__( + self, + carried_inputs: List[TensorBox], + additional_inputs: List[TensorBox], + cond_subgraph: Subgraph, + body_subgraph: Subgraph, + layout: MultiOutputLayout, + ): + self.carried_inputs = carried_inputs + self.additional_inputs = additional_inputs + self.cond_subgraph = cond_subgraph + self.body_subgraph = body_subgraph + + super().__init__( + name=None, + layout=layout, # type: ignore[arg-type] + inputs=carried_inputs + additional_inputs, # type: ignore[list-item] + ) + + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create( + cls, + cond_fn: Subgraph, + body_fn: Subgraph, + carried_inputs: List[TensorBox], + additional_inputs: List[TensorBox], + ): + carried_inputs = [cls.realize_input(x) for x in carried_inputs] + additional_inputs = [cls.realize_input(x) for x in additional_inputs] + all_inputs = carried_inputs + additional_inputs + + fx_all_inputs = V.graph.current_node.args[-2] + V.graph.current_node.args[-1] # type: ignore[operator] + fake_all_inputs = [x.meta["val"] for x in fx_all_inputs] # type: ignore[union-attr] + + for subgraph in (cond_fn, body_fn): + if subgraph.graph is None: + # create and lower subgraphs + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fx_all_inputs, # type: ignore[arg-type] + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_all_inputs) + + cond_outputs = cond_fn.graph.graph_outputs # type: ignore[union-attr] + body_outputs = body_fn.graph.graph_outputs # type: ignore[union-attr] + + if _has_aliased_buffers(body_outputs): + raise AssertionError( + "Output aliasing is currently not supported in compiled torch.while_loop. " + f"The outputs of the body_fn subgraph of torch.while_loop are aliased: {body_outputs}" + ) + + # make sure cond_fn returns a boolean scalar Tensor + assert len(cond_outputs) == 1, cond_outputs + assert cond_outputs[0].get_dtype() == torch.bool, cond_outputs + assert len(cond_outputs[0].get_size()) == 0, cond_outputs + + assert ( + len(all_inputs) > 0 + ), "torch.while_loop is assumed to have at least one operand." + + device = all_inputs[0].get_device() + + # make sure carried_inputs and body outputs are structurally equivalent + assert len(carried_inputs) == len(body_outputs), (carried_inputs, body_outputs) + for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)): + assert op.get_size() == bo.get_size(), (i, op, bo) + assert op.get_stride() == bo.get_stride(), (i, op, bo) + # assume all carried_inputs and outputs are on the same device + # as the MultiOutputLayout below requires single device + assert op.get_device() == bo.get_device() == device, (i, op, bo, device) + assert op.get_dtype() == bo.get_dtype(), (i, op, bo) + assert op.get_layout().offset == bo.get_layout().offset, (i, op, bo) + + while_loop = WhileLoop( + carried_inputs=carried_inputs, + additional_inputs=additional_inputs, + cond_subgraph=cond_fn, + body_subgraph=body_fn, + # asserted above that there is at least one operand + layout=MultiOutputLayout(device), + ) + + outputs = [ + MultiOutput( + FixedLayout( + device=output.get_device(), + dtype=output.get_dtype(), + size=output.get_size(), + stride=output.get_stride(), + offset=output.get_layout().offset, + ), + while_loop, + [(list, i)], + ) + for i, output in enumerate(body_outputs) + ] + + for inp, out in zip(carried_inputs, outputs): + if inp.get_name() in V.graph.graph_inputs: + # if a carried input of the while_loop is a graph input, + # it can be returned as is when the number of iterations + # is zero. due to this, we can't (generally) reuse the + # output buffers corresponding to the graph inputs, as + # the inputs may end up being mutated. + V.graph.never_reuse_buffers.add(out.get_name()) + + while_loop.outputs = outputs + return outputs + + def codegen(self, wrapper): + wrapper.codegen_while_loop(self) + + +class EffectfulKernel(FallbackKernel): + def __init__( + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + *, + unbacked_bindings=None, + ): + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + unbacked_bindings=unbacked_bindings, + ) + + from torch._higher_order_ops.effects import get_effect_key + + effect_type = get_effect_key(kernel, (*nontensor_args, *tensor_args), kwargs) + assert effect_type is not None + self.effect_type = effect_type + self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None) + V.graph.effectful_ops[effect_type] = self + + def get_read_writes(self): + read_writes = super().get_read_writes() + + if self.prev_effect_buffer is not None: + read_writes.reads.add( + dependencies.StarDep(self.prev_effect_buffer.get_name()) + ) + + return read_writes + + def has_side_effects(self): + return True + + +@dataclasses.dataclass +class TorchBindObject(IRNode): + name: str + value: torch._C.ScriptObject + + def get_name(self): + return self.name + + def get_device(self): + return None # is there a device?? + + def codegen_reference(self, writer=None): + return self.name + + +class _CollectiveKernel(FallbackKernel): + def should_allocate(self): + return False + + def has_side_effects(self): + return True + + # This is identical to FallbackKernel.set_cpp_kernel(), minus the + # part that checks against input aliasing and mutation. + def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None): + from .codegen.wrapper import get_cpp_op_schema + + assert ( + type(self.op_overload) is torch._ops.OpOverload + ), "Setting cpp kernel needs a valid op_overload" + kernel = self.op_overload + self.cpp_kernel_name = kernel._schema.name + self.cpp_kernel_overload_name = kernel._schema.overload_name + self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] + + self.cpp_op_schema = get_cpp_op_schema(kernel) + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in kernel._schema.arguments if x.kwarg_only + ] + + # NOTE: [In-Place Collective Safety] + # Between the initiation and completion of an in-place collective, the + # input buffers are subject to both volatile reads and volatile writes. + # They must not be read, written to or reused by another kernel. To ensure + # the constraints, we model collective -> wait_tensor as as two-step + # mutation of the input buffers. + @classmethod + def create_inplace( + cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs + ) -> None: + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + assert not unbacked_bindings, f"{kernel} {unbacked_bindings}" + for tensor_arg in tensor_args: + tensor_arg.realize() + + device = tensor_args[0].get_device() + packed = cls( + NoneLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + + inps = pytree.tree_leaves(inputs) + packed.mutation_outputs.extend( + [MutationOutput(NoneLayout(device), buf, packed) for buf in inps] + ) + + # For inplace collective ops, the input is guaranteed to be alias of the returned value of op. + packed.alias_names.extend([inp.get_name() for inp in inps]) + if "out" in kwargs: + packed.mutation_outputs.append( + MutationOutput(NoneLayout(device), kwargs["out"], packed) + ) + # For out-variant collective ops, the `out=` arg is guaranteed to be alias of the returned value of op. + packed.alias_names.append(kwargs["out"].get_name()) + + # NOTE: [Out-of-Place Collective Safety] + # Between the initiation and completion of an out-of-place collective: + # + # Input buffers: + # - Are subject to volatile reads + # - Can be read by another kernel + # - Must not be written to or reused by another kernel + # + # Output buffers: + # - Are subject to volatile writes + # - Must not be read, written to or reused by another kernel + # + # To ensure the safety of input buffers without sacrificing read + # availability, we add input buffers as read deps of wait_tensor kernels. + # + # To ensure the safety of output buffers, we model wait_tensor as a + # mutation to the output buffer. Note we also assumes the user program being + # correct and the output buffer is not consumed by kernels other than + # wait_tensor. + # + # TODO(yifu): add a pre-grad pass to validate the correctness of collective + # usage in the user program. + @classmethod + def create_out_of_place( + cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs + ): + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + assert not unbacked_bindings, f"{kernel}, {unbacked_bindings}" + for tensor_arg in tensor_args: + tensor_arg.realize() + + if isinstance(example_output, list): + device = cls.find_device(tensor_args, example_output) + packed = cls( + MultiOutputLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.outputs = [ + MultiOutput( + cls.tensor_to_layout(tensor), + packed, + [(list, i)], + ) + for i, tensor in enumerate(example_output) + ] + return packed.outputs + else: + packed = cls( + cls.tensor_to_layout(example_output), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.outputs = [packed] + return packed + + +class _WaitKernel(_CollectiveKernel): + def get_volatile_reads(self): + inp = self.inputs[0] + if isinstance(inp, _CollectiveKernel): + # Out-of-place single-output + return [inp.inputs[0]] + elif isinstance(inp, MultiOutput): + # This can be two things: + # 1. Out-of-place multi-output coll + # 2. In-place coll with inputs coming from another MultiOutput + coll = inp.inputs[0] + # Case 1 + if isinstance(coll, _CollectiveKernel): + _, idx = inp.indices[0] + return [coll.inputs[idx]] + # Case 2 + return [] + else: + # In-place requires no additional deps handling for volatile + # reads since the inputs are mutated. + return [] + + @classmethod + def create_wait(cls, kernel, inp: TensorBox) -> None: + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, inp) + assert not unbacked_bindings, f"{kernel} {unbacked_bindings}" + packed = cls( + NoneLayout(inp.get_device()), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.mutation_outputs.append( + MutationOutput(NoneLayout(inp.get_device()), inp, packed) + ) + + def get_read_writes(self): + read_writes = super().get_read_writes() + # See [Out-of-Place Collective Safety]. + volatile_reads = self.get_volatile_reads() + for vr in volatile_reads: + read_writes.reads.add(dependencies.StarDep(vr.get_name())) + return read_writes + + +# NB: recursive structure here reflects val_to_arg_str, avoid +# calling free_unbacked_symbols on "exotic" types that don't get pexpr +# treatment +def maybe_free_unbacked_symbols(s: object) -> OrderedSet[Symbol]: + if isinstance(s, (SymTypes, Expr)): + # This branch should be impossible in return position + return free_unbacked_symbols(s) + elif isinstance(s, (tuple, list)): + r: OrderedSet[sympy.Symbol] = OrderedSet() + for t in s: + r |= maybe_free_unbacked_symbols(t) + return r + elif isinstance(s, torch.Tensor): + # This branch is impossible in constant-args position + return free_unbacked_symbols(s) + else: + return OrderedSet() diff --git a/lib/python3.10/site-packages/torch/_inductor/jagged_lowerings.py b/lib/python3.10/site-packages/torch/_inductor/jagged_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..c96c9f4ae2d443779f7514e3077287c327c92123 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/jagged_lowerings.py @@ -0,0 +1,264 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import List, Optional, Tuple, Union + +import sympy + +import torch + +from .ir import Pointwise, TensorBox +from .lowering import fallback_handler, is_integer_type, register_lowering +from .virtualized import ops + + +# pyre-ignore[2,3] +def dense_idx_to_jagged_idx(batch_idx, seq_idx, offsets_loader, jagged_len): + # jagged_len + 1 is used as the upper bound, + # because the last sequence length may be zero + begin_idx = ops.indirect_indexing( + offsets_loader([batch_idx]), + jagged_len + 1, + ) + end_idx = offsets_loader([batch_idx + 1]) + jagged_idx = begin_idx + seq_idx + return jagged_idx, end_idx + + +def get_inverse_offsets( + offsets: TensorBox, + jagged_len: Union[int, sympy.Expr], + realize: bool = True, +) -> TensorBox: + """ + Returns "inverse_offsets" - the inverse of the offsets array. + offsets maps batch index (dense) to jagged index (i.e. offset into jagged tensor). + inverse_offsets maps jagged index to batch index. + + e.g. for offsets [0, 3, 4, 9, 10] this will return + inverse_offsets = [0, 0, 0, 1, 2, 2, 2, 2, 2, 3] + + For the given offsets, the computed inverse_offsets are cached + on the first call and reused in the further calls. + """ + + if hasattr(offsets, "inverse_offsets"): + # inverse_offsets are already computed + # for these offsets: can reuse + return offsets.inverse_offsets + + # ops.bucketize takes offsets.get_name() which doesn't exist on Pointwise + # kernels, i.e. we need to realize it before using. In other words, we need + # offsets to be in global memory so that we can binary search over the + # entire tensor + offsets.realize() + device: torch.device = offsets.get_device() + dtype: torch.dtype = offsets.get_dtype() + + # pyre-ignore[2,3] + def inner_fn(index): + idx = index[0] + bucket = ops.bucketize( + values=ops.index_expr(idx, dtype), + offsets_name=offsets.get_name(), + offsets_size=offsets.get_size()[0], + indexing_dtype=dtype, + right=True, + ) + # ops.bucketize above returns 1-based bucket indices, + # but we need 0-based, hence we subtract 1 from batch + return bucket - 1 + + inverse_offsets = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[jagged_len], + ) + + if realize: + # "freeze" the node so that it doesn't get inlined downstream. + inverse_offsets.realize() + + # cache inverse_offsets for further reuse + offsets.inverse_offsets = inverse_offsets # type: ignore[attr-defined] + + return inverse_offsets + + +def jagged_idx_to_dense_idx( + jagged_idx, # pyre-ignore[2] + inverse_offsets_loader, # pyre-ignore[2] + offsets_loader, # pyre-ignore[2] + batch_size: Union[int, sympy.Expr], + max_seq_len: Union[int, sympy.Expr], + offsets_dtype: torch.dtype, +) -> Tuple[sympy.Expr, sympy.Expr]: + batch_idx = ops.indirect_indexing( + inverse_offsets_loader([jagged_idx]), + batch_size + 1, + ) + batch_start = offsets_loader([batch_idx]) + seq = ops.index_expr(jagged_idx, offsets_dtype) - batch_start + # check=False because there may be sequences longer than max_seq_len + seq_idx = ops.indirect_indexing(seq, max_seq_len, check=False) + return batch_idx, seq_idx + + +def register_jagged_ops(): + # pyre-ignore[56] + @register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default) + def _jagged_to_padded_dense_forward( + jagged_values: TensorBox, + jagged_offsets: List[TensorBox], + max_lengths: List[int], # list of ints/SymInts + padding_value: float = 0.0, + ) -> TensorBox: + device = jagged_values.get_device() + dtype = jagged_values.get_dtype() + + jagged_values_size = jagged_values.get_size() + + # only handle the common case of a single jagged dimension + if ( + len(jagged_offsets) != 1 + or device.type != "cuda" + or device != jagged_offsets[0].get_device() + or len(jagged_values_size) != 2 + or len(jagged_offsets[0].get_size()) != 1 + or len(max_lengths) != len(jagged_offsets) + or not is_integer_type(jagged_offsets[0]) + ): + return fallback_handler( + torch.ops.aten._jagged_to_padded_dense_forward.default, + add_to_fallback_set=False, + )( + jagged_values, + jagged_offsets, + max_lengths, + padding_value, + ) + + offsets: TensorBox = jagged_offsets[0] + offsets_len = offsets.get_size()[0] + offsets_dtype = offsets.get_dtype() + batch_size = offsets_len - 1 + max_seq_len = max_lengths[0] + embedding_len = jagged_values_size[1] + jagged_len = jagged_values_size[0] + + output_size = [batch_size, max_seq_len, embedding_len] + + values_loader = jagged_values.make_loader() + offsets_loader = offsets.make_loader() + + # pyre-ignore[2,3,53] + def inner_fn(index): + # dense tensor size: [B, N, D] + batch_idx, seq_idx, emb_idx = index + jagged_idx, end_idx = dense_idx_to_jagged_idx( + batch_idx=batch_idx, + seq_idx=seq_idx, + offsets_loader=offsets_loader, + jagged_len=jagged_len, + ) + return ops.masked( + ops.lt( + ops.index_expr(jagged_idx, offsets_dtype), + end_idx, + ), + lambda: values_loader([jagged_idx, emb_idx]), + padding_value, + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=output_size, + ) + + def _dense_to_jagged_forward_impl( + fallback_op, # pyre-ignore[2] + dense: TensorBox, + jagged_offsets: List[TensorBox], + jagged_len: Optional[int] = None, + ) -> TensorBox: + device = dense.get_device() + dtype = dense.get_dtype() + + dense_size = dense.get_size() + + # only handle the common case of a single jagged dimension + if ( + len(jagged_offsets) != 1 + or device.type != "cuda" + or device != jagged_offsets[0].get_device() + or len(jagged_offsets[0].get_size()) != 1 + or len(dense_size) != 3 + or jagged_len is None + or not is_integer_type(jagged_offsets[0]) + ): + return fallback_handler(fallback_op, add_to_fallback_set=False)( + dense, + jagged_offsets, + jagged_len, + ) + + offsets: TensorBox = jagged_offsets[0] + offsets_dtype = offsets.get_dtype() + batch_size = dense_size[0] + max_seq_len = dense_size[1] + embedding_len = dense_size[-1] + + output_size = [jagged_len, embedding_len] + + dense_loader = dense.make_loader() + offsets_loader = offsets.make_loader() + + inverse_offsets = get_inverse_offsets( + offsets=offsets, + jagged_len=jagged_len, + ) + inverse_offsets_loader = inverse_offsets.make_loader() + + # pyre-ignore[2,3,53] + def inner_fn(index): + # jagged tensor size: [sum_B(N_B), D] + jagged_idx, emb_idx = index + batch_idx, seq_idx = jagged_idx_to_dense_idx( + jagged_idx=jagged_idx, + offsets_loader=offsets_loader, + inverse_offsets_loader=inverse_offsets_loader, + batch_size=batch_size, + max_seq_len=max_seq_len, + offsets_dtype=offsets_dtype, + ) + return ops.masked( + ops.lt( + ops.index_expr(seq_idx, offsets_dtype), + ops.index_expr(max_seq_len, offsets_dtype), + ), + lambda: dense_loader([batch_idx, seq_idx, emb_idx]), + 0.0, # jagged sequence longer than max_seq_len + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=output_size, + ) + + # pyre-ignore[56] + @register_lowering(torch.ops.aten._padded_dense_to_jagged_forward) + def _dense_to_jagged_forward( + dense: TensorBox, + jagged_offsets: List[TensorBox], + jagged_len: Optional[int] = None, + ) -> TensorBox: + return _dense_to_jagged_forward_impl( + fallback_op=torch.ops.aten._padded_dense_to_jagged_forward.default, + dense=dense, + jagged_offsets=jagged_offsets, + jagged_len=jagged_len, + ) diff --git a/lib/python3.10/site-packages/torch/_inductor/loop_body.py b/lib/python3.10/site-packages/torch/_inductor/loop_body.py new file mode 100644 index 0000000000000000000000000000000000000000..2f0b00b724e71fb8620d45f6601635f37772baf4 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/loop_body.py @@ -0,0 +1,594 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import itertools +import re +from enum import auto, Enum +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple + +import sympy + +import torch.fx +from torch._dynamo.utils import identity +from torch.utils._sympy.symbol import SymT + +from . import config, dependencies +from .codegen.common import index_prevent_reordering +from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs +from .virtualized import ops, V + + +class InterpreterShim(torch.fx.Interpreter): + @staticmethod + @functools.lru_cache(None) + def _dummy_gm(): + return torch.fx.symbolic_trace(identity) + + def __init__(self, graph, submodules): + # call super() with a placeholder to avoid constructing a + # GraphModule which is very expensive (it does codegen). + super().__init__(self._dummy_gm(), garbage_collect_values=False) + self.module = self # type: ignore[assignment] + self.graph = graph + self.submodules = submodules + self.extra_traceback = False + self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign] + self.current_node = None + + def run_node(self, n: torch.fx.Node) -> Any: + self.current_node = n + return super().run_node(n) + + def run(self, *args, **kwargs): + with V.set_interpreter_handler(self): + return super().run(*args, **kwargs) + + +class MemoryEntry(NamedTuple): + index_name: str # LoopBody.indexing_exprs[index_name] + buffer_name: Optional[str] + mode: Optional[str] # V.ops.store(..., mode=mode) + + +class MemoryUsageType(Enum): + # These are 1:1 with the opcode generating the usage + LOAD = auto() + LOAD_SEED = auto() + STORE = auto() + STORE_REDUCTION = auto() + INDEX_EXPR = auto() + CHECK_BOUNDS = auto() + BUCKETIZE = auto() + + +class LoopBody: + """ + Captures the body of a Loops subclass into an FX graph. Persists any + indexing simplifications and makes it easier to analyze loop bodies. + """ + + indexing_exprs: Dict[str, sympy.Expr] + indexing_exprs_name: Dict[sympy.Expr, str] + submodules: Dict[str, Any] + subblocks: Dict[str, LoopBodyBlock] + indirect_vars: List[str] + indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] + root_block: LoopBodyBlock + memory_usage: Dict[MemoryUsageType, List[MemoryEntry]] + + def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): + super().__init__() + + _flat_sizes = tuple(var_ranges.values()) + self.sizes = ( + _flat_sizes[: len(iter_vars)], + _flat_sizes[len(iter_vars) :], + ) + + self.iter_vars = iter_vars + self.reduce_vars = reduce_vars + self.var_ranges = var_ranges + + if isinstance(fn, LoopBody): + self._init_with_copy(fn, args) + else: + self._init_with_tracing(fn, args) + + self.indexing = None + + def _init_with_tracing(self, fn, args): + """Do an FX trace of an arbitrary callable to construct self""" + self.indexing_exprs = {} + self.indexing_exprs_name = {} + self.submodules = {"get_index": self.get_index} + self.subblocks = {} + self.indirect_vars = [] + self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {} + self.memory_usage = {t: [] for t in MemoryUsageType} + self.root_block = LoopBodyBlock(self, fn, args) # traces + del self.indexing_exprs_name # not used after _init_with_tracing + + def _init_with_copy(self, other: LoopBody, args): + """ + _init_with_tracing() is slow, so this is a fast path in the case + where we are just reordering/merging/splitting the args of an + existing LoopBody. + """ + indexing_exprs = other.indexing_from_args(args) + self.indexing_exprs = { + name: V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges) + for name, expr in indexing_exprs.items() + } + self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()} + self.indirect_vars = other.indirect_vars + self.indirect_var_ranges = other.indirect_var_ranges + self.memory_usage = other.memory_usage + self.root_block = other.root_block.clone(self) + + submodules = {**other.submodules} + submodules.pop("get_index") + self.submodules = { + "get_index": self.get_index, + **{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined] + } + + def merge_loops(self) -> LoopBody: + """ + Merge both iteration and reduction loops and return a new LoopBody. + """ + old_body = self + old_sizes = self.sizes + old_iter_vars, old_reduce_vars = old_body.vars + old_iter_sizes, old_reduce_sizes = old_sizes + + index_exprs = [*old_body.indexing_exprs.values()] + + iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops( + old_iter_vars, + old_iter_sizes, + index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes), + ) + + reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops( + old_reduce_vars, + old_reduce_sizes, + index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes), + ) + + # if iter_sizes == old_iter_sizes: + # # no dimensions get merged. + # return old_sizes, old_body + + # Note: if no dimension get merges, the symbol prefix will + # remain 'y'. But if we merge dimensions, we change prefix to + # 'z'. If this is an issue, we can always retrace the LoopBody + # to change symbol prefix to 'z'. + # + # There is indeed an issue due to symbol name conflicting. + # y0 maybe reused for the y dimension later. + ( + iter_vars, + reduce_vars, + ), var_ranges = dependencies.index_vars_no_squeeze( + iter_sizes, reduce_sizes, prefix="t" + ) + new_body = LoopBody( + old_body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, + ) + + # use the original symbol prefix + # Can try to optimize if this is a bottleneck for compilation time + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + iter_sizes, reduce_sizes, prefix="z" + ) + new_body2 = LoopBody( + new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body2 + + def reorder_iter_loops(self, new_order) -> LoopBody: + """ + Reorder iteration loops and return a new LoopBody. + """ + from .ir import same_reorder + + old_body = self + old_sizes = self.sizes + assert len(old_sizes[0]) == len(new_order) + reorder_fn = same_reorder(new_order) + + iter_size, reduce_size = old_sizes + new_iter_size = reorder_fn(iter_size) + + new_sizes = (new_iter_size, reduce_size) + + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + *new_sizes, prefix="t" # type: ignore[arg-type] + ) + + inverse_order = {b: a for a, b in enumerate(new_order)} + inverse_order = [inverse_order[i] for i in range(len(new_order))] + + def new_body(*indices: Sequence[sympy.Expr]) -> Any: + index = list(itertools.chain(*indices)) + assert len(index) == len(iter_size) + len(reduce_size) + iter_idx = index[: len(iter_size)] + reduce_idx = index[len(iter_size) :] + iter_idx = [iter_idx[i] for i in inverse_order] + return old_body(iter_idx, reduce_idx) + + loop_body = LoopBody( + new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars + ) + + # use the original symbol prefix so we can do multiple round of reordering + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + *new_sizes, prefix="z" # type: ignore[arg-type] + ) + new_body = LoopBody( + loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body + + @property + def vars(self): + assert self.iter_vars is not None + assert self.reduce_vars is not None + return self.iter_vars, self.reduce_vars + + @cache_on_self + def get_nodes(self): + all_graphs = itertools.chain( + (self.root_block.graph,), + (block.graph for block in self.subblocks.values()), + ) + return [node for graph in all_graphs for node in graph.nodes] + + @cache_on_self + def bounds(self): + # Doing a local import to avoid dumping all the code here + from .bounds import BoundVars + + return BoundVars(self) + + def get_read_expr(self, buffer_name): + # reversed to match old behavior + for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]): + if entry.buffer_name == buffer_name: + return self.indexing_exprs[entry.index_name] + raise KeyError(buffer_name) + + def get_write_expr(self, buffer_name): + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ): + if entry.buffer_name == buffer_name: + return self.indexing_exprs[entry.index_name] + raise KeyError(buffer_name) + + def get_read_exprs(self): + return [ + self.indexing_exprs[entry.index_name] + for entry in self.memory_usage[MemoryUsageType.LOAD] + ] + + def get_write_exprs(self): + return [ + self.indexing_exprs[entry.index_name] + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ) + ] + + def debug_str(self): + lines = [f"var_ranges = {dict(self.var_ranges)}"] + lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) + lines.extend( + [ + block.debug_str(name) + for name, block in itertools.chain( + [("body", self.root_block)], self.subblocks.items() + ) + ] + ) + return "\n".join(lines) + + def is_memory_copy(self) -> bool: + """ + True of this contains only a single loads and store. + Note, this could involve a layout change. + """ + return ( + len(self.memory_usage[MemoryUsageType.LOAD]) == 1 + and len(self.memory_usage[MemoryUsageType.STORE]) == 1 + and len(self.submodules) == 1 # get_index + and self.root_block.contains_only_ops(("load", "store")) + ) + + __repr__ = debug_str + + def add_index_expr( + self, + expr: sympy.Expr, + mtype: MemoryUsageType, + buffer_name: Optional[str] = None, + mode: Optional[str] = None, + ): + name = self.indexing_exprs_name.get(expr) + if not name: + name = f"index{len(self.indexing_exprs)}" + self.indexing_exprs_name[expr] = name + self.indexing_exprs[name] = expr + self.memory_usage[mtype].append(MemoryEntry(name, buffer_name, mode)) + return name + + def add_submodule(self, block, prefix): + """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" + if prefix[-1].isnumeric() and prefix not in self.submodules: + name = prefix + else: + name = f"{prefix}{len(self.submodules)}" + self.submodules[name] = block + return name + + def add_indirect(self, size): + var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars)) + assert var not in self.indirect_var_ranges + self.indirect_vars.append(var) + self.indirect_var_ranges[var] = size + return var + + def replace_indirect(self, old, new): + """Swap in a variable used in indirect indexing""" + if str(old) == str(new): + return + assert self.indexing is not None + self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} + + def get_index(self, name): + assert self.indexing is not None + return self.indexing[name] + + def indexing_from_args(self, indices): + index = [*itertools.chain.from_iterable(indices)] + assert len(index) == len(self.var_ranges), (index, self.var_ranges) + assert all( + v not in self.var_ranges for v in index + ), f"{self.var_ranges=}, {indices=}" + replacements = dict(zip(self.var_ranges.keys(), index)) + return { + name: sympy_subs(expr, replacements) + for name, expr in self.indexing_exprs.items() + } + + def __call__(self, *indices): + self.indexing = self.indexing_from_args(indices) + result = self.root_block() + self.indexing = None + return result + + def bind_set_indirect_shim(self, var, size, check, wrap_neg): + def set_indirect(new_var): + self.replace_indirect( + var, V.ops.indirect_indexing(new_var, size, check, wrap_neg) + ) + + set_indirect.clone = functools.partial( # type: ignore[attr-defined] + LoopBody.bind_set_indirect_shim, + var=var, + size=size, + check=check, + wrap_neg=wrap_neg, + ) + return set_indirect + + def bind_scan_shim(self, combine_fn): + def shim(dtypes, values): + return V.ops.scan(dtypes, combine_fn, values) + + shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn) # type: ignore[attr-defined] + return shim + + def bind_masked_shim(self, name): + def shim(mask, other): + return V.ops.masked(mask, self.subblocks[name], other) + + shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name) # type: ignore[attr-defined] + return shim + + +class LoopBodyBlock: + """ + Captures the body of a Loops subclass into an FX graph. + In normal cases there will be a 1:1 mapping between LoopBody and + LoopBodyBlock, hower in the case of ops.masked() the masked out + operations will manifest as an extra LoopBodyBlock. + """ + + def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): + self.body = body + + def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs): + return tracer.create_proxy( + "call_module", + "get_index", + (body.add_index_expr(expr, mtype, **kwargs),), + {}, + ) + + class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] + self.name = "CaptureIndexing" + + def load(self, name: str, index: sympy.Expr): + index = add_index(index, MemoryUsageType.LOAD, buffer_name=name) + return self._inner.load(name, index) + + def load_seed(self, name: str, index: int): + assert isinstance(index, int) + body.add_index_expr( + sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name + ) + return self._inner.load_seed(name, index) + + def store(self, name, index, value, mode=None): + index = add_index( + index, MemoryUsageType.STORE, buffer_name=name, mode=mode + ) + return self._inner.store(name, index, value, mode) + + def store_reduction(self, name, index, value): + index = add_index( + index, MemoryUsageType.STORE_REDUCTION, buffer_name=name + ) + return self._inner.store_reduction(name, index, value) + + def reduction(self, dtype, src_dtype, reduction_type, value): + result = self._inner.reduction(dtype, src_dtype, reduction_type, value) + if "welford" in reduction_type: + return tuple(result[i] for i in range(3)) + return result + + def index_expr(self, index, dtype): + if isinstance(index, (int, sympy.Integer)): + return self._inner.constant(int(index), dtype) + index = add_index(index, MemoryUsageType.INDEX_EXPR) + return self._inner.index_expr(index, dtype) + + def check_bounds(self, index, size, lower, upper): + index = add_index(index, MemoryUsageType.CHECK_BOUNDS) + size = add_index(size, MemoryUsageType.CHECK_BOUNDS) + return self._inner.check_bounds(index, size, lower, upper) + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + offsets_size = add_index( + offsets_size, MemoryUsageType.BUCKETIZE, buffer_name=offsets_name + ) + return self._inner.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + + @staticmethod + def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): + """ + Recursively capture the masked out body in another LoopBodyBlock + """ + name = self.body.add_submodule(None, "masked_subblock") + self.body.submodules[name] = self.body.bind_masked_shim(name) + self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) + return tracer.create_proxy( + "call_module", name, (mask_proxy, other_proxy), {} + ) + + @staticmethod + def scan( + dtype_proxy, + combine_fn: Callable[ + [Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...] + ], + value_proxy, + ): + shim = self.body.bind_scan_shim(combine_fn) + name = self.body.add_submodule(shim, "scan") + result = tracer.create_proxy( + "call_module", + name, + (dtype_proxy, value_proxy), + {}, + ) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(value_proxy))) + + def sort(self, dtypes, values, stable, descending): + result = self._inner.sort(dtypes, values, stable, descending) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(values))) + + def frexp(self, value_proxy): + result = self._inner.frexp(value_proxy) + # Proxies are iterable, but some methods expect tuples/lists + return (result[0], result[1]) + + @staticmethod + def indirect_indexing(index_proxy, size, check=True, wrap_neg=True): + """ + Flow data from tensors into indexing formulas. + Introduce a call_module to update the indexing. + """ + + var = self.body.add_indirect(size) + set_indirect = self.body.bind_set_indirect_shim( + var, size, check, wrap_neg + ) + tracer.create_proxy( + "call_module", + self.body.add_submodule(set_indirect, f"set_{var}"), + (index_proxy,), + {}, + ) + return var + + @staticmethod + def output(result): + tracer.create_proxy("output", "output", (result,), {}) + + tracer = torch.fx.Tracer() + tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) + proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) + + from .index_propagation import IndexPropagation + from .sizevars import SimplifyIndexing + + handler: Any = SimplifyIndexing( + CaptureIndexing(proxy_ops), self.body.var_ranges + ) + if config.constant_and_index_propagation: + handler = IndexPropagation( + handler, self.body.var_ranges, self.body.indirect_var_ranges + ) + + with V.set_ops_handler(handler): + # This indirection is just a cute way to get IndexPropagation to + # unwrap the return value. + ops.output(fn(*args)) + self.graph = tracer.graph + + def __call__(self): + graph = self.graph + submodules = self.body.submodules + + return InterpreterShim(graph, submodules).run(V.get_ops_handler()) + + def debug_str(self, name="block"): + code = torch.fx.GraphModule(self.body.submodules, self.graph).code + return re.sub( + # strip `; del var0` suffixes to make output prettier + r";[^\n]*", + "", + code.strip().replace("def forward(", f"def {name}("), + ) + + def contains_only_ops(self, allowed_ops) -> bool: + return all( + node.target in allowed_ops + for node in self.graph.find_nodes(op="call_method") + ) + + def clone(self, body: LoopBody): + """Shallow copy with a new parent LoopBody""" + copy = LoopBodyBlock.__new__(LoopBodyBlock) + copy.__dict__.update({**self.__dict__, "body": body}) + return copy diff --git a/lib/python3.10/site-packages/torch/_inductor/lowering.py b/lib/python3.10/site-packages/torch/_inductor/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..adf921913edfcdfc902ab70307edae5bf6c46528 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/lowering.py @@ -0,0 +1,6449 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import itertools +import logging +import math +import operator +import os +import warnings +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from unittest.mock import patch + +import sympy + +import torch +import torch.ao.quantization.fx._decomposed +import torch.fx +import torch.utils._pytree as pytree +from torch._higher_order_ops.associative_scan import associative_scan_op +from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation +from torch._prims_common import ( + canonicalize_dim, + canonicalize_dims, + check, + dtype_to_type, + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + get_computation_dtype, + is_boolean_dtype, + is_float_dtype, + is_integer_dtype, + Number, +) +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.utils._sympy.functions import ( + CeilDiv, + FloorDiv, + Identity, + IntTrueDiv, + ModularIndexing, +) + +from .._dynamo.utils import import_submodule +from . import config, inductor_prims, ir, test_operators # NOQA: F401 +from .decomposition import decompositions, get_decompositions +from .ir import ( + DtypeView, + ExpandView, + IndexingConstant, + is_triton, + ops_wrapper, + PermuteView, + Pointwise, + Reduction, + SqueezeView, + TensorBox, + validate_ir, + View, +) +from .utils import ( + ceildiv, + decode_device, + is_dynamic, + is_gpu, + is_pointwise_use, + needs_fallback_due_to_atomic_add_limitations, + pad_listlike, + sympy_product, + use_scatter_fallback, +) +from .virtualized import ops, V + + +log = logging.getLogger(__name__) +lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {} +# Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints +_maybe_layout_constraints: Dict[ + torch._ops.OpOverload, Optional[Callable[..., Any]] +] = {} +fallbacks: Set[torch._ops.OpOverload] = set() +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims +needs_realized_inputs: Set[torch._ops.OpOverload] = set() +foreach_ops: Set[torch._ops.OpOverload] = set() +inplace_foreach_ops: Set[torch._ops.OpOverload] = set() +inplaceable_foreach_ops: Dict[torch._ops.OpOverload, torch._ops.OpOverload] = {} +quantized_decomposed = torch.ops.quantized_decomposed + + +def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]: + """Get layout constraints. Returns None if there are no layout constraints.""" + if not isinstance(fn, torch._ops.OpOverload): + # Only OpOverloads have layout constraints. + return None + if fn in _maybe_layout_constraints: + return _maybe_layout_constraints[fn] + # OpOverload with custom lowerings override tag-based layout constraints + if fn in lowerings: + _maybe_layout_constraints[fn] = None + return None + # We lazily register tag-based layout constraints. + + def handle_layout_constraint_tag(tag): + if tag is torch._C.Tag.needs_fixed_stride_order: + _maybe_layout_constraints[fn] = constrain_to_fx_strides + return _maybe_layout_constraints[fn] + elif tag is torch._C.Tag.flexible_layout: + _maybe_layout_constraints[fn] = None + return None + else: + raise AssertionError(f"Unknown layout constraint tag: {tag}") + + tag = get_layout_constraint_tag(fn) + return handle_layout_constraint_tag(tag) + + +def get_layout_constraint_tag(fn): + tags_by_priority = [ + torch._C.Tag.needs_fixed_stride_order, + torch._C.Tag.flexible_layout, + ] + for tag in tags_by_priority: + if tag in fn.tags: + return tag + return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) + + +def assert_nyi(cond, msg): + if not cond: + raise NotImplementedError(f"inductor does not support {msg}") + + +def add_needs_realized_inputs(fn): + if isinstance(fn, (list, tuple, set)): + return [add_needs_realized_inputs(x) for x in fn] + needs_realized_inputs.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + needs_realized_inputs.update( + getattr(fn, overload) for overload in fn.overloads() + ) + + +def add_layout_constraint(fn, constraint): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + _maybe_layout_constraints[getattr(fn, overload)] = constraint + else: + _maybe_layout_constraints[fn] = constraint + + +add_needs_realized_inputs( + [ + aten.as_strided, + aten.as_strided_copy, + aten.avg_pool2d, + aten.avg_pool2d_backward, + aten.bmm, + aten.convolution, + aten.convolution_backward, + aten.max_pool2d_with_indices, + aten.max_pool2d_with_indices_backward, + aten.mm, + aten.upsample_nearest2d, + aten._upsample_nearest_exact2d, + aten._int_mm, + ] +) + +# TODO(jansel): ezyang says we won't need this in the future, try removing it +# based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28 +DTYPE_ID_LOOKUP = { + 0: torch.uint8, + 1: torch.int8, + 2: torch.int16, + 3: torch.int32, + 4: torch.int64, + 5: torch.float16, + 6: torch.float32, + 7: torch.float64, + 8: torch.complex32, + 9: torch.complex64, + 10: torch.complex32, + 11: torch.bool, + 15: torch.bfloat16, + # TODO(jansel): add quantized types? + # _(c10::qint8, QInt8) /* 12 */ + # _(c10::quint8, QUInt8) /* 13 */ + # _(c10::qint32, QInt32) /* 14 */ + # _(c10::quint4x2, QUInt4x2) /* 16 */ + # _(c10::quint2x4, QUInt2x4) /* 17 */ +} + + +def decode_dtype(dtype: int): + if not isinstance(dtype, int): + return dtype + assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP" + dtype = DTYPE_ID_LOOKUP[dtype] + return dtype + + +def is_integer_type(x): + if isinstance(x, TensorBox): + return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + elif isinstance(x, sympy.Expr): + return x.is_integer is True # type: ignore[attr-defined] + else: + return isinstance(x, int) + + +def is_boolean_type(x): + if isinstance(x, TensorBox): + return is_boolean_dtype(x.get_dtype()) + else: + return isinstance(x, bool) + + +def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND): + def construct_input(inp): + if isinstance(inp, (Number, sympy.Basic)): + return inp + else: + assert hasattr(inp, "get_dtype") + dim = len(inp.get_size()) + # construct a tmp tensor to feed into torch.result_type + return torch.zeros([1] * dim, dtype=inp.get_dtype()) + + inps = [construct_input(arg) for arg in args] + _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind) + return dtype + + +def get_overloads(aten_fn): + if not isinstance(aten_fn, (list, tuple)): + aten_fn = [aten_fn] + else: + aten_fn = list(aten_fn) + + for fn in list(aten_fn): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + if other_fn not in lowerings: + aten_fn.append(other_fn) + + return aten_fn + + +def in_namespace(op, namespace): + if isinstance(op, torch._ops.OpOverloadPacket): + return namespace in op._qualified_op_name + elif isinstance(op, torch._ops.OpOverload): + return namespace in op.name() + return False + + +def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool): + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + if (type_promotion_kind or convert_input_to_bool) and indices: + if convert_input_to_bool: + dtype = torch.bool + else: + # FIXME that's a crude approximation for promoting args + promoting_args = [ + a + for a in args + if isinstance(a, (Number, sympy.Basic)) + or getattr(a, "dtype", None) is not None + ] + dtype = get_promoted_dtype( + *promoting_args, type_promotion_kind=type_promotion_kind + ) + + # sometimes args are an immutable list so we can't mutate them + def promote(arg): + if isinstance(arg, TensorBox): + return to_dtype(arg, dtype) + elif isinstance(arg, ir.Constant): + return ir.Constant(arg.value, dtype, args[indices[0]].get_device()) + else: + return arg + + args = [promote(a) for a in args] + if broadcast and indices: + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + + return args + + +def _register_foreach_lowering(aten_fn, decomp_fn): + """ + Add a foreach lowering to lowerings dict. + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + assert len(args) <= 2 + out = decomp_fn(*args, **kwargs) + validate_ir(out) + return out + + aten_fns = get_overloads(aten_fn) + foreach_ops.update(aten_fns) + lowerings.update(dict.fromkeys(aten_fns, wrapped)) + return wrapped + + +def _register_lowering( + aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool +): + """ + Add a lowering to lowerings dict + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + args: Union[List[Any], Tuple[Any, ...], Dict[Any, Any]] = list(args) + unpacked = False + # TODO maybe we need to use pytrees here + if len(args) == 1 and isinstance(args[0], (list, tuple)): + unpacked = True + args = args[0] + + # kwargs tensors not supported yet unless it's a fallback op + if not all( + (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn + ): + assert not any(isinstance(x, TensorBox) for x in kwargs.values()) + # explicitly assert for "out=" ops for better error messages + assert not any( + x == "out" for x in kwargs.keys() + ), "out= ops aren't yet supported" + + args = transform_args( + args, broadcast, type_promotion_kind, convert_input_to_bool + ) + + if unpacked: + args = [args] + + out = decomp_fn(*args, **kwargs) + validate_ir(out) + + return out + + aten_fn = get_overloads(aten_fn) + + lowerings.update(dict.fromkeys(aten_fn, wrapped)) + return wrapped + + +def register_lowering( + aten_fn, + broadcast=False, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, +): + """ + Shim to support decorator syntax. + """ + return functools.partial( + _register_lowering, + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + ) + + +def broadcast_symbolic_shapes(a, b): + """ + Broadcasting logic based on symbolic shapes. + + We give the shapes 0 and 1 concrete values, while all other shapes + are symbolic sympy formulas. + """ + output = [] + for x, y in itertools.zip_longest( + reversed(a), reversed(b), fillvalue=sympy.Integer(1) + ): + if y == 1: + output.append(x) + elif x == 1: + output.append(y) + else: + V.graph.sizevars.guard_equals(x, y) + if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols): + output.append(y) # prefer shorter formula + else: + output.append(x) + return tuple(reversed(output)) + + +def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None): + assert ( + override_return_dtype is None or type_promotion_kind is None + ), "only one of override_return_dtype or type_promotion_kind may be given" + + if override_return_dtype is None and type_promotion_kind is None: + type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + + if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs): + return inputs + if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs): + dtype = override_return_dtype or get_promoted_dtype( + *inputs, type_promotion_kind=type_promotion_kind + ) + + def const_func(x): + if isinstance(x, sympy.Basic): + return ir.IndexingConstant(x, dtype, decode_device(None)) + else: + return ir.Constant(x, dtype, decode_device(None)) + + return [const_func(x) for x in inputs] + ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView, ir.Constant))) + out = [] + for x in inputs: + if isinstance(x, (int, float)): + out.append( + ExpandView.create( + ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size()) + ) + ) + elif isinstance(x, sympy.Basic): + out.append( + ExpandView.create( + IndexingConstant(x, ex.get_dtype(), ex.get_device()), + list(ex.get_size()), + ) + ) + else: + out.append(x) + + return out + + +def make_pointwise( + fn, + override_return_dtype=None, + override_device=None, + override_fn_when_input_bool=None, + override_fn_when_gpu_float64=None, + allow_alpha=False, + triton_fallback=None, +): + def inner(*inputs: List[TensorBox], alpha=None): + if triton_fallback is not None and any(map(is_triton, inputs)): + assert not allow_alpha # not implemented + return triton_fallback(*inputs) + + inputs = promote_constants(inputs, override_return_dtype) + if allow_alpha: + if alpha is not None and alpha != 1: + inputs = list(inputs) + inputs[-1] = mul(inputs[-1], alpha) + else: + assert alpha is None + loaders = [x.make_loader() for x in inputs] + ranges = inputs[0].get_size() + dtype = override_return_dtype or inputs[0].get_dtype() + is_gpu_device = is_gpu(decode_device(inputs[0].get_device()).type) + + for other in inputs[1:]: + assert isinstance(other, ir.BaseConstant) or len(ranges) == len( + other.get_size() + ), f"ndim mismatch {fn} {ranges} {other.get_size()}" + + # in tracing, we will annotate pointwise nodes that correspond to the output of + # a pointwise node that would have been run in eager. intermediary pointwise nodes + # during decompositions are not annotated. + emulate_precision_casts = ( + V.graph is not None + and getattr(V.graph, "current_node", None) is not None + and V.graph.current_node.meta is not None + and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False) + and dtype in (torch.bfloat16, torch.float16) + ) + + def inner_fn(index): + assert len(index) == len(ranges), f"wrong ndim {index} {ranges}" + if dtype == torch.bool and override_fn_when_input_bool is not None: + return override_fn_when_input_bool(*[load(index) for load in loaders]) + elif ( + override_fn_when_gpu_float64 + and is_gpu_device + and dtype == torch.float64 + ): + return override_fn_when_gpu_float64(*[load(index) for load in loaders]) + else: + inputs_loaded = [] + for load in loaders: + out = load(index) + if emulate_precision_casts: + downcast = ops.to_dtype(out, dtype, use_compute_types=False) + out = ops.to_dtype(downcast, dtype) + inputs_loaded.append(out) + + out = fn(*inputs_loaded) + if emulate_precision_casts: + # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here, + # then upcasting again, to emulate casts that eager would do. + downcast = ops.to_dtype(out, dtype, use_compute_types=False) + return ops.to_dtype(downcast, dtype) + return out + + if not override_device: + device = None + for i in inputs: + if is_gpu(i.get_device().type): + device = i.get_device() + break + if not device: + device = inputs[0].get_device() + + device = override_device or device + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + ) + + return inner + + +def make_foreach_pointwise(pw_fn, allow_alpha=False): + def inner(*inputs: List[List[TensorBox]], alpha=1): + # group by device, whether any of the inputs are dynamic, and whether their types match + # (proxy for type promotion) + def group_args(arg_pairs): + out = defaultdict(list) + for i, args in enumerate(arg_pairs): + use_foreach = ( + not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes + ) + device = None + for t in args: + if isinstance(t, TensorBox): + device = t.data.get_device() + break + assert ( + device is not None + ), "foreach op should have at least one tensor arg" + out[(device, use_foreach)].append((i, args)) + return out + + realize_outputs = ( + len(V.graph.current_node.users) == 0 + or V.graph.current_node.target in inplace_foreach_ops + ) + for node in V.graph.current_node.users: + for user in node.users: + if not (user.op == "call_function" and (user.target in foreach_ops)): + realize_outputs = True + + a_list_input = None + for input in inputs: + if isinstance(input, (list, tuple)): + a_list_input = input + break + assert ( + a_list_input is not None + ), "at least one input must be a list to a foreach op" + + # broadcast scalar inputs to match length of list inputs + broadcast_inputs = [] + for input in inputs: + if not isinstance(input, (list, tuple)): + broadcast_inputs.append([input] * len(a_list_input)) + else: + broadcast_inputs.append(input) + + groups = group_args(zip(*broadcast_inputs)) + + outputs = [None] * len(a_list_input) + for (device, use_foreach), group in groups.items(): + operation_list: List[str] = [] + for ( + output_ind, + args, + ) in group: + if allow_alpha: + output = pw_fn(*args, alpha=alpha) + else: + output = pw_fn(*args) + + outputs[output_ind] = output + + if ( + V.graph.has_feature(device, BackendFeature.FOREACH) + and use_foreach + and realize_outputs + ): + output.realize() + operation_list.append(output.get_operation_name()) + + if operation_list: + V.graph.register_operation_list(operation_list) + + assert all(x is not None for x in outputs) + return outputs + + return inner + + +def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False): + src_dtype = x.get_dtype() + if src_dtype == dtype: + return clone(x) if copy else x + + def _to_dtype(x): + return ops.to_dtype(x, dtype, src_dtype=src_dtype) + + return make_pointwise(_to_dtype, override_return_dtype=dtype)(x) + + +@register_lowering(prims.convert_element_type, type_promotion_kind=None) +def _convert_element_type(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + if x.get_size(): + # Decompose since aa aten fallback is more friendly for c++ codegen. + # This decomposition doesn't work for empty tensor, which needs more investigation. + dst = empty_like(x, dtype=dtype) + ir.InplaceCopyFallback.create(dst, x) + return dst + else: + return fallback_handler( + prims.convert_element_type.default, add_to_fallback_set=False + )(x, dtype) + return to_dtype(x, dtype, copy=True) + + +def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False): + x_dtype = x.get_dtype() + if x_dtype == dtype: + return clone(x) if copy else x + + def _get_primitive_bitwidth(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits + else: + return torch.iinfo(dtype).bits + + src_bits = _get_primitive_bitwidth(x_dtype) + dst_bits = _get_primitive_bitwidth(dtype) + if src_bits != dst_bits: + # fallback to aten eager implementation for differing bitwidths + return fallback_handler(aten.view.dtype)(x, dtype) + else: + return TensorBox(DtypeView.create(x, dtype)) + + +@register_lowering(aten.view.dtype, type_promotion_kind=None) +def _view_dtype(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + return TensorBox.create( + ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype) + ) + return to_dtype_bitcast(x, dtype) + + +def to_device(x: TensorBox, device: torch.device, *, copy=False): + device = decode_device(device) + if x.get_device() == device: + return clone(x) if copy else x + return TensorBox.create(ir.DeviceCopy.create(x, device)) + + +@register_lowering(prims.device_put, type_promotion_kind=None) +def _device_put(x: TensorBox, device: torch.device): + return to_device(x, device, copy=True) + + +def register_pointwise( + aten_fn, + name=None, + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + override_return_dtype=None, + override_fn_when_input_bool=None, + allow_alpha=False, + use_libdevice_for_f64=False, + triton_fallback=None, +): + """A pointwise function that maps ops.{name} to inputs""" + name = name or aten_fn.__name__ + fn = ops_wrapper(name) + if use_libdevice_for_f64: + fn_libdevice = ops_wrapper("libdevice_" + name) + if override_fn_when_input_bool is not None: + override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool) + + fn = make_pointwise( + fn, + override_return_dtype=override_return_dtype, + override_fn_when_input_bool=override_fn_when_input_bool, + override_fn_when_gpu_float64=fn_libdevice if use_libdevice_for_f64 else None, # type: ignore[possibly-undefined] + allow_alpha=allow_alpha, + triton_fallback=triton_fallback, + ) + fn = register_lowering( + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + convert_input_to_bool=convert_input_to_bool, + )(fn) + return fn + + +def register_frexp(): + """A pointwise function that maps ops.frexp to inputs""" + name = "frexp" + frexp = ops_wrapper("frexp") + + def frexp0(*args, **kwargs): + return frexp(*args, **kwargs)[0] # type: ignore[index] # next PR + + def frexp1(*args, **kwargs): + return frexp(*args, **kwargs)[1] # type: ignore[index] # next PR + + pw_fns = [ + make_pointwise(frexp0), + make_pointwise(frexp1, override_return_dtype=torch.int32), + ] + + def fn(*args, **kwargs): + return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs) + + fn = register_lowering( + aten.frexp, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + )(fn) + return fn + + +register_frexp() + + +def register_foreach_pointwise( + aten_fn, + pointwise_lowering_fn, + allow_alpha=False, +): + fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha) + fn = _register_foreach_lowering(aten_fn, fn) + return fn + + +@register_lowering(aten.where, broadcast=False, type_promotion_kind=None) +def where(cond, a, b): + def fn(*args): + return ops.where(*args) + + if isinstance(a, (float, int)): + a = constant_like(a)(b) + if isinstance(b, (float, int)): + b = constant_like(b)(a) + + args = [cond, a, b] + dtype = get_promoted_dtype( + args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + return make_pointwise(fn, override_return_dtype=dtype)( + args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype) + ) + + +@register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None) +def broadcast_tensors(*inputs): + if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): + return broadcast_tensors(*inputs[0]) + target: List[sympy.Expr] = functools.reduce( + broadcast_symbolic_shapes, [x.get_size() for x in inputs], [] + ) + outputs = [] + for x in inputs: + sizes = x.get_size() + if len(sizes) != len(target) or any( + ((a == 1 and b != 1) or (a != 1 and b == 1)) for a, b in zip(sizes, target) + ): + x = expand(x, target) + outputs.append(x) + return outputs + + +@register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of]) +def nop(x): + return x # AOT autograd handles this for us + + +if hasattr(aten, "lift_fresh"): + register_lowering(aten.lift_fresh)(nop) + + +@register_lowering(aten.squeeze, type_promotion_kind=None) +def squeeze(x, dim=None): + assert isinstance(x, TensorBox) + if dim is None: + return TensorBox(SqueezeView.create(x.data)) + + dim = ( + V.graph.sizevars.evaluate_static_shape(dim) + if isinstance(dim, (int, sympy.Expr)) + else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim) + ) + dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload] + dims = set((dim,) if not isinstance(dim, tuple) else dim) + + new_shape = [] + for d, s in enumerate(x.get_size()): + if not (d in dims and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1))): + new_shape.append(s) + + # squeeze does nothing if the size isn't 1 + return view(x, new_shape) if new_shape != x.get_size() else x + + +@register_lowering(aten.squeeze_copy, type_promotion_kind=None) +def squeeze_copy(x, dim=None): + return clone(squeeze(x, dim)) + + +@register_lowering([aten.squeeze_]) +def squeeze_(x, dim=None): + val = squeeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +@register_lowering(aten.isinf) +def isinf(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isinf") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.isnan) +def isnan(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isnan") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.ceil) +def ceil(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("ceil") + return make_pointwise(fn)(x) + + +@register_lowering(aten.floor) +def floor(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("floor") + return make_pointwise(fn)(x) + + +@register_lowering(aten.round.default) +def round(x): + if is_integer_type(x): + return clone(x) + else: + fn = ops_wrapper("round") + return make_pointwise(fn)(x) + + +@register_lowering(aten.trunc) +def trunc(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("trunc") + return make_pointwise(fn)(x) + + +@register_lowering(aten.expand, type_promotion_kind=None) +def expand(x, sizes): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + (x,) = promote_constants([x]) + if isinstance(x, ir.BaseConstant): + return ExpandView.create(x, tuple(sizes)) + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + if tuple(x.get_size()) == tuple(sizes): + return x + + if not free_unbacked_symbols(x.get_size()): + x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size())) + # TODO: It would be better to realize the input if any of its sizes + # are unbacked, because typically the size will be non-zero. However, + # this cannot be done directly as below as we'll choke on the size_hint + # here + if x_size_product > 0 and not free_unbacked_symbols(sizes): + # maybe realize input before broadcasting it + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product + ) + return TensorBox(ExpandView.create(x.data, tuple(sizes))) + + +@register_lowering(prims.broadcast_in_dim, type_promotion_kind=None) +def broadcast_in_dim(a, shape, broadcast_dimensions): + s = list(shape) + for broadcast_dimension in broadcast_dimensions: + s[broadcast_dimension] = -1 + + v = a + for idx, x in enumerate(s): + if x != -1: + v = unsqueeze(v, idx) + + return expand(v, shape) + + +@register_lowering(aten.expand_as, type_promotion_kind=None) +def expand_as(x, y): + return expand(x, y.get_size()) + + +@register_lowering(aten.repeat) +def repeat(x, repeats): + old_size = list(x.get_size()) + if len(repeats) > len(old_size): + old_size = [sympy.Integer(1)] * (len(repeats) - len(old_size)) + old_size + x = view(x, list(old_size)) + assert len(repeats) == len(x.get_size()) + + new_size = list(x.get_size()) + + zero_tensor = False + for i in range(len(repeats)): + if repeats[i] == 0: + zero_tensor = True + new_size[i] = new_size[i] * repeats[i] + + if zero_tensor: + return empty(new_size, dtype=x.get_dtype(), device=x.get_device()) + if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): + return clone(expand(x, new_size)) + + x_loader: Callable[[Any], Any] + + def inner_fn(index): + assert len(index) == len(repeats) + index = list(index) + for i in range(len(repeats)): + if repeats[i] != 1: + if old_size[i] == 1: + index[i] = sympy.Integer(0) + else: + index[i] = ModularIndexing(index[i], 1, old_size[i]) + return x_loader(index) + + old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size)) + if old_size_product > 0: + # maybe realize the input + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product + ) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(new_size), + ) + + +@register_lowering(aten._unsafe_view, type_promotion_kind=None) +@register_lowering(aten.view, type_promotion_kind=None) +@register_lowering(aten.reshape, type_promotion_kind=None) +def view(x, sizes): + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + return TensorBox(View.create(x.data, sizes)) + + +@register_lowering(aten.permute, type_promotion_kind=None) +def permute(x, dims): + assert isinstance(x, TensorBox) + assert isinstance(dims, (list, tuple)) + return TensorBox(PermuteView.create(x.data, tuple(dims))) + + +@register_lowering(aten.slice, type_promotion_kind=None) +def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True): + assert isinstance(x, TensorBox) + dim = _validate_dim(x, dim, 0) + return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp)) + + +@register_lowering(aten.as_strided, type_promotion_kind=None) +def as_strided(x, size, stride, storage_offset=None): + if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView): + # as_strided ignores views + x = x.data.unwrap_view() + x.realize() + if not ir.is_storage_and_layout(x): + raise NotImplementedError(f"unrealized as_strided({x}, ...)") + storage, old_layout = ir.as_storage_and_layout(x) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + [sympy.expand(s) for s in size], + [sympy.expand(s) for s in stride], + sympy.expand(storage_offset or 0), + ) + return TensorBox(ir.ReinterpretView(storage, new_layout)) + + +@register_lowering(aten.as_strided_, type_promotion_kind=None) +def as_strided_(x, size, stride, storage_offset=None): + assert isinstance(x, TensorBox) + x.data = as_strided(x, size, stride, storage_offset).data + return x + + +@register_lowering(aten.as_strided_copy, type_promotion_kind=None) +def as_strided_copy(x, size, stride, storage_offset=None): + result = as_strided(x, size, stride, storage_offset) + return clone(result) + + +def pointwise_cat(inputs, dim=0): + # (inclusive, exclusive) + inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = [] + prev_end = 0 + for inp in inputs: + inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type] + prev_end = inputs_ranges[-1][-1] # type: ignore[assignment] + + inputs_loaders = [inp.make_loader() for inp in inputs] + + def inner_fn(idx): + idx_dim = ops.index_expr(idx[dim], torch.int64) + + masks = [] + masked_loads = [] + for i in range(len(inputs)): + start = ( + ops.constant(0, torch.int64) + if i == 0 + else ops.index_expr(inputs_ranges[i][0], torch.int64) + ) + end = ops.index_expr(inputs_ranges[i][1], torch.int64) + + start_cond = ops.ge(idx_dim, start) + end_cond = ops.lt(idx_dim, end) + if i == 0: + mask = end_cond + elif i == len(inputs) - 1: + mask = start_cond + else: + mask = ops.and_(start_cond, end_cond) + + masks.append(mask) + idx_load = list(idx) + + # if we're concatting [4], [2] + # when we index the second tensor for 5 we want to index 5 - 4 + # Use Identity to prevent expansion of index * stride to keep expression + # in same int bitwidth as shape + idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0]) + + masked_loads.append( + ops.masked( + mask, + lambda: inputs_loaders[i](idx_load), + 0.0, # this value should be unused + ), + ) + + next_val = masked_loads[-1] + for i in range((len(inputs)) - 2, -1, -1): + next_val = ops.where( + masks[i], + masked_loads[i], + next_val, + ) + return next_val + + new_size = list(inputs[0].get_size()) + new_size[dim] = inputs_ranges[-1][-1] + + return Pointwise.create( + device=inputs[0].get_device(), + dtype=inputs[0].get_dtype(), + inner_fn=inner_fn, + ranges=new_size, + ) + + +@register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None) +def quantized_decomposed_quantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert ( + input.get_dtype() == torch.float32 + ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + assert axis < len( + input.get_size() + ), f"Expecting axis to be < {len(input.get_size())}" + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.int32: + zero_point = ops.to_dtype(zero_point, torch.int32) + inv_scale = ops.reciprocal(scale) + val = ops.round(input * inv_scale) + zero_point + clamped = ops.maximum(qmin, ops.minimum(qmax, val)) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_channel, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + assert ( + input.get_dtype() == dtype + ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + assert axis < len( + input.get_size() + ), f"Expecting axis to be < {len(input.get_size())}" + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.float32: + zero_point = ops.to_dtype(zero_point, torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale + return val + + return Pointwise.create( + device=input.get_device(), + dtype=torch.float32, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None +) +def quantized_decomposed_quantize_per_tensor_default( + input: TensorBox, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert ( + input.get_dtype() == torch.float32 + ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + + input_loader = input.make_loader() + + def inner_fn(idx, scale, zero_point): + input = input_loader(idx) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=functools.partial( + inner_fn, scale=float(scale), zero_point=int(zero_point) + ), + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_tensor_default( + input: TensorBox, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert ( + input.get_dtype() == dtype + ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + + input_loader = input.make_loader() + + def inner_fn(idx, scale, zero_point): + input = input_loader(idx) + scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale + return val + + return Pointwise.create( + device=input.get_device(), + dtype=torch.float32, + inner_fn=functools.partial( + inner_fn, scale=float(scale), zero_point=int(zero_point) + ), + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None +) +def quantized_decomposed_quantize_per_tensor_tensor( + input: TensorBox, + scale: TensorBox, + zero_point: TensorBox, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert ( + input.get_dtype() == torch.float32 + ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + assert len(scale.get_size()) == 0 or ( + len(scale.get_size()) == 1 and scale.get_size()[0] == 1 + ), "expect scale as scalar tensor" + assert len(zero_point.get_size()) == 0 or ( + len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 + ), "expect zero_point as scalar tensor" + + input_loader = input.make_loader() + scale_loader = scale.make_loader() + zero_point_loader = zero_point.make_loader() + + def inner_fn(idx): + input = input_loader(idx) + _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) + _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) + if scale.dtype != torch.float32: + _scale = ops.to_dtype(_scale, torch.float32) + if zero_point.dtype != torch.float32: + _zero_point = ops.to_dtype(_zero_point, torch.float32) + val = ops.round(input * ops.reciprocal(_scale)) + _zero_point + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_tensor_tensor( + input: TensorBox, + scale: TensorBox, + zero_point: TensorBox, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scale.get_size()) == 0 or ( + len(scale.get_size()) == 1 and scale.get_size()[0] == 1 + ), "expect scale as scalar tensor" + assert len(zero_point.get_size()) == 0 or ( + len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 + ), "expect zero_point as scalar tensor" + assert ( + input.get_dtype() == dtype + ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + + input_loader = input.make_loader() + scale_loader = scale.make_loader() + zero_point_loader = zero_point.make_loader() + + def inner_fn(idx): + input = input_loader(idx) + _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) + _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) + if scale.dtype != torch.float32: + _scale = ops.to_dtype(_scale, torch.float32) + if zero_point.dtype != torch.float32: + _zero_point = ops.to_dtype(_zero_point, torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale + return val + + return Pointwise.create( + device=input.get_device(), + dtype=torch.float32, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering(aten.cat) +def cat(inputs, dim=0): + cpu_device = inputs[0].get_device().type == "cpu" + if cpu_device and all( + input.get_dtype() in [torch.int8, torch.uint8] for input in inputs + ): + # TODO Remove this fallback when we support vectorization + # code gen with uint8 data type directly. + for input in inputs: + input.realize() + if all(len(input.get_size()) == 4 for input in inputs): + inputs, _ = require_channels_last(aten.cat, *inputs) + return fallback_handler(aten.cat.default)(inputs, dim) + + if len(inputs) == 1: + return clone(inputs[0]) + + dim = _validate_dim(inputs[0], dim, 0) + dtype = get_promoted_dtype( + *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + inputs = [to_dtype(inp, dtype) for inp in inputs] + + def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode: + if isinstance(x, TensorBox): + if isinstance(x.data, ir.BaseView): + return x.data.unwrap_view() + else: + return x.data + + if isinstance(x, ir.StorageBox): + return x.data + + return x + + def is_reduction(t): + return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction) + + def can_fuse_reduction(t): + if isinstance(t, (TensorBox, ir.StorageBox)): + return can_fuse_reduction(unwrap_tensor(t)) + return ( + is_reduction(t) + or isinstance(t, ir.Pointwise) + and any( + can_fuse_reduction(V.graph.get_buffer(read)) + for read in t.get_read_names() + ) + ) + + # fusing reducutions into computed concat buffer can cause regressions. + fusable_reduction = any(can_fuse_reduction(t) for t in inputs) + + def should_lower_cat_input(x) -> bool: + # Unrealized inputs will not be storage and layouts, and we dont want to realize + # them in case we want to fuse + if ir.is_storage_and_layout(x): + storage, _ = ir.as_storage_and_layout(x, freeze=False) + return not ir.ConcatKernel.can_realize_into_without_copy(storage) + + if isinstance(x, (TensorBox, ir.StorageBox)): + return should_lower_cat_input(unwrap_tensor(x)) + + if isinstance(x, ir.Pointwise): + return True + + return False + + # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it. + # We will revisit this later after enabling vectorization on index_expr. + if cpu_device: + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + def op_count(x): + if isinstance(x, (TensorBox, ir.StorageBox)): + return op_count(unwrap_tensor(x)) + + # this will correspond to a direct memory read + if not isinstance(x, ir.Pointwise): + return 0 + + count = x.inner_fn_opcount().num_ops + for read in x.get_read_names(): + count += op_count(V.graph.get_buffer(read)) + + return count + + # as of inputs increase, possibility for register spilling also increases + # past a certain threshold of inputs we only fuse if the if the input kernels + # are simple + # not sure if we want to expose to users via config since logic may change in future + MAX_COMPLEX_POINTWISE_CAT = 8 + MAX_SIMPLE_OP_COUNT = 2 + + def additional_pointwise_ops(op: torch._ops.OpOverload): + return op in (aten.cat.default, aten.constant_pad_nd.default) + + if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or ( + (len(inputs) <= config.max_pointwise_cat_inputs) + and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs) + ): + pointwise_uses = all( + is_pointwise_use(use, additional_pointwise_ops) + for use in V.current_node.users + ) + # fuse in case we will be used in a pointwise node, and there are any inputs we + # we can prevent materialization of. + fuse_pointwise_use = ( + any(should_lower_cat_input(inp) for inp in inputs) and pointwise_uses + ) + + # horizontal fuse in case all inputs will require a copy kernel anyway. + # only horizontally fuse pointwise kernels + horizontal_fuse_cat = all( + should_lower_cat_input(inp) for inp in inputs + ) and not any(can_fuse_reduction(t) for t in inputs) + if fuse_pointwise_use or (horizontal_fuse_cat and not fusable_reduction): + return pointwise_cat(inputs, dim) + + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + +@register_lowering(aten.diagonal, type_promotion_kind=None) +def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + original_shape = input.get_size() + num_dims = len(original_shape) + dim1 = canonicalize_dim(idx=dim1, rank=num_dims) + dim2 = canonicalize_dim(idx=dim2, rank=num_dims) + + check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0)) + if offset_negative: + diag_size = V.graph.sizevars.evaluate_max( + V.graph.sizevars.evaluate_min( + original_shape[dim1] + offset, original_shape[dim2] + ), + 0, # type: ignore[arg-type] + ) + else: + diag_size = V.graph.sizevars.evaluate_max( + V.graph.sizevars.evaluate_min( + original_shape[dim1], original_shape[dim2] - offset + ), + 0, # type: ignore[arg-type] + ) + + base_idx = (0, 0) + if offset_negative: + base_idx = (-offset, 0) + else: + base_idx = (0, offset) + + sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)] + sizes.append(diag_size) + + def reindexer(idx): + diag_idx = idx[-1] + original_idx = [0] * len(original_shape) + cur_dim = 0 + for d in range(num_dims): + if d == dim1: + original_idx[d] = diag_idx + base_idx[0] + elif d == dim2: + original_idx[d] = diag_idx + base_idx[1] + else: + original_idx[d] = idx[cur_dim] + cur_dim += 1 + + assert cur_dim == len(original_shape) - 2 + return original_idx + + return TensorBox(ir.GenericView.create(input, sizes, reindexer)) + + +@register_lowering(aten.diagonal_copy, type_promotion_kind=None) +def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + return clone(diagonal(input, offset, dim1, dim2)) + + +@register_lowering(aten.diagonal_scatter, type_promotion_kind=None) +def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): + output = clone(input) + target = diagonal(output, offset, dim1, dim2) + mutate_to(target, src) + return output + + +@register_lowering(aten.select, type_promotion_kind=None) +def select(x, dim, idx): + idx = View.handle_negative_index(idx, x.get_size()[dim]) + return squeeze(slice_(x, dim, idx, idx + 1), dim) + + +@register_lowering(aten.split, type_promotion_kind=None) +def split(x, sizes, dim=0, clamp=True): + dim = _validate_dim(x, dim, 0) + if isinstance(sizes, sympy.Expr): + # TODO: We don't have to guard on sizes per se, but the number + # of splits must stay constant + sizes = V.graph.sizevars.evaluate_static_shape(sizes) + if isinstance(sizes, (int, sympy.Integer)): + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + sizes = [sizes] * ((x_size + sizes - 1) // sizes) + result = [] + start = 0 + for size in sizes: + end = start + size + result.append(slice_(x, dim, start, end, clamp=clamp)) + start = end + return result + + +@register_lowering(aten.split_with_sizes, type_promotion_kind=None) +def split_with_sizes(x, sizes, dim=0): + return split(x, sizes, dim, clamp=False) + + +@register_lowering(aten.unbind, type_promotion_kind=None) +def unbind(x, dim=0): + dim = _validate_dim(x, dim, 0) + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + result = [] + for i in range(x_size): + result.append(select(x, dim, i)) + return result + + +@register_lowering(aten.unfold, type_promotion_kind=None) +def unfold(x, dimension, size, step): + sizes = x.get_size() + ndim = len(sizes) + dim = canonicalize_dim(ndim, dimension) + + if ndim == 0: + return slice_(unsqueeze(x, 0), end=size) + + dim_size = sizes[dim] + sizevars = V.graph.sizevars + sizevars.guard_leq(size, dim_size) + sizevars.guard_lt(0, step) # type: ignore[arg-type] + + new_dim_size = FloorDiv(dim_size - size, step) + 1 + if sizevars.size_hint(dim_size) > 0: + x.mark_reuse(sizevars.size_hint(CeilDiv(new_dim_size * size, dim_size))) + + out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size] + + def reindexer(idx): + dim_idx = idx[-1] + idx[dim] * step + return (*idx[:dim], dim_idx, *idx[dim + 1 : -1]) + + return TensorBox(ir.GenericView.create(x, out_size, reindexer)) + + +@register_lowering(aten.unsqueeze, type_promotion_kind=None) +def unsqueeze(x, dim): + dim = _validate_dim(x, dim, 1) + new_shape = list(x.get_size()) + new_shape.insert(dim, sympy.Integer(1)) + return view(x, new_shape) + + +@register_lowering(aten.unsqueeze_, type_promotion_kind=None) +def unsqueeze_(x, dim): + val = unsqueeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +def _validate_dim(x, dim, offset=0): + dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim)) + ndim = len(x.get_size()) + if dim < 0: + dim += ndim + offset + assert 0 <= dim < ndim + offset + return dim + + +@register_lowering(aten.glu) +def glu(x, dim=-1): + dim = _validate_dim(x, dim, 0) + # TODO: don't guard on static shape here + new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2 + a = slice_(x, dim, 0, new_len) + b = slice_(x, dim, new_len, new_len * 2) + return mul(a, sigmoid(b)) + + +def fallback_handler(kernel, add_to_fallback_set=True): + if add_to_fallback_set: + fallbacks.add(kernel) + + def handler(*args, **kwargs): + def wrap_tensors(x): + return TensorBox.create(x) if isinstance(x, ir.IRNode) else x + + return pytree.tree_map( + wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs) + ) + + return handler + + +@functools.lru_cache(None) +def _warn_complex_not_supported(): + warnings.warn( + "Torchinductor does not support code generation for complex operators. Performance may be worse than eager." + ) + + +# There are some types (CPU) which we accept as input but not as +# output. +def unsupported_input_tensor(t: torch.Tensor, parent=None): + "Do not support reading or writing to this tensor" + if t.is_complex(): + # Complex views are supported with IR ComplexView + if parent and parent.target in ( + torch.ops.aten.view.dtype, + torch.ops.prims.convert_element_type.default, + ): + return False + _warn_complex_not_supported() + return True + return False + + +def unsupported_output_tensor(t: torch.Tensor, parent=None): + "Do not support writing tensor but can read from it" + if unsupported_input_tensor(t, parent): + return True + return t.is_cpu and config.disable_cpp_codegen + + +def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True): + # Custom fallback lowering + if node.target is aten.view_as_complex.default: + return False + + # We should be able to remove this special case once `disable_cpp_codegen` is killed. + if node.target is aten.lift_fresh_copy.default: + return False + + def check_skip_condition(node, parent, is_output): + if not isinstance(node, torch.fx.Node): + return False + + if "val" not in node.meta: + return False + + for meta in pytree.tree_leaves(node.meta["val"]): + if not isinstance(meta, torch._subclasses.FakeTensor): + continue + + if is_output: + if unsupported_output_tensor(meta, parent): + return True + else: + if unsupported_input_tensor(meta, parent): + return True + + return False + + # only skip codegen if there is a cpu output, not input + for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs): + if check_skip_condition(arg, node, is_output=False): + return True + + return check_skip_condition(node, node, is_output=True) + + +def make_fallback(op, layout_constraint=None, warn=True): + assert op not in decompositions, f"both a fallback and a decomp for same op: {op}" + if ( + warn + and bool(os.getenv("CI")) + and get_decompositions([op]) + # if fallback_random, we allow not decomposing random + and not ( + config.fallback_random + and op in torch._decomp.decompositions_for_rng.extra_random_decomps + ) + ): + # Note: 'warn' is holdover from when this was a warning, but for ops that previously + # set warn=False we do not want a CI error. + # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not + # likely to be triggered preferentially on one CI config over another. + if torch._dynamo.config.suppress_errors: + torch._dynamo.config.suppress_errors = False + log.warning( + "A make_fallback error occurred in suppress_errors config," + " and suppress_errors is being disabled to surface it." + ) + raise AssertionError( + f"make_fallback({op}): a decomposition exists, we should switch to it." + " To fix this error, either add a decomposition to core_aten_decompositions (preferred)" + " or inductor_decompositions, and delete the corresponding `make_fallback` line." + " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.", + ) + + def register_fallback(op_overload): + add_needs_realized_inputs(op_overload) + if layout_constraint is not None: + add_layout_constraint(op_overload, layout_constraint) + return register_lowering(op_overload, type_promotion_kind=None)( + fallback_handler(op_overload) + ) + + if isinstance(op, torch._ops.OpOverloadPacket): + for ol in op.overloads(): + op_overload = getattr(op, ol) + register_fallback(op_overload) + elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + register_fallback(op) + else: + raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}") + + +def philox_rand_offset(shape): + """ + TorchInductor offset calculation differs from PyTorch eager offset + calculation for random ops (tl.rand vs torch.rand). In future, we should + strive for same impl for tl.rand and torch.rand. + """ + numel = 1 + for s in shape: + numel = numel * s + return tensor(numel, dtype=torch.int64) + + +@register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None) +def philox_rand(size, seed, offset, stride, device, dtype): + # stride arg is optional and will be used in future for distributed random + # ops. Currently, its unused. + random_pos = ir.FixedLayout( + device, + dtype, + size, + ir.FlexibleLayout.contiguous_strides(size), + ).make_indexer() + seed_loader = seed.make_loader() + offset_loader = offset.make_loader() + + def inner_fn(index): + # Both seed and offset in the philox_rand op are tensors. + # torch seed and offsets are of type int64, but tl.rand accepts int32 + seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32) + offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32) + # Get the offset'd position + rand_index_expr = ops.add( + ops.index_expr(random_pos(index), torch.int32), offset_index_expr + ) + result = ops.rand( + seed_index_expr, + rand_index_expr, + ) + return ops.to_dtype(result, dtype) + + random_values_node = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + offset_node = philox_rand_offset(size) + return random_values_node, offset_node + + +@register_lowering(aten.native_dropout, type_promotion_kind=None) +def native_dropout(x, p, train): + if config.fallback_random: + return pytree.tree_map( + TensorBox.create, + ir.FallbackKernel.create(aten.native_dropout.default, x, p, train), + ) + else: + raise AssertionError("should be handled in replace_random.py") + + +@register_lowering(aten.bernoulli_, type_promotion_kind=None) +def bernoulli_(x, *args): + assert config.fallback_random or x.get_device() == torch.device( + "cpu" + ), "this should be handled in decomps unless config.fallback_random or the device is CPU" + x.realize() + op_overload = ( + aten.bernoulli_.float + if len(args) == 0 or isinstance(args[0], float) + else aten.bernoulli_.Tensor + ) + ir.InplaceBernoulliFallback(op_overload, x, *args) + return x + + +@register_lowering(aten.bernoulli.p, type_promotion_kind=None) +def bernoulli_p(x, *args): + assert config.fallback_random or x.get_device() == torch.device( + "cpu" + ), "this should be handled in decomps unless config.fallback_random or the device is CPU" + return bernoulli_(clone(x), *args) + + +# This shouldn't be called in general +@register_lowering(aten._foobar) +def _foobar(_): + raise AssertionError + + +@functools.lru_cache(1) +def _warn_triton_random(salt): + log.info("using triton random, expect difference from eager") + + +def warn_triton_random(): + # only warn once per graph + _warn_triton_random(V.graph.creation_time) + + +fallback_rand_default = fallback_handler(aten.rand.default) +fallback_rand_generator = fallback_handler(aten.rand.generator) +fallback_randn_default = fallback_handler(aten.randn.default) +fallback_randn_generator = fallback_handler(aten.randn.generator) +make_fallback(aten.randint) + + +@register_lowering(aten.rand) +def rand(*args, **kwargs): + if kwargs.get("generator", None) is not None: + return fallback_rand_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_rand_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(aten.randn) +def randn(*args, **kwargs): + if kwargs.get("generator", None) is not None: + return fallback_randn_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_randn_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None) +def inductor_force_stride_order(input_tensor, stride): + stride_order = ir.get_stride_order(stride) + return ir.ExternKernel.require_stride_order(input_tensor, stride_order) + + +@register_lowering(inductor_prims.seed, type_promotion_kind=None) +def inductor_seed(device: torch.device): + raise AssertionError("should be handled in fuse_seed_creation_pass()") + + +@register_lowering(inductor_prims.seeds, type_promotion_kind=None) +def inductor_seeds(count, device): + warn_triton_random() + return TensorBox.create(ir.RandomSeeds(count, decode_device(device))) + + +@register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None) +def inductor_lookup_seed(seeds, index): + def inner_fn(_): + return ops.load_seed(seeds.get_name(), index) + + return Pointwise.create( + device=seeds.get_device(), + dtype=seeds.get_dtype(), + inner_fn=inner_fn, + ranges=[], + ) + + +@register_lowering(inductor_prims.random, type_promotion_kind=None) +def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int = 0): + assert not config.fallback_random + assert mode in ("rand", "randn") + size = [*size] + dtype = torch.float32 + device = seed.get_device() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return getattr(ops, mode)( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ) + + result = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + result.realize() + return result + + +@register_lowering(inductor_prims.randint, type_promotion_kind=None) +def inductor_randint( + low: int, high: int, size: List[int], seed: TensorBox, *, offset: int = 0 +): + assert not config.fallback_random + size = [*size] + dtype = torch.int64 + device = seed.get_device() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return ops.randint64( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ops.index_expr(low, torch.int64), + ops.index_expr(high, torch.int64), + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + + +@register_lowering(aten.bucketize, type_promotion_kind=None) +def bucketize( + input: TensorBox, + boundaries: TensorBox, + *, + out_int32: bool = False, + right: bool = False, +): + assert len(boundaries.get_size()) == 1 + + if not ( + V.graph.has_feature(input, BackendFeature.BUCKETIZE) + and V.graph.has_feature(boundaries, BackendFeature.BUCKETIZE) + ): + return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)( + input, boundaries, out_int32=out_int32, right=right + ) + + # The entire boundaries tensor needs to be used by ops.bucketize, so we + # need to realize it into global memory; or in other words, we can't + # guarantee that boundaries.get_name() (used below) will exist unless + # we call boundaries.realize(). + boundaries.realize() + boundaries_size = boundaries.get_size()[0] + device = input.get_device() + input_loader = input.make_loader() + + index_dtype = torch.int32 if out_int32 else torch.int64 + + def inner_fn(index): + val = input_loader(index) + indices = ops.bucketize( + val, + boundaries.get_name(), + boundaries_size, + index_dtype, + right, + ) + + return indices + + return Pointwise.create( + device=device, + dtype=index_dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +def require_dense(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs) + ) + return args, kwargs + + +def require_contiguous(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs) + ) + return args, kwargs + + +def require_channels_last(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs) + ) + return args, kwargs + + +def constrain_to_fx_strides(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, ir.IRNode): + stride_order = ir.get_stride_order(fx_arg.meta["val"].stride()) + return ir.ExternKernel.require_stride_order(arg, stride_order) + return arg + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +# TODO(jansel): we should implement decomps or lowerings for these +# https://github.com/pytorch/torchdynamo/issues/327 +FALLBACK_ALLOW_LIST = { + "torchvision::roi_align", +} + + +def sdpa_constraint(fx_node, *args, **kwargs): + # sdpa requires dense last dimension] + + def apply_constraint(arg, fx_arg): + if not isinstance(arg, ir.IRNode): + return arg + + meta_val = fx_arg.meta["val"] + meta_stride = meta_val.stride() + + stride_order = ir.get_stride_order(meta_stride) + if stride_order and stride_order[-1] != 0: + # contiguous stride order + stride_order = list(reversed(range(len(arg.get_size())))) + + if not meta_val.is_cuda: + return ir.ExternKernel.require_stride_order(arg, stride_order) + + # This is the minimum alignment required by SDPA kernels for attention_bias. + # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask + ALIGNMENT = 8 + + assert isinstance(arg, TensorBox) + if len(arg.get_size()) not in (3, 4): + return arg + + def is_aligned_realized_tensor(x): + aligned_strides = all( + (V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0 + for i in range(len(x.get_stride()) - 1) + ) + return ( + V.graph.sizevars.size_hint(x.get_stride()[-1]) + ) == 1 and aligned_strides + + try: + arg.get_stride() + if is_aligned_realized_tensor(arg): + return V.graph.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride + ) + except AttributeError: + pass + + def is_aligned(x): + return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 + + if isinstance(arg.data, ir.BaseView): + if not is_aligned(arg): + if is_aligned(arg.unwrap_view()): + return V.graph.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride + ) + + return ir.ExternKernel.require_stride_order(arg, stride_order) + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +# WIP +make_fallback(aten._adaptive_avg_pool3d) # @isuruf +make_fallback(aten.adaptive_max_pool3d) # @isuruf +make_fallback(aten.fractional_max_pool3d) # @isuruf +make_fallback(aten.max_pool3d_with_indices) # @isuruf (can this one be implemented?) + + +# 1) Easy +make_fallback(aten.uniform, warn=False) +make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py) +make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks +make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl? +make_fallback(aten.searchsorted) # bucketized is implemented (see eager impl) + + +# 1.5) Easy or Impossible +make_fallback(aten._cdist_forward) # p=2 should be feasible +make_fallback(aten._cdist_backward) + +# 2) Medium +make_fallback(aten.max_unpool2d) +make_fallback(aten.max_unpool3d) +make_fallback(aten._trilinear) + + +# 3) Difficult +# Scans +# See the discussion at +# https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19 +make_fallback(aten.segment_reduce.default) +make_fallback(aten._segment_reduce_backward.default) + +# Histogram (need to implement Histogram IR) +make_fallback(aten.histc) +make_fallback(aten.histogram.bin_ct) +make_fallback(aten._histogramdd_bin_edges.default) +make_fallback(aten._histogramdd_from_bin_cts.default) + +# Need templated kernel +make_fallback(aten.addbmm) +make_fallback(aten._addmm_activation, warn=False) + +# Need templated kernel. Probably impossible to write efficiently +make_fallback(aten.convolution_backward, constrain_to_fx_strides) +make_fallback(aten._cudnn_rnn, require_dense) +make_fallback(aten._cudnn_rnn_backward, require_contiguous) + +# Haven't checked but sound difficult / impossible +make_fallback(aten._embedding_bag, require_contiguous) +make_fallback(aten._embedding_bag_forward_only, require_contiguous) +make_fallback(aten._embedding_bag_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._fused_moving_avg_obs_fq_helper) +make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) + + +# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp +make_fallback(aten.max_pool3d_with_indices_backward) +make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) +make_fallback(aten._adaptive_avg_pool3d_backward) +make_fallback(aten.adaptive_max_pool2d_backward) +make_fallback(aten.adaptive_max_pool3d_backward) +make_fallback(aten.fractional_max_pool2d_backward) +make_fallback(aten.fractional_max_pool3d_backward) +make_fallback(aten.replication_pad1d_backward) +make_fallback(aten.replication_pad2d_backward) +make_fallback(aten.upsample_linear1d_backward) +make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) +make_fallback(aten.upsample_trilinear3d_backward) +make_fallback(aten.grid_sampler_2d_backward, require_dense) +make_fallback(aten._pdist_backward) + + +# 5) Impossible (missing triton/CPU features) + +# Sorting / Sorting-like +make_fallback(aten.sort) +make_fallback(aten.sort.stable) +make_fallback(aten.kthvalue) +make_fallback(aten.topk) +make_fallback(aten.mode) +make_fallback(aten.median) +make_fallback(aten.nanmedian) +make_fallback(aten.randperm) +# see: https://github.com/pytorch/pytorch/pull/121354 +make_fallback(aten.resize_) +make_fallback(aten.resize_as_) + +# Linalg +make_fallback(aten._linalg_det) +make_fallback(aten.linalg_householder_product) +make_fallback(aten.linalg_inv_ex) +make_fallback(aten.linalg_ldl_factor_ex) +make_fallback(aten.linalg_ldl_solve) +make_fallback(aten.linalg_lu) +make_fallback(aten.linalg_lu_factor_ex) +make_fallback(aten.linalg_lu_solve) +make_fallback(aten.linalg_matrix_exp) +make_fallback(aten.linalg_qr) +make_fallback(aten._linalg_slogdet) +make_fallback(aten._linalg_solve_ex) +make_fallback(aten.linalg_solve_triangular) +make_fallback(aten._linalg_svd) +make_fallback(aten.lu_unpack) +make_fallback(aten.ormqr) +make_fallback(aten._linalg_check_errors) +make_fallback(aten.linalg_pinv.atol_rtol_tensor) +make_fallback(aten._linalg_eigh) +make_fallback(aten.triangular_solve) +make_fallback(aten.linalg_cholesky_ex) +make_fallback(aten.cholesky_inverse) +make_fallback(aten.cholesky_solve) +make_fallback(aten.geqrf) +make_fallback(aten._fft_r2c) # needs complex as well + +# Data dependent (are these necessary?) +make_fallback(aten.nonzero.default) + +# Misc +make_fallback(aten.gcd.default, warn=False) +make_fallback(aten._thnn_fused_lstm_cell, require_dense) +make_fallback(torch._prims.rng_prims.run_and_save_rng_state) +make_fallback(torch._prims.rng_prims.run_with_rng_state) + +# Implmented / Half implemented +# Scans. Implemented for CUDA, missing CPU +make_fallback(aten.masked_scatter) +make_fallback(aten.masked_scatter_backward) + +# Complex number support +make_fallback(aten.view_as_complex, require_contiguous) +make_fallback(aten.angle) # needs complex + +# Needs efficentzerotensor +make_fallback(aten._efficientzerotensor) + +# Needs Sparse +make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) +make_fallback(aten.to_sparse) +make_fallback(aten._to_sparse) + +# Needs dimname support +make_fallback(aten.zeros.names) + +# 6) Pattern-matched +make_fallback( + aten._scaled_dot_product_efficient_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_efficient_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_cudnn_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_cudnn_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback(aten._flash_attention_forward.default, sdpa_constraint) +make_fallback(aten._flash_attention_backward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_forward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_backward.default, sdpa_constraint) + +# index_reduce requires fallback when use_scatter_fallback(...) returns True +make_fallback(aten.index_reduce) + + +# Register with type_promotion_kind None. +# For example, fp16.copy_(fp32) should **not** promote the first input's dtype. +@register_lowering(aten.copy, type_promotion_kind=None) +def copy(self, src, non_blocking=False): + x = src + if self.get_device() != src.get_device(): + x = to_device(x, self.get_device()) + if self.get_dtype() != src.get_dtype(): + x = to_dtype(x, self.get_dtype()) + + if self.get_size() != src.get_size(): + out = expand(x, self.get_size()) + return clone(out) + return clone(x) + + +@register_lowering(aten.clone) +def clone(x, *, memory_format=None): + # TODO(jansel): memory format + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=list(x.get_size()), + ) + + +def clone_preserve_reinterpret_view(x): + reinterpret_view_layouts = [] + if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView): + x = x.data # unwrap TensorBox + while isinstance(x, ir.ReinterpretView): + reinterpret_view_layouts.append(x.get_layout()) + x = x.data + x = TensorBox(x) + + x = clone(x) + + if reinterpret_view_layouts: + x = x.data # unwrap TensorBox + for layout in reinterpret_view_layouts[::-1]: + x = ir.ReinterpretView(x, layout) + x = TensorBox(x) + + return x + + +if hasattr(aten, "lift_fresh_copy"): + register_lowering(aten.lift_fresh_copy)(clone) + + +@register_lowering(prims.iota) +def iota( + length, + *, + start, + step, + dtype, + device, + requires_grad, +): + def fn(index): + return ops.index_expr(step * index[0] + start, dtype=dtype) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=fn, + ranges=[length], + ) + + +@register_lowering(aten.select_scatter, type_promotion_kind=None) +def select_scatter(x, src, dim: int, index: int): + assert x.get_dtype() == src.get_dtype() + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): + index = index + x.get_size()[dim] + V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type] + V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type] + src = expand(unsqueeze(src, dim), x.get_size()) + src_loader = src.make_loader() + + def inner_fn(idx): + return ops.where( + ops.eq( + ops.index_expr(idx[dim], torch.int32), + ops.index_expr(index, torch.int32), + ), + src_loader(idx), + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + ) + + +@register_lowering(aten.slice_scatter, type_promotion_kind=None) +def slice_scatter(x, src, dim=0, start=None, end=None, step=1): + assert x.get_dtype() == src.get_dtype() + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + dim_size = x.get_size()[dim] + + start, end = ir.SliceView.normalize_start_end(x, dim, start, end) + + src_size = list(x.get_size()) + src_size[dim] = FloorDiv(end - start + (step - 1), step) + src = expand(src, src_size) + src_loader = src.make_loader() + + def inner_fn(idx): + if start == 0 and end == dim_size and step == 1: + # selecting every element is the same as just src.clone() + return src_loader(idx) + + idx_dim = ops.index_expr(idx[dim], torch.int64) + src_idx = list(idx) + src_idx[dim] = FloorDiv(idx[dim] - start, step) + + mask = [] + if start != 0: + mask.append( + ops.ge( + idx_dim, + ops.index_expr(sympy.expand(start), torch.int64), + ) + ) + if end != dim_size: + mask.append( + ops.lt( + idx_dim, + ops.index_expr(sympy.expand(end), torch.int64), + ) + ) + if step != 1: + mask.append( + ops.eq( + ops.index_expr( + ModularIndexing(idx[dim] - start, 1, step), torch.int64 + ), + ops.constant(0, torch.int64), + ) + ) + assert mask + mask = functools.reduce(ops.and_, mask) + src_val = ops.masked( + mask, + lambda: src_loader(src_idx), + 0 if is_integer_type(x) else 0.0, + ) + return ops.where( + mask, + src_val, + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + ) + + +def _unwrap(x): + if isinstance(x, (list, tuple)) and len(x) > 0: + return _unwrap(x[0]) + return x + + +@register_lowering([torch.tensor, aten.scalar_tensor]) +def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + if isinstance(_unwrap(data), int): + dtype = dtype or torch.int64 + else: + dtype = dtype or torch.get_default_dtype() + + ranges: List[sympy.Expr] = [] + + if isinstance(data, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(data, dtype) + + elif isinstance(data, (float, int)): + + def inner_fn(index): + return ops.constant(data, dtype) + + elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8: + # inline small tensors + ranges.append(sympy.Integer(len(data))) + + def inner_fn(index): + def binary_search(start, end): + assert start < end + if end - start == 1: + return ops.constant(data[start], dtype) + mid = (end - start) // 2 + start + return ops.where( + ops.lt( + ops.index_expr(index[0], torch.int64), + ops.constant(mid, torch.int64), + ), + binary_search(start, mid), + binary_search(mid, end), + ) + + if len(data) == 0: + return ops.constant(0, dtype) + return binary_search(0, len(data)) + + else: + return V.graph.add_tensor_constant( + torch.tensor(data, dtype=dtype, device=device) + ) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + ) + + +@register_lowering(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if isinstance(data, TensorBox): + if dtype is not None: + data = to_dtype(data, dtype) + if device is not None: + data = to_device(data, device) + return data + return tensor(data, dtype=dtype, device=device) + + +@register_lowering(torch.LongTensor) +def long_tensor(data): + return tensor(data, dtype=torch.int64) + + +@register_lowering(aten._local_scalar_dense) +def _local_scalar_dense(data): + from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings + + # This is interesting! Most lowerings return tensors, so you can just + # return the buffer you allocated and it will get used (or not used, if + # it's dead.) But _local_scalar_dense (aka item) returns an int, + # not a Tensor, so you would have a type mismatch if you return a buffer; + # we are obligated to return a sympy expression instead. However, + # we need to actually codegen the .item() call somehow. We do this + # by registering a faux buffer for the DynamicScalar IR node, which is + # solely responsible for generating this .item(). The buffer is + # not used for anything (notice we discard it); at codegen time, + # the "buffer" just gets assigned None. + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] + ) + assert len(unbacked_bindings) == 1, unbacked_bindings + # NB: Have to be very careful here. V.graph.current_node.meta["val"] + # seemingly also contains a symbol which you want to do binding for, + # but it actually isn't. In particular, if we have later performed + # a deferred runtime assert saying that u0 == s0, you will actually + # see s0 from expr! This is bad because we need to actually generate + # the assert that says u0 == s0, so we need to know where to get u0 + # from (this call). In particular, we must use unbacked_bindings, which + # is guaranteed to have the original, unreplaced symbol in question. + # + # NB2: Another thing we have to be very careful about are symbol bindings + # that require nontrivial refinement, e.g., when you have a binding site + # x: Sym(u0 * 4) = y.item(). Here, the code generation must do a division + # in order to appropriately bind u0. This is communicated via the keypath + # in unbacked_bindings, and we need to hold onto it in order to generate + # code appropriately for this case. + binding_sym, keypath = next(iter(unbacked_bindings.items())) + buffer = ir.DynamicScalar(binding_sym, keypath, data) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + # NB: the replaced expr is OK to use directly downstream, we want + # simplifications in this case! + val = V.graph.current_node.meta["val"] + if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): + return val.node.expr + else: + return sympy.sympify(val) + + +@register_lowering(aten._assert_scalar) +def _assert_scalar(data, msg): + # NB: These will be handled at codegen time + # Not sure if we are guaranteed to be able to serve out truth from the + # deferred_runtime_asserts, TODO: try this assert out + # assert bool(data.scalar), data + return None + + +def _full(fill_value, device, dtype, size): + value = fill_value + if not isinstance(fill_value, (int, float)) and hasattr(value, "value"): + value = value.value + + if isinstance(value, (int, float)): + + def inner_fn(index): + return ops.constant(value, dtype) + + elif isinstance(value, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(value, dtype) + + else: + assert len(value.get_size()) == 0 + value_loader = value.make_loader() + + def inner_fn(index): + return value_loader([]) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + +@register_lowering(aten.full_like, type_promotion_kind=None) +def full_like(x, fill_value, **kwargs): + return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) + + +def tensor_constructor(fill_value): + # torch.zeros, torch.ones, etc + def inner( + *size, + names=None, + dtype=None, + device=None, + layout=None, + pin_memory=False, + memory_format=None, + ): + assert_nyi(names is None, "named tensors") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + device = decode_device(device) + dtype = dtype or torch.get_default_dtype() + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + # See https://github.com/pytorch/pytorch/issues/118102 + # All sizes at lowering time should be sympy.Symbol, not SymInt! + for s in size: + assert not isinstance(s, torch.SymInt) + size = [sympy.expand(s) for s in size] + return _full(fill_value, device, dtype, size) + + return inner + + +@register_lowering([torch.empty, aten.empty]) +def empty( + *size, + names=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, +): + assert_nyi(names is None, "named tensors") + device = decode_device(device) + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +def create_tensor_like(creation_fn): + """ + Shim to convert X_like(...) into X(...). For example zeros_like() into zeros(). + """ + + def _constant_like( + x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None + ): + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + if dtype is None: + dtype = x.get_dtype() + else: + dtype = decode_dtype(dtype) + device = device or x.get_device() + size = list(x.get_size()) + return creation_fn( + size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory + ) + + return _constant_like + + +def constant_like(fill_value): + return create_tensor_like(tensor_constructor(fill_value)) + + +empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty)) +ones_like = create_tensor_like(tensor_constructor(1)) +zeros_like = create_tensor_like(tensor_constructor(0)) + + +def new_constant(fill_value): + def _new_constant( + x, size, *, dtype=None, layout=None, device=None, pin_memory=None + ): + assert isinstance(size, (list, tuple)) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = decode_dtype(dtype) or x.get_dtype() + device = device or x.get_device() + size = [sympy.Integer(s) for s in size] + return _full(fill_value, device, dtype, size) + + return _new_constant + + +@register_lowering(aten.new_empty) +def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_lowering(aten.empty_strided) +def empty_strided( + size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + assert isinstance(size, (list, tuple)) + assert isinstance(stride, (list, tuple, type(None))) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = decode_dtype(dtype) or torch.get_default_dtype() + device = device or torch.tensor(0.0).device + pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size) + pointwise.realize() + buffer = pointwise.data.data + # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode + buffer.data.ranges = [0] * len(size) + assert isinstance(buffer, ir.ComputedBuffer) + size = [sympy.expand(s) for s in size] + stride = ( + [sympy.expand(s) for s in stride] + if stride + else ir.FlexibleLayout.contiguous_strides(size) + ) + buffer.layout = ir.FixedLayout( + device=device, + dtype=dtype, + size=size, + stride=stride, + ) + return pointwise + + +@register_lowering(aten.new_empty_strided) +def new_empty_strided( + x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, stride, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_lowering(prims.copy_strided.default) +def copy_strided(x, stride): + stride = [V.graph.sizevars.size_hint(s) for s in stride] + stride_order = sorted(range(len(stride)), key=stride.__getitem__) + return ir.ExternKernel.require_stride_order(x, stride_order) + + +@register_lowering([torch.full, aten.full]) +def full(size, fill_value, **kwargs): + assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition" + return tensor_constructor(fill_value)(size, **kwargs) + + +@register_lowering(aten.gather, type_promotion_kind=None) +def gather(x, dim, index, sparse_grad=False): + # sparse_grad doesn't affect forward computation, + # and backward tracing is taken care of by AOT Autograd + assert isinstance(x, TensorBox) + if index.get_numel() == 0: + # Empty index case. Return an empty array with the same shape + return new_empty(x, index.get_size()) + + assert index.get_dtype() == torch.int64 + size = x.get_size() + offset = len(size) == 0 + dim = _validate_dim(x, dim, offset) + + if offset: + x = expand(x, [1]) + size = [1] + + x_loader = x.make_loader() + index_loader = index.make_loader() + + def fn(idx): + idx = list(idx) + gather_idx = ops.indirect_indexing(index_loader(idx), size[dim]) + if len(idx) == 0: + idx = [gather_idx] + else: + idx[dim] = gather_idx + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + ) + + +@register_lowering(aten.embedding, type_promotion_kind=None) +def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + assert not sparse + assert isinstance(weight, TensorBox) + assert isinstance(indices, TensorBox) + assert "int" in str(indices.get_dtype()) + + weight_loader = weight.make_loader() + indices_loader = indices.make_loader() + indices_ndim = len(indices.get_size()) + weight_size = weight.get_size() + new_size = [*indices.get_size(), *weight_size[1:]] + + def fn(idx): + assert len(idx) == len(new_size), f"{idx} != {new_size}" + var_index = indices_loader(idx[:indices_ndim]) + weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [ + *idx[indices_ndim:] + ] + return weight_loader(weight_idx) + + return Pointwise.create( + device=weight.get_device(), + dtype=weight.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + + +def check_and_broadcast_indices(indices, device): + assert all( + i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8) + for i in indices + if i is not None + ), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}" + if any( + i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None + ): + raise NotImplementedError("Fallback for bool indices") + + valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)] + assert len(valid_idxs) > 0, "requires at least 1 non-None index" + new_indices = [None] * len(indices) + for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])): + # Eager allows indices to be CPU tensor when running on CUDA + # FIXME: Calling to_device(x, device) should work but + # test_advancedindex_mixed_cpu_devices still fails + if x.get_device() != device: + raise NotImplementedError("Fallback when indices is on a different device") + new_indices[i] = x + return new_indices, valid_idxs + + +def index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + x_loader, + check, +): + # Note that behavior of indexing differs when there are non consecutive + # tensors. In this case, the tensor index is pulled to the beginning. + # + # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7) + # x = torch.tensor[1,2] + # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will + # be pulled to the front. + non_consecutive_tensors = False + for previous, current in zip(tensor_indices, tensor_indices[1:]): + if current - previous != 1: + non_consecutive_tensors = True + + output_size = [x_size[i] for i, val in enumerate(indices) if val is None] + output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]] + + first_tensor_index = tensor_indices[0] + if non_consecutive_tensors: + output_size = tensor_size + output_size + else: + output_size = ( + output_size[:first_tensor_index] + + tensor_size + + output_size[first_tensor_index:] + ) + + def fn(idx): + assert len(idx) == len(output_size) + assert len(indices_loaders) == len(indexed_size) + + rank = len(tensor_size) + new_index = [] + first_tensor_index = tensor_indices[0] + start_offset = 0 if non_consecutive_tensors else first_tensor_index + next_idx = 0 + for i in range(tensor_indices[-1] + 1): + if i == start_offset: + next_idx += rank + if indices[i] is None: + assert next_idx < len(idx) + new_index.append(idx[next_idx]) + next_idx += 1 + else: + loader = indices_loaders[i] + assert loader is not None + size = indexed_size[i] + new_index.append( + ops.indirect_indexing( + loader(idx[start_offset : start_offset + rank]), + size, + check=check, + ) + ) + new_index = [ + *new_index, + *idx[next_idx:], + ] + return new_index if x_loader is None else x_loader(new_index) + + return output_size, fn + + +def index_impl(x, indices, check): + output_size, inner_fn, _ = index_impl_helper(x, indices, check) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=output_size, + ) + + +def index_impl_helper(x, indices, check): + assert isinstance(indices, (list, tuple)) + x_loader = x.make_loader() + indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) + assert len(tensor_indices) > 0, "Must have at least one valid idx" + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + # no guards on output size, all the guards are set in broadcast_tensors + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + + x_size = x.get_size() + + indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None] + if check and 0 in indexed_size and 0 not in tensor_size: + raise IndexError("index is out of bounds for dimension with size 0") + + indexed_size = [x_size[i] for i in range(len(indices))] + output_size, index_inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=check, + ) + + def inner_fn(idx): + return x_loader(index_inner_fn(idx)) + + return output_size, inner_fn, index_inner_fn + + +@register_lowering(aten.index, type_promotion_kind=None) +def index(x, indices): + try: + return index_impl(x, indices, check=True) + except NotImplementedError: + # Fallback to ATen for boolean indexing + x.realize() + return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)( + x, indices + ) + + +@register_lowering(aten._unsafe_index, type_promotion_kind=None) +def _unsafe_index(x, indices): + return index_impl(x, indices, check=False) + + +# All the indexing decompositions are written in terms of index, index_put, and index_put_ +# We cannot have this lowering as a decomposition as it introduces +# mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead +# code elimination and common subexpression elimination optimizations, which +# assume graphs to be side-effect free. More details at +# https://github.com/pytorch/torchdynamo/issues/1235 +# and +# https://github.com/pytorch/torchdynamo/issues/1863 +@register_lowering(aten.index_put) +def index_put(x, indices, values, accumulate=False): + return index_put_(clone(x), indices, values, accumulate) + + +@register_lowering(aten._unsafe_index_put) +def _unsafe_index_put(x, indices, values, accumulate=False): + return index_put_impl_(clone(x), indices, values, accumulate, check=False) + + +def index_put_as_masked_fill(self, indices, value, accumulate): + if value.get_device() != self.get_device(): + value = to_device(value, self.get_device()) + if accumulate: + value = add(self, value) + return mutate_to(self, where(indices[0], value, self)) + + +def index_put_fallback(self, indices, values, accumulate): + deterministic = torch.are_deterministic_algorithms_enabled() + if is_triton(values) and (accumulate or deterministic): + msg = ( + "index put with accumulate." + if not deterministic + else "deterministic index put." + ) + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate) + return self + + +@register_lowering(aten.index_put_, type_promotion_kind=None) +def index_put_(self, indices, values, accumulate=False): + return index_put_impl_(self, indices, values, accumulate, check=True) + + +@register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None) +def _unsafe_index_put_(self, indices, values, accumulate=False): + return index_put_impl_(self, indices, values, accumulate, check=False) + + +def index_put_impl_(self, indices, values, accumulate, check): + # Dispatch to masked fill for single boolean index with single value + if ( + values.get_numel() == 1 + and len(indices) == 1 + and indices[0].get_dtype() in {torch.bool, torch.uint8} + ): + mask = indices[0] + for _ in range(len(mask.get_size()), len(self.get_size())): + mask = unsqueeze(mask, -1) + return index_put_as_masked_fill(self, [mask], values, accumulate) + + # Fallback in torch deterministic mode + if torch.are_deterministic_algorithms_enabled(): + return index_put_fallback(self, indices, values, accumulate) + + # Fallback if there is a boolean index + for index in indices: + if index is not None and index.get_dtype() in {torch.bool, torch.uint8}: + return index_put_fallback(self, indices, values, accumulate) + + x_size = self.get_size() + x_ndim = len(x_size) + + if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()): + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + self = index_put_fallback(self, indices, values, accumulate) + if x_ndim == 0: + self = view(self, []) + return self + + values = to_dtype(values, self.get_dtype()) + + try: + # Note that code will only get here when dtype is uint32 + indices, tensor_indices = check_and_broadcast_indices( + indices, self.get_device() + ) + except NotImplementedError: + return index_put_fallback(self, indices, values, accumulate) + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + + assert isinstance(self, TensorBox) + self.realize() + + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=check, + ) + + values = expand(values, expected_vals_size) + # all guards are set above during broadcast_tensors and expand + + scatter = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add" if accumulate else None, + ) + buffer = ir.ComputedBuffer( + None, + ir.MutationLayoutSHOULDREMOVE(self), + scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + if x_ndim == 0: + self = view(self, []) + return self + + +fallback__unsafe_masked_index = fallback_handler( + aten._unsafe_masked_index.default, add_to_fallback_set=False +) + +fallback__unsafe_masked_index_put_accumulate = fallback_handler( + aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False +) + + +@register_lowering(aten._unsafe_masked_index, type_promotion_kind=None) +def _unsafe_masked_index(self, mask, indices, fill): + ranges, _, _unsafe_index_fn = index_impl_helper(self, indices, check=False) + mask_loader = mask.make_loader() + self_loader = self.make_loader() + + def inner_fn(idx): + if mask.dtype != torch.bool: + mask_val = ops.to_dtype(mask_loader(idx), torch.bool) + else: + mask_val = mask_loader(idx) + return ops.masked(mask_val, lambda: self_loader(_unsafe_index_fn(idx)), fill) + + return Pointwise.create( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=inner_fn, + ranges=ranges, + ) + + +@register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None) +def _unsafe_masked_index_put_accumulate(x, mask, indices, values): + masked_value = where(mask, values, 0) + shape = x.get_size() + clamped_indices = [ + clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None + for i in range(len(indices)) + ] + # TODO: use a masked store for this. currently only triton + # supports masked stores and cpp backend does not. + return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True) + + +@make_pointwise +def clamp(a, min, max): + return ops.maximum(min, ops.minimum(max, a)) + + +@register_lowering(aten.as_strided_scatter, type_promotion_kind=None) +def as_strided_scatter(self, src, size, stride, storage_offset=None): + output = clone(self) + output_view = as_strided(output, size, stride, storage_offset) + copy_(output_view, src) + return output + + +@register_lowering(aten.scatter, type_promotion_kind=None) +def scatter(x, dim: int, index, src, **kwargs): + return scatter_(clone(x), dim, index, src, **kwargs) + + +def scatter_fallback( + op_overload: torch._ops.OpOverload, + self, + dim: int, + index, + src, + *, + reduce: Optional[str] = None, + include_self: bool = True, +): + src_is_tensor = isinstance(src, TensorBox) + if use_scatter_fallback( + op_overload, + reduce, + self.get_dtype(), + src.get_dtype() if src_is_tensor else type(src), + src.get_device().type if src_is_tensor else "not impl", + src_is_tensor, + ): + ir.ScatterFallback( + op_overload, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + return self + + return None + + +@register_lowering(aten.scatter_, type_promotion_kind=None) +def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None): + assert reduce in {None, "add", "multiply"} + if reduce is None: + op_overload = getattr(aten.scatter_, V.graph.current_node.target._overloadname) # type: ignore[union-attr] + fallback_result = scatter_fallback( + op_overload, self, dim, index, src, reduce=reduce + ) + if fallback_result is not None: + return fallback_result + + if reduce == "add": + reduce = "sum" + elif reduce == "multiply": + reduce = "prod" + return scatter_reduce_(self, dim, index, src, reduce) + + +@register_lowering(aten.scatter_add, type_promotion_kind=None) +def scatter_add(x, dim: int, index, src): + return scatter_add_(clone(x), dim, index, src) + + +@register_lowering(aten.scatter_add_, type_promotion_kind=None) +def scatter_add_(x, dim: int, index, src): + return scatter_reduce_(x, dim, index, src, "sum") + + +@register_lowering(aten.scatter_reduce, type_promotion_kind=None) +def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs): + return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs) + + +@register_lowering(aten.scatter_reduce_, type_promotion_kind=None) +def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True): + assert reduce in {None, "sum", "prod", "mean", "amax", "amin"} + assert ( + len(aten.scatter_reduce_.overloads()) == 1 + and "two" in aten.scatter_reduce_.overloads() + ), "aten.scatter_reduce_.two is not the unique overload of aten.scatter_reduce_" + + if isinstance(src, Number): + src = full_like(self, src) + + fallback_result = scatter_fallback( + aten.scatter_reduce_.two, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + + if fallback_result: + return fallback_result + + assert isinstance(self, TensorBox) + assert "int" in str(index.get_dtype()) + + ndim = len(self.get_size()) + if ndim == 0: + self = view(self, [1]) + + if isinstance(src, TensorBox) and len(src.get_size()) == 0: + src = view(src, [1]) + + if isinstance(index, TensorBox) and len(index.get_size()) == 0: + index = view(index, [1]) + + if index.get_numel() == 0: + return self + + dim = _validate_dim(self, dim) + + self.realize() + index_loader = index.make_loader() + src_loader = src.make_loader() if isinstance(src, TensorBox) else None + + def output_indexer(idx): + # self is captured from the end of the function, so it may have 0 dim + shape = self.get_size() + ndim = len(shape) + indirect_idx = list(idx) + indirect_idx[dim] = ops.indirect_indexing( + index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False + ) + return indirect_idx + + def fn(idx): + if src_loader: + return src_loader(idx) + else: + # src is a scalar + return ops.constant(src, self.get_dtype()) + + def backend_reduce_str(reduce): + if reduce == "sum": + return "atomic_add" + else: + # TODO: Need to support more reduction type + assert reduce is None + return None + + if not include_self: + # zero out the corresponding elements first + zero_out = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=lambda index: ops.constant(0, self.get_dtype()), + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=None, + ) + buffer = ir.ComputedBuffer( + None, + ir.MutationLayoutSHOULDREMOVE(self), + zero_out, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + scatter = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=backend_reduce_str(reduce), + ) + buffer = ir.ComputedBuffer( + None, + ir.MutationLayoutSHOULDREMOVE(self), + scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + if ndim == 0: + self = view(self, []) + return self + + +def upsample_nearestnd( + x, + output_size, + scales_x: Tuple[Optional[float], ...], + n: int = 2, + exact: bool = False, +): + x.realize_hint() # elements are reused + x_loader = x.make_loader() + i_sizes = x.get_size()[-n:] + batch = x.get_size()[:-n] + i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes] + + assert len(scales_x) == n + o_sizes = output_size + + inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)] + for i, scale in enumerate(scales_x): + if scale is not None: + inv_scales[i] = 1.0 / scale + + def scale_fn(x, scale, size): + # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5) + # = floor(scale * (output_index + 0.5)) + # Nearest: input_index = floor(scale * output_index) + x = ops.index_expr(x, torch.float32) + if exact: + x = ops.add(x, ops.constant(0.5, torch.float32)) + x = ops.mul(x, ops.constant(scale, torch.float32)) + x = ops.to_dtype(x, torch.int32) + return ops.indirect_indexing(x, size, check=False) + + def fn(idx): + x = idx[-n:] + b = idx[:-n] + return x_loader( + [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]] + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=[*batch, *o_sizes], + ) + + +@register_lowering(aten.upsample_nearest1d.default) +def upsample_nearest1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1) + + +@register_lowering(aten._upsample_nearest_exact1d.default) +def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True) + + +@register_lowering(aten.upsample_nearest2d.default) +def upsample_nearest2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2) + + +@register_lowering(aten._upsample_nearest_exact2d.default) +def _upsample_nearest_exact2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True) + + +@register_lowering(aten.upsample_nearest3d.default) +def upsample_nearest3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3) + + +@register_lowering(aten._upsample_nearest_exact3d.default) +def _upsample_nearest_exact3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd( + x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True + ) + + +def _create_constants(*args, dtype): + return tuple(ops.constant(a, dtype) for a in args) + + +@register_lowering(prims.rev.default) +def rev(x, dims): + # note - dims pre-canonicalized + x_loader = x.make_loader() + sizes = x.get_size() + + def loader(idx): + idx = list(idx) + assert len(idx) == len(sizes) + for dim in dims: + idx[dim] = (sizes[dim] - 1) - idx[dim] + + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=loader, + ranges=sizes, + ) + + +@register_lowering(aten.constant_pad_nd, type_promotion_kind=None) +def constant_pad_nd(x, padding, fill_value=0): + assert (len(padding) % 2) == 0 + if all(p == 0 for p in padding): + return clone(x) + + sizes = x.get_size() + + bounds = list(reversed(list(zip(padding[::2], padding[1::2])))) + n = len(sizes) - len(bounds) + + # if padding is a complicated expression, hoist it + bounds_precomp: List[Tuple[sympy.Symbol, Any]] = [] + for l, h in bounds: + bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type] + + output_size = list(sizes[:n]) + mask_sizes = [] + for (low, high), size in zip(bounds, sizes[n:]): + mask_sizes.append(size) + output_size.append(sympy.expand(size + low + high)) + assert len(output_size) == len(sizes) + fill_value = dtype_to_type(x.get_dtype())(fill_value) + + def mask(index): + mask = [] + for idx, (low, high), length in zip(index[n:], bounds, mask_sizes): + if low != 0: + mask.append(range_mask_low(idx, 0)) + if high != 0: + mask.append(range_mask_high(idx, length)) + mask = functools.reduce(ops.and_, mask) + return ops.masked(mask, lambda: x_loader(index), fill_value) + + def offset_fn(index): + new_index = list(index[:n]) + for idx, (low, high) in zip(index[n:], bounds_precomp): + new_index.append(idx - low) + assert len(new_index) == len(index) + return mask(new_index) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=offset_fn, + ranges=output_size, + ) + + +def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]): + return ops.ge( + ops.index_expr(i, torch.int64), + ops.index_expr(sympy.Integer(low), torch.int64), + ) + + +def range_mask_high(i: sympy.Expr, high: sympy.Expr): + return ops.lt( + ops.index_expr(i, torch.int64), + ops.index_expr(high, torch.int64), + ) + + +def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr): + return ops.and_( + range_mask_low(i, low), + range_mask_high(i, high), + ) + + +def constant_boundary_condition( + x, fill_value, padding=None, pad_fill_value=1.0, dim=None +): + h = x.get_size()[-dim:] + x_loader = x.make_loader() + padding_h = padding or [0] * dim + + def load(index): + prefix = index[:-dim] + ih = index[-dim:] + + mask = functools.reduce( + ops.and_, + [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)], + ) + return ( + ops.masked( + mask, + lambda: constant_boundary_condition(x, pad_fill_value, dim=dim)( + [*prefix, *ih] + ), + fill_value, + ) + if padding + else ops.masked(mask, lambda: x_loader([*prefix, *ih]), fill_value) + ) + + return load + + +def pooling_size(x, i, kernel_size, stride, padding, ceil_mode): + x_out = FloorDiv( + x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i] + ) + + if ceil_mode: + x_alt = FloorDiv( + x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i] + ) + if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: + # Sliding windows must start within the input or left padding + x_alt -= 1 # type: ignore[assignment] + V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] + if V.graph.sizevars.size_hint(x_out - x_alt) == 0: + # ceil mode is actually a no-op, lets guard on that + V.graph.sizevars.guard_equals(x_out, x_alt) + ceil_mode = False + else: + x_out = x_alt + return x_out, ceil_mode + + +def should_fallback_max_pool2d_with_indices(kernel_size, dilation): + kernel_size = pad_listlike(kernel_size, 2) + window_size = kernel_size[0] * kernel_size[1] + return (window_size > 25) or any(d > 1 for d in dilation) + + +def max_pool2d_checks( + x, kernel_size, stride, padding, dilation, *, assert_fallback=None +): + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + dilation = pad_listlike(dilation, 2) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + assert len(x.get_size()) in (3, 4) + + use_fallback = should_fallback_max_pool2d_with_indices(kernel_size, dilation) + if assert_fallback is not None: + assert use_fallback == assert_fallback + + return kernel_size, stride, padding, dilation, use_fallback + + +@register_lowering(prims._low_memory_max_pool2d_with_offsets, type_promotion_kind=None) +def _low_memory_max_pool2d_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode=False, +): + # assert we are not on a fallback path, the inductor decomp should have guaranteed this + kernel_size, stride, padding, dilation, _ = max_pool2d_checks( + x, kernel_size, stride, padding, dilation, assert_fallback=False + ) + + x.realize_hint() + *batch, h, w = x.get_size() + + h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode) + w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode) + + dtype = x.dtype + min_value = ( + False + if dtype is torch.bool + else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min) + ) + + new_size = list(batch) + [h_out, w_out] + if padding[0] or padding[1] or ceil_mode1 or ceil_mode2: + x_loader = constant_boundary_condition(x, min_value, dim=2) + else: + x_loader = x.make_loader() + + def fn(idx, return_index): + *prefix, bh, bw = idx + maxval = None + maxindex = None + for h_inc, w_inc in itertools.product( + range(kernel_size[0]), range(kernel_size[1]) + ): + ih = bh * stride[0] + h_inc - padding[0] + iw = bw * stride[1] + w_inc - padding[1] + val = x_loader([*prefix, ih, iw]) + if return_index: + index = ops.index_expr(h_inc * kernel_size[1] + w_inc, torch.int8) + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + if maxval is None: + maxval = val + else: + maxval = ops.maximum(val, maxval) + if return_index: + return maxindex + else: + return maxval + + out = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=functools.partial(fn, return_index=False), + ranges=new_size, + ) + offsets = Pointwise.create( + device=x.get_device(), + dtype=torch.int8, + inner_fn=functools.partial(fn, return_index=True), + ranges=new_size, + ) + return out, offsets + + +@register_lowering( + prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None +) +def _low_memory_max_pool2d_offsets_to_indices( + offsets, kernel_width, input_width, stride, padding +): + # TODO: Generalize to other max pooling flavors, and arbitrary dim + + offsets_loader = offsets.make_loader() + + def increments_to_index(h_inc, w_inc, bh, bw): + w_in = ops.index_expr(input_width, torch.int64) + hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64) + wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64) + ih = hbase + h_inc + iw = wbase + w_inc + return ih * w_in + iw + + def offsets_to_indices(idx): + *prefix, bh, bw = idx + offset = offsets_loader([*prefix, bh, bw]) + kw_const = ops.constant(kernel_width, torch.int32) + h_inc = offset // kw_const + w_inc = offset - (h_inc * kw_const) + return increments_to_index(h_inc, w_inc, bh, bw) + + indices = Pointwise.create( + device=offsets.get_device(), + dtype=torch.int64, + inner_fn=offsets_to_indices, + ranges=offsets.get_size(), + ) + return indices + + +# Fallback selected when we do not decompose to the low-memory path. +make_fallback(aten.max_pool2d_with_indices) + + +fallback_max_pool2d_with_indices_backward = fallback_handler( + aten.max_pool2d_with_indices_backward.default, + add_to_fallback_set=False, +) + + +@register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None) +def max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices +): + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + assert len(x.get_size()) in (3, 4) + + # we will read this many times, so make sure it is computed + grad_output.realize_hint() + try: + gO_stride = grad_output.get_stride() + except AttributeError: + # some classes don't have `get_stride` + # TODO will need a better way of determining if inputs are channels-last + gO_stride = None + if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined] + data = x.data.data # type: ignore[attr-defined] + x_buffer = ir.ComputedBuffer( + name=None, + layout=ir.FlexibleLayout( + device=data.get_device(), + dtype=data.get_dtype(), + size=data.get_size(), + ), + data=data, + ) + x_buffer.decide_layout() + x_stride = x_buffer.get_stride() + else: + try: + x_stride = x.get_stride() + except AttributeError: + x_stride = None + + is_channels_last = (x_stride is not None and x_stride[1] == 1) or ( + gO_stride is not None and gO_stride[1] == 1 + ) + if any(d != 1 for d in dilation): + # dilation NYI + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + *batch, height, width = x.get_size() + *_, pooled_height, pooled_width = grad_output.get_size() + + indices_loader = indices.make_loader() + grad_loader = grad_output.make_loader() + new_size = list(x.get_size()) + + h_window_size = max( + max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1) + for h in range(kernel_size[0] * 2) + ) + w_window_size = max( + max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1) + for w in range(kernel_size[1] * 2) + ) + + window_size = h_window_size * w_window_size + + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + indices_size = indices.get_size() + + def fn(idx): + *prefix, h, w = idx + index_test = ops.index_expr(h * width + w, torch.int32) + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + grad_index = [ + *prefix, + ops.indirect_indexing( + ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))), + indices_size[-2], + check=False, + ), + ops.indirect_indexing( + ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))), + indices_size[-1], + check=False, + ), + ] + + index_actual = indices_loader(grad_index) + grad_part = grad_loader(grad_index) + check = ops.eq(index_actual, index_test) + + if gradient is None: + # don't need mask for 0, 0 + gradient = ops.where( + check, grad_part, ops.constant(0.0, torch.float32) + ) + else: + mask = ops.and_( + ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ), + check, + ) + gradient = ops.where(mask, ops.add(gradient, grad_part), gradient) + assert gradient is not None + return gradient + + out = Pointwise.create( + device=grad_output.get_device(), + dtype=grad_output.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + if is_channels_last: + return ir.ExternKernel.require_channels_last(out) + else: + return out + + +def pad_adaptive_loader(x, pad_val=0.0): + *_, h, w = x.get_size() + x_loader = x.make_loader() + + def load(prefix, increments, start_indices, end_indices): + ih, iw = increments + h_start_index, w_start_index = start_indices + h_end_index, w_end_index = end_indices + + mask = ops.and_( + ops.lt( + ops.index_expr(h_start_index + ih, torch.int64), + ops.index_expr(h_end_index, torch.int64), + ), + ops.lt( + ops.index_expr(w_start_index + iw, torch.int64), + ops.index_expr(w_end_index, torch.int64), + ), + ) + + return ops.masked( + mask, + lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]), + pad_val, + ) + + return load + + +def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out): + h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) + h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) + + w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) + w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + + return h_start_index, h_end_index, w_start_index, w_end_index + + +def _adaptive_pooling_fn( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + result = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + if result is None: + result = val + else: + result = pooling_fn(val, result) + return result + + return fn + + +def _adaptive_pooling_fn_with_idx( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + + index = ops.index_expr( + (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 + ) + + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + + if maxval is None: + maxval = val + else: + maxval = pooling_fn(val, maxval) + + return maxindex + + return fn + + +fallback_adaptive_avg_pool2d = fallback_handler( + aten._adaptive_avg_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten._adaptive_avg_pool2d) +def _adaptive_avg_pool2d(x, output_size): + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) + + h_out, w_out = output_size + + # no-op if the same input and output + if h_in == h_out and w_in == w_out: + return clone(x) + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()) + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + return avg_pool2d(x, kernel_size) + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_avg_pool2d(x, output_size) + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.add, + ) + + ones_loader = pad_adaptive_loader(ones_like(x)) + + def fn(idx): + return ops.truediv( + fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader) + ) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO: should we force these to be realized? + return rv + + +fallback_adaptive_max_pool2d = fallback_handler( + aten.adaptive_max_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.adaptive_max_pool2d) +def adaptive_max_pool2d(x, output_size): + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) + + h_out, w_out = output_size + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty( + o_size, dtype=torch.int64, device=x.get_device() + ) + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + if should_fallback_max_pool2d_with_indices(kernel_size, dilation=[1, 1]): + return max_pool2d_with_indices(x, kernel_size) # type: ignore[name-defined] # noqa: F821 + else: + v, offsets = _low_memory_max_pool2d_with_offsets( + x, + kernel_size, + stride=kernel_size, + padding=[0, 0], + dilation=[1, 1], + ceil_mode=False, + ) + indices = _low_memory_max_pool2d_offsets_to_indices( + offsets, kernel_size[1], w_in, kernel_size, padding=[0, 0] + ) + return v, indices + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_max_pool2d(x, output_size) + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + inner_func_max_val = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.maximum, + ) + + inner_func_max_idx = _adaptive_pooling_fn_with_idx( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.maximum, + ) + + def inner_fn_max_val(idx): + return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf"))) + + def inner_fn_max_idx(idx): + return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf"))) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=inner_fn_max_val, + ranges=new_size, + ) + ri = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=inner_fn_max_idx, + ranges=new_size, + ) + return rv, ri + + +fallback_fractional_max_pool2d = fallback_handler( + aten.fractional_max_pool2d.default, add_to_fallback_set=False +) + + +def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): + out_sz = out_sz[dim] + in_sz = in_sz[dim] + kernel_sz = kernel_sz[dim] + alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1) + samples_loader = samples.make_loader() + + def load(prefix, i): + sample = samples_loader([*prefix, dim]) + i_expr = ops.index_expr(i, samples.get_dtype()) + alpha_expr = ops.index_expr(alpha, samples.get_dtype()) + seq_i = ops.floor((i_expr + sample) * alpha_expr) - ops.floor( + sample * alpha_expr + ) + seq_i = ops.to_dtype(seq_i, torch.int64) + + mask = ops.lt( + i_expr, + ops.index_expr(out_sz - 1, torch.int64), + ) + return ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64)) + + return load + + +@register_lowering(aten.fractional_max_pool2d) +def fractional_max_pool2d(x, kernel_size, output_size, random_samples): + x.realize_hint() + *batch, inp_h, inp_w = x.get_size() + kernel_h, kernel_w = kernel_size + h_out, w_out = output_size + + if kernel_h * kernel_w >= 25: + return fallback_fractional_max_pool2d( + x, kernel_size, output_size, random_samples + ) + + gen_offsets_for_dim = functools.partial( + _fractional_pooling_offsets, + samples=random_samples, + in_sz=[inp_h, inp_w], + out_sz=output_size, + kernel_sz=kernel_size, + ) + + h_index_fn = gen_offsets_for_dim(dim=0) + w_index_fn = gen_offsets_for_dim(dim=1) + x_loader = x.make_loader() + + def fn(idx, return_index): + *prefix, bh, bw = idx + + h_start_index = ops.indirect_indexing(h_index_fn(prefix, bh), inp_h) + w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w) + + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])): + val = x_loader([*prefix, h_start_index + ih, w_start_index + iw]) + if return_index: + index = ops.index_expr( + (h_start_index + ih) * inp_w + w_start_index + iw, torch.int64 + ) + if maxindex is None: + maxindex = index + else: + maxindex = ops.where( + ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex + ) + if maxval is None: + maxval = val + else: + maxval = ops.maximum(val, maxval) + if return_index: + return maxindex + else: + return maxval + + new_size = list(batch) + [h_out, w_out] + rv = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=functools.partial(fn, return_index=False), + ranges=new_size, + ) + + ri = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=functools.partial(fn, return_index=True), + ranges=new_size, + ) + return rv, ri + + +@register_lowering(aten.upsample_nearest2d_backward.default) +def upsample_nearest2d_backward( + x, output_size=None, input_size=None, scales_h=None, scales_w=None +): + x.realize_hint() + + *batch, inp_h, inp_w = x.get_size() + inp_h = V.graph.sizevars.evaluate_static_shape(inp_h) + inp_w = V.graph.sizevars.evaluate_static_shape(inp_w) + + *batch, out_h, out_w = input_size + + if inp_h % out_h == 0 and inp_w % out_w == 0: + return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1) + + h_kernel_max = ceildiv(inp_h, out_h) + w_kernel_max = ceildiv(inp_w, out_w) + + def start_index(index, out_dim, inp_dim): + return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) + + def end_index(index, out_dim, inp_dim): + return start_index((index + 1), out_dim, inp_dim) + + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[inp_h, inp_w], + out_sizes=[out_h, out_w], + pooling_fn=ops.add, + ) + + def fn(idx): + return fn_sum(idx, pad_adaptive_loader(x)) + + rv = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=list(input_size), + ) + + return rv + + +fallback_avg_pool2d = fallback_handler( + aten.avg_pool2d.default, add_to_fallback_set=False +) +fallback_avg_pool3d = fallback_handler( + aten.avg_pool3d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool2d, type_promotion_kind=None) +def avg_pool2d( + x, + kernel_size, + stride=(), + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + return _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim=2, + ) + + +@register_lowering(aten.avg_pool3d, type_promotion_kind=None) +def avg_pool3d( + x, + kernel_size, + stride=(), + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + return _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim=3, + ) + + +def _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim, +): + if not stride: + stride = kernel_size + if not padding: + padding = [0] * dim + kernel_size = pad_listlike(kernel_size, dim) + stride = pad_listlike(stride, dim) + padding = pad_listlike(padding, dim) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == dim + assert len(stride) == dim + assert len(padding) == dim + assert len(x.get_size()) in (dim + 1, dim + 2) + + x.realize_hint() + batch = x.get_size()[:-dim] + h = x.get_size()[-dim:] + + h_out, ceil_modes = zip( + *[ + pooling_size(h[i], i, kernel_size, stride, padding, ceil_mode) + for i in range(dim) + ] + ) + + if any(padding) or any(ceil_modes): + x_loader = constant_boundary_condition(x, 0.0, dim=dim) + had_padding = True + else: + x_loader = x.make_loader() + had_padding = False + + new_size = list(batch) + list(h_out) + dtype = x.get_dtype() + + window_size = functools.reduce(operator.mul, kernel_size) + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + if dim == 2: + fallback = fallback_avg_pool2d + elif dim == 3: + fallback = fallback_avg_pool3d + else: + raise ValueError(f"Unknown dim: {dim}") + + return fallback( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def fn_sum(idx, loader): + prefix = idx[:-dim] + b = idx[-dim:] + total = None + for ih in itertools.product(*[range(kernel_size[i]) for i in range(dim)]): + inp = [b[i] * stride[i] + ih[i] - padding[i] for i in range(dim)] + val = loader([*prefix, *inp]) + if total is None: + total = val + else: + total = ops.add(val, total) + return total + + if not had_padding or divisor_override: + if divisor_override: + scale = 1 / divisor_override + else: + scale = 1.0 / window_size + + def fn(idx): + return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype)) + + else: + + def fn(idx): + prefix = idx[:-dim] + bh = idx[-dim:] + + divide_factors = [] + for i in range(dim): + hstart = bh[i] * stride[i] - padding[i] + hend = sympy.Min(hstart + kernel_size[i], h[i] + padding[i]) + if not count_include_pad: + hstart = sympy.Max(hstart, 0) + hend = sympy.Min(hend, h[i]) + factor = ops.index_expr(hend - hstart, torch.int32) + divide_factors.append(factor) + divide_factor = functools.reduce(ops.mul, divide_factors) + return ops.truediv(fn_sum(idx, x_loader), divide_factor) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO(jansel): should we force these to be realized? + return rv + + +fallback_avg_pool2d_backward = fallback_handler( + aten.avg_pool2d_backward.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None) +def avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + assert divisor_override is None or divisor_override != 0, "divisor must be not zero" + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0] + + assert isinstance(grad_output, TensorBox) + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(x.get_size()) in (3, 4) + + grad_output.realize_hint() # we will read this many times, so make sure it is computed + + *batch, height, width = x.get_size() + + h_out, ceil_mode1 = pooling_size(height, 0, kernel_size, stride, padding, ceil_mode) + w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode) + + grad_loader = grad_output.make_loader() + + had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2 + + *_, pooled_height, pooled_width = grad_output.get_size() + new_size = list(x.get_size()) + dtype = x.get_dtype() + + h_window_size = max( + max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1) + for h in range(kernel_size[0] * 2) + ) + w_window_size = max( + max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1) + for w in range(kernel_size[1] * 2) + ) + + window_size = h_window_size * w_window_size + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(ph, pw): + """ + This computes the scaling factor that we will divide an element + by when `count_include_pad=False` + """ + stride_h = ops.constant(stride[0], torch.int32) + stride_w = ops.constant(stride[1], torch.int32) + pad_h = ops.constant(padding[0], torch.int32) + pad_w = ops.constant(padding[1], torch.int32) + kernel_h = ops.constant(kernel_size[0], torch.int32) + kernel_w = ops.constant(kernel_size[1], torch.int32) + hstart = ops.sub(ops.mul(ph, stride_h), pad_h) + wstart = ops.sub(ops.mul(pw, stride_w), pad_w) + hend = ops.minimum( + ops.add(hstart, kernel_h), + ops.add(ops.index_expr(height, torch.int32), pad_h), + ) + wend = ops.minimum( + ops.add(wstart, kernel_w), + ops.add(ops.index_expr(width, torch.int32), pad_w), + ) + hstart = ops.maximum(hstart, ops.constant(0, torch.int32)) + wstart = ops.maximum(wstart, ops.constant(0, torch.int32)) + hend = ops.minimum(hend, ops.index_expr(height, torch.int32)) + wend = ops.minimum(wend, ops.index_expr(width, torch.int32)) + divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart)) + return divide_factor + + def fn(idx): + *prefix, h, w = idx + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] + else: + scale = compute_pool_size_without_padding(ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where(mask, part, ops.constant(0.0, torch.float32)) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + +fallback_avg_pool3d_backward = fallback_handler( + aten.avg_pool3d_backward.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None) +def avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + assert divisor_override is None or divisor_override != 0, "divisor must be not zero" + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0, 0] + + assert isinstance(grad_output, TensorBox) + assert isinstance(x, TensorBox) + assert len(kernel_size) == 3 + assert len(stride) == 3 + assert len(padding) == 3 + assert len(x.get_size()) in (4, 5) + + grad_output.realize_hint() + + *batch, depth, height, width = x.get_size() + + d_out, ceil_mode_d = pooling_size(depth, 0, kernel_size, stride, padding, ceil_mode) + h_out, ceil_mode_h = pooling_size( + height, 1, kernel_size, stride, padding, ceil_mode + ) + w_out, ceil_mode_w = pooling_size(width, 2, kernel_size, stride, padding, ceil_mode) + + grad_loader = grad_output.make_loader() + had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w + + *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size() + new_size = list(x.get_size()) + dtype = x.get_dtype() + + d_window_size, h_window_size, w_window_size = ( + max( + max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1) + for d in range(kernel_size[i] * 2) + ) + for i in range(3) + ) + + window_size = d_window_size * h_window_size * w_window_size + if window_size > 125: + # Kernel size too big. Results in hard-to-optimize Triton code. + return fallback_avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(pd, ph, pw): + stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride) + pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding) + kernel_d, kernel_h, kernel_w = ( + ops.constant(k, torch.int32) for k in kernel_size + ) + + dstart, hstart, wstart = ( + ops.sub(ops.mul(p, s), pad) + for p, s, pad in zip( + [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w] + ) + ) + dend, hend, wend = ( + ops.minimum( + ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad) + ) + for start, k, dim, pad in zip( + [dstart, hstart, wstart], + [kernel_d, kernel_h, kernel_w], + [depth, height, width], + [pad_d, pad_h, pad_w], + ) + ) + dstart, hstart, wstart = ( + ops.maximum(start, ops.constant(0, torch.int32)) + for start in [dstart, hstart, wstart] + ) + dend, hend, wend = ( + ops.minimum(end, ops.index_expr(dim, torch.int32)) + for end, dim in zip([dend, hend, wend], [depth, height, width]) + ) + divide_factor = ops.mul( + ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart) + ) + return divide_factor + + def fn(idx): + *prefix, d, h, w = idx + d, h, w = (v + pad for v, pad in zip([d, h, w], padding)) + + pdstart, phstart, pwstart = ( + ops.index_expr(FloorDiv(v - k + s, s), torch.int32) + for v, k, s in zip([d, h, w], kernel_size, stride) + ) + + pdend, phend, pwend = ( + ops.index_expr(FloorDiv(v, s) + 1, torch.int32) + for v, s in zip([d, h, w], stride) + ) + + pdstart, phstart, pwstart = ( + ops.maximum(pstart, ops.constant(0, torch.int32)) + for pstart in [pdstart, phstart, pwstart] + ) + pdend, phend, pwend = ( + ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32)) + for pend, pooled_dim in zip( + [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width] + ) + ) + + gradient = None + # Iterate over the 3D region to accumulate gradients + for pd_ in range(d_window_size): + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + pd, ph, pw = ( + ops.add(pstart, ops.constant(p_, torch.int32)) + for pstart, p_ in zip( + [pdstart, phstart, pwstart], [pd_, ph_, pw_] + ) + ) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] * kernel_size[2] + else: + scale = compute_pool_size_without_padding(pd, ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + pd, ops.sub(pdend, ops.constant(1, torch.int32)) + ), + pooled_depth, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where( + mask, part, ops.constant(0.0, torch.float32) + ) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + +def _validate_reduction_axis(x, axis): + size = x.get_size() + if isinstance(axis, int): + axis = [axis] + elif not axis: + axis = range(len(size)) + if len(size) == 0: + assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}" + return [] + axis = list(axis) + for i in range(len(axis)): + if axis[i] < 0: + axis[i] += len(size) if len(size) else 1 + assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0) + assert len(set(axis)) == len(axis), "reduction axis not unique" + return axis + + +def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = set(_validate_reduction_axis(x, axis)) + + kept_sizes = [] + kept_idx = [] + reduced_sizes = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + reduced_sizes.append(size[i]) + else: + kept_idx.append(i) + kept_sizes.append(size[i]) + + def loader(index, reduction_index): + assert len(reduction_index) == len(reduced_idx) + if keepdims: + assert len(index) == len(size) + index = [index[i] for i in kept_idx] + assert len(index) == len(kept_idx) + new_index = [None] * (len(index) + len(reduction_index)) + for idx, var in itertools.chain( + zip(kept_idx, index), zip(reduced_idx, reduction_index) + ): + new_index[idx] = var + return inner_loader(new_index) + + if keepdims: + new_size = list(size) + for i in reduced_idx: + new_size[i] = sympy.Integer(1) + else: + new_size = kept_sizes + + inner_loader = x.make_loader() + return dict( + device=x.get_device(), + dst_dtype=override_return_dtype or x.get_dtype(), + src_dtype=x.get_dtype(), + inner_fn=loader, + ranges=new_size, + reduction_ranges=reduced_sizes, + ) + + +def make_reduction(reduction_type: str, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + ) + result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) + if isinstance( + result.data.data, Reduction + ): # Only realize if reduction isn't unrolled + result.realize() + return result + + return inner + + +def _make_scan_inner(x, *, axis, dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_dim(x, axis) + + return dict( + device=x.get_device(), + dtypes=(x.get_dtype(),), + inner_fns=(x.make_loader(),), + size=x.get_size(), + axis=axis, + ) + + +@register_lowering(aten.mean) +def mean(x, axis=None, keepdim=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.float16, torch.bfloat16): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + +def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + +def use_two_step_variance(x, axis, keepdim): + # Instead of unrolling welford, just unroll the simpler two-step var + axis = _validate_reduction_axis(x, axis) + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + + ranges = kwargs["ranges"] + reduction_numel = sympy_product(kwargs["reduction_ranges"]) + return ( + isinstance(reduction_numel, sympy.Integer) + and int(reduction_numel) < config.unroll_reductions_threshold + and sympy_product(ranges) != 1 + ) + + +def var_mean_welford_(x, axis, *, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + loader = kwargs.pop("inner_fn") + kwargs.pop("dst_dtype") + kwargs.pop("src_dtype") + + mean, m2, _ = ir.WelfordReduction.create( + inner_fns=(loader,), + reduction_type="welford_reduce", + dtype=x.get_dtype(), + **kwargs, + ) + m2.realize() + + dtype = x.get_dtype() + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + rnumel = sympy_product(size[i] for i in axis) + + def get_constant_or_index_expr(x, dtype): + if isinstance(x, sympy.Expr) and not x.is_number: + return ops.to_dtype(ops.index_expr(x, torch.int64), dtype) + return ops.constant(x, dtype) + + def scale_fn(data): + c = get_constant_or_index_expr(correction, dtype) + N = get_constant_or_index_expr(rnumel, dtype) + zero = ops.constant(0, dtype) + return data / ops.maximum(zero, N - c) + + var = make_pointwise(scale_fn)(m2) + + if return_mean: + mean.realize() + return var, mean + return (var,) + + +def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + if use_two_step_variance(x, axis=axis, keepdim=keepdim) + else var_mean_welford_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + +@register_lowering([aten.var, prims.var]) +def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + +@register_lowering(aten.var_mean) +def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + +def pow_recursive(x, y, dtype): + if y < 0: + return pow_recursive(ops.reciprocal(x), -y, dtype) + if y == 0: + return ops.constant(1, dtype) + if y == 1: + return x + + result = pow_recursive(x, y // 2, dtype) + result = ops.mul(result, result) + if (y % 2) == 1: + result = ops.mul(result, x) + return result + + +@make_pointwise +def pow_native(a, b): + return ops.pow(a, b) + + +fallback_pow_tensor_tensor = fallback_handler( + aten.pow.Tensor_Tensor, add_to_fallback_set=False +) +fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False) +fallback_pow_tensor_scalar = fallback_handler( + aten.pow.Tensor_Scalar, add_to_fallback_set=False +) + + +@register_lowering(aten.pow, broadcast=True) +def pow(a, b): + if isinstance(b, float) and b == int(b): + return pow(a, int(b)) + elif isinstance(b, float) and b == 0.5: + return sqrt(a) + elif isinstance(b, int) and b == 1: + return clone(a) + + # Type promotion ensures all tensor arguments have the same type + dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox)) + is_integer_pow = is_integer_dtype(dtype) + + # Optimize away small fixed powers, or for integers avoid falling back to ATen + embed_exponent = isinstance(b, int) and ( + -32 < b < 32 or (is_integer_pow and b >= 0) + ) + if embed_exponent: + loader = a.make_loader() + + def fn(idx): + return pow_recursive(loader(idx), b, a.get_dtype()) + + return Pointwise.create( + device=a.get_device(), + dtype=a.get_dtype(), + inner_fn=fn, + ranges=a.get_size(), + ) + + if isinstance(a, Number): + if a == 1: + return full_like(b, 1) + if a == 2 and is_float_dtype(b.get_dtype()): + return exp2(b) + + if is_integer_pow: + # ops.pow doesn't work for integers + if isinstance(a, Number): + return fallback_pow_scalar(a, b) + elif isinstance(b, Number): + return fallback_pow_tensor_scalar(a, b) + else: + return fallback_pow_tensor_tensor(a, b) + + return pow_native(a, b) + + +def mutate_to(changed, val, unsafe_alias=False): + if isinstance(changed, TensorBox): + changed_data = changed.data + else: + changed_data = changed + if isinstance(val, TensorBox): + val = val.data + + if not isinstance(val, ir.StorageBox): + # introduce a copy to handle views + val = Pointwise.create( + device=changed.get_device(), + dtype=changed.get_dtype(), + inner_fn=val.make_loader(), + ranges=changed.get_size(), + ).data + assert isinstance(val, ir.StorageBox) + + if isinstance(changed_data, ir.StorageBox) and not ( + changed_data.is_input_buffer() + # In AOTI, module parameters and buffers are not lifted as graph inputs + or changed_data.is_module_buffer() + or isinstance(changed_data.data, ir.NopKernel) + ): + # Fast path, just swing the data pointer + val.realize() + changed_data.data = val.data + return changed + + ir.MutationLayoutSHOULDREMOVE.realize_into( + val, changed_data, unsafe_alias=unsafe_alias + ) + return changed + + +@register_lowering(aten.fill_) +def fill_(x, fill_value): + return mutate_to(x, full_like(x, fill_value)) + + +@register_lowering(aten.copy_, type_promotion_kind=None) +def copy_(dst, src, non_blocking=False): + if dst is src: + # dst.copy_(dst) can happen from the reinplacing pass + return dst + src = to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) + + +@make_pointwise +def floordiv(a, b): + return ops.floordiv(a, b) + + +@make_pointwise +def truncdiv(a, b): + return ops.truncdiv(a, b) + + +@register_lowering(aten.div, broadcast=True) +def div_mode(a, b, rounding_mode=None): + both_integer = is_integer_type(a) and is_integer_type(b) + both_boolean = is_boolean_type(a) and is_boolean_type(b) + + # floordiv and truncdiv need special handling for integer tensors on Triton, + # see the discussion at https://github.com/openai/triton/issues/605 + if rounding_mode == "floor": + assert not both_boolean, "floordiv operands can not be boolean at the same time" + return floordiv(a, b) if both_integer else floor(div(a, b)) + if rounding_mode == "trunc": + assert not both_boolean, "truncdiv operands can not be boolean at the same time" + return truncdiv(a, b) if both_integer else trunc(div(a, b)) + return div(a, b) + + +@register_lowering([aten.mul], broadcast=True) +def mul(a, b): + both_bool = is_boolean_type(a) and is_boolean_type(b) + if both_bool: + return logical_and(a, b) + else: + fn = ops_wrapper(aten.mul.__name__) + return make_pointwise(fn)(a, b) + + +def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]: + """Try convert an arbitrary IR node into an ir.Constant value""" + + # First try unwrapping the IRNode to see if it is already an ir.Constant + # Optional step, but avoids unnecessary inner_fn evaluation. + if isinstance(x, ir.MutableBox): + return get_constant_value(x.data) + if isinstance(x, ir.BaseView): + return get_constant_value(x.unwrap_view()) + if isinstance(x, ir.Constant): + return x + + # If the unwrapped node is not an ir.Constant, try evaluating inner_fn + # to see if the returned value is from an `ops.constant` call + if not isinstance(x, ir.Loops): + return None + + handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device()) + with V.set_ops_handler(handler), patch.object( + ir.FlexibleLayout, "allow_indexing", True + ): + out = x.inner_fn(*x.inner_fn_args()) + + assert isinstance(out, torch._inductor.virtualized.OpsValue) + if isinstance(out.value, ir.Constant): + return out.value + return None + + +# NOTE: prims.div maps to a / b in C, so performs truncation division on +# integer inputs and true division for floating and complex inputs. +@register_lowering([prims.div], broadcast=True) +def div_prim(a, b): + is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b]) + + if is_integral: + return truncdiv(a, b) + + if (divisor := get_constant_value(b)) is not None: + # Replace divide by constant with multiply by reciprocal + if divisor.value == 0: + reciprocal = math.copysign(float("inf"), divisor.value) + else: + reciprocal = 1.0 / divisor.value + return mul(a, reciprocal) + + def fn(*args): + return ops.truediv(*args) + + return make_pointwise(fn)(a, b) + + +@register_lowering( + [aten.true_divide, aten.div.Tensor], + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def div(a, b): + a, b = promote_constants( + (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + return div_prim(a, b) + + +@register_lowering([aten.fmod, prims.fmod], broadcast=True) +def fmod(a, b): + is_integral = is_boolean_type(a) or is_integer_type(a) + + if is_integral: + + def fn(a, b): + return ops.mod(a, b) + + else: + + def fn(a, b): + return ops.fmod(a, b) + + return make_pointwise(fn)(a, b) + + +@register_lowering(aten.rsqrt) +def rsqrt(x): + dtype = x.get_dtype() + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + x = to_dtype(x, torch.get_default_dtype()) + + def _rsqrt(x): + return ops.rsqrt(x) + + return make_pointwise(_rsqrt)(x) + + +@register_lowering([aten.sum, prims.sum]) +def sum_(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("sum", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +fallback_cumsum = fallback_handler(aten.cumsum.default) +fallback_cumprod = fallback_handler(aten.cumprod.default) +fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default) +fallback_cummax = fallback_handler(aten.cummax.default) +fallback_cummin = fallback_handler(aten.cummin.default) + + +@register_lowering(aten.cumsum) +def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + def combine_fn(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + return (ops.add(a, b),) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn) + if result is None: + return fallback_cumsum(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.cumprod) +def cumprod(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + def combine_fn(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + return (ops.mul(a, b),) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn) + if result is None: + return fallback_cumprod(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.logcumsumexp) +def logcumsumexp(x, dim): + def log_add_exp_helper(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + min_v = ops.minimum(a, b) + max_v = ops.maximum(a, b) + mask = (min_v != max_v) | (~ops.isinf(min_v)) + return (ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a),) + + dtype = x.get_dtype() + if len(x.get_size()) == 0: + assert dim in [0, -1] + return clone(x) + + kwargs = _make_scan_inner(x, axis=dim, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper) + if result is None: + return fallback_logcumsumexp(x, dim=dim) + return result + + +@register_lowering(aten.cummax, type_promotion_kind=None) +def cummax(x, axis=None): + if len(x.get_size()) == 0: + assert axis in [0, -1] + return clone(x), empty_like(x, dtype=torch.int64) + + dtype = x.get_dtype() + combine_fn = ir.get_reduction_combine_fn( + "argmax", dtype=dtype, arg_break_ties_left=False + ) + + min_value = ( + False + if dtype is torch.bool + else ( + torch.finfo(dtype).min + if dtype.is_floating_point + else torch.iinfo(dtype).min + ) + ) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + kwargs["dtypes"] = (dtype, torch.int64) + kwargs["inner_fns"] = (x.make_loader(), lambda _: "rindex") + values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type] # next PR + if values is None: + return fallback_cummax(x, dim=axis) + return values, indices + + +@register_lowering(aten.cummin, type_promotion_kind=None) +def cummin(x, axis=None): + if len(x.get_size()) == 0: + assert axis in [0, -1] + return clone(x), empty_like(x, dtype=torch.int64) + + dtype = x.get_dtype() + combine_fn = ir.get_reduction_combine_fn( + "argmin", dtype=dtype, arg_break_ties_left=False + ) + + max_value = ( + True + if dtype is torch.bool + else ( + torch.finfo(dtype).max + if dtype.is_floating_point + else torch.iinfo(dtype).max + ) + ) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + kwargs["dtypes"] = (dtype, torch.int64) + kwargs["inner_fns"] = (x.make_loader(), lambda _: "rindex") + values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type] # next PR + if values is None: + return fallback_cummin(x, dim=axis) + return values, indices + + +@register_lowering(aten.prod) +def prod(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("prod", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +@register_lowering(aten.any) +def reduce_any(x, dim=None, keepdim=False): + x = to_dtype(x, torch.bool) + return make_reduction("any")(x, axis=dim, keepdims=keepdim) + + +@register_lowering(aten.max, type_promotion_kind=None) +def reduce_max(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amax(x, axis=dim, keepdims=keepdim), + reduce_argmax(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amax(x, axis=None, keepdims=keepdim) + + +@register_lowering(aten.min, type_promotion_kind=None) +def reduce_min(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amin(x, axis=dim, keepdims=keepdim), + reduce_argmin(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amin(x, axis=None, keepdims=keepdim) + + +register_lowering(prims.xor_sum)(make_reduction("xor_sum")) +reduce_amax = register_lowering(aten.amax)(make_reduction("max")) +reduce_amin = register_lowering(aten.amin)(make_reduction("min")) +reduce_argmax = register_lowering(aten.argmax)( + make_reduction("argmax", override_return_dtype=torch.int64) +) +reduce_argmin = register_lowering(aten.argmin)( + make_reduction("argmin", override_return_dtype=torch.int64) +) + +add = register_pointwise( + aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or" +) + +sort_fallback = fallback_handler(aten.sort.stable, add_to_fallback_set=False) + + +@register_lowering(aten.sort.stable, type_promotion_kind=None) +def sort_stable(x, *, stable=None, dim=-1, descending=False): + if stable is None: + stable = False + + shape = x.get_size() + device = x.get_device() + dim = canonicalize_dim(len(shape), dim) + if len(shape) == 0: + return clone(x), _full(0, device, torch.int64, shape) + + dim_size = shape[dim] if len(shape) else 1 + if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(torch.int16).max): + return sort_fallback(x, stable=stable, dim=dim, descending=descending) + + indices = iota( + dim_size, start=0, step=1, dtype=torch.int16, device=device, requires_grad=False + ) + view_shape = [1] * len(shape) + if len(shape): + view_shape[dim] = dim_size + indices = view(indices, view_shape) + indices = expand(indices, shape) + + values, indices = ir.Sort.create( + device=device, + dtypes=(x.dtype, indices.dtype), + inner_fns=(x.make_loader(), indices.make_loader()), + size=shape, + axis=dim, + stable=stable, + descending=descending, + ) + if values is None: + return sort_fallback(x, stable=stable, dim=dim, descending=descending) + + assert indices is not None + return values, to_dtype(indices, torch.int64) + + +@register_lowering(aten.sort.default, type_promotion_kind=None) +def sort(x, dim=-1, descending=False): + return sort_stable(x, stable=False, dim=dim, descending=descending) + + +def register_pointwise_numeric(op, name=None, triton_fallback=None): + return register_pointwise( + op, + name=name, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + triton_fallback=triton_fallback, + ) + + +def register_pointwise_numeric_ldf64(op): + return register_pointwise( + op, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + use_libdevice_for_f64=True, + ) + + +exp = register_pointwise_numeric_ldf64(aten.exp) +exp2 = register_pointwise_numeric(aten.exp2) +expm1 = register_pointwise_numeric(aten.expm1) +relu = register_pointwise(aten.relu) +sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid) +sqrt = register_pointwise_numeric_ldf64(aten.sqrt) +square = register_pointwise(aten.square) +sub = register_pointwise(aten.sub, allow_alpha=True) +register_pointwise_numeric_ldf64(aten.cos) +register_pointwise_numeric_ldf64(aten.sin) +abs = register_pointwise(aten.abs) +bitwise_and = register_pointwise(aten.bitwise_and) +bitwise_left_shift = register_pointwise(aten.bitwise_left_shift) +bitwise_not = register_pointwise( + aten.bitwise_not, override_fn_when_input_bool="logical_not" +) +bitwise_or = register_pointwise(aten.bitwise_or) +bitwise_right_shift = register_pointwise(aten.bitwise_right_shift) +bitwise_xor = register_pointwise(aten.bitwise_xor) +register_pointwise_numeric(aten.lgamma) +erf = register_pointwise_numeric(aten.erf) +register_lowering( + aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +)(erf) + +register_pointwise_numeric(aten.log1p) +register_pointwise_numeric(aten.tan) +register_pointwise_numeric(aten.tanh) +register_pointwise_numeric_ldf64(aten.log) +logical_and = register_pointwise( + aten.logical_and, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_not = register_pointwise( + aten.logical_not, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_or = register_pointwise( + aten.logical_or, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_xor = register_pointwise( + aten.logical_xor, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +maximum = register_pointwise(aten.maximum) +minimum = register_pointwise(aten.minimum) +register_lowering(aten.clamp_min)(maximum) +register_lowering(aten.clamp_max)(minimum) +neg = register_pointwise(aten.neg) +abs = register_pointwise(aten.abs) +reciprocal = register_pointwise_numeric(aten.reciprocal) +register_pointwise(aten.remainder) +sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity") +register_pointwise(aten.ceil) +register_pointwise(aten.signbit, override_return_dtype=torch.bool) + +register_lowering(aten._neg_view)(neg) + +register_pointwise(aten.le, override_return_dtype=torch.bool) +register_pointwise(aten.lt, override_return_dtype=torch.bool) +register_pointwise(aten.ge, override_return_dtype=torch.bool) +gt = register_pointwise(aten.gt, override_return_dtype=torch.bool) +register_pointwise(aten.eq, override_return_dtype=torch.bool) +register_pointwise(aten.ne, override_return_dtype=torch.bool) + +register_pointwise_numeric(aten.cosh) +register_pointwise_numeric(aten.sinh) +register_pointwise_numeric(aten.acos) +register_pointwise_numeric(aten.acosh) +register_pointwise_numeric(aten.asin) +register_pointwise_numeric(aten.asinh) +register_pointwise_numeric(aten.atan2) +register_pointwise_numeric(aten.atan) +register_pointwise_numeric(aten.atanh) +register_pointwise_numeric(aten.copysign) +register_pointwise_numeric(aten.erfc) +register_pointwise_numeric(aten.erfinv) +register_pointwise_numeric(aten.hypot) +register_pointwise_numeric(aten.log10) +register_pointwise_numeric(aten.log2) +register_pointwise_numeric(aten.nextafter) + +from .codegen.common import BackendFeature, pointwise_overrides_data + + +def _get_pointwise_overrides(ns, name): + data = pointwise_overrides_data[name] + op = getattr(ns, data.name, None) + if op is None: + return + + def make_triton_fallback(op): + if data.triton is None: + return fallback_handler(op) + + if isinstance(op, torch._ops.OpOverloadPacket): + for olname in op.overloads(): + ol = getattr(op, olname) + yield ol, data.type_promotion_kind, make_triton_fallback(ol) + else: + yield op, data.type_promotion_kind, make_triton_fallback(op) + + +for name in pointwise_overrides_data: + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + aten, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + prims, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + +foreach_add_list = register_foreach_pointwise( + aten._foreach_add.List, add, allow_alpha=True +) +foreach_add_scalar = register_foreach_pointwise( + aten._foreach_add.Scalar, add, allow_alpha=True +) +register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True) +foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul) +foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul) +register_foreach_pointwise(aten._foreach_sub.List, sub) +register_foreach_pointwise(aten._foreach_sub.Scalar, sub) +register_foreach_pointwise(aten._foreach_neg.default, neg) +register_foreach_pointwise(aten._foreach_abs.default, abs) +register_foreach_pointwise(aten._foreach_pow.Scalar, pow) +register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow) +foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div) +foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div) +register_foreach_pointwise(aten._foreach_sqrt, sqrt) +register_foreach_pointwise(aten._foreach_maximum.List, maximum) +register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum) +register_foreach_pointwise(aten._foreach_minimum.List, minimum) +register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum) +register_foreach_pointwise(aten._foreach_clamp_min.List, maximum) +register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum) +register_foreach_pointwise(aten._foreach_clamp_max.List, minimum) +register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum) +register_foreach_pointwise(aten._foreach_reciprocal, reciprocal) +register_foreach_pointwise(aten._foreach_sign, sign) +register_foreach_pointwise(aten._foreach_copy, copy) + + +# these are only encountered as outputs of the graph +# reinplacing epilogue copies improves compile time +# by removing extra buffers sent to the scheduler. +def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op): + inplaceable_foreach_ops[outplace_aten_op] = aten_op + inplace_foreach_ops.add(aten_op) + + def fn(*args, **kwargs): + results = outplace_op(*args, **kwargs) + mut_results = [] + for arg, result in zip(args[0], results): + mut_results.append(mutate_to(arg, result, unsafe_alias=True)) + + return mut_results + + _register_foreach_lowering(aten_op, fn) + + +register_foreach_inplace( + aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list +) +register_foreach_inplace( + aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar +) +register_foreach_inplace( + aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list +) +register_foreach_inplace( + aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar +) +register_foreach_inplace( + aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list +) +register_foreach_inplace( + aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar +) + + +def register_inplace(aten_op, outplace_op): + @register_lowering(aten_op, type_promotion_kind=None) + def fn(*args, **kwargs): + result = outplace_op(*args, **kwargs) + result = to_dtype(result, args[0].get_dtype()) + return mutate_to(args[0], result) + + return fn + + +register_inplace(aten.add_, add) +register_inplace(aten.bitwise_and_, bitwise_and) +register_inplace(aten.bitwise_left_shift_, bitwise_left_shift) +register_inplace(aten.bitwise_not_, bitwise_not) +register_inplace(aten.bitwise_or_, bitwise_or) +register_inplace(aten.bitwise_right_shift_, bitwise_right_shift) +register_inplace(aten.bitwise_xor_, bitwise_xor) +register_inplace(aten.mul_, mul) +register_inplace(aten.div_.Tensor, div) +register_inplace(aten.div_.Tensor_mode, div_mode) +register_inplace(aten.logical_and_, logical_and) +register_inplace(aten.logical_not_, logical_not) +register_inplace(aten.logical_or_, logical_or) +register_inplace(aten.logical_xor_, logical_xor) +register_inplace(aten.sub_, sub) +register_inplace(aten.relu_, relu) +register_inplace(aten.sigmoid_, sigmoid) + + +register_lowering(aten.__and__)(bitwise_and) +register_lowering(aten.__lshift__)(bitwise_left_shift) +register_lowering(aten.__or__)(bitwise_or) +register_lowering(aten.__rshift__)(bitwise_right_shift) +register_lowering(aten.__xor__)(bitwise_xor) + +register_inplace(aten.__iand__, aten.__and__) +register_inplace(aten.__ilshift__, aten.__lshift__) +register_inplace(aten.__ior__, aten.__or__) +register_inplace(aten.__irshift__, aten.__rshift__) +register_inplace(aten.__ixor__, aten.__xor__) + + +@register_lowering(aten.sym_constrain_range) +def sym_constrain_range(a, min=None, max=None): + return None + + +@register_lowering(aten.sym_size.int) +def sym_size(a, dim): + val = V.graph.current_node.meta["val"] + # Note [Can val be an int?] + # ~~~~~~~~~~~~~~~~~~~~~~~~~ + # In principle, someone could construct an FX graph where + # a call to size/stride has a val that is a plain int (not + # SymInt). However, we will maintain the invariant that + # this is not possible: if you are constructing an FX graph + # where there is a call to size/stride that returns an + # int, but you KNOW that int must always be a constant, + # then you do not need trace that call at all (and just + # constant propagate the integer as is.) + assert isinstance(val, torch.SymInt) + return val.node.expr + + +@register_lowering(aten.sym_stride.int) +def sym_stride(a, dim): + val = V.graph.current_node.meta["val"] + # See Note [Can val be an int?] + assert isinstance(val, torch.SymInt) + return val.node.expr + + +@register_lowering(aten.sym_numel) +def sym_numel(a): + return a.get_numel() + + +for method, func in magic_methods.items(): + register_lowering(method_to_operator(method))(func) + + +@register_lowering(aten._foobar) +def foobar(self, *args, **kwargs): + raise NotImplementedError("Helpful for debugging") + + +@register_lowering(torch.ops._inductor_test.realize) +def _realize(x): + x.realize() + return clone(x) + + +@register_lowering(torch.ops.inductor.resize_storage_bytes_) +def resize_storage_bytes_(variable, new_size): + variable.realize() + ir.ResizeStorageBytes(variable, new_size) + return variable + + +@register_lowering(torch.ops.aten.set_.source_Tensor) +def set__source_tensor(self, source_tensor): + self.realize() + source_tensor.realize() + return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor)) + + +if hasattr(torch.ops.fsdp, "set_"): + + @register_lowering(torch.ops.fsdp.set_.default) + def fsdp_set_(self, source_tensor): + self.realize() + source_tensor.realize() + ir.SetSourceTensorKernel(self, source_tensor) + + +@register_lowering(torch.ops.aten.resize) +def resize(x, size, *, memory_format=None): + assert isinstance(x, TensorBox) + assert isinstance(size, (list, tuple)) + + if memory_format is None: + memory_format = torch.contiguous_format + if memory_format == torch.preserve_format: + raise RuntimeError(f"unsupported memory format: {memory_format}") + + if memory_format == torch.channels_last: + assert len(size) == 4 + if memory_format == torch.channels_last_3d: + assert len(size) == 5 + + old_numel = x.get_numel() + dtype = x.get_dtype() + device = x.get_device() + + if isinstance(x.data, ir.BaseView): + x.data = x.data.unwrap_view() + + if ( + torch.are_deterministic_algorithms_enabled() + and torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined] + ): + if is_float_dtype(dtype): + uninitalized_val = float("nan") + elif is_integer_dtype(dtype): + uninitalized_val = torch.iinfo(dtype).max + else: + uninitalized_val = True + else: + # using zero as that is what empty does + uninitalized_val = 0.0 + + if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type] + return full(size, uninitalized_val, dtype=dtype, device=device) + + x_flat = as_strided( + x, + [ + old_numel, + ], + [ + 1, + ], + ) + flat_loader = x_flat.make_loader() + out_stride = ir.FlexibleLayout.stride_ordered_for_memory_format(size, memory_format) + out_indexer = ir.FixedLayout(device, dtype, size, out_stride).make_indexer() + + def inner_fn(idx): + flat_index = out_indexer(idx) + flat_index_expr = ops.index_expr(flat_index, torch.int64) + limit = ops.index_expr(old_numel, torch.int64) + mask = ops.lt(flat_index_expr, limit) + return ops.masked(mask, lambda: flat_loader([flat_index]), uninitalized_val) + + out = Pointwise.create( + device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(size) + ) + return out + + +from torch._higher_order_ops.auto_functionalize import auto_functionalized + + +make_fallback(auto_functionalized) + + +@register_lowering(triton_kernel_wrapper_mutation) +def triton_kernel_wrap_(*, kernel_idx, constant_args_idx, grid, kwargs): + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + + constant_args = kernel_side_table.get_constant_args(constant_args_idx) + ir.UserDefinedTritonKernel( + kernel_idx=kernel_idx, + grid=grid, + kernel_args={**kwargs, **constant_args}, + ) + return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)} + + +@register_lowering(torch.ops.higher_order.cond) +def cond(pred, true_fn, false_fn, operands): + if is_triton(pred) or any(map(is_triton, operands)): + msg = "control flow operator: torch.cond." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + result = ir.Conditional.create(pred, true_fn, false_fn, operands) + return list(map(TensorBox.create, result)) + + +@register_lowering(torch.ops.higher_order.while_loop) +def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): + if any(map(is_triton, carried_inputs + additional_inputs)): + msg = "control flow operator: torch.while_loop." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + result = ir.WhileLoop.create(cond_fn, body_fn, carried_inputs, additional_inputs) + return list(map(TensorBox.create, result)) + + +@register_lowering(associative_scan_op, type_promotion_kind=None) +def associative_scan(combine_fn: ir.Subgraph, input, dim: int): + from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph + + subgraph_inputs = [ + InputDescriptor(dtype=x.get_dtype(), device=x.get_device()) + for x in itertools.chain(input, input) + ] + lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated] + + def wrapped_combine_fn(lhs, rhs): + return lowered_combine_fn( + *pytree.tree_leaves(lhs), + *pytree.tree_leaves(rhs), + ) + + kwargs = _make_scan_inner(input[0], axis=dim, dtype=None) + kwargs["dtypes"] = tuple(x.get_dtype() for x in input) + kwargs["inner_fns"] = tuple(x.make_loader() for x in input) + result = ir.Scan.create( + combine_fn=wrapped_combine_fn, + can_fallback_to_aten=False, + **kwargs, + ) + if result[0] is None: + raise RuntimeError("Unable to generate code for associative_scan op") + return result + + +@register_lowering(torch.ops.prims._sink_tokens.default) +def _sink_tokens(tokens): + return None + + +@register_lowering(torch.ops.higher_order.with_effects) +def with_effects(token, op, *args, **kwargs): + result = ir.EffectfulKernel.create(op, *args, **kwargs) + + from torch._higher_order_ops.effects import get_effect_key + + effect_type = get_effect_key(op, args, kwargs) + assert effect_type is not None + effectful_kernel = V.graph.effectful_ops[effect_type] + + if result is None: + return (effectful_kernel,) + + result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result) + if not isinstance(result, (list, tuple)): + return (effectful_kernel, result) + else: + return (effectful_kernel, *result) + + +try: + import torch.distributed._functional_collectives + + _c10d_functional = torch.ops._c10d_functional + + @register_lowering(_c10d_functional.all_reduce) + def _all_reduce(inp, reduce_op, group_name): + inp = clone(inp) + if config.reorder_for_compute_comm_overlap: + # The horizontal fusion of this clone often severely delays the + # scheduling of the all_reduce_ node. Horizontally fusing this + # clone can almost never out-perform scheduling the all_reduce_ + # earlier. Also in most cases, this clone is eliminated via + # in-place reuse. Therefore, we tell the scheduler to not fuse it. + inp.realize() + V.graph.no_fuse_buffer_names.add(inp.get_name()) + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_.default, inp, reduce_op, group_name + ) + return inp + + @register_lowering(_c10d_functional.all_reduce_) + def _all_reduce_(inp, reduce_op, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_.default, inp, reduce_op, group_name + ) + return inp + + @register_lowering(_c10d_functional.all_reduce_coalesced) + def _all_reduce_coalesced(inputs, reduce_op, group_name): + inputs = [clone(inp) for inp in inputs] + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_lowering(_c10d_functional.all_reduce_coalesced_) + def _all_reduce_coalesced_(inputs, reduce_op, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_lowering(_c10d_functional.all_gather_into_tensor) + def _all_gather_into_tensor(inp, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_gather_into_tensor.default, + inp, + group_size, + group_name, + ) + ) + + @register_lowering(_c10d_functional.all_gather_into_tensor_coalesced) + def _all_gather_into_tensor_coalesced(inputs, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_gather_into_tensor_coalesced.default, + inputs, + group_size, + group_name, + ), + ) + + @register_lowering(_c10d_functional.all_gather_into_tensor_out) + def _all_gather_into_tensor_out(inp, group_size, group_name, *, out): + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_gather_into_tensor_out.default, + inp, + group_size, + group_name, + out=out, + ) + return out + + @register_lowering(_c10d_functional.reduce_scatter_tensor) + def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.reduce_scatter_tensor.default, + inp, + reduce_op, + group_size, + group_name, + ) + ) + + @register_lowering(_c10d_functional.reduce_scatter_tensor_coalesced) + def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.reduce_scatter_tensor_coalesced.default, + inputs, + reduce_op, + group_size, + group_name, + ), + ) + + @register_lowering(_c10d_functional.all_to_all_single) + def _all_to_all_single(inp, output_split_sizes, input_split_sizes, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_to_all_single.default, + inp, + output_split_sizes, + input_split_sizes, + group_name, + ) + ) + + @register_lowering(_c10d_functional.broadcast) + def _broadcast(inp, src, group_name): + inp = clone(inp) + ir._CollectiveKernel.create_inplace( + _c10d_functional.broadcast_.default, inp, src, group_name + ) + return inp + + @register_lowering(_c10d_functional.broadcast_) + def _broadcast_(inp, src, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.broadcast_.default, inp, src, group_name + ) + return inp + + @register_lowering(_c10d_functional.wait_tensor) + def _wait_tensor(inp): + ir._WaitKernel.create_wait(_c10d_functional.wait_tensor.default, inp) + return inp + + @register_lowering(torch.ops._dtensor.shard_dim_alltoall) + def _shard_dim_alltoall(inp, gather_dim, shard_dim, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + torch.ops._dtensor.shard_dim_alltoall.default, + inp, + gather_dim, + shard_dim, + group_name, + ) + ) + +except (AttributeError, ImportError): + log.info( + "Inductor support for distributed collectives depends on building torch.distributed" + ) + +# populate lowerings defined in kernel/* +from . import kernel + + +import_submodule(kernel) + +from . import quantized_lowerings + + +quantized_lowerings.register_quantized_ops() +quantized_lowerings.register_woq_mm_ops() + +from . import mkldnn_lowerings + + +mkldnn_lowerings.register_onednn_fusion_ops() + +from . import jagged_lowerings + + +jagged_lowerings.register_jagged_ops() diff --git a/lib/python3.10/site-packages/torch/_inductor/metrics.py b/lib/python3.10/site-packages/torch/_inductor/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..fe77279800e3da5ae581cb684c340726b9b3827d --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/metrics.py @@ -0,0 +1,436 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import csv +import dataclasses +import inspect +import os +import re +from dataclasses import dataclass +from functools import lru_cache +from typing import Dict, List, Set, Tuple, TYPE_CHECKING + +from torch._inductor import config +from torch._inductor.utils import get_benchmark_name + + +# Prevent circular import +if TYPE_CHECKING: + from torch._inductor.scheduler import BaseSchedulerNode + +# counter for tracking how many kernels have been generated +generated_kernel_count = 0 +generated_cpp_vec_kernel_count = 0 +num_bytes_accessed = 0 +nodes_num_elem: List[ + Tuple[ + BaseSchedulerNode, + int, + ] +] = [] +node_runtimes: List[Tuple[BaseSchedulerNode, float]] = [] + +# counters for tracking fusions +ir_nodes_pre_fusion = 0 + +# counters for tracking to_dtype inserted +cpp_to_dtype_count = 0 + + +@dataclasses.dataclass +class CppOuterLoopFusedCount: + inner_kernel_number: int + local_buffer_number: int = 0 + + +# The length counts the number of outer loop fusions. +cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = [] + +num_comprehensive_padding = 0 +num_matches_for_scatter_upon_const_tensor = 0 + +num_loop_reordering = 0 + + +# reset all counters +def reset(): + global generated_kernel_count + global generated_cpp_vec_kernel_count + global num_bytes_accessed, nodes_num_elem + global ir_nodes_pre_fusion + global cpp_to_dtype_count + global cpp_outer_loop_fused_inner_counts + global num_comprehensive_padding + global num_matches_for_scatter_upon_const_tensor + global num_loop_reordering + + generated_kernel_count = 0 + generated_cpp_vec_kernel_count = 0 + num_bytes_accessed = 0 + nodes_num_elem.clear() + node_runtimes.clear() + ir_nodes_pre_fusion = 0 + cpp_to_dtype_count = 0 + cpp_outer_loop_fused_inner_counts.clear() + num_comprehensive_padding = 0 + num_matches_for_scatter_upon_const_tensor = 0 + num_loop_reordering = 0 + + +@dataclass +class CachedMetricsDeltas: + """ + The subset of metrics we want update across cache hits, e.g., the + FxGraphCache. + """ + + generated_kernel_count: int + generated_cpp_vec_kernel_count: int + ir_nodes_pre_fusion: int + cpp_to_dtype_count: int + num_bytes_accessed: int + num_matches_for_scatter_upon_const_tensor: int + + +def get_metric_fields(): + return [field.name for field in dataclasses.fields(CachedMetricsDeltas)] + + +class CachedMetricsHelper: + """ + A helper class to help calculate and apply counter deltas for those + metrics we want to save with cache entries (e.g., FxGraphCache) and + apply on a cache hit. + """ + + def __init__(self) -> None: + self.cached_metrics = {} + for metric in get_metric_fields(): + self.cached_metrics[metric] = globals()[metric] + + def get_deltas(self) -> CachedMetricsDeltas: + delta_metrics = {} + for metric in get_metric_fields(): + delta_metrics[metric] = globals()[metric] - self.cached_metrics[metric] + + return CachedMetricsDeltas(**delta_metrics) + + @staticmethod + def apply_deltas(delta: CachedMetricsDeltas): + for metric in get_metric_fields(): + globals()[metric] += getattr(delta, metric) + + +REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {} + + +@dataclass +class MetricTable: + table_name: str + column_names: List[str] + + num_rows_added: int = 0 + + def add_row(self, row_fn): + if self.table_name not in enabled_metric_tables(): + return + + row_dict = row_fn() + assert len(self.column_names) == len( + row_dict + ), f"{len(self.column_names)} v.s. {len(row_dict)}" + assert set(self.column_names) == set( + row_dict.keys() + ), f"{set(self.column_names)} v.s. {set(row_dict.keys())}" + + row = [ + get_benchmark_name(), + ] + row += [row_dict[column_name] for column_name in self.column_names] + self._write_row(row) + + def output_filename(self): + return f"metric_table_{self.table_name}.csv" + + def write_header(self): + filename = self.output_filename() + with open(filename, "w") as fd: + writer = csv.writer(fd, lineterminator="\n") + writer.writerow(["model_name"] + self.column_names) + + def _write_row(self, row): + filename = self.output_filename() + if self.num_rows_added == 0 and not os.path.exists(filename): + self.write_header() + + self.num_rows_added += 1 + + for idx, orig_val in enumerate(row): + if isinstance(orig_val, float): + new_val = f"{orig_val:.6f}" + elif orig_val is None: + new_val = "" + else: + new_val = orig_val + row[idx] = new_val + + with open(filename, "a") as fd: + writer = csv.writer(fd, lineterminator="\n") + writer.writerow(row) + + @staticmethod + def register_table(name, column_names): + table = MetricTable(name, column_names) + REGISTERED_METRIC_TABLES[name] = table + + +MetricTable.register_table( + "slow_fusion", + [ + "kernel1_path", + "kernel1_latency", + "kernel2_path", + "kernel2_latency", + "fused_kernel_path", + "fused_kernel_latency", + "slow_down_ratio", + ], +) + +# track the fusion statistics for each graph +MetricTable.register_table( + "graph_stats", + [ + "graph_id", + "num_nodes_before_fusion", + "num_nodes_after_fusion", + ], +) + +# track the perf difference between persistent reduction and non-persistent +# reductions +MetricTable.register_table( + "persistent_red_perf", + [ + "kernel1_name", + "kernel2_name", + "kernel1_latency", + "kernel2_latency", + "size_hints", + "reduction_hint", + "speedup", + ], +) + +# Log the fusion failures due to indexing mismatch +MetricTable.register_table( + "fusion_failure_due_to_indexing_mismatch", + [ + "pre_grad_graph_id", + "post_grad_graph_id", + "node1_name", + "node2_name", + "node1_debug_str", + "node2_debug_str", + "common_buffer_names", + "failure_reason", + ], +) + +# Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint +MetricTable.register_table( + "kernel_metadata", + [ + "kernel_name", + "kernel_path", + "kernel_category", # pointwise/reduction/foreach etc. + "size_hints", + "reduction_hint", + "line_of_code", + "num_load", + "num_store", + "num_for_loop", + "num_atomic_add", + "num_args", + # xyz numel can be different to size_hints since size_hints are rounded + # up to the nearest power of 2. + # Inductor kernel will burn in the xyz numel in kernel code for static + # shape kernels. + # Logging them will be helpful to find unaligned shape for reduction + "xnumel", + "ynumel", + "rnumel", + "kernel_args_num_gb", + ], +) + + +def _parse_kernel_fn_code(kernel_module_code): + """ + The kernel_module_code is the python module that contains kernel function code. + kernel function is the proper triton kernel function annotated with + @triton.jit + """ + from .codecache import PyCodeCache + from .wrapper_benchmark import get_triton_kernel + + mod = PyCodeCache.load(kernel_module_code) + kernel = get_triton_kernel(mod) + # kernel is a CachingAutotune; kernel.fn is the JITFunction; + # kernel.fn.fn is the function being decorate by triton.jit + return inspect.getsource(kernel.fn.fn) + + +def _parse_kernel_line_of_code(proper_kernel_fn_code): + """ + Return the line of code for the kernel excluding the decorators. + """ + return len(proper_kernel_fn_code.splitlines()) + + +def _parse_size_hints(kernel_module_code, kernel_category): + if kernel_category == "foreach": + # foreach kernel does not have size_hints + return None + m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code) + assert m, "size_hints missing!" + return m.group(1) + + +def _parse_reduction_hint(kernel_category, kernel_module_code): + if kernel_category not in ("reduction", "persistent_reduction"): + return None + m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code) + assert m, "reduction_hint not found in kernel source code!" + return m.group(1) + + +def _count_pattern(proper_kernel_fn_code, pattern): + return proper_kernel_fn_code.count(pattern) + + +def _count_args(proper_kernel_fn_code): + def_line = proper_kernel_fn_code.splitlines()[0] + assert def_line.startswith("def ") + start_idx = def_line.index("(") + end_idx = def_line.index("):") + decl_csv = def_line[start_idx + 1 : end_idx] + comps = decl_csv.split(",") + return len(comps) + + +def _parse_proper_kernel_fn_code(kernel_fn_code): + """ + Skip decorators. + """ + start_pos = kernel_fn_code.index("def ") + return kernel_fn_code[start_pos:] + + +def _parse_numel(proper_kernel_fn_code, numel_arg_name): + m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code) + if m: + return int(m.group(1)) + else: + return None + + +def _parse_kernel_args_num_gb(kernel_fn_code, kernel_category): + """ + inductor meta looks like: + inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0}, + """ + m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code) + if m: + return float(m.group(1)) + else: + """ + There are a few cases that kernel_num_gdb field can be missing: + 1. the field will be missing if config.benchmark_kernel and + config.profile_bandwidth are false + 2. even if config.benchmark_kernel or config.profile_bandwidth is true. + foreach kernel does not have kernel_num_gb field in the metadata + """ + return None + + +def log_kernel_metadata(kernel_name, kernel_path, kernel_module_code): + """ + An utility to log kernel metadata. We may parse metadata from kernel source code here. + + It's fine to parse the generated kernel code here since the logging is + disabled by default. It would hurt compilation time. + """ + from .wrapper_benchmark import get_kernel_category_by_source_code + + kernel_category = get_kernel_category_by_source_code(kernel_module_code) + reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code) + size_hints = _parse_size_hints(kernel_module_code, kernel_category) + kernel_fn_code = _parse_kernel_fn_code(kernel_module_code) + + proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code) + + # the line of code excluding the decortors + kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code) + + get_metric_table("kernel_metadata").add_row( + lambda: { + "kernel_name": kernel_name, + "kernel_path": kernel_path, + "kernel_category": kernel_category, + "size_hints": size_hints, + "reduction_hint": reduction_hint, + "line_of_code": kernel_line_of_code, + "num_load": _count_pattern(proper_kernel_fn_code, "tl.load"), + "num_store": _count_pattern(proper_kernel_fn_code, "tl.store"), + "num_for_loop": _count_pattern(proper_kernel_fn_code, "for "), + "num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"), + "num_args": _count_args(proper_kernel_fn_code), + "xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"), + "ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"), + "rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"), + "kernel_args_num_gb": _parse_kernel_args_num_gb( + kernel_fn_code, kernel_category + ), + } + ) + + +def purge_old_log_files(): + """ + Purge the old log file at the beginning when the benchmark script runs. + Should do it in the parent process rather than the child processes running + each individual model. + """ + for name, table in REGISTERED_METRIC_TABLES.items(): + if name in enabled_metric_tables(): + filename = table.output_filename() + if os.path.exists(filename): + os.unlink(filename) + + table.write_header() + + +@lru_cache +def enabled_metric_tables() -> Set[str]: + config_str = config.enabled_metric_tables + + enabled = set() + for name in config_str.split(","): + name = name.strip() + if not name: + continue + assert ( + name in REGISTERED_METRIC_TABLES + ), f"Metric table name {name} is not registered" + enabled.add(name) + return enabled + + +def is_metric_table_enabled(name): + return name in enabled_metric_tables() + + +def get_metric_table(name): + assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined" + return REGISTERED_METRIC_TABLES[name] diff --git a/lib/python3.10/site-packages/torch/_inductor/mkldnn_ir.py b/lib/python3.10/site-packages/torch/_inductor/mkldnn_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..12634c632b80a9e6c9dd52ea9b8ddf2b190b1943 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/mkldnn_ir.py @@ -0,0 +1,1881 @@ +# mypy: allow-untyped-defs +from typing import Any, List, Optional + +import sympy + +import torch +from torch._prims_common import make_channels_last_strides_for +from torch.utils._ordered_set import OrderedSet + +from .ir import ( + ExternKernelAlloc, + FixedLayout, + FlexibleLayout, + ir_node_to_tensor, + IRNode, + is_contiguous_storage_and_layout, + Layout, + may_convert_to_optional, + MultiOutput, + MultiOutputLayout, + MutationOutput, + NoneLayout, + TensorBox, +) +from .utils import convert_shape_to_inductor, pad_listlike +from .virtualized import V + + +def _prepare_convolution_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding: List[int], + stride: List[int], + dilation: List[int], + groups: int, + transposed: bool = False, + output_padding: Optional[List[int]] = None, +): + """ + This function is a helper function to prepare inputs, layout and constant args + for convolution post-op fusion's create function, including deciding the output + layout (channels first or channels last), realizing inputs and make them etc. The + function only supports the CPU device since conv post-op fusion kernel is only + supported on CPU right now. + """ + + # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size + def _conv_input_size( + output_size, weight_size, padding, output_padding, stride, dilation, groups + ): + assert len(output_size) == len(weight_size), "Expect input dim == weight dim" + dim = len(output_size) + assert dim > 2, "Expect input dim > 2" + + BATCH_DIM = 0 + WEIGHT_INPUT_CHANNELS_DIM = 1 + input_size = [] + input_size.append(output_size[BATCH_DIM]) + input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) + for d in range(2, dim): + kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 + input_size_d = ( + (output_size[d] - 1) * stride[d - 2] + - (padding[d - 2] * 2) + + kernel + + output_padding[d - 2] + ) + input_size.append(input_size_d) + return list(map(int, input_size)) + + # The size of prepacked_weight is the prepacked weight size of deconv: + # Groups > 1: [g*o, i/g, ...] + # Groups == 1: [o, i, ...] + # Returns original weight size in [i, o, ...] + def _original_deconv_weight_size( + prepacked_weight, + groups, + ): + prepacked_weight_size = prepacked_weight.size() + dim = len(prepacked_weight_size) + assert dim > 2, "Expect weight dim > 2" + if groups > 1: + weight_size = [] + weight_size.append(prepacked_weight_size[1] * groups) + weight_size.append(prepacked_weight_size[0] / groups) + for d in range(2, dim): + weight_size.append(prepacked_weight_size[d]) + else: + weight_size = prepacked_weight.transpose(0, 1).size() + return weight_size + + x.realize() + weight.realize() + if bias is not None: + bias.realize() + with V.graph.fake_mode: + # TODO cleaned up the fake_tensor trace as Linear implementation + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + dims = len(x_fake.size()) - 2 + assert 0 < len(padding) <= dims + assert 0 < len(dilation) <= dims + assert 0 < len(stride) <= dims + padding = pad_listlike(padding, dims) + dilation = pad_listlike(dilation, dims) + stride = pad_listlike(stride, dims) + if output_padding is None: + output_padding = pad_listlike([0], dims) + else: + assert 0 < len(output_padding) <= dims + output_padding = pad_listlike(output_padding, dims) + assert isinstance(groups, (int, sympy.core.numbers.Integer)) + if transposed: + # When transposed, the size of the prepacked oneDNN weight is different + # from the PyTorch weight. We're not able to run aten conv with such + # size. We infer the output size from the input params here: + weight_size = _original_deconv_weight_size(weight_fake, groups) + input_size = x_fake.size() + output_size = _conv_input_size( + input_size, + weight_size, + padding, + output_padding, + stride, + dilation, + groups, + ) + else: + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + output_size = output.size() + + req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) + req_stride_order = [len(req_stride_order)] + req_stride_order + + x = cls.require_stride_order(x, req_stride_order) + + # We won't do weight prepack for Conv if dynamic_shapes. + # In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel. + # In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1), + # x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order + # won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel, + # this tensor is considered as channels first and the output will be in contiguous format. + # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. + dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) + if dynamic_shapes and is_contiguous_storage_and_layout(x): + output_stride = FlexibleLayout.contiguous_strides(output_size) + else: + output_stride = make_channels_last_strides_for(output_size) + + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + convert_shape_to_inductor(output_size), + convert_shape_to_inductor(output_stride), + ) + constant_args = [padding, stride, dilation, groups] + if transposed: + constant_args.insert(1, output_padding) + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +def _prepare_linear_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", +): + """ + This function is a helper function to prepare inputs, layout and constant args + for linear post-op fusion's create function. The function only supports the CPU device + since linear post-op fusion kernel is only supported on CPU right now. + """ + x.realize() + weight.realize() + if bias is not None: + bias.realize() + + *m, _ = x.get_size() + # The weight has been transposed during the qlinear weight prepack process. + # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/ + # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291 + _, oc = weight.get_size() + output_size = list(m) + [oc] + req_stride_order = list(reversed(range(len(x.get_size())))) + + x = cls.require_stride_order(x, req_stride_order) + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + output_stride = FlexibleLayout.contiguous_strides(output_size) + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ) + constant_args: List[Any] = [] + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +class ConvolutionUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_pointwise.default, + ) + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const std::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view attr, + torch::List> scalars, + std::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + op_overload=self.op_overload, + raw_args=[*self.inputs, *self.constant_args], + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + attr, + scalars: Optional[List[Any]], + algorithm, + ): + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + return ConvolutionUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class ConvolutionBinary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + cpp_constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_pointwise.binary, + ) + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& other_t, + const at::Tensor& weight_t, + const std::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm)""" + self.cpp_constant_args = cpp_constant_args + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + self.op_overload, + [*self.inputs, *self.constant_args], + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + return ConvolutionBinary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class ConvolutionBinaryInplace(ExternKernelAlloc): + def __init__( + self, + kernel_layout, + inputs, + constant_args=(), + ) -> None: + # Due to constrain of op.call, other (Tensor&) should be at input[0] + reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] + + super().__init__( + kernel_layout, + reordered_inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_pointwise_.binary, + ) + # TODO: op.call: input[0] should be at::Tensor& + self.cpp_op_schema = """ + at::Tensor&( + at::Tensor& other_t, + const at::Tensor& input_t, + const at::Tensor& weight_t, + const std::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm)""" + + self.mutation_outputs = [ + MutationOutput(NoneLayout(inputs[0].get_device()), inputs[0], self), + MutationOutput(NoneLayout(inputs[1].get_device()), inputs[1], self), + ] + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + self.op_overload, + [*self.inputs, *self.constant_args], + ) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + _, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + packed = ConvolutionBinaryInplace( + kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type] + inputs=inputs, + constant_args=constant_args, + ) + # This op mutates in place which means that the result is not the + # target but rather the input that is being mutated + # init reorders the inputs, so inputs[1] becomes packed.inputs[0] + return packed.inputs[0] + + +class ConvolutionTransposeUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_transpose_pointwise.default, + ) + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const std::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef output_padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view attr, + torch::List> scalars, + std::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + output_padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups_: int, + attr, + scalars: Optional[List[Any]], + algorithm, + ): + transposed = True + ( + inputs, + constant_args, + kernel_layout, + _, + ) = _prepare_convolution_fusion_create( + cls, + x, + weight, + bias, + padding_, + stride_, + dilation_, + groups_, + transposed, + output_padding_, + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + return ConvolutionTransposeUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class QConvPointWisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 5 + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.onednn.qconv2d_pointwise.default, + ) + self.cpp_op_schema = """ + at::Tensor( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + c10::string_view attr, + torch::List> scalars, + std::optional algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation + args = [x.codegen_reference() for x in self.inputs] + const_arg_names = [ + "x_scale", + "x_zero_point", + "stride", + "padding", + "dilation", + "groups", + "output_scale", + "output_zero_point", + "output_dtype", + "attr", + "scalars", + "algorithm", + ] + if not self.has_bias: + const_arg_names.insert(2, "bias") + const_args = list(self.codegen_const_args(const_arg_names)) + + x = args[0] + x_raw = self.inputs[0] + packed_weight = args[1] + packed_weight_raw = self.inputs[1] + bias = args[2] if self.has_bias else const_args[2] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[2] + w_scale, w_zp = args[-2], args[-1] + w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1] + ( + x_scale, + x_zp, + ) = const_args[:2] + ( + x_scale_raw, + x_zp_raw, + ) = self.constant_args[:2] + ( + stride, + padding, + dilation, + groups, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-10:] + ( + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-10:] + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + bias_raw, + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + op_overload=self.op_overload, + raw_args=raw_args, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: float, + x_zero_point: int, + qw: "TensorBox", # qw + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, + output_scale: float, + output_zero_point: int, + output_dtype, + attr, + scalars, + algorithm, + ): + transposed = False + output_padding = None + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, + qx, + qw, + bias, + padding, + stride, + dilation, + groups, + transposed, + output_padding, + ) + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + + constant_args = ( + [ + x_scale, + x_zero_point, + ] + + constant_args + + [ + output_scale, + output_zero_point, + output_dtype, + attr, + may_convert_to_optional(scalars), + algorithm, + ] + ) + + assert output_dtype is not None + if output_dtype in [torch.float32, torch.bfloat16]: + # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. + kernel_layout.dtype = output_dtype + + return QConvPointWisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class QConvPointWiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + """ + Needs input/weight/output qparams + if bias is not None + - inputs = [x, w, b, accum, w_scale, w_zp] + - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, accum, w_scale, w_zp] + - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale, + accum_zp, o_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 6 + self.idx_for_inplace_sum = 3 if self.has_bias else 2 + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.onednn.qconv2d_pointwise.binary, + ) + self.cpp_op_schema = """ + at::Tensor( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor accum, + double accum_scale, + int64_t accum_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + c10::string_view binary_attr, + std::optional alpha, + std::optional attr, + torch::List> scalars, + std::optional algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation + args = [x.codegen_reference() for x in self.inputs] + const_arg_names = [ + "x_scale", + "x_zero_point", + "accum_scale", + "accum_zero_point", + "stride", + "padding", + "dilation", + "groups", + "output_scale", + "output_zero_point", + "output_dtype", + "binary_attr", + "alpha", + "unary_attr", + "unary_scalars", + "unary_algorithm", + ] + if not self.has_bias: + const_arg_names.insert(4, "bias") + const_args = list(self.codegen_const_args(const_arg_names)) + + x = args[0] + x_raw = self.inputs[0] + packed_weight = args[1] + packed_weight_raw = self.inputs[1] + bias = args[2] if self.has_bias else const_args[4] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[4] + accum, w_scale, w_zp = args[-3], args[-2], args[-1] + accum_raw, w_scale_raw, w_zp_raw = ( + self.inputs[-3], + self.inputs[-2], + self.inputs[-1], + ) + ( + x_scale, + x_zp, + accum_scale, + accum_zp, + ) = const_args[:4] + ( + x_scale_raw, + x_zp_raw, + accum_scale_raw, + accum_zp_raw, + ) = self.constant_args[:4] + ( + stride, + padding, + dilation, + groups, + o_scale, + o_zp, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-12:] + ( + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-12:] + conv_args = ( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_scale, + o_zp, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + accum_raw, + accum_scale_raw, + accum_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + bias_raw, + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + conv_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + op_overload=self.op_overload, + raw_args=raw_args, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_mutation_names(self): + return [self.inputs[self.idx_for_inplace_sum].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale, + x_zero_point, + qaccum: "TensorBox", + accum_scale, + accum_zero_point, + qw: "TensorBox", # packed_weight + w_scale, + w_zero_point, + bias: "TensorBox", + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, + output_scale: "TensorBox", + output_zero_point: "TensorBox", + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + transposed = False + output_padding = None + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, + qx, + qw, + bias, + padding, + stride, + dilation, + groups, + transposed, + output_padding, + ) + + qaccum = cls.require_stride_order(qaccum, req_stride_order) + inputs.append(qaccum) + + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + constant_args = ( + [ + x_scale, + x_zero_point, + accum_scale, + accum_zero_point, + ] + + constant_args + + [ + output_scale, + output_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + ) + + assert ( + binary_attr == "sum" + ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." + + V.graph.mark_buffer_mutated(qaccum.get_name()) + packed = QConvPointWiseBinaryPT2E( + layout=NoneLayout(qaccum.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + + # Return accum since it has been inplace changed. + return packed.inputs[packed.idx_for_inplace_sum] + + +class MKLPackedLinear(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkl._mkl_linear.default, + ) + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& self, + const at::Tensor& mkl_weight_t, + const at::Tensor& origin_weight_t, + const std::optional& bias_opt, + const int64_t prepack_batch_size)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create(cls, x, packed_w, orig_w, B, batch_size): + x = cls.require_stride1(cls.realize_input(x)) + orig_w = cls.require_stride1(cls.realize_input(orig_w)) + *m, _ = x.get_size() + oc, _ = orig_w.get_size() + output_size = list(m) + [oc] + output_stride = FlexibleLayout.contiguous_strides(output_size) + inputs = [x, packed_w, orig_w] + constant_args = [batch_size] + if B is not None: + inputs += [B] + else: + constant_args.insert(0, None) + + return MKLPackedLinear( + layout=FixedLayout( + x.get_device(), x.get_dtype(), output_size, output_stride + ), + inputs=inputs, + constant_args=constant_args, + ) + + +class LinearUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._linear_pointwise.default, + ) + self.cpp_kernel_key = "linear_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const std::optional& bias_opt, + c10::string_view attr, + torch::List> scalars, + std::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create(cls, x, w, B, attr, scalars, algorithm): + x = cls.require_contiguous(cls.realize_input(x)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + inputs = [x, w] + constant_args = [attr, scalars if scalars else [-1], algorithm] + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) + else: + constant_args.insert(0, None) + + return LinearUnary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + ) + + def apply_constraint(self): + pass + + +class LinearBinary(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._linear_pointwise.binary" + + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._linear_pointwise.binary, + ) + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& other_t, + const at::Tensor& weight_t, + const std::optional& bias_opt, + c10::string_view attr) + """ + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + + @classmethod + def create(cls, x, y, w, B, attr): + x = cls.require_contiguous(cls.realize_input(x)) + y = cls.require_contiguous(cls.realize_input(y)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + + inputs = [x, y, w] + constant_args = [attr] + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) + else: + constant_args.insert(0, B) + + return LinearBinary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + ) + + def apply_constraint(self): + pass + + +class QLinearPointwisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + x_scale_zp_are_tensors=False, + ) -> None: + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default, + ) + x_scale_type_str, x_zp_type_str = ( + ("at::Tensor", "at::Tensor") + if x_scale_zp_are_tensors + else ("double", "int64_t") + ) + self.cpp_op_schema = f""" + at::Tensor( + at::Tensor act, + {x_scale_type_str} act_scale, + {x_zp_type_str} act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + c10::string_view post_op_name, + torch::List> post_op_args, + c10::string_view post_op_algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + x_raw = self.inputs[0] + packed_weight = args[1] + packed_weight_raw = self.inputs[1] + bias = args[2] if self.has_bias else const_args[0] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0] + w_scale, w_zp = args[-2], args[-1] + w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1] + if self.x_scale_zp_are_tensors: + assert len(args) >= 4 + x_scale, x_zp = args[-4], args[-3] + x_scale_raw, x_zp_raw = self.inputs[-4], self.inputs[-3] + ( + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-6:] + ( + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-6:] + else: + assert len(const_args) >= 8 + ( + x_scale, + x_zp, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-8:] + ( + x_scale_raw, + x_zp_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-8:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + bias_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + self.op_overload, + raw_args, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: float, + x_zero_point: int, + qw: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + output_scale: float, + output_zero_point: int, + output_dtype, + post_op_name, + post_op_args, + post_op_algorithm, + ): + (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create( + cls, + qx, + qw, + bias, + ) + + if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox): + x_scale.realize() + x_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + x_scale_zp_are_tensors = True + else: + assert isinstance(x_scale, float) and isinstance(x_zero_point, int) + constant_args = constant_args + [x_scale, x_zero_point] + x_scale_zp_are_tensors = False + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + post_op_name, + may_convert_to_optional(post_op_args), + post_op_algorithm, + ] + + assert output_dtype is not None + if output_dtype in [torch.float32, torch.bfloat16]: + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + + +class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + x_scale_zp_are_tensors=False, + ) -> None: + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp, x2] + - const_args is: [x_scale, x_zp, o_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp, x2] + - const_args is: [bias, x_scale, x_zp, o_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary, + ) + x_scale_type_str, x_zp_type_str = ( + ("at::Tensor", "at::Tensor") + if x_scale_zp_are_tensors + else ("double", "int64_t") + ) + self.cpp_op_schema = f""" + at::Tensor( + at::Tensor act, + {x_scale_type_str} act_scale, + {x_zp_type_str} act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional other, + std::optional bias, + double inv_output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double other_scale, + int64_t other_zero_point, + c10::string_view binary_post_op, + double binary_alpha, + c10::string_view unary_post_op, + torch::List> unary_post_op_args, + c10::string_view unary_post_op_algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + x_raw = self.inputs[0] + packed_weight = args[1] + packed_weight_raw = self.inputs[1] + bias = args[2] if self.has_bias else const_args[0] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0] + w_scale, w_zp, other = args[-3], args[-2], args[-1] + w_scale_raw, w_zp_raw, other_raw = ( + self.inputs[-3], + self.inputs[-2], + self.inputs[-1], + ) + if self.x_scale_zp_are_tensors: + assert len(args) >= 5 + x_scale, x_zp = args[-5], args[-4] + x_scale_raw, x_zp_raw = self.inputs[-5], self.inputs[-4] + ( + o_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-10:] + ( + o_scale_raw, + o_zp_raw, + output_dtype_raw, + other_scale_raw, + other_zp_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-10:] + else: + assert len(const_args) >= 8 + ( + x_scale, + x_zp, + o_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-12:] + ( + x_scale_raw, + x_zp_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + other_scale_raw, + other_zp_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-12:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + other, + bias, + o_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + other_raw, + bias_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + other_scale_raw, + other_zp_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + self.op_overload, + raw_args, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_mutation_names(self): + binary_post_op = self.constant_args[-5] + if binary_post_op == "sum": + return [self.inputs[-1].get_name()] + else: + return [] + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: float, + x_zero_point: int, + qw: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zero_point: "TensorBox", + other: "TensorBox", + bias: "TensorBox", + output_scale: float, + output_zero_point: int, + output_dtype, + other_scale, + other_zp, + binary_post_op, + binary_alpha, + unary_post_op, + unary_post_op_args, + unary_post_op_algorithm, + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_linear_fusion_create( + cls, + qx, + qw, + bias, + ) + + if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox): + x_scale.realize() + x_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + x_scale_zp_are_tensors = True + else: + assert isinstance(x_scale, float) and isinstance(x_zero_point, int) + constant_args = constant_args + [x_scale, x_zero_point] + x_scale_zp_are_tensors = False + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + if binary_post_op == "sum": + other = cls.require_stride_order(other, req_stride_order) + inputs.append(other) + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + other_scale, + other_zp, + binary_post_op, + binary_alpha, + unary_post_op, + may_convert_to_optional(unary_post_op_args), + unary_post_op_algorithm, + ] + + if binary_post_op == "sum": + V.graph.mark_buffer_mutated(other.get_name()) + packed = QLinearPointwiseBinaryPT2E( + layout=NoneLayout(other.get_device()), + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + # Return other since it has been inplace changed. + return packed.inputs[-1] + + assert output_dtype is not None + if output_dtype in [torch.float32, torch.bfloat16]: + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwiseBinaryPT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + + +class MkldnnRnnLayer(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.aten.mkldnn_rnn_layer.default, + ) + + @classmethod + def create( + cls, + x: "TensorBox", + w0: "TensorBox", + w1: "TensorBox", + w2: "TensorBox", + w3: "TensorBox", + hx: "TensorBox", + cx: "TensorBox", + reverse: bool, + batch_sizes: List[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + x = cls.require_stride1(cls.realize_input(x)) + # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. + # Make sure x is contiguous in batch_first case. + x.freeze_layout() + w0 = cls.require_stride1(cls.realize_input(w0)) + w1 = cls.require_stride1(cls.realize_input(w1)) + w2 = cls.require_stride1(cls.realize_input(w2)) + w3 = cls.require_stride1(cls.realize_input(w3)) + hx = cls.require_stride1(cls.realize_input(hx)) + hx.freeze_layout() + cx = cls.require_stride1(cls.realize_input(cx)) + cx.freeze_layout() + + input_size = x.get_size() + assert len(input_size) == 3, "Expect lstm input to be 3D" + # batch_first is handled in the lstm OP. When entering + # rnn_layer here, we'll always have batch_first = False + seq_length, mini_batch, input_size = input_size + output_shape = [seq_length, mini_batch, hidden_size] + + hy_shape = hx.get_size() + cy_shape = cx.get_size() + + res: List[IRNode] = [] + + inputs = [x, w0, w1, w2, w3, hx, cx] + constant_args = [ + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ] + + packed = MkldnnRnnLayer( + MultiOutputLayout(x.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + + def get_strides_of_lstm_output(output_shape, batch_first): + assert len(output_shape) == 3, "Expect output_shape to be 3D" + return FlexibleLayout.contiguous_strides(output_shape) + + output_sizes = [output_shape, hy_shape, cy_shape] + output_strides = [ + get_strides_of_lstm_output(output_shape, batch_first), + FlexibleLayout.contiguous_strides(hy_shape), + FlexibleLayout.contiguous_strides(cy_shape), + ] + output_ir = [ + MultiOutput( + FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ), + packed, + [(tuple, i)], + ) + for i, (output_size, output_stride) in enumerate( + zip(output_sizes, output_strides) + ) + ] + + return output_ir diff --git a/lib/python3.10/site-packages/torch/_inductor/mkldnn_lowerings.py b/lib/python3.10/site-packages/torch/_inductor/mkldnn_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..a9cc0bc8299ebfce4bc67628ac21af27a847a936 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/mkldnn_lowerings.py @@ -0,0 +1,1087 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +from typing import List, Optional + +import torch +import torch.utils._pytree as pytree +from torch._inductor.kernel.mm_common import mm_args + +from . import ir +from .codegen.cpp_gemm_template import CppPackedGemmTemplate +from .codegen.cpp_utils import create_epilogue_with_attr +from .ir import TensorBox +from .lowering import ( + add, + add_needs_realized_inputs, + aten, + permute, + register_lowering, + to_dtype, + view, +) +from .select_algorithm import ( + autotune_select_algorithm, + ChoiceCaller, + ExternKernelChoice, +) +from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune +from .virtualized import ops, V + + +def register_onednn_fusion_ops(): + if torch._C._has_mkldnn: + from . import mkldnn_ir + + aten_mkldnn_linear_unary = ExternKernelChoice( + torch.ops.mkldnn._linear_pointwise, + "mkldnn::_linear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.LinearUnary.create, + ) + aten_mkldnn_linear_binary = ExternKernelChoice( + torch.ops.mkldnn._linear_pointwise.binary, + "mkldnn::_linear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.LinearBinary.create, + ) + aten_mkldnn_qlinear_unary = ExternKernelChoice( + torch.ops.onednn.qlinear_pointwise, + "onednn::qlinear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.QLinearPointwisePT2E.create, + ) + aten_mkldnn_qlinear_binary = ExternKernelChoice( + torch.ops.onednn.qlinear_pointwise.binary, + "onednn::qlinear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.QLinearPointwiseBinaryPT2E.create, + ) + cpu_needs_realized_inputs = [ + torch.ops.mkldnn._convolution_pointwise, + torch.ops.mkldnn._convolution_pointwise_, + torch.ops.mkldnn._convolution_transpose_pointwise, + torch.ops.mkldnn._linear_pointwise, + aten.mkldnn_rnn_layer.default, + torch.ops.onednn.qconv2d_pointwise, + ] + + @register_lowering(torch.ops.mkldnn._convolution_pointwise) + def convolution_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionUnary.create( + x, + weight, + bias, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary) + def convolution_binary( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionBinary.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary) + def convolution_binary_inplace( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionBinaryInplace.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._linear_pointwise) + def linear_unary( + x: TensorBox, + w: TensorBox, + b: TensorBox, + attr, + scalars, + algorithm, + layout=None, + ): + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + if b is not None: + b = ir.ExternKernel.realize_input(b) + choices: List[ChoiceCaller] = [] + if use_max_autotune(): + transposed_w = permute(w, [1, 0]) + *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout) + if use_cpp_packed_gemm_template(layout, x, transposed_w): + + def epilogue_creator(buf): + return create_epilogue_with_attr( + buf, attr, scalars=scalars, algorithm=algorithm + ) + + kwargs = dict( + has_bias=b is not None, + trans_w=True, + epilogue_creator=None if attr == "none" else epilogue_creator, + ) + if b is not None: + kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment] + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, w] if b is None else [x, w, b], + **kwargs, # type: ignore[arg-type] + ) + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm) + if b is None: + kwargs["B"] = None + choices.append( + aten_mkldnn_linear_unary.bind( + [x, w] if b is None else [x, w, b], + layout, + **kwargs, + ) + ) + assert w.get_name() in V.graph.constants + input_gen_fns = { + 1: lambda x: V.graph.constants[x.get_name()], + } + result = autotune_select_algorithm( + "linear_unary", + choices, + [x, w] if b is None else [x, w, b], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + @register_lowering(torch.ops.mkldnn._linear_pointwise.binary) + def linear_binary( + x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None + ): + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + y_size = y.get_size() + if len(y_size) > 2: + y = view(y, [-1, y_size[-1]]) + if b is not None: + b = ir.ExternKernel.realize_input(b) + choices: List[ChoiceCaller] = [] + if use_max_autotune(): + transposed_w = permute(w, [1, 0]) + *_, layout, x, transposed_w, y = mm_args( + x, transposed_w, y, layout=layout + ) + if use_cpp_packed_gemm_template(layout, x, transposed_w): + + def epilogue_creator(buf): + return create_epilogue_with_attr(buf, attr, other=y) + + kwargs = dict( + has_bias=b is not None, + trans_w=True, + epilogue_creator=epilogue_creator, + ) + kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1] + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, y, w] if b is None else [x, y, w, b], + **kwargs, # type: ignore[arg-type] + ) + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict(attr=attr) + if b is None: + kwargs["B"] = None + choices.append( + aten_mkldnn_linear_binary.bind( + [x, y, w] if b is None else [x, y, w, b], + layout, + **kwargs, + ) + ) + assert w.get_name() in V.graph.constants + input_gen_fns = { + 2: lambda x: V.graph.constants[x.get_name()], + } + result = autotune_select_algorithm( + "linear_binary", + choices, + [x, y, w] if b is None else [x, y, w, b], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise) + def convolution_transpose_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + output_padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionTransposeUnary.create( + x, + weight, + bias, + padding, + output_padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(aten.mkldnn_rnn_layer.default) + def mkldnn_rnn_layer( + x: TensorBox, + w0: TensorBox, + w1: TensorBox, + w2: TensorBox, + w3: TensorBox, + hx: TensorBox, + cx: TensorBox, + reverse: bool, + batch_sizes: List[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + return pytree.tree_map( + TensorBox.create, + mkldnn_ir.MkldnnRnnLayer.create( + x, + w0, + w1, + w2, + w3, + hx, + cx, + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ), + ) + + @register_lowering(torch.ops.onednn.qconv2d_pointwise, type_promotion_kind=None) + def qconvolution_unary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + mkldnn_ir.QConvPointWisePT2E.create( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering( + torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None + ) + def qconvolution_binary( + x: TensorBox, + x_scale, + x_zp, + accum: TensorBox, + accum_scale, + accum_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ): + if ( + binary_attr == "sum" + and output_dtype in [torch.float32, torch.bfloat16] + and accum.get_dtype() in [torch.float32, torch.bfloat16] + and accum.get_dtype() != output_dtype + ): + # For int8-mixed-bf16 quantization and inplace add, + # there is case when accum dtype is float32 but output dtype is bfloat16. + # Since the accum will be inplaced changed with post op sum, + # we will do accum dtype convertion here. + accum = to_dtype(accum, output_dtype) + return TensorBox.create( + mkldnn_ir.QConvPointWiseBinaryPT2E.create( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ) + ) + + @register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None) + def qlinear_unary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + o_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + layout=None, + ): + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + if not isinstance(x_scale, ir.TensorBox): + assert type(x_scale) == float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + else: + x_scale.realize() + if not isinstance(x_zp, ir.TensorBox): + assert type(x_zp) == int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + else: + x_zp.realize() + + # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer + # Refer to https://github.com/pytorch/pytorch/blob + # /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 + w_scale.realize() + w_zp.realize() + if w_zp.get_dtype() != torch.int32 and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ): + # W_zp might be a ConstantBuffer with int64, convert it to int32 + w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) + w_zp = V.graph.add_tensor_constant( + torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() + ) + + bias_dtype = None if bias is None else bias.get_dtype() + + choices: List[ChoiceCaller] = [] + if use_max_autotune(): + *_, layout, x, packed_weight = mm_args( + x, packed_weight, layout=layout, out_dtype=output_dtype + ) + if ( + isinstance( + ir.InputsKernel.unwrap_storage_for_input(x_zp), + ir.ConstantBuffer, + ) + and len(x_zp.get_layout().size) == 0 # Per tensor quant of act + and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ) + and torch.equal( + torch.zeros_like(V.graph.constants[w_zp.get_name()]), + V.graph.constants[w_zp.get_name()], + ) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA + and use_cpp_packed_gemm_template(layout, x, packed_weight) + ): + W_tensor = V.graph.constants[packed_weight.get_name()].to_dense() + weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0) + weight_compens = V.graph.add_tensor_constant( + weight_compens_tensor, + name=packed_weight.get_name() + "_BMatrixCompens", + ) + + def epilogue_creator(input_buffer): + # Epilogue to convert from s32 to f32 for u8s8f32 + assert output_dtype in [ + torch.float32, + torch.bfloat16, + torch.uint8, + ] + input_loader = input_buffer.make_loader() + weight_compens_loader = weight_compens.make_loader() + x_scale_loader = x_scale.make_loader() + w_scale_loader = w_scale.make_loader() + x_zp_loader = x_zp.make_loader() + nonlocal bias + bias_loader = None + if bias is not None: + bias_loader = bias.make_loader() + + def inner_fn(index): + nonlocal bias + input = input_loader(index) + # MicroKernel Output is with int32 + # cvt to FP32 before doing compensation + input = ops.to_dtype(input, torch.float32) + weight_compens_index = (index[-1],) + _x_scale = x_scale_loader(()) + _x_zp = x_zp_loader(()) + _w_scale = w_scale_loader(weight_compens_index) + _weight_compo = weight_compens_loader(weight_compens_index) + # Step 1: Doing compensation to cvt fp32 + temp = ops.mul( + ops.mul( + input, + _x_scale, + ), + _w_scale, + ) + temp = ops.sub( + temp, + ops.mul( + ops.mul( + ops.mul( + _x_scale, + _w_scale, + ), + _x_zp, + ), + _weight_compo, + ), + ) + # Step 2: add Bias if applicable + if bias is not None: + _bias = bias_loader(weight_compens_index) + nonlocal bias_dtype + assert bias_dtype in [torch.float32, torch.bfloat16] + if bias_dtype == torch.bfloat16: + _bias = ops.to_dtype(_bias, torch.float32) + temp = ops.add(temp, _bias) + + return temp + + output_buf = ir.Pointwise( + device=input_buffer.get_device(), + dtype=torch.float32, # Hardcode to FP32 for u8s8f32 + inner_fn=inner_fn, + ranges=input_buffer.get_size(), + ) + + # Step 3: Doing the unary post op fusion + if attr != "none": + output_buf = create_epilogue_with_attr( + output_buf, attr, scalars=scalars, algorithm=algorithm + ) + + # Step 4: Cast output to Target Dtype + if output_dtype == torch.bfloat16: + output_cast_loader = output_buf.make_loader() + + def inner_fn_cast_output_to_bf16(index): + input = output_cast_loader(index) + return ops.to_dtype(input, output_dtype) + + output_buf = ir.Pointwise( + device=output_buf.get_device(), + dtype=output_dtype, + inner_fn=inner_fn_cast_output_to_bf16, + ranges=output_buf.get_size(), + ) + elif output_dtype == torch.uint8: + from .lowering import _create_constants + + requant_input_loader = output_buf.make_loader() + + def inner_fn_requant(index, scale, zero_point): + input = requant_input_loader(index) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + qmin, qmax = _create_constants( + 0, 255, dtype=torch.float32 + ) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, torch.uint8) + + output_buf = ir.Pointwise( + device=output_buf.get_device(), + dtype=output_dtype, + inner_fn=functools.partial( + inner_fn_requant, + scale=float(o_scale), + zero_point=int(o_zero_point), + ), + ranges=output_buf.get_size(), + ) + + return output_buf + + assert x.get_dtype() == torch.uint8 + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias], + has_bias=bias is not None, + epilogue_creator=epilogue_creator, + input_indices=[0, 3, 1, 2, 4, 5] + if bias is None + else [6, 0, 3, 1, 2, 4, 5], + ) + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict( + output_scale=o_scale, + output_zero_point=o_zero_point, + output_dtype=output_dtype, + post_op_name=attr, + post_op_args=scalars, + post_op_algorithm=algorithm, + ) + if bias is None: + kwargs["bias"] = None + choices.append( + aten_mkldnn_qlinear_unary.bind( + (x, x_scale, x_zp, packed_weight, w_scale, w_zp) + if bias is None + else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias), + layout, + **kwargs, + ) + ) + assert packed_weight.get_name() in V.graph.constants + input_gen_fns = { + 3: lambda x: V.graph.constants[x.get_name()], + 4: lambda x: V.graph.constants[x.get_name()], + 5: lambda x: V.graph.constants[x.get_name()], + 6: lambda x: V.graph.constants[x.get_name()], # For bias + } + result = autotune_select_algorithm( + "qlinear_unary", + choices, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + @register_lowering( + torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None + ) + @register_lowering( + torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None + ) + def qlinear_binary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + x2: TensorBox, + bias: TensorBox, + o_scale, + o_zero_point, + output_dtype, + x2_scale, + x2_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + layout=None, + ): + x_size = x.get_size() + x2_size = x2.get_size() + assert len(x_size) == len(x2_size) + if len(x_size) > 2 and binary_attr == "add": + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + x2 = view(x2, [-1, x2_size[-1]]) + if not isinstance(x_scale, ir.TensorBox): + assert type(x_scale) == float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + else: + x_scale.realize() + if not isinstance(x_zp, ir.TensorBox): + assert type(x_zp) == int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + else: + x_zp.realize() + + # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer + # Refer to https://github.com/pytorch/pytorch/blob + # /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 + w_scale.realize() + w_zp.realize() + if w_zp.get_dtype() != torch.int32 and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ): + w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) + w_zp = V.graph.add_tensor_constant( + torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() + ) + if binary_attr == "sum": + if output_dtype in [ + torch.float32, + torch.bfloat16, + ] and x2.get_dtype() in [torch.float32, torch.bfloat16]: + if x2.get_dtype() != output_dtype: + # For int8-mixed-bf16 quantization and inplace add, + # there is case when accum dtype is float32 but output dtype is bfloat16. + # Since the accum will be inplaced changed with post op sum, + # we will do accum dtype convertion here. + x2 = to_dtype(x2, output_dtype) + else: + assert ( + x2.get_dtype() == output_dtype + ), "dtype of accum for qlinear post op sum should be the same as output" + x2_dtype = x2.get_dtype() + bias_dtype = bias.get_dtype() if bias is not None else None + choices: List[ChoiceCaller] = [] + if ( + use_max_autotune() and binary_attr == "add" + ): # Support inplace sum fusion + *_, layout, x, packed_weight, x2 = mm_args( + x, packed_weight, x2, layout=layout, out_dtype=output_dtype + ) + if ( + isinstance( + ir.InputsKernel.unwrap_storage_for_input(x_zp), + ir.ConstantBuffer, + ) + and len(x_zp.get_layout().size) == 0 # Per tensor quant of act + and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ) + and torch.equal( + torch.zeros_like(V.graph.constants[w_zp.get_name()]), + V.graph.constants[w_zp.get_name()], + ) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA + and use_cpp_packed_gemm_template(layout, x, packed_weight) + ): + W_tensor = V.graph.constants[packed_weight.get_name()] + W_tensor = W_tensor.to_dense() + weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0) + weight_compens = V.graph.add_tensor_constant( + weight_compens_tensor, + name=packed_weight.get_name() + "_BMatrixCompens", + ) + + def epilogue_creator(input_buffer): + # Epilogue to convert from s32 to f32 for u8s8f32 + assert output_dtype in [ + torch.float32, + torch.bfloat16, + torch.uint8, + ] + + input_loader = input_buffer.make_loader() + x2_loader = x2.make_loader() + weight_compens_loader = weight_compens.make_loader() + x_scale_loader = x_scale.make_loader() + w_scale_loader = w_scale.make_loader() + x_zp_loader = x_zp.make_loader() + nonlocal bias + bias_loader = None + if bias is not None: + bias_loader = bias.make_loader() + + def inner_fn(index): + nonlocal bias + input = input_loader(index) + _x2 = x2_loader(index) + _x_scale = x_scale_loader(()) + _x_zp = x_zp_loader(()) + + # MicroKernel Output is with int32 + # cvt to FP32 before doing compensation + input = ops.to_dtype(input, torch.float32) + weight_compens_index = (index[-1],) + _w_scale = w_scale_loader(weight_compens_index) + _weight_compens = weight_compens_loader( + weight_compens_index + ) + # Step 1: Doing compensation to cvt fp32 + temp = ops.mul( + ops.mul( + input, + _x_scale, + ), + _w_scale, + ) + temp = ops.sub( + temp, + ops.mul( + ops.mul( + ops.mul( + _x_scale, + _w_scale, + ), + _x_zp, + ), + _weight_compens, + ), + ) + + # Step 2: add Bias if applicable + if bias is not None: + _bias = bias_loader(weight_compens_index) + nonlocal bias_dtype + assert bias_dtype in [torch.float32, torch.bfloat16] + if bias_dtype == torch.bfloat16: + _bias = ops.to_dtype(_bias, torch.float32) + temp = ops.add(temp, _bias) + + # Step 3: Binary add + nonlocal x2_dtype + assert x2_dtype in [torch.float32, torch.bfloat16] + if x2_dtype == torch.bfloat16: + _x2 = ops.to_dtype(_x2, torch.float32) + temp = ops.add(temp, _x2) + + return temp + + output_buf = ir.Pointwise( + device=input_buffer.get_device(), + dtype=torch.float32, # Hardcode to FP32 for u8s8f32 + inner_fn=inner_fn, + ranges=input_buffer.get_size(), + ) + + # Step 4: Unary post op if has + if unary_attr != "none": + output_buf = create_epilogue_with_attr( + output_buf, + unary_attr, + scalars=unary_scalars, + algorithm=unary_algorithmm, + ) + + # Step 5: Cast output to Target Dtype + if output_dtype == torch.bfloat16: + output_cast_loader = output_buf.make_loader() + + def inner_fn_cast_output_to_bf16(index): + input = output_cast_loader(index) + return ops.to_dtype(input, output_dtype) + + output_buf = ir.Pointwise( + device=output_buf.get_device(), + dtype=output_dtype, + inner_fn=inner_fn_cast_output_to_bf16, + ranges=output_buf.get_size(), + ) + elif output_dtype == torch.uint8: + from .lowering import _create_constants + + requant_input_loader = output_buf.make_loader() + + def inner_fn_requant(index, scale, zero_point): + input = requant_input_loader(index) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + qmin, qmax = _create_constants( + 0, 255, dtype=torch.float32 + ) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, torch.uint8) + + output_buf = ir.Pointwise( + device=output_buf.get_device(), + dtype=torch.uint8, + inner_fn=functools.partial( + inner_fn_requant, + scale=float(o_scale), + zero_point=int(o_zero_point), + ), + ranges=output_buf.get_size(), + ) + + return output_buf + + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias], + has_bias=bias is not None, + epilogue_creator=epilogue_creator, + # Reorder bias and x2 + input_indices=[0, 3, 1, 2, 4, 5, 6] + if bias is None + else [7, 0, 3, 1, 2, 4, 5, 6], + ) + + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict( + output_scale=o_scale, + output_zero_point=o_zero_point, + output_dtype=output_dtype, + other_scale=x2_scale, + other_zp=x2_zp, + binary_post_op=binary_attr, + binary_alpha=alpha, + unary_post_op=unary_attr, + unary_post_op_args=unary_scalars, + unary_post_op_algorithm=unary_algorithmm, + ) + if bias is None: + kwargs["bias"] = None + choices.append( + aten_mkldnn_qlinear_binary.bind( + (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2) + if bias is None + else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias), + layout, + **kwargs, + ) + ) + assert packed_weight.get_name() in V.graph.constants + input_gen_fns = { + 3: lambda x: V.graph.constants[x.get_name()], + 4: lambda x: V.graph.constants[x.get_name()], + 5: lambda x: V.graph.constants[x.get_name()], + } + if bias is not None: + input_gen_fns[7] = lambda x: V.graph.constants[x.get_name()] # For bias + result = autotune_select_algorithm( + "qlinear_binary", + choices, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2 and binary_attr == "add": + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + if torch._C.has_mkl: + aten_mkl_linear = ExternKernelChoice( + torch.ops.mkl._mkl_linear, + "mkl::_mkl_linear", + has_out_variant=False, + kernel_creator=mkldnn_ir.MKLPackedLinear.create, + ) + cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) + + @register_lowering(torch.ops.mkl._mkl_linear) + def mkl_packed_linear( + x: TensorBox, + packed_w: TensorBox, + orig_w: TensorBox, + b: Optional[TensorBox], + batch_size, + *, + layout=None, + ): + choices: List[ChoiceCaller] = [] + if use_max_autotune(): + transposed_w = permute(orig_w, [1, 0]) + *_, layout, x, transposed_w = mm_args( + x, transposed_w, layout=layout + ) + if use_cpp_packed_gemm_template(layout, x, transposed_w): + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, packed_w, orig_w], + trans_w=True, + input_indices=[0, 2], + ) + + if len(choices) == 0 or use_aten_gemm_kernels(): + choices.append( + aten_mkl_linear.bind( + (x, packed_w, orig_w), layout, B=None, batch_size=batch_size + ) + ) + + assert packed_w.get_name() in V.graph.constants + assert orig_w.get_name() in V.graph.constants + # packed_w is a mkldnn tensor which we can't generate directly + # so we use the weights from the original tensor in autotune. + input_gen_fns = { + 1: lambda x: V.graph.constants[x.get_name()], + 2: lambda x: V.graph.constants[x.get_name()], + } + result: TensorBox = autotune_select_algorithm( + "packed_linear", + choices, + [x, packed_w, orig_w], + layout, + input_gen_fns=input_gen_fns, + ) + if b is not None: + result = add(result, b) + return result + + add_needs_realized_inputs(cpu_needs_realized_inputs) + else: + pass diff --git a/lib/python3.10/site-packages/torch/_inductor/ops_handler.py b/lib/python3.10/site-packages/torch/_inductor/ops_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c47ee1026ab919f0d4c619a8e89278268f503d52 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/ops_handler.py @@ -0,0 +1,1093 @@ +# mypy: allow-untyped-defs +import itertools +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Literal, + NamedTuple, + Optional, + Tuple, + TypeVar, + Union, +) +from typing_extensions import Protocol +from unittest.mock import patch + +import sympy + +import torch +import torch.utils._pytree as pytree + +from ..utils._ordered_set import OrderedSet +from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str + + +T = TypeVar("T") +StoreMode = Optional[Literal["atomic_add"]] +ReductionType = Literal[ + "argmax", + "argmin", + "welford_reduce", + "welford_combine", + "any", + "max", + "min", + "prod", + "sum", + "xor_sum", +] + + +def _arg_str(a) -> str: + if isinstance(a, sympy.Expr): + return sympy_str(a) + return str(a) + + +# NB: This is not done as a parent class, because our ops handlers +# implementations make heavy use of __getattr__ magic, and pre-existing +# stubs for methods would interfere with this mechanism. +# +# TODO: A superclass that does desugaring for operations like +# reciprocal/square might be useful. +class OpsHandler(Protocol[T]): + """ + Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``, + as well as the contract for op handlers. The type T signifies the domain + of the abstract analysis AKA what all of the functions return / take as arguments + anywhere compute occurs. + + While these operators are typically dtype polymorphic (e.g., you can use mul + on both integers and floats), they do NOT do promotion and usually return the + same dtype as the input. You are expected to have handled type promotion + during ATen decompositions. Most operators correspond exactly to pointwise + operations as defined by torch, so when in doubt about semantics, check the + corresponding torch documentation. These are all scalar operations (so they + are defined to operate on a single element at a time.) + + For convenience, many operators take a src_dtype which indicates what the dtype + of the input argument is. Although in principle this can be derived by an + analysis, providing this for ops where it is useful helps avoid having to repeatedly + recompute dtype in code generation. + + Note that this often describes a class of static methods, for stateless + ops handlers. + + Handlers are often defined using ``__getattr__`` metaprogramming, which means + that you cannot declare that a type implements a protocol by inheriting from + it (as the type stubs count as attribute declarations and impede the getattr + magic method from being called). Instead, define a function that casts an + argument of your type to the protocol, which is sufficient to induce mypy to + test that the protocol is implemented correctly. Search for ``_typecheck_`` + in this file to see some examples. If you see an obscure error where a + class doesn't implement a Protocol, but mypy doesn't say why, check to see + that ``__getattr__`` is typed correctly (typically, it is not possible to + type ``__getattr__`` without typing it as ``Callable[..., Any]``) + """ + + def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T: + """Produces a scalar constant of type dtype.""" + ... + + def load_seed(self, name: str, offset: T): + """Computes inductor_prims.lookup_seed.""" + ... + + def rand(self, seed: T, offset: T) -> T: + """Computes inductor_prims.random with mode="rand". offset has dtype int32.""" + ... + + def randn(self, seed: T, offset: T) -> T: + """Computes inductor_prims.random with mode="randn". offset has dtype int32.""" + ... + + def randint64(self, seed: T, offset: T, low: T, high: T) -> T: + """Computes inductor_prims.randint. offset has dtype int32.""" + ... + + def masked(self, mask: T, body: Callable[[], T], other: T) -> T: + """ + Computes body, but only perform loads/stores if the boolean mask + evaluates to true. For example, you would use this if you needed to + perform an indirect load that may not be valid on some elements; + without masking, invalid accesses can cause IMAs. When mask is true, + the result is the result of body; otherwise it is other. Here, `other` + needs to be a constant. + + Contrast this with ops.where, which can multiplex between two values + that have been unconditionally computed. + """ + ... + + def where(self, condition: T, input: T, other: T) -> T: + """ + Computes torch.where: when condition is true, return input; otherwise return other. + """ + ... + + def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T: + """ + Converts a sympy expression into a scalar of type dtype. expr is typically + an indexing expression, thus the name; however, it can also be used in + non-indexing situations. + """ + ... + + def to_dtype( + self, + x: T, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types=True, + ) -> T: + """ + Convert x to dtype. src_dtype can be optionally set to specify what the original + dtype of x was, which can improve code generation (used by torch to(dtype=dtype)). + """ + ... + + def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with truncation semantics (similar to how the int + constructor works in Python). In Inductor codegen, this just decays + to trunc and then to_dtype, but this composite operation helps + roundtrips for Sympy evaluation. + + dtype is taken as an explicit parameter because the desired output + dtype is typically the index dtype, which may vary between int32 and + int64 depending on if we've shown that all the indexing operations can + be done in int32. + """ + ... + + def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def floor_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def round_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with round-to-even semantics. See also trunc_to_int. + """ + ... + + def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: + """ + Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) + src_dtype must be the original type of x. + """ + ... + + def identity(self, x: T) -> T: + """ + Returns x as is. This is used to trigger CSE. + """ + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These operations are only available in a "kernel" context. Check + # torch._inductor.codegen.common.CSEProxy for their typical implementation + # in op handler (routing to their respective implementations in the kernel + # handler) + # + # Importantly, inside a kernel, indexing and mask variables are available + # in scope, which are typically used by sympy.Expr indexing. + + def indirect_indexing( + self, x: T, size: sympy.Expr, check: bool = True, wrap_neg=True + ) -> sympy.Expr: + """ + Convert an integral x into a sympy.Expr that can be subsequently used in + indexing computation. 'size' represents an upper bound on the what valid + indexes can be; when 'check' is True, we check that the x is in bounds. + + NB: This is typically mandatory to implement for any analysis, because you + MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol). + """ + ... + + def load(self, name: str, index: sympy.Expr) -> T: + """ + Load from the memory location 'name', offset by some indexing expression 'index'. + """ + ... + + def store( + self, + name: str, + index: sympy.Expr, + value: T, + mode: StoreMode = None, + ) -> None: + """ + Store 'value' to the memory location 'name' offset by 'expr'. If + specified, 'mode' can require the store to be an atomic addition. + """ + ... + + # TODO: Better explain how the "collective" semantics of these ops; + # remember that the input value is a scalar, you can't reduce on it in the + # traditional sense! + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: T, + ) -> Union[T, Tuple[T, ...]]: + """ + Perform a 'reduction_type' reduction on 'value' of dtype 'src_dtype', + using 'dtype' as the accumulation dtype for the reduction. The result + is an intermediate computation which should be stored to the final + location using 'ops.store_reduction'. + + Valid reduction types are . For Welford reduction types, this + function returns multiple outputs; consult reduction_num_outputs to + determine the amount in metaprogramming applications. + """ + ... + + # TODO: in practice, this seems to actually return None, but not returning + # a T makes common __getattr__ idioms not type correctly. Figure out if + # this should be returning something. + def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T: + """ + Store the fully accumulated result of 'reduction' to the memory + location 'name' offset by 'expr'. + """ + ... + + def scan( + self, + dtypes: Tuple[torch.dtype, ...], + combine_fn: Callable[[Tuple[T, ...], Tuple[T, ...]], Tuple[T, ...]], + values: Tuple[T, ...], + ) -> Tuple[T, ...]: + """ + Perform an associative scan on 'value'. + """ + # TODO: Improve the description with some pseudocode + ... + + def sort( + self, + dtypes: Tuple[torch.dtype, ...], + values: Tuple[T, ...], + stable: bool, + descending: bool, + ) -> Tuple[T, ...]: + """ + Sort values along the reduction dimension. + """ + ... + + def bucketize( + self, + values: T, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> T: + # See [Note: Inductor bucketize op] + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # The following ops have semantics that correspond exactly to the torch + # operation with the same corresponding name. + + def abs(self, x0: T) -> T: + ... + + def exp(self, x0: T) -> T: + ... + + def exp2(self, x0: T) -> T: + ... + + def expm1(self, x0: T) -> T: + ... + + def sqrt(self, x0: T) -> T: + ... + + def relu(self, x0: T) -> T: + ... + + def minimum(self, x0: T, x1: T) -> T: + ... + + def maximum(self, x0: T, x1: T) -> T: + ... + + def cos(self, x0: T) -> T: + ... + + def sin(self, x0: T) -> T: + ... + + def lgamma(self, x0: T) -> T: + ... + + def erf(self, x0: T) -> T: + ... + + def cosh(self, x0: T) -> T: + ... + + def sinh(self, x0: T) -> T: + ... + + def acos(self, x0: T) -> T: + ... + + def acosh(self, x0: T) -> T: + ... + + def asin(self, x0: T) -> T: + ... + + def asinh(self, x0: T) -> T: + ... + + def atan2(self, x0: T, x1: T) -> T: + ... + + def atan(self, x0: T) -> T: + ... + + def atanh(self, x0: T) -> T: + ... + + def copysign(self, x0: T, x1: T) -> T: + ... + + def erfc(self, x0: T) -> T: + ... + + def erfinv(self, x0: T) -> T: + ... + + def frexp(self, x0: T): + ... + + def hypot(self, x0: T, x1: T) -> T: + ... + + def log10(self, x0: T) -> T: + ... + + def log2(self, x0: T) -> T: + ... + + def nextafter(self, x0: T, x1: T) -> T: + ... + + def logical_and(self, x0: T, x1: T) -> T: + ... + + def logical_not(self, x0: T) -> T: + ... + + def logical_or(self, x0: T, x1: T) -> T: + ... + + def logical_xor(self, x0: T, x1: T) -> T: + ... + + def bitwise_and(self, x0: T, x1: T) -> T: + ... + + def bitwise_not(self, x0: T) -> T: + ... + + def bitwise_or(self, x0: T, x1: T) -> T: + ... + + def bitwise_xor(self, x0: T, x1: T) -> T: + ... + + def bitwise_left_shift(self, x0: T, x1: T) -> T: + ... + + def bitwise_right_shift(self, x0: T, x1: T) -> T: + ... + + def rsqrt(self, x0: T) -> T: + ... + + def log1p(self, x0: T) -> T: + ... + + def tan(self, x0: T) -> T: + ... + + def tanh(self, x0: T) -> T: + ... + + def sigmoid(self, x0: T) -> T: + ... + + def signbit(self, x0: T) -> T: + ... + + def fmod(self, x0: T, x1: T) -> T: + ... + + def log(self, x0: T) -> T: + ... + + def isinf(self, x0: T) -> T: + ... + + def isnan(self, x0: T) -> T: + ... + + # NB: this returns a float, like the torch operation + # This rounds half to even to break ties + def round(self, x0: T) -> T: + ... + + # NB: this returns a float, like the torch operation + def floor(self, x0: T) -> T: + ... + + def sign(self, x0: T) -> T: + ... + + # NB: this returns a float, like the torch operation + def trunc(self, x0: T) -> T: + ... + + # NB: this returns a float, like the torch operation + def ceil(self, x0: T) -> T: + ... + + def neg(self, x0: T) -> T: + ... + + def reciprocal(self, x0: T) -> T: + ... + + def eq(self, x0: T, x1: T) -> T: + ... + + def ne(self, x0: T, x1: T) -> T: + ... + + def lt(self, x0: T, x1: T) -> T: + ... + + def gt(self, x0: T, x1: T) -> T: + ... + + def le(self, x0: T, x1: T) -> T: + ... + + def ge(self, x0: T, x1: T) -> T: + ... + + def add(self, x0: T, x1: T) -> T: + ... + + def sub(self, x0: T, x1: T) -> T: + ... + + def mul(self, x0: T, x1: T) -> T: + ... + + # NB: this returns a float, like the torch operation + def pow(self, x0: T, x1: T) -> T: + ... + + def and_(self, x0: T, x1: T) -> T: + ... + + def or_(self, x0: T, x1: T) -> T: + ... + + def xor(self, x0: T, x1: T) -> T: + ... + + # These are metaprogrammed by MockHandler._init_cls + def lshift(self, x0: T, x1: T) -> T: + ... + + def rshift(self, x0: T, x1: T) -> T: + ... + + def getitem(self, x0: T, x1: T) -> T: + # TODO: this is probably just illegal lol + ... + + def matmul(self, x0: T, x1: T) -> T: + # TODO: this is probably just illegal lol + ... + + def invert(self, x0: T) -> T: + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These are "special" operators. These only exist if the target + # language actually supports the operator. Keep this in sync with + # pointwise_overrides_data. + + def airy_ai(self, x: T) -> T: + ... + + def bessel_j0(self, x: T) -> T: + ... + + def bessel_j1(self, x: T) -> T: + ... + + def bessel_y0(self, x: T) -> T: + ... + + def bessel_y1(self, x: T) -> T: + ... + + def digamma(self, x: T) -> T: + ... + + def erfcx(self, x: T) -> T: + ... + + def fma(self, x: T, y: T, z: T) -> T: + ... + + def igamma(self, x: T, y: T) -> T: + ... + + def igammac(self, x: T, y: T) -> T: + ... + + def gammainc(self, x: T, y: T) -> T: + ... + + def gammaincc(self, x: T, y: T) -> T: + ... + + def i0(self, x: T) -> T: + ... + + def i0e(self, x: T) -> T: + ... + + def i1(self, x: T) -> T: + ... + + def i1e(self, x: T) -> T: + ... + + def log_ndtr(self, x: T) -> T: + ... + + def modified_bessel_i0(self, x: T) -> T: + ... + + def modified_bessel_i1(self, x: T) -> T: + ... + + def modified_bessel_k0(self, x: T) -> T: + ... + + def modified_bessel_k1(self, x: T) -> T: + ... + + def ndtr(self, x: T) -> T: + ... + + def ndtri(self, x: T) -> T: + ... + + def polygamma(self, x: T, y: T) -> T: + ... + + def scaled_modified_bessel_k0(self, x: T) -> T: + ... + + def scaled_modified_bessel_k1(self, x: T) -> T: + ... + + def spherical_bessel_j0(self, x: T) -> T: + ... + + def zeta(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_t(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_u(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_v(self, x: T, y: T) -> T: + ... + + def chebyshev_polynomial_w(self, x: T, y: T) -> T: + ... + + def legendre_polynomial_p(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T: + ... + + def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T: + ... + + def hermite_polynomial_h(self, x: T, y: T) -> T: + ... + + def hermite_polynomial_he(self, x: T, y: T) -> T: + ... + + def laguerre_polynomial_l(self, x: T, y: T) -> T: + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These operators are a bit special, because they are conventionally + # natively supported in both Python and C, but the semantics differ so + # care must be taken + + def truncdiv(self, x0: T, x1: T) -> T: + """C-style trunc division between integers only. Computes the true + division of two numbers and rounds the result to zero. + """ + ... + + def floordiv(self, x0: T, x1: T) -> T: + """Python-style floor division between integers only. Computes the + true division of two numbers and floors the result. If you want + floor division for floats, do regular truediv and floor the result. + """ + ... + + def truediv(self, x0: T, x1: T) -> T: + """True division between floats. Integer inputs are NOT valid. To + do Python-style (int, int) -> float division, use int_truediv""" + ... + + def int_truediv(self, x0: T, x1: T) -> T: + """True division between integers. This is NOT the same as promoting + to float and doing integer division, there is a bespoke algorithm for + doing the division in higher precision than the above. + """ + ... + + def div(self, x0: T, x1: T) -> T: + """TODO: to be removed. This renders as / no matter what the backend is + which is incoherent.""" + ... + + def mod(self, x0: T, x1: T) -> T: + """C-style modulus, take sign from LHS (x0).""" + ... + + def remainder(self, x0: T, x1: T) -> T: + """Python-style modulus, take sign from RHS (x1).""" + ... + + def round_decimal(self, x0: T, x1: T) -> T: + """Python-style round with decimal argument""" + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # In CUDA, optimized implementations of other mathematical operations are + # offered separately via libdevice for double precision computation (in + # Triton, these go to tl.math rather than tl). We lower to these + # operators when doing FP64 on CUDA. Note that some operators + # unconditional go to tl.math. + # + # TODO(ezyang): Is this really the best way to do this? What if we have + # abs internally route to tl.math automatically when given a double + # precision input? One reason is that when doing codegen, we often don't + # know what the dtype of the inputs are! (In principle we do know, but + # for many analyses it's not conveniently available.) + + def libdevice_abs(self, x0: T) -> T: + ... + + def libdevice_exp(self, x0: T) -> T: + ... + + def libdevice_sqrt(self, x0: T) -> T: + ... + + def libdevice_cos(self, x0: T) -> T: + ... + + def libdevice_sin(self, x0: T) -> T: + ... + + def libdevice_sigmoid(self, x0: T) -> T: + ... + + def libdevice_log(self, x0: T) -> T: + ... + + +class NoopHandler: + def __getattr__(self, name): + if name == "name": + return "NoopHandler" + + def inner(*args, **kwargs): + return None + + return inner + + @staticmethod + def masked(mask, body, other) -> None: + return None + + @staticmethod + def frexp(x) -> Tuple[None, None]: + return (None, None) + + @staticmethod + def scan(dtypes, combine_fn, values) -> Tuple[None, ...]: + return (None,) * len(values) + + @staticmethod + def sort(dtypes, values, stable, descending) -> Tuple[None, ...]: + return (None,) * len(values) + + @staticmethod + def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: + return sympy.Integer(0) + + +# Use mypy to check protocol implemented correctly +def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]: + return h + + +class MockHandler: + def __getattr__(self, name): + if name == "name": + return "MockHandler" + + def inner(*args, **kwargs): + fargs = [_arg_str(a) for a in args] + fargs.extend(f"{k}={v}" for k, v in kwargs.items()) + return f"ops.{name}({', '.join(fargs)})" + + return inner + + @staticmethod + def masked(mask, body, other) -> str: + return f"ops.masked({mask}, {body()}, {other})" + + @staticmethod + def frexp(x): + return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]") + + @staticmethod + def scan(dtypes, combine_fn, values): + return tuple( + f"ops.scan({dtypes}, {combine_fn}, {values})[{i}]" + for i in range(len(values)) + ) + + @staticmethod + def sort(dtypes, values, stable, descending): + return tuple( + f"ops.sort({dtypes}, {values}, stable={stable}, descending={descending})[{i}]" + for i in range(len(values)) + ) + + @staticmethod + def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: + return sympy_index_symbol(str(index_var)) + + @classmethod + def _init_cls(cls): + def make_handler(format_string): + @staticmethod # type: ignore[misc] + def inner(*args): + return format_string.format(*args) + + return inner + + for name, format_string in { + "add": "{} + {}", + "sub": "{} - {}", + "mul": "{} * {}", + "floordiv": "{} // {}", + "truediv": "{} / {}", + "mod": "{} % {}", # careful, depending on target semantics varies + "pow": "{} ** {}", + "lshift": "{} << {}", + "rshift": "{} >> {}", + "and_": "{} & {}", + "or_": "{} | {}", + "xor": "{} ^ {}", + "eq": "{} == {}", + "ne": "{} != {}", + "lt": "{} < {}", + "gt": "{} > {}", + "le": "{} <= {}", + "ge": "{} >= {}", + "neg": "-{}", + }.items(): + setattr(cls, name, make_handler(format_string)) + + +MockHandler._init_cls() + + +# Use mypy to check protocol implemented correctly +def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]: + return h + + +class KernelFormatterHandler: + def __init__(self, parent_handler): + self.parent_handler = parent_handler + self.output = IndentedBuffer(1) + self.var_counter = itertools.count() + + @staticmethod + def ir_to_string(ir_fn, index, rindex=None) -> str: + from .ir import FlexibleLayout + from .virtualized import V + + args = [index, rindex] if rindex is not None else [index] + names = ["index", "rindex"] if rindex is not None else ["index"] + formatter = KernelFormatterHandler(MockHandler()) + + with formatter.output.indent(-1): + formatter.output.writeline(f"def inner_fn({', '.join(names)}):") + for name, arg in zip(names, args): + if arg: + lhs = ", ".join( + [ + str("_" if isinstance(v, (int, sympy.Integer)) else v) + for v in arg + ] + ) + formatter.output.writeline(f"{lhs} = {name}") + + with V.set_ops_handler(formatter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + result = ir_fn(*args) + return formatter.getvalue(result) + + def __getattr__(self, name) -> Callable[..., Any]: + def inner(*args, **kwargs): + line = getattr(self.parent_handler, name)(*args, **kwargs) + if name == "indirect_indexing": + return line + + def write(line): + # replace line with a new variable name + varname = f"tmp{next(self.var_counter)}" + self.output.writeline(f"{varname} = {line}") + return varname + + return pytree.tree_map(write, line) + + return inner + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[str, Tuple[str, ...]], + ) -> Union[str, Tuple[str, ...]]: + line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value) + num_values = reduction_num_outputs(reduction_type) + varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)] + self.output.writeline(f"{','.join(varnames)} = {line}") + return tuple(varnames) if num_values > 1 else varnames[0] + + def getvalue(self, result): + self.output.writeline(f"return {result}") + return self.output.getvalue() + + +# Use mypy to check protocol implemented correctly +def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]: + return h + + +class WrapperHandler(Generic[T]): + def __init__(self, inner: OpsHandler[T]): + self._inner = inner + + def __getattr__(self, item): + return getattr(self._inner, item) + + +# Use mypy to check protocol implemented correctly +def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]: + return h + + +class AddParenHandler(WrapperHandler[T]): + def __getattr__(self, name): + def inner(*args, **kwargs): + val = getattr(self._inner, name)(*args, **kwargs) + return f"({val})" + + return inner + + +# Use mypy to check protocol implemented correctly +def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]: + return h + + +class OpCountResult(NamedTuple): + num_ops: int + used_ops: OrderedSet[str] + read_buffers: List[str] + nontrivial_read_count: int + + +class OpCounterCSE: + """Shim to count how many ops are used""" + + def __init__(self, inner): + super().__init__() + self.parent_handler = inner + self.op_count = 0 + self.var_names = {} + self._used_ops: OrderedSet[str] = OrderedSet() + self._read_names: List[str] = [] + self._nontrivial_read_count = 0 + + def __getattr__(self, name): + def inner(*args, **kwargs): + return pytree.tree_map( + self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) + ) + + self._used_ops.add(name) + return inner + + def _update_count(self, val): + varname = self.var_names.get(val) + if not varname: + varname = f"tmp{self.op_count}" + self.op_count += 1 + self.var_names[val] = varname + return varname + + def indirect_indexing(self, *args, **kwargs): + self._used_ops.add("indirect_indexing") + return self.parent_handler.indirect_indexing(*args, **kwargs) + + def load(self, name: str, index: sympy.Expr) -> str: + val = self.parent_handler.load(name, index) + if val not in self.var_names: + self._used_ops.add("load") + self._read_names.append(name) + if not isinstance(index, (sympy.Integer, int)): + self._nontrivial_read_count += 1 + return self._update_count(val) + + def load_seed(self, name: str, offset: T): + val = self.parent_handler.load_seed(name, offset) + if val not in self.var_names: + self._used_ops.add("load_seed") + self._read_names.append(name) + return self._update_count(val) + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + val = self.parent_handler.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + if val not in self.var_names: + self._used_ops.add("bucketize") + self._read_names.append(offsets_name) + return self._update_count(val) + + def getvalue(self): + return OpCountResult( + self.op_count, self._used_ops, self._read_names, self._nontrivial_read_count + ) + + +def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]: + return h + + +class ExtractConstantsHandler(NoopHandler): + def __init__(self, device): + self.device = device + + def constant(self, value: Any, dtype: torch.dtype) -> "torch._inductor.ir.Constant": + from torch._inductor import ir + + return ir.Constant(value=value, dtype=dtype, device=self.device) + + +def _typecheck_ExtractConstantsHandler(h: ExtractConstantsHandler) -> OpsHandler[Any]: + return h + + +class SimpleCSEHandler(WrapperHandler[T]): + """Wraps the underlying handler with a CSE pass + + NOTE: Compared to codegen level CSE this is simplified as it + doesn't support stores which require load cache invalidation. + """ + + def __init__(self, inner: OpsHandler[T]): + super().__init__(inner) + self.cse_cache: Dict[str, Union[T, Tuple[T, ...]]] = {} + self.mock = MockHandler() + + def indirect_indexing(self, *args, **kwargs) -> sympy.Expr: + return super().indirect_indexing(*args, **kwargs) # type: ignore[misc] + + def store(self, *args, **kwargs) -> T: + raise NotImplementedError("store not implemented") + + def store_reduction(self, *args, **kwargs) -> T: + raise NotImplementedError("store not implemented") + + def __getattr__(self, name) -> Callable[..., Any]: + def inner(*args, **kwargs): + key = getattr(self.mock, name)(*args, **kwargs) + val = self.cse_cache.get(key) + if val is not None: + return val + + val = getattr(self._inner, name)(*args, **kwargs) + self.cse_cache[key] = val + return val + + return inner + + +def _typecheck_SimpleCSEHandler(h: SimpleCSEHandler[Any]) -> OpsHandler[Any]: + return h diff --git a/lib/python3.10/site-packages/torch/_inductor/optimize_indexing.py b/lib/python3.10/site-packages/torch/_inductor/optimize_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..96bf8641f3c9a62b3c61fe769132717ef493cf7f --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/optimize_indexing.py @@ -0,0 +1,120 @@ +# mypy: allow-untyped-defs +import math + +import sympy + +import torch +from torch.utils._sympy.value_ranges import ValueRanges + +from .loop_body import LoopBody +from .utils import dominated_nodes + + +def val_expressable_in_32_bits(val): + if getattr(val, "is_Boolean", False): + return True + + if isinstance(val, sympy.Expr): + assert val.is_number + if val.is_Integer or val.is_Boolean: + val = int(val) + else: + val = float(val) + + # bound within mantissa + if isinstance(val, float): + return val <= (2**24) and val >= -(2**24) + + if isinstance(val, int): + iinfo = torch.iinfo(torch.int32) + return val <= iinfo.max and val >= iinfo.min + + raise TypeError(f"Unexpected value {val}") + + +def range_expressable_in_32_bits(range): + return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits( + range.upper + ) + + +def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals): + # if a downstream use of a node explicitly converts to int32, or float16/float32/float64, + # then it's precision is set for that chain of uses, and we don't need to consider those + # dominated values + def skip_filter(node): + return node.target == "to_dtype" and node.args[2] in ( + torch.int32, + torch.float32, + torch.float64, + ) + + # TODO - there are dominated uses whose dtype does not depend on whether + # we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to + # int32 without changing the output precision of the node. this case hasn't shown up + for dominated in dominated_nodes([node], skip_filter): + if dominated.target in ["store", "output"]: + continue + + if isinstance(dominated.target, str) and "set_indirect" in dominated.target: + idx = int(dominated.target[len("set_indirect") :]) + indirect_var = indirect_vars[idx] + + # We check that we can compute all the indices it's involved in with int32 + for index, expr in indices.items(): + if indirect_var in expr.free_symbols: + index_val = replacement_vals[index] + + if math.isinf(index_val.lower) or math.isinf(index_val.upper): + return + + # all indices are integers, so make sure that we + # use the bounds of integers instead of floats. + # TODO - not sure if we should be doing int/float casts while tracing, + # might interfere with sympy. + + index_val_int = ValueRanges[sympy.Expr]( + int(index_val.lower), int(index_val.upper) + ) + if not range_expressable_in_32_bits(index_val_int): + return + + if not range_expressable_in_32_bits(bounds[dominated]): + return + + args = list(node.args) + args[2] = torch.int32 + node.args = tuple(args) + + +def indexing_dtype_strength_reduction(loop_body: LoopBody): + """ + Performs Value Range Analysis on LoopBody's fx graph to reduce precision of + intermediaries from int64 to int32 + """ + bv = loop_body.bounds() + + int64_dtype_nodes = [ + node + for node in loop_body.get_nodes() + if ( + node.target == "to_dtype" + and node.args[2] == torch.int64 + and node not in bv.unbounded_vars + ) + ] + if not int64_dtype_nodes: + return + + bounds = bv.get_bounds() + + # TODO - if dominated node of one to_dtype is not expressible in int32, + # we should short circuit another to_dtype node if that node also dominates + for node in int64_dtype_nodes: + try_to_reduce_precision( + node, + bounds, + loop_body.indirect_vars, + loop_body.indexing_exprs, + bv.replacement_vals, + ) diff --git a/lib/python3.10/site-packages/torch/_inductor/pattern_matcher.py b/lib/python3.10/site-packages/torch/_inductor/pattern_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..e43d37fd37b1a757949406ce5e906b8b08cf719b --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/pattern_matcher.py @@ -0,0 +1,2005 @@ +# mypy: allow-untyped-decorators +""" +# Inductor Pattern Matcher + +The pattern matcher enables search/replace within an FX graph. + +The main entrypoint to the pattern matcher is register_replacement(). Given a +search function and a replacement function this will register a replacement with +a pass (such as torch._inductor.fx_passes.joint_graph.patterns). + +Internally the pattern matcher represents patterns as a graph (a DAG). Creating +new patterns manually as a graph is cumbersome and error-prone so the standard +way to create patterns (using register_replacement()) is to provide a search +function and a replacement function which is traced and converted into a graph. + +Because the search functions are built somewhat generic (they tend to ignore +tensor sizes, for example) register_replacement() allows you to specify an +`extra_check` function which performs additional checks to verify that the +matched pattern fully matches before returning it. + +## Precompiled Patterns + +New patterns are added using register_replacement(). Patterns added in this way +can have a compile-time overhead because they need to be traced before +use. Patterns can be precompiled and added using gen_register_replacement() +instead. To do this you call gen_register_replacement() instead of +register_replacement(). The arguments are the same except for an additional +unique name which is used as a lookup key. + +## Internals + +The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr +implements a `_match` method which returns either a `Match` object for a +successful match or a `FailedMatch` object for a failure to match. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import importlib +import inspect +import itertools +import logging +import operator +import os +import re +import textwrap +import typing +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Generator, + Iterable, + List, + Mapping, + NoReturn, + Optional, + Protocol, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, +) +from typing_extensions import Self, TypeGuard + +import torch +import torch._guards +import torch.fx +import torch.utils._pytree as pytree +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import counters +from torch._inductor.config import trace as trace_config +from torch._prims_common import is_integer_dtype +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import guard_size_oblivious +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.graph_transform_observer import GraphTransformObserver + +from .._functorch import config as functorch_config +from .._functorch.aot_autograd import aot_function, make_boxed_func +from .._functorch.partitioners import default_partition +from .._subclasses import FakeTensor, FakeTensorMode +from ..fx import Transformer +from . import config +from .decomposition import select_decomp_table +from .lowering import fallback_node_due_to_unsupported_type + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +Constant = Any +NodeOrConstant = Union[Constant, torch.fx.Node] + + +class SearchFn(Protocol): + __name__: str + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + ... + + +class ReplaceFn(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> Any: + ... + + +class TraceFn(Protocol): + def __call__( + self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any + ) -> torch.fx.GraphModule: + ... + + +T = TypeVar("T") + +# What's a better name for this? +FnsType = Union[torch.fx.node.Target, str] + + +class Multiple: + def __init__(self) -> None: + # Ensure we're really a singleton. + assert "MULTIPLE" not in globals() or self is MULTIPLE + + +# Sentinel indicating multiple quantities can be matched +MULTIPLE = Multiple() + + +class Match: + """ + Represents a successfully matched pattern. + + The `Match` object is returned to represent a successfully matched + pattern. Included in the Match are the pattern that was matched, the graph + nodes matched, and any args that were used during the matching. + + The args and kwargs are specific to the type of pattern that was matched and + provide hints about what was matched. + """ + + pattern: PatternExpr + args: List[Any] + kwargs: Dict[str, Any] + nodes: List[torch.fx.Node] + targets: Dict[_TargetExpr, torch.fx.node.Target] + ctx: MatchContext + replacement_graph: Optional[torch.fx.Graph] + + def __init__( + self, + ctx: MatchContext, + pattern: PatternExpr, + args: Optional[Sequence[Any]] = None, + kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self.pattern = pattern + # The input nodes that must be passed in to the result + self.args = list(args or []) + self.kwargs = kwargs or {} + # The nodes matched in this expression + self.nodes = [] + # Mapping CallFunction to the node.target + self.targets = {} + self.ctx = ctx + self.replacement_graph = None + + @property + def graph(self) -> torch.fx.Graph: + return self.ctx.graph + + def extend(self, other: Match) -> None: + if self.kwargs: + for key in set(self.kwargs.keys()) & set(other.kwargs.keys()): + if self.kwargs[key] != other.kwargs[key]: + raise FailedMatch("kwarg mismatch: {}", key) + self.args.extend(other.args) + self.nodes.extend(other.nodes) + self.kwargs.update(other.kwargs) + self.targets.update(other.targets) + + def bundle(self) -> Match: + # Wrap args in an extra list + self.args = [tuple(self.args)] if self.args else [] + return self + + def __repr__(self) -> str: + return f"Match(..., {self.args}, {self.kwargs})" + + def erase_nodes(self) -> None: + graph = self.graph + for n in reversed(self.nodes): + if not n._erased and not n.users: + graph.erase_node(n) + + def output_nodes(self) -> List[Optional[torch.fx.Node]]: + return [ + (self.ctx.pattern_to_node[p] if p is not None else None) + for p in self.ctx.outputs + ] + + def output_node(self) -> torch.fx.Node: + return next(p for p in self.output_nodes() if p) + + def replace_with_graph( + self, replacement_graph: torch.fx.Graph, args: Sequence[Any] + ) -> None: + ReplacementPatternEntry.replace_with_graph( + self, self.ctx.graph, replacement_graph, args + ) + + def replace_by_example( + self, + replacement_fn: ReplaceFn, + args: Sequence[Any], + trace_fn: Optional[TraceFn] = None, + run_functional_passes: bool = True, + ) -> None: + """Replace with a graph generated by tracing the replacement_fn. + + Args: + run_functional_passes (bool). If we should run passes that + assume functional IR (like DCE, remove_noop_ops), on the + replacement graph. + + """ + from torch._inductor.virtualized import NullHandler, V + + context = ( + V.fake_mode + if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None)) + else contextlib.nullcontext() + ) + + with context: + if trace_fn is None: + trace_fn = functools.partial( + fwd_only, run_functional_passes=run_functional_passes + ) + replacement = trace_fn( + replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type] + ) + ReplacementPatternEntry.replace_with_graph( + self, + self.ctx.graph, + replacement, + args, + ) + + +class FailedMatch(RuntimeError): + """ + Represents a unsuccessful match. + + The `FailedMatch` object is returned to represent a failure to match a + pattern. + """ + + format_string: str + + def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None: + self.format_string = format_string + # We want to construct error messages lazily instead of eagerly, as + # constructing them eagerly can significantly worsen compile times. + if len(format_string) > 200: + raise RuntimeError( + f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}" + ) + self.args = args + self.kwargs = kwargs + + def __str__(self) -> str: + return self.format_string.format(*self.args, **self.kwargs) + + def __bool__(self) -> bool: + return False + + +MatchResult = Union[Match, FailedMatch] + + +def is_match(m: MatchResult) -> TypeGuard[Match]: + """ + TypeGuards cannot act on `self`. Thus this function exists to let mypy + recognize FailedMatch.__bool__ as a TypeGuard. + """ + return bool(m) + + +class MatchContext: + """ + Internal state needed while running PatternExpr._match(). + """ + + outputs: List[Optional[PatternExpr]] + pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]] + graph: torch.fx.Graph + exclusive_node_set: List[NodeOrConstant] + + def __init__( + self, + outputs: List[Optional[PatternExpr]], + pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None, + *, + graph: torch.fx.Graph, + ) -> None: + self.outputs = outputs + self.pattern_to_node = {} if pattern_to_node is None else dict(pattern_to_node) + self.graph = graph + self.exclusive_node_set = [] + + def match(self, pattern: PatternExpr, node: NodeOrConstant) -> MatchResult: + """wrapper to check reused nodes in patterns""" + if pattern in self.pattern_to_node: + if self.pattern_to_node[pattern] == node: + return Match(self, pattern) # already checked this node + else: + return FailedMatch("repeated pattern differs") + m = pattern._match(node, self) + assert pattern not in self.pattern_to_node + self.pattern_to_node[pattern] = node if m else None + return m + + def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]: + return { + pattern: node + for pattern, node in self.pattern_to_node.items() + if pattern.has_multiple_users() and node is not None + } + + +class PatternExpr(ABC): + """ + Base class for types of patterns. + """ + + @abstractmethod + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + ... + + def match(self, node: torch.fx.Node) -> MatchResult: + try: + return MatchContext([self], graph=node.graph).match(self, node) + except FailedMatch as e: + return e + + def has_multiple_users(self) -> bool: + return False + + def __repr__(self) -> str: + return self.__class__.__name__ + "()" + + def find_anchor_nodes( + self, ctx: MatchContext, searched: Set[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: + if self in ctx.pattern_to_node: + yield ctx.pattern_to_node[self] + + def pattern_eq(self, other: Any) -> bool: + """ + Compare two `PatternExpr`s and return true if they are the + same. Note this is NOT matching a pattern - it is comparing the pattern + structures (for debugging). + """ + return isinstance(other, self.__class__) + + +class Arg(PatternExpr): + """ + Capture an arg which will become an input to the handler. Args are + passed in depth first order. + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self, args=[node]) # matches anything + + +class Ignored(PatternExpr): + """ + Match an arg, but don't pass it to handler + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self) # matches anything + + def __repr__(self) -> str: + return "*" + + def pretty_print(self, pp: PatternPrettyPrinter) -> str: + return "Ignored()" + + +class KeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + def __init__(self, name: str) -> None: + super().__init__() + self.name = name + + def __repr__(self) -> str: + return f"KeywordArg({self.name!r})" + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self, kwargs={self.name: node}) # matches anything + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return super().pattern_eq(other) and self.name == other.name + + +class ExclusiveKeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + name: str + + def __init__(self, name: str) -> None: + super().__init__() + self.name = name + + def __repr__(self) -> str: + return f"ExclusiveKeywordArg({self.name!r})" + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + if node in ctx.exclusive_node_set: + return FailedMatch("exclusive arg appears twice") + + ctx.exclusive_node_set.append(node) + return Match(ctx, self, kwargs={self.name: node}) # matches anything + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return super().pattern_eq(other) and self.name == other.name + + +class _TargetExpr(PatternExpr): + """ + Base class for filtering match by node.target + """ + + fns: List[FnsType] + fns_set: Set[FnsType] + + def __init__( + self, fns: Union[FnsType, Sequence[FnsType]], users: Union[Multiple, int] = 1 + ) -> None: + super().__init__() + fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns) + for fn in fns: + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend(getattr(fn, overload) for overload in fn.overloads()) + + self.fns = fns + self.fns_set = set(fns) + self.users = users + + @property + @abstractmethod + def op(self) -> str: + ... + + def fns_repr(self) -> str: + first_repr = self.fns[0] + if not isinstance(first_repr, str): + first_repr = first_repr.__name__ + + if len(self.fns) > 1: + return f"[{first_repr}, ...]" + elif self.fns[0] is getattr(torch, first_repr, None): + return f"torch.{first_repr}" + elif isinstance(self.fns[0], torch._ops.OpOverload): + return str(self.fns[0]) + else: + return first_repr + + def __repr__(self) -> str: + if self.users is MULTIPLE: + comma_users = ", MULTIPLE" + elif self.users != 1: + comma_users = f", {self.users})" + else: + comma_users = "" + return f"{self.__class__.__name__}({self.fns_repr()}{comma_users})" + + def has_multiple_users(self) -> bool: + return isinstance(self.users, Multiple) or self.users > 1 + + def find_anchor_nodes( + self, ctx: MatchContext, searched: Set[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: + raise NotImplementedError + + def _match_fns(self, node: torch.fx.Node) -> bool: + return ( + isinstance(node, torch.fx.Node) + and node.op == self.op + and extract_target(node) in self.fns_set + ) + + def _match_users(self, node: torch.fx.Node, ctx: MatchContext) -> bool: + return ( + self in ctx.outputs + or self.users is MULTIPLE + or len(node.users) == self.users + ) + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and self.op == other.op + and self.fns == other.fns + and self.users == other.users + ) + + +_SimpleSpec = Tuple[Any, ...] + + +class _TargetArgsExpr(_TargetExpr): + """ + Base class for filtering match by node.{target,args,kwargs} + """ + + def __init__( + self, + fns: Union[torch.fx.node.Target, str, Sequence[Any]], + *args: Any, + _users: Union[int, Multiple] = 1, + **kwargs: Any, + ) -> None: + super().__init__(fns, _users) + self.args = tuple(args) + self.kwargs = dict(kwargs) + if any( + isinstance(x, (dict, list, tuple)) + for x in itertools.chain(args, kwargs.values()) + ): + self.flatten = self.pytree_flatten + else: + self.flatten = self.simple_flatten + self.flat_args_kwargs = self.flatten(self.args, self.kwargs) + + @staticmethod + def simple_flatten( + args: Sequence[Any], kwargs: Mapping[Any, Any] + ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: + values = (*args, *kwargs.values()) + spec = (len(args), *kwargs.keys()) + return values, spec + + @staticmethod + def pytree_flatten( + args: Sequence[Any], kwargs: Mapping[Any, Any] + ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: + def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec: + if s.type is None: + return s + mapping = {immutable_list: list, tuple: list, immutable_dict: dict} + return pytree.TreeSpec( + mapping.get(s.type, s.type), + s.context, + list(map(norm_spec, s.children_specs)), + ) + + flat, spec = pytree.tree_flatten([args, kwargs]) + spec = norm_spec(spec) + return flat, spec + + def __repr__(self) -> str: + args = [ + self.fns_repr(), + *map(repr, self.args), + *[f"{k}={v}" for k, v in self.kwargs.items()], + ] + if self.users is MULTIPLE: + args.append("_users=MULTIPLE") + elif self.users != 1: + args.append(f"_users={self.users}") + return f"{self.__class__.__name__}({', '.join(args)})" + + def pretty_print(self, pp: PatternPrettyPrinter) -> str: + args = [ + self.fns_repr(), + *(pp.pretty_print(x) for x in self.args), + *[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()], + ] + if self.users is MULTIPLE: + args.append("_users=MULTIPLE") + elif self.users != 1: + args.append(f"_users={self.users}") + + joiner_str = ", " + return f"{self.__class__.__name__}({joiner_str.join(args)})" + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + if not self._match_fns(node) or len(node.args) != len(self.args): + return FailedMatch("function_mismatch: node={}, pattern={}", node, self) + + if not self._match_users(node, ctx): + return FailedMatch("multiple_users {}", self) + + _args = node.args + _kwargs = node.kwargs + if len(_kwargs) < len(self.kwargs): + from torch.fx.operator_schemas import normalize_function + + normalized_args_and_kwargs = normalize_function( + node.target, node.args, node.kwargs # type: ignore[arg-type] + ) + + if normalized_args_and_kwargs is None: + return FailedMatch("function_mismatch: node={}, pattern={}", node, self) + else: + _args, _kwargs = normalized_args_and_kwargs + if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs): + _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} + else: + return FailedMatch( + "function_mismatch: node={}, pattern={}", node, self + ) + else: + _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} + + node_items, node_spec = self.flatten(_args, _kwargs) + self_items, self_spec = self.flat_args_kwargs + if node_spec != self_spec: + return FailedMatch("args_structure {} {}", node_spec, self_spec) + assert len(node_items) == len(self_items) + + m = Match(ctx, self) + for i, pattern, child_node in zip(itertools.count(), self_items, node_items): + if isinstance(pattern, PatternExpr): + child_match = ctx.match(pattern, child_node) + if not is_match(child_match): + return child_match + m.extend(child_match) + elif isinstance(child_node, torch.fx.Node) or child_node != pattern: + return FailedMatch( + "constant_args: {} {!r}!={pattern!r}", node, child_node + ) + m.nodes.append(node) + m.targets[self] = node.target + return m + + def find_anchor_nodes( + self, ctx: MatchContext, searched: Set[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: + """ + This is used when we are matching a pattern with multiple outputs. + There is a partial match (stored in ctx) and we want to walk + this pattern to find a connection to an already-matched node. + + Yields candidate nodes that `self._match` might like. + """ + if self in ctx.pattern_to_node: + yield ctx.pattern_to_node[self] + return + + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and self.flat_args_kwargs[1] == other.flat_args_kwargs[1] + and all( + a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b + for a, b in zip(self.flat_args_kwargs[0], other.flat_args_kwargs[0]) + ) + ) + + +class CallFunction(_TargetArgsExpr): + """ + Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)` + """ + + op = "call_function" + + +class CallMethod(_TargetArgsExpr): + """ + Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)` + """ + + op = "call_method" + + +class CallModule(_TargetArgsExpr): + """ + Matches a call_module node in the FX graphs: `module(*args, **kwargs)` + """ + + op = "call_module" + + +class _TargetExprVarArgs(_TargetExpr): + """ + Matches a call_function node with any arguments which are passed into the pattern + """ + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + if not self._match_fns(node): + return FailedMatch("function_mismatch") + + if not self._match_users(node, ctx): + return FailedMatch("multiple_users") + + m = Match(ctx, self) + m.nodes.append(node) + m.targets[self] = node.target + m.args.extend(node.args) + m.kwargs.update(node.kwargs) + return m + + +class CallFunctionVarArgs(_TargetExprVarArgs): + op = "call_function" + + +class CallMethodVarArgs(_TargetExprVarArgs): + op = "call_method" + + +class CallModuleVarArgs(_TargetExprVarArgs): + op = "call_module" + + +class ListOf(PatternExpr): + """ + Matches a repeated pattern + """ + + def __init__(self, pattern: PatternExpr, partial: bool = False) -> None: + super().__init__() + assert isinstance(pattern, PatternExpr) + self.pattern = pattern + self.partial = partial + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.pattern})" + + def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override] + if not isinstance(node, (list, tuple)) or len(node) == 0: + return FailedMatch("non_list") + m = Match(ctx, self) + # Propagating patterns with multiple users will ensure we don't revisit + # the same nodes + pattern_to_node = ctx.filter_multi_user_patterns() + matched = False + for i, child_node in enumerate(node): + child_ctx = MatchContext( + ctx.outputs, pattern_to_node, graph=child_node.graph + ) + child_match = child_ctx.match(self.pattern, child_node) + pattern_to_node = child_ctx.filter_multi_user_patterns() + if not is_match(child_match): + if not self.partial: + return FailedMatch("list[{}]: {}", i, child_match) + continue + matched = True + m.extend(child_match.bundle()) + if not matched: + return FailedMatch("list: no_match") + return m.bundle() + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and self.pattern.pattern_eq(other.pattern) + and self.partial == other.partial + ) + + +class MultiOutputPattern(PatternExpr): + outputs: List[Optional[PatternExpr]] + + def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None: + super().__init__() + assert isinstance(outputs[0], _TargetExpr) + assert all(x is None or isinstance(x, PatternExpr) for x in outputs), outputs + self.outputs = list(outputs) + self.op = outputs[0].op + + @property + def fns(self) -> Union[Callable[..., Any], str, Sequence[Any]]: + # This cast is checked above in __init__() + output = typing.cast(_TargetExpr, self.outputs[0]) + return output.fns + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.outputs})" + + def pretty_print(self, pp: PatternPrettyPrinter) -> str: + args = [pp.pretty_print(x) for x in self.outputs] + joiner_str = f",\n{' '}" + str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}" + str_out = f"{str_out}\n])" + return str_out + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + output = typing.cast(_TargetExpr, self.outputs[0]) + m = ctx.match(output, node) + if not is_match(m): + return m + + for pattern in self.outputs[1:]: + if pattern is None: + continue + child_match = self._match_from_anchors(pattern, ctx) + if not is_match(child_match): + return child_match + m.extend(child_match) + + return m + + def _match_from_anchors( + self, pattern: PatternExpr, ctx: MatchContext + ) -> MatchResult: + prior = dict(ctx.pattern_to_node) + m: MatchResult = FailedMatch("no anchor found") + for node in pattern.find_anchor_nodes(ctx, set()): + m = ctx.match(pattern, node) + if is_match(m): + return m + # revert any partial matches + ctx.pattern_to_node = dict(prior) + return m + + def match(self, node: torch.fx.Node) -> MatchResult: + try: + return MatchContext(self.outputs, graph=node.graph).match(self, node) + except FailedMatch as e: + return e + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and len(self.outputs) == len(other.outputs) + and all( + a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b + for a, b in zip(self.outputs, other.outputs) + ) + ) + + +class RepeatedExpr(PatternExpr): + """ + Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind` + """ + + def __init__(self, inner_pattern: _TargetExpr) -> None: + super().__init__() + self.inner_pattern = inner_pattern + self.op = inner_pattern.op + + @property + def fns(self) -> Sequence[FnsType]: + return self.inner_pattern.fns + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + m = ctx.match(self.inner_pattern, node) + if not is_match(m): + return m + ctx.pattern_to_node.pop( + self.inner_pattern, + ) + # Check all anchor nodes match the pattern + for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, set()): + anchor_m = MatchContext([self], graph=node.graph).match( + self.inner_pattern, anchor_node + ) + if not is_match(anchor_m): + return anchor_m + m.extend(anchor_m) + return m + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return super().pattern_eq(other) and self.inner_pattern.pattern_eq( + other.inner_pattern + ) + + +class PatternPrettyPrinter: + """ + Serializes Patterns to executable python. + XXX: currently only used and tested for fuse attention patterns. May not cover + all patterns. + """ + + def __init__(self) -> None: + self.namespace = torch.fx.graph._Namespace() + self.memoized_objs_names: Dict[PatternExpr, str] = {} + self.memoized_objs_pp: Dict[PatternExpr, str] = {} + + @staticmethod + @functools.lru_cache(None) + def run(obj: PatternExpr, output_name: str = "output") -> str: + """ + Serializes obj to python code with obj written out to `output_name` + """ + + pp = PatternPrettyPrinter() + assert hasattr(obj, "pretty_print") + out_str = obj.pretty_print(pp=pp) + + output = [] + for key in pp.memoized_objs_names: + output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}") + + output.append(f"{output_name} = {out_str}") + + return "\n".join(output) + + def pretty_print(self, obj: Any) -> str: + if isinstance(obj, _TargetArgsExpr): + if memoized_name := self.memoized_objs_names.get(obj): + return memoized_name + else: + return self.memoize(obj) + if hasattr(obj, "pretty_print"): + return obj.pretty_print(self) + + return repr(obj) + + def memoize(self, obj: _TargetArgsExpr) -> str: + obj_str = obj.pretty_print(self) + obj_name = obj.fns_repr() + for prefix in ("aten.", "torch.", "prims."): + obj_name = obj_name.replace(prefix, "") + + tmp_name = self.namespace.create_name(obj_name, None) + self.memoized_objs_names[obj] = tmp_name + self.memoized_objs_pp[obj] = obj_str + return tmp_name + + +class _PassDictsType(Protocol): + def __getitem__(self, k: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: + ... + + +@dataclasses.dataclass +class PatternEntry: + pattern: PatternExpr + extra_check: Callable[[Match], bool] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + raise NotImplementedError + + def register( + self, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + target: Union[torch.fx.node.Target, None] = None, + prepend: bool = False, + ) -> None: + if target is None: + assert hasattr(self.pattern, "fns") + for fn in self.pattern.fns: + self.register(pass_dicts, fn, prepend=prepend) + elif isinstance(pass_dicts, (dict, PatternMatcherPass)): + assert hasattr(self.pattern, "op") + if prepend: + pass_dicts[(self.pattern.op, target)].insert(0, self) + else: + pass_dicts[(self.pattern.op, target)].append(self) + else: + pass_dicts = typing.cast(Sequence[_PassDictsType], pass_dicts) + for x in pass_dicts: + self.register(x, target, prepend=prepend) + + +@dataclasses.dataclass +class LoweringPatternEntry(PatternEntry): + handler: Callable[..., Any] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + handler = functools.wraps(self.handler)(functools.partial(self.handler, match)) + with graph.inserting_before(node): + replacement = graph.call_function(handler, tuple(match.args), match.kwargs) + replacement.meta.update(node.meta) + node.replace_all_uses_with(replacement) + assert match.nodes[-1] is node + match.erase_nodes() + + +@dataclasses.dataclass +class GraphPatternEntry(PatternEntry): + """ + A pattern that runs a function on the FX graph + """ + + handler: Callable[..., Any] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + with graph.inserting_before(node): + self.handler(match, *match.args, **match.kwargs) + + +@dataclasses.dataclass +class ReplacementPatternEntry(PatternEntry): + normalize_args: Callable[..., List[Any]] + + @staticmethod + def replace_with_graph( + match: Match, + graph: torch.fx.Graph, + replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule], + args: Sequence[torch.fx.Node], + ) -> None: + class Replacer(torch.fx.Interpreter): + call_method = None # type: ignore[assignment] + call_module = None # type: ignore[assignment] + get_attr = None # type: ignore[assignment] + + def run_node(self, node: torch.fx.Node) -> Any: + if node.op in ("placeholder", "output"): + return super().run_node(node) + if node.op == "call_function": + target = node.target + args, kwargs = self.fetch_args_kwargs_from_env(node) + result = graph.call_function(target, args, kwargs) # type: ignore[arg-type] + if "val" in node.meta and "val" not in result.meta: + result.meta["val"] = node.meta["val"] + if isinstance(node.meta["val"], torch.Tensor): + assert "tensor_meta" in node.meta + result.meta["tensor_meta"] = node.meta["tensor_meta"] + return result + raise NotImplementedError(f"unhandled {node}") + + output_nodes = match.output_nodes() + + if len(output_nodes) == 1: + last_node = output_nodes[0] + else: + assert output_nodes[0] + nodes = list(output_nodes[0].graph.nodes) + indices = [ + (nodes.index(n), n) + for n in output_nodes + if isinstance(n, torch.fx.Node) + ] + last_node = min(indices, key=operator.itemgetter(0))[1] + + def percolate_tags( + node: torch.fx.Node, + tag_name: str, + tag_value: str, + input_stops: Set[torch.fx.Node], + ) -> None: + queue = [node] + visited = set() + + while queue: + arg = queue.pop() + if ( + arg not in visited + and arg not in input_stops + and hasattr(arg, "meta") + ): + visited.add(arg) + arg.meta[tag_name] = tag_value + queue.extend(arg.all_input_nodes) + + with graph.inserting_before(last_node): + replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type] + if isinstance(replacement, torch.fx.Node): + replacement = [replacement] + + def maybe_getitem(node: torch.fx.Node) -> Any: + if node.op != "call_function": + return None + if node.target != operator.getitem: + return None + assert len(node.args) == 2 + return node.args[1] + + def replace( + old: Union[torch.fx.Node, None], + new: Union[torch.fx.Node, Sequence[torch.fx.Node], None], + ) -> None: + if old is None: + assert new is None + return + assert isinstance(old, torch.fx.Node) + if new is None: + old.replace_all_uses_with(None) # type: ignore[arg-type] + graph.erase_node(old) + return + if isinstance(new, torch.fx.Node): + if "val" not in new.meta: + new.meta.update(old.meta) + + # Preserve the recompute tags in the replacement graph. We + # look at the recompute tags of the original output node to + # propagate the tag from the output all the way to the input + # args (named as args in the replace_with_graph). + # Note that this is best effort. Since patterns are from + # many to many, there is no easy way to correctly map the + # recomputable tags. It is possible in some scenarios that we + # incorrectly tag some nodes as recomputables. + for tag_name in ["recompute", "ac_graph_id"]: + if tag_name in old.meta: + percolate_tags(new, tag_name, old.meta[tag_name], set(args)) + + old.replace_all_uses_with(new) + graph.erase_node(old) + return + + # `new` is not a node: it's a list of nodes. + # + # This happens when we want to replace a node that has a single + # packed return with multiple unpacked returns. We need to do + # some graph surgery here. + # + # Example: + # def original_graph(x): + # a = op(x) + # b = a[0] + # c = a[1] + # ... + # + # Assume that we want to replace op(x) with the graph + # def new_op(x): + # w = x + 1 + # z = x + 2 + # return (w, z) + # + # We need to replace `op` with the contents of `new_op`, + # and then rewrite a[0] to be w and a[1] to be z, as so: + # def new_graph(x): + # w = x + 1 + # z = x + 2 + # b = w + # c = z + # ... + old_uses = list(old.users.keys()) + for user in old_uses: + idx = maybe_getitem(user) + if idx is None: + raise AssertionError("can't handle") + replace(user, new[idx]) # type: ignore[index] + graph.erase_node(old) + + if len(output_nodes) == len(replacement): + for old, new in zip(output_nodes, replacement): + replace(old, new) + else: + assert len(output_nodes) == 1 + replace(output_nodes[0], replacement) + + match.erase_nodes() + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + assert match.replacement_graph is not None + self.replace_with_graph( + match, + graph, + match.replacement_graph, + self.normalize_args(*match.args, **match.kwargs), + ) + + +def _return_true(match: Match) -> bool: + return True + + +def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None: + log.info( + "Replacement pattern %s failed to apply due to shape mismatch: %s", + search_fn.__name__, + e, + ) + + +def register_replacement( + search_fn: SearchFn, + replace_fn: ReplaceFn, + example_inputs: Iterable[Any], + trace_fn: TraceFn, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + extra_check: Callable[[Match], bool] = _return_true, + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), + search_fn_pattern: Union[PatternExpr, None] = None, +) -> bool: + """ + Create a replacement rule based on example functions that get traced + to create patterns. This supports both training and inference when + run on a joint forward+backward graph. + + Args: + search_fn: traced to give original pattern + replace_fn: traced to give replacement graph + example_inputs: example inputs for initial trace + trace_fn: fwd_only or joint_fwd_bwd + pass_dict: dict of passes to register to + extra_check: additional check to run on match(using real shapes) + """ + argnames_static = [*inspect.signature(search_fn).parameters.keys()] + + def check_fn(match: Match) -> bool: + """ + Often shapes get burned into the pattern, so our initial match ran with + `ignore_types=(int, ...)`. + + Recheck the match with the correct shapes. + """ + argnames = list(argnames_static) + for name in argnames: + if name not in match.kwargs: + raise RuntimeError( + f"Not all inputs to pattern found in match.kwargs. Perhaps one " + f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}" + ) + + args = list( + torch.fx.map_arg( # type: ignore[arg-type] + [match.kwargs[name] for name in argnames], lambda n: n.meta["val"] + ) + ) + sym_args: List[torch.SymInt] = [] + with torch._dynamo.utils.detect_fake_mode(args): + for i, grad in enumerate(requires_grad): + if isinstance(args[i], torch.Tensor): + if grad and is_integer_dtype(args[i].dtype): + return False + + args[i] = torch.empty_strided( + args[i].size(), + args[i].stride(), + dtype=args[i].dtype, + device=args[i].device, + requires_grad=grad, + ) + for v in itertools.chain(args[i].shape, args[i].stride()): + if isinstance(v, torch.SymInt) and all( + guard_size_oblivious(v != a) for a in sym_args + ): + sym_args.append(v) + + # If we were given a pre-traced pattern then use that instead of + # retracing. Note that this means the pattern has to be independent + # of its args. + specific_pattern = search_fn_pattern + + if not specific_pattern: + if sym_args: + # AOT Autograd and make fx will dedupe symbolic shape size + # accesses of sym ints that appear as inputs + # We don't want the sym_size uses to interfere with pattern matching + # so we provide them as inputs. + # Later, when we actually do the replacement, the symbolic shape + # sizes will get re-traced and added to the graph. + + def search_fn_new(*args_new: Any) -> Any: + return search_fn(*args_new[len(args_new) - len(args) :]) + + try: + specific_graph = trace_fn(search_fn_new, sym_args + args) + except RuntimeError as e: + log_trace_failure(search_fn, e) + return False + + # correct argnames in the graph + sym_arg_names = [] + for i, placeholder in zip( + range(len(sym_args) + len(args)), + specific_graph.graph.nodes, + ): + if i < len(sym_args): + sym_arg_names.append(placeholder.target) + continue + + with specific_graph.graph.inserting_after(placeholder): + new_node = specific_graph.graph.placeholder( + argnames[i - len(sym_args)] + ) + new_node.target = new_node.name + placeholder.replace_all_uses_with(new_node) + specific_graph.graph.erase_node(placeholder) + + argnames = sym_arg_names + argnames + else: + try: + specific_graph = trace_fn(search_fn, args) + except RuntimeError as e: + log_trace_failure(search_fn, e) + return False + + specific_pattern = fx_to_pattern( + specific_graph, + argnames=argnames, + exclusive_arg_names=exclusive_arg_names, + scalar_workaround=scalar_workaround, + ) + + node = match.output_nodes()[0] + assert node is not None + specific_pattern_match = specific_pattern.match(node) + + if is_match(specific_pattern_match) and extra_check(specific_pattern_match): + # trace the pattern using the shapes from the user program + match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment] + return True + return False + + def normalize_args(**kwargs: Any) -> List[Any]: + args = [] + for name in argnames_static: + args.append(kwargs.pop(name)) + for i in range(1, len(kwargs) + 1): + if f"tangents_{i}" not in kwargs: + break + args.append(kwargs.pop(f"tangents_{i}")) + assert not kwargs, f"leftover kwargs: {kwargs!r}" + return args + + if trace_fn is joint_fwd_bwd: + # If inference mode is enabled during compilation, assume that we don't + # want to match on any training graph patterns + if torch.is_inference_mode_enabled(): + return False + + # TODO: Revisit the functionalize_rng_ops for lowmem dropout + with functorch_config.patch(functionalize_rng_ops=False): + requires_grad: List[bool] = [ + isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs + ] + if search_fn_pattern is None: + pattern = gen_pattern( + search_fn, + example_inputs, + trace_fn, + scalar_workaround, + exclusive_arg_names, + ) + else: + pattern = search_fn_pattern + + pattern_repr = PatternPrettyPrinter.run(pattern) + assert pattern_repr not in _seen_patterns + _seen_patterns.add(pattern_repr) + pattern = ReplacementPatternEntry( + pattern=pattern, + extra_check=check_fn, + normalize_args=normalize_args, + ) + pattern.register(pass_dicts) + return pattern.pattern + + +_serialized_patterns: Set[str] = set() + + +def _serialize_pattern( + unique_name: str, + search_fn: SearchFn, + example_inputs: Iterable[Any], + trace_fn: TraceFn, + scalar_workaround: Union[Dict[str, Union[float, int]], None], +) -> PatternExpr: + def get_file_template() -> str: + auto_generated_msg = textwrap.dedent( + """\ + # This is an auto-generated file. Please do not modify it by hand. + # To re-generate, run: + # cd ~/pytorch && python torchgen/fuse/gen_patterns.py + """ + ) + + file_template = textwrap.dedent( + """\ + # mypy: ignore-errors + + # noqa: F401, E501 + {msg} + import torch + import torch._inductor + + aten = torch.ops.aten + prims = torch.ops.prims + + """ + ).format(msg=auto_generated_msg) + + pattern_matcher_imports = [] + for name in dir(torch._inductor.pattern_matcher): + attr = getattr(torch._inductor.pattern_matcher, name) + if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)): + pattern_matcher_imports.append(name) + + formatted_imports = ",\n ".join(pattern_matcher_imports) + formatted_imports = f"from torch._inductor.pattern_matcher import (\n {formatted_imports},\n)\n" + return f"{file_template}{formatted_imports}" + + if not SERIALIZED_PATTERN_PATH.is_dir(): + raise RuntimeError( + f"Could not find serialized patterns directory at {SERIALIZED_PATTERN_PATH}" + ) + + pattern_name = search_fn.__name__ + + from torch._functorch import config as functorch_config + + with functorch_config.patch(functionalize_rng_ops=False): + pattern = gen_pattern(search_fn, example_inputs, trace_fn, scalar_workaround) + + serialized_pattern = PatternPrettyPrinter.run(pattern, output_name=unique_name) + if pattern_name not in _serialized_patterns: + write_mode = "w" + _serialized_patterns.add(pattern_name) + else: + write_mode = "a" + + file_template = get_file_template() + + with open(SERIALIZED_PATTERN_PATH / f"{pattern_name}.py", write_mode) as f: + if write_mode == "w": + f.write(file_template) + else: + f.write("\n\n") + f.write(serialized_pattern) + f.write("\n") + + return pattern + + +SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patterns" + +# This is the set of serialized patterns that we've registered. Used by +# test_serialized_patterns_up_to_date() to ensure the patterns are up +# to date. +_known_precompiled_patterns: List[ + Tuple[ + Any, + Iterable[Any], + Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], + Any, + PatternExpr, + ] +] = [] + + +def gen_register_replacement( + unique_name: str, + search_fn: SearchFn, + replace_fn: ReplaceFn, + example_inputs: Iterable[Any], + trace_fn: TraceFn, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + extra_check: Callable[[Match], bool] = _return_true, + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), + skip_duplicates: bool = False, +) -> None: + # Make sure the example_inputs is materialized. + example_inputs = tuple(example_inputs) + + if "PYTORCH_GEN_PATTERNS" in os.environ: + pat = _serialize_pattern( + unique_name, search_fn, example_inputs, trace_fn, scalar_workaround + ) + else: + pattern_name = search_fn.__name__ + m = importlib.import_module( + f"torch._inductor.fx_passes.serialized_patterns.{pattern_name}" + ) + if not m or not hasattr(m, unique_name): + log.warning( + "Precompiled pattern %r not found. Run torchgen/fuse/gen_patterns.py.", + unique_name, + ) + pat = getattr(m, unique_name) + + for arg in pytree.tree_iter(example_inputs): + if isinstance(arg, FakeTensor) and arg.constant is not None: + # This can be a problem - small fake tensors (e.g. `tensor(2)`) will + # hold onto their original constant value - and by stashing it here + # will cause a memory leak if the constant value is on GPU. + # Since this is just an optimization we can clear it out. + arg.constant = None + + if PatternPrettyPrinter.run(pat) in _seen_patterns and skip_duplicates: + return + _known_precompiled_patterns.append( + (search_fn, example_inputs, trace_fn, scalar_workaround, pat) + ) + register_replacement( + search_fn, + replace_fn, + example_inputs, + trace_fn, + pass_dicts, + extra_check, + scalar_workaround, + exclusive_arg_names, + search_fn_pattern=pat, + ) + + +@functorch_config.patch(functionalize_rng_ops=False) +def gen_pattern( + search_fn: SearchFn, + example_inputs: Sequence[Any], + trace_fn: TraceFn, + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), +) -> PatternExpr: + argnames = [*inspect.signature(search_fn).parameters.keys()] + + if scalar_workaround is None: + scalar_workaround = {} + flat_inputs = [] + input_idx = 0 # Positional arguments index + + for argname in argnames: + if argname in scalar_workaround: + flat_inputs.append(scalar_workaround[argname]) + else: + flat_inputs.append(example_inputs[input_idx]) + input_idx += 1 + + search_gm = trace_fn(search_fn, flat_inputs) + return fx_to_pattern( + search_gm, + ignore_types=(int, float, list, torch.device, torch.dtype), + argnames=argnames, + scalar_workaround=scalar_workaround, + exclusive_arg_names=exclusive_arg_names, + ) + + +def register_lowering_pattern( + pattern: PatternExpr, + extra_check: Callable[[Match], bool] = _return_true, + *, + pass_dict: _PassDictsType, + prepend: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Register an aten to inductor IR replacement pattern. The decorated + function is saved and then called a lowering time allowing direct + pattern to inductor IR conversion. + """ + + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + assert callable(handler) + LoweringPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_dict, prepend=prepend) + handler._inductor_lowering_function = True # type: ignore[attr-defined] + return handler + + return decorator + + +def register_graph_pattern( + pattern: PatternExpr, + extra_check: Callable[[Match], bool] = _return_true, + *, + pass_dict: _PassDictsType, + prepend: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Register a pattern that runs a function on the FX graph, allowing + custom transformation code. + """ + + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + assert callable(handler) + GraphPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_dict, prepend=prepend) + return handler + + return decorator + + +def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool: + # first node in the graph + return node is next(iter(graph.nodes)) + + +# match: copy_, relu_, _set_grad_enabled, manual_seed, _enter_autocast, etc +# doesn't match: __rshift__, etc +_mutation_op_re = re.compile(r"(? bool: + if node.op == "call_function": + if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr] + return True + elif node.op == "call_method": + if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type] + return True + return node.kwargs.get("out") is not None + + +def same_mutation_regions(a: torch.fx.Node, b: torch.fx.Node) -> bool: + assert "mutation_region_id" in a.meta + assert "mutation_region_id" in b.meta + return a.meta["mutation_region_id"] == b.meta["mutation_region_id"] + + +def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int: + n = node + while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n): + n = n.prev + mutation_region_id = n.meta.get("mutation_region_id", 0) + while n is not node: + n = n.next + if is_mutation_op(n): + mutation_region_id += 1 + n.meta["mutation_region_id"] = mutation_region_id + return mutation_region_id + + +def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool: + return "mutation_region_id" not in next(iter(graph.nodes)).meta + + +def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None: + mutation_region_id = 0 + for nd in graph.nodes: + if is_mutation_op(nd): + mutation_region_id += 1 + nd.meta["mutation_region_id"] = mutation_region_id + + +class PatternMatcherPass: + def __init__( + self, + pass_name: Optional[str] = None, + ) -> None: + super().__init__() + self.patterns: DefaultDict[ + Tuple[str, torch.fx.node.Target], List[PatternEntry] + ] = defaultdict(list) + self.pass_name = pass_name + + def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: + return self.patterns[item] + + def apply(self, gm: torch.fx.GraphModule) -> int: + if not self.patterns: + return 0 + if isinstance(gm, torch.fx.GraphModule): + graph = gm.graph + elif isinstance(gm, torch.fx.Graph): + graph = gm + gm = graph.owning_module + else: + raise RuntimeError( + f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}" + ) + if should_compute_mutation_region_ids(graph): # type: ignore[arg-type] + compute_mutation_region_ids(graph) # type: ignore[arg-type] + get_mutation_region_id_partial = functools.partial( + get_mutation_region_id, graph + ) + count = 0 + nodes = [] + has_call_module = False + for op, target in self.patterns: + if op == "call_module": + has_call_module = True + else: + nodes.append(graph.find_nodes(op=op, target=target, sort=False)) + if has_call_module: + nodes.append(graph.find_nodes(op="call_module", sort=False)) + pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher" + with GraphTransformObserver( + gm, pass_name, trace_config.log_url_for_graph_xform + ): + for node in sorted(itertools.chain.from_iterable(nodes), reverse=True): + target = extract_target(node) + if node.op == "call_module": + if (node.op, target) not in self.patterns: + continue + + # conservatively not applying pattern for cpu input, + # since some of the patterns induce codegen and split nodes. + # Note: we will only skip cpu compute if disable_cpp_codegen=True + if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False): + continue + + for entry in self.patterns[(node.op, target)]: + if node._erased: + break + m = entry.pattern.match(node) + # pattern match crosses mutation barrier - discard + if ( + is_match(m) + and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined] + ): + continue + if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + log.warning("%s%s %s %s", node, node.args, m, entry.pattern) + if is_match(m) and entry.extra_check(m): + count += 1 + entry.apply(m, graph, node) # type: ignore[arg-type] + counters["inductor"]["pattern_matcher_count"] += 1 + counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) + return count + + def clear(self) -> None: + self.patterns.clear() + + +def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn: + raise NotImplementedError + + +def fx_to_pattern( + gm: Union[torch.fx.GraphModule, torch.fx.Graph], + ignore_types: Sequence[Type[Any]] = (), + argnames: Sequence[str] = (), + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), +) -> PatternExpr: + """ + Convert an FX graph into a PatternExpr. This is useful for simple + patterns that can only match single functions and fixed-length lists. + """ + # scalar_workaround is a hack to capture dropout_p + # see https://github.com/pytorch/pytorch/issues/97894 + scalar_workaround = scalar_workaround or {} + inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()} + assert len(inv_scalar_workaround) == len(scalar_workaround) + + def process_arg(x: T) -> Union[T, KeywordArg, Ignored]: + if isinstance(x, (float, int)) and x in inv_scalar_workaround: + return KeywordArg(inv_scalar_workaround[x]) + if type(x) in ignore_types: + return Ignored() + if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x: + return Ignored() + return x + + argnum = itertools.count() + + class Converter(torch.fx.Interpreter): + call_method = _not_implemented + call_module = _not_implemented + get_attr = _not_implemented + + def placeholder( + self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override] + ) -> Union[ExclusiveKeywordArg, KeywordArg]: + n = next(argnum) + if n < len(argnames): + name = argnames[n] + elif argnames: + assert target.startswith("tangent") + name = target + else: + target = re.sub(r"_\d+$", "", target) # de-mangle arg name + name = target + if name in exclusive_arg_names: + return ExclusiveKeywordArg(name) + else: + return KeywordArg(name) + + def call_function( + self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override] + ) -> PatternExpr: + args, kwargs = pytree.tree_map(process_arg, (args, kwargs)) + if list in ignore_types: + # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...] + args = [process_arg(a) for a in args] + kwargs = {k: process_arg(a) for k, a in kwargs.items()} + return CallFunction(target, *args, **kwargs) + + def run_node(self, n: torch.fx.Node) -> Any: + rv = super().run_node(n) + if n.op == "output" and isinstance(rv, tuple): + assert len(rv) == len(n.args[0]) # type: ignore[arg-type] + for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type] + r.users = len(arg.users) + else: + rv.users = len(n.users) + return rv + + pattern = Converter(gm).run() # type: ignore[arg-type] + if not isinstance(pattern, PatternExpr): + return MultiOutputPattern(pytree.tree_leaves(pattern)) + return pattern + + +@torch.no_grad() +def fwd_only( + fn: Callable[..., Any], + args: Sequence[Any], + *, + run_functional_passes: bool = True, + get_decomp_fn: Optional[Callable[..., Any]] = None, +) -> torch.fx.GraphModule: + """Build a normalized inference graph, for use with fx_to_pattern""" + # TODO - look into using aot autograd, asserting no mutating ops here + with enable_python_dispatcher(): + decompositions = ( + get_decomp_fn() if get_decomp_fn is not None else select_decomp_table() + ) + gm = make_fx(fn, decompositions, tracing_mode="real")(*args) + + from .fx_passes.post_grad import remove_noop_ops + + if run_functional_passes: + remove_noop_ops(gm.graph) + gm.graph.eliminate_dead_code() + + gm.recompile() + return gm + + +@torch.enable_grad() +def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule: + """Build a normalized training graph, for use with fx_to_pattern""" + gm: Optional[torch.fx.GraphModule] = None + + def record_joint_graph( + joint_graph: torch.fx.GraphModule, inputs: Sequence[Any], **kwargs: Any + ) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + nonlocal gm + assert not gm + gm = clone_graph(joint_graph) + return default_partition(joint_graph, inputs, **kwargs) + + with torch._guards.tracing(None): + aot_function( + fn, + lambda g, i: make_boxed_func(g), + partition_fn=record_joint_graph, + decompositions=select_decomp_table(), + keep_inference_input_mutations=True, + enable_log=False, + )(*args) + assert gm + + from .fx_passes.post_grad import remove_noop_ops + + remove_noop_ops(gm.graph) + + from .fx_passes.joint_graph import pointless_view + + matcher_pass = PatternMatcherPass() + + pattern = CallFunction( + torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size") + ) + GraphPatternEntry( + pattern=pattern, handler=pointless_view, extra_check=_return_true + ).register(matcher_pass.patterns) + matcher_pass.apply(gm.graph) # type: ignore[arg-type] + + # remove in/out specs + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.graph.eliminate_dead_code() + gm.recompile() + return gm + + +def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: + args: List[torch.fx.node.Argument] = [] + torch.fx.map_arg((n.args, n.kwargs), args.append) + return args + + +def stable_topological_sort(graph: torch.fx.Graph) -> None: + # Nodes are in exactly one of these three collections: + + # - Nodes in `pending` are waiting to be processed (in reverse order): + pending = list(reversed(graph.nodes)) + + # - Nodes in `ready` have been processed and are already in the correct + # order. + ready = set() + + # - `waiting` is a mapping from a dependency to nodes which depend on that + # dependency. + waiting = defaultdict(list) + + # The cursor indicates the last processed node so we can add new nodes + # after it. + cursor = None + while pending: + node = pending.pop() + waiting_for = [x for x in _args(node) if x not in ready] + if waiting_for: + # We have unprocessed input nodes. Might as well wait for the last + # arg so an already sorted list will only recheck this node once. + waiting[waiting_for[-1]].append(node) + else: + ready.add(node) + if cursor and cursor.next is not node: + cursor.append(node) + cursor = node + # Mark the nodes that have been waiting for this node to finish as + # ready to check again. + pending.extend(reversed(waiting.pop(node, ()))) + + assert not waiting and len(ready) == len(graph.nodes) + + +def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]: + """Wrapper around lazy init functions in fx_passes/""" + + @functools.lru_cache(None) + @functools.wraps(fn) + def lazy_init() -> Any: + counters_ref = counters["inductor"].copy() + + with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): + result = fn() + + # clear view matches encountered during tracing + counters["inductor"] = counters_ref + + return result + + return lazy_init + + +def config_flag(name: str) -> Callable[[Match], Any]: + """Function for extra_check to put pass behind a flag""" + + def flag_check(match: Match) -> Any: + return getattr(config, name) + + return flag_check + + +def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule: + class CopyGraph(Transformer): + def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node: + new_node = super().run_node(old_node) + if isinstance(new_node, torch.fx.Proxy): + new_node.node.meta.update(old_node.meta) + new_node.node.name = self.new_graph._graph_namespace.create_name( + old_node.name, None + ) + return new_node + + return CopyGraph(input_graph).transform() + + +_seen_patterns: Set[str] = set() + + +def get_arg_value( + node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None +) -> Any: + return ( + node.args[arg_number] + if len(node.args) > arg_number + else node.kwargs.get(kwarg_name) # type: ignore[arg-type] + ) + + +def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]: + fns = [fn] + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend([getattr(fn, overload) for overload in fn.overloads()]) + + return [node for node in nodes if node.target in fns] + + +def extract_target(node: torch.fx.Node) -> torch.fx.node.Target: + """For call_function and call_method, we directly use the target function; + For call_module, the target is string, and we treat the module class + as a function. + """ + if node.op == "call_module": + return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type] + return node.target diff --git a/lib/python3.10/site-packages/torch/_inductor/quantized_lowerings.py b/lib/python3.10/site-packages/torch/_inductor/quantized_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..80910e67d3a61b7ba1ae115371c6b59786eff205 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/quantized_lowerings.py @@ -0,0 +1,92 @@ +# mypy: allow-untyped-defs +import logging + +import torch +from torch._inductor.kernel.mm_common import mm_args + +from . import config as inductor_config, lowering +from .codegen.cpp_gemm_template import CppPackedGemmTemplate +from .codegen.cpp_utils import create_epilogue_with_attr +from .lowering import expand, register_lowering +from .select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, +) +from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template + + +log = logging.getLogger(__name__) + +aten__weight_int8pack_mm = ExternKernelChoice( + torch._weight_int8pack_mm, "at::_weight_int8pack_mm", has_out_variant=False +) + + +quantized = torch.ops.quantized +_quantized = torch.ops._quantized +aten = torch.ops.aten + + +def register_quantized_ops(): + lowering.add_needs_realized_inputs( + [ + quantized.max_pool2d, + _quantized.wrapped_fbgemm_pack_gemm_matrix_fp16, + _quantized.wrapped_fbgemm_linear_fp16_weight, + ] + ) + + lowering.make_fallback(quantized.max_pool2d) + lowering.make_fallback(_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) + lowering.make_fallback(_quantized.wrapped_fbgemm_linear_fp16_weight) + + +def register_woq_mm_ops(): + @register_lowering(aten._weight_int8pack_mm, type_promotion_kind=None) + def int8pack_mm(input, weight, scale, *, layout=None): + _, _, _, layout, mat1, mat2 = mm_args( + input, weight, layout=layout, mat2_transposed=True + ) + assert ( + mat1.get_dtype() in [torch.bfloat16, torch.float16, torch.float] + and mat2.get_dtype() == torch.int8 + ) + aten_layout = layout + + # options to tune from + choices = ( + [aten__weight_int8pack_mm.bind((mat1, mat2, scale), aten_layout)] + if use_aten_gemm_kernels() + else [] + ) + + # scale is applied as an epilogue, and the scale tensor is expanded (with a view op) + # for broadcasting, as it's 1D. + def _mul_epilogue(buf): + return create_epilogue_with_attr( + buf, "mul", other=realize_inputs(expand(scale, layout.size)) + ) + + if use_cpp_packed_gemm_template(aten_layout, mat1, mat2, mat2_transposed=True): + CppPackedGemmTemplate.add_choices( + choices, + aten_layout, + [mat1, mat2, scale], + trans_w=True, + epilogue_creator=_mul_epilogue, + ) + + if ( + len(choices) == 0 + and inductor_config.autotune_fallback_to_aten + and not use_aten_gemm_kernels() + ): + log.warning("No choices for GEMM, using ATen backend as fallback") + return aten__weight_int8pack_mm.bind( + (mat1, mat2, scale), aten_layout + ).output_node() + + return autotune_select_algorithm( + "_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout + ) diff --git a/lib/python3.10/site-packages/torch/_inductor/remote_cache.py b/lib/python3.10/site-packages/torch/_inductor/remote_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..14963a1bf5d4358b012f69bb97411f3d610f6bbe --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/remote_cache.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import json +import os +import typing +from abc import abstractmethod +from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union +from typing_extensions import override, TypeAlias + +from torch._inductor import config + + +try: + import redis +except ImportError: + redis = None # type: ignore[assignment] + + +if config.is_fbcode(): + from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found] + Sample as Sample_, + ) + + Sample: TypeAlias = Sample_ +else: + Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef] + + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class RemoteCacheBackend(Generic[_T]): + """ + A backend implementation for accessing a remote/distributed cache. Only + works with bytes in/out. For structured data use a RemoteCache. + """ + + @abstractmethod + def get(self, key: str) -> Optional[_T]: + pass + + @abstractmethod + def put(self, key: str, data: _T) -> None: + pass + + +# Serde that encodes from _T to _U and decodes from _U to _T. +class RemoteCacheSerde(Generic[_T, _U]): + @abstractmethod + def encode(self, data: _T) -> _U: + pass + + @abstractmethod + def decode(self, data: _U) -> _T: + pass + + +JsonDataTy = Optional[ + Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]] +] + + +class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]): + def encode(self, data: JsonDataTy) -> bytes: + return bytes(json.dumps(data), "ascii") + + def decode(self, data: bytes) -> JsonDataTy: + return json.loads(data) + + +class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]): + def encode(self, data: _T) -> _T: + return data + + def decode(self, data: _T) -> _T: + return data + + +class RemoteCache(Generic[_T]): + backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None + + def __init__( + self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U] + ) -> None: + # Support for testing. + if (override_cls := self.__class__.backend_override_cls) is not None: + self.backend = override_cls() + else: + self.backend = backend + self.serde = serde + + def get(self, key: str) -> Optional[_T]: + sample = self._create_sample() + result = self._get(key, sample) + self._log_sample(sample) + return result + + def put(self, key: str, value: _T) -> None: + sample = self._create_sample() + self._put(key, value, sample) + self._log_sample(sample) + + def _decode(self, data: _U, sample: Optional[Sample]) -> _T: + return self.serde.decode(data) + + def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U + return self.serde.encode(value) + + def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]: + if data := self.backend.get(key): + return self._decode(data, sample) + return None + + def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None: + data = self._encode(value, sample) + self.backend.put(key, data) + + def _create_sample(self) -> Optional[Sample]: + return None + + def _log_sample(self, sample: Optional[Sample]) -> None: + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]): + """ + A Redis implementation of a remote/distributed cache. + """ + + _key_fmt: str + _redis: Optional[redis.Redis] = None + + def __init__(self, cache_id: str) -> None: + if not redis: + # We had trouble importing redis - just skip init. + return + + self._key_fmt = f"pt2:{cache_id}:{{key}}" + self._redis = redis.Redis( + host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"), + port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)), + ) + + def __get_key(self, key: str) -> str: + return self._key_fmt.format(key=key) + + @override + def get(self, key: str) -> Optional[bytes]: + if not self._redis: + # Either redis wasn't found or we already had some trouble... + return None + + try: + value = self._redis.get(self.__get_key(key)) + except redis.exceptions.ConnectionError: + # Redis is lazy and doesn't actually attempt to connect until the + # first use. Mark is as unavailable now. + self._redis = None + return None + + # In theory redis.get() can return an Awaitable as well... + assert value is None or isinstance(value, bytes) + return value + + @override + def put(self, key: str, data: bytes) -> None: + if not self._redis: + # Either redis wasn't found or we already had some trouble... + return + + try: + self._redis.set(self.__get_key(key), data) + except redis.exceptions.ConnectionError: + # Redis is lazy and doesn't actually attempt to connect until the + # first use. Mark is as unavailable now. + self._redis = None + + +class RedisRemoteCache(RemoteCache[JsonDataTy]): + def __init__(self, key: str) -> None: + # Special test handling: If we're just going to override the backend + # anyway don't require redis + if self.__class__.backend_override_cls: + # This is totally bogus but it works for now... + backend = typing.cast(RemoteCacheBackend[bytes], None) + else: + backend = RedisRemoteCacheBackend(key) + serde = RemoteCacheJsonSerde() + super().__init__(backend, serde) + + +class RemoteAutotuneCache(RedisRemoteCache): + pass + + +class RemoteFxGraphCache(RedisRemoteCache): + pass diff --git a/lib/python3.10/site-packages/torch/_inductor/scheduler.py b/lib/python3.10/site-packages/torch/_inductor/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..564a9b4ccfd8337e13c52c005c6ba5149fd85825 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/scheduler.py @@ -0,0 +1,3727 @@ +# mypy: disallow-untyped-defs +from __future__ import annotations + +import collections +import dataclasses +import functools +import itertools +import logging +import math +import operator +import os +import pprint +import textwrap +import traceback +import typing +from typing import ( + Any, + Callable, + Counter, + DefaultDict, + Dict, + Generic, + List, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, +) + +import sympy + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.symbol import free_symbol_is_type, SymT +from torch.utils._triton import has_triton + +from . import comms, config, dependencies, ir, metrics +from .codecache import write_text +from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel +from .comm_analysis import estimate_nccl_collective_runtime +from .dependencies import Dep, MemoryDep, StarDep, WeakDep +from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout +from .loop_body import LoopBody +from .runtime.runtime_utils import green_text, red_text +from .sizevars import SimplifyIndexing +from .utils import ( + cache_on_self, + cmp, + device_need_guard, + get_device_tflops, + get_dtype_size, + get_gpu_dram_gbps, + IndentedBuffer, + is_collective, + is_gpu, + is_wait, + sympy_product, +) +from .virtualized import V + + +log = logging.getLogger(__name__) +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") +loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering") + + +@dataclasses.dataclass +class SchedulerBuffer: + scheduler: Scheduler + node: ir.Buffer + defining_op: BaseSchedulerNode + users: List[NodeUser] = dataclasses.field(default_factory=list) + + def __hash__(self) -> int: + return hash(self.node.name) + + def debug_str(self) -> str: + result = IndentedBuffer() + name = self.get_name() + result.writeline(f"{name}: {type(self.node).__name__}") + result.writeline(f"{name}.layout = {self.node.layout}") + if self.get_aliases(): + result.writeline(f"{name}.aliases = {pformat(self.get_aliases())}") + if self.get_mutations(): + result.writeline(f"{name}.mutations = {pformat(self.get_mutations())}") + + if len(self.users) <= 1: + result.writeline(f"{name}.users = {self.users}") + else: + result.writeline(f"{name}.users = [") + with result.indent(1): + for user in self.users: + result.writeline(f"{user},") + result.writeline("]") + return result.getrawvalue() + + def get_name(self) -> str: + return self.node.get_name() + + def allocate(self) -> None: + assert self.node is not None + if not self.node.should_allocate(): + return + + if self.node.get_inputs_that_alias_output() or self.node.get_mutation_names(): + V.graph.wrapper_code.codegen_allocation(self.node) + return + + # hacky check for if V.kernel is a real kernel or NullHandler + if ( + hasattr(V.kernel, "args") + and self.get_name() in V.kernel.inplace_update_buffers + ): + V.graph.wrapper_code.codegen_inplace_reuse( + self.scheduler.name_to_buf[ + V.kernel.inplace_update_buffers[self.get_name()] + ].node, + self.node, + ) + else: + V.graph.wrapper_code.codegen_allocation(self.node) + + def can_free(self) -> bool: + # There's no real allocated buffer, no need to free it + assert self.node is not None + if isinstance(self.node.layout, ir.NoneLayout): + return False + for use in self.users: + if isinstance(use.node, OutputNode): + return False + return True + + def set_users(self, users: List[NodeUser]) -> None: + # deduplicate + result: Dict[int, NodeUser] = {} + for use in users: + if id(use.node) in result: + result[id(use.node)] = use.merge(result[id(use.node)]) + else: + result[id(use.node)] = use + self.users = list(result.values()) + + def get_aliases(self) -> Sequence[str]: + assert self.node is not None + return self.node.get_inputs_that_alias_output() + + def get_mutations(self) -> List[str]: + assert self.node is not None + return self.node.get_mutation_names() + + +class BaseSchedulerNode: + group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]] + read_writes: dependencies.ReadWrites + unmet_dependencies: OrderedSet[Dep] + # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode. + # e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node + # in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3. + # For non-"grouped" nodes (i.e. regular SchedulerNode), + # .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`. + min_order: int + max_order: int + + def __init__(self, scheduler: Scheduler) -> None: + self.scheduler: Scheduler = scheduler + + def _init_from_node(self, node: ir.Operation) -> None: + self.node: Optional[ir.Operation] = node + self.ancestors: OrderedSet[str] = OrderedSet() + self.last_usage: OrderedSet[ + str + ] = OrderedSet() # buffers that won't be used after this kernel + self.written = False + self.outputs: List[SchedulerBuffer] = [ + SchedulerBuffer( + scheduler=self.scheduler, + node=output, + defining_op=self, + ) + for output in node.get_outputs() + ] + self.outputs_by_name: Dict[str, SchedulerBuffer] = { + buf.get_name(): buf for buf in self.outputs + } + + def __repr__(self) -> str: + return f"{type(self).__name__}(name={self.get_name()!r})" + + def debug_str(self) -> str: + """Longer form printout for trace logs""" + name = self.get_name() + buf = IndentedBuffer() + buf.splice( + f"""\ +{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__}) +{name}.writes = {pformat(self.read_writes.writes)} +{name}.unmet_dependencies = {pformat(self.unmet_dependencies)} +{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)} +{name}.outputs = [ + """ + ) + with buf.indent(): + for out in self.get_outputs(): + buf.splice(out.debug_str()) + buf.writeline("]") + + try: + buf.splice(self.debug_str_extra()) + except Exception: + log.warning("Ignoring error in debug_str()", exc_info=True) + + return buf.getrawvalue().rstrip() + + def debug_str_extra(self) -> str: + return "" + + def debug_str_short(self) -> str: + maybe_data = getattr(self.node, "data", None) + data_str = "" + if isinstance(maybe_data, torch._inductor.ir.Pointwise): + data_str = ", " + maybe_data.str_helper( + [maybe_data.get_size()], shorten=False, multiline=False + ) + elif isinstance(maybe_data, torch._inductor.ir.Reduction): + data_str = ", " + maybe_data.str_helper( + [maybe_data.get_reduction_size(), maybe_data.get_reduction_type()], + shorten=False, + multiline=False, + ) + return f"{self}{data_str}" + + def log_details(self) -> None: + log.info( + "%s: unmet_dependencies = %s, writes = %s", + self, + self.unmet_dependencies, + self.read_writes.writes, + ) + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + return + + def update_mutated_names(self, renames: Dict[str, str]) -> None: + self.set_read_writes(self.read_writes.rename(renames)) + + def add_fake_dep(self, dep: Dep) -> None: + self.set_read_writes(self.read_writes.with_read(dep)) + + def has_aliasing_or_mutation(self) -> bool: + return any( + buf.get_aliases() or buf.get_mutations() for buf in self.get_outputs() + ) + + def set_read_writes(self, rw: dependencies.ReadWrites) -> None: + self.read_writes = rw + self.unmet_dependencies = self.read_writes.reads + self.prune_deps() + + def set_last_usage( + self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str] + ) -> None: + used_buffers = self.used_or_aliased_buffer_names() + used_buffers = OrderedSet([mutation_real_name.get(k, k) for k in used_buffers]) + self.last_usage = used_buffers - future_used_buffers + + def mark_run(self) -> None: + for buf in self.outputs: + buf.allocate() + + def used_buffer_names(self) -> OrderedSet[str]: + return OrderedSet( + dep.name + for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) + ) + + def used_or_aliased_buffer_names(self) -> OrderedSet[str]: + used_names: OrderedSet[str] = OrderedSet() + + deps = [ + dep.name + for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) + ] + while len(deps) > 0: + dep = deps.pop() + used_names.add(dep) + if V.graph.name_to_buffer.get(dep): + for alias in V.graph.name_to_buffer[dep].get_inputs_that_alias_output(): + if alias not in used_names: + deps.append(alias) + return used_names + + def prune_deps(self) -> None: + self.unmet_dependencies = OrderedSet( + dep + for dep in self.unmet_dependencies + if dep.name not in self.scheduler.available_buffer_names + ) + + def prune_weak_deps(self) -> None: + # Prune weak dependencies on operations that have been removed + def should_prune(dep: Dep) -> bool: + if not isinstance(dep, WeakDep): + return False + op = self.scheduler.name_to_buf[dep.name].defining_op + return op.get_name() in V.graph.removed_operations + + to_remove = OrderedSet( + dep for dep in self.read_writes.reads if should_prune(dep) + ) + self.set_read_writes(self.read_writes.remove_reads(to_remove)) + + def prune_redundant_deps( + self, name_to_fused_node: Dict[str, BaseSchedulerNode] + ) -> None: + _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf) + + def get_name(self) -> str: + assert self.node is not None + return self.node.get_operation_name() + + def get_first_name(self) -> str: + return self.get_name() + + def get_operation_names(self) -> OrderedSet[str]: + return OrderedSet(node.get_name() for node in self.get_nodes()) + + def get_buffer_names(self) -> OrderedSet[str]: + return OrderedSet(out.get_name() for out in self.outputs) + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + return [self] + + def get_outputs(self) -> Sequence[SchedulerBuffer]: + return self.outputs + + def get_output(self, buf_name: str) -> SchedulerBuffer: + return self.outputs_by_name[buf_name] + + def get_device(self) -> torch.device: + assert self.node is not None + return self.node.get_device() + + def is_reduction(self) -> bool: + return False + + def is_split_scan(self) -> bool: + return False + + def is_template(self) -> bool: + return False + + def is_extern(self) -> bool: + return False + + def is_foreach(self) -> bool: + return False + + def can_inplace(self, read_dep: dependencies.Dep) -> bool: + return False + + def has_side_effects(self) -> bool: + return False + + def decide_inplace_update(self) -> None: + """ + Decide if there should be inplace updates for the node + and record the decision in the active kernel. + """ + from .codegen.wrapper import buffer_reuse_key + + if not ( + isinstance(self, (SchedulerNode,)) + and config.inplace_buffers + and V.graph.has_feature(self.get_device(), BackendFeature.INPLACE_BUFFERS) + and ( + not isinstance(V.kernel, torch._inductor.codegen.simd.SIMDKernel) + or getattr(V.kernel, "mutations", None) is not None + ) + # hacky check for if V.kernel is a real kernel or NullHandler + and hasattr(V.kernel, "args") + ): + return + + ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name) + + for buf in self.get_outputs(): + buf_node = buf.node + assert buf_node is not None + if ( + not buf_node.should_allocate() + or buf_node.get_inputs_that_alias_output() + or buf_node.get_mutation_names() + or buf.get_name() in V.graph.removed_buffers + ): + continue + + for read in ordered_reads: + input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get( + read.name + ) + if ( + input_buf + and V.graph.wrapper_code.can_reuse(input_buf, self) + and not isinstance(input_buf.defining_op, NopKernelSchedulerNode) + ): + assert input_buf.users is not None + remaining_uses = [ + x + for x in input_buf.users + if x.node.get_name() not in self.scheduler.completed_operations + ] + if ( + len(remaining_uses) == 1 + and remaining_uses[0].can_inplace + and remaining_uses[0].node is self + and input_buf.node is not None + and not isinstance( + input_buf.node.get_layout(), + ( + ir.MultiOutputLayout, + ir.MutationLayoutSHOULDREMOVE, + ), + ) + and not ( + isinstance( + input_buf.defining_op.node, + (ir.FallbackKernel, ir.MultiOutput), + ) + and len(input_buf.node.get_inputs_that_alias_output()) > 0 + ) + and buffer_reuse_key(input_buf.node) + == buffer_reuse_key(buf.node) + ): + # if there isn't a triton kernel, then we don't need to call triton-specific things. + # but TODO this might be a convenient place to signal to the Collective kernels to inplace + # (and, can we make "kernel" less generic of a name?) + V.kernel.args.make_inplace(input_buf.get_name(), buf.get_name()) + # mutations not tracked in cpp kernels + if isinstance( + V.kernel, torch._inductor.codegen.simd.SIMDKernel + ): + V.kernel.mutations.add(input_buf.get_name()) + V.kernel.mutations.add(buf.get_name()) + + # update last usage of reused node + self.last_usage.discard(input_buf.get_name()) + + V.kernel.inplace_update_buffers[ + buf.get_name() + ] = input_buf.get_name() + break + + def codegen_originating_info( + self, buffer: IndentedBuffer, only_once: bool = True + ) -> None: + if not config.comment_origin: + return + + if only_once and self.written: + return + assert self.node is not None + origins = self.node.get_origins() + out_lines = [] + + for o in origins: + if o.op == "output": + # These are boring and samey + continue + + out_lines.append("") + # TODO(voz): Should the pragma be constant somewhere? + out_lines.append("#pragma CMT ORIGIN:") + op_info_str = f"#pragma CMT {o.op} {o.target}" + if "seq_nr" in o.meta: + op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}" + out_lines.append(op_info_str) + if "stack_trace" in o.meta: + stack_trace = f"{o.meta['stack_trace']}" + stack_trace_last_line = stack_trace.split("|")[-1] + out_lines.append( + "#pragma CMT " + + stack_trace_last_line.replace("{", "{{") + .replace("}", "}}") + .replace("\n", "\\") + ) + out_lines.append("#pragma CMT END ORIGIN") + out_lines.append("") + + if len(out_lines) == 0: + return + + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + buffer.writelines(out_lines) + self.written = True + + def get_read_write_buffers_sizes(self) -> int: + """ + Counting the number of bytes accessed for a kernel is + surprisingly tricky. In particular, there is a differentiation + between 'theoretical' memory accesses and practical memory + accesses. For example, a layernorm kernel may actually access an + input 3 times, but in theory, it only needs to access its input + once (and may be optimized to do so through say, persistent + reductions) + + Another example is that even though a buffer is passed in, we may + not access the entire buffer. This may occur if we are accessing + a slice of the buffer. Another tricky case is for indirect + indexing, where the amount of bytes accessed depends on the + values of the input. + + What this function aims to compute is the memory accesses for + worst-case inputs, best-case optimization. What this means is + that for each buffer we compute the amount of potential accesses in two ways and take the minimum. + + 1. Numel in ranges multiplied by number of deps the buffer has + 2. The buffer size + """ + if isinstance(self, NopKernelSchedulerNode): + return 0 + if isinstance(self, ExternKernelSchedulerNode) and isinstance( + self.node, MultiOutput + ): + # todo: Calculate this - it's kinda annoying. + return 0 + + def try_size_hint(s: sympy.Expr) -> int: + return V.graph.sizevars.size_hint(s, fallback=0) + + if isinstance(self, SchedulerNode): + node_numel = try_size_hint( + sympy_product(self.get_ranges()[0]) + * sympy_product(self.get_ranges()[1]), + ) + else: + node_numel = int(1e9) + buf_accesses = collections.defaultdict(list) + for dep in self.read_writes.reads | self.read_writes.writes: + buf_accesses[dep.name].append(dep) + + reads = OrderedSet(dep.name for dep in self.read_writes.reads) + writes = OrderedSet(dep.name for dep in self.read_writes.writes) + + def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool: + users = self.scheduler.name_to_buf[buf].users + buf_uses = OrderedSet(user.node for user in users) + return len(buf_uses - OrderedSet(snodes)) > 0 + + if isinstance(self, FusedSchedulerNode): + removed_buffers = OrderedSet( + dep for dep in writes if not is_materialized(dep, self.snodes) + ) + writes = writes - removed_buffers + reads = reads - removed_buffers + node_bytes = 0 + + for buf_name in reads | writes: + buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name]) + buf: Union[ir.Buffer, ir.TensorBox] + if buf_name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[buf_name] + elif buf_name in V.graph.graph_inputs: + buf = V.graph.graph_inputs[buf_name] + else: + continue + + def get_buf_bytes(buf: Optional[Union[ir.Buffer, ir.TensorBox]]) -> int: + if not buf: + return 0 + # Kind of a lazy way to get the MultiOutput nodes corresponding to + # a MultiOutputLayout + if isinstance(buf.layout, MultiOutputLayout): + users = self.scheduler.name_to_buf[buf.get_name()].users + tot = 0 + for user in users: + assert isinstance(user.node, BaseSchedulerNode) + if isinstance(user.node.node, MultiOutput): + for sched_buf in user.node.get_outputs(): + tot += get_buf_bytes(sched_buf.node) + else: + # Buf is a MultiOutputLayout but not all of its + # users are MultiOutputs... + # TODO: Figure out what's going on + return 0 + return tot + elif isinstance(buf.layout, ir.NoneLayout): + return sum( + get_buf_bytes(V.graph.get_buffer(mut_name)) + for mut_name in buf.get_mutation_names() + ) + else: + buf_elems = try_size_hint(sympy_product(buf.get_size())) + return get_dtype_size(buf.get_dtype()) * min( + buf_accessed_elems, buf_elems + ) + + node_bytes += get_buf_bytes(buf) + + return node_bytes + + def get_estimated_runtime(self) -> float: + """ + Returns estimated op runtime in nanoseconds (ns) + """ + buf = self.get_nodes()[0].get_outputs()[0] + layout = buf.node.get_layout() + dtype = buf.node.get_dtype() + + if layout.device is not None and not is_gpu(layout.device.type): + # default to no reordering based on runtime + return 0 + + # Collective kernels + if is_collective(self.node): + assert isinstance(self.node, ir.IRNode) + try: + return estimate_nccl_collective_runtime(self.node) + except ValueError as e: + # We don't know how to estimate runtime for this collective, + # falling back to 0 + log.info(e) + return 0 + + elif is_wait(self.node): + # ir.Wait is only used for collective ops. + # The time needed for the collective op is already estimated and considered + # when we are processing the collective op IR node, so ir.Wait takes 0 time + # since it doesn't take extra time to get the result after the collective is completed. + return 0 + + try: + gpu_memory_bandwidth = get_gpu_dram_gbps() + gpu_flops = get_device_tflops(dtype) * 10**12 + except Exception: + return 0 + + if isinstance(self, ExternKernelSchedulerNode): + assert isinstance(self.node, ir.ExternKernel), f"{type(self.node)=}" + op = kernel_name_to_op.get( + getattr(self.node, "python_kernel_name", ""), None + ) + + # if there is a resolved op, dry-run using fake mode and record flop count + if op is not None: + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.utils.flop_counter import FlopCounterMode + + if any( + len(free_unbacked_symbols(n.get_numel())) > 0 + for n in self.node.inputs + ): + # Tensor has unbacked symints, we don't know how to estimate + # runtime for that today + return 0 + + with FakeTensorMode() as fake_mode, FlopCounterMode( + display=False + ) as flop_counter_mode, V.set_current_node( + self.node.fx_node + ), V.set_fake_mode( + fake_mode + ): + from .ir import ir_node_to_tensor + + fake_inputs = [ + ir_node_to_tensor(input, guard_shape=False) + for input in self.node.inputs + ] + cls = self.node.__class__ + cls.process_kernel(op, *fake_inputs, **self.node.kwargs) + + # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship + factor = 1.0 + counted_flops = flop_counter_mode.get_total_flops() + counted_bytes = self.get_read_write_buffers_sizes() + compute_time = (factor * counted_flops / gpu_flops) * 1e9 + transfer_time = counted_bytes / gpu_memory_bandwidth + + # Return estimated runtime in nanoseconds + return max(compute_time, transfer_time) + + elif isinstance(self, FusedSchedulerNode) or isinstance( + self.node, ComputedBuffer + ): + # Return estimated runtime in nanoseconds (bytes / gbps) + return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth + + return 0 + + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + return None + + +class WhyNoFuse: + # TODO when we drop support for Python < 3.10, we can use + # @dataclass(slots=True) instead of manually specifying __slots__. + __slots__ = ["node1", "node2", "reason", "args"] + reason: str + args: Tuple[Any, ...] + + def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None: + self.node1 = node1 + self.node2 = node2 + + def __call__(self, reason: str, *args: Any) -> None: + self.reason = reason + self.args = args + fusion_log.debug(self) + + def __str__(self) -> str: + return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + ( + self.reason % self.args + ) + + +def pformat(obj: Any) -> str: + if isinstance(obj, OrderedSet): + # pformat has trouble with sets of sympy exprs + obj = sorted(obj, key=str) + result = pprint.pformat(obj, indent=4) + if "\n" in result: + return f"\n{textwrap.indent(result, ' ' * 4)}" + return result + + +class OutputNode: + def __init__(self, dep: StarDep) -> None: + self.unmet_dependencies = OrderedSet([dep]) + + def is_reduction(self) -> bool: + return False + + def get_inputs_that_alias_output(self) -> Sequence[str]: + return () + + def get_name(self) -> str: + return "OUTPUT" + + __repr__ = get_name + + +def _prune_redundant_deps( + node: BaseSchedulerNode, + name_to_fused_node: Dict[str, BaseSchedulerNode], + name_to_buf: Dict[str, SchedulerBuffer], +) -> None: + """ + Prunes weakdeps intended for mutation ordering + on an upstream fused node if after fusion there is another dependency + on the fused upstream node, making the weakdep redundant + + In essence this enforces an ordering on fusions. As fusions occur, weakdeps will + be incrementally removed, enabling other fusions, ensuring they are fused in order. + """ + name_to_dep_count: Counter[str] = collections.Counter() + + for dep in node.unmet_dependencies: + if not isinstance(dep, WeakDep): + op = name_to_buf[dep.name].defining_op + name_to_dep_count[name_to_fused_node[op.get_name()].get_name()] += 1 + + def should_prune(dep: Dep) -> bool: + if isinstance(dep, WeakDep): + op_name = name_to_buf[dep.name].defining_op.get_name() + is_redundant = name_to_dep_count[name_to_fused_node[op_name].get_name()] > 0 + # These can occur because fused nodes always gather deps from their snodes + # If B has a weakdep on A + # B gets fused with C, then any time BC is fused, the weakdep will reappear + is_self_dep = name_to_fused_node[op_name] == node + return is_redundant or is_self_dep + else: + return False + + deps_to_prune = OrderedSet( + dep for dep in node.unmet_dependencies if should_prune(dep) + ) + + if deps_to_prune: + node.unmet_dependencies = node.unmet_dependencies - deps_to_prune + node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) + + +# TODO(xmfan): reuse: an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel +kernel_name_to_op = { + "extern_kernels.convolution": torch.ops.aten.convolution, + "extern_kernels.mm": torch.ops.aten.mm, + "extern_kernels.bmm": torch.ops.aten.bmm, + "extern_kernels.addmm": torch.ops.aten.addmm, +} + + +class ExternKernelSchedulerNode(BaseSchedulerNode): + def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self.set_read_writes(node.get_read_writes()) + + def debug_str_extra(self) -> str: + return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" + + def is_extern(self) -> bool: + return True + + def has_side_effects(self) -> bool: + assert self.node is not None + return hasattr(self.node, "has_side_effects") and self.node.has_side_effects() + + +class NopKernelSchedulerNode(BaseSchedulerNode): + def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self.set_read_writes(node.get_read_writes()) + + +class SchedulerNode(BaseSchedulerNode): + def __init__( + self, + scheduler: Scheduler, + node: Union[ir.ComputedBuffer, ir.TemplateBuffer], + ) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self._compute_attrs() + + def _compute_attrs( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ) -> None: + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) + self._sizes, self._body = self.node.simplify_and_reorder( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, + ) + + group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn + self.group = (self.node.get_device(), group_fn(self._sizes)) + + # Don't normalize since normalization will merge loops which + # makes it hard to decide new loop orders. + should_normalize = ( + not config.loop_ordering_after_fusion + or self.node.get_device().type != "cuda" + ) + + if isinstance(self.node, ir.TemplateBuffer): + self.set_read_writes( + self.node.extract_read_writes(normalize=should_normalize) + ) + else: + self.set_read_writes( + dependencies.extract_read_writes( + self._body, *self._sizes, normalize=should_normalize + ) + ) + + def recompute_size_and_body( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ) -> None: + self._compute_attrs( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, + ) + + def refresh_dependencies(self, normalize: bool) -> None: + # Fake dependencies are added manually. They can not be analyzed from + # extract_read_writes. Find them out and apply manually. + fake_deps = { + dep for dep in self.read_writes.reads if isinstance(dep, (WeakDep, StarDep)) + } + + # don't normalize since the loop order may need to be further changed + # later + self.set_read_writes( + dependencies.extract_read_writes( + self._body, *self._sizes, normalize=normalize + ).with_read(fake_deps) + ) + + def apply_new_loop_order(self, new_order: Sequence[int]) -> None: + self._body = self._body.reorder_iter_loops( + new_order, + ) + self._sizes = self._body.sizes + + self.refresh_dependencies(normalize=False) + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + new_order = None + self_sizes = self._sizes[0] + if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: + new_order = self_dep.decide_loop_order_to_match(other_dep) + + if new_order: + metrics.num_loop_reordering += 1 + loop_ordering_log.debug( + "Reorder loops for %s with order %s", self.get_name(), new_order + ) + self.apply_new_loop_order(new_order) + else: + loop_ordering_log.debug( + "Don't reordering %s because we can not decide the suitable loop order", + self.get_name(), + ) + + def debug_str_extra(self) -> str: + name = self.get_name() + lines = [ + f"{name}.group.device = {self.group[0]}", + f"{name}.group.iteration = {self.group[1]}", + f"{name}.sizes = {self._sizes}", + ] + for dep in self.read_writes.reads_and_writes(): + if not isinstance(dep, WeakDep): + buf_name = dep.name + buf = V.graph.get_buffer(buf_name) + lines.append(f"{buf_name}_layout = {pformat(buf.layout)}") + if isinstance(self._body, LoopBody): + lines.append(f"class {name}_loop_body:") + lines.append(textwrap.indent(self._body.debug_str(), " ")) + + assert self.node is not None + if ir.is_triton(self.node.get_device()): + lines.extend(debug_triton_code(self)) + + return "\n".join(lines) + + def get_ranges(self) -> Sequence[Sequence[sympy.Expr]]: + return self._sizes + + def is_reduction(self) -> bool: + assert isinstance( + self.node, (ir.ComputedBuffer, ir.TemplateBuffer) + ), f"{type(self.node)=}" + return bool(self.node.get_reduction_type()) + + def is_split_scan(self) -> bool: + assert isinstance( + self.node, (ir.ComputedBuffer, ir.TemplateBuffer) + ), f"{type(self.node)=}" + return isinstance(self.node, ir.ComputedBuffer) and isinstance( + self.node.data, ir.SplitScan + ) + + def is_template(self) -> bool: + return isinstance(self.node, ir.TemplateBuffer) + + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + return self.node if isinstance(self.node, ir.TemplateBuffer) else None + + def run(self, *index_vars: Sequence[sympy.Expr]) -> None: + self.decide_inplace_update() + self.mark_run() + self.codegen(index_vars) + + def ranges_from_index_vars( + self, index_vars: Sequence[Sequence[sympy.Expr]] + ) -> Dict[sympy.Expr, sympy.Expr]: + sizes = self._sizes + assert sum(map(len, sizes)) == sum(map(len, index_vars)) + var_ranges = dict( + zip( + itertools.chain.from_iterable(index_vars), + itertools.chain.from_iterable(sizes), + ) + ) + return var_ranges + + def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: + var_ranges = self.ranges_from_index_vars(index_vars) + try: + with V.set_ops_handler( + SimplifyIndexing(V.get_ops_handler(), var_ranges) + ), V.kernel.set_current_node(self): + self._body(*index_vars) + except Exception: + log.fatal("Error in codegen for %s", self.node) + raise + + @cache_on_self + def pointwise_read_writes(self) -> dependencies.ReadWrites: + """ + Get the memory dependencies in the non-reduction axis. + """ + sizes, reduction_sizes = self._sizes + return dependencies.extract_read_writes( + self._body, sizes, hidden_args=[[sympy.Integer(0)] * len(reduction_sizes)] + ) + + def can_inplace(self, read_dep: dependencies.Dep) -> bool: + if self.is_template(): + return False + if any(out.get_aliases() for out in self.get_outputs()): + return False + if len(self.read_writes.writes) == 1 and isinstance( + read_dep, dependencies.MemoryDep + ): + write_dep = next(iter(self.read_writes.writes)) + assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}" + return read_dep.index == write_dep.index and read_dep.size == write_dep.size + return False + + @cache_on_self + def _get_atomic_add_buffers(self) -> OrderedSet[str]: + buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet() + if isinstance(self._body, LoopBody): + for node in self._body.get_nodes(): + if ( + node.op == "call_method" + and node.target == "store" + and ( + ("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add") + or (len(node.args) == 5 and node.args[4] == "atomic_add") + ) + ): + buffers_store_as_atomic_add.add( + node.kwargs["name"] + if "name" in node.kwargs + else (node.args[1] if len(node.args) >= 2 else "") + ) + return buffers_store_as_atomic_add + + +def refresh_group_node_dependencies(group_snode: BaseSchedulerNode) -> None: + snodes = group_snode.snodes # type: ignore[attr-defined] + group_snode.set_read_writes( + dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) + ) + + group_snode.unmet_dependencies = ( + OrderedSet( + dep + for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes]) + if dep.name not in group_snode.get_buffer_names() + ) + - group_snode.read_writes.writes + ) + + +def init_group_node( + group_snode: BaseSchedulerNode, + scheduler: Scheduler, + snodes: List[BaseSchedulerNode], +) -> None: + assert isinstance(group_snode, (FusedSchedulerNode, GroupedSchedulerNode)) + group_snode.snodes = snodes + group_snode.scheduler = scheduler + group_snode.node = None + group_snode.ancestors = OrderedSet.union( + *[x.ancestors for x in snodes if x.ancestors is not None] + ) + + refresh_group_node_dependencies(group_snode) + + group_snode.min_order = min(x.min_order for x in group_snode.snodes) + group_snode.max_order = max(x.max_order for x in group_snode.snodes) + group_snode.outputs_by_name = { + buf.get_name(): buf for buf in group_snode.get_outputs() + } + + +class FusedSchedulerNode(BaseSchedulerNode): + """ + This is a "fake" scheduler node that represents a group of scheduler nodes + that are meant to be fused together. The way it does this is by maintaining + its unmet dependencies as the union of its constituent nodes. + """ + + snodes: List[BaseSchedulerNode] + + @classmethod + def fuse( + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> FusedSchedulerNode: + assert node1.scheduler is node2.scheduler + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes())) + return cls(node1.scheduler, nodes) + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + if self.is_template(): + # We can not really reorder loops for a triton template + return + self_sizes = None + for snode in self.snodes: + assert isinstance(snode, SchedulerNode) + if self_sizes is not None and self_sizes != snode._sizes[0]: + loop_ordering_log.debug( + "Can not reorder fused node due to different sizes" + ) + return + self_sizes = snode._sizes[0] + new_order = None + + assert self_sizes is not None + if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: + new_order = self_dep.decide_loop_order_to_match(other_dep) + + if not new_order: + loop_ordering_log.debug( + "Dont reordering fused node %s because we can not decide the suitable loop order", + self.get_name(), + ) + return + metrics.num_loop_reordering += 1 + loop_ordering_log.debug( + "Reorder loops for fused node %s with order %s", self.get_name(), new_order + ) + for snode in self.snodes: + assert isinstance(snode, SchedulerNode) + snode.apply_new_loop_order(new_order) # type: ignore[arg-type] + + refresh_group_node_dependencies(self) + + def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: + super().__init__(scheduler) + init_group_node(self, scheduler, snodes) + self.users: List[NodeUser] = [] + self.group = max(snodes, key=lambda x: int(x.is_reduction())).group + + @cache_on_self + def get_name(self) -> str: + return "_".join([x.get_name() for x in self.snodes]) + + def get_first_name(self) -> str: + return self.snodes[0].get_name() + + @cache_on_self + def get_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes]) + + def get_outputs(self) -> List[SchedulerBuffer]: + result: List[SchedulerBuffer] = [] + for node in self.snodes: + result.extend(node.get_outputs()) + return result + + def debug_str_extra(self) -> str: + lines = [ + f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}" + for i, node in enumerate(self.snodes) + ] + node = self.snodes[0].node + if node is not None: + device = node.get_device() + if ir.is_triton(device): + lines.extend(debug_triton_code(self)) + + return textwrap.indent("\n".join(lines).rstrip(), " ") + + def debug_str_short(self) -> str: + snodes_str = [node.debug_str_short() for node in self.snodes] + return f"{self}, snodes: {snodes_str}" + + def set_last_usage( + self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str] + ) -> None: + # Set self.last_usage using the global information + # This will be used for inter-kernel optimisations + super().set_last_usage(future_used_buffers, mutation_real_name) + # Set self.last_usage on the snodes + # This will be used for optimisations within the kernel + future_used_buffers: OrderedSet[str] = OrderedSet() + for node in reversed(self.snodes): + node.set_last_usage(future_used_buffers, mutation_real_name) + future_used_buffers.update(node.last_usage) + + @cache_on_self + def used_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union(*[x.used_buffer_names() for x in self.snodes]) + + @cache_on_self + def used_or_aliased_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union( + *[x.used_or_aliased_buffer_names() for x in self.snodes] + ) + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + return self.snodes + + def __repr__(self) -> str: + return f"{type(self).__name__}(nodes={self.get_name()})" + + @cache_on_self + def is_reduction(self) -> bool: + return any(x.is_reduction() for x in self.snodes) + + @cache_on_self + def is_split_scan(self) -> bool: + return any(x.is_split_scan() for x in self.snodes) + + @cache_on_self + def is_template(self) -> bool: + return any(x.is_template() for x in self.snodes) + + @cache_on_self + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + for node in self.snodes: + if node.is_template(): + return node.get_template_node() + return None + + def get_device(self) -> torch.device: + return self.group[0] + + @cache_on_self + def has_aliasing_or_mutation(self) -> bool: + return any(x.has_aliasing_or_mutation() for x in self.snodes) + + # None of these need to be implemented, as a FusedSchedulerNode is just an + # abstraction for scheduling purposes + def update_mutated_names(self, renames: Dict[str, str]) -> None: + raise NotImplementedError + + def add_fake_dep(self, name: Dep) -> None: + raise NotImplementedError + + def can_inplace(self, read_dep: dependencies.Dep) -> bool: + raise NotImplementedError + + def debug_str(self) -> str: + """Longer form printout for trace logs""" + name = self.get_name() + node_typestr = ",".join(type(n).__name__ for n in self.snodes) + buf = IndentedBuffer() + buf.splice( + f"""\ +{name}: {type(self).__name__}({node_typestr}) +{name}.writes = {pformat(self.read_writes.writes)} +{name}.unmet_dependencies = {pformat(self.unmet_dependencies)} +{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)} +{name}.outputs = [ + """ + ) + with buf.indent(): + for out in self.get_outputs(): + buf.splice(out.debug_str()) + buf.writeline("]") + + try: + buf.splice(self.debug_str_extra()) + except Exception: + log.warning("Ignoring error in debug_str()", exc_info=True) + + return buf.getrawvalue().rstrip() + + +class ForeachKernelSchedulerNode(FusedSchedulerNode): + """ + This is a schedular node that consists of a set of scheduler nodes that + has no data dependencies among them and can be executed in parallel. + """ + + def get_consumer_subnode_for( + self, producer: BaseSchedulerNode + ) -> Optional[BaseSchedulerNode]: + for buf in producer.get_outputs(): + if buf.get_name() in self.read_to_node: + return self.read_to_node[buf.get_name()] + + return None + + def get_producer_subnode_for( + self, consumer: BaseSchedulerNode + ) -> Optional[BaseSchedulerNode]: + producers = set() + for rd in consumer.read_writes.reads: + if rd.name not in self.scheduler.name_to_buf: + continue + + node_name = self.scheduler.name_to_buf[rd.name].defining_op.get_name() + if node_name in self.name_to_node: + producers.add(self.name_to_node[node_name]) + + # Don't permit fusion if there are multiple subnodes + # that this consumer reads from + if len(producers) == 1: + return next(iter(producers)) + else: + return None + + @classmethod + def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: + why = WhyNoFuse(producer, consumer) + if producer.is_foreach() and consumer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + foreach_match = len(producer.snodes) == len(consumer.snodes) + if not foreach_match: + why("foreach do not have same length") + return foreach_match and all( + producer.scheduler.can_fuse(l, r) + for l, r in zip(producer.snodes, consumer.snodes) + ) + elif consumer.is_foreach(): + if producer.is_reduction(): + why( + "candidate producer is a reduction, foreach ops cannot be fused with reductions currently" + ) + return False + + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + consumer_subnode = consumer.get_consumer_subnode_for(producer) + if consumer_subnode is not None: + return consumer.scheduler.can_fuse(producer, consumer_subnode) + + why("candidate producer is not dep of any foreach consumer") + return False + + elif producer.is_foreach(): + if consumer.is_reduction(): + why( + "candidate consumer is a reduction, foreach ops cannot be fused with reductions currently" + ) + return False + + producer = typing.cast(ForeachKernelSchedulerNode, producer) + producer_subnode = producer.get_producer_subnode_for(consumer) + if producer_subnode is not None: + return producer.scheduler.can_fuse(producer_subnode, consumer) + + why("candidate consumer has no dep in any foreach producer") + return False + + raise AssertionError( + "At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node" + ) + + @classmethod + def fuse( + cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode + ) -> ForeachKernelSchedulerNode: + assert producer.is_foreach() or consumer.is_foreach() + if producer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + use_custom_partition_algo = producer.use_custom_partition_algo + enable_autotune = producer.enable_autotune + else: + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + use_custom_partition_algo = consumer.use_custom_partition_algo + enable_autotune = consumer.enable_autotune + prev_node_1 = None + prev_node_2 = None + fused_nodes: List[BaseSchedulerNode] + if producer.is_foreach() and consumer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + fused_nodes = [ + FusedSchedulerNode.fuse(l, r) + for l, r in zip(producer.snodes, consumer.snodes) + ] + elif producer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + producer_subnode = producer.get_producer_subnode_for(consumer) + fused_nodes = [] + prev_node_1 = producer + prev_node_2 = None + for node in producer.snodes: + if node is producer_subnode: + new_node = FusedSchedulerNode.fuse(node, consumer) + prev_node_2 = new_node + fused_nodes.append(new_node) + else: + fused_nodes.append(node) + + elif consumer.is_foreach(): + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + consumer_subnode = consumer.get_consumer_subnode_for(producer) + fused_nodes = [] + prev_node_1 = consumer + prev_node_2 = None + + for node in consumer.snodes: + if node is consumer_subnode: + new_node = FusedSchedulerNode.fuse(producer, node) + prev_node_2 = new_node + fused_nodes.append(new_node) + else: + fused_nodes.append(node) + else: + raise AssertionError( + "At least one node passed to ForeachKernelSchedulerNode.fuse should be a foreach node" + ) + + return cls( + producer.scheduler, + fused_nodes, + use_custom_partition_algo=use_custom_partition_algo, + prev_node_1=prev_node_1, + prev_node_2=prev_node_2, + enable_autotune=enable_autotune, + ) + + def __init__( + self, + scheduler: Scheduler, + snodes: List[BaseSchedulerNode], + use_custom_partition_algo: bool, + prev_node_1: Optional[BaseSchedulerNode] = None, + prev_node_2: Optional[BaseSchedulerNode] = None, + enable_autotune: bool = False, + ) -> None: + self.read_to_node = {} + self.name_to_node = {} + + if prev_node_1 is None or prev_node_2 is None: + super().__init__(scheduler, snodes) + + for node in snodes: + for read in node.read_writes.reads: + self.read_to_node[read.name] = node + + for name in node.get_operation_names(): + self.name_to_node[name] = node + else: + self.scheduler = scheduler + self.snodes = snodes + self.node = None + self.users: List[NodeUser] = [] + + self.set_read_writes( + dependencies.ReadWrites.merge_list( + [prev_node_1.read_writes, prev_node_2.read_writes] + ) + ) + + self.unmet_dependencies = ( + OrderedSet( + dep + for dep in OrderedSet.union( + prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies + ) + if dep.name not in self.get_buffer_names() + ) + - self.read_writes.writes + ) + + self.min_order = min([prev_node_1.min_order, prev_node_2.min_order]) + self.max_order = max([prev_node_1.max_order, prev_node_2.max_order]) + + if prev_node_1.is_foreach(): + assert isinstance(prev_node_1, ForeachKernelSchedulerNode) + foreach_node, other_node = prev_node_1, prev_node_2 + else: + assert isinstance(prev_node_2, ForeachKernelSchedulerNode) + foreach_node, other_node = prev_node_2, prev_node_1 + + self.ancestors = foreach_node.ancestors + self.ancestors.update(other_node.ancestors) + + self.name_to_node = foreach_node.name_to_node + for name in other_node.get_operation_names(): + self.name_to_node[name] = other_node + + self.use_custom_partition_algo = use_custom_partition_algo + self.group = (snodes[0].get_device(), ((sympy.Expr("combo_kernel"),),)) + self.origins: OrderedSet[torch.fx.Node] = OrderedSet() + self.enable_autotune = enable_autotune + + @classmethod + def combinable_nodes( + cls, nodes: List[BaseSchedulerNode] + ) -> List[BaseSchedulerNode]: + extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)] + if extern: + log.debug( + "ComboKernels: %d external nodes are filtered %s", + len(extern), + [node.node.get_origins() for node in extern if node.node is not None], + ) + filtered_nodes = [ + x + for x in nodes + if not isinstance(x, (NopKernelSchedulerNode, ExternKernelSchedulerNode)) + ] + foreach_nodes = [ + x for x in filtered_nodes if isinstance(x, ForeachKernelSchedulerNode) + ] + if foreach_nodes: + log.debug("ComboKernels: %d foreach nodes are filtered", len(foreach_nodes)) + filtered_nodes = [ + x for x in filtered_nodes if not isinstance(x, ForeachKernelSchedulerNode) + ] + template_nodes = [x for x in filtered_nodes if x.is_template()] + if template_nodes: + log.debug( + "ComboKernels: %d template nodes are filtered", {len(template_nodes)} + ) + filtered_nodes = [x for x in filtered_nodes if x not in template_nodes] + return filtered_nodes + + @staticmethod + def _default_group_nodes_for_combo_kernels( + scheduler: Scheduler, + ) -> List[List[BaseSchedulerNode]]: + """ + Returns a list of lists of nodes that are to be grouped together. + """ + sorted_nodes = scheduler._topological_sort_nodes() + grouped_nodes = [] + max_num_nodes = 8 + for nodes in sorted_nodes: + grouped_nodes.extend( + [ + nodes[i : i + max_num_nodes] + for i in range(0, len(nodes), max_num_nodes) + ] + ) + + return grouped_nodes + + group_algorithm_for_combo_kernels: Callable[ + [Scheduler], List[List[BaseSchedulerNode]] + ] = _default_group_nodes_for_combo_kernels + + @staticmethod + def set_group_algorithm_for_combo_kernels( + custom_group_algorithm: Callable[[Scheduler], List[List[BaseSchedulerNode]]] + ) -> None: + ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = ( + custom_group_algorithm + ) + + @staticmethod + def group_nodes_for_combo_kernels( + scheduler: Scheduler, + ) -> List[List[BaseSchedulerNode]]: + return ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels(scheduler) + + def mark_run(self) -> None: + raise NotImplementedError + + def codegen(self) -> None: + assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}" + self.node.get_store_function()(self.node.make_loader()()) + + def is_foreach(self) -> bool: + return True + + def get_subkernel_nodes(self) -> List[BaseSchedulerNode]: + """Returns a list of nodes which comprise the combo kernel. + These nodes may be vertically fused.""" + return list(self.snodes) + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + """Returns all nodes contained in this kernel, unpacking fused nodes + into their constituent scheduler nodes.""" + return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes)) + + def get_first_name(self) -> str: + return self.snodes[0].get_first_name() + + def prune_redundant_deps( + self, name_to_fused_node: Dict[str, BaseSchedulerNode] + ) -> None: + _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf) + + for node in self.snodes: + node.prune_redundant_deps(name_to_fused_node) + + +class GroupedSchedulerNode(BaseSchedulerNode): + """ + This is a "fake" scheduler node that represents a group of scheduler nodes + that are meant to be *grouped* together (it does not allow another node to be scheduled + in between its constituent nodes, nor does it allow another node to fuse into any of its constituent nodes). + The way it does this is by maintaining its unmet dependencies as the union of its constituent nodes. + Fusion will still happen among the nodes within each GroupedSchedulerNode. + At codegen time, this scheduler node will be unpacked and codegen is called on each constituent node. + """ + + snodes: List[BaseSchedulerNode] + + @classmethod + def create(cls, snodes: List[BaseSchedulerNode]) -> GroupedSchedulerNode: + scheduler = snodes[0].scheduler + assert all(node.scheduler is scheduler for node in snodes) + grouped_snode = cls(scheduler, snodes) # type: ignore[arg-type] + for snode in snodes: + scheduler.name_to_fused_node[snode.get_name()] = grouped_snode + scheduler.name_to_fused_node[grouped_snode.get_name()] = grouped_snode + return grouped_snode + + def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: + super().__init__(scheduler) + init_group_node(self, scheduler, snodes) + + def unpack(self) -> List[BaseSchedulerNode]: + """ + Do fusion among nodes within this GroupedSchedulerNode, + and then unpack this GroupedSchedulerNode into regular nodes. + """ + for snode in self.snodes: + self.scheduler.name_to_fused_node[snode.get_name()] = snode + del self.scheduler.name_to_fused_node[self.get_name()] + return self.scheduler.fuse_nodes(self.snodes) + + def add_fake_dep(self, fake_dep: Dep) -> None: + self.set_read_writes(self.read_writes.with_read(fake_dep)) + self.unmet_dependencies.add(fake_dep) + + @cache_on_self + def get_name(self) -> str: + return "_".join([x.get_name() for x in self.snodes]) + + def get_first_name(self) -> str: + return self.snodes[0].get_name() + + @cache_on_self + def get_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes]) + + def get_outputs(self) -> List[SchedulerBuffer]: + result: List[SchedulerBuffer] = [] + for node in self.snodes: + result.extend(node.get_outputs()) + return result + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + return self.snodes + + @classmethod + def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: + # GroupedSchedulerNode cannot be fused with another node + return False + + +def pick_loop_order( + stride_lengths: List[List[int]], + sizes: List[sympy.Expr], + priority_idx: Tuple[int, ...] = (), +) -> List[int]: + """ + A heuristic to decide loop iteration orders. This has not been well + tuned and may be something we should autotune. + """ + + @functools.cmp_to_key + def index_cmp(a: int, b: int) -> int: + if sizes[a] == 1 or sizes[b] == 1: + # 1-sizes don't matter, just move them to the end + return cmp(sizes[a] == 1, sizes[b] == 1) + + # Take abs, otherwise flipped dimensions are treated as smaller + # strides than contiguous dims + stride_len_a = [abs(sl[a]) for sl in stride_lengths] + stride_len_b = [abs(sl[b]) for sl in stride_lengths] + + # equivalent to + # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all() + a_first = sum( + sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b) + ) + b_first = sum( + sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b) + ) + if a_first > b_first: + return -1 + if b_first > a_first: + return 1 + + # otherwise contiguous + return cmp(b, a) + + order = list(reversed(range(len(stride_lengths[0])))) + if len(priority_idx) > 0: + # if we have priority node, only use that node's order + stride_lengths = [stride_lengths[pi] for pi in priority_idx] + if config.pick_loop_orders: + order.sort(key=index_cmp) + return order + + +@dataclasses.dataclass +class NodeUser: + node: Union[BaseSchedulerNode, OutputNode] + can_inplace: bool = False + + # A weak user must be scheduled after a given node, but doesn't actually + # use the result + is_weak: bool = False + + def __hash__(self) -> int: + return hash((self.node.get_name(), self.can_inplace, self.is_weak)) + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, NodeUser) + and self.get_name() == other.get_name() + and self.can_inplace == other.can_inplace + and self.is_weak == other.is_weak + ) + + def get_name(self) -> str: + return self.node.get_name() + + def merge(self, other: NodeUser) -> NodeUser: + assert self.node is other.node + return NodeUser( + self.node, + self.can_inplace and other.can_inplace, + self.is_weak and other.is_weak, + ) + + +_post_grad_graph_counter = itertools.count() + + +class Scheduler: + __dep_size_hint_cache: Dict[Dep, int] + + def __init__(self, nodes: List[ir.Operation]) -> None: + with dynamo_timed("Scheduler.__init__"): + self._init(nodes) + + def _init(self, nodes: List[ir.Operation]) -> None: + super().__init__() + self.__dep_size_hint_cache = {} + V.graph.scheduler = self + self.backends: Dict[torch.device, BaseScheduling] = {} + self.post_grad_graph_id = next(_post_grad_graph_counter) + + self.completed_operations: OrderedSet[str] = OrderedSet() + self.available_buffer_names = OrderedSet( + [ + *V.graph.graph_inputs.keys(), + *V.graph.constants.keys(), + *V.graph.torchbind_constants.keys(), + ] + ) + + self.nodes = [self.create_scheduler_node(n) for n in nodes] + self.update_zero_dim_cpu_tensor() + # some new constants could have been created above + self.available_buffer_names.update(V.graph.constants.keys()) + for node in self.nodes: + node.prune_deps() + + self.name_to_node: Dict[str, BaseSchedulerNode] = { + n.get_name(): n for n in self.nodes + } + self.name_to_buf: Dict[str, SchedulerBuffer] = { + buf.get_name(): buf for node in self.nodes for buf in node.get_outputs() + } + self.name_to_fused_node: Dict[str, BaseSchedulerNode] = self.name_to_node.copy() + + # mutation_real_name: Maps back to the original name for codegen + # Example: + # If you mutate buf0 inside of buf1's kernel, then: + # mutation_real_name = {"buf0" : "buf1"} + # all subsequent uses of buf0 become buf1's usage in dependency graph + self.mutation_real_name: Dict[str, str] = {} + + # We handle mutation by renaming modified versions of the same + # buffer in the dependency graph to prevent cycles. + # mutation_renames: tracks the current name for a given buffer + # (changed once per mutation) + # Example: + # If you mutate buf0 inside of buf1's kernel, then: + # mutation_renames = {"buf1" : "buf0"} + # in codegen we only use buf0, never buf1 + self.mutation_renames: Dict[str, str] = {} + + self.compute_dependencies() + self.nodes = self.topological_sort_schedule(self.nodes) + self.dead_node_elimination() + self.name_to_fused_node = {n.get_name(): n for n in self.nodes} + self.compute_ancestors() + if config.reorder_for_compute_comm_overlap: + self.nodes = comms.decide_global_ordering_of_comms( + self.nodes, + self.name_to_buf, + self.name_to_fused_node, + ) + + metrics.ir_nodes_pre_fusion += len(self.nodes) + V.debug.ir_pre_fusion(self.nodes) + self.num_orig_nodes = len(self.nodes) + self.create_foreach_nodes() + self.nodes = self.topological_sort_schedule(self.nodes) + self.logged_slow_fusion: OrderedSet[Tuple[str, str]] = OrderedSet() + if config._pre_fusion_custom_pass is not None: + self.nodes = config._pre_fusion_custom_pass(self.nodes) + self.nodes = self.fuse_nodes(self.nodes) + self.merge_loops() + self.finalize_multi_template_buffers() + if config.reorder_for_compute_comm_overlap: + self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) + if config.combo_kernels: + self.create_combo_kernel_nodes(num_ck_nodes=None) + self.process_grouped_nodes() + self.compute_last_usage() + V.debug.ir_post_fusion(self.nodes) + V.debug.graph_diagram(self.nodes) + self.debug_draw_graph() + + # used during codegen: + self.current_device: Optional[torch.device] = None + self.buffer_names_to_free: OrderedSet[str] = OrderedSet() + + # fx graph node to the position it appears in the graph + # for debug attribution + self.origin_to_index: Dict[torch.fx.Node, int] = {} + + get_metric_table("graph_stats").add_row( + lambda: { + "graph_id": self.post_grad_graph_id, + "num_nodes_before_fusion": self.num_orig_nodes, + "num_nodes_after_fusion": len(self.nodes), + } + ) + + def get_current_device_or_throw(self) -> torch.device: + if device := self.current_device: + return device + else: + raise RuntimeError("No current device") + + def debug_draw_graph(self) -> None: + """Generate an image of the graph for debugging""" + if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1": + from .debug import draw_buffers + + draw_buffers(self.nodes, print_graph=True) + + def debug_print_nodes(self, label: str) -> None: + if log.isEnabledFor(logging.INFO): + log.info("%s:", label) + for node in self.nodes: + node.log_details() + + def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode: + assert ( + node.get_origins() is not None + ), "All nodes passed to scheduling must have an origin" + if node.is_no_op(): + return NopKernelSchedulerNode(self, node) + elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)): + return SchedulerNode(self, node) + elif isinstance(node, ir.ExternKernel): + return ExternKernelSchedulerNode(self, node) + else: + raise NotImplementedError(node) + + def create_foreach_nodes(self) -> None: + removed_node_names: OrderedSet[str] = OrderedSet() + fe_nodes = [] + kept_node_names = self.name_to_fused_node.keys() + + for names in V.graph.lists.values(): + names = [ + name + for name in names + if name in kept_node_names + and not isinstance(self.name_to_node[name], NopKernelSchedulerNode) + ] + if not names: + # All nodes eliminated + continue + + removed_node_names.update(names) + snodes = [self.name_to_node[name] for name in names] + + enable_autotune = config.combo_kernels_autotune > 1 + fe_node = ForeachKernelSchedulerNode( + self, + snodes, + use_custom_partition_algo=False, + enable_autotune=enable_autotune, + ) + + fe_nodes.append(fe_node) + + for name in names: + self.name_to_fused_node[name] = fe_node + + self.nodes = [ + node for node in self.nodes if node.get_name() not in removed_node_names + ] + list(fe_nodes) + + def compute_dependencies(self) -> None: + """ + Create dependency edges between nodes, handling aliasing and + mutation properly. + """ + + T = TypeVar("T") + + class DedupList(Generic[T]): + """ + This data structure behaves like a list except it makes sure the + elements remain unique. + Normally one could use a OrderedSet/dict for this purpose however + the list in question gets elements appended as it is being + iterated over which means that we need to keep the list + semantics. + """ + + def __init__( + self, + items: Optional[List[T]] = None, + membership: Optional[OrderedSet[T]] = None, + ) -> None: + self.items = items or [] + self.membership = membership or OrderedSet() + + def append(self, node_user: T) -> None: + if node_user in self.membership: + return + self.items.append(node_user) + self.membership.add(node_user) + + def __add__(self, other: DedupList[T]) -> DedupList[T]: + new_membership = OrderedSet.union(self.membership, other.membership) + new_items = self.items + [ + x for x in other.items if x not in self.membership + ] + return DedupList(new_items, new_membership) + + name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict( + DedupList + ) + + # handle aliasing by using python aliasing in name_to_users + # if foo aliases bar then we will make name_to_users["foo"] point + # to the same python list as name_to_users["bar"] + for node in self.nodes: + for buf1 in node.get_outputs(): + buf1_name = buf1.get_name() + for buf2_name in buf1.get_aliases(): + if buf1_name in name_to_users and buf2_name in name_to_users: + # merge the two + list1 = name_to_users[buf1_name] + list2 = name_to_users[buf2_name] + combined = list1 + list2 + for key in name_to_users.keys(): + if ( + name_to_users[key] is list1 + or name_to_users[key] is list2 + ): + name_to_users[key] = combined + elif buf1_name in name_to_users: + name_to_users[buf2_name] = name_to_users[buf1_name] + else: + name_to_users[buf1_name] = name_to_users[buf2_name] + + def rename(n: str) -> str: + if n in self.mutation_renames: + return rename(self.mutation_renames[n]) + return n + + def add_user( + used_by_name: str, + user_node: Union[BaseSchedulerNode, OutputNode], + can_inplace: bool = False, + is_weak: bool = False, + ) -> None: + name_to_users[rename(used_by_name)].append( + NodeUser(user_node, can_inplace, is_weak) + ) + + unbacked_symbol_to_origin_node: Dict[sympy.Symbol, Optional[str]] = {} + + # NB: None means that the dependency is on an input. Don't actually + # generate a dependency because if we do, Inductor will start trying + # to free the unbacked int but that's pointless + for name, val in V.graph.graph_inputs.items(): + if isinstance(val, sympy.Expr): + for fs in val.free_symbols: + unbacked_symbol_to_origin_node[fs] = None + + for node in self.nodes: + log.debug("scheduling %s", node.node) + + # unbacked symbols don't follow ordinary buffer dependencies, so + # we track their def/uses separately + assert node.node is not None + unbacked_symbol_defs = sorted( + node.node.get_unbacked_symbol_defs(), key=lambda x: x.name + ) + for s in unbacked_symbol_defs: + assert isinstance(s, sympy.Symbol) + # Pick the first definer as canonical. There may be multiple + # because if a MultiOutputLayout buffer propagates an unbacked + # symint to multiple outputs, they will all claim to def it. + if s not in unbacked_symbol_to_origin_node: + unbacked_symbol_to_origin_node[s] = node.get_name() + + unbacked_symbol_uses = sorted( + node.node.get_unbacked_symbol_uses(), key=lambda x: x.name + ) + # if a kernel takes unbacked symints, register dependencies + for s in unbacked_symbol_uses: + assert ( + s in unbacked_symbol_to_origin_node + ), f"{s} not in {unbacked_symbol_to_origin_node}" + if (r := unbacked_symbol_to_origin_node[s]) is not None: + for buf in self.name_to_node[r].get_outputs(): + node.add_fake_dep(StarDep(buf.get_name())) + + if ( + len(node.read_writes.writes) == 1 + and (dep := next(iter(node.read_writes.writes))) + and isinstance(dep, MemoryDep) + ): + node_mode = dep.mode + else: + node_mode = None + + # Handle output mutations + for buf in node.get_outputs(): + # a node will mutate either 0 or 1 buffers + assert len(buf.get_mutations()) <= 1 + for alt_name in buf.get_mutations(): + alt_name = rename(alt_name) + # this node must run after the prior writer + add_user(alt_name, node) + node.add_fake_dep(StarDep(alt_name, mode=node_mode)) + for user in name_to_users[alt_name].items: + if user.get_name() == node.get_name(): + continue + + assert isinstance(user.node, BaseSchedulerNode) + for other_name in user.node.get_buffer_names(): + # this node must run after all prior readers + other_name = rename(other_name) + node.add_fake_dep( + WeakDep(other_name, mutating_buf=buf.get_name()) + ) + add_user(other_name, node, is_weak=True) + + # add normal non-mutation dependencies + for read in node.read_writes.reads: + if not isinstance(read, WeakDep): + add_user(read.name, node, node.can_inplace(read)) + + node.update_mutated_names(self.mutation_renames) + + # update our renaming scheme for the next iteration + for buf in node.get_outputs(): + for alt_name in buf.get_mutations(): + self.mutation_renames[rename(alt_name)] = buf.get_name() + self.mutation_renames[alt_name] = buf.get_name() + self.mutation_real_name[ + buf.get_name() + ] = self.mutation_real_name.get(alt_name, alt_name) + + # make sure outputs aren't dead-code-eliminated + for buf_name in V.graph.get_output_names(): + log.debug("scheduling output %s", buf_name) + add_user(buf_name, OutputNode(StarDep(buf_name))) + + # make sure unbacked symints aren't dead-code-eliminated + for out in V.graph.graph_outputs: + for s in out.get_unbacked_symbol_uses(): + assert ( + s in unbacked_symbol_to_origin_node + ), f"{s} not in {unbacked_symbol_to_origin_node.keys()}" + if r := unbacked_symbol_to_origin_node[s]: + for buf_name in self.name_to_node[r].get_buffer_names(): + log.debug( + "scheduling output %s for unbacked symint %s", buf_name, s + ) + add_user(buf_name, OutputNode(StarDep(buf_name))) + + # make sure input mutation isn't dead-code-eliminated + for name in self.mutation_renames: + if name in V.graph.graph_inputs: + add_user(name, OutputNode(StarDep(name))) + V.graph.mutated_inputs.add(name) + elif name in V.graph.constants: + # In AOTI, module parameters and buffers are not lifted as graph inputs + add_user(name, OutputNode(StarDep(name))) + + inp_names = { + name: index for index, name in enumerate(V.graph.graph_inputs.keys()) + } + V.graph.mutated_input_idxs = [ + inp_names[name] for name in V.graph.mutated_inputs + ] + + # copy users information onto the nodes + for node in self.nodes: + for buf in node.get_outputs(): + buf.set_users(name_to_users[buf.get_name()].items) + + def dead_node_elimination(self) -> None: + """ + Remove any nodes without users + """ + # self.nodes is in topological order, so by iterating in reverse order + # we have visited (and potentially removed) all users before visiting a + # given node. + updated_nodes = [] + for node in reversed(self.nodes): + + def can_eliminate_user(user: NodeUser) -> bool: + return user.is_weak or user.get_name() in V.graph.removed_operations + + active_buffers = False + for buf in node.get_outputs(): + can_eliminate = all(can_eliminate_user(u) for u in buf.users) + if can_eliminate: + log.debug("removed dead buffer: %s", buf.get_name()) + V.graph.removed_buffers.add(buf.get_name()) + else: + active_buffers = True + + can_eliminate = not node.has_side_effects() and not active_buffers + + if not can_eliminate: + updated_nodes.append(node) + else: + # dead code + log.debug("removed dead operation: %s", node.get_name()) + V.graph.removed_operations.add(node.get_name()) + + self.nodes = list(reversed(updated_nodes)) + + # Prune any WeakDeps no longer needed + for node in self.nodes: + node.prune_weak_deps() + + def topological_sort_schedule( + self, nodes: List[BaseSchedulerNode] + ) -> List[BaseSchedulerNode]: + """ + Ensure nodes is in topologically sorted order + """ + seen: OrderedSet[BaseSchedulerNode] = OrderedSet() + name_to_node: Dict[str, BaseSchedulerNode] = dict() + result: List[BaseSchedulerNode] = [] + + def visit(n: BaseSchedulerNode) -> None: + if n not in seen: + seen.add(n) + for dep in sorted(n.unmet_dependencies, key=lambda d: d.name): + # We only care about doing toposort within `nodes` + if dep.name not in name_to_node: + continue + visit(name_to_node[dep.name]) + result.append(n) + + for node in nodes: + for name in node.get_buffer_names(): + name_to_node[name] = node + for node in nodes: + visit(node) + return result + + def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> List[BaseSchedulerNode]: + unmet_deps = set() + if isinstance( + snode, + ( + SchedulerNode, + ExternKernelSchedulerNode, + NopKernelSchedulerNode, + FusedSchedulerNode, + ), + ): + for dep in snode.unmet_dependencies: + unmet_deps.add(dep.name) + else: + raise RuntimeError( + f"get_unmet_dep_nodes is not implemented for {type(snode)}." + ) + unmet_dep_ops = (self.name_to_buf[dep].defining_op for dep in unmet_deps) + return list({self.name_to_fused_node[n.get_name()] for n in unmet_dep_ops}) + + def _topological_sort_nodes(self) -> List[List[BaseSchedulerNode]]: + """ + Sort nodes by their topological order, return a list of node lists. + """ + order = [] + nodes = dict.fromkeys(self.nodes, 0) + children: Dict[Any, Any] = {} + for node in self.nodes: + deps = self._get_unmet_dep_nodes(node) + nodes[node] = len(deps) + for dep in deps: + c = children.get(dep, []) + c.append(node) + children[dep] = c + + zero_deg_nodes = [n for n, v in nodes.items() if v == 0] + while zero_deg_nodes: + order.append(zero_deg_nodes) + for n in zero_deg_nodes: + for user in children.get(n, []): + nodes[user] -= 1 + nodes.pop(n) + zero_deg_nodes = [n for n, v in nodes.items() if v == 0] + assert not nodes, "Topological sort failed!" + return order + + def compute_ancestors(self) -> None: + """ + Populate each node.ancestors + """ + # note self.nodes is topologically sorted + name_to_ancestors: Dict[str, OrderedSet[str]] = {} + for node in self.nodes: + ancestors: OrderedSet[str] = OrderedSet() + for dep in node.unmet_dependencies: + dep_node_name = self.name_to_buf[dep.name].defining_op.get_name() + ancestors.add(dep_node_name) + ancestors |= name_to_ancestors[dep_node_name] + name_to_ancestors[node.get_name()] = ancestors + node.ancestors = ancestors + + for order, node in enumerate(self.nodes): + node.min_order = order + node.max_order = order + + def merge_loops(self) -> None: + for node in self.nodes: + if not config.loop_ordering_after_fusion: + continue + + # Even for CPU, if we are using the halide backend, we still need + # the merge loops steps below + if not isinstance(node, (SchedulerNode, FusedSchedulerNode)) or ( + node.get_device().type != "cuda" and config.cpu_backend != "halide" + ): + continue + for snode in node.get_nodes(): + # merge loops for the scheduler node + if not isinstance(snode, SchedulerNode) or snode.is_template(): + continue + + snode._body = snode._body.merge_loops() + snode._sizes = snode._body.sizes + + # merge_loops is called after loop reordering. + # We still need retain fake dependencies since codegen the + # estimated amount of memory access rely on them. + snode.refresh_dependencies(normalize=True) + + # Note that for CPU backend, merging loops will change + # snode.group. It's fine for Triton backend. + # But if we simplify update snode.group like this: + # group_fn = self.get_backend(snode.node.get_device()).group_fn + # snode.group = (snode.node.get_device(), group_fn(snode._sizes)) + # There is still an issue due to different snode in a + # FusedSchedulerNode having different merged loops. + # Skip CPU backend for now. + + def fuse_nodes(self, nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: + """ + Combine eligible nodes into FusedSchedulerNodes. + """ + for i in range(10): + old_len = len(nodes) + fusion_log.debug( + "===== attempting fusion (%d/10): %d nodes =====", + i + 1, + old_len, + ) + nodes = self.fuse_nodes_once(nodes) + new_len = len(nodes) + fusion_log.debug( + "completed fusion round (%d/10): fused %d nodes into %d nodes\n", + i + 1, + old_len, + new_len, + ) + if new_len == old_len or new_len == 1: + fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1) + break + return nodes + + def process_grouped_nodes(self) -> None: + """ + Unpack GroupedSchedulerNode into regular nodes. + """ + new_nodes: List[BaseSchedulerNode] = [] + for node in self.nodes: + new_nodes.extend( + node.unpack() if isinstance(node, GroupedSchedulerNode) else [node] + ) + self.nodes = new_nodes + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> Tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + assert len(nodes) > 0 + device = nodes[0].get_device() + self.current_device = device + backend = self.get_backend(device) + return backend.benchmark_fused_nodes(nodes) + + def finalize_multi_template_buffers(self) -> None: + def replace_operation_buffer( + orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer + ) -> None: + replaced_buf_name = new_node.get_name() + orig_buf_name = orig_node.get_name() + assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str) + + replaced_op_name = new_node.get_operation_name() + orig_op_name = orig_node.get_operation_name() + assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str) + + del V.graph.name_to_buffer[replaced_buf_name] + new_node.name = orig_buf_name + + del V.graph.name_to_op[replaced_op_name] + new_node.operation_name = orig_op_name + + orig = V.graph.buffers.index(orig_node) + V.graph.buffers.remove(new_node) + V.graph.buffers[orig] = new_node + V.graph.name_to_buffer[orig_buf_name] = new_node + + orig = V.graph.operations.index(orig_node) + V.graph.operations.remove(new_node) + V.graph.operations[orig] = new_node + V.graph.name_to_op[orig_op_name] = new_node + + for i, node in enumerate(self.nodes): + if isinstance(node, SchedulerNode) and isinstance( + node.node, ir.MultiTemplateBuffer + ): + multi_node = node.node + min_node_unfused, _ = multi_node.get_min_choice() + + if isinstance( + min_node_unfused, + torch._inductor.ir.TritonTemplateCallerBase, + ): + node.node.finalize_as_triton_caller(min_node_unfused) + continue + + out_tensorbox = min_node_unfused.output_node() + out_storage = out_tensorbox.data + assert isinstance(out_storage, ir.StorageBox) + out_buffer = out_storage.data + assert isinstance(out_buffer, ir.OperationBuffer) + + out_buffer.layout = multi_node.layout + replace_operation_buffer(multi_node, out_buffer) + new_scheduler_node = self.create_scheduler_node(out_buffer) + + self.nodes[i] = new_scheduler_node + self.name_to_node[node.get_name()] = new_scheduler_node + self.name_to_fused_node[node.get_name()] = new_scheduler_node + + for new_out, old_out in zip( + new_scheduler_node.get_outputs(), node.get_outputs() + ): + self.name_to_buf[old_out.get_name()] = new_out + new_out.users = old_out.users + + new_scheduler_node.min_order = node.min_order + new_scheduler_node.max_order = node.max_order + new_scheduler_node.last_usage = node.last_usage + + def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool: + return any( + hasattr(n.node, "data") + and n.node is not None + and hasattr(n.node.data, "scatter_mode") + and n.node.data.scatter_mode == "atomic_add" + for n in node_list + ) + + def speedup_by_fusion( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + If config.benchmark_fusion is False, always return True. + Otherwise, return True if fusion can brings speedup. + """ + + is_multi_template = node1.is_template() and isinstance( + node1.get_template_node(), ir.MultiTemplateBuffer + ) + if not config.benchmark_fusion and not is_multi_template: + return True + + if ( + node1.is_template() + and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer) + or node1.is_foreach() + or node2.is_foreach() + ): + # TODO support benchmarking epilogue fusion + return True + + node_list_1 = node1.get_nodes() + device = node_list_1[0].get_device() + + # don't support benchmark fusion for CPU right now. + if device.type == "cpu": + return True + + node_list_2 = node2.get_nodes() + node_list_fused = list(itertools.chain(node_list_1, node_list_2)) + + # We can not accurately benchmark kernel using atomic_add + # due to how we generate random integer inputs. + # Skip benchmarking them by allowing fusion. + if self._any_atomic_add(node_list_fused): + return True + + from triton.compiler.errors import CompilationError + + why = WhyNoFuse(node1, node2) + + def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None: + if fusion_log.isEnabledFor(logging.DEBUG): + if ms_fused < ms1 + ms2: + fusion_log.debug( + "can fuse (benchmark): fusing %s with %s cause %sx speedup", + node1.get_buffer_names(), + node2.get_buffer_names(), + green_text(f"{(ms1 + ms2) / ms_fused:.3f}"), + ) + else: + fusion_log.debug( + "cannot fuse (benchmark): fusing %s with %s cause %sx slowdown", + node1.get_buffer_names(), + node2.get_buffer_names(), + red_text(f"{ms_fused / (ms1 + ms2):.3f}"), + ) + + if isinstance(node1, SchedulerNode) and isinstance( + node1.node, ir.MultiTemplateBuffer + ): + multi_node = node1.node + choice_timings = multi_node.choice_timings + + _, ms1 = multi_node.get_min_choice() + ms2, path2 = self.benchmark_fused_nodes(node_list_2) + + min_ms_fused = float("inf") + ms_fused_choice = None + + triton_choices = 0 + + for choice, unfused_time in sorted( + choice_timings.items(), key=lambda x: x[1] + ): + if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase): + continue + + if unfused_time >= ms1 + ms2: + break + + triton_choices += 1 + if triton_choices > config.max_epilogue_benchmarked_choices: + break + + # TODO - parallel compile triton templates + # TODO - should prune/skip choices that are not within certain % of best choice + with node1.node.swap_as_triton_caller(choice): + ms_fused, _ = self.benchmark_fused_nodes(node_list_fused) + + if ms_fused < min_ms_fused: + min_ms_fused = ms_fused + ms_fused_choice = choice + + log_fusion(min_ms_fused, ms1, ms2) + + # after we do a fusion, we finalize a triton template. + # TODO - could preserve multi template and choices for subsequent fusions + if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None: + node1.node.finalize_as_triton_caller(ms_fused_choice) + return True + else: + return False + else: + try: + ms1, path1 = self.benchmark_fused_nodes(node_list_1) + if math.isinf(ms1): + why("register spilling of the first kernel") + return False + ms2, path2 = self.benchmark_fused_nodes(node_list_2) + if math.isinf(ms2): + why("register spilling of the second kernel") + return False + ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused) + if math.isinf(ms_fused): + why("register spilling of the fused kernel") + return False + except CompilationError as e: + # workaround triton issue: https://github.com/openai/triton/issues/2151 + if "Loop-carried variable" in str(e): + return True # allow fusion + else: + raise + + log_fusion(ms_fused, ms1, ms2) + if ( + is_metric_table_enabled("slow_fusion") + and ms_fused >= ms1 + ms2 + and (path1, path2) not in self.logged_slow_fusion + ): + self.logged_slow_fusion.add((path1, path2)) + get_metric_table("slow_fusion").add_row( + lambda: { + "kernel1_path": path1, + "kernel1_latency": ms1, + "kernel2_path": path2, + "kernel2_latency": ms2, + "fused_kernel_path": path_fused, + "fused_kernel_latency": ms_fused, + "slow_down_ratio": ms_fused / (ms1 + ms2), + } + ) + return ms_fused < ms1 + ms2 + + def fuse_nodes_once( + self, nodes: List[BaseSchedulerNode] + ) -> List[BaseSchedulerNode]: + """ + Combine eligible nodes into FusedSchedulerNodes. + + This relies on two key functions to control the logic: + - self.can_fuse(): checks if a fusion is legal + - self.score_fusion(): assigns priority to a given fusion + """ + fused_nodes = OrderedSet(nodes) + if fusion_log.isEnabledFor(logging.DEBUG): + fusion_log.debug("fuse_nodes_once, candidates:") + for node in fused_nodes: + fusion_log.debug(" " + node.debug_str_short()) # noqa: G003 + for node1, node2 in self.get_possible_fusions(nodes): + node1 = self.name_to_fused_node[node1.get_first_name()] + node2 = self.name_to_fused_node[node2.get_first_name()] + if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle( + node1, node2 + ): + if not self.speedup_by_fusion(node1, node2): + continue + fusion_log.debug( + "fusing %s with %s", node1.get_name(), node2.get_name() + ) + + # above can_fuse asserts that node2 has the same device + device = node1.get_device() + node3 = self.get_backend(device).fuse(node1, node2) + fused_nodes.remove(node1) + fused_nodes.remove(node2) + fused_nodes.add(node3) + self.name_to_fused_node.update( + {n.get_name(): node3 for n in node3.get_nodes()} + ) + nodes = sorted(fused_nodes, key=lambda x: x.min_order) + nodes = self.topological_sort_schedule(nodes) + self.prune_redundant_deps(nodes) + return nodes + + def create_combo_kernel_nodes(self, num_ck_nodes: Optional[int] = None) -> None: + """ + Groups parallel nodes + """ + fused_nodes = set(self.nodes) + count = 0 + num_nodes_orig = len(self.nodes) + log.debug("ComboKernels: Generating with num_ck_nodes = %d...", num_ck_nodes) + for num, node_list in enumerate( + ForeachKernelSchedulerNode.group_nodes_for_combo_kernels(self) + ): + node_list = ForeachKernelSchedulerNode.combinable_nodes(node_list) + if len(node_list) < 2: + continue + if num_ck_nodes is not None and count > num_ck_nodes: + break + if not self.speedup_by_combo_kernel(node_list): + log.debug("ComboKernels: Not speeding up %d-th group", num) + continue + count += 1 + enable_autotune = config.combo_kernels_autotune > 0 + group_snode = ForeachKernelSchedulerNode( + node_list[0].scheduler, + node_list, + use_custom_partition_algo=True, + enable_autotune=enable_autotune, + ) + log.info( + "ComboKernels: Combining %d nodes for %d-th group", + len(node_list), + num, + ) + for node in node_list: + fused_nodes.remove(node) + fused_nodes.add(group_snode) + self.name_to_fused_node.update( + {n.get_name(): group_snode for n in group_snode.get_nodes()} + ) + self.nodes = sorted(fused_nodes, key=lambda x: x.min_order) + self.nodes = self.topological_sort_schedule(self.nodes) + log.info( + "Generated ComboKernel nodes: %d ComboKernels, totally %d -> %d nodels", + count, + num_nodes_orig, + len(self.nodes), + ) + self.prune_redundant_deps(self.nodes) + + def prune_redundant_deps(self, nodes: List[BaseSchedulerNode]) -> None: + for node in nodes: + node.prune_redundant_deps(self.name_to_fused_node) + + def get_possible_fusions( + self, nodes: List[BaseSchedulerNode] + ) -> List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]: + """ + Helper to find all legal fusion opportunities, sorted by self.score_fusion() + """ + possible_fusions = [] + seen: OrderedSet[Tuple[BaseSchedulerNode, BaseSchedulerNode]] = OrderedSet() + + def check_all_pairs(nodes: List[BaseSchedulerNode]) -> None: + for node1_index, node1 in enumerate(nodes): + for node2 in nodes[node1_index + 1 :]: + key = (node1, node2) + if key in seen: + continue + seen.add(key) + + if self.can_fuse(node1, node2): + possible_fusions.append(key) + elif (node2.is_template() or node2.is_foreach()) and self.can_fuse( + node2, node1 + ): + # foreach fusions and epilogue fusions are order dependent + possible_fusions.append((node2, node1)) + + buffer_names_grouping = collections.defaultdict(list) + for node in nodes: + for buf in node.used_buffer_names(): + buffer_names_grouping[buf].append(node) + for node_grouping in buffer_names_grouping.values(): + check_all_pairs(node_grouping) + + if config.aggressive_fusion: + group_grouping = collections.defaultdict(list) + for node in nodes: + group = getattr(node, "group", None) + if group: + group_grouping[group].append(node) + for node_grouping in group_grouping.values(): + check_all_pairs(node_grouping) + + possible_fusions = self.get_possible_fusions_with_highest_priority( + possible_fusions + ) + possible_fusions.sort(key=self.score_fusion_key, reverse=True) + fusion_log.debug("found %d possible fusions", len(possible_fusions)) + return possible_fusions + + def will_fusion_create_cycle( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Finds whether there's a path from node1 to node2 (or vice-versa) + caused indirectly by other fusions. + """ + # since we are just returning boolean here, use slightly faster, unordered set + visited: Set[FusedSchedulerNode] = set() + + def found_path(node: BaseSchedulerNode) -> bool: + # only fused nodes can introduce new ancestors. + if isinstance(node, FusedSchedulerNode) and node not in visited: + visited.add(node) + if node.get_operation_names().issubset(combined_ancestors): + # All fusion outputs are in ancestors of node1 and node2, thus + # cannot introduce new path: + # + # 1. if output is neither descendent of node1 or node2, the + # output cannot introduce a path + # 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be + # on path(node1->node2), hence it cannot be ancestor of node2 + # 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be + # ancestor of node1 + return False + else: + # continue DFS of new ancestors introduced by the fusion + return bool(combined_names & node.ancestors) or any( + found_path(self.name_to_fused_node[n]) + for n in node.ancestors - combined_ancestors + ) + return False + + # as above - use slightly faster, unordered set + combined_names = ( + node1.get_operation_names()._dict.keys() + | node2.get_operation_names()._dict.keys() + ) + combined_ancestors = ( + node1.ancestors._dict.keys() | node2.ancestors._dict.keys() + ) - combined_names + cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors) + if cycle: + WhyNoFuse(node1, node2)("will create cycle") + return cycle + + def can_fusion_increase_peak_memory( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + This function prevents fusion for nodes that can increase memory + footprint. This problem is more common in horizontal fusion, where nodes + that are far apart in the original order get fused, lengthening the live + intervals of tensors. This is very evident in models with activation + checkpointing, where the recomputed nodes from different checkpointed + regions get fused and significantly increase the memory footprint. + + The current attempt is a quick, possibly hacky, heuristic to prevent the + fusion of nodes that are far away in the original order. + + A better but difficult to implement heurisitic would be to use live + intervals of the buffers, find region of peak pressure in the original + program and prevent fusion that crosses that peak region. We might need + special care or good approximation in this implementation, as fusion of + node changes live intervals, and re-computing live intervals and peak + memory after each fusion can introduce large compilation overhead. + """ + proximity_score = max( + abs(node1.min_order - node2.max_order), + abs(node2.min_order - node1.max_order), + ) + return proximity_score > 64 + + def decide_fusion_fail_reason( + self, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + common_buf_names: Tuple[str, ...], + ) -> str: + """ + Try to decide reasons why fusion fail due to no shared memory even though + there are common buffers. + """ + reasons = {} + node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} + node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} + + for buf_name in common_buf_names: + buf = V.graph.get_buffer(buf_name) + lhs_dep = node1_name2dep[buf_name] + rhs_dep = node2_name2dep[buf_name] + + if lhs_dep.get_numel() != rhs_dep.get_numel(): + reasons[ + buf_name + ] = f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}" + continue + + # same numel but different MemoryDep.size. Should be broadcasting + if sympy_product(lhs_dep.size) != sympy_product(rhs_dep.size): + reasons[buf_name] = "broadcast" + continue + + if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): + reasons[ + buf_name + ] = f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}" + continue + + lhs_off = lhs_dep.get_offset() + rhs_off = rhs_dep.get_offset() + if lhs_off != rhs_off: + # One example is in transformer, we use a concatenated linear layer + # to project Q/K/V and then split the result. The 3 splits will + # point to the same buffer with different offsets. + reasons[buf_name] = f"different offset: {lhs_off} v.s. {rhs_off}" + continue + + if ( + lhs_dep.normalize_with_stride_order() + == rhs_dep.normalize_with_stride_order() + ): + reasons[buf_name] = f"Mismatch loop orders: {lhs_dep} v.s. {rhs_dep}" + continue + + # Add more rules here + reasons[ + buf_name + ] = f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. Layout: {buf.layout}" + + return str(reasons) + + def has_shared_data_after_reordering_loop( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Right now just greedily reorder the loop of node1 to be compatible with node2, + but ideally we should have some heuristics to reorder the loop for node2 + to be compatibile with node1 if that's more efficient. + """ + + # TODO Don't do loop reordering for CPU for now. + # Should debug more why it does not work for CPU codegen + if not config.loop_ordering_after_fusion or any( + n.get_device().type == "cpu" for n in [node1, node2] + ): + return False + + node1_buffer_names = node1.read_writes.buffer_names() + node2_buffer_names = node2.read_writes.buffer_names() + # Fast path: no common buffers. + common_buffer_names = node1_buffer_names & node2_buffer_names + if not common_buffer_names: + return False + + node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} + node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} + + # Find the commons buffers that has different loop orders + candidates = [] + for buffer_name in common_buffer_names: + lhs_dep = node1_name2dep[buffer_name] + rhs_dep = node2_name2dep[buffer_name] + if ( + lhs_dep.normalize_with_stride_order() + == rhs_dep.normalize_with_stride_order() + ): + candidates.append( + ( + V.graph.sizevars.size_hint(lhs_dep.get_numel(), fallback=0), + lhs_dep, + rhs_dep, + ) + ) + + if len(candidates) == 0: + return False + + # Pick the largest buffer to guide the loop reordering + numel, lhs_dep, rhs_dep = sorted(candidates, reverse=True, key=lambda x: x[0])[ + 0 + ] + + if lhs_dep.num_vars != rhs_dep.num_vars: + # this can happen due to we don't merge loops. + # We can not do loop reordering in this case right now + # Simply returning true if the two Deps are the same after + # normalization (merging loops) + return lhs_dep.normalize() == rhs_dep.normalize() + + # Only reorder loops for pointwise for now + if not node1.is_reduction(): + node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep) + elif not node2.is_reduction(): + node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep) + else: + loop_ordering_log.debug( + "Don't reorder loops since both nodes are reductions: %s v.s. %s", + node1.get_name(), + node2.get_name(), + ) + + return self.score_fusion_memory(node1, node2) > 0 + + def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: + """ + Determine if it is possible to combine node1 and node2 into a + single fused node. + """ + + if node1 is node2: + return False + + why = WhyNoFuse(node1, node2) + + if isinstance(node1, GroupedSchedulerNode) or isinstance( + node2, GroupedSchedulerNode + ): + why("grouped node must not be fused with other nodes") + return False + if ( + isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node1.is_template() + ): + why("node1 is extern or nop") + return False + if ( + isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node2.is_template() + ): + why("node2 is extern or nop") + return False + + if node2.get_operation_names() & node1.ancestors: + why("node1 must go before node2") + return False + + if node2.is_template(): + why("templates can only fuse epilogues") + return False + if node1.is_template() and ( + node2.has_aliasing_or_mutation() + or node2.is_reduction() + or not config.epilogue_fusion + ): + why("template epilogue not satisfied") + return False + + if ( + node1.get_buffer_names() | node2.get_buffer_names() + ) & V.graph.no_fuse_buffer_names: + why("fusion for buffer explicit disabled") + return False + + device = node1.get_device() + device2 = node2.get_device() + if device != device2: + why("device mismatch (%s vs %s)", device, device2) + return False + del device2 + + no_shared_data = self.score_fusion_memory(node1, node2) == 0 + if no_shared_data: + no_shared_data = not self.has_shared_data_after_reordering_loop( + node1, node2 + ) + + loop_ordering_log.debug( + "%s and %s has%s shared data", + node1.get_name(), + node2.get_name(), + " no" if no_shared_data else "", + ) + if no_shared_data and ( + not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction() + ): + if is_metric_table_enabled("fusion_failure_due_to_indexing_mismatch"): + common_buf_names = ( + node1.read_writes.buffer_names() & node2.read_writes.buffer_names() + ) + if len(common_buf_names) > 0: + get_metric_table("fusion_failure_due_to_indexing_mismatch").add_row( + lambda: { + "pre_grad_graph_id": V.graph.graph_id, + "post_grad_graph_id": V.graph.post_grad_graph_id, + "node1_name": node1.get_name(), + "node2_name": node2.get_name(), + "node1_debug_str": write_text(node1.debug_str()), + "node2_debug_str": write_text(node2.debug_str()), + "common_buffer_names": list(common_buf_names), + "failure_reason": self.decide_fusion_fail_reason( + node1, node2, common_buf_names + ), + } + ) + + why("no shared data due to indexing mismatch") + return False + why("no shared data") + return False # heuristic not needed for correctness + + if ( + not node1.is_foreach() + and not node2.is_foreach() + and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size + ): + why("exceeds max fusion") + return False # heuristic not needed for correctness + + if node1.get_operation_names() & node2.ancestors: + # node2 depends on node1 outputs + if not self.can_fuse_vertical(node1, node2): + return False + return self.get_backend(device).can_fuse_vertical(node1, node2) + else: # nodes don't depend on each other, but may have common reads + if self.can_fusion_increase_peak_memory(node1, node2): + why("will increase peak memory") + return False + return self.get_backend(device).can_fuse_horizontal(node1, node2) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Check if it is legal to fuse a consumer (node2) into a producer (node1). + + We can fuse them if all the reads of node2 either match + corresponding writes in node1, or are written by nodes that can + be scheduled before the fusion of node1 and node2. + """ + node1_buf_names = node1.get_buffer_names() + node1_op_names = node1.get_operation_names() + computed_deps: OrderedSet[Dep] = OrderedSet() + why = WhyNoFuse(node1, node2) + + for cd in node1.read_writes.writes: + if not isinstance(cd, MemoryDep): + continue + for rd in node2.unmet_dependencies: + if self.fusable_read_and_write(rd, cd): + computed_deps.add(rd) + + for dep in node2.unmet_dependencies: + if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2): + computed_deps.add(dep) + + remaining_deps = OrderedSet( + dep.name for dep in node2.unmet_dependencies - computed_deps + ) + if remaining_deps & node1_buf_names: + # MemoryDeps didn't match and read different locations of the same buffer. + # Examples here include: + # - MemoryDep("foo", x) != MemoryDep("foo", x + 1) + # - MemoryDep("foo", x) != StarDep("foo") + why("memory deps did not match") + return False + for name in remaining_deps: + op_name = self.name_to_buf[name].defining_op.get_name() + if node1_op_names & self.name_to_fused_node[op_name].ancestors: + why("intermediate nodes between node1 & node2") + return False + + return True + + def fusable_weak_dep( + self, weak_dep: WeakDep, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if weak_dep.name not in node1.get_buffer_names(): + return False + + # A weak dep can be fused if and only if the fused operation acts inplace + # on the buffer being mutated. i.e. the same index is being read then mutated + mutating_writes = [ + write + for write in node2.read_writes.writes + if write.name == weak_dep.mutating_buf + ] + if len(mutating_writes) != 1: + return False + write = mutating_writes[0] + assert isinstance(write, MemoryDep) + + if free_symbol_is_type(write.index, SymT.TMP): + return False + + real_name = self.mutation_real_name[weak_dep.mutating_buf] + relevant_reads = [ + read for read in node1.read_writes.reads if read.name == real_name + ] + return all( + isinstance(read, MemoryDep) + and not free_symbol_is_type(read.index, SymT.TMP) + and read.index == write.index + and read.size == write.size + for read in relevant_reads + ) + + # StarDep doesn't match MemoryDep, different indices don't match + # However, broadcasting sometimes strips dimensions, and if that's the case + # we still can match unmet dep + # if there's indirect indexing, don't match it + def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: + if isinstance(read, MemoryDep): + if read.mode == write.mode and write.mode is not None: + return True + read_name = self.mutation_renames.get(read.name, read.name) + + if ( + read_name != write.name + or free_symbol_is_type(read.index, SymT.TMP) + or free_symbol_is_type(write.index, SymT.TMP) + ): + return False + + if config.loop_ordering_after_fusion and read.num_vars != write.num_vars: + # Need merge loops if we do loop ordering after fusion since + # we have not merged the loops yet when creating the scheduler + # nodes. + read = read.normalize() + write = write.normalize() + + return ( + read.index == write.index + and len(read.size) >= len(write.size) + and read.size[: len(write.size)] == write.size + ) + elif isinstance(read, StarDep): + read_name = self.mutation_renames.get(read.name, read.name) + write_name = self.mutation_renames.get(write.name, write.name) + if ( + read.mode == write.mode + and write.mode is not None + and read_name == write_name + ): + return True + return False + + def score_fusion( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> Tuple[bool, bool, int, int]: + """ + Assign a score (higher comes first) to the fusion of node1 + and node2. When different fusions conflict with each other, + this is the way we decide what order to run them in. + + Our current score is based on: + - Estimate of the saved memory operations + - Fusions closer together in original order + """ + memory_score = self.score_fusion_memory(node1, node2) + proximity_score = -max( + abs(node1.min_order - node2.max_order), + abs(node2.min_order - node1.max_order), + ) + return ( + node1.is_template() == config.epilogue_fusion_first and memory_score > 0, + node1.is_reduction() == node2.is_reduction() and memory_score > 0, + memory_score, + proximity_score, + ) + + def dep_size_hint(self, dep: Dep) -> int: + res = 0 + if dep not in self.__dep_size_hint_cache: + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.__dep_size_hint_cache[dep] = res + else: + res = self.__dep_size_hint_cache[dep] + return res + + def score_fusion_memory( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + The first term in our fusion score that estimates number of saved + memory operations. + """ + node1_dep_len = len(node1.read_writes.reads) + len(node1.read_writes.writes) + node2_dep_len = len(node1.read_writes.reads) + len(node2.read_writes.writes) + + # optimization: iter over smaller set + if max(node1_dep_len, node2_dep_len) * 4 > min(node1_dep_len, node2_dep_len): + if node1_dep_len > node2_dep_len: + tmp = node1 + node1 = node2 + node2 = tmp + + deps = [] + for dep in node1.read_writes.reads | node1.read_writes.writes: + if dep in node2.read_writes.reads or dep in node2.read_writes.writes: + deps.append(dep) + + return sum(self.dep_size_hint(dep) for dep in deps) + + common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( + node2.read_writes.reads | node2.read_writes.writes + ) + return sum(self.dep_size_hint(dep) for dep in common_memory_deps) + + def get_possible_fusions_with_highest_priority( + self, possible_fusions: List[Tuple[BaseSchedulerNode, BaseSchedulerNode]] + ) -> List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]: + # Group the possible fusions based on their priority from the backend. + # Only return the group of possible fusions with highest priority. + if len(possible_fusions) == 0: + return possible_fusions + possible_fusions_group_by_priority: Dict[ + int, List[Tuple[BaseSchedulerNode, BaseSchedulerNode]] + ] = {} + + for node1, node2 in possible_fusions: + assert node1.get_device() == node2.get_device() + device = node1.get_device() + fusion_pair_priority = int( + self.get_backend(device).get_fusion_pair_priority(node1, node2) + ) + if fusion_pair_priority not in possible_fusions_group_by_priority: + possible_fusions_group_by_priority[fusion_pair_priority] = [ + (node1, node2), + ] + else: + possible_fusions_group_by_priority[fusion_pair_priority].append( + (node1, node2) + ) + # return the possible fusions with highest priority + possible_fusions_with_highest_priority = min( + possible_fusions_group_by_priority.items(), key=operator.itemgetter(0) + )[1] + assert len(possible_fusions_with_highest_priority) > 0 + return possible_fusions_with_highest_priority + + def score_fusion_key( + self, nodes: Tuple[BaseSchedulerNode, BaseSchedulerNode] + ) -> Tuple[bool, bool, int, int]: + """ + Shim for list.sort(key=...) + """ + node1, node2 = nodes + return self.score_fusion(node1, node2) + + def compute_last_usage(self) -> None: + """ + Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode) + """ + + future_used_buffers: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + + for node in reversed(self.nodes): + node.set_last_usage(future_used_buffers, self.mutation_real_name) + future_used_buffers.update(node.last_usage) + + def free_buffers(self) -> None: + """Free any buffers that are no longer needed""" + for name in sorted( + self.buffer_names_to_free + - V.graph.removed_buffers + - V.graph.wrapper_code.freed + ): + if name in self.name_to_buf: + buf = self.name_to_buf[name] + if buf.can_free(): + V.graph.wrapper_code.codegen_free(buf.node) + elif name in V.graph.graph_inputs: + storage = V.graph.graph_inputs[name].data + assert isinstance(storage, ir.StorageBox) and storage.is_input_buffer() + V.graph.wrapper_code.codegen_free(storage.data) + + self.buffer_names_to_free.clear() + + def remove_kernel_local_buffers(self) -> None: + """ + Any buffers that are both created and have a last use in the + same kernel can be removed. + """ + + fused_node_names = OrderedSet( + self.name_to_buf[buf].defining_op.get_name() + for buf in V.kernel.store_buffer_names + if buf in self.name_to_buf + ) + names_to_remove = [] + for out_buf in V.kernel.store_buffer_names: + if out_buf not in self.name_to_buf: + # Aux buffers created during kernel codegen + names_to_remove.append(out_buf) + continue + users = self.name_to_buf[out_buf].users + assert users is not None + users = OrderedSet(user.get_name() for user in users if not user.is_weak) + if users.issubset(fused_node_names): + names_to_remove.append(out_buf) + + def remove_filter(n: str) -> bool: + return ( + n not in V.kernel.must_keep_buffers + and n not in V.kernel.args.input_buffers + and n not in self.mutation_renames + and n not in self.mutation_real_name + ) + + names_to_remove = list(filter(remove_filter, names_to_remove)) + + for name in names_to_remove: + if name in V.kernel.args.inplace_buffers: + buf = V.kernel.args.inplace_buffers[name] + if isinstance(buf, str) and buf.startswith("REMOVED"): + continue + remove = all(n in names_to_remove for n in buf.other_names) + if remove: + self.remove_inplace_buffer(name) + V.kernel.inplaced_to_remove.add(name) + else: + self.remove_buffer(name) + + def remove_buffer(self, name: str) -> None: + # Assign a special value instead of deleting the entry + # because we still rely on output_buffers's length to + # generate unique arg name. + log.debug("remove_buffer(%r)", name) + V.kernel.args.output_buffers[name] = "REMOVED" + V.kernel.removed_buffers.add(name) + + def remove_inplace_buffer(self, name: str) -> None: + log.debug("removing_inplace_buffer(%r)", name) + inner_name = V.kernel.args.inplace_buffers[name].inner_name + V.kernel.args.inplace_buffers[name] = inner_name.replace( + "in_out_ptr", "REMOVED" + ) + V.kernel.removed_buffers.add(name) + + def flush(self) -> None: + for backend in self.backends.values(): + backend.flush() + self.free_buffers() + + def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode) -> None: + assert isinstance(scheduler_node, ExternKernelSchedulerNode) + # 'decide_inplace_update' stores the inplace update decisions in + # the current kernel from where 'allocate' retrieve those decisions. + # We have to make sure there is a non-NULL kernel handler to store + # those inplace update decisions. + counters["inductor"]["extern_calls"] += 1 + with V.set_kernel_handler(Kernel(increase_kernel_count=False)): + scheduler_node.decide_inplace_update() + scheduler_node.mark_run() + node = scheduler_node.node + assert isinstance(node, ir.ExternKernel), f"{type(node)=}" + node.codegen(V.graph.wrapper_code) + self.free_buffers() + + def create_backend(self, device: torch.device) -> BaseScheduling: + assert ( + not is_gpu(device.type) or device.index is not None + ), f"{device} should have been normalized in lowering" + V.graph.add_device_info(device) + + device_scheduling = get_scheduling_for_device(device.type) + if device_scheduling is None: + raise RuntimeError(f"Unsupported device type: {device.type}") + + if not has_triton(): + if ( + device.type == "cuda" + and (device_props := torch.cuda.get_device_properties(device)).major < 7 + ): + raise RuntimeError( + f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950 + ) + elif is_gpu(device.type): + raise RuntimeError( + "Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950 + ) + + return device_scheduling(self) + + def get_backend(self, device: torch.device) -> BaseScheduling: + if device not in self.backends: + self.backends[device] = self.create_backend(device) + return self.backends[device] + + def enter_context(self, node: BaseSchedulerNode) -> None: + def get_order(n: torch.fx.Node) -> int: + if n not in self.origin_to_index: + self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) + return self.origin_to_index[n] + + # Use a dict to have ordering + origins = { + (get_order(e), e): None + for n in node.get_nodes() + if n.node is not None + for e in n.node.get_origins() + } + origins = list(origins.keys()) + if origins: + _, last = max(origins, key=operator.itemgetter(0)) + V.graph.wrapper_code.enter_context(last) + + def codegen(self) -> None: + with dynamo_timed("Scheduler.codegen"): + return self._codegen() + + def _codegen(self) -> None: + if config.check_stack_no_cycles_TESTING_ONLY: + import torch._dynamo.convert_frame + + stack = traceback.extract_stack() + seen = set() + for frame in reversed(stack): + # This is where maybe_cprofile is + if ( + frame.name == "_compile_inner" + and frame.filename == torch._dynamo.convert_frame.__file__ + ): + break + key = (frame.filename, frame.lineno) + assert key not in seen, ( + f"Duplicate stack frame {frame.filename}:{frame.lineno}; " + "did you add a decorator to one of the functions in this stack " + "trace? If so, try using a context manager instead." + ) + seen.add(key) + + for node in self.nodes: + try: + log.debug( + "Generating code for node %s with estimated runtime %f", + node.get_name(), + node.get_estimated_runtime(), + ) + except Exception as e: + log.debug( + "Generating code for node %s with estimated runtime 0.0", + node.get_name(), + ) + + self.enter_context(node) + + if not isinstance(node, NopKernelSchedulerNode) and ( + device := node.get_device() + ): + if ( + device != self.current_device + or node.is_extern() + or node.is_template() + ): + self.flush() + if device != self.current_device: + if self.current_device and device_need_guard( + self.current_device.type + ): + V.graph.wrapper_code.codegen_device_guard_exit() + if device_need_guard(device.type): + assert device.index is not None, "device should have an index" + V.graph.wrapper_code.codegen_device_guard_enter(device.index) + + self.current_device = device + + self.buffer_names_to_free.update(node.last_usage) + + if node.is_template(): + node, *epilogue = node.get_nodes() + self.get_backend(device).codegen_template(node, epilogue) + elif node.is_extern(): + node = typing.cast(ExternKernelSchedulerNode, node) + self.codegen_extern_call(node) + elif node.is_foreach(): + node = typing.cast(ForeachKernelSchedulerNode, node) + backend_ = self.get_backend(device) + from .codegen.cuda_combined_scheduling import CUDACombinedScheduling + from .codegen.simd import SIMDScheduling + + if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)): + backend = backend_ + else: + raise AssertionError(f"{type(self)=}") + backend.codegen_combo_kernel(node) + elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): + self.get_backend(device).codegen_node(node) + else: + assert isinstance(node, NopKernelSchedulerNode) + node.mark_run() + + if config.triton.debug_sync_kernel: + self.get_backend(device).codegen_sync() + + self.available_buffer_names.update(node.get_buffer_names()) + self.completed_operations.update(node.get_operation_names()) + + if not isinstance(node, NopKernelSchedulerNode): + device = node.get_device() + if device is not None and self.get_backend(device).ready_to_flush(): + self.flush() + + if self.current_device and device_need_guard(self.current_device.type): + # exit the outermost CUDA device guard. this is + # important for nested indentation codegen-ing. + V.graph.wrapper_code.codegen_device_guard_exit() + + self.flush() + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode] + ) -> Tuple[float, float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + device = node_list[0].get_device() + V.graph.scheduler = self + self.current_device = device + backend = self.get_backend(device) + return backend.benchmark_combo_kernel(node_list) + + def speedup_by_combo_kernel(self, nodes: List[BaseSchedulerNode]) -> bool: + """ + If config.benchmark_fusion is False, always return True. + Otherwise, return True if fusion can brings speedup. + """ + if not config.benchmark_combo_kernel: + return True + + subkernel_nodes = nodes + device = subkernel_nodes[0].get_device() + + # don't support benchmark fusion for CPU right now. + if device.type == "cpu": + return True + + from triton.compiler.errors import CompilationError + + ms1, path1_list = 0.0, [] + for i, snode in enumerate(subkernel_nodes): + node_list = snode.get_nodes() + # We can not accurately benchmark kernel using atomic_add + # due to how we generate random integer inputs. + if self._any_atomic_add(node_list): + fusion_log.debug( + "ComboKernel: benchmarking may not accurate due to atomic_add" + ) + + try: + ms, path = self.benchmark_fused_nodes(node_list) + if math.isinf(ms): + fusion_log.debug( + "ComboKernel benchmark: register spilling of %d-th subkernel", + i, + ) + return False + except CompilationError as e: + # workaround triton issue: https://github.com/openai/triton/issues/2151 + if "Loop-carried variable" in str(e): + fusion_log.debug( + "ComboKernel benchmark: return True because of loop-carried variable" + ) + return True # allow fusion + else: + raise + ms1 += ms + path1_list.append(path) + + try: + ms2, ms2_clone, path2_list = self.benchmark_combo_kernel(subkernel_nodes) + except CompilationError as e: + # workaround triton issue: https://github.com/openai/triton/issues/2151 + if "Loop-carried variable" in str(e): + fusion_log.debug( + "ComboKernel benchmark: return True because of loop-carried variable" + ) + return True # allow fusion + else: + raise + + # small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking. + small_kernel = ms2 - ms2_clone < 0.3 or ms1 < 0.3 + if fusion_log.isEnabledFor(logging.DEBUG): + if ms1 > ms2 or small_kernel: + fusion_log.debug( + "can fuse (benchmark): fusing causes %sx speedup", + green_text(f"{ms1 / ms2:.3f}"), + ) + else: + fusion_log.debug( + "cannot fuse (benchmark): fusing causes %sx slowdown", + red_text(f"{ms1 / ms2:.3f}"), + ) + # ms1 returned by benchmark_fused_nodes discounted clone time + return ms2 - ms2_clone < ms1 or small_kernel + + def get_buffer_layout(self, buf_name: str) -> ir.Layout: + buf = self.name_to_buf[buf_name] + assert buf.node is not None + return buf.node.get_layout() + + def update_zero_dim_cpu_tensor(self) -> None: + for node in self.nodes: + if node.get_device() and is_gpu(node.get_device().type): + for read in node.read_writes.reads: + buffer = V.graph.name_to_buffer.get(read.name) + if ( + buffer + and buffer.get_device() + and buffer.get_device().type == "cpu" + and not isinstance(buffer.layout, MultiOutputLayout) + and buffer.get_size() == [] + ): + V.graph.zero_dim_cpu_tensor_list.add(read.name) + + +class BaseScheduling: + @classmethod + def get_backend_features(cls, device: torch.device) -> Sequence[BackendFeature]: + """Return a set of .codegen.common.BackendFeature()""" + return () + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Check whether node1 and node2 can be vertically fused or not. + """ + raise NotImplementedError + + def can_fuse_horizontal( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Check whether node1 and node2 can be horizontally fused or not. + """ + raise NotImplementedError + + def fuse( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> FusedSchedulerNode: + """ + Fuse two nodes + """ + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + else: + return FusedSchedulerNode.fuse(node1, node2) + + def group_fn( + self, sizes: Sequence[Sequence[sympy.Expr]] + ) -> Tuple[Tuple[sympy.Expr, ...], ...]: + """ + Process the iteration sizes in case a transformation needs to be applied. + """ + raise NotImplementedError + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + ) -> Optional[str]: + """ + Given a template node, generate a kernel. + + This function is only available for triton now. If the third-party backend behaves as a sub-class + of TritonScheduling, it can override it or reuse it. + """ + raise NotImplementedError + + def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None: + """ + Generate a kernel given a list of pre-fused nodes. + """ + raise NotImplementedError + + def codegen_sync(self) -> None: + """ + Generate synchronization code for the kernel. This method depends on the hardware characteristics. + """ + raise NotImplementedError + + def ready_to_flush(self) -> bool: + """ + Check whether the backend is requesting the scheduler to flush the generated kernel. + If not supported, please return False. + """ + return False + + def flush(self) -> None: + """ + Flush the generated kernel and python wrapper code to the source code file. + """ + raise NotImplementedError + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> Tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + raise NotImplementedError + + def get_fusion_pair_priority( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + Return an unsigned integer which represents the priority of this fusion pair. + The smaller is with higher priority. + """ + return 0 + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode] + ) -> Tuple[float, float, str]: + """ + Benchmark the list of nodes to combine and return the execution time + and memory copy time in milliseconds on randomly generated inputs. + """ + raise NotImplementedError + + +def debug_triton_code(node: Union[SchedulerNode, FusedSchedulerNode]) -> List[str]: + lines = [] + multi_template = node.get_template_node() + assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) + if multi_template and multi_template.make_kernel_render is None: + lines.append(f"{node.get_name()} Unfinalized multi template buffer") + else: + from torch._inductor.codegen.cuda_combined_scheduling import ( + CUDACombinedScheduling, + ) + + from .codegen.simd import SIMDScheduling + + snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes + device = snodes[0].get_device() + backend = node.scheduler.get_backend(device) + assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)) + V.graph.scheduler.current_device = device + + # Don't increment kernel count when generating debug string. + # This will confuse some unit tests that check the number of + # generated kernels. + old_generated_kernel_count = metrics.generated_kernel_count + triton_code = backend.generate_kernel_code_from_nodes(snodes).strip() + metrics.generated_kernel_count = old_generated_kernel_count + + lines.append(f"{node.get_name()} Triton code:") + lines.append(textwrap.indent(triton_code, " ")) + return lines diff --git a/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py b/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..67d9651a1e39fae326a190102fe8a23aee29fcd0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py @@ -0,0 +1,1743 @@ +# mypy: allow-untyped-defs +import builtins +import contextlib +import functools +import inspect +import itertools +import json +import logging +import math +import operator +import os +import sys +import textwrap +import time +from collections import namedtuple +from concurrent.futures import as_completed, ThreadPoolExecutor +from io import StringIO +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest.mock import patch + +import sympy +from filelock import FileLock + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.testing import rand_strided +from torch._dynamo.utils import counters, identity, preserve_rng_state + +from . import config, ir +from .autotune_process import TensorMeta, TritonBenchmarkRequest +from .codecache import code_hash, PersistentCache, PyCodeCache +from .codegen.common import IndentedBuffer, KernelTemplate +from .codegen.triton import ( + gen_common_triton_imports, + texpr, + TritonKernel, + TritonPrinter, + TritonScheduling, +) +from .codegen.triton_utils import config_of, signature_to_meta +from .exc import CUDACompileError +from .ir import ChoiceCaller, PrimitiveInfoType +from .runtime.benchmarking import benchmarker +from .runtime.hints import DeviceProperties +from .utils import ( + FakeIndentedBuffer, + get_dtype_size, + Placeholder, + restore_stdout_stderr, + sympy_dot, + sympy_index_symbol, + sympy_product, + unique, +) +from .virtualized import V + + +log = logging.getLogger(__name__) + +# correctness checks struggle with fp16/tf32 +VERIFY: Dict[str, Any] = {} +PRINT_AUTOTUNE = True +DEBUG = False + + +class KernelNamespace: + pass + + +# these objects are imported from the generated wrapper code +extern_kernels = KernelNamespace() + + +class PartialRender: + """ + Some parts of a template need to be generated at the end, but + inserted into the template at the start. This allows doing a bunch + of replacements after the initial render. + """ + + def __init__(self, code, replacement_hooks) -> None: + super().__init__() + self.code = code + self.replacement_hooks = replacement_hooks + + def finalize_hook(self, hook_key: str, strict=True) -> None: + if hook_key not in self.replacement_hooks: + if strict: + raise RuntimeError( + f"{hook_key} not registered in self.replacement_hooks" + ) + else: + return + assert ( + self.replacement_hooks[hook_key] is not None + ), "hook_key can only be called once" + self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]()) + self.replacement_hooks[hook_key] = None + + def finalize_all(self) -> str: + for key, fn in self.replacement_hooks.items(): + self.code = self.code.replace(key, fn()) + return self.code + + +# This is used to store info needed for lowering each subgraph in triton +# templates +SubgraphInfo = namedtuple( + "SubgraphInfo", + [ + "body", + "template_mask", + "template_out", + ], +) + + +class TritonTemplateKernel(TritonKernel): + def __init__( + self, + kernel_name, + input_nodes, + output_node, + defines, + num_stages, + num_warps, + grid_fn, + meta, + call_sizes, + use_jit=False, + prefix_args=0, + suffix_args=0, + epilogue_fn=identity, + subgraphs: Optional[List[ir.ComputedBuffer]] = None, + *, + index_dtype, + ) -> None: + super().__init__( + sympy_product(output_node.get_size()), + sympy.Integer(1), + index_dtype=index_dtype, + ) + self.input_nodes = input_nodes + self.output_node = output_node + self.named_input_nodes = {} # type: ignore[var-annotated] + self.defines = defines + self.kernel_name = kernel_name + self.use_jit = use_jit + self.num_stages = num_stages + self.num_warps = num_warps + self.grid_fn = grid_fn + self.meta = meta + self.call_sizes = call_sizes + # for templates with fixed epilogues + self.prefix_args = prefix_args + self.suffix_args = suffix_args + self.epilogue_fn = epilogue_fn + self.render_hooks = {} # type: ignore[var-annotated] + self.triton_meta: Optional[Dict[str, object]] = None + # For Templated Attention this can be a list of ir.Subgraph + self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs + + # The following attributes (body, template_mask, output_val) are all + # used for triton kernel codegen. + # They are swapped onto the TritonTemplateKernel object by + # `set_subgraph_body` + self.subgraph_bodies: Dict[str, SubgraphInfo] = {} + + self.body: IndentedBuffer = FakeIndentedBuffer() + self.template_mask: Optional[str] = None + self.template_out: Optional[str] = None + + @contextlib.contextmanager + def set_subgraph_body(self, body_name: str): + old_body, old_mask, old_out = self.body, self.template_mask, self.template_out + assert body_name in self.subgraph_bodies, body_name + self.body, self.template_mask, self.template_out = self.subgraph_bodies[ + body_name + ] + yield + self.subgraph_bodies[body_name] = SubgraphInfo( + self.body, self.template_mask, self.template_out + ) + self.body, self.template_mask, self.template_out = old_body, old_mask, old_out + + @contextlib.contextmanager + def create_subgraph_body(self, body_name: str): + assert body_name not in self.subgraph_bodies + self.subgraph_bodies[body_name] = SubgraphInfo(IndentedBuffer(), None, None) + with self.set_subgraph_body(body_name): + yield + + def need_numel_args(self): + return False + + def estimate_kernel_num_bytes(self): + """ + Estimate the total number of bytes this kernel takes. + For in/out nodes, sizes are counted twice: once for reading and + once for writing. + """ + ninplace_args = len(unique(self.args.inplace_buffers.values())) + num_bytes = [] + for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): + size = V.graph.sizevars.size_hints(inp.get_size()) + numel = functools.reduce(operator.mul, size, 1) + dtype_size = get_dtype_size(inp.get_dtype()) + num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(num_bytes) + + def jit_lines(self): + if self.use_jit: + return "@triton.jit" + + argdefs, _, signature, _ = self.args.python_argdefs() + triton_meta = { + "signature": signature_to_meta(signature, size_dtype=self.index_dtype), + "device": DeviceProperties.create(self.output_node.get_device()), + "constants": {}, + } + triton_meta["configs"] = [config_of(signature)] + for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] + triton_meta["constants"][arg_num] = 1 # type: ignore[index] + matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0) + if matrix_instr_nonkdim != 0: + triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim + + self.triton_meta = triton_meta + + inductor_meta = { + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + **TritonKernel.inductor_meta_common(), + } + if config.profile_bandwidth or config.benchmark_kernel: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + return f""" + @triton_heuristics.template( + num_stages={self.num_stages}, + num_warps={self.num_warps}, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + ) + @triton.jit + """ + + def gen_argdefs(self): + def hook(): + # python_argdefs() cannot be run until after the rest of the template lazily adds more args + arg_defs, *_ = self.args.python_argdefs() + return f"{', '.join(arg_defs)}" + + self.render_hooks[""] = hook + return "" + + def gen_defines(self): + return self.defines + + def def_kernel(self, *argnames): + """ + Hook called from template code to generate function def and + needed args. + """ + assert all(isinstance(x, str) for x in argnames) + renames = IndentedBuffer(initial_indent=1) + + named_args = self.input_nodes[ + self.prefix_args : len(self.input_nodes) - self.suffix_args + ] + + assert len(argnames) == len(named_args), ( + len(argnames), + len(named_args), + self.prefix_args, + len(self.input_nodes), + ) + + for input_node in self.input_nodes[: self.prefix_args]: + # get args in correct order + self.args.input(input_node.get_name()) + + for name, input_node in zip(argnames, named_args): + arg_name = f"arg_{name}" + self.named_input_nodes[name] = input_node + self.args.input_buffers[input_node.get_name()] = arg_name + + # The args may be duplicated, so renaming must be after args are de-duplicated. + for name in argnames: + input_node = self.named_input_nodes[name] + arg_name = self.args.input_buffers[input_node.get_name()] + if input_node.get_layout().offset == 0: + renames.writeline(f"{name} = {arg_name}") + else: + offset = texpr(self.rename_indexing(input_node.get_layout().offset)) + renames.writeline(f"{name} = {arg_name} + {offset}") + + for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]: + # get args in correct order + self.args.input(input_node.get_name()) + + def hook(): + # python_argdefs() cannot be run until after the rest of the template lazily adds more args + arg_defs, *_ = self.args.python_argdefs() + code = IndentedBuffer() + code.splice(gen_common_triton_imports()) + code.splice(self.jit_lines()) + code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):") + with code.indent(): + code.splice(self.defines) + code.splice(renames.getvalue()) + return code.getvalue() + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def size(self, name: str, index: int): + """ + Hook called from template code to get the size of an arg. + Will add needed args to pass it in if it is dynamic. + """ + assert isinstance(index, int) + if name is None: + val = self.output_node.get_size()[index] + else: + assert isinstance(name, str) + val = self.named_input_nodes[name].get_size()[index] + return texpr(self.rename_indexing(val)) + + def stride(self, name, index=None): + """ + Hook called from template code to get the stride of an arg. + Will add needed args to pass it in if it is dynamic. + """ + if name is None: + val = self.output_node.get_stride() + else: + assert isinstance(name, str) + val = self.named_input_nodes[name].get_stride() + + if isinstance(index, int): + return texpr(self.rename_indexing(val[index])) + else: + return ", ".join([texpr(self.rename_indexing(i)) for i in val]) + + def modification( + self, subgraph_number: int, output_name: str, **fixed_inputs + ) -> str: + """This creates a modification function for a subgraph. + To use this inside a template, the first argument should specify which subgraph to codegen for + + Args: + subgraph_number (int): The index of the subgraph in self.subgraphs + """ + num = 0 + while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: + num += 1 + with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): + assert isinstance(subgraph_number, int) + assert isinstance(self.subgraphs, list) + assert ( + self.body.getvalue() == "" + ), "Body should be clear before adding a modification" + assert subgraph_number < len( + self.subgraphs + ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" + + subgraph = self.subgraphs[subgraph_number] + + def add_input(name): + return self.args.input(name) + + name = f"PlaceholderSubstitution_{subgraph_number}" + + class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined] + self.name = name + + def load(self, name: str, index: sympy.Expr): + if name not in fixed_inputs: + # If it's not a fixed input, it's a load from a captured + # tensor + var = add_input(name) + return f"tl.load({var} + {index})" + + return f"({fixed_inputs[name]})" + + def indirect_indexing(self, index_var, size, check, wrap_neg=True): + return sympy_index_symbol(str(index_var)) + + with V.set_ops_handler(PlaceholderSubstitution(V.ops)): + assert isinstance( + subgraph, ir.ComputedBuffer + ), f"Expected the subgraph to be a ComputedBuffer, got {type(subgraph)}" + if isinstance(subgraph.data, ir.InputBuffer): + out = subgraph.data.make_loader()(()) + else: + out = subgraph.data.inner_fn(()) + + self.codegen_body() + self.body.writeline(f"{output_name} = {out.value}") + + body_val = self.body.getvalue() + self.cse.invalidate(set()) # type: ignore[arg-type] + return body_val + + def store_output( + self, + indices: Union[List[Any], Tuple[Any]], + val: str, + mask: Optional[str] = None, + indent_width: int = 4, + ): + """Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away. + + Args: + indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of + these indices and output strides must match `val`. + val (str): The value to store. + mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask + will be applied to the store. + indent_width (int): The number of spaces to use for indentation. This is used when the call to + store_output is indented in the kernel definition. + """ + with self.create_subgraph_body(""): + assert isinstance(indices, (list, tuple)) + assert isinstance(val, str) + assert isinstance(mask, (str, type(None))) + assert self.template_mask is None + indices = list(map(TritonPrinter.paren, indices)) + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] + lengths = [ + V.graph.sizevars.simplify(s) for s in self.output_node.get_size() + ] + assert len(indices) == len(lengths) + + # glue to make generated code use same indexing from template + for name, range_tree_entry in zip( + indices, self.range_trees[0].construct_entries(lengths) + ): + range_tree_entry.set_name(name) + contiguous_index = sympy_dot( + ir.FlexibleLayout.contiguous_strides(lengths), index_symbols + ) + contiguous_index = self.rename_indexing(contiguous_index) + self.body.writeline("xindex = " + texpr(contiguous_index)) + self.range_trees[0].lookup( + sympy.Integer(1), sympy_product(lengths) + ).set_name("xindex") + self.template_mask = mask + self.template_out = val + self.template_indices = indices + output_index = self.output_node.get_layout().make_indexer()(index_symbols) + output_index = self.rename_indexing(output_index) + if output_index == contiguous_index: + output_index = sympy.Symbol("xindex", integer=True) + + epilogue_args = [val] + for input_node in itertools.chain( + self.input_nodes[: self.prefix_args], + self.input_nodes[len(self.input_nodes) - self.suffix_args :], + ): + input_node.freeze_layout() + epilogue_args.append(input_node.make_loader()(index_symbols)) + + V.ops.store( + self.output_node.get_name(), + output_index, + self.epilogue_fn(*epilogue_args), + ) + self.codegen_body() + + def hook(): + # more stuff might have been added since the codegen_body above + self.codegen_body() + + return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def render(self, template, kwargs): + return PartialRender( + template.render(**self.template_env(), **kwargs), + self.render_hooks, + ) + + def make_load(self, name, indices, mask): + """ + Optional helper called from template code to generate the code + needed to load from an tensor. + """ + assert isinstance(indices, (list, tuple)) + assert isinstance(name, str) + assert isinstance(mask, str) + stride = self.named_input_nodes[name].get_stride() + indices = list(map(TritonPrinter.paren, indices)) + assert len(indices) == len(stride) + index = " + ".join( + f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices) + ) + return f"tl.load({name} + ({index}), {mask}, other=0.0)" + + def template_env(self): + """ + Generate the namespace visible in the template. + """ + return { + fn.__name__: fn + for fn in [ + self.def_kernel, + self.size, + self.stride, + self.store_output, + self.make_load, + self.modification, + self.gen_argdefs, + self.gen_defines, + ] + } + + def indexing( + self, + index: sympy.Expr, + *, + dense_indexing=False, + copy_shape=None, + override_mask=None, + block_ptr=False, + ): + """ + Override the default indexing to use our custom mask and force + dense indexing. + """ + return super().indexing( + index, + dense_indexing=False, + # We pass template_out as the shape to broadcast the indexing to as + # the mask might be broadcast to the output shape + copy_shape=self.template_out, + override_mask=self.template_mask, + block_ptr=block_ptr, + ) + + def codegen_range_tree(self): + pass # ignore default codegen + + def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): + wrapper = V.graph.wrapper_code + _, call_args, _, arg_types = self.args.python_argdefs() + if V.graph.cpp_wrapper: + # In the cpp_wrapper case, we have to compute CUDA launch grid at runtime + # if any dynamic dimension is involved. We rely on the Python version + # of the grid function to generate those grid configs, which may contain + # symbolic values. The wrapper will use cexpr to print out C++ code + # appropriately for the grid configs. + grid = self.call_sizes + [self.meta] + wrapper.generate_kernel_call( + name, + call_args, + grid=self.grid_fn(*grid), + arg_types=arg_types, + triton_meta=self.triton_meta, + ) + else: + wrapper.add_import_once(f"import {self.grid_fn.__module__}") + meta = wrapper.add_meta_once(self.meta) + grid = self.call_sizes + [meta] + wrapper.generate_kernel_call( + name, + call_args, + grid=grid, + grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}", + arg_types=arg_types, + triton_meta=self.triton_meta, + ) + + +@functools.lru_cache(None) +def _jinja2_env(): + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +class TritonTemplate(KernelTemplate): + index_counter = itertools.count() + all_templates: Dict[str, "TritonTemplate"] = {} + + def __init__(self, name: str, grid: Any, source: str, debug=False) -> None: + super().__init__(name) + self.grid = grid + self.template = self._template_from_string(source) + assert name not in self.all_templates, "duplicate template name" + self.all_templates[name] = self + self.debug = debug + + def generate( # type: ignore[override] + self, + input_nodes, + layout, + num_stages, + num_warps, + prefix_args=0, + suffix_args=0, + epilogue_fn=identity, + subgraphs=None, + mutated_inputs=None, + call_sizes=None, + **kwargs, + ): + """This function generates a TritonTemplateCaller + + Args: + input_nodes: List of input nodes + layout: Output layout + num_stages: Number of stages for triton launch + num_warps: Number of warps for triton launch + prefix_args: Number of input nodes to be passed as arguments + suffix_args: Number of input nodes to be passed as arguments + epilogue_fn: Optional epilogue function to be called on the output + subgraphs: Optional subgraphs to be passed as arguments, these will be inlined + into the triton template string + mutated_inputs: Optional list of input nodes that are mutated by the kernel, this is helpful + if you need to return multiple outputs. You can pass them as inputs and mark them as + being mutated by the kernel. + """ + assert self.template, "requires jinja2" + defines = StringIO() + for name, val in kwargs.items(): + defines.write(f"{name} : tl.constexpr = {val}\n") + defines = defines.getvalue() + + fake_out = ir.Buffer("buf_out", layout) + kernel_name = f"triton_{self.name}" + + numel = sympy_product(layout.size) + buffers = itertools.chain(input_nodes, (fake_out,)) + if not TritonScheduling.can_use_32bit_indexing(numel, buffers): + raise NotImplementedError( + "64-bit indexing is not yet implemented for triton templates" + ) + + if call_sizes is None: + call_sizes = layout.size + + kernel_options = dict( + input_nodes=input_nodes, + defines=defines, + num_stages=num_stages, + num_warps=num_warps, + grid_fn=self.grid, + meta=kwargs, + call_sizes=call_sizes, + prefix_args=prefix_args, + suffix_args=suffix_args, + epilogue_fn=epilogue_fn, + index_dtype="tl.int32", + subgraphs=subgraphs, + ) + + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(fake_out) + ), TritonTemplateKernel( + kernel_name=kernel_name, + output_node=fake_out, + use_jit=False, + **kernel_options, + ) as kernel: + try: + template = kernel.render(self.template, kwargs) + with kernel.set_subgraph_body(""): + code = template.finalize_all() + except ZeroDivisionError: + # TODO(nmacchioni): fix sympy division by zero + return None + if self.debug: + print("Generated Code:\n", code) + extra = ( + "-".join( + [ + *[ + f"{kwarg}={repr(kwargs[kwarg])}" + for kwarg in sorted(kwargs.keys()) + ], + f"num_stages={num_stages}", + f"num_warps={num_warps}", + ] + ) + + "-" + ) + mod = PyCodeCache.load(code, extra) + + input_call_args = tuple(kernel.args.input_buffers.keys()) + output_call_args = tuple(kernel.args.output_buffers.keys()) + + # We expect the input_buffer order to be [*input_nodes, *captured_buffers] + expected_input_args = tuple(unique(x.get_name() for x in input_nodes)) + expected_output_args = (fake_out.get_name(),) + assert input_call_args[: len(expected_input_args)] == expected_input_args, ( + input_call_args, + expected_input_args, + ) + assert output_call_args == expected_output_args, ( + output_call_args, + expected_output_args, + ) + + full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args]) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, tuple(kernel.args.sizevars.keys())), + fallback=config.unbacked_symint_fallback, + ) + + kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}" + + def make_kernel_render(out_node): + kernel = TritonTemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), + output_node=out_node, + use_jit=False, + **kernel_options, + ) + render = functools.partial( + kernel.render, + self.template, + kwargs, + ) + return kernel, render + + # create the BenchmarkRequest + assert mod.__file__ is not None + grid = self.grid( + *V.graph.sizevars.size_hints( + call_sizes, + fallback=config.unbacked_symint_fallback, + ), + kwargs, + ) + bmreq = TritonBenchmarkRequest( + module_path=mod.__file__, + module_cache_key=mod.key, + kernel_name=kernel_name, + grid=grid, + extra_args=extra_args, + num_stages=num_stages, + num_warps=num_warps, + matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0), + input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type] + output_tensor_meta=TensorMeta.from_irnodes(layout), + ) + + return TritonTemplateCaller( + kernel_hash_name, + full_input_nodes, + layout, + make_kernel_render, + extra.strip("-").replace("-", ", "), + bmreq, + log_info={ + "tile_shape": str( + ( + kwargs.get("BLOCK_M", -1), + kwargs.get("BLOCK_K", -1), + kwargs.get("BLOCK_N", -1), + ) + ), + "num_stages": num_stages, + "num_warps": num_warps, + "allow_tf32": str(kwargs.get("ALLOW_TF32", None)), + "acc_type": str(kwargs.get("ACC_TYPE", None)), + }, + mutated_inputs=mutated_inputs, + ) + + +class ExternKernelChoice: + def __init__( + self, + kernel, + cpp_kernel=None, + *, + name=None, + has_out_variant=True, + op_overload=None, + use_fallback_kernel=False, + kernel_creator=None, + ) -> None: + super().__init__() + name = name or kernel.__name__ + assert callable(kernel) + assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}" + self.name = name + self.cpp_kernel_name = cpp_kernel + self.has_out_variant = has_out_variant + setattr(extern_kernels, name, kernel) + self.op_overload = op_overload + self.use_fallback_kernel = use_fallback_kernel + self.kernel_creator = kernel_creator + + def to_callable(self): + return getattr(extern_kernels, self.name) + + def call_name(self): + return f"extern_kernels.{self.name}" + + @functools.lru_cache(None) # noqa: B019 + def hash_key(self): + fn = self.to_callable() + parts = [ + self.name, + getattr(fn, "__name__", ""), + getattr(fn, "__module__", ""), + ] + try: + parts.append(inspect.getsource(fn)) + except Exception: + pass + return code_hash("-".join(parts)) + + def bind( + self, + input_nodes, + layout, + ordered_kwargs_for_cpp_kernel=(), + **kwargs, + ): + self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel + return ExternKernelCaller( + self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant + ) + + +class TritonTemplateCaller(ir.TritonTemplateCallerBase): + def __init__( + self, + name, + input_nodes, + layout, + make_kernel_render, + debug_extra, + bmreq, + log_info: Optional[ + Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]] + ] = None, + mutated_inputs=None, + ) -> None: + super().__init__(name, input_nodes, layout) + self.make_kernel_render = make_kernel_render + self.debug_extra = debug_extra + self.bmreq: TritonBenchmarkRequest = bmreq + if log_info is None: + log_info = {} + self.log_info: Dict[str, Any] = log_info + self.log_info.update( + { + "backend": "Triton", + "grid": str(self.bmreq.grid), + "num_stages": self.bmreq.num_stages, + "num_warps": self.bmreq.num_warps, + } + ) + self.mutated_inputs = mutated_inputs + + def benchmark(self, *args, out): + assert self.bmreq is not None + return self.bmreq.benchmark(*args, output_tensor=out) + + def precompile(self): + assert self.bmreq is not None + self.bmreq.precompile() + + def __str__(self) -> str: + return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})" + + def call_name(self): + return f"template_kernels.{self.name}" + + def hash_key(self): + return "-".join( + [ + self.name.rsplit("_", 1)[0], + self.bmreq.module_cache_key, + ] + ) + + def output_node(self): + return ir.TensorBox.create( + ir.TritonTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + debug_extra=self.debug_extra, + mutated_inputs=self.mutated_inputs, + ) + ) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return self.log_info + + def get_make_kernel_render(self): + return self.make_kernel_render + + def autoheuristic_id(self): + type_name = "triton" + info = self.info_dict() + # TODO(AlnisM): Does tile_shape always exist? + tile = info["tile_shape"] + tile_vals = eval(tile) # type: ignore[arg-type] + BLOCK_M = tile_vals[0] + BLOCK_K = tile_vals[1] + BLOCK_N = tile_vals[2] + num_stages = info["num_stages"] + num_warps = info["num_warps"] + return f"type={type_name}_BLOCK-M={BLOCK_M}_BLOCK-K={BLOCK_K}_BLOCK-N={BLOCK_N}_numstages={num_stages}_numwarps={num_warps}" + + +class ExternKernelCaller(ChoiceCaller): + def __init__( + self, + choice: ExternKernelChoice, + input_nodes, + layout, + kwargs=None, + *, + has_out_variant=True, + ) -> None: + super().__init__(choice.name, input_nodes, layout) + self.choice = choice + self.kwargs = kwargs or {} + self.has_out_variant = has_out_variant + + def __str__(self) -> str: + return f"ExternKernelCaller({self.choice.call_name()})" + + def benchmark(self, *args, out): + if out.numel() == 0: + # no need to run the kerrnel of do benchmarking + return 0.0 + if self.has_out_variant: + return super().benchmark(*args, out=out) + else: + algo = self.to_callable() + out_new = algo(*args) + torch._C._dynamo.guards.assert_size_stride( + out_new, tuple(out.size()), tuple(out.stride()) + ) + out.copy_(out_new) # for correctness checking + return benchmarker.benchmark(algo, args, {}) + + def to_callable(self): + fn = self.choice.to_callable() + if self.kwargs: + return functools.partial(fn, **self.kwargs) + else: + return fn + + def hash_key(self): + return "-".join( + [ + self.choice.name, + *[ + f"{kwarg}={repr(self.kwargs[kwarg])}" + for kwarg in sorted(self.kwargs.keys()) + ], + self.choice.hash_key(), + ] + ) + + def output_node(self): + if config.abi_compatible and self.choice.use_fallback_kernel: + assert ( + self.choice.op_overload is not None + ), "Please provide an op_overload to use ir.FallbackKernel" + inner = ir.FallbackKernel.create( + self.choice.op_overload, *self.input_nodes, **self.kwargs + ) + elif self.choice.kernel_creator is not None: + inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs) + else: + cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc + inner = cls( + layout=self.layout, + inputs=self.input_nodes, + python_kernel_name=self.choice.call_name(), + cpp_kernel_name=self.choice.cpp_kernel_name, + ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel, + op_overload=self.choice.op_overload, + kwargs=self.kwargs, + ) + + return ir.TensorBox.create(inner) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "extern", + "kernel_call_name": self.choice.call_name(), + } + + def autoheuristic_id(self): + return f"extern_{self.choice.name}" + + +@functools.lru_cache(None) +def get_mm_log_filename() -> Optional[str]: + mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None) + if not mm_file_name: + return None + + if "json" not in mm_file_name: + mm_file_name = f"{mm_file_name}.json" + + return mm_file_name + + +def append_to_log(filename, data): + lock_file = filename.replace(".json", ".lock") + lock = FileLock(lock_file) + with lock: + try: + with open(filename) as f: + log_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + log_data = [] + + log_data.append(data) + + with open(filename, "w") as f: + json.dump(log_data, f, indent=4) + + +class DataProcessorChoiceCallerWrapper: + def __init__(self, wrapped, preprocessor, postprocessor) -> None: + self._wrapped = wrapped + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def benchmark(self, *args, out) -> float: + new_args, new_out = self._preprocessor(args, out) + result = self._wrapped.benchmark(*new_args, out=new_out) + new_out = self._postprocessor(new_out) + if out is not new_out: + out.copy_(new_out) + return result + + def output_node(self) -> ir.TensorBox: + result = self._wrapped.output_node() + return self._postprocessor(result) + + def __repr__(self) -> str: + return f"DataProcessorChoiceCallerWrapper({self._wrapped})" + + +class DataProcessorTemplateWrapper: + """ + A wrapper class for a kernel template. + + This class together with `DataProcessorChoiceCallerWrapper` provides a convenient way to + preprocess and postprocess data before and after using the wrapped template. A typical + usage is to reorder or filter the input nodes in order to match the expected input of other + kernel choices like a ATen kernel. A more complicated usage is to prepack the weights. + See the example from :mod:`cpp_gemm_template` for more details. + """ + + def __init__( + self, + wrapped_template_cls, + preprocessor, + postprocessor, + **kwargs, + ) -> None: + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + assert "input_nodes" in kwargs + assert "layout" in kwargs + kwargs["input_nodes"], kwargs["layout"] = preprocessor( + kwargs["input_nodes"], kwargs["layout"] + ) + self._wrapped = wrapped_template_cls(**kwargs) + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def maybe_append_choice(self, choices, **kwargs): + return type(self._wrapped).maybe_append_choice(self, choices, **kwargs) + + def generate(self, **kwargs): + choice_caller = self._wrapped.generate(**kwargs) + return DataProcessorChoiceCallerWrapper( + choice_caller, self._preprocessor, self._postprocessor + ) + + def __repr__(self) -> str: + return f"DataProcessorTemplateWrapper({self._wrapped})" + + +class ErrorFromChoice(RuntimeError): + def __init__(self, msg, choice: ChoiceCaller, inputs_str) -> None: + msg += f"\nFrom choice {choice}\n{inputs_str}" + super().__init__(msg) + self.choice = choice + + +class NoValidChoicesError(RuntimeError): + pass + + +@functools.lru_cache(None) +def get_env_num_workers() -> Optional[int]: + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + return None + + +def create_inputs_key(input_nodes) -> str: + return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes]) + + +def create_precompile_key( + name: str, inputs_key: str, choices: List[ChoiceCaller] +) -> str: + return ":".join( + [ + name, + inputs_key, + torch.get_float32_matmul_precision(), + ] + + [choice.hash_key() for choice in choices] + ) + + +class AlgorithmSelectorCache(PersistentCache): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # the autotuning will get occur in the scheduler, so there is + # no guarantee that the first lowering for a given key will also be the + # first to benchmark it. share a single precompilation function for all lowerings + # of a particular key + self.precompile_cache: Dict[str, Callable[[], None]] = {} + # list of callbacks that are called after benchmarking + self.feedback_saver_fns: List[ + Callable[ + [Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None + ] + ] = [] + + def __call__( + self, + name, + choices: List[ChoiceCaller], + input_nodes, + layout, + # optional dict mapping arg indices to the functions + # generating a torch.Tensor for that input from the + # corresponding ir.Buffer. if passed for a given + # arg, the function will be called instead of + # generating a random torch.Tensor for benchmarking. + input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, + precompilation_timeout_seconds: int = 60 * 60, + return_multi_template=False, + ): + from .codegen.cuda.cuda_kernel import CUDATemplateCaller + + # Templates selected with input_gen_fns require specific input data to avoid IMA + # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection + # TODO(jgong5): support multi-template on CPU + if input_gen_fns is not None or layout.device.type == "cpu": + return_multi_template = False + + # TODO - assert that we have not mutating kernels here + + # TODO(nmacchioni): remove once CI tests are fixed + choices = [choice for choice in choices if choice is not None] + + if mm_file_name := get_mm_log_filename(): + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + append_to_log(mm_file_name, {"invoke": str((M, K, N))}) + + if len(choices) == 0: + backend_config = ( + "max_autotune_gemm_backends" + if name != "convolution" + else "max_autotune_conv_backends" + ) + raise NoValidChoicesError( + f"No choices to select, please consider adding ATEN into {backend_config} " + "config (defined in torch/_inductor/config.py) to allow at least one choice. " + ) + log.debug("Max autotune selects from %s choices.", str(len(choices))) + + if len(choices) == 1: + if not isinstance(choices[0], CUDATemplateCaller): + # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size. + return choices[0].output_node() + + @functools.lru_cache(None) + def make_benchmark_fn(): + return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) + + inputs_key = create_inputs_key(input_nodes) + + def precompile(choices) -> Callable[[], None]: + def no_op(*args, **kwargs): + return + + if ( + precompilation_timeout_seconds is None + or precompilation_timeout_seconds <= 0 + ): + return no_op + + env_workers = get_env_num_workers() + num_workers = env_workers if env_workers is not None else (len(choices)) + + if num_workers <= 0: + return no_op + + # https://github.com/python/cpython/issues/106905 + if ( + sys.version_info.major == 3 + and sys.version_info.minor == 11 + and sys.version_info.micro <= 8 + ): + return no_op + + # check local and global cache before precompiling + timings = self.lookup( + choices, + name, + inputs_key, + benchmark=None, + ) + + if timings: + return no_op + + precompile_key = create_precompile_key(name, inputs_key, choices) + if precompile_func := self.precompile_cache.get(precompile_key): + return precompile_func + + log.info( + "Multithreaded precompilation for %d choices using %d worker threads", + len(choices), + num_workers, + ) + + # In rare circumstances, because python threads inherit global state, + # thread pool executor can race and leave stdout/stderr in a state + # different than the original values. we explicitly restore the state + # here to avoid this issue. + + initial_stdout = sys.stdout + initial_stderr = sys.stderr + + def precompile_with_captured_stdout(choice): + with restore_stdout_stderr(initial_stdout, initial_stderr): + return choice.precompile() + + executor = ThreadPoolExecutor(max_workers=num_workers) + + futures = {} + for c in choices: + if hasattr(c, "precompile"): + future = executor.submit(precompile_with_captured_stdout, c) + futures[future] = c + + @functools.lru_cache(None) + @restore_stdout_stderr(initial_stdout, initial_stderr) + def wait_on_futures(): + counters["inductor"]["select_algorithm_precompile"] += 1 + for future in as_completed( + futures, + timeout=precompilation_timeout_seconds, + ): + if e := future.exception(): + log.error( + "Exception %s for benchmark choice %s", e, futures[future] + ) + + executor.shutdown(wait=True) + + self.precompile_cache[precompile_key] = wait_on_futures + + return wait_on_futures + + def autotune(choices): + return make_benchmark_fn()(choices) + + if config.autotune_in_subproc: + from .autotune_process import tuning_pool + + # do the optional warmup + tuning_pool.initialize() + + def do_autotuning(precompile_fn): + precompile_start_ts = time.time() + precompile_fn() + precompile_elapse = time.time() - precompile_start_ts + + autotune_start_ts = time.time() + timings = self.lookup( + choices, + name, + inputs_key, + autotune, + ) + autotune_elapse = time.time() - autotune_start_ts + + if timings and all( + not math.isfinite(timing) for timing in timings.values() + ): + raise NoValidChoicesError + + if make_benchmark_fn.cache_info().currsize: + counters["inductor"]["select_algorithm_autotune"] += 1 + + if ( + make_benchmark_fn.cache_info().currsize + or log.getEffectiveLevel() == logging.DEBUG + or config.trace.log_autotuning_results + ): + self.log_results( + name, input_nodes, timings, autotune_elapse, precompile_elapse + ) + + for feedback_fn in self.feedback_saver_fns: + feedback_fn(timings, name, input_nodes, choices) + + return timings + + precompile_fn = precompile(choices) + + if return_multi_template and (config.max_autotune or config.max_autotune_gemm): + + def get_timings(): + timings = do_autotuning(precompile_fn) + min_extern_choice = float("inf") + for choice, timing in timings.items(): + if isinstance(choice, ExternKernelCaller): + min_extern_choice = min(min_extern_choice, timing) + + timings = { + choice: time + for choice, time in timings.items() + if ( + time <= min_extern_choice + or not isinstance(choice, ExternKernelCaller) + ) + } + + return timings + + return torch._inductor.ir.TensorBox.create( + torch._inductor.ir.MultiTemplateBuffer( + layout, + input_nodes, + get_timings, + ) + ) + + # TODO - dont want to precompile if we have a cache hit + timings = do_autotuning(precompile_fn) + if timings == {} or choices[0] not in timings: + return choices[0].output_node() + + selected_key = builtins.min(timings, key=timings.__getitem__) + selected_time = timings[selected_key] + selected_choice = selected_key.output_node() + log.debug("selected choice: %s", str(selected_choice)) + return selected_choice + + @classmethod + def make_benchmark_fn( + cls, + choices, + input_nodes, + layout, + input_gen_fns=None, + ): + if input_gen_fns is None: + input_gen_fns = {} + + def get_inputs(): + # de-duplicate args + unique_example_inputs = { + x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x) + for i, x in enumerate(input_nodes) + } + example_inputs = list(unique_example_inputs.values()) + example_inputs_extern = [ + unique_example_inputs[input_node.get_name()] + if unique_example_inputs[input_node.get_name()].is_mkldnn + else torch.as_strided( + unique_example_inputs[input_node.get_name()], + V.graph.sizevars.size_hints( + input_node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + input_node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hint( + input_node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + for input_node in input_nodes + ] + + out = cls.benchmark_example_value(layout) + out_extern = torch.as_strided( + out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset) + ) + expected = None + if VERIFY: + choices[0].benchmark(*example_inputs_extern, out=out_extern) + expected = out_extern.clone() + + return example_inputs, example_inputs_extern, out, out_extern, expected + + if DEBUG: + print(f"{len(choices)} tuning requests:") + + def debug_str(example_inputs, out): + def tensor_repr(x): + return ( + f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, " + f"dtype={x.dtype!r}, device={x.device.type!r})" + ) + + lines = [ + "inputs = [", + ] + for x in example_inputs: + lines.append(f" {tensor_repr(x)},") + lines += ["]", f"out = {tensor_repr(out)}", ""] + return "\n".join(lines) + + def benchmark_choice_in_current_process( + choice, example_inputs, example_inputs_extern, out, out_extern, expected + ): + out.zero_() + if isinstance(choice, ExternKernelCaller): + # aten kernels want the offset baked in for sliced tensors + result = choice.benchmark(*example_inputs_extern, out=out_extern) + else: + # triton templates want the base pointer for sliced tensors + result = choice.benchmark(*example_inputs, out=out) + if VERIFY and expected is not None: + torch.testing.assert_close(out_extern, expected, **VERIFY) + if torch.cuda.is_available(): + torch.cuda.synchronize() # shake out any CUDA errors + return result + + def benchmark_in_current_process(choices): + inputs = get_inputs() + example_inputs, _, out, _, _ = inputs + timings = {} + for choice in choices: + try: + timing = benchmark_choice_in_current_process(choice, *inputs) + except CUDACompileError as e: + log.error( + "CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.", + str(e), + ) + timing = float("inf") + except NotImplementedError as e: + log.warning("Not yet implemented: %s", e) + timing = float("inf") + except RuntimeError as e: + msg = str(e) + if "invalid argument" in msg: + msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n" + else: + if "illegal memory access" in msg: + msg += "\n\nEither error in template or triton bug.\n" + log.error( + "Runtime error during autotuning: \n%s. \nIgnoring this choice.", + msg, + ) + timing = float("inf") + except AssertionError as e: + raise AssertionError( # noqa: B904 + f"Incorrect result from choice {choice}\n\n{e}" + ) + except Exception as e: + try: + from triton.runtime.autotuner import OutOfResources + + if isinstance(e, OutOfResources): + log.warning(e) + timing = float("inf") + else: + raise e + except ImportError: + raise e from None + + timings[choice] = timing + + return timings + + def benchmark_in_sub_process(choices): + from . import autotune_process + + # only benchmark triton kernel in sub process for now. + # ATen/Extern kernel are still benchmarked in the current process. + extern = [c for c in choices if isinstance(c, ExternKernelCaller)] + triton = [c for c in choices if not isinstance(c, ExternKernelCaller)] + + timings = benchmark_in_current_process(extern) + timings.update(autotune_process.benchmark_in_sub_process(triton)) + return timings + + benchmark = ( + benchmark_in_sub_process + if config.autotune_in_subproc + else benchmark_in_current_process + ) + + return benchmark + + @staticmethod + def log_results( + name: str, + input_nodes: List[ir.IRNode], + timings: Dict[ChoiceCaller, float], + elapse: float, + precompile_elapse: float, + ): + V.debug.log_autotuning_results( + name, input_nodes, timings, elapse, precompile_elapse + ) + if not (config.max_autotune or config.max_autotune_gemm) or not PRINT_AUTOTUNE: + return + sizes = ", ".join( + [ + "x".join( + map( + str, + V.graph.sizevars.size_hints( + n.get_size(), fallback=config.unbacked_symint_fallback + ), + ) + ) + for n in input_nodes + ] + ) + + n = None if log.getEffectiveLevel() == logging.DEBUG else 10 + top_k = sorted(timings, key=timings.__getitem__)[:n] + best = top_k[0] + + def get_choice_info(choice): + if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller): + return {"type": "cublas", "time": timings[choice]} + + assert isinstance( + choice, torch._inductor.select_algorithm.TritonTemplateCaller + ) + + info = choice.info_dict() + tile = info["tile_shape"] + + tile_vals = eval(tile) # type: ignore[arg-type] + BLOCK_M = tile_vals[0] + BLOCK_K = tile_vals[1] + BLOCK_N = tile_vals[2] + + return { + "type": "triton", + "time": timings[choice], + "BLOCK_M": BLOCK_M, + "BLOCK_K": BLOCK_K, + "BLOCK_N": BLOCK_N, + "num_stages": info["num_stages"], + "num_warps": info["num_warps"], + } + + mm_filename = get_mm_log_filename() + if mm_filename and "mm" in name: + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + + out_dict = { + str((M, K, N)): [get_choice_info(choice) for choice in timings.keys()] + } + + append_to_log(mm_filename, out_dict) + + best_time = timings[best] + sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") + for choice in top_k: + result = timings[choice] + if result: + kernel_info = ( + choice.debug_extra if hasattr(choice, "debug_extra") else "" + ) + sys.stderr.write( + f" {choice.name} {result:.4f} ms {best_time / result:.1%} {kernel_info}\n" + ) + else: + sys.stderr.write( + f" {choice.name} {result:.4f} ms \n" + ) + + autotune_type_str = ( + "SubProcess" if config.autotune_in_subproc else "SingleProcess" + ) + sys.stderr.write( + f"{autotune_type_str} AUTOTUNE benchmarking takes {elapse:.4f} seconds and {precompile_elapse:.4f}" + " seconds precompiling\n" + ) + + @staticmethod + def benchmark_example_value(node): + """ + Convert an ir.Buffer into a concrete torch.Tensor we can use for + benchmarking. + """ + if isinstance(node, ir.Layout): + node = ir.Buffer("fake", node) + # triton templates want the base tensor. + if isinstance(node, ir.BaseView): + node = node.unwrap_view() + return AlgorithmSelectorCache.generate_example_value( + V.graph.sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + node.get_device(), + node.get_dtype(), + node.layout.offset, + ) + + @staticmethod + def generate_example_value(size, stride, device, dtype, extra_size): + # preserve rng states to avoid the rand_strided call below changes + # the rng states for the real model code. + with preserve_rng_state(): + return rand_strided( + size, + stride, + device=device, + dtype=dtype, + extra_size=extra_size, + ) + + @staticmethod + def key_of(node): + """ + Extract the pieces of an ir.Buffer that we should invalidate cached + autotuning results on. + """ + sizevars = V.graph.sizevars + return ( + node.get_device().type, + str(node.get_dtype()), + *sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + *sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + + def add_feedback_saver( + self, + fn: Callable[ + [Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None + ], + ): + self.feedback_saver_fns.append(fn) + + +_ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None + + +def autotune_select_algorithm(*args, **kwargs): + global _ALGORITHM_SELECTOR_CACHE + if _ALGORITHM_SELECTOR_CACHE is None: + _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() + + if "return_multi_template" not in kwargs: + kwargs[ + "return_multi_template" + ] = torch._inductor.config.benchmark_epilogue_fusion + + return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs) + + +def add_feedback_saver( + fn: Callable[[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None] +): + global _ALGORITHM_SELECTOR_CACHE + if _ALGORITHM_SELECTOR_CACHE is None: + _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() + _ALGORITHM_SELECTOR_CACHE.add_feedback_saver(fn) + + +def realize_inputs(*args): + if len(args) == 1: + return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0])) + return [realize_inputs(x) for x in args] + + +# ensure lowering is imported so that `extern_kernels.*` is populated +from . import lowering # noqa: F401 diff --git a/lib/python3.10/site-packages/torch/_inductor/sizevars.py b/lib/python3.10/site-packages/torch/_inductor/sizevars.py new file mode 100644 index 0000000000000000000000000000000000000000..8d3c6d411d278c2ef893ad06f8195dbd2096572c --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/sizevars.py @@ -0,0 +1,892 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import sympy +from sympy import Expr + +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, ShapeEnv +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import symbol_is_type, SymT +from torch.utils._sympy.value_ranges import bound_sympy, IntInfinity, ValueRanges + +from .runtime.runtime_utils import is_power_of_2 +from .utils import ( + has_free_symbols, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_subs, + VarRanges, +) +from .virtualized import V + + +log = logging.getLogger(__name__) + + +def evaluate_expr( + shape_env: ShapeEnv, + expr: Union[sympy.Basic, bool], + axioms: Optional[Tuple[sympy.Expr]] = None, + var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges[Any]]]] = None, +) -> bool: + if expr in (True, False): + return bool(expr) + + try: + simplified = shape_env._maybe_evaluate_static( + expr, + axioms=axioms, + var_to_range=var_to_range, + ) + if simplified is not None: + return bool(simplified) + except Exception: + log.debug("Could not simplify %s", expr, exc_info=True) + + return False + + +# This class is a little awkward, because ShapeEnv is doing most of the heavy +# lifting and in some cases we should be directly passing through to ShapeEnv, +# but there is some extra inductor logic that needs to be handled here +class SizeVarAllocator: + def __init__(self, shape_env=None) -> None: + super().__init__() + if shape_env is None: + shape_env = ShapeEnv() + self.shape_env = shape_env + self.var_to_val = self.shape_env.var_to_val + self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements + # Maps of dynamic sizes that have to be precomputed on the host to the kernel args. + # The basic idea is if we have some complicated sympy expression + # f(s0), we may choose to precompute it on the host and then replace + # all occurrences of that sympy expression with ps0, so that when we + # codegen we simply reference ps0 directly without repeating + # f(s0). Unlike regular size variables, ps variables cannot be + # guarded upon; so if we are asked to guard on a Sympy expression + # which potentially could have already had a precomputed replacement + # on it, we are obligated to invert the precomputed replacements + # (inv_precomputed_replacements). + self.precomputed_replacements: Dict[Expr, sympy.Symbol] = {} + self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = {} + self.stride_vars = self.make_stride_vars_cache() + self.simplify_with_ranges = self.make_simplify_with_ranges_cache() + self._simplify_loops = self.make_simplify_loops_cache() + + def simplify(self, expr: Expr): + return sympy.expand(expr).xreplace(self.replacements) + + def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]: + """ + self._simplify_with_ranges() can be expensive, cache its results + """ + cache: Dict[Tuple[Any, ...], Expr] = {} + replacement_count = len(self.replacements) + + def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr: + nonlocal replacement_count + if replacement_count != len(self.replacements): + # new replacements invalidates cached results + cache.clear() + replacement_count = len(self.replacements) + key = (expr, *var_ranges.items()) + result = cache.get(key, None) + if result is None: + result = self._simplify_with_ranges(expr, var_ranges) + cache[key] = result + return result + + return simplify_with_ranges + + def make_simplify_loops_cache(self): + """ + self._simplify_with_ranges() can be expensive, cache its results + """ + cache: Dict[Tuple[Any, ...], Any] = {} + replacement_count = len(self.replacements) + + def simplify_loops(index_vars, sizes, index_formulas): + nonlocal replacement_count + if replacement_count != len(self.replacements): + # new replacements invalidates cached results + cache.clear() + replacement_count = len(self.replacements) + key = (*index_vars, *sizes, *index_formulas) + result = cache.get(key, None) + if result is None: + result = self._simplify_loops_impl(index_vars, sizes, index_formulas) + cache[key] = result + return result + + return simplify_loops + + def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr: + """ + Simplify indexing expression with knowledge of the ranges of + iteration variables. + """ + + expr = join_dimensions(self.simplify(expr)) + original_expr = expr + + var_to_range = dict(self.shape_env.var_to_range) + var_to_range.update( + { + k: ValueRanges( + 0, max(0, v - 1) if not has_free_symbols([v]) else IntInfinity() + ) + for k, v in var_ranges.items() + } + ) + for var in expr.free_symbols: + if var not in var_to_range: + var_to_range[var] = ValueRanges(0, IntInfinity()) + + var_to_range_tuple = cast( + Tuple[Tuple[sympy.Symbol, ValueRanges[sympy.Expr]]], + tuple(var_to_range.items()), + ) + + axioms = [] + for var, upper_bound in var_ranges.items(): + axioms.append(0 <= var) + axioms.append(var < upper_bound) + axioms = tuple(axioms) + self.shape_env.get_axioms() + + def statically_known(expr): + evaluated = self.shape_env._maybe_evaluate_static( + expr, + axioms=axioms, + var_to_range=var_to_range_tuple, + ) + return bool(evaluated) + + def remove_zero_terms(base, divisor): + """Symbols smaller than the divisor are zero""" + if not statically_known(base >= 0): + return base + + for v in base.free_symbols: + if v in var_ranges: + # var smaller than divisor can be removed + # if the rest is guaranteed to be multiple of divisor + rest = sympy.Wild("_rest", exclude=[v]) + m = base.match(v + rest) + if m and v not in m[rest].free_symbols: + gcd = sympy.gcd(m[rest], divisor) + if gcd == divisor: + if statically_known(v < divisor): + base = m[rest] + return base + + def visit_indexing_div(base, divisor): + return FloorDiv(remove_zero_terms(base, divisor), divisor) + + def visit_modular_indexing(base, divisor, modulus): + base = remove_zero_terms(base, divisor) + + can_remove_mod = statically_known(base >= 0) and statically_known( + base < modulus * divisor + ) + + if can_remove_mod: + return FloorDiv(base, divisor) + return ModularIndexing(base, divisor, modulus) + + if expr.has(ModularIndexing): + expr = expr.replace( + ModularIndexing( + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + sympy.Wild("modulus", integer=True), + ), + visit_modular_indexing, + ) + + if expr.has(FloorDiv): + expr = expr.replace( + FloorDiv( + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + ), + visit_indexing_div, + ) + + if expr != original_expr: + return self._simplify_with_ranges(expr, var_ranges) + return expr + + def _simplify_loops_impl( + self, index_vars: List[sympy.Symbol], sizes, index_formulas + ): + """ + Try to remove as many axis from loop iterations as possible, by: + 1) removing size==1 dimensions + 2) fuse contiguous dimensions into a single loop + If channel_last = True, we will prevent the last dim fused with other dims + """ + sizes = list(map(self.simplify, sizes)) + + strides = [ + # index_formulas may contain boolean expressions (e.g. s0 < 10), + # for which "strides" don't make sense so we ignore them here. + # NOTE: These expressions may still block merging dims in the sound + # substitution test performed in can_merge_dims. + self.stride_vars(x, index_vars) + if isinstance(x, sympy.Expr) + else [0] * len(index_vars) + for x in index_formulas + ] + assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0])) + + for i in range(len(sizes)): + if sizes[i] == 1: + # remove dim + sizes[i] = None + + def can_merge_dims(a, b): + for k in range(len(strides)): + if self.simplify(strides[k][a] * sizes[a]) == self.simplify( + strides[k][b] + ): + # approximate test passed, try sound version + va = index_vars[a] + vb = index_vars[b] + m1 = sympy_index_symbol("_merge_tester1") + m2 = sympy_index_symbol("_merge_tester2") + # NOTE: can't sub vb=0 here in case va * vb appears in the expression, + # in which case both expr1 and expr2 would be zero! + expr1 = sympy_subs(index_formulas[k], {va: m1 * sizes[a], vb: m2}) + expr2 = sympy_subs(index_formulas[k], {va: 0, vb: (m1 + m2)}) + if self.simplify(expr1) == self.simplify(expr2): + continue + return False + return True + + changed = True + while changed: + changed = False + for i, j in itertools.product( + reversed(range(len(sizes))), reversed(range(len(sizes))) + ): + if i == j or sizes[i] is None or sizes[j] is None: + continue + if can_merge_dims(i, j): + changed = True + sizes[i] = sizes[i] * sizes[j] + sizes[j] = None + + def reindex(index): + it = list(reversed(index)) + new_index = [] + for size in sizes: + if size is None: + new_index.append(sympy.Integer(0)) + else: + new_index.append(it.pop()) + assert not it + return new_index + + def prune(index): + assert len(index) == len(sizes) + return [i for i, s in zip(index, sizes) if s is not None] + + return [x for x in sizes if x is not None], reindex, prune + + # Note - [On Statically Known] + # + # The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system + # operated by providing essentially a question, where the size hinted values were evaluated. If the condition was + # true, we add a guard and return True, otherwise, False. + # + # def maybe_guard_foo(args): + # if size_hinted_check(args): + # return False # No guard, no optim + # guard(args) # Make a guard + # return True # Safe to apply optimization + # + # The prior system incurred a guard, and green lit an optimization. + # + # The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the + # condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we + # return False. + # + # def maybe_guard_foo(args): + # if all_static(args): + # return True # Safe to apply optimization + # else: + # return False # No guard, no optim + + # See Note - [On Statically Known] + + def is_expr_static_and_true(self, expr: Union[sympy.Basic, bool]) -> bool: + return evaluate_expr(self.shape_env, expr) + + def statically_known_equals( + self, left: Union[Expr, int], right: Union[Expr, int] + ) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left and right are equal. + """ + return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type] + + # See Note - [On Statically Known] + def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left and right lists are equal. + """ + return len(left) == len(right) and all( + self.statically_known_equals(l, r) for l, r in zip(left, right) + ) + + # See Note - [On Statically Known] + def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is less than or equal to right. + """ + expr = left <= right + return self.is_expr_static_and_true(expr) + + # See Note - [On Statically Known] + def statically_known_geq(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right. + """ + expr = left >= right + return self.is_expr_static_and_true(expr) + + # See Note - [On Statically Known] + def statically_known_lt(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is less than right. + """ + expr = left < right + return self.is_expr_static_and_true(expr) + + # See Note - [On Statically Known] + def statically_known_gt(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is greater than right. + """ + expr = left > right + return self.is_expr_static_and_true(expr) + + # See Note - [On Statically Known] + def statically_known_multiple_of( + self, numerator: Expr, denominator: Union[Expr, int] + ) -> bool: + """ + Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator. + """ + if free_unbacked_symbols(numerator) or free_unbacked_symbols(denominator): + return False + expr = sympy.Eq(numerator % denominator, 0) + return self.is_expr_static_and_true(expr) # type: ignore[arg-type] + + # See Note - [On Statically Known] + def statically_known_power_of_2(self, expr: Expr) -> bool: + """ + Returns a bool indicating if x is known to be a power of 2. + """ + return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr)) + + # The guard functions require you to ALREADY KNOW that a particular + # condition holds. If you don't know (you want to guard on an expression + # being a particular value, and then get access to that value), use + # the evaluate functions. + + def guard_equals(self, left: Expr, right: Expr) -> Expr: + if isinstance(left, Expr): + left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] + if isinstance(right, Expr): + right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] + assert self.shape_env.evaluate_expr(sympy.Eq(left, right)) + return left + + def guard_leq(self, left: Expr, right: Expr) -> None: + return self.guard_lt(left, right + 1) + + def guard_lt(self, left: Expr, right: Expr) -> None: + assert self.shape_env.evaluate_expr(sympy.Lt(left, right)) + + def guarded_order(self, seq): + """ + Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing. + """ + seq = [*map(self.remove_precomputed_replacements, seq)] + seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)] + seq.sort() + order = [-1] * len(seq) + last_var = None + for new_index, (_, orig_index, var) in enumerate(seq): + order[orig_index] = new_index + if last_var is not None: + self.guard_leq(last_var, var) + last_var = var + return order + + # The evaluate functions evaluate some symbolic sympy expression + # (NB: not necessarily an Expr) and return what the concrete result + # is, guarding on the expression being that result + + # NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b) + # as this will ensure that you actually have a sympy'ified expression, + # and will prevent you from incorrectly writing evaluate_expr(a == b) + # which does the wrong thing if a or b is a sympy expression + def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool: + assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left) + return self.shape_env.evaluate_expr(sympy.sympify(left)) + + def evaluate_min(self, left: Expr, right: Expr) -> Expr: + """return the smaller of left and right, and guard on that choice""" + if isinstance(left, Expr): + left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] + if isinstance(right, Expr): + right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] + try: + lv = self.size_hint(left) + rv = self.size_hint(right) + except TypeError: # unbacked symints + if left == right or self.statically_known_leq(left, right): + return left + if self.statically_known_leq(right, left): + return right + gcd = sympy.gcd(left, right) + if left == gcd: # handle `min(10*u0, u0)` etc + return left + if right == gcd: + return right + raise TypeError( + f"evaluate_min({left}, {right}) with unbacked symints" + ) from None + if lv <= rv: + self.guard_leq(left, right) + return left + else: + self.guard_leq(right, left) + return right + + def evaluate_max(self, left: Expr, right: Expr) -> Expr: + """return the larger of left and right, and guard on that choice""" + # Always choose the opposite of eval min for consistency + # This means min(a, b) and max(a, b) produce the same guards + min_val = self.evaluate_min(left, right) + return right if min_val is left else left + + def evaluate_static_shape(self, left: Union[Expr, int]) -> int: + if isinstance(left, int): + return left + right = self.size_hint(left) + self.guard_equals(left, sympy.Integer(right)) + return int(right) + + def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> List[int]: + return [self.evaluate_static_shape(x) for x in left] + + def remove_precomputed_replacements(self, expr: Expr) -> Expr: + if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined] + return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type] + return expr + + def symbolic_hint(self, expr: Union[Expr, int]) -> Union[Expr, int]: + if isinstance(expr, int): + return expr + # Substitute all hints into expr, but leave unbacked symints alone + expr = self.simplify(expr) + if not isinstance(expr, Expr): + assert isinstance(expr, int) + return expr + free_symbols = expr.free_symbols + if not free_symbols: + try: + return int(expr) # type: ignore[return-value] + except TypeError: + return expr # inf/nan/I + expr = self.remove_precomputed_replacements(expr) + return sympy_subs(expr, self.var_to_val) + + def size_hint( + self, expr: Union[Expr, int], *, fallback: Optional[int] = None + ) -> int: + out = self.symbolic_hint(expr) + if not isinstance(out, (int, sympy.Integer)) and fallback is not None: + # Use the provided heuristic fallback hint + unbacked_sym_vrs = { + s: self.shape_env.var_to_range.get(s, None) for s in out.free_symbols + } + if all(vr is not None for vr in unbacked_sym_vrs.values()): + hint_vr = bound_sympy(out, unbacked_sym_vrs) # type: ignore[arg-type] + if isinstance(hint_vr.lower, (int, sympy.Integer)): + fallback = max(fallback, int(hint_vr.lower)) + if isinstance(hint_vr.upper, (int, sympy.Integer)): + fallback = min(fallback, int(hint_vr.upper)) + return fallback + + try: + return int(out) + except Exception: + log.debug("failed on: %s", out) + raise + + def size_hints( + self, + exprs: Iterable[Expr], + *, + fallback: Optional[int] = None, + ) -> Tuple[int, ...]: + return tuple(self.size_hint(x, fallback=fallback) for x in exprs) + + def _lru_cache(self, fn, maxsize=None): + """ + Wrapper around functools.lru_cache that clears when replacements + has been invalidated. + """ + fn_cache = functools.lru_cache(maxsize)(fn) + prior_len = len(self.replacements) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + nonlocal prior_len + if prior_len != len(self.replacements): + prior_len = len(self.replacements) + fn_cache.cache_clear() + return fn_cache(*args, **kwargs) + + return wrapper + + def make_stride_vars_cache(self): + cache = self._lru_cache(self._stride_vars) + + def stride_vars( + index: Expr, + vars: Sequence[sympy.Symbol], + support_vars: Optional[Sequence[sympy.Symbol]] = None, + ) -> List[Expr]: + if not support_vars: + support_vars = vars + return cache(index, tuple(vars), tuple(support_vars)) + + return stride_vars + + def _stride_vars( + self, + index: Expr, + vars: Sequence[sympy.Symbol], + support_vars: Sequence[sympy.Symbol], + ) -> List[Expr]: + """Convert an indexing expression back into strides + + NOTE: This is only valid if the index is a standard strided offset + calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a + stride of -10 because the index wraps around after the first element + + """ + strides = [] + index = self.simplify(index) + # remove any offset + index = index - sympy_subs( + index, {v: sympy.Integer(0) for v in support_vars if v != 0} + ) + for i in range(len(vars)): + # drop all the other dims + index_dim = sympy_subs( + index, + { + support_vars[j]: sympy.Integer(0) + for j in range(len(support_vars)) + if vars[i] != support_vars[j] and support_vars[j] != 0 + }, + ) + v = vars[i] + if v == 0: + strides.append(sympy.Integer(0)) + else: + # TODO(jansel): should we use sympy.diff here? + strides.append( + sympy_subs(index_dim, {v: sympy.Integer(1)}) + - sympy_subs(index_dim, {v: sympy.Integer(0)}) + ) + return strides + + def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr: + """Extract offset part of an indexing expression""" + index = self.simplify(index) + return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0}) + + def stride_hints( + self, + index: Expr, + vars: Sequence[sympy.Symbol], + support_vars: Optional[Sequence[sympy.Symbol]] = None, + ) -> List[int]: + for v in index.free_symbols: + if symbol_is_type(v, SymT.INDIRECT): # type: ignore[attr-defined] + index = sympy_subs(index, {v: 0}) # type: ignore[dict-item] + result = [] + for s in self.stride_vars(index, vars, support_vars): + try: + result.append(self.size_hint(s)) + except TypeError: + result.append(0) + return result + + def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]: + strides = tuple(map(abs, self.stride_hints(index, vars))) + order = list(range(len(strides))) + order.sort(key=lambda x: (strides[x] == 0, strides[x])) + return order + + def lookup_precomputed_size(self, expr: Expr) -> Expr: + if ( + isinstance(expr, (int, sympy.Symbol, sympy.Number)) + or expr.is_number + or expr.is_symbol + ): + return expr + expr = self.remove_precomputed_replacements(expr) + if expr not in self.precomputed_replacements: + sym = sympy_index_symbol_with_prefix( + SymT.PRECOMPUTED_SIZE, len(self.precomputed_replacements) + ) + self.precomputed_replacements[expr] = sym + self.inv_precomputed_replacements[sym] = expr + return self.precomputed_replacements[expr] + + def free_symbols(self) -> Set[sympy.Symbol]: + return set(self.var_to_val.keys()) - set(self.replacements.keys()) + + def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: + """ + A pair of special ModularIndexing can be combined. + + E.g. ModularIndexing(ModularIndexing(x, 1, a), 1, b) + We can simplify this to ModuleIndexing(x, 1, b), if + 1. x is non negative integer + 2. a and b are positive integers + 3. a is a multiple of b. + """ + + def _check_args(x, div, mod, is_first): + if not isinstance(div, sympy.Integer) or not isinstance(mod, sympy.Integer): + return False + if div != 1: + return False + if mod <= 0: + return False + + if is_first: + # first ModularIndexing should conatins a nested ModularIndex + if not isinstance(x, ModularIndexing): + return False + else: + # second ModularIndexing should constains a non-negative + # symbol + if not isinstance(x, sympy.Symbol) or not self.statically_known_geq( + x, 0 + ): + return False + return True + + if isinstance(index, ModularIndexing): + x, div, mod = index.args + + if not _check_args(x, div, mod, True): + return index + + x2, div2, mod2 = x.args + + if not _check_args(x2, div2, mod2, False): + return index + + if mod2 % mod != 0: + return index + + return ModularIndexing(x2, 1, mod) + + return index + + def expand_floor_div( + self, index: sympy.Expr + ) -> Union[bool, Tuple[sympy.Expr, sympy.Expr]]: + """ + Expand the FloorDiv to the entire expression so that the expression may + be simplfied. + + E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables + x1, x2, index expression 'x1 * 2b + x2' can be easily combined. + But index expression 'x1 * b + x2 // 2' can not. + By expanding the FloorDiv to the entire expression, we get + '(x1 * 2b + x2) // 2'. This transformation allows us to merge loops + for the numerator! + + Return false if this optimization can be applied; + Return the new expression and the denominator otherwise. + The original expression will be equivalent to 'new_expression // denominator' + """ + if not isinstance(index, sympy.Add): + return False + terms = index.args + + if len(terms) < 2: + return False + floor_div_index = -1 + varlist = [] + factorlist = [] + for idx, term in enumerate(terms): + if isinstance(term, sympy.Mul): + # For dynamic shape, term like '2*s1*x1' has 3 child nodes. + # - A integer for 2 + # - A symbol for s1 + # - A symbol for x1 + # Skip for now. + if len(term.args) != 2: + return False + factor, var = term.args + varlist.append(var) + factorlist.append(factor) + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + # It's easier to reason about the correceness of the transformation + # for non-negative integers. + if not self.statically_known_geq(var, 0): + return False + elif isinstance(term, FloorDiv): + var, factor = term.args + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + if not self.statically_known_geq(var, 0): + return False + if floor_div_index >= 0: + # can not handle multi FloorDiv yet + return False + + floor_div_index = idx + varlist.append(var) + # this factor is denominator + factorlist.append(factor) + else: + return False + + if floor_div_index < 0: + return False + + # Construct the new expression and remember the denominator + denominator = factorlist[floor_div_index] + new_index = sympy.Integer(0) + + for var, factor, idx in zip(varlist, factorlist, itertools.count()): + if idx == floor_div_index: + new_index += var + else: + new_index += (factor * denominator) * var + + return new_index, denominator + + +def join_dimensions(expr: Expr) -> Expr: + if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing): + return expr # fast exit path + return _join_dimensions_cached(expr) + + +@functools.lru_cache(256) +def _join_dimensions_cached(expr: Expr) -> Expr: + """ + ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4) + becomes + ModularIndexing(i0, 1, 128) + ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32) + becomes i0 + + + This type of pattern can come from view operations + """ + assert isinstance(expr, sympy.Add) + + scale = sympy.Wild("scale", exclude=[0], integer=True) + base = sympy.Wild("base", integer=True) + divisor = sympy.Wild("divisor", integer=True) + mod1 = sympy.Wild("modulus", integer=True) + mod2 = sympy.Wild("modulus2", integer=True) + for term1 in expr.args: + m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) + if m1: + for term2 in expr.args: + m2 = term2.match( + m1[scale] + * m1[mod1] + * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2) + ) + if m2 and term1 != term2: + expr = join_dimensions( + expr + - term1 + - term2 + + m1[scale] + * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2]) + ) + return expr + for term1 in expr.args: + m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) + if m1: + for term2 in expr.args: + m2 = term2.match( + m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1]) + ) + if m2 is not None: # in case of success we get an empty dict here + expr = join_dimensions( + expr + - term1 + - term2 + + m1[scale] * FloorDiv(m1[base], m1[divisor]) + ) + return expr + return expr + + +class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined] + """ + A wrapper around .virtualize.ops that uses var range information to + simplify ModularIndexing/FloorDiv. + """ + + def __init__(self, inner, var_ranges: VarRanges) -> None: + super().__init__(inner) + self.name = "SimplifyIndexing" + self._simplify: Callable[ + [Expr], Expr + ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges) + + def load(self, name: str, index: sympy.Expr): + return self._inner.load(name, self._simplify(index)) + + def store(self, name, index, value, mode=None): + return self._inner.store(name, self._simplify(index), value, mode=mode) + + def store_reduction(self, name, index, value): + return self._inner.store_reduction(name, self._simplify(index), value) + + def index_expr(self, index, dtype): + return self._inner.index_expr(self._simplify(index), dtype) + + def check_bounds(self, index, size, lower, upper): + return self._inner.check_bounds(self._simplify(index), size, lower, upper) diff --git a/lib/python3.10/site-packages/torch/_inductor/subgraph_lowering.py b/lib/python3.10/site-packages/torch/_inductor/subgraph_lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..21145375ede8fa1eb803955c435564238148fb1c --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/subgraph_lowering.py @@ -0,0 +1,155 @@ +"""Utilities for lowering subgraphs used by higher order operators + +""" + +import functools +import operator +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union +from typing_extensions import ParamSpec + +import torch + +from . import ir +from .exc import SubgraphLoweringException +from .ops_handler import SimpleCSEHandler +from .sizevars import SizeVarAllocator +from .virtualized import ops, V, WrapperHandler + + +T = TypeVar("T") +_P = ParamSpec("_P") + + +class PointwiseSubgraphLowering(torch.fx.Interpreter): + graph_outputs: Optional[List[ir.IRNode]] + + def __init__( + self, + gm: torch.fx.GraphModule, + root_graph_lowering: "torch._inductor.graph.GraphLowering", + ) -> None: + super().__init__(gm) + self.graph_outputs = None + self.root_graph = root_graph_lowering + + @property + def sizevars(self) -> SizeVarAllocator: + return self.root_graph.sizevars + + def mark_buffer_mutated(self, name: str) -> None: + raise SubgraphLoweringException("Mutations are not supported in this context") + + def register_buffer(self, buffer: ir.Buffer) -> str: + raise SubgraphLoweringException( + "Buffer creation is not supported in this context" + ) + + def call_function( + self, + target: Callable[[Any], Any], # type: ignore[override] + args: Any, + kwargs: Dict[str, Any], + ) -> Any: + from .lowering import lowerings + + if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): + return super().call_function(target, args, kwargs) + + assert isinstance(target, torch._ops.OpOverload) + + if target not in lowerings: + raise SubgraphLoweringException( + f"{target} not supported in subgraph, (missing lowering)" + ) + + if torch.Tag.pointwise not in target.tags: + raise SubgraphLoweringException( + f"Only pointwise operators are supported in this context, but got {target}" + ) + + return lowerings[target](*args, **kwargs) + + def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override] + assert len(args) == 1 + self.graph_outputs = args[0] + + +@dataclass +class InputDescriptor: + dtype: torch.dtype + device: torch.device + + +class TracingOpsHandler(WrapperHandler[T]): + def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None: + parent = tracer.create_proxy("placeholder", "ops", (), {}) + super().__init__(parent) + self.tracer = tracer + + self.placeholders = [ + self.tracer.create_proxy("placeholder", f"input{i}", (), {}) + for i in range(num_inputs) + ] + + def placeholder(self, idx: int) -> torch.fx.Proxy: + return self.placeholders[idx] + + def output(self, *args: Tuple[object]) -> torch.fx.Node: + return self.tracer.create_node( + "output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {} + ) + + +def lower_pointwise_subgraph( + subgraph: ir.Subgraph, inputs: List[InputDescriptor] +) -> Callable[_P, Any]: + # Lower subgraph to ir.Pointwise nodes + def fake_inner_fn( + loop_idx: int, input_idx: int + ) -> Union[ir.Expr, ir.TensorBox, None]: + return ops.placeholder(input_idx) + + graph_inputs = [ + ir.Pointwise.create( + device=desc.device, + dtype=desc.dtype, + inner_fn=functools.partial(fake_inner_fn, input_idx=i), + ranges=[], + ) + for i, desc in enumerate(inputs) + ] + gm = subgraph.graph_module + pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph) + with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] + pw_subgraph.run(*graph_inputs) + + # Combine multiple pointwise computations into a single graph module + # Do this by tracing through each individually and doing CSE + tracer = torch.fx.Tracer() + tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) + trace_ops = SimpleCSEHandler(TracingOpsHandler(tracer, len(inputs))) + assert pw_subgraph.graph_outputs is not None + + with V.set_ops_handler(trace_ops): + output_irs = [] + + for out_var in pw_subgraph.graph_outputs: + assert isinstance(out_var, ir.TensorBox), type(out_var) + assert out_var.get_size() == [] + assert isinstance(out_var.data, ir.StorageBox) + assert isinstance(out_var.data.data, ir.Pointwise) + + idx = () + ir_out = out_var.data.data.inner_fn(idx) + + output_irs.append(ir_out) + + ops.output(*output_irs) + + lowered_gm = torch.fx.GraphModule({}, tracer.graph) + + def inner_fn(*args: _P.args, **kwargs: _P.kwargs) -> Any: + return lowered_gm(V.get_ops_handler(), *args, **kwargs) + + return inner_fn diff --git a/lib/python3.10/site-packages/torch/_inductor/test_case.py b/lib/python3.10/site-packages/torch/_inductor/test_case.py new file mode 100644 index 0000000000000000000000000000000000000000..53a791685c67ed005f8d88f8d64661187e63e0bb --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/test_case.py @@ -0,0 +1,35 @@ +# mypy: allow-untyped-defs +import contextlib +import os + +from torch._dynamo.test_case import ( + run_tests as dynamo_run_tests, + TestCase as DynamoTestCase, +) +from torch._inductor import config +from torch._inductor.utils import fresh_inductor_cache + + +def run_tests(needs=()): + dynamo_run_tests(needs) + + +class TestCase(DynamoTestCase): + """ + A base TestCase for inductor tests. Enables FX graph caching and isolates + the cache directory for each test. + """ + + def setUp(self): + super().setUp() + self._inductor_test_stack = contextlib.ExitStack() + self._inductor_test_stack.enter_context(config.patch({"fx_graph_cache": True})) + if ( + os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1" + and os.environ.get("TORCH_COMPILE_DEBUG") != "1" + ): + self._inductor_test_stack.enter_context(fresh_inductor_cache()) + + def tearDown(self): + super().tearDown() + self._inductor_test_stack.close() diff --git a/lib/python3.10/site-packages/torch/_inductor/test_operators.py b/lib/python3.10/site-packages/torch/_inductor/test_operators.py new file mode 100644 index 0000000000000000000000000000000000000000..a5c1d401f2d0207503273dc709b01e50a288fe1c --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/test_operators.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +import torch.library +from torch import Tensor +from torch.autograd import Function + + +if not torch._running_with_deploy(): + _test_lib_def = torch.library.Library("_inductor_test", "DEF") + _test_lib_def.define( + "realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag + ) + + _test_lib_impl = torch.library.Library("_inductor_test", "IMPL") + for dispatch_key in ("CPU", "CUDA", "Meta"): + _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) + + class Realize(Function): + @staticmethod + def forward(ctx, x): + return torch.ops._inductor_test.realize(x) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + def realize(x: Tensor) -> Tensor: + return Realize.apply(x) diff --git a/lib/python3.10/site-packages/torch/_inductor/utils.py b/lib/python3.10/site-packages/torch/_inductor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b17f3a68559a60d0d84f814739cc3b325b07594c --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/utils.py @@ -0,0 +1,2037 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import enum +import functools +import inspect +import io +import itertools +import logging +import math +import operator +import os +import platform +import shutil +import sys +import tempfile +import textwrap +import time +import unittest +from datetime import datetime +from io import StringIO +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + NamedTuple, + Optional, + Protocol, + Sequence, + Set, + TypeVar, + Union, + ValuesView, +) +from typing_extensions import Concatenate, ParamSpec +from unittest import mock + +import sympy + +import torch + + +GPU_TYPES = ["cuda", "xpu"] + + +# defines here before import torch._dynamo is for avoiding circular import +# when get_gpu_type is imported from dynamo +@functools.lru_cache(None) +def get_gpu_type(): + avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()] + assert len(avail_gpus) <= 1 + gpu_type = "cuda" if len(avail_gpus) == 0 else avail_gpus.pop() + return gpu_type + + +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import detect_fake_mode +from torch.autograd import DeviceType +from torch.autograd.profiler_util import EventList +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import ShapeProp +from torch.utils._sympy.functions import ( + CeilDiv, + CleanDiv, + FloorDiv, + Identity, + ModularIndexing, +) +from torch.utils._sympy.symbol import make_symbol, SymT +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from . import config +from .runtime.runtime_utils import ceildiv as runtime_ceildiv + + +_IS_WINDOWS = sys.platform == "win32" + +log = logging.getLogger(__name__) + +_T = TypeVar("_T") +VarRanges = Dict[sympy.Expr, sympy.Expr] +InputType = Union[torch.Tensor, int] + + +GPU_ALIGN_BYTES = 16 +ALIGNMENT = 16 + +ALIGN_BYTES = 64 +assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" + + +def _align(nbytes): + """Round up to the nearest multiple of ALIGN_BYTES""" + return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES + + +def _is_aligned(v: sympy.Expr): + """v can be statically proven to be a multiple of ALIGN_BYTES""" + if isinstance(v, (sympy.Add, sympy.Max)): + return all(map(_is_aligned, v.args)) + return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES + + +class align(sympy.Function): + """Symbolically round up to the nearest multiple of ALIGN_BYTES""" + + nargs = (1,) + is_integer = True + + @classmethod + def eval(cls, value): + if isinstance(value, (int, sympy.Integer)): + return _align(int(value)) + if _is_aligned(value): + return value + + +def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float: + """ + Returns benchmark results by examining torch profiler events. + This could be more accurate as it doesn't count CPU side overhead. + However, this also requires manually excluding irrelevant event, e.g. + vectorized_elementwise_kernel which is used to fill L2 cache, + various CUDA events, etc, so could also be fragile. + """ + + fn() + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + + # Warm-up + for _ in range(n_warmup): + fn() + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + # Benchmark + for i in range(n_repeat): + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + fn() + # Record clocks + torch.cuda.synchronize() + + log.debug("raw events") + log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1)) + + filtered_events = EventList( + [ + event + for event in p.events() + if event.device_type == DeviceType.CUDA and event.name != "Context Sync" + ] + ) + if len(filtered_events) % n_repeat != 0: + raise RuntimeError( + "Failed to divide all profiling events into #repeat groups. " + "#CUDA events: %d, #repeats: %s", + len(filtered_events), + n_repeat, + ) + num_event_per_group = len(filtered_events) / n_repeat + actual_events = EventList( + [ + event + for i, event in enumerate(filtered_events) + if i % num_event_per_group != 0 + ] + ) + actual_events._build_tree() + actual_events = actual_events.key_averages() + + log.debug("profiling time breakdown") + log.debug(actual_events.table(row_limit=-1)) + + res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat + log.debug("profiling results: %s ms", res) + return res + + +@functools.lru_cache(None) +def has_torchvision_roi_align() -> bool: + try: + from torchvision.ops import roi_align # noqa: F401 + + torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta") + return roi_align is not None and hasattr( + getattr(torch.ops, "torchvision", None), "roi_align" + ) + except ImportError: + return False + except RuntimeError as e: + assert "torchvision::nms does not exist" in str(e) + return False + + +def decode_device(device: Union[Optional[torch.device], str]) -> torch.device: + if device is None: + return torch.tensor(0.0).device # default device + if isinstance(device, str): + device = torch.device(device) + if device.type not in ("cpu", "meta") and device.index is None: + device_interface = get_interface_for_device(device.type) + return torch.device(device.type, index=device_interface.Worker.current_device()) + return device + + +def sympy_product(it): + return functools.reduce(operator.mul, it, sympy.Integer(1)) + + +def sympy_dot(seq1, seq2): + assert len(seq1) == len(seq2) + return sympy.expand(sum(a * b for a, b in zip(seq1, seq2))) + + +def unique(it: Iterable[_T]) -> ValuesView[_T]: + return {id(x): x for x in it}.values() + + +def ceildiv( + numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] +) -> Union[int, sympy.Expr]: + if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): + return CeilDiv(sympy.sympify(numer), sympy.sympify(denom)) + # TODO: There is a bug in a call to this function, to repro: + # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy + # --amp --only YituTechConvBert --dynamic-shapes + assert isinstance(numer, int) and isinstance( + denom, int + ), f"{numer}: {type(numer)}, {denom}: {type(denom)}" + return runtime_ceildiv(numer, denom) + + +def _type_of(key): + # Use the function here to get rid of dependencies on the Triton during the codegen. + # Refer to Triton implementation here: + # https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238 + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + dtype_str = str(key).split(".")[-1] + tys = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8e4b15x4": "fp8e4b15x4", + "float8_e4m3fn": "fp8e4nv", + "float8_e5m2": "fp8e5", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", + } + # reinterpret can create triton type + for v in list(tys.values()): + tys[v] = v + return key if isinstance(key, str) else f"*{tys[dtype_str]}" + + +def convert_shape_to_inductor( + lst: Iterable[Union[int, torch.SymInt]] +) -> List[sympy.Expr]: + """ + Gets the shape and stride of a tensor. For non-symbolic tensors, this is + trivial. But for symbolic tensors, we need to map from SymIntNode into + sympy.Expr. + """ + return [sympy.sympify(i) for i in lst] + + +def convert_shape_to_symint( + lst: Iterable[Union[int, sympy.Expr]] +) -> List[Union[int, torch.SymInt]]: + """ + Takes a list of shapes from Inductor and converts them into symints (or just + ints if all shapes are static). + """ + from .virtualized import V + + return [ + i + if isinstance(i, int) + else int(i) + if isinstance(i, sympy.Integer) + else V.graph.sizevars.shape_env.create_symintnode(i, hint=None) + for i in lst + ] + + +def is_view(op: torch._ops.OpOverload): + """ + Does this op overload have aliasing + """ + assert isinstance(op, torch._ops.OpOverload) + return any(a.alias_info is not None for a in op._schema.arguments) + + +def is_pointwise_use( + use, is_pointwise_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None +): + """ + Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn` + + Uses in views ops will follow the views uses + """ + + if not use.op == "call_function": + return False + + if not ( + isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem + ): + return False + + if use.target is operator.getitem or is_view(use.target): + return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users) + + return torch.Tag.pointwise in use.target.tags or ( + is_pointwise_fn is not None and is_pointwise_fn(use.target) + ) + + +def gen_gm_and_inputs(target, args, kwargs): + g = torch.fx.Graph() + g_args = [] + a_args = [] + for n, arg in enumerate(args): + if isinstance(arg, torch.Tensor): + g_args.append(g.placeholder(f"arg{n}")) + a_args.append(arg) + else: + g_args.append(arg) + assert all(not isinstance(x, torch.Tensor) for x in kwargs.values()) + node = g.call_function(target, tuple(g_args), kwargs) + if ( + len(target._schema.returns) == 1 + and str(target._schema.returns[0].type) == "Tensor" + ): + node = (node,) # type: ignore[assignment] + g.output(node) + + gm = torch.fx.GraphModule({}, g) + return gm, a_args + + +def synchronize(device: str = "cuda"): + if device == "cpu": + return + device_interface = get_interface_for_device(device) + if device_interface.is_available(): + device_interface.synchronize() + + +def timed( + model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda" +) -> float: + synchronize(device) + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(times): + result = model(*example_inputs) + synchronize(device) + t1 = time.perf_counter() + # GC the result after timing + assert result is not None # type: ignore[possibly-undefined] + return t1 - t0 + + +def print_performance( + fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda" +): + timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)]) + took = torch.median(timings) / times + print(f"{took / baseline:.6f}") + return took + + +def precompute_method(obj: Any, method: str): + """Replace obj.method() with a new method that returns a precomputed constant.""" + result = getattr(obj, method)() + setattr(obj, method, lambda: result) + + +def precompute_methods(obj: Any, methods: List[str]): + """Replace methods with new methods that returns a precomputed constants.""" + for method in methods: + precompute_method(obj, method) + + +def cmp(a, b) -> int: + return int(a > b) - int(a < b) + + +def pad_listlike(x, size): + if len(x) == 1: + return type(x)([x[0]]) * size + else: + return x + + +# Used to ensure that iterating over a set is deterministic +def tuple_sorted(x): + if len(x) == 0: + return [] + + def sort_func(elem): + if isinstance(elem, str): + return elem + else: + # We expect `elem` to be `scheduler.BaseSchedulerNode` type here, + # but we are not able to do isinstance assert because of circular dependency + return elem.get_name() + + return sorted(x, key=sort_func) + + +P = ParamSpec("P") +RV = TypeVar("RV", covariant=True) + + +class CachedMethod(Protocol, Generic[P, RV]): + @staticmethod + def clear_cache(self) -> None: + ... + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: + ... + + +# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature +def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]: + key = f"__{fn.__name__}_cache" + + @functools.wraps(fn) + def wrapper(self): + if not hasattr(self, key): + setattr(self, key, fn(self)) + return getattr(self, key) + + def clear_cache(self): + if hasattr(self, key): + delattr(self, key) + + wrapper.clear_cache = clear_cache # type: ignore[attr-defined] + return wrapper # type: ignore[return-value] + + +def aggregate_origins(node_schedule): + from . import ir + + if isinstance(node_schedule, list): + return functools.reduce( + operator.or_, + [ + node.node.origins + for node in node_schedule + if hasattr(node, "node") and node.node + ], + set(), + ) + elif isinstance(node_schedule, ir.ExternKernel): + return node_schedule.origins + else: + return set() + + +def get_fused_kernel_name(node_schedule, descriptive_names): + all_origins = aggregate_origins(node_schedule) + if descriptive_names == "original_aten": + # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions) + sources = [ + origin.meta["original_aten"]._overloadpacket.__name__ + for origin in all_origins + if origin.op == "call_function" + and "original_aten" in origin.meta + and origin.meta["original_aten"] is not None + ] + sources = sorted(set(sources)) + elif descriptive_names == "torch": + # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph) + sources = [] + for origin in all_origins: + if origin.op == "call_function" and "source_fn_stack" in origin.meta: + source_fn = origin.meta["source_fn_stack"][-1] + if isinstance(source_fn[1], str): + sources.append(source_fn[1]) + else: + sources.append(source_fn[1].__name__) + sources = sorted(set(sources)) + elif descriptive_names == "inductor_node": + sources = [ + origin.name for origin in all_origins if origin.op == "call_function" + ] + else: + raise NotImplementedError + sources = sources + return "_".join(["fused"] + sources) + + +def get_kernel_metadata(node_schedule, wrapper): + all_origins = aggregate_origins(node_schedule) + inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"] + + from_node_dict = collections.defaultdict(list) + original_aten_dict = collections.defaultdict(list) + + # Attempt to sort `inductor_nodes` topologically. Note that the case + # where `inductor_nodes` contains nodes from multiple graph instances + # is not supported. An example of this is conditional statements. + single_graph = None + if len(inductor_nodes): + unique_graphs = {n.graph for n in inductor_nodes} + if len(unique_graphs) == 1: + single_graph = inductor_nodes[0].graph + # create a map of idx -> node and cache it + if not hasattr(single_graph, "_inductor_kernel_metadata_node_to_idx_map"): + node_to_idx_map = {} + for idx, n in enumerate(single_graph.nodes): + node_to_idx_map[n] = idx + single_graph._inductor_kernel_metadata_node_to_idx_map = node_to_idx_map + inductor_nodes.sort( + key=lambda n: single_graph._inductor_kernel_metadata_node_to_idx_map[n] + ) + + for node in inductor_nodes: + if "original_aten" in node.meta and node.meta["original_aten"] is not None: + key = str(node.meta["original_aten"]._overloadpacket) + original_aten_dict[key].append(node.name) + if "from_node" in node.meta: + key = node.meta["from_node"][0][0] + from_node_dict[key].append(node.name) + sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted" + metadata = ( + f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], " + f"Original ATen: [{', '.join(original_aten_dict.keys())}]" + ) + + # trace back to original node here + detailed_metadata = [f"{wrapper.comment} Source node to ATen node mapping:"] + for original_node, nodes in sorted(from_node_dict.items()): + detailed_metadata.append( + f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}" + ) + + # print the aot_autograd graph fragment + if single_graph is not None: + detailed_metadata.append(f"{wrapper.comment} Graph fragment:") + for n in inductor_nodes: + # TODO(future): maybe refactor torch/fx/graph.py to make it easy to + # generate python code for graph fragments + detailed_metadata.append(f"{wrapper.comment} {n.format_node()}") + + return metadata, "\n".join(detailed_metadata) + + +def dominated_nodes( + initial_queue: Iterable[torch.fx.Node], skip_filter=None +) -> Set[torch.fx.Node]: + """Returns the set of nodes whose values depend on those within initial_queue""" + initial_queue = list(initial_queue) + dominated_set = set(initial_queue) + + while initial_queue: + node = initial_queue.pop() + for user in node.users: + if skip_filter and skip_filter(user): + continue + if user not in dominated_set: + dominated_set.add(user) + initial_queue.append(user) + + return dominated_set + + +def gather_origins(args, kwargs): + import itertools + + from . import ir + + def is_unrealized_node(n): + if isinstance(n, ir.TensorBox): + return is_unrealized_node(n.data) + if isinstance(n, ir.StorageBox): + return is_unrealized_node(n.data) + return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise) + + kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)] + arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)] + return set(itertools.chain(*arg_origins, *kwarg_origins)) + + +def sympy_str(expr: sympy.Expr) -> str: + """ + Normal sympy str is very slow, this is a lot faster. The result are + somewhat worse, as it doesn't do as much simplification. So don't + use this for final codegen. + """ + if isinstance(expr, sympy.Symbol): + return expr.name + if isinstance(expr, sympy.Add): + return " + ".join(map(sympy_str, expr.args)) + if isinstance(expr, sympy.Mul): + return " * ".join(map(sympy_str, expr.args)) + + if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)): + return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})" + return str(expr) + + +def get_bounds_index_expr(index): + from .virtualized import V + + # If this expression does not come from an FX node, we compute its bounds + if ( + config.compute_all_bounds + and (fx_node := getattr(V.interpreter, "current_node", None)) + and fx_node.target != "index_expr" + ): + return bound_sympy(index) + else: + return ValueRanges.unknown() + + +def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol: + """ + Used to generate an integer-nonnegative symbol. + """ + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert prefix != SymT.SIZE + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return make_symbol(prefix, idx, integer=True, nonnegative=True) + + +def generate_assert(check): + return (check or config.debug_index_asserts) and config.assert_indirect_indexing + + +def sympy_index_symbol(name: str) -> sympy.Symbol: + """ + Used to generate an integer-nonnegative symbol. + """ + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert name[0] != "s" + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return sympy.Symbol(name, integer=True, nonnegative=True) + + +def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr: + """ + When the passed replacement symbol v is a string, it is converted to a symbol with name v that + have the same replaced expression integer and nonnegative properties. + """ + + def to_symbol(replaced, replacement): + assert isinstance(replaced, sympy.Expr) + if isinstance(replacement, str): + return sympy.Symbol( + replacement, + integer=replaced.is_integer, # type: ignore[attr-defined] + nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined] + ) + else: + return replacement + + # xreplace is faster than subs, but is way more picky + return sympy.sympify(expr).xreplace( + {k: to_symbol(k, v) for k, v in replacements.items()} + ) + + +def is_symbolic(a: Any) -> bool: + return isinstance(a, torch.SymInt) or ( + isinstance(a, torch.Tensor) + and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride())) + ) + + +def any_is_symbolic(*args: Any) -> bool: + return any(is_symbolic(a) for a in args) + + +def get_first_incompatible_cudagraph_node(gm): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + forbidden_set = { + "aten._fused_moving_avg_obs_fq_helper.default", + "aten._fused_moving_avg_obs_fq_helper_functional.default", + "aten.multinomial.default", + "fbgemm.dense_to_jagged.default", + "fbgemm.jagged_to_padded_dense.default", + "run_and_save_rng_state", + "run_with_rng_state", + "aten._local_scalar_dense", + # Technically, it's not necessary to ban this, because an + # assert_scalar with constant arguments can be validly run + # with CUDA graphs, but the operator is also pointless with + # constant arguments, so might as well ban + "aten._assert_scalar", + } + if torch.are_deterministic_algorithms_enabled(): + forbidden_set.update( + { + "aten._unsafe_index_put.default", + "aten._unsafe_masked_index_put_accumulate.default", + "aten.index_put.default", + "aten.index_put_.default", + "aten.scatter.src", + "aten.scatter.reduce", + "aten.scatter.value_reduce", + "aten.scatter_add_", + "aten.scatter_add.default", + "aten.scatter_reduce.two", + "aten.scatter_reduce_.two", + "aten.scatter_reduce.two_out", + } + ) + for node in gm.graph.nodes: + if str(node.target) in forbidden_set: + return node + if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val): + return node + return None + + +def has_incompatible_cudagraph_ops(gm): + return get_first_incompatible_cudagraph_node(gm) is not None + + +def output_node(gm: torch.fx.GraphModule): + """Get the output node from an FX graph""" + last_node = next(iter(reversed(gm.graph.nodes))) + assert last_node.op == "output" + return last_node + + +_registered_caches: List[Any] = [] + + +def clear_on_fresh_inductor_cache(obj: Any): + """ + Use this decorator to register any caches that should be cache_clear'd + with fresh_inductor_cache(). + """ + if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear): + raise AttributeError(f"{obj} does not have a cache_clear method") + + _registered_caches.append(obj) + return obj + + +def clear_inductor_caches(): + """ + Clear all registered caches. + """ + for obj in _registered_caches: + obj.cache_clear() + + +@contextlib.contextmanager +def fresh_inductor_cache(cache_entries=None, dir=None, delete=True): + """ + Contextmanager that provides a clean tmp cachedir for inductor. + + Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes + generated with this cache instance. + """ + clear_inductor_caches() + + inductor_cache_dir = tempfile.mkdtemp(dir=dir) + try: + with mock.patch.dict( + os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} + ): + log.debug("Using inductor cache dir %s", inductor_cache_dir) + triton_cache_dir = os.path.join(inductor_cache_dir, "triton") + with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): + yield + if isinstance(cache_entries, dict): + assert len(cache_entries) == 0, "expected empty cache_entries dict" + if os.path.exists(triton_cache_dir): + files = os.listdir(triton_cache_dir) + cache_entries.update( + { + f: os.path.getsize(os.path.join(triton_cache_dir, f)) + for f in files + if ".lock" not in f + } + ) + if delete: + shutil.rmtree(inductor_cache_dir) + except Exception: + if not _IS_WINDOWS: + """ + Windows can't delete the loaded modules, because the modules binaries are opened. + TODO: discuss if have better solution to handle this issue. + """ + log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir) + raise + finally: + clear_inductor_caches() + + +def argsort(seq) -> List[int]: + # preserve original order for equal strides + getter = seq.__getitem__ + a_r = range(len(seq)) + return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413 + + +@functools.lru_cache(8) +def get_dtype_size(dtype): + return torch.empty((), dtype=dtype).element_size() + + +class LineContext(NamedTuple): + context: Any + + +class IndentedBuffer: + tabwidth = 4 + + def __init__(self, initial_indent=0): + self._lines = [] + self._indent = initial_indent + + def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]: + buf = StringIO() + p = 1 + linemap = [] + for line in self._lines: + if isinstance(line, DeferredLineBase): + line = line() + if line is None: + continue + elif isinstance(line, LineContext): + linemap.append((p, line.context)) + continue + assert isinstance(line, str) + buf.write(line) + buf.write("\n") + p += 1 + line.count("\n") + return buf.getvalue(), linemap + + def getvalue(self) -> str: + v, _ = self.getvaluewithlinemap() + return v + + def getrawvalue(self) -> str: + buf = StringIO() + for line in self._lines: + if isinstance(line, DeferredLineBase): + line = line() + if line is None: + continue + elif isinstance(line, LineContext): + continue + assert isinstance(line, str) + # backslash implies line continuation + if line.endswith("\\"): + buf.write(line[:-1]) + else: + buf.write(line) + buf.write("\n") + return buf.getvalue() + + def clear(self): + self._lines.clear() + + def __bool__(self): + return bool(self._lines) + + def prefix(self): + return " " * (self._indent * self.tabwidth) + + def newline(self): + self.writeline("\n") + + def writeline(self, line): + if isinstance(line, LineContext): + self._lines.append(line) + elif isinstance(line, DeferredLineBase): + self._lines.append(line.with_prefix(self.prefix())) + elif line.strip(): + self._lines.append(f"{self.prefix()}{line}") + else: + self._lines.append("") + + def writelines(self, lines): + for line in lines: + self.writeline(line) + + def indent(self, offset=1): + @contextlib.contextmanager + def ctx(): + self._indent += offset + try: + yield + finally: + self._indent -= offset + + return ctx() + + def do_indent(self, offset=1): + self._indent += offset + + def do_unindent(self, offset=1): + self._indent -= offset + + def splice(self, other_code, strip=False): + if isinstance(other_code, IndentedBuffer): + dedent = float("inf") + for line in other_code._lines: + if not isinstance(line, LineContext) and line: + dedent = min(dedent, len(line) - len(line.lstrip())) + if math.isinf(dedent): + dedent = 0 + for line in other_code._lines: + if isinstance(line, LineContext): + self._lines.append(line) + else: + IndentedBuffer.writeline(self, line[int(dedent) :]) + else: + other_code = textwrap.dedent(other_code) + if strip: + other_code = other_code.lstrip() + if not other_code: + return + other_code = other_code.rstrip() + for line in other_code.split("\n"): + self.writeline(line) + + def map(self, func: Callable[[Any], Any]) -> IndentedBuffer: + res = IndentedBuffer(initial_indent=self._indent) + res._lines = [func(line) for line in self._lines] + return res + + def __repr__(self): + return f"{type(self)}({self.getvalue()})" + + def __add__(self, other): + assert self._indent == other._indent + res = IndentedBuffer(initial_indent=self._indent) + res.writelines(self._lines) + res.writelines(other._lines) + return res + + +class FakeIndentedBuffer(IndentedBuffer): + def __init__(self) -> None: + super().__init__() + + def __getattribute__(self, name): + if name == "__class__": # Allow access to the class attribute + return object.__getattribute__(self, name) + raise RuntimeError( + f"Tried to call self.{name} on FakeIndentedBuffer. This buffer" + "is currently used on TritonTemplateKernel to prevent actual" + "writes to the body without explicitly specifying the body with" + "`TritonTemplateKernel.set_subgraph_body(name)`" + ) + + +@contextlib.contextmanager +def restore_stdout_stderr(initial_stdout, initial_stderr): + try: + yield + finally: + sys.stdout = initial_stdout + sys.stderr = initial_stderr + + +class DeferredLineBase: + """A line that can be 'unwritten' at a later time""" + + def __init__(self, line): + if not line.strip(): + line = "" + self.line = line + + def __call__(self) -> Optional[str]: + """Returns either self.line or None to indicate the line has been 'unwritten'""" + raise NotImplementedError + + def _new_line(self, line: str) -> DeferredLineBase: + """Returns a new deferred line with the same condition""" + raise NotImplementedError + + def with_prefix(self, prefix): + return self._new_line(f"{prefix}{self.line}") + + def lstrip(self): + return self._new_line(self.line.lstrip()) + + def __getitem__(self, index): + return self._new_line(self.line[index]) + + def __bool__(self): + return bool(self.line) + + def __len__(self): + return len(self.line) + + +@functools.lru_cache(None) +def is_big_gpu(index) -> bool: + min_sms = 68 # 3080 + avail_sms = torch.cuda.get_device_properties(index).multi_processor_count + if avail_sms < min_sms: + log.warning( + "Not enough SMs to use max_autotune_gemm mode", + extra={"min_sms": min_sms, "avail_sms": avail_sms}, + ) + return False + return True + + +def use_max_autotune() -> bool: + return config.max_autotune or config.max_autotune_gemm + + +def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool: + return ( + use_max_autotune() + and layout.device.type == "cuda" + and layout.dtype in allowed_layout_dtypes + and is_big_gpu(layout.device.index or 0) + ) + + +def _use_autotune_backend(backend: str) -> bool: + return backend.upper() in [ + x.strip() for x in config.max_autotune_gemm_backends.upper().split(",") + ] + + +def _use_conv_autotune_backend(backend: str) -> bool: + return backend.upper() in [ + x.strip() for x in config.max_autotune_conv_backends.upper().split(",") + ] + + +def use_triton_template(layout, *, enable_int32=False, enable_float8=False): + from .codegen.common import BackendFeature, has_backend_feature + + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32] + if enable_int32: + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] + if enable_float8: + layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2]) + return ( + _use_template_for_cuda(layout, layout_dtypes) + and _use_autotune_backend("TRITON") + and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES) + ) + + +def use_cutlass_template(layout, m, n, k): + from .virtualized import V + + gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) + if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size: + return False + from .codegen.cuda.cutlass_utils import try_import_cutlass + + # Do not use cutlass template on ROCm + if torch.version.hip: + return False + + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] + res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend( + "CUTLASS" + ) + + if res: + if not try_import_cutlass(): + log.warning( + "Failed to import CUTLASS lib. Please check whether " + "_inductor.config.cuda.cutlass_dir is set correctly. " + "Skipping CUTLASS backend for now." + ) + return False + return res + + +@functools.lru_cache(None) +def _rocm_native_device_arch_name(device): + return torch.cuda.get_device_properties(device).gcnArchName + + +@functools.lru_cache(None) +def try_import_ck_lib(): + try: + import ck4inductor # type: ignore[import] + from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import] + gen_ops_library, + gen_ops_preselected, + ) + from ck4inductor.universal_gemm.op import ( # type: ignore[import] + CKGemmOperation, + ) + + package_dirname = os.path.dirname(ck4inductor.__file__) + except ImportError: + + def gen_ops_library(): + return [] + + def gen_ops_preselected(): + return [] + + class CKGemmOperation: # type: ignore[no-redef] + pass + + package_dirname = None + return package_dirname, gen_ops_library, gen_ops_preselected, CKGemmOperation + + +def use_ck_template(layout, m, n, k): + # config knobs check 1 + if not use_max_autotune(): + return False + # config knobs check 2 + if not _use_autotune_backend("CK"): + return False + # platform check + if not torch.version.hip: + return False + # tensors must be on GPU + if not layout.device.type == "cuda": + return False + # hardware check + # if config arch list is not specified, get the native arch from the device properties + native_arch = _rocm_native_device_arch_name(layout.device) + requested_archs = {k.split(":")[0]: k for k in config.rocm.arch} or { + native_arch.split(":")[0]: native_arch + } + requested_supported_archs = [ + requested_archs[k] + for k in requested_archs.keys() & config.rocm.ck_supported_arch + ] + if not requested_supported_archs: + return False + # supported input dtypes + if layout.dtype not in [torch.float16, torch.bfloat16]: + return False + # TBD: investigate if we need to disable backend based on number of available CUs similar to `is_big_gpu` + # check if shape is static and gemm size is not 0 + from .virtualized import V + + gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) + if gemm_size <= 0: + return False + # TBD: investigate if backend needs to be disabled for small gemms similar to CUTLASS + + ck_package_dirname, _, _, _ = try_import_ck_lib() + + if not ck_package_dirname: + log.warning("Please pip install Composable Kernel package") + return False + + if not config.rocm.ck_dir: + log.warning("Please set TORCHINDUCTOR_CK_DIR env variable") + return False + + if ck_package_dirname != config.rocm.ck_dir: + log.warning("Invalid path to CK library") + return False + + return True + + +def _use_template_for_cpu(layout): + return use_max_autotune() and layout.device.type == "cpu" + + +def use_cpp_packed_gemm_template(layout, mat1, mat2, mat2_transposed=False): + from . import ir + from .codegen.cpp_micro_gemm import create_micro_gemm + from .codegen.cpp_utils import get_gemm_template_output_and_compute_dtype + from .kernel.mm_common import mm_args + + if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"): + return False + + if not config.cpp.weight_prepack: + return False + + int8_gemm = mat1.get_dtype() == torch.uint8 + layout_dtypes = [torch.float32, torch.bfloat16, torch.half, torch.uint8] + m, n, k, layout, mat1, mat2 = mm_args( + mat1, + mat2, + out_dtype=layout.dtype if int8_gemm else None, + mat2_transposed=mat2_transposed, + ) + + # TODO(jgong5): support dynamic shapes for n or k + if has_free_symbols((n, k)): + return False + if isinstance(mat2, ir.BaseView): + mat2 = mat2.unwrap_view() + + output_dtype, _ = get_gemm_template_output_and_compute_dtype(mat1.get_dtype()) + micro_gemm = create_micro_gemm( + "micro_gemm", + m, + n, + k, + input_dtype=mat1.get_dtype(), + input2_dtype=mat2.get_dtype(), + output_dtype=output_dtype, + num_threads=parallel_num_threads(), + ) + + def is_last_dim_stride1(x): + x.freeze_layout() + return x.get_stride()[-1] == 1 + + return ( + layout.dtype in layout_dtypes + and micro_gemm is not None + and is_last_dim_stride1(mat1) # TODO(jgong5): support transposed input + and isinstance(mat2, ir.StorageBox) + and mat2.is_module_buffer() + ) + + +def use_aten_gemm_kernels(): + return not use_max_autotune() or _use_autotune_backend("ATEN") + + +class DebugDirManager: + counter = itertools.count(0) + prev_debug_name: str + + def __init__(self) -> None: + self.id = next(DebugDirManager.counter) + + def __enter__(self): + self.prev_debug_name = torch._dynamo.config.debug_dir_root + self.new_name = f"{self.prev_debug_name}_tmp_{self.id}" + torch._dynamo.config.debug_dir_root = self.new_name + + def __exit__(self, *args): + shutil.rmtree(self.new_name) + torch._dynamo.config.debug_dir_root = self.prev_debug_name + + +def run_and_get_code(fn, *args, **kwargs): + from .graph import GraphLowering + + source_codes: List[str] = [] + + def save_output_code(code: str): + source_codes.append(code) + + with mock.patch.object(GraphLowering, "save_output_code", save_output_code): + torch._dynamo.reset() + result = fn(*args, **kwargs) + return result, source_codes + + +def run_fw_bw_and_get_code(fn): + def run_with_backward(): + result = fn() + result.sum().backward() + return result + + return run_and_get_code(run_with_backward) + + +def get_code(fn, *args, **kwargs): + """Get the inductor-generated code, but skip any actual compilation or running.""" + from .graph import GraphLowering + + source_codes: List[str] = [] + + def save_output_code(code: str): + source_codes.append(code) + + def patched_compile_to_module(self: GraphLowering): + class DummyModule: + """This is empty to replace the generated triton module""" + + def __init__(self) -> None: + pass + + def call(self, *args, **kwargs): + # Don't do anything when called + pass + + code, _ = ( + self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() + ) + # Skip all the actual compiling. + nonlocal save_output_code + save_output_code(code) + + return DummyModule() + + with mock.patch.object( + GraphLowering, "compile_to_module", patched_compile_to_module + ), mock.patch.object(GraphLowering, "save_output_code", save_output_code): + torch._dynamo.reset() + # Note the return here is None + _ = fn(*args, **kwargs) + + return source_codes + + +def get_triton_code(fn, *args, **kwargs): + source_codes = get_code(fn, *args, **kwargs) + # Can have two outputs if backwards was eagerly compiled + assert ( + 1 <= len(source_codes) <= 2 + ), f"expected one or two code outputs got {len(source_codes)}" + return source_codes[0] + + +def run_and_get_triton_code(fn, *args, **kwargs): + _, source_codes = run_and_get_code(fn, *args, **kwargs) + # Can have two outputs if backwards was eagerly compiled + assert ( + 1 <= len(source_codes) <= 2 + ), f"expected one or two code outputs got {len(source_codes)}" + return source_codes[0] + + +def run_and_get_graph_lowering(fn, *args, **kwargs): + from torch._inductor.codecache import CompiledFxGraph + from torch._inductor.graph import GraphLowering + + real_init = CompiledFxGraph.__init__ + graph_lowerings = [] + + def fake_init(*args, **kwargs): + real_init(*args, **kwargs) + graph = args[2] + assert isinstance(graph, GraphLowering) + graph_lowerings.append(graph) + + with mock.patch.object(CompiledFxGraph, "__init__", fake_init): + result = fn(*args, **kwargs) + + return result, graph_lowerings + + +@contextlib.contextmanager +def override_lowering(aten_op, override_fn): + """ + Override the lowering of aten_op with override_fn. + The first argument of override_fn is the original lowering fn. + """ + from torch._inductor import lowering + + orig_fn = lowering.lowerings[aten_op] + try: + lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn) + yield + finally: + lowering.lowerings[aten_op] = orig_fn + + +def add_scheduler_init_hook(pre_fn, post_fn=None): + """ + Add hook functions to be called at the beginning and end of Scheduler.__init__. + Used for unit tests. + """ + from torch._inductor.scheduler import Scheduler + + orig_fn = Scheduler.__init__ + + def wrapper(scheduler, nodes): + pre_fn(scheduler, nodes) + out = orig_fn(scheduler, nodes) + if post_fn: + post_fn(scheduler, nodes) + return out + + return unittest.mock.patch.object(Scheduler, "__init__", wrapper) + + +def developer_warning(msg): + """ + Warnings that will be actionable for PyTorch developers, but not + end users. Allows us to easily disable them in stable releases but + keep them on for nightly builds. + """ + if config.developer_warnings: + log.warning(msg) + else: + log.info(msg) + + +def get_benchmark_name(): + """ + An experimental API used only when config.benchmark_kernel is true. + + The benchmark name is only available at codegen time. So we can not + directly call it in benchmark_all_kernels which is run after codegen. + + The function assumes the argument after --only is the benchmark name. + It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc + scripts, this function may return None. + + There are 2 flavors of --only argument we need handle: + 1. --only model_name + 2. --only=model_name + """ + try: + idx = sys.argv.index("--only") + if ( + idx + 1 < len(sys.argv) + and len(sys.argv[idx + 1]) > 0 + and sys.argv[idx + 1][0] != "-" + ): + return sys.argv[idx + 1] + except ValueError: + pass + + for arg in sys.argv: + if arg.startswith("--only="): + return arg[len("--only=") :] + + +def is_ones(items): + return all(x == 1 for x in items) + + +def is_zeros(items): + return all(x == 0 for x in items) + + +def is_cpu_device(inputs): + return all( + item.device == torch.device("cpu") + for item in inputs + if isinstance(item, torch.Tensor) + ) + + +def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype: + assert isinstance( + val, sympy.Expr + ), "only support sympy.Expr as input to get_sympy_Expr_dtype" + if val.is_integer: # type: ignore[attr-defined] + return torch.int64 + else: + return torch.float64 + + +@contextlib.contextmanager +def maybe_profile(should_profile, *args, **kwargs): + if should_profile: + with torch.profiler.profile(*args, **kwargs) as p: + yield p + else: + yield + + +def parallel_num_threads(): + threads = config.cpp.threads + if threads < 1: + threads = torch.get_num_threads() + return threads + + +@functools.lru_cache(None) +def get_device_tflops(dtype): + from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops + + assert dtype in (torch.float16, torch.bfloat16, torch.float32) + + if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"): + # Triton API change in https://github.com/openai/triton/pull/2293 + from torch._utils_internal import max_clock_rate + + sm_clock = max_clock_rate() + if dtype in (torch.float16, torch.bfloat16): + return get_max_tensorcore_tflops(dtype, sm_clock) + + if torch.backends.cuda.matmul.allow_tf32: + return get_max_tensorcore_tflops(torch.float32, sm_clock) + else: + return get_max_simd_tflops(torch.float32, sm_clock) + else: + if dtype in (torch.float16, torch.bfloat16): + return get_max_tensorcore_tflops(dtype) + + if torch.backends.cuda.matmul.allow_tf32: + return get_max_tensorcore_tflops(torch.float32) + else: + return get_max_simd_tflops(torch.float32) + + +@functools.lru_cache(None) +def get_gpu_dram_gbps(): + from triton.testing import get_dram_gbps + + return get_dram_gbps() + + +def get_gpu_shared_memory(): + from triton.runtime import driver + + return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0) + + +def is_welford_reduction(reduction_type): + return reduction_type.startswith("welford") + + +def reduction_num_outputs(reduction_type): + return 3 if is_welford_reduction(reduction_type) else 1 + + +def is_linux() -> bool: + return platform.system() == "Linux" + + +def is_windows(): + return sys.platform == "win32" + + +def has_free_symbols(itr: Iterable[Any]): + return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr) + + +def is_dynamic(*args): + from . import ir + + for t in args: + if isinstance(t, ir.TensorBox): + if has_free_symbols(t.data.get_size()) or ( + hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride()) + ): + return True + elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)): + assert hasattr(t, "get_size") and hasattr(t, "get_stride") + if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()): + return True + elif not isinstance(t, ir.IRNode): + continue + else: + raise TypeError(f"unexpected type for is_dynamic {type(t)}") + + return False + + +# Placeholder strings used in triton codegen. +class Placeholder(enum.Enum): + # The placeholder for the actual name of a triton kernel. + # e.g. for "def triton_" it would be "triton_" + KERNEL_NAME = "KERNEL_NAME" + + # The descriptive name of the triton kernel; when unique_kernel_names = False, this + # placeholder will be replaced with a string with more information. + DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME" + + +def pass_execution_and_save(func, gm, inp, msg): + from .pattern_matcher import stable_topological_sort + + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + delete=False, + ) as f: + before_io = io.StringIO() + after_io = io.StringIO() + ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp) + print(f"Before:\n{gm.graph}", file=f) + print(gm.graph, file=before_io) + start_time = datetime.now() + with GraphTransformObserver(gm, msg, config.trace.log_url_for_graph_xform): + func(gm.graph) + time_elapsed = datetime.now() - start_time + # recompile graph + stable_topological_sort(gm.graph) + gm.graph.lint() + gm.recompile() + + print(f"After:\n{gm.graph}", file=f) + print(gm.graph, file=after_io) + t = before_io.getvalue() == after_io.getvalue() + log.info( + "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s", + msg, + f.name, + t, + time_elapsed, + ) + + +def is_collective(node, op=None): + from . import ir + + return type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op) + + +def is_wait(node): + from . import ir + + return type(node) == ir._WaitKernel + + +def contains_collective(snode): + from torch._inductor.scheduler import BaseSchedulerNode, GroupedSchedulerNode + + assert isinstance(snode, BaseSchedulerNode) + if isinstance(snode, GroupedSchedulerNode): + return any(contains_collective(x) for x in snode.snodes) + else: + return is_collective(snode.node) + + +def contains_wait(snode): + from torch._inductor.scheduler import BaseSchedulerNode, GroupedSchedulerNode + + assert isinstance(snode, BaseSchedulerNode) + if isinstance(snode, GroupedSchedulerNode): + return any(contains_wait(x) for x in snode.snodes) + else: + return is_wait(snode.node) + + +def is_fallback_op(node, op): + from . import ir + + if isinstance(op, torch._ops.OpOverload): + op = {op} + return isinstance(node, ir.FallbackKernel) and node.op_overload in op + + +def buf_name_to_fused_snode(buf_name, name_to_buf, name_to_fused_node): + return name_to_fused_node[name_to_buf[buf_name].defining_op.get_name()] + + +def find_recursive_deps_of_node( + snode, collected_node_set, name_to_buf, name_to_fused_node, criteria_cb=None +): + if criteria_cb and criteria_cb(snode): + return + collected_node_set.add(snode) + for dep in snode.unmet_dependencies: + defining_op_for_dep = buf_name_to_fused_snode( + dep.name, name_to_buf, name_to_fused_node + ) + if defining_op_for_dep in collected_node_set: + continue + find_recursive_deps_of_node( + defining_op_for_dep, + collected_node_set, + name_to_buf, + name_to_fused_node, + criteria_cb=criteria_cb, + ) + + +def find_recursive_users_of_node( + snode, collected_node_set, name_to_buf, name_to_fused_node, criteria_cb=None +): + if criteria_cb and criteria_cb(snode): + return + collected_node_set.add(snode) + for o in snode.get_outputs(): + for user in o.users: + assert user.node is not None + if user.node.get_name() == "OUTPUT": + continue + if user.node.get_name() not in name_to_fused_node: + continue + user_op = name_to_fused_node[user.node.get_name()] + if user_op in collected_node_set: + continue + find_recursive_users_of_node( + user_op, + collected_node_set, + name_to_buf, + name_to_fused_node, + criteria_cb=criteria_cb, + ) + + +def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int): + "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)" + num_rng_seed_offset_inputs = ( + 2 if torch._functorch.config.functionalize_rng_ops else 0 + ) + # AOT won't lift any parameters if we're inlining NN Modules + # however desugaring subclasses will still add arguments + # resulted in extra fixed inputs https://github.com/pytorch/pytorch/issues/130502 + if ( + torch._dynamo.config.inline_inbuilt_nn_modules + and not torch._dynamo.utils.is_parameter_freezing() + ): + return 0 + + return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs + + +def count_tangents(fx_g: torch.fx.GraphModule): + """ + Infers which inputs are static for a backwards graph + """ + + def is_saved_tensor(x): + return ( + "tangents" not in x.name + and "bwd_seed" not in x.name + and "bwd_base_offset" not in x.name + ) + + arg_count = 0 + static_arg_idxs = [] + for n in fx_g.graph.nodes: + if n.op == "placeholder": + if is_saved_tensor(n): + static_arg_idxs.append(arg_count) + arg_count += 1 + + assert static_arg_idxs == list(range(len(static_arg_idxs))) + return len(static_arg_idxs) + + +@dataclasses.dataclass +class BoxedBool: + value: bool + + def __bool__(self): + return self.value + + @staticmethod + def disable(obj): + if isinstance(obj, BoxedBool): + obj.value = False + return obj + return False + + +@contextlib.contextmanager +def collect_defined_kernels(kernel_list): + from .codegen.wrapper import WrapperCodeGen + + orig_define_kernel = WrapperCodeGen.define_kernel + + def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs): + nonlocal kernel_list + kernel_list.append(kernel_code) + return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs) + + with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel): + yield + + +def get_cloned_parameter_buffer_name(name: str): + return name + "__original__" + + +def is_gpu(device: str): + assert isinstance(device, str) or device is None, device + return device in ["cuda", "xpu"] + + +def device_need_guard(device: str): + assert isinstance(device, str) + return is_gpu(device) + + +def needs_fallback_due_to_atomic_add_limitations(dtype): + # tl.atomic_add does NOT support the following types + return dtype in {torch.int64, torch.bool, torch.bfloat16} + + +def use_scatter_fallback( + op_overload: torch._ops.OpOverload, + reduction_type, + self_dtype, + src_dtype, + src_device_type, + src_is_tensor, +): + if ( + op_overload.overloadpacket + in (torch.ops.aten.scatter_reduce_, torch.ops.aten.scatter_reduce) + and reduction_type is None + ): + return False + + reduce_ty = ( + "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum" + ) + + return ( + reduction_type not in {None, reduce_ty} + or ( + src_is_tensor + and is_gpu(src_device_type) + and needs_fallback_due_to_atomic_add_limitations(src_dtype) + ) + or ( + op_overload.overloadpacket == torch.ops.aten.scatter_reduce_ + and reduction_type == "sum" + and src_is_tensor + and src_device_type == "cpu" + and config.cpp.fallback_scatter_reduce_sum + and (config.cpp.dynamic_threads or parallel_num_threads() != 1) + ) + or (reduction_type == reduce_ty and self_dtype in {torch.bool, torch.int64}) + or torch.are_deterministic_algorithms_enabled() + ) + + +def dump_node_schedule(node_schedule): + """ + An API that can be used in pdb to dump a node_schedule. + Right mainly dump the read/write dependencies but can add more as needed. + """ + from torch._inductor.codegen.simd import DisableReduction, EnableReduction + from torch._inductor.scheduler import SchedulerNode + + print(f"Node schedule with {len(node_schedule)} nodes") + for idx, node in enumerate(node_schedule): + print(f" {idx:3}:") + if node is EnableReduction: + print("enable reduction") + elif node is DisableReduction: + print("disable reduction") + elif isinstance(node, SchedulerNode): + is_red = node.is_reduction() + print(f"{'red' if is_red else 'pw'} scheduler node") + if is_red: + assert node.node is not None + print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined] + print("ReadDep:") + for dep in node.read_writes.reads: + print(dep) + print("WriteDep:") + for dep in node.read_writes.writes: + print(dep) + else: + raise RuntimeError(f"Unrecognized node type: {type(node)}") + + +def tensor_is_aligned(tensor: torch.Tensor): + # See Note: [Input Alignment handling in Inductor] + # Right now, we don't try to guard on the alignment of the storage offset. + # When this comment was written, non-symbolic storage_offsets are not guarded on + # but symbolic storage_offsets are. For consistency, we suppress guard creation + # upon performing this check: that ensures that we don't add recompiles when we + # add this logic. + from torch.fx.experimental.symbolic_shapes import statically_known_true + + return statically_known_true( + (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % GPU_ALIGN_BYTES == 0 + ) + + +def should_assume_input_aligned(example_input: torch.Tensor): + # See Note: [Input Alignment handling in Inductor] + + # right now, we only care about alignment for cuda tensors. + if not is_gpu(example_input.device.type): + return False + return config.assume_aligned_inputs or tensor_is_aligned(example_input) + + +def maybe_get_suppress_shape_guards_ctx(): + # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards() + # If it's not available, return a nullcontext. + + # If we're dealing with cudagraphs, we might not have a tracing_context + tracing_context = torch._guards.TracingContext.try_get() + if not tracing_context: + return contextlib.nullcontext() + + # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode + shape_env = tracing_context.fake_mode.shape_env + if not shape_env: + return contextlib.nullcontext() + + return shape_env.suppress_guards() + + +def run_and_get_cpp_code(fn, *args, **kwargs): + # We use the patch context manager instead of using it as a decorator. + # In this way, we can ensure that the attribute is patched and unpatched correctly + # even if this run_and_get_cpp_code function is called multiple times. + with unittest.mock.patch.object(config, "debug", True): + torch._dynamo.reset() + import io + import logging + + log_capture_string = io.StringIO() + ch = logging.StreamHandler(log_capture_string) + from torch._inductor.codecache import output_code_log + + output_code_log.addHandler(ch) + prev_level = output_code_log.level + output_code_log.setLevel(logging.DEBUG) + result = fn(*args, **kwargs) + s = log_capture_string.getvalue() + output_code_log.setLevel(prev_level) + output_code_log.removeHandler(ch) + return result, s + + +def shape_env_from_inputs(inputs: List[torch.Tensor]): + shape_env = None + fake_mode = detect_fake_mode(inputs) + + # TODO(voz): It would be nice to enable this assert, but there are lots of tests that + # pass in real inputs for now. + # if len(inputs) > 0: + # assert fake_mode is not None, breakpoint() + + if fake_mode is not None: + return fake_mode.shape_env + + # When there are no tensor inputs, get shape_env from the first SymInt. + for input in inputs: + if isinstance(input, torch.SymInt): + return input.node.shape_env + + # TODO(voz): Should we always have one anyway? + return None + + +def align_inputs_from_check_idxs( + model: Callable[[List[InputType]], Any], + inputs_to_check: Sequence[int], +) -> Callable[[List[InputType]], Any]: + if len(inputs_to_check) == 0: + return model + + def run(new_inputs: List[InputType]): + copy_misaligned_inputs(new_inputs, inputs_to_check) + return model(new_inputs) + + return run + + +def clone_preserve_strides(x: torch.Tensor): + needed_size = ( + sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 + ) + buffer = torch.as_strided(x, (needed_size,), (1,)).clone() + return torch.as_strided(buffer, x.size(), x.stride()) + + +def copy_misaligned_inputs( + new_inputs: List[InputType], check_inputs_idxs: Sequence[int] +) -> None: + for i in check_inputs_idxs: + _inp = new_inputs[i] + assert isinstance(_inp, torch.Tensor) + if _inp.data_ptr() % ALIGNMENT: + new_inputs[i] = clone_preserve_strides(_inp) + + +def remove_unaligned_input_idxs( + inputs: List[InputType], + static_input_idxs: Sequence[int], +): + """ + We require all inputs to be aligned, so introduce a copy for any + that aren't. + """ + aligned_static_input_idxs = [] + for idx in static_input_idxs: + input = inputs[idx] + if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0: + aligned_static_input_idxs.append(idx) + if len(aligned_static_input_idxs) != len(static_input_idxs): + return aligned_static_input_idxs + return static_input_idxs + + +def set_tracing_context_output_strides(example_inputs, compiled_graph): + # Return the output strides to the caller via TracingContext + context = torch._guards.TracingContext.try_get() + if context is not None and context.output_strides is not None: + assert len(context.output_strides) == 0 + shape_env = shape_env_from_inputs(example_inputs) + for exprs in compiled_graph.output_strides: + if exprs is None: + context.output_strides.append(None) + else: + context.output_strides.append( + tuple( + ( + shape_env.evaluate_symexpr(e) + if shape_env is not None + else int(e) + ) + for e in exprs + ) + ) diff --git a/lib/python3.10/site-packages/torch/_inductor/virtualized.py b/lib/python3.10/site-packages/torch/_inductor/virtualized.py new file mode 100644 index 0000000000000000000000000000000000000000..b00a94c5f2cee32534f94730b34dc7bffa7a9c56 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/virtualized.py @@ -0,0 +1,361 @@ +# mypy: allow-untyped-defs +""" +This file provides a number of "global" variables/handlers that are actually +thread local and dynamically scoped, with Inductor patching them to various +implementations depending on the situation. + +These handlers are interacted with in a fairly stylized way. Typically, +we will import V from this module:: + + from .virtualized import V + +Various handlers are accessible as attributes on this module; for example, +you might access ``V.graph.sizevars.size_hint`` to resolve a size hint associated with +a number. + +There are a few distinct usage patterns for virtualized global variables: + +1. Implicit argument passing. Examples: ``V.current_node``, ``V.aot_compilation``. + Use ``V.set_current_node`` to change what the current node is while we're + executing some region of code, so code inside that region can query ``V.current_node`` + to find out what it is. This is often more convenient than manually threading + the current node as an argument through all call stacks. + +2. Per-compilation global state. Examples: ``V.fake_mode``, ``V.graph``. For a + given ``compile_fx`` invocation, these typically don't change, but they are + associated with some internal state so they cannot just be global functions. + We install these objects at the beginning of compilation and then you can + conveniently access them without having to pass them around. + +3. Alternate define-by-run interpretations. Examples: ``V.ops``, ``V.kernel``. + A commonly used IR in Inductor is define-by-run: instead of maintaining + explicit syntax data structures, we instead represent loop bodies as + callable functions, which internally invoke operations defined on + ``V.ops``. To perform semantic analysis, print or code generate these + operations, we dynamically patch ``V.ops`` with an alternate handler with + the intended semantics and then run the callable function. For example, to + extract out a traditional (FX) graph representation of the define-by-run + IR, simply install a handler that records each ``ops`` call to a graph. + + TODO: Define a parent class / protocol that defines all of the operations + V.ops is expected to support. + +It is typically an error to access a virtualized global without having installed +an appropriate handler (you will get a NullHandler), although in some cases we +provide a default implementation. + +One last thing: although most virtualized globals are accessed via ``V``, ``ops`` is +ubiquitous enough to have its own top level variable, so you will typically see +``ops.constant(...)`` rather than ``V.ops.constant(...)``. In fact, these are not +equivalent; the former interface supports arithmetic overloads like ``x + y`` +instead of forcing ``ops.add(x, y)``, so it should be preferred. + +Some operators are seemingly unused, but they are implicitly used by ops_wrapper. +In particular, we typically have an operator for every basic pointwise PyTorch operation +supported. +""" + +from __future__ import annotations + +from contextlib import AbstractContextManager, contextmanager +from threading import local +from typing import Any, Callable, Generic, List, Type, TYPE_CHECKING, TypeVar, Union + +from .ops_handler import ( # noqa: F401 + KernelFormatterHandler, + MockHandler, + OpsHandler, + ReductionType, + StoreMode, + WrapperHandler, +) + + +if TYPE_CHECKING: + import torch + from torch._inductor.codegen.cpp_utils import LocalBufferContext + from torch._inductor.debug import DebugContext + from torch._inductor.graph import GraphLowering + from torch._inductor.loop_body import InterpreterShim + from torch._subclasses import FakeTensorMode + +threadlocal = local() + +T = TypeVar("T") + + +class NullHandler: + """ + Sentinel indicating that a global variable is unset ala None. Typically, + attempting to access the global variable before it's set is an error, but with + NullHandler it won't fail until you try to access an attribute on it. + """ + + +class Virtualized(Generic[T]): + """ + Implements a global variable that redirects via thread local variable + (NB: construct this class to create the global variable; this is not + a singleton class!) + + This allows us to swap in different op implementations in codegen. + + NB: Despite the fact that we typically call these "handlers" (e.g., NullHandler is + the default value of the variable), we sometimes use these variables to + store other things, like booleans. + """ + + def __init__(self, vname: str, default: Union[Callable[[], T], Type[NullHandler]]): + self._key: str = f"__torchinductor_{vname}" + self._default = default + + def _set_handler(self, value: T) -> AbstractContextManager[None]: + prior = self._get_handler() + setattr(threadlocal, self._key, value) + + @contextmanager + def ctx(): + try: + yield + finally: + self._set_handler(prior) + + return ctx() + + def _get_handler(self) -> T: + try: + return getattr(threadlocal, self._key) + except AttributeError: + # TODO: To be honest, I feel we probably should just error in this + # case, instead of making a null handler that will probably error + # when you getattr on it + return self._default() # type: ignore[return-value] + + def __getattr__(self, name: str) -> Any: + return getattr(self._get_handler(), name) + + +class NullKernelHandler(NullHandler): + """ + We need access `V.kernel.removed_buffers` in DeferredLine class when there + is no kernel in the context. This happens when codegening the wrapper. + Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't + need call 'getattr' with default value which is error prone to typo in + attribute name. + """ + + def __init__(self): + super().__init__() + self.removed_buffers = set() + self.inplaced_to_remove = set() + self.index_dtype = "tl.int64" + + +_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler) +_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler) +_real_inputs: Virtualized[List[torch.Tensor]] = Virtualized("real_inputs", NullHandler) +_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler) +_kernel: Virtualized[NullKernelHandler] = Virtualized( + "kernel", NullKernelHandler +) # TODO: improve type +_debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler) +_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler) +_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler) +_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler) +_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized( + "local_buffer_context", NullHandler +) + + +class OpsValue: + """The return type of most ops calls. + + This exists so we can overload magic methods, and write mathematical + expressions much more fluently. So instead of + + ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1) + + we can write + + (_Ap2 * x - _Ap3) * x * x + _1 + + """ + + value: Any + + def __init__(self, value): + self.value = value + + def __str__(self): + return str(self.value) + + def __repr__(self): + return f"OpsValue({self.value!r})" + + def __add__(self, other): + return ops.add(self, other) + + def __mul__(self, other): + return ops.mul(self, other) + + def __sub__(self, other): + return ops.sub(self, other) + + def __neg__(self): + return ops.neg(self) + + def __truediv__(self, other): + return ops.truediv(self, other) + + def __floordiv__(self, other): + return ops.floordiv(self, other) + + def __mod__(self, other): + return ops.mod(self, other) + + def __pow__(self, other): + return ops.pow(self, other) + + def __lt__(self, other): + return ops.lt(self, other) + + def __le__(self, other): + return ops.le(self, other) + + def __eq__(self, other): + return ops.eq(self, other) + + def __ne__(self, other): + return ops.ne(self, other) + + def __gt__(self, other): + return ops.gt(self, other) + + def __ge__(self, other): + return ops.ge(self, other) + + def __and__(self, other): + return ops.bitwise_and(self, other) + + def __or__(self, other): + return ops.bitwise_or(self, other) + + def __xor__(self, other): + return ops.bitwise_xor(self, other) + + def __invert__(self): + return ops.bitwise_not(self) + + def __rshfit__(self, n): + return ops.bitwise_right_shift(self, n) + + def __lshift__(self, n): + return ops.bitwise_left_shift(self, n) + + +class OpsWrapper: + """This wraps any returned IR values into an `OpsValue` instance, so that we + can overload the magic methods for writing mathematical expressions fluently. + """ + + def __getattr__(self, name): + def inner(*args, **kwargs): + new_args = [OpsWrapper._unwrap(a) for a in args] + new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()} + return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs)) + + return inner + + @staticmethod + def _unwrap(x): + if isinstance(x, (list, tuple)): + return tuple(OpsWrapper._unwrap(v) for v in x) + if isinstance(x, OpsValue): + return x.value + return x + + @staticmethod + def _wrap(x): + if isinstance(x, (list, tuple)): + return tuple(OpsValue(v) for v in x) + return OpsValue(x) + + @staticmethod + def indirect_indexing(index, size, check=True, wrap_neg=True): + # Returns a sympy value, not IR value + index = OpsWrapper._unwrap(index) + return _ops.indirect_indexing(index, size, check, wrap_neg) + + +ops = OpsWrapper() + + +class _V: + MockHandler = MockHandler + KernelFormatterHandler = KernelFormatterHandler + WrapperHandler = WrapperHandler + + set_ops_handler: Callable[[Any], Any] = _ops._set_handler + get_ops_handler: Callable[[], Any] = _ops._get_handler + set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler + set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler + get_real_inputs: Callable[[], Any] = _real_inputs._get_handler + set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler + get_fake_mode: Callable[[], Any] = _fake_mode._get_handler + set_kernel_handler: Callable[[Any], Any] = _kernel._set_handler + set_debug_handler: Callable[[Any], Any] = _debug._set_handler + set_interpreter_handler: Callable[[Any], Any] = _interpreter._set_handler + set_aot_compilation: Callable[[bool], Any] = _aot_compilation._set_handler + get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler + set_current_node: Callable[[Any], Any] = _current_node._set_handler + get_current_node: Callable[[], Any] = _current_node._get_handler + set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler + get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler + + @property + def ops(self) -> OpsHandler[Any]: + """The operator handler specific to the current codegen task""" + return _ops._get_handler() + + @property + def graph(self) -> GraphLowering: + """The graph currently being generated""" + return _graph._get_handler() + + @property + def real_inputs(self): + """non-fake example inputs""" + return _real_inputs._get_handler() + + @property + def fake_mode(self): + """The graph currently being generated""" + return _fake_mode._get_handler() + + @property + def kernel(self): + """The kernel currently being generated""" + return _kernel._get_handler() + + @property + def debug(self): + return _debug._get_handler() + + @property + def interpreter(self): + return _interpreter._get_handler() + + @property + def aot_compilation(self): + return _aot_compilation._get_handler() + + @property + def current_node(self): + return _current_node._get_handler() + + @property + def local_buffer_context(self): + return _local_buffer_context._get_handler() + + +V = _V() diff --git a/lib/python3.10/site-packages/torch/_inductor/wrapper_benchmark.py b/lib/python3.10/site-packages/torch/_inductor/wrapper_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd1f0fc95b7f905eab8457a733d4b036849c443 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_inductor/wrapper_benchmark.py @@ -0,0 +1,315 @@ +# mypy: allow-untyped-defs +import dataclasses +import tempfile +from collections import defaultdict + +import torch +from torch.autograd import DeviceType + +from .runtime.benchmarking import benchmarker +from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes + + +_kernel_category_choices = [ + "foreach", + "persistent_reduction", + "pointwise", + "reduction", + "split_scan", + "template", +] + + +def get_kernel_category_by_source_code(src_code): + """ + Similar to get_kernel_category but use the source code. Call this API + if we have not compile the src_code to module yet. + """ + choices = [ + ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code + ] + if len(choices) == 1: + return choices[0] + else: + return "unknown" + + +def get_kernel_category(kernel_mod): + """ + Given the module defining a triton kernel, return the category of the kernel. + Category can be one of: + - pointwise + - reduction + - persistent_reduction + + Currently we simply decide the category depending on what decorator is imported + by the kernel. + """ + choices = [ch for ch in _kernel_category_choices if ch in kernel_mod.__dict__] + if len(choices) == 1: + return choices[0] + else: + return "unknown" + + +def get_triton_kernel(mod): + from torch._inductor.runtime.triton_heuristics import CachingAutotuner + + cand_list = [ + v + for k, v in mod.__dict__.items() + if k.startswith("triton_") and isinstance(v, CachingAutotuner) + ] + assert len(cand_list) == 1 + return cand_list[0] + + +def benchmark_all_kernels(benchmark_name, benchmark_all_configs): + """ + An experimental API used only when config.benchmark_kernel is true. + + Run the kernel benchmarks for all the kernels cached in PyCodeCache. + Used in the compiled modules. + + Put this method here rather than codegen it for convenience since its implementation + does not change based on different graph modules being compiled. + """ + from torch._inductor.codecache import PyCodeCache + + nfound = 0 + for kernel_key, kernel_mod in PyCodeCache.cache.items(): + if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"): + continue + + triton_kernel = get_triton_kernel(kernel_mod) + kernel_category = get_kernel_category(kernel_mod) + args = kernel_mod.get_args() + num_in_out_ptrs = len( + [ + arg_name + for arg_name in triton_kernel.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + + def get_info_str(ms, n_regs, n_spills, shared, prefix=""): + if not any(x is None for x in [n_regs, n_spills, shared]): + kernel_detail_str = ( + f" {n_regs:3} regs {n_spills:3} spills {shared:8} shared mem" + ) + else: + kernel_detail_str = "" + + gb_per_s = num_gb / (ms / 1e3) + return create_bandwidth_info_str( + ms, num_gb, gb_per_s, prefix=prefix, suffix=kernel_detail_str + ) + + kernel_desc = ( + f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}" + ) + if benchmark_all_configs: + assert hasattr(kernel_mod, "benchmark_all_configs") + bench_result = kernel_mod.benchmark_all_configs(args) + print(kernel_desc) + for launcher, ms in bench_result.items(): + print( + f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}" + ) + else: + ms = benchmarker.benchmark_gpu( + lambda: kernel_mod.call(args), rep=40, fast_flush=True + ) + assert ( + len(triton_kernel.launchers) == 1 + ), "Autotuner should have selected the best config" + launcher = triton_kernel.launchers[0] + print( + get_info_str( + ms, + launcher.n_regs, + launcher.n_spills, + launcher.shared, + prefix=f"{kernel_desc} ", + ) + ) + + nfound += 1 + if nfound == 0: + print( + "No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True" + ) + + +@dataclasses.dataclass +class ProfileEvent: + category: str + key: str + self_device_time_ms: float + # the benchmark is run multiple times and we average the count across all the + # runs. It should be an integer but define a float just in case. + count: float + + +def parse_profile_event_list( + benchmark_name, event_list, wall_time_ms, nruns, device_name +): + def get_self_device_time(ev): + """ + ev.self_device_time_total is in microsecond. Convert to millisecond. + """ + return ev.self_device_time_total / 1000 / nruns + + all_events = defaultdict(list) + + def add_event(ev, category): + profile_ev = ProfileEvent( + category=category, + key=ev.key, + self_device_time_ms=get_self_device_time(ev), + count=ev.count / nruns, # average across all runs + ) + all_events[category].append(profile_ev) + + for ev in event_list: + assert not ev.is_legacy, "Don't support the legacy profiler" + if ev.device_type == DeviceType.CPU: + # ignore the event on CPU side + continue + + category = "unknown" + if ev.key.startswith("triton_"): + if ev.key.startswith("triton_poi"): + category = "triton_pointwise" + elif ev.key.startswith("triton_red"): + category = "triton_reduction" + elif ev.key.startswith("triton_per"): + category = "triton_persistent_reduction" + else: + category = "triton_unknown" + + add_event(ev, category) + + def report_category(category, profile_events): + from tabulate import tabulate + + profile_events.sort(key=lambda ev: ev.self_device_time_ms, reverse=True) + + rows = [] + total_time = 0.0 + print(f"\n == {category} category kernels == ") + for ev in profile_events: + total_time += ev.self_device_time_ms + percent = f"{ev.self_device_time_ms / wall_time_ms * 100:.2f}%" + rows.append([ev.key[:120], ev.self_device_time_ms, ev.count, percent]) + rows.append( + ["Total", total_time, "", f"{total_time / wall_time_ms * 100:.2f}%"] + ) + print( + tabulate( + rows, + headers=[ + "Kernel", + f"Self {device_name.upper()} TIME (ms)", + "Count", + "Percent", + ], + ) + ) + return total_time + + def report(): + category_list = [ + "triton_pointwise", + "triton_reduction", + "triton_persistent_reduction", + "triton_unknown", + "unknown", + ] + assert set(all_events.keys()).issubset( + set(category_list) + ), f"{list(all_events.keys())}" + + per_category_wall_time = {} + total_device_ms = 0.0 + for category in category_list: + if category in all_events: + _time = report_category(category, all_events[category]) + per_category_wall_time[category] = _time + total_device_ms += _time + + device_busy_percent = f"{total_device_ms / wall_time_ms * 100:.2f}%" + print( + f"\nPercent of time when {device_name.upper()} is busy: {device_busy_percent}" + ) + print(f"Total wall time {wall_time_ms:.3f} ms") + + # output such a line so we can gather such line from all compiled modules from all + # benchmarks and tabulate it! + # Columns: benchmark_name, pointwise_percent, reduction_percent, persistent_reduction_percent, + # unknown_category_percent, device_busy_percent, wall_time_ms + tabulate_line = f"Output for tabulate: {benchmark_name}" + for category in category_list: + percent = ( + f"{per_category_wall_time.get(category, 0.0) / wall_time_ms * 100:.2f}%" + ) + tabulate_line += f", {percent}" + tabulate_line += f", {device_busy_percent}, {wall_time_ms:.3f}ms" + + print(tabulate_line) + + report() + + +def compiled_module_main(benchmark_name, benchmark_compiled_module_fn): + """ + This is the function called in __main__ block of a compiled module. + """ + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--benchmark-kernels", + "-k", + action="store_true", + help="Whether to benchmark each individual kernels", + ) + parser.add_argument( + "--benchmark-all-configs", + "-c", + action="store_true", + help="Whether to benchmark each individual config for a kernel", + ) + parser.add_argument( + "--profile", + "-p", + action="store_true", + help="Whether to profile the compiled module", + ) + args = parser.parse_args() + + if args.benchmark_kernels: + benchmark_all_kernels(benchmark_name, args.benchmark_all_configs) + else: + times = 10 + repeat = 10 + wall_time_ms = benchmark_compiled_module_fn(times=times, repeat=repeat) * 1000 + + if not args.profile: + return + + with torch.profiler.profile(record_shapes=True) as p: + benchmark_compiled_module_fn(times=times, repeat=repeat) + + path = f"{tempfile.gettempdir()}/compiled_module_profile.json" + p.export_chrome_trace(path) + print(f"Profiling result for a compiled module of benchmark {benchmark_name}:") + print(f"Chrome trace for the profile is written to {path}") + event_list = p.key_averages(group_by_input_shape=True) + print(event_list.table(sort_by="self_device_time_total", row_limit=10)) + parse_profile_event_list( + benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device + ) diff --git a/lib/python3.10/site-packages/torch/_lazy/__init__.py b/lib/python3.10/site-packages/torch/_lazy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d90efa40e58841a11a25569ca6722b791894999 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_lazy/__init__.py @@ -0,0 +1,55 @@ +# mypy: allow-untyped-defs + +import torch._C._lazy +from torch.utils._pytree import tree_flatten, tree_unflatten + +from .closure import add_step_closure, run_step_closures + + +def mark_step(device: str = "", wait=False): + """Triggers a mark step, which amounts to + - collecting a group of 'live' lazy tensors to index into the compilation cache + (lowering/compiling their IR graphs if not cached) + - kicking off execution of the compiled function + - (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator) + """ + # TODO(whc) expand this to include backend hooks and align with XLA backend needs + torch._C._lazy._mark_step(device, [], wait=wait) + + run_step_closures() + + +def wait_device_ops(devices=None): + """Waits for all the async operations on the given devices to complete. + Args: + devices (string..., optional): The devices whose async ops need to be waited + for. If empty, all the local devices will be waited for. + """ + if devices is None: + devices = [] + torch._C._lazy._wait_device_ops(devices=devices) + + +def sync_multi(tensors, devices): + """ + Sync the list of lazy tensors so there IR get lowered for the activate backend + and the compiled computation graph get cached. + """ + torch._C._lazy._sync_multi(tensors, devices) + + +def get_tensor_id(tensor): + """Return a unique id of the lazy tensor maintained by LTC""" + return torch._C._lazy._get_tensor_id(tensor) + + +def to_cpu(tensors, devices=None): + devices = devices or ["lazy"] + + flattened, spec = tree_flatten(tensors) + sync_multi(flattened, devices) + return tree_unflatten([t.to("cpu") for t in flattened], spec) + + +def save(tensors, *args, **kwargs): + torch.save(to_cpu(tensors), *args, **kwargs) diff --git a/lib/python3.10/site-packages/torch/_lazy/closure.py b/lib/python3.10/site-packages/torch/_lazy/closure.py new file mode 100644 index 0000000000000000000000000000000000000000..94c12c075a092b9f70db02e5f280f38c6f94f050 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_lazy/closure.py @@ -0,0 +1,135 @@ +# mypy: allow-untyped-defs +import os +import threading +from queue import Empty as EmptyQueue, Queue + +from torch._lazy.device_context import get_device_context + + +class ClosureHandler: + def __init__(self) -> None: + pass + + def run(self, closure): + """Run closure function + + Args: + closure: callable function to run + """ + closure() + + def __call__(self, closures): + for closure in closures: + self.run(closure) + + +class AsyncClosureHandler(ClosureHandler): + """Handler for Asynchronous Step Closures + Args: + max_queue_size: The maximum length of the closure queue after which + the training loop will block until closures are evaluated. + By default, a reasonable limit of a maximum of 100 on the queue. + This value can be set using the `XLA_MAX_ASYNC_QUEUE` environment + variable. + """ + + def __init__(self, max_queue_size=100): + super().__init__() + self._closure_queue: Queue = Queue( + int(os.environ.get("LTC_MAX_ASYNC_QUEUE", max_queue_size)) + ) + self._closure_exception: Queue = Queue() + self._closure_lock = threading.Lock() + self._closure_event_loop_finished = threading.Event() + self._closure_event_loop = None + + def start_event_loop(self): + """Start closure event loop if not started""" + if self._closure_event_loop is None: + + def event_loop(): + # Run loop until closure event is set and closure queue is empty + while True: + try: + closure = self._closure_queue.get(block=True, timeout=3) + closure() + self._closure_queue.task_done() + except EmptyQueue: + with self._closure_lock: + if self._closure_queue.empty(): + self._closure_event_loop_finished.set() + return + except Exception as e: + self._closure_exception.put(e) + return + + self._closure_event_loop = threading.Thread(target=event_loop) + self._closure_event_loop.start() + + def run(self, closure): + with self._closure_lock: + self._closure_queue.put(closure, block=True) + if ( + self._closure_event_loop is None + or not self._closure_event_loop.is_alive() + ): + try: + e = self._closure_exception.get(block=False) + raise RuntimeError( + "Cannot run asynchronous closure due to previously raised exception" + ) from e + except EmptyQueue: + self._closure_event_loop = None + self.start_event_loop() + + +def add_step_closure(closure, args=(), run_async=False): + """Adds a closure to the list of the ones to be run at the end of the step. + Many times during model training there is the need to print/report (print to + console, post to tensorboard, etc...) information which require the content of + intermediary tensors to be inspected. + Inspecting different tensors content in different points of the model code + requires many executions and typically causes performance issues. + Adding a step closure will ensure that it will be run after the barrier, when + all the live tensors will be already materialized to device data. + Live tensors which will include the ones captured by the closure arguments. + So using `add_step_closure()` will ensure a single execution will be + performed, even when multiple closures are queued, requiring multiple tensors + to be inspected. + Step closures will be run sequentially in the order they have been queued. + Note that even though using this API the execution will be optimized, it is + advised to throttle the printing/reporting events once every N steps. + Args: + closure (callable): The function to be called. + args (tuple): The arguments to be passed to the closure. + run_async: If True, run the closure asynchronously. + """ + devctx = get_device_context() + closures_type = "async_step_closures" if run_async else "step_closures" + step_closures = getattr(devctx, closures_type, None) + if step_closures is None: + step_closures = [] + setattr(devctx, closures_type, step_closures) + step_closures.append(lambda a=args: closure(*a)) + + +def run_step_closures(): + devctx = get_device_context() + async_step_closures = getattr(devctx, "async_step_closures", None) + if async_step_closures is not None: + devctx.async_step_closures = [] + async_closure_handler = getattr(devctx, "async_closure_handler", None) + if async_closure_handler is None: + async_closure_handler = AsyncClosureHandler() + devctx.async_closure_handler = async_closure_handler + async_closure_handler(async_step_closures) + + step_closures = getattr(devctx, "step_closures", None) + if step_closures is not None: + devctx.step_closures = [] + closure_handler = getattr(devctx, "closure_handler", None) + if closure_handler is None: + closure_handler = ClosureHandler() + devctx.closure_handler = closure_handler + closure_handler(step_closures) + return devctx diff --git a/lib/python3.10/site-packages/torch/_lazy/computation.py b/lib/python3.10/site-packages/torch/_lazy/computation.py new file mode 100644 index 0000000000000000000000000000000000000000..17a61e36cb9f2a46461d14caa3c1a3ff6e8c9094 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_lazy/computation.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +import torch._C._lazy +import torch._C._lazy_ts_backend + + +def get_tensors_ts_device_data_node(tensors): + """Return tensor ids and eager tensors for DeviceData nodes in the + IR for the passed in lazy tensors. + + TODO: This API is currently ts backend specific. We are working on + generalizing it to all backends including XLA. + """ + return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors) + + +def get_graph_hash(tensors): + """Return the graph hash for the passed in lazy tensors""" + return torch._C._lazy._get_graph_hash(tensors) + + +def run_cached_graph(hash_str, graph_inputs): + """Running the cached computation graph with the given inputs + + TODO: This API is currently ts backend specific. We are working on + generalizing it to all backends including XLA. + """ + return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs) diff --git a/lib/python3.10/site-packages/torch/_lazy/config.py b/lib/python3.10/site-packages/torch/_lazy/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ebca12de7fc44c27a2b3ae7c2ed1c7d8097c99 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_lazy/config.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def get_force_fallback(): + """Get the config used to force LTC fallback""" + return torch._C._lazy._get_force_fallback() + + +def set_force_fallback(configval): + """Set the config used to force LTC fallback""" + torch._C._lazy._set_force_fallback(configval) + + +def set_reuse_ir(val: bool): + """Set the config to reuse IR nodes for faster tracing""" + torch._C._lazy._set_reuse_ir(val) diff --git a/lib/python3.10/site-packages/torch/_lazy/debug.py b/lib/python3.10/site-packages/torch/_lazy/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..84534fb232509f0c9bbe722820bd1ae649d53e07 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_lazy/debug.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def render_ir_graph(tensors): + """Return a text dump of the LTC IR graph in dot format for the tensors. + The text can be processed by tools like dot to be rendered in pdf,png etc.""" + return torch._C._lazy._get_tensors_dot(tensors) + + +def dump_ir(tensors, ir_format): + """Return a dump of the tensors in the specified format. + Valid format are + - text: for LTC IR + - backend: for the activate backend IR + """ + if ir_format == "text": + return torch._C._lazy._get_tensors_text(tensors) + elif ir_format == "backend": + return torch._C._lazy._get_tensors_backend(tensors) + else: + raise RuntimeError(f"Unrecognized IR format: {ir_format}") diff --git a/lib/python3.10/site-packages/torch/_lazy/device_context.py b/lib/python3.10/site-packages/torch/_lazy/device_context.py new file mode 100644 index 0000000000000000000000000000000000000000..e09fdab3f7458cc6a410a1736b89e68b4a4eef17 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_lazy/device_context.py @@ -0,0 +1,26 @@ +# mypy: allow-untyped-defs +import threading +from typing import Any, Dict + +import torch._C._lazy + + +class DeviceContext: + _CONTEXTS: Dict[str, Any] = {} + _CONTEXTS_LOCK = threading.Lock() + + def __init__(self, device): + self.device = device + + +def get_device_context(device=None): + if device is None: + device = torch._C._lazy._get_default_device_type() + else: + device = str(device) + with DeviceContext._CONTEXTS_LOCK: + devctx = DeviceContext._CONTEXTS.get(device, None) + if devctx is None: + devctx = DeviceContext(device) + DeviceContext._CONTEXTS[device] = devctx + return devctx diff --git a/lib/python3.10/site-packages/torch/_lazy/extract_compiled_graph.py b/lib/python3.10/site-packages/torch/_lazy/extract_compiled_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..f46eea4eee9b79033aa22ce2bcc77ba9f650c622 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_lazy/extract_compiled_graph.py @@ -0,0 +1,225 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import itertools +import os +from typing import Any, Callable, Dict, List + +import torch +import torch._lazy as lazy +import torch._lazy.metrics as metrics +from torch import fx +from torch._lazy import computation, debug as lazy_debug +from torch._lazy.tensor_factory_functions import tensor_factory_functions + + +debug = os.environ.get("debug_extract_compiled_graph") is not None + + +@dataclasses.dataclass +class GraphInputMatcher: + """ + The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing. + Specifically, those graph inputs corresponding to method parameters should be replaced with the + arguments for the current call. + + tensor_id_to_arg_idx maps the tensor id to the parameter index. + graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the + TS/XLA graph inputs. + """ + + tensor_id_to_arg_idx: Dict[int, int] + graph_input_tensor_ids: List[int] + # there are 2 categories of graph_input_tensors. + # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are + # most likely const tensors and we can get its content from graph_input_tensors + # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get + # the tensor from method arguments + graph_input_ivalues: List[Any] + + # get the real graph input tensors + def __call__(self, args): + real_input = [] + for tensor_id, traced_ivalue in zip( + self.graph_input_tensor_ids, self.graph_input_ivalues + ): + arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None) + if arg_idx is None: + inp = traced_ivalue + else: + inp = args[arg_idx] + real_input.append(inp) + return real_input + + +class ReturnValueHandler: + r""" + When ltc_sync_multi is called on multi tensors, the compiled graph + will contain output only for unique tensors - if a tensor appears multiple + times in the input to _ltc_sync_multi, only the first occurance matters. + + However from python level, we still expect multi tensors returned with duplciation + even if the TS graph dedup the output. e.g. for method: + + def forward(self, a): + return a, a + + the TS graph captured by LTC will return a single tensor, but Python method expects 2. + + This class dedup the lazy tensors first to get the index that will be used + to duplicate the eager tensors later. + """ + + def __init__(self, lazy_out_list): + self.index: List[List[int]] = [] + self.total_count = len(lazy_out_list) + + tensor_id_to_idx: Dict[int, int] = {} + for dup_idx, lazy_tensor in enumerate(lazy_out_list): + uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None) + if uniq_idx is not None: + self.index[uniq_idx].append(dup_idx) + else: + uniq_idx = len(self.index) + self.index.append([dup_idx]) + tensor_id_to_idx[id(lazy_tensor)] = uniq_idx + + def duplicate_eager_tensors(self, eager_tensor_list): + duplicated_list = [None] * self.total_count + assert len(eager_tensor_list) == len(self.index) + + for uniq_idx, eager_tensor in enumerate(eager_tensor_list): + for dup_idx in self.index[uniq_idx]: + duplicated_list[dup_idx] = eager_tensor + return duplicated_list + + +def force_lazy_device(model: fx.GraphModule): + """ + Factory methods in a Fx graph may create tensors for a specific eager devices. + If we take no actions, those eager tensors will be mixed with lazy tensors and + cause crash. This method overwrite those eager device to lazy device. + """ + + def tolazydevice(dev): + if isinstance(dev, torch.device): + return torch.device("lazy", index=dev.index) + return dev + + def hasDeviceArg(args, kwargs): + return any( + isinstance(arg, torch.device) + for arg in itertools.chain(args, kwargs.values()) + ) + + for nd in model.graph.nodes: + nd.args = tuple(tolazydevice(arg) for arg in nd.args) + nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()} + + # For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return + # eager tensors on the default device + # (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove, + # and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart). + # To force those tensors on the lazy device, we can not simply override + # the device argument since there is no explicit device argument. + # What we are doing here is, for the list of covered tensor factory methods + # we add a lazy device argument explicity. + # + # TODO: This solution is no ideal since we may miss some factory methods. In future + # when we support lazy mode, this method can be replaced by that. + if nd.target in tensor_factory_functions and not hasDeviceArg( + nd.args, nd.kwargs + ): + kwargs = dict(nd.kwargs) # nd.kwargs is immutable. make a mutable copy. + kwargs["device"] = torch.device("lazy") + nd.kwargs = kwargs + + model.recompile() + + +def get_fallback_ops(): + fallback_ops = [] + for opname in metrics.counter_names(): + if "aten::" not in opname: + continue + val = int(metrics.counter_value(opname)) + if val > 0: + fallback_ops.append(f"{opname}={val}") + + return fallback_ops + + +def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable: + """ + Optimize an eager model with LTC and returns a wrapper to execute the + compiled graph directly without retracing. It depends on other mechanisms + like TorchDynamo guards to guarantee the returned wrapper is only called + when it's safe. + """ + lazy_args = [arg.to(device="lazy") for arg in example_inputs] + args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args] + tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)} + lazy_model = copy.deepcopy(model).to(device=torch.device("lazy")) + force_lazy_device(lazy_model) + + # This line executes lazy tracing and enable us extracting compiled graph later + metrics.reset() + lazy_out = lazy_model(*lazy_args) + fallback_ops = get_fallback_ops() + metrics.reset() + + if len(fallback_ops) > 0: + raise RuntimeError( + f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}" + ) + + if not isinstance(lazy_out, (tuple, list)): + lazy_out = (lazy_out,) + + args_and_out = tuple(lazy_args) + tuple(lazy_out) + return_value_handler = ReturnValueHandler(args_and_out) + if debug: + print("Fx code:\n", model.code) + print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text")) + + # TODO: this part is TS backend specific for now and will be generalized to + # support XLA + ( + graph_input_tensor_ids, + graph_input_ivalues, + ) = computation.get_tensors_ts_device_data_node(args_and_out) + assert len(graph_input_tensor_ids) == len(graph_input_ivalues) + graph_input_matcher = GraphInputMatcher( + tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues + ) + + graph_hash = computation.get_graph_hash(args_and_out) + + if debug: + print("graph_hash", graph_hash) + print(f"args_tensor_ids {args_tensor_ids}") + print("tensor ids from device data:", graph_input_tensor_ids) + + # sync the list of output tensors so the computation graph for these + # tensors will be cached. Those computation graphs can be retrieved + # by graph hash later. + lazy.sync_multi(args_and_out, []) + + def optimized_mod(*args): + if len(args_and_out) == 0: + return () + graph_input = graph_input_matcher(args) + res = return_value_handler.duplicate_eager_tensors( + computation.run_cached_graph(graph_hash, graph_input) + ) + + assert len(res) == len(args_and_out) + for i, arg in enumerate(args): + # only copy those tensors that get inplace updated + if arg is not res[i]: + arg.copy_(res[i]) + + # skip the args + return res[len(args) :] + + return optimized_mod diff --git a/lib/python3.10/site-packages/torch/_lazy/ir_cache.py b/lib/python3.10/site-packages/torch/_lazy/ir_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e654566f29bce166eb52e721b694f3b1f7862b --- /dev/null +++ b/lib/python3.10/site-packages/torch/_lazy/ir_cache.py @@ -0,0 +1,14 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def dump(dot_file_name: str): + """Dump TrieCache in the dot format""" + return torch._C._lazy._dump_ir_cache(dot_file_name) + + +def reset(): + """Clear TrieCache. This is needed in testing to avoid + node reusing between different tests. + """ + return torch._C._lazy._clear_ir_cache() diff --git a/lib/python3.10/site-packages/torch/_lazy/metrics.py b/lib/python3.10/site-packages/torch/_lazy/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..a77981feb90dbd74eb0a31ae86fe661a758a494a --- /dev/null +++ b/lib/python3.10/site-packages/torch/_lazy/metrics.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def reset(): + """Resets all metric counters.""" + torch._C._lazy._reset_metrics() + + +def counter_names(): + """Retrieves all the currently active counter names.""" + return torch._C._lazy._counter_names() + + +def counter_value(name: str): + """Return the value of the counter with the speficied name""" + return torch._C._lazy._counter_value(name) + + +def metrics_report(): + """Return the combined (lazy core and backend) metric report""" + return torch._C._lazy._metrics_report() diff --git a/lib/python3.10/site-packages/torch/_lazy/tensor_factory_functions.py b/lib/python3.10/site-packages/torch/_lazy/tensor_factory_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..3b8ddc8b11c7e036ba6beac440d04eb1835b26d4 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_lazy/tensor_factory_functions.py @@ -0,0 +1,49 @@ +import torch + + +""" +tensor_factory_functions defines the list of torch functions that create tensors. +The list is grabbed by searching thru native_functions.yaml by the following +regular expression: + + cat native_functions.yaml | grep 'func:' | grep -v "Tensor.*->" | grep "[-]>.*Tensor" + +It's possible that new tensor factory functions are added making this list stale. +Use at your own risk or regenerate the list. +""" +tensor_factory_functions = ( + torch._cudnn_init_dropout_state, + torch.arange, + torch.bartlett_window, + torch.blackman_window, + torch._empty_affine_quantized, + torch.empty_strided, + torch.eye, + torch.full, + torch.from_file, + torch.hann_window, + torch.hamming_window, + torch.kaiser_window, + torch.linspace, + torch.logspace, + torch.ones, + torch.scalar_tensor, + torch.rand, + torch.randint, + torch.randn, + torch.randperm, + torch.range, + torch._efficientzerotensor, + torch.zeros, + torch.tril_indices, + torch.triu_indices, + # Note: the following functions match the regular expression search above but + # they are not available in the torch module. Comment out. + # torch._sparse_coo_tensor_with_dims, + # torch.fft_fftfreq, + # torch.fft_rfftfreq, +) + ( + # torch.tensor is special since it's not in native_functions.yaml + # add it separately + torch.tensor, +) diff --git a/lib/python3.10/site-packages/torch/_lazy/ts_backend.py b/lib/python3.10/site-packages/torch/_lazy/ts_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6ce13746e913db8e27081b8b0dcf8f4e0d4c88 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_lazy/ts_backend.py @@ -0,0 +1,7 @@ +# mypy: allow-untyped-defs +import torch._C._lazy_ts_backend + + +def init(): + """Initializes the lazy Torchscript backend""" + torch._C._lazy_ts_backend._init() diff --git a/lib/python3.10/site-packages/torch/_library/__init__.py b/lib/python3.10/site-packages/torch/_library/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5db97310dec0c482eee3a4af85c765f5832599cb --- /dev/null +++ b/lib/python3.10/site-packages/torch/_library/__init__.py @@ -0,0 +1,6 @@ +import torch._library.autograd +import torch._library.fake_impl +import torch._library.simple_registry +import torch._library.utils +from torch._library.fake_class_registry import register_fake_class +from torch._library.triton import capture_triton, triton_op diff --git a/lib/python3.10/site-packages/torch/_library/autograd.py b/lib/python3.10/site-packages/torch/_library/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..75997ec63eb19243143c1e1d824405ed85bc10d1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_library/autograd.py @@ -0,0 +1,241 @@ +# mypy: allow-untyped-defs +import dataclasses +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Protocol + +from torch import _C, _ops, autograd, Tensor +from torch.utils import _pytree + +from . import utils + + +class InfoProtocol(Protocol): + _backward_fn: Optional[Callable] + _setup_context_fn: Optional[Callable] + + +@dataclasses.dataclass +class Info: + _backward_fn: Optional[Callable] + _setup_context_fn: Optional[Callable] + + +def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable: + name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}" + + has_kwarg_only_args = utils.has_kwarg_only_args(op._schema) + + @dataclass + class Metadata: + keyset: _C.DispatchKeySet + keyword_only_args: Dict[str, Any] + + def forward_no_grad(*args): + metadata = args[-1] + args = args[:-1] + + with _C._AutoDispatchBelowAutograd(): + keyset = metadata.keyset + kwargs = metadata.keyword_only_args + result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs) + return result + + def forward(ctx, *args): + metadata = args[-1] + args = args[:-1] + + with _C._AutoDispatchBelowAutograd(): + keyset = metadata.keyset + kwargs = metadata.keyword_only_args + result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs) + if info._setup_context_fn: + # The Dispatcher will remove args that are equal to their default + # values from (args, kwargs). We're going to add it back so that + # the user can access them. + # + # This is OK to do: The Dispatcher removed the args for serialization + # FC/BC reasons (that is, a graph will not store args that are equal + # to their default values), but that doesn't matter here. If the user + # adds a new default arg, then they must update + # their setup_context (along with the rest of their operator + # registrations) + args, kwargs = utils.fill_defaults(op._schema, args, kwargs) + + if has_kwarg_only_args: + info._setup_context_fn( + ctx=ctx, inputs=args, keyword_only_inputs=kwargs, output=result + ) + else: + info._setup_context_fn(ctx=ctx, inputs=args, output=result) + return result + + def backward(ctx, *grads): + if info._backward_fn: + try: + prev_needs_input_grad = ctx.needs_input_grad + ctx.needs_input_grad = ctx.needs_input_grad[:-1] + result = info._backward_fn(ctx, *grads) + finally: + ctx.needs_input_grad = prev_needs_input_grad + if isinstance(result, tuple): + return (*result, None) + return result, None + raise RuntimeError( + f"Trying to backward through {op} but no autograd " + f"formula was registered. " + f"Please use register_autograd to add one." + ) + + Generated = type( + name, + (autograd.Function,), + { + "forward": staticmethod(forward), + "backward": staticmethod(backward), + }, + ) + + schema = op._schema + if any( + utils.is_tensorlist_like_type(a.type) + for a in (*schema.arguments, *schema.returns) + ): + Generated = supports_tensorlist(Generated) + + # The dispatcher passes any keyword-only-args as kwargs and the + # rest of the args (even if specified as kwargs) as args. + def autograd_impl(keyset, *args, **keyword_only_args): + if _C.is_grad_enabled() and _pytree.tree_any_only( + Tensor, lambda x: x.requires_grad, args, not_list_of_tensor + ): + result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined] + else: + result = forward_no_grad(*args, Metadata(keyset, keyword_only_args)) + return result + + return autograd_impl + + +def supports_tensorlist(cls: Any) -> Any: + """Allows a given autograd.Function class to support List[Tensor] inputs/outputs. + + Regular autograd.Function has a constraint that it only directly supports autograd for + Tensors. Applying @supports_tensorlist enables an autograd.Function to support + autograd for List[Tensor] inputs and outputs. + """ + orig_forward = cls.forward + orig_backward = cls.backward + orig_apply = cls.apply + + @dataclass + class Metadata: + input_spec: spec_t + output_spec: Optional[spec_t] = None + result_is_tuple: Optional[bool] = None + + def new_forward(ctx, *args): + metadata = args[-1] + args = args[:-1] + if not isinstance(metadata, Metadata): + raise NotImplementedError( + "NYI: calling supports_tensorlist autograd.Function.forward directly. " + "You should probably be calling .apply instead. " + "Please file an issue if not." + ) + args = unflatten(list(args), metadata.input_spec) + result = orig_forward(ctx, *args) + metadata.result_is_tuple = isinstance(result, tuple) + if not metadata.result_is_tuple: + result = (result,) + flat_result, output_spec = flatten(result, not_list_of_tensor) + metadata.output_spec = output_spec + + if hasattr(ctx, "_pt_metadata"): + raise RuntimeError( + "Please don't set ctx._pt_metadata; PyTorch uses it to store info" + ) + ctx._pt_metadata = metadata + + return tuple(flat_result) + + def new_backward(ctx, *grads): + if not hasattr(ctx, "_pt_metadata"): + raise NotImplementedError( + "NYI: calling supports_tensorlist autograd.Function.backward directly. " + "This will automatically get called by PyTorch autograd. " + "Please file an issue if you need this." + ) + + metadata = ctx._pt_metadata + grads = unflatten(list(grads), metadata.output_spec) + + # If the user's input is ([x, y, z], w), + # then needs_input_grad is (bool, bool, bool, bool, bool). + # We need to + # 1. get rid of the additional bool (which comes from the extra + # `metadata input`) + # 2. unflatten to get the right structure. + prev_needs_input_grad = ctx.needs_input_grad + try: + ctx.needs_input_grad = unflatten( + list(ctx.needs_input_grad[:-1]), metadata.input_spec + ) + grad_inputs = orig_backward(ctx, *grads) + finally: + ctx.needs_input_grad = prev_needs_input_grad + + if not isinstance(grad_inputs, tuple): + grad_inputs = (grad_inputs,) + # Assume that any Nones in the backward are Tensors. + # If the forward has an arg that is [1, 2, 3], the backward should + # return None as the grad. + # If the forward has an arg that is [tensor, tensor], the backward + # may return [None, None], [grad, None], [None, grad], or [grad, grad]. + flat_grad_inputs, grad_inputs_spec = flatten( + grad_inputs, not_list_of_optional_tensor + ) + if grad_inputs_spec != metadata.input_spec: + raise RuntimeError( + f"Expected the return from backward to be of the same structure " + f"as the inputs. Got: {grad_inputs_spec} (return from backward), " + f"{metadata.input_spec} (inputs)" + ) + return tuple(flat_grad_inputs + [None]) + + def new_apply(*args): + flat_args, input_spec = flatten(args, is_leaf=not_list_of_tensor) + metadata = Metadata(input_spec) + result = orig_apply(*flat_args, metadata) # type: ignore[misc] + assert metadata.output_spec is not None + result = unflatten(list(result), metadata.output_spec) + if not metadata.result_is_tuple: + assert isinstance(result, tuple) + assert len(result) == 1 + return result[0] + return result + + cls.forward = new_forward + cls.backward = new_backward + cls.apply = new_apply + return cls + + +def not_list_of_tensor(tree): + if isinstance(tree, tuple): + return False + if isinstance(tree, list): + return any(not isinstance(l, Tensor) for l in tree) + return True + + +def not_list_of_optional_tensor(tree): + if isinstance(tree, tuple): + return False + if isinstance(tree, list): + return any(l is not None and not isinstance(l, Tensor) for l in tree) + return True + + +flatten = _pytree.tree_flatten +unflatten = _pytree.tree_unflatten +spec_t = _pytree.TreeSpec diff --git a/lib/python3.10/site-packages/torch/_library/custom_ops.py b/lib/python3.10/site-packages/torch/_library/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb45a78a5edce949c2946d27019e8322d73b0ae --- /dev/null +++ b/lib/python3.10/site-packages/torch/_library/custom_ops.py @@ -0,0 +1,835 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import inspect +import logging +import weakref +from contextlib import contextmanager +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch +from torch import _C, _ops, Tensor +from torch.utils._exposed_in import exposed_in + +from . import autograd, utils + + +device_types_t = Optional[Union[str, Sequence[str]]] +log = logging.getLogger(__name__) + + +@exposed_in("torch.library") +def custom_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: device_types_t = None, + schema: Optional[str] = None, +) -> Callable: + """Wraps a function into custom operator. + + Reasons why you may want to create a custom op include: + - Wrapping a third-party library or custom kernel to work with PyTorch + subsystems like Autograd. + - Preventing torch.compile/export/FX tracing from peeking inside your function. + + This API is used as a decorator around a function (please see examples). + The provided function must have type hints; these are needed to interface + with PyTorch's various subsystems. + + Args: + name (str): A name for the custom op that looks like "{namespace}::{name}", + e.g. "mylib::my_linear". The name is used as the op's stable identifier + in PyTorch subsystems (e.g. torch.export, FX graphs). + To avoid name collisions, please use your project name as the namespace; + e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. + mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. + This MUST be accurate, otherwise, the behavior is undefined. If "unknown", + it pessimistically assumes that all inputs to the operator are being mutated. + device_types (None | str | Sequence[str]): The device type(s) the function + is valid for. If no device type is provided, then the function + is used as the default implementation for all device types. + Examples: "cpu", "cuda". + When registering a device-specific implementation for an operator that accepts no Tensors, + we require the operator to have a "device: torch.device argument". + schema (None | str): A schema string for the operator. If None + (recommended) we'll infer a schema for the operator from its type + annotations. We recommend letting us infer a schema unless you + have a specific reason not to. + Example: "(Tensor x, int y) -> (Tensor, Tensor)". + + .. note:: + We recommend not passing in a ``schema`` arg and instead letting us infer + it from the type annotations. It is error-prone to write your own schema. + You may wish to provide your own schema if our interpretation of + the type annotation is not what you want. + For more info on how to write a schema string, see + `here `_ + + Examples:: + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> import numpy as np + >>> + >>> @custom_op("mylib::numpy_sin", mutates_args=()) + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> x = torch.randn(3) + >>> y = numpy_sin(x) + >>> assert torch.allclose(y, x.sin()) + >>> + >>> # Example of a custom op that only works for one device type. + >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu") + >>> def numpy_sin_cpu(x: Tensor) -> Tensor: + >>> x_np = x.numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np) + >>> + >>> x = torch.randn(3) + >>> y = numpy_sin_cpu(x) + >>> assert torch.allclose(y, x.sin()) + >>> + >>> # Example of a custom op that mutates an input + >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu") + >>> def numpy_sin_inplace(x: Tensor) -> None: + >>> x_np = x.numpy() + >>> np.sin(x_np, out=x_np) + >>> + >>> x = torch.randn(3) + >>> expected = x.sin() + >>> numpy_sin_inplace(x) + >>> assert torch.allclose(x, expected) + >>> + >>> # Example of a factory function + >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu") + >>> def bar(device: torch.device) -> Tensor: + >>> return torch.ones(3) + >>> + >>> bar("cpu") + + """ + + def inner(fn): + import torch + + if schema is None: + schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args) + else: + schema_str = schema + + namespace, opname = name.split("::") + result = CustomOpDef(namespace, opname, schema_str, fn) + if schema is not None: + # Check that schema's alias annotations match those of `mutates_args`. + expected = set() + for arg in result._opoverload._schema.arguments: + if arg.alias_info is not None and arg.alias_info.is_write: + expected.add(arg.name) + if expected != set(mutates_args): + raise ValueError( + f"Attempted to create a custom op with `mutates_args={mutates_args}` " + f"and `schema={schema}. The schema suggests that the op mutates {expected}" + f"which is different from what was provided to us in `mutates_args`. " + f"Please make these consistent." + ) + result.register_kernel(device_types)(fn) + return result + + if fn is None: + return inner + return inner(fn) + + +class CustomOpDef: + """CustomOpDef is a wrapper around a function that turns it into a custom op. + + It has various methods for registering additional behavior for this + custom op. + + You should not instantiate CustomOpDef directly; instead, use the + :func:`torch.library.custom_op` API. + """ + + def __init__(self, namespace: str, name: str, schema: str, fn: Callable) -> None: + # Fields used to interface with the PyTorch dispatcher + self._namespace = namespace + self._name = name + self._schema = schema + + self._init_fn = fn + + self._backend_fns: Dict[Union[str, None], Callable] = {} + self._abstract_fn: Optional[Callable] = None + self._setup_context_fn: Optional[Callable] = None + self._backward_fn: Optional[Callable] = None + self._torch_dispatch_fns: Dict[type, Callable] = {} + self._vmap_fn: Optional[Callable] = None + + self._lib = get_library_allowing_overwrite(self._namespace, self._name) + self._register_to_dispatcher() + self._disabled_kernel: Set = set() + OPDEFS[self._qualname] = self + + @property + def _qualname(self) -> str: + return f"{self._namespace}::{self._name}" + + def __repr__(self) -> str: + return f"" + + @contextmanager + def set_kernel_enabled(self, device_type: str, enabled: bool = True): + """ + Disable or re-enable an already registered kernel for this custom operator. + + If the kernel is already disabled/enabled, this is a no-op. + + Note: + If a kernel is first disabled and then registered, it is disabled until enabled again. + + Args: + device_type (str): The device type to disable/enable the kernel for. + disable (bool): Whether to disable or enable the kernel. + + Example: + >>> inp = torch.randn(1) + >>> + >>> # define custom op `f`. + >>> @custom_op("mylib::f", mutates_args=()) + >>> def f(x: Tensor) -> Tensor: + >>> return torch.zeros(1) + >>> + >>> print(f(inp)) # tensor([0.]), default kernel + >>> + >>> @f.register_kernel("cpu") + >>> def _(x): + >>> return torch.ones(1) + >>> + >>> print(f(inp)) # tensor([1.]), CPU kernel + >>> + >>> # temporarily disable the CPU kernel + >>> with f.set_kernel_enabled("cpu", enabled = False): + >>> print(f(inp)) # tensor([0.]) with CPU kernel disabled + + """ + action = "enable" if enabled else "disable" + originally_disabled = device_type in self._disabled_kernel + if device_type not in self._backend_fns: + log.warning( + "Attempted to %s kernel for %s but no kernel was registered for this device type.", + action, + device_type, + ) + + if not enabled: + if originally_disabled: + log.warning( + "Attempted to disable kernel for %s but it was already disabled.", + device_type, + ) + else: + self._disabled_kernel.add(device_type) + else: # enable the kernel + if not originally_disabled: + log.warning( + "Attempted to enable kernel for %s but it was already enabled.", + device_type, + ) + else: + self._disabled_kernel.remove(device_type) + + try: + yield + finally: + # restore original state + if originally_disabled: + self._disabled_kernel.add(device_type) + else: + self._disabled_kernel.discard(device_type) + + def register_kernel( + self, device_types: device_types_t, fn: Optional[Callable] = None, / + ) -> Callable: + """Register an implementation for a device type for this operator. + + Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". + This API may be used as a decorator. + + Args: + fn (Callable): The function to register as the implementation for + the given device types. + device_types (str | Sequence[str]): The device device_types to register an impl to. + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> import numpy as np + >>> + >>> # Create a custom op that works on cpu + >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np) + >>> + >>> # Add implementations for the cuda device + >>> @numpy_sin.register_kernel("cuda") + >>> def _(x): + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> x_cpu = torch.randn(3) + >>> x_cuda = x_cpu.cuda() + >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) + >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin()) + + """ + + def inner(fn): + if device_types is None or isinstance(device_types, str): + dtypes: List[Union[str, None]] = [device_types] + else: + dtypes = list(device_types) + for device_type in dtypes: + if device_type not in self._backend_fns: + + def backend_impl(*args, **kwargs): + # Checks the assumption that outputs cannot alias + # inputs or other outputs. + storages = { + id(tensor.untyped_storage()) + for tensor in iter_tensors(args, kwargs) + } + + result = self._backend_fns[device_type](*args, **kwargs) + + tuple_result = result + if not isinstance(result, tuple): + tuple_result = (result,) + for tensor in iter_tensors(tuple_result, {}): + key = id(tensor.untyped_storage()) + if id(tensor.untyped_storage()) in storages: + fn = self._backend_fns[device_type] + module = inspect.getmodule(fn) + raise RuntimeError( + f"{self._name} (with implementation in {module}): " + f"The output of this custom operator (1) must not " + f"also be an input to this custom operator and " + f"(2) may not alias any inputs to this custom operator " + f"or other returns. " + f"The most common way to trigger this error is if " + f"we have y = custom_op(x) and y and x are the same Tensor. " + f"Please instead return a clone of the offending output " + f"tensor(s) (e.g. return x.clone()) or refactor the custom " + f"operator to not return y." + ) + storages.add(key) + return result + + if device_type is None: + self._lib.impl( + self._name, backend_impl, "CompositeExplicitAutograd" + ) + else: + self._lib.impl( + self._name, + backend_impl, + _C._dispatch_key_for_device(device_type), + ) + + # Wrap function to choose between the default implementation or the device-specific + # implementation depending on if the kernel is disabled. + @torch._disable_dynamo + def wrapped_fn(*args, **kwargs): + if device_type in self._disabled_kernel: + return self._init_fn(*args, **kwargs) + else: + return fn(*args, **kwargs) + + self._backend_fns[device_type] = wrapped_fn + return fn + + if device_types is not None and not utils.has_tensor_arg( + self._opoverload._schema + ): + device_arg_index = utils.get_device_arg_index(self._opoverload._schema) + if device_arg_index is None: + raise ValueError( + "Functions without tensor inputs are required to have a `device: torch.device` argument" + ) + self._register_backend_select_dispatcher(device_arg_index) + + # See NOTE: [Supporting decorator and non-decorator usage] + if fn is None: + return inner + return inner(fn) + + def register_fake(self, fn: Callable, /) -> Callable: + r"""Register a FakeTensor implementation for this custom op. + + This is necessary to get the operator to work efficiently with torch.compile. + + The Fake impl (sometimes also known as a meta kernel or abstract impl) + specifies the behavior of this operator on Tensors that carry no data. + Given some input Tensors with certain properties + (sizes/strides/storage_offset/device), it specifies what the properties of + the output Tensors are. + + Please see :func:`torch.library.impl_abstract` for more details. + + Args: + fn (Callable): The function to register as the FakeTensor + implementation. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Example 1: an operator without data-dependent output shape + >>> @torch.library.custom_op("mylib::linear", mutates_args=()) + >>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: + >>> return (x @ weight.t()) + bias + >>> + >>> @linear.register_fake + >>> def _(x, weight, bias): + >>> assert x.dim() == 2 + >>> assert weight.dim() == 2 + >>> assert bias.dim() == 1 + >>> assert x.shape[1] == weight.shape[1] + >>> assert weight.shape[0] == bias.shape[0] + >>> assert x.device == weight.device + >>> return x.new_empty(x.size(0), weight.size(0)) + >>> + >>> x = torch.randn(2, 2) + >>> weight = torch.randn(2, 2) + >>> bias = torch.randn(2) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> out = torch.compile(linear, fullgraph=True)(x, weight, bias) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias)) + >>> + >>> # Example 2: an operator with data-dependent output shape + >>> @torch.library.custom_op("mylib::nonzero", mutates_args=()) + >>> def nonzero(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> return torch.tensor(res, device=x.device) + >>> + >>> @nonzero.register_fake + >>> def _(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an abstract impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch.library.get_ctx() + >>> nnz = ctx.new_dynamic_size() + >>> shape = [nnz, x.dim()] + >>> result = x.new_empty(shape, dtype=torch.int64) + >>> return result + >>> + >>> x = torch.tensor([0, 1, 2, 0, 0, 1]) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> out = torch.compile(nonzero, fullgraph=True)(x) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> assert torch.allclose(out, x.nonzero()) + + """ + self._abstract_fn = fn + return fn + + def register_torch_dispatch( + self, torch_dispatch_class: Any, fn: Optional[Callable] = None, / + ) -> Callable: + r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. + + This allows for open registration to specify the behavior between the operator + and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class`` + or the operator directly. + + Please see :func:`torch.library.register_torch_dispatch` for examples and more details. + """ + + def register(fn): + if torch_dispatch_class not in self._torch_dispatch_fns: + + def inner(*args, **kwargs): + return self._torch_dispatch_fns[torch_dispatch_class]( + *args, **kwargs + ) + + self._lib._register_torch_dispatch_rule( + self._name, torch_dispatch_class, inner + ) + self._torch_dispatch_fns[torch_dispatch_class] = fn + return fn + + if fn is None: + return register + else: + return register(fn) + + def register_autograd( + self, + backward: Callable, + /, + *, + setup_context: Optional[Callable] = None, + ) -> None: + r"""Register a backward formula for this custom op. + + In order for an operator to work with autograd, you need to register + a backward formula: + 1. You must tell us how to compute gradients during the backward pass + by providing us a "backward" function. + 2. If you need any values from the forward to compute gradients, you can + use `setup_context` to save values for backward. + + ``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``: + - ``grads`` is one or more gradients. The number of gradients matches + the number of outputs of the operator. + The ``ctx`` object is `the same ctx object `_ used by + :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the + same as :meth:`torch.autograd.Function.backward`. + + ``setup_context(ctx, inputs, output)`` runs during the forward pass. + Please save quantities needed for backward onto the ``ctx`` object via + either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` + or assigning them as attributes of ``ctx``. If your custom op has + kwarg-only arguments, we expect the signature of ``setup_context`` + to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``. + + Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is, + they may not directly access :meth:`torch.Tensor.data_ptr` and they must + not depend on or mutate global state. If you need a non-traceable backward, + you can make it a separate custom_op that you call inside ``backward_fn``. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> def setup_context(ctx, inputs, output) -> Tensor: + >>> x, = inputs + >>> ctx.save_for_backward(x) + >>> + >>> def backward(ctx, grad): + >>> x, = ctx.saved_tensors + >>> return grad * x.cos() + >>> + >>> numpy_sin.register_autograd(backward, setup_context=setup_context) + >>> + >>> x = torch.randn(3, requires_grad=True) + >>> y = numpy_sin(x) + >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> assert torch.allclose(grad_x, x.cos()) + >>> + >>> # Example with a keyword-only arg + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = x_np * val + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: + >>> ctx.val = keyword_only_inputs["val"] + >>> + >>> def backward(ctx, grad): + >>> return grad * ctx.val + >>> + >>> numpy_mul.register_autograd(backward, setup_context=setup_context) + >>> + >>> x = torch.randn(3, requires_grad=True) + >>> y = numpy_mul(x, val=3.14) + >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) + + """ + schema = self._opoverload._schema + if not utils.is_functional_schema(schema): + raise RuntimeError( + f"Cannot register autograd formula for non-functional operator " + f"{self} with schema {schema}. Please create " + f"a functional operator and register an autograd formula for that." + ) + + self._backward_fn = backward + self._setup_context_fn = setup_context + + def _register_to_dispatcher(self) -> None: + lib = self._lib + schema_str = self._name + self._schema + cpp_schema = _C.parse_schema(schema_str) + if utils.has_kwarg_only_tensors(cpp_schema): + # If you want to support this, the progression is: + # - supporting kwarg-only Tensors that are non-differentiable + # - supporting kwarg-only Tensors (regardless of differentiability) + raise NotImplementedError( + f"custom_op with kwarg-only Tensor args. Please make your " + f"tensors not kwarg-only. Got: {schema_str}" + ) + + lib.define( + schema_str, + tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order], + ) + self._opoverload = utils.lookup_op(self._qualname) + + def fake_impl(*args, **kwargs): + if self._abstract_fn is None: + if utils.can_generate_trivial_fake_impl(self._opoverload): + return None + raise RuntimeError( + f"There was no fake impl registered for {self}. " + f"This is necessary for torch.compile/export/fx tracing to work. " + f"Please use `{self._init_fn.__name__}.register_fake` to add an " + f"fake impl." + ) + return self._abstract_fn(*args, **kwargs) + + lib._register_fake(self._name, fake_impl, _stacklevel=4) + + autograd_impl = autograd.make_autograd_impl(self._opoverload, self) + lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True) + + schema = self._opoverload._schema + if schema.is_mutable: + + def adinplaceorview_impl(keyset, *args, **kwargs): + for arg, val in utils.zip_schema(schema, args, kwargs): + if not arg.alias_info: + continue + if not arg.alias_info.is_write: + continue + if isinstance(val, Tensor): + torch.autograd.graph.increment_version(val) + elif isinstance(val, (tuple, list)): + for v in val: + if isinstance(v, Tensor): + torch.autograd.graph.increment_version(v) + with _C._AutoDispatchBelowADInplaceOrView(): + return self._opoverload.redispatch( + keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs + ) + + lib.impl( + self._name, + adinplaceorview_impl, + "ADInplaceOrView", + with_keyset=True, + ) + + def _register_backend_select_dispatcher(self, device_arg_index: int): + """ + Switch on the device argument to select the correct backend to dispatch to. + """ + + def backend_select(keyset, *args, **kwargs): + device = args[device_arg_index].type + if device not in self._backend_fns: + raise RuntimeError( + f"{self._name} does not have a kernel registered for {device}. " + "Please use register_kernel to do so." + ) + dispatch_key = _C._dispatch_key_for_device(device) + dispatch_key = getattr(_C.DispatchKey, dispatch_key) + return self._opoverload.redispatch( + _C.DispatchKeySet(dispatch_key), *args, **kwargs + ) + + self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True) + + def __call__(self, *args, **kwargs): + return self._opoverload(*args, **kwargs) + + def register_vmap( + self, + func: Optional[Callable] = None, + ): + r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. + + This API may be used as a decorator. + + In order for an operator to work with :func:`torch.vmap`, you may need to register a + vmap implementation in the following signature: + + ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, + + where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. + + It specifies how do we compute the batched version of ``op`` given inputs with an additional + dimension (specified by ``in_dims``). + + For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` + if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer + specifying what dimension of the Tensor is being vmapped over. + + ``info`` is a collection of additional metadata that may be helpful: + ``info.batch_size`` specifies the size of the dimension being vmapped over, while + ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. + + The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, + ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` + per output that specifies if the output has the vmapped dimension and what index it is in. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> from typing import Tuple + >>> + >>> def to_numpy(tensor): + >>> return tensor.cpu().numpy() + >>> + >>> lib = torch.library.Library("mylib", "FRAGMENT") + >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) + >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: + >>> x_np = to_numpy(x) + >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) + >>> return torch.tensor(x_np ** 3, device=x.device), dx + >>> + >>> def numpy_cube_vmap(info, in_dims, x): + >>> result = numpy_cube(x) + >>> return result, (in_dims[0], in_dims[0]) + >>> + >>> numpy_cube.register_vmap(numpy_cube_vmap) + >>> + >>> x = torch.randn(3) + >>> torch.vmap(numpy_cube)(x) + >>> + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: + >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + >>> + >>> @numpy_mul.register_vmap + >>> def numpy_mul_vmap(info, in_dims, x, y): + >>> x_bdim, y_bdim = in_dims + >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + >>> result = x * y + >>> result = result.movedim(-1, 0) + >>> return result, 0 + >>> + >>> + >>> x = torch.randn(3) + >>> y = torch.randn(3) + >>> torch.vmap(numpy_mul)(x, y) + """ + from torch._functorch.autograd_function import custom_function_call_vmap_helper + from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter + + def register(func): + need_register = self._vmap_fn is None + self._vmap_fn = func + + if need_register: + + def wrapped_func(keyset, *args, **kwargs): + interpreter = retrieve_current_functorch_interpreter() + return custom_function_call_vmap_helper( + interpreter, self._vmap_fn, self._opoverload, *args, **kwargs + ) + + self._lib.impl( + self._name, wrapped_func, "FuncTorchBatched", with_keyset=True + ) + + if func is None: + return register + else: + return register(func) + + +# NOTE: [Supporting decorator and non-decorator usage] +# +# Some APIs may be both used as a decorator and not as a decorator. +# For example: +# +# >>> def fn(x): +# >>> return x.sin() +# >>> +# >>> # Usage 1: not as a decorator +# >>> numpy_sin.register_kernel("cuda", fn) +# >>> +# >>> # Usage 2: as a decorator +# >>> @numpy_sin.register_kernel("cuda") +# >>> def fn2(x): +# >>> return x.sin +# +# The way we support this is that `register_kernel` accepts an optional `fn`. +# If `fn` is provided (Usage 1), then we know that the user is using it not +# as a decorator. +# If `fn` is not provided (Usage 2), then `register_kernel` needs to return a +# decorator. + + +OPDEF_TO_LIB: Dict[str, "torch.library.Library"] = {} +OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + + +def get_library_allowing_overwrite( + namespace: str, name: str +) -> "torch.library.Library": + qualname = f"{namespace}::{name}" + + if qualname in OPDEF_TO_LIB: + OPDEF_TO_LIB[qualname]._destroy() + del OPDEF_TO_LIB[qualname] + + lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901 + OPDEF_TO_LIB[qualname] = lib + return lib + + +def iter_tensors( + args: Tuple[Any], kwargs: Dict[str, Any], allowed_nesting: int = 1 +) -> Iterator[Tensor]: + def check(arg): + if isinstance(arg, Tensor): + yield arg + elif allowed_nesting > 0 and isinstance(arg, (tuple, list)): + yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1) + + for arg in args: + yield from check(arg) + for kwarg in kwargs.values(): + yield from check(kwarg) + + +def _maybe_get_opdef( + op: Union[CustomOpDef, _ops.OpOverload, str] +) -> Optional[CustomOpDef]: + if isinstance(op, CustomOpDef): + return op + if isinstance(op, _ops.OpOverload): + op = op._name + assert isinstance(op, str) + if op in OPDEFS: + return OPDEFS[op] + return None diff --git a/lib/python3.10/site-packages/torch/_library/fake_class_registry.py b/lib/python3.10/site-packages/torch/_library/fake_class_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2689f5328f55dfbf37f1a2f6584aa328c6ce83 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_library/fake_class_registry.py @@ -0,0 +1,320 @@ +# mypy: allow-untyped-defs +import logging +from typing import Any, Dict, Optional, Protocol, Tuple, Union + +import torch +from torch._library.utils import parse_namespace + + +log = logging.getLogger(__name__) + + +class FakeScriptObject: + def __init__(self, wrapped_obj: Any, script_class_name: str, x: torch.ScriptObject): + self.wrapped_obj = wrapped_obj + + # The fully qualified name of the class of original script object + self.script_class_name = script_class_name + self.real_obj = x + + +class FakeScriptMethod: + def __init__( + self, + self_fake_obj: FakeScriptObject, + method_name: str, + schema: Optional[torch.FunctionSchema], + ): + self.self_fake_obj = self_fake_obj + self.method_name = method_name + self.schema = schema + + def __call__(self, *args, **kwargs): + from torch._higher_order_ops.torchbind import call_torchbind + + return call_torchbind(self.self_fake_obj, self.method_name, *args, **kwargs) + + +class HasStaticMethodFromReal(Protocol): + @classmethod + def from_real(cls, real_obj: torch.ScriptObject): + pass + + +class FakeClassRegistry: + def __init__(self) -> None: + self._registered_class: Dict[str, Any] = {} + + def has_impl(self, full_qualname: str) -> bool: + return full_qualname in self._registered_class + + def get_impl(self, full_qualname: str) -> Any: + self._check_registered(full_qualname) + return self._registered_class[full_qualname] + + def register(self, full_qualname: str, fake_class=None) -> None: + if self.has_impl(full_qualname): + log.warning( + "%s is already registered. Previous fake class is overridden with %s.", + full_qualname, + fake_class, + ) + self._registered_class[full_qualname] = fake_class + + def deregister(self, full_qualname: str) -> Any: + if not self.has_impl(full_qualname): + log.warning( + "Cannot deregister %s. Please use register_fake_class to register it first." + " Or do you dereigster it twice?", + full_qualname, + ) + else: + return self._registered_class.pop(full_qualname) + + def clear(self) -> None: + self._registered_class.clear() + + def _check_registered(self, full_qualname: str) -> None: + if full_qualname not in self._registered_class: + raise RuntimeError( + f"{full_qualname} is not registered. Please use register_fake_class to register it first." + ) + + +global_fake_class_registry = FakeClassRegistry() + + +# TODO: add this check at compile time for __obj_flatten__. +def _check_valid_flat_script_obj(flat_x): + if not isinstance(flat_x, tuple): + raise RuntimeError("Expect flat x to be a tuple.") + + for tp in flat_x: + if not isinstance(tp, tuple): + raise RuntimeError("Expect flat x to be a tuple of tuples.") + + if not len(tp) == 2 or not isinstance(tp[0], str): + raise RuntimeError( + "Expect element of flat x to be a tuple of two elements with first element being a string" + ) + + +def tracing_with_real(x: torch.ScriptObject) -> bool: + if not hasattr(x, "tracing_mode"): + return False + + assert x.tracing_mode() in [ + "real", + "fake", + ], f"tracing_mode can be either real or fake but got {x.tracing_mode()}" + return x.tracing_mode() == "real" + + +def maybe_to_fake_obj( + fake_mode, x: torch.ScriptObject +) -> Union[FakeScriptObject, torch.ScriptObject]: + import torch.utils._pytree as pytree + from torch.utils._python_dispatch import _disable_current_modes + + # When tracing with real mode, people should implement meta kernels that can + # handle the case of real script object + fake tensor inputs. + if tracing_with_real(x): + return x + + # x.__obj_flatten__() could be calling some tensor operations inside but we don't + # want to call these ops in surrounding dispatch modes when executing it. + # Otherwise, for example, the fake tensor modes will error out when the tensors inside + # script obeject execute some operations like clone if allow_non_fake_input flag is set. + with _disable_current_modes(): + flat_x = x.__obj_flatten__() # type: ignore[attr-defined] + + _check_valid_flat_script_obj(flat_x) + + fake_flattened = pytree.tree_map_only( + torch.Tensor, + lambda t: fake_mode.from_tensor(t), + flat_x, + ) + + fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened) + + fake_x_wrapped = FakeScriptObject(fake_x, x._type().qualified_name(), x) # type: ignore[attr-defined] + + for name in x._method_names(): # type: ignore[attr-defined] + attr = getattr(fake_x, name, None) + if attr: + if not callable(attr): + raise RuntimeError(f"Expect {name} to be a callable but got {attr}.") + + real_attr = getattr(x, name) # type: ignore[attr-defined] + + # real attr sometimes is not torch.ScriptMethod thus doesn't have schema e.g. __init___ or __eq__ + method_schema: Optional[torch.FunctionSchema] = None + if isinstance(real_attr, torch.ScriptMethod): + method_schema = real_attr.schema # type: ignore[attr-defined] + + setattr( + fake_x_wrapped, + name, + FakeScriptMethod(fake_x_wrapped, name, method_schema), + ) + else: + override_skip_list = {"__obj_flatten__", "__get_state__", "__set_state__"} + if name not in override_skip_list: + log.warning("fake object of %s doesn't implement method %s.", x, name) + return fake_x_wrapped + + +def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] = None): + r"""Register a fake implementation for this class. + + It's in the same spirit of registering a fake implementation for + an operator but with the difference that it + associates a fake class with the original torch bind class (registered + with torch::class_). In this way, torch.compile can handle them properly + in components such as Dynamo and AOTAutograd. + + This API may be used as a decorator (see example). For the fake class, users + are required to provide a from_real classmethod that takes a real object and + returns an instance of the fake class. All tensors in the fake object should also + be properly fakified with to_fake_tensor() in from_real. + + + Examples: + # For a custom class Foo defined in test_custom_class_registration.cpp: + + TORCH_LIBRARY(_TorchScriptTesting, m) { + m.class_("_TensorQueue") + .def(torch::init()) + .def("push", &TensorQueue::push) + .def("pop", &TensorQueue::pop) + .def("top", &TensorQueue::top) + .def("size", &TensorQueue::size) + .def("clone_queue", &TensorQueue::clone_queue) + .def("__obj_flatten__", &TensorQueue::__obj_flatten__) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& self) + -> c10::Dict { + return self->serialize(); + }, + // __setstate__ + [](c10::Dict data) + -> c10::intrusive_ptr { + return c10::make_intrusive(std::move(data)); + }); + }; + # We could register a fake class FakeTensorQueue in Python as follows: + import torch + + @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") + class FakeTensorQueue: + def __init__(self, queue): + self.queue = queue + + @classmethod + def __obj_unflatten__(cls, flattened_ctx): + return cls(**dict(ctx)) + + def push(self, x): + self.queue.append(x) + + def pop(self): + return self.queue.pop(0) + + def size(self): + return len(self.queue) + + In this example, the original TensorQeue need to addd a __obj_flatten__ method + to the class TensorQueue and the flattend result is passed into FakeTensorQueue's + __obj_unflatten__ as inputs to create a fake class. This protocol allows pytorch to look + at the contents of the script object and properly handle them in the subsystems + like dynamo, aot_aotugrad or more. + """ + + def inner(fake_class: HasStaticMethodFromReal): + ns, name = parse_namespace(qualname) + + # This also checks whether the refered torch::class_ exists. + torchbind_class = torch._C._get_custom_class_python_wrapper(ns, name) + + from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None) + if not from_method: + raise RuntimeError( + f"{fake_class} doesn't define a classmethod {_CONVERT_FROM_REAL_NAME}." + ) + + if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod): + raise RuntimeError( + f"{_CONVERT_FROM_REAL_NAME} method is not a classmethod." + ) + + global_fake_class_registry.register(_full_qual_class_name(qualname), fake_class) + return fake_class + + if fake_class is None: + return inner + return inner(fake_class) + + +def deregister_fake_class(qualname): + return global_fake_class_registry.deregister(_full_qual_class_name(qualname)) + + +def has_fake_class(full_qualname) -> bool: + return global_fake_class_registry.has_impl(full_qualname) + + +def find_fake_class(full_qualname) -> Optional[Any]: + if not has_fake_class(full_qualname): + return None + return global_fake_class_registry.get_impl(full_qualname) + + +def _full_qual_class_name(qualname: str) -> str: + ns, name = parse_namespace(qualname) + return "__torch__.torch.classes." + ns + "." + name + + +# Return the namespace and class name from fully qualified name. +def _ns_and_class_name(full_qualname: str) -> Tuple[str, str]: + splits = full_qualname.split(".") + assert len(splits) == 5 + _torch, torch_ns, classes, ns, class_name = splits + return ns, class_name + + +def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any: + full_qualname = x._type().qualified_name() # type: ignore[attr-defined] + ns, class_name = _ns_and_class_name(full_qualname) + fake_class = find_fake_class(full_qualname) + if fake_class is None: + raise RuntimeError( + f" ScriptObject's {full_qualname} haven't registered a fake class." + f" Please use register_fake_class({ns}::{class_name}) to annotate a fake class for the script obj." + f" Specifically, create a python class that implements a fake version for all the methods" + f" that're used in the program and put annotated class in the program e.g. after loading the library." + f" The fake methods can be written in the same way as a meta kernel for an operator but need to additionally" + f" simulate the object's states. Be sure to add a {_CONVERT_FROM_REAL_NAME} classmethod" + f" to enable creating a fake obj from a real one." + ) + return fake_class + + +_CONVERT_FROM_REAL_NAME = "__obj_unflatten__" + + +def _fake_obj_from_real(fake_mode, x) -> Any: + fake_class = _find_fake_class_for_script_object(x) + + from_real_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None) + if not from_real_method: + raise RuntimeError( + f"{fake_class} must define a classmethod {_CONVERT_FROM_REAL_NAME}" + f" that converts the real object to the fake object." + ) + + # from_real defined by user need the ctx to fakify the tensor states. + ctx = torch._library.fake_impl.FakeImplCtx(fake_mode, None) + with torch._library.fake_impl.set_ctx_getter(lambda: ctx): + return fake_class.from_real(x) diff --git a/lib/python3.10/site-packages/torch/_library/fake_impl.py b/lib/python3.10/site-packages/torch/_library/fake_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..a972e8da89eb796db7bdd4e48043cff0cab38bcc --- /dev/null +++ b/lib/python3.10/site-packages/torch/_library/fake_impl.py @@ -0,0 +1,207 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +from typing import Callable, Optional +from typing_extensions import deprecated + +import torch +from torch._library.utils import Kernel, RegistrationHandle + + +class FakeImplHolder: + """A holder where one can register an fake impl to.""" + + def __init__(self, qualname: str): + self.qualname: str = qualname + self.kernel: Optional[Kernel] = None + self.lib: Optional[torch.library.Library] = None + + def register(self, func: Callable, source: str) -> RegistrationHandle: + """Register an fake impl. + + Returns a RegistrationHandle that one can use to de-register this + fake impl. + """ + if self.kernel is not None: + raise RuntimeError( + f"register_fake(...): the operator {self.qualname} " + f"already has an fake impl registered at " + f"{self.kernel.source}." + ) + if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"): + raise RuntimeError( + f"register_fake(...): the operator {self.qualname} " + f"already has an DispatchKey::Meta implementation via a " + f"pre-existing torch.library or TORCH_LIBRARY registration. " + f"Please either remove that registration or don't call " + f"register_fake." + ) + + if torch._C._dispatch_has_kernel_for_dispatch_key( + self.qualname, "CompositeImplicitAutograd" + ): + raise RuntimeError( + f"register_fake(...): the operator {self.qualname} " + f"already has an implementation for this device type via a " + f"pre-existing registration to " + f"DispatchKey::CompositeImplicitAutograd." + f"CompositeImplicitAutograd operators do not need an fake " + f"impl; " + f"instead, the operator will decompose into its constituents " + f"and those " + f"can have fake impls defined on them." + ) + + # Store the kernel in this holder + self.kernel = Kernel(func, source) + + # Also register the fake impl to Meta key + if self.lib is None: + ns = self.qualname.split("::")[0] + self.lib = torch.library.Library(ns, "FRAGMENT") # noqa: TOR901 + meta_kernel = construct_meta_kernel(self.qualname, self) + self.lib.impl(self.qualname, meta_kernel, "Meta") + + def deregister_fake_class(): + if self.lib: + self.lib._destroy() + self.lib = None + self.kernel = None + + return RegistrationHandle(deregister_fake_class) + + +def construct_meta_kernel(qualname: str, fake_impl_holder: FakeImplHolder) -> Callable: + assert fake_impl_holder.kernel is not None + + @functools.wraps(fake_impl_holder.kernel.func) + def meta_kernel(*args, **kwargs): + assert fake_impl_holder.kernel is not None + source = fake_impl_holder.kernel.source + + def error_on_ctx(): + raise RuntimeError( + f"Attempted to call get_ctx() for the meta implementation " + f"for {qualname} (implemented at {source})" + f"You have presumably called get_ctx() because the operator " + f"has a data-dependent output shape; if so, there is no " + f"such meta implementation and this error is the correct " + f"behavior." + ) + + with set_ctx_getter(error_on_ctx): + return fake_impl_holder.kernel(*args, **kwargs) + + return meta_kernel + + +def get_none(): + return None + + +global_ctx_getter: Callable = get_none + + +@contextlib.contextmanager +def set_ctx_getter(ctx_getter): + global global_ctx_getter + prev = global_ctx_getter + try: + global_ctx_getter = ctx_getter + yield + finally: + global_ctx_getter = prev + + +class FakeImplCtx: + """ + Context object for writing fake implementations for custom operators. + """ + + def __init__(self, _fake_mode, _op): + self._fake_mode = _fake_mode + self._shape_env = _fake_mode.shape_env + self._op = _op + + @deprecated( + "`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead", + category=FutureWarning, + ) + def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt: + return self.new_dynamic_size(min=min, max=max) + + def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt: + """Constructs a new symint (symbolic int) representing a data-dependent value. + + This is useful for writing the fake implementation (which is necessary + for torch.compile) for a CustomOp where an output Tensor has a size + that depends on the data of the input Tensors. + + Args: + min (int): A statically known inclusive lower bound for this symint. Default: 0 + max (Optional[int]): A statically known inclusive upper bound for this + symint. Default: None + + .. warning: + + It is important that the ``min`` and ``max`` (if not None) values are set + correctly, otherwise, there will be undefined behavior under + torch.compile. The default value of ``min`` is 2 due to torch.compile + specializing on 0/1 sizes. + + You must also verify that your implementation on concrete Tensors + (e.g. CPU/CUDA) only returns Tensors where the size that corresponds + to the symint also has respects these constraint. + The easiest way to do this is to add an assertion in the CPU/CUDA/etc + implementation that the size follows these bounds. + + Example:: + + >>> # An operator with data-dependent output shape + >>> lib = torch.library.Library("mymodule", "FRAGMENT") + >>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor") + >>> + >>> @torch.library.register_fake("mymodule::custom_nonzero") + >>> def _(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an fake impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch.library.get_ctx() + >>> nnz = ctx.new_dynamic_size() + >>> shape = [nnz, x.dim()] + >>> result = x.new_empty(shape, dtype=torch.int64) + >>> return result + >>> + >>> @torch.library.impl(lib, "custom_nonzero", "CPU") + >>> def _(x): + >>> x_np = x.numpy() + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> return torch.tensor(res, device=x.device) + + """ + if ( + self._shape_env is None + or not self._shape_env.allow_dynamic_output_shape_ops + ): + raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op) + + if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt): + raise ValueError( + f"ctx.new_dynamic_size(min={min}, max={max}): expected " + f"min and max to be statically known ints but got SymInt. " + f"This is not supported." + ) + + if min < 0: + raise ValueError( + f"ctx.new_dynamic_size(min={min}, ...): expected min to be " + f"greater than or equal to 0: this API can only create " + f"non-negative sizes." + ) + + result = self._shape_env.create_unbacked_symint() + torch.fx.experimental.symbolic_shapes._constrain_range_for_size( + result, min=min, max=max + ) + return result diff --git a/lib/python3.10/site-packages/torch/_library/infer_schema.py b/lib/python3.10/site-packages/torch/_library/infer_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..b2eeb24521d382516830cd4ebe289672eecfeae8 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_library/infer_schema.py @@ -0,0 +1,271 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import inspect +import typing +from typing import List, Optional, Sequence, Union # noqa: F401 + +import torch +from torch import device, dtype, Tensor, types +from torch.utils._exposed_in import exposed_in + + +@exposed_in("torch.library") +def infer_schema( + prototype_function: typing.Callable, + /, + *, + mutates_args, + op_name: Optional[str] = None, +) -> str: + r"""Parses the schema of a given function with type hints. The schema is inferred from the + function's type hints, and can be used to define a new operator. + + We make the following assumptions: + + * None of the outputs alias any of the inputs or each other. + * | String type annotations "device, dtype, Tensor, types" without library specification are + | assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union" + | without library specification are assumed to be typing.*. + * | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown", + | it assumes that all inputs to the operator are being mutates. + + Callers (e.g. the custom ops API) are responsible for checking these assumptions. + + Args: + prototype_function: The function from which to infer a schema for from its type annotations. + op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the + name is not included in the inferred schema. Note that the input schema to + ``torch.library.Library.define`` requires a operator name. + mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function. + + Returns: + The inferred schema. + + Example: + >>> def foo_impl(x: torch.Tensor) -> torch.Tensor: + >>> return x.sin() + >>> + >>> infer_schema(foo_impl, op_name="foo", mutates_args={}) + foo(Tensor x) -> Tensor + >>> + >>> infer_schema(foo_impl, mutates_args={}) + (Tensor x) -> Tensor + """ + UNKNOWN_MUTATES = "unknown" + sig = inspect.signature(prototype_function) + + def error_fn(what): + raise ValueError( + f"infer_schema(func): {what} " f"Got func with signature {sig})" + ) + + def convert_type_string(annotation_type: str): + try: + return eval(annotation_type) + except Exception as e: + error_fn( + f"Unsupported type annotation {annotation_type}. It is not a type." + ) + + params = [] + seen_args = set() + saw_kwarg_only_arg = False + for idx, (name, param) in enumerate(sig.parameters.items()): + if not supported_param(param): + error_fn("We do not support positional-only args, varargs, or varkwargs.") + + if param.kind == inspect.Parameter.KEYWORD_ONLY: + # The first time we see a kwarg-only arg, add "*" to the schema. + if not saw_kwarg_only_arg: + params.append("*") + saw_kwarg_only_arg = True + + if param.annotation is inspect.Parameter.empty: + error_fn(f"Parameter {name} must have a type annotation.") + + # The annotation might be converted to a string by annotation, + # we convert it to the actual type. + annotation_type = param.annotation + if type(annotation_type) == str: + annotation_type = convert_type_string(annotation_type) + + if annotation_type not in SUPPORTED_PARAM_TYPES.keys(): + if annotation_type.__origin__ is tuple: + list_type = tuple_to_list(annotation_type) + example_type_str = "\n\n" + # Only suggest the list type if this type is supported. + if list_type in SUPPORTED_PARAM_TYPES.keys(): + example_type_str = f"For example, {list_type}.\n\n" + error_fn( + f"Parameter {name} has unsupported type {param.annotation}. " + f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. " + f"{example_type_str}" + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." + ) + else: + error_fn( + f"Parameter {name} has unsupported type {param.annotation}. " + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." + ) + + schema_type = SUPPORTED_PARAM_TYPES[annotation_type] + if type(mutates_args) == str: + if mutates_args != UNKNOWN_MUTATES: + raise ValueError( + "mutates_args must either be a sequence of the names of " + "the arguments that are mutated or the string 'unknown'. " + ) + if schema_type.startswith("Tensor"): + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + elif name in mutates_args: + if not schema_type.startswith("Tensor"): + error_fn( + f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated" + ) + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + seen_args.add(name) + if param.default is inspect.Parameter.empty: + params.append(f"{schema_type} {name}") + else: + default_repr = None + if param.default is None or isinstance(param.default, (int, float, bool)): + default_repr = str(param.default) + elif isinstance(param.default, (str, torch.device)): + default_repr = f'"{param.default}"' + elif isinstance(param.default, torch.dtype): + dtype_repr = str(param.default) + torch_dot = "torch." + assert dtype_repr.startswith(torch_dot) + default_repr = dtype_repr[len(torch_dot) :] + else: + error_fn( + f"Parameter {name} has an unsupported default value type {type(param.default)}. " + f"Please file an issue on GitHub so we can prioritize this." + ) + params.append(f"{schema_type} {name}={default_repr}") + if mutates_args != UNKNOWN_MUTATES: + mutates_args_not_seen = set(mutates_args) - seen_args + if len(mutates_args_not_seen) > 0: + error_fn( + f"{mutates_args_not_seen} in mutates_args were not found in " + f"the custom op's signature. " + f"mutates_args should contain the names of all args that the " + f"custom op mutates, or just the string 'unknown' if you don't know." + ) + return_annotation = sig.return_annotation + if type(return_annotation) == str: + return_annotation = convert_type_string(return_annotation) + ret = parse_return(return_annotation, error_fn) + if op_name is not None: + return f"{op_name}({', '.join(params)}) -> {ret}" + return f"({', '.join(params)}) -> {ret}" + + +def derived_types( + base_type, cpp_type, list_base, optional_base_list, optional_list_base +): + result = [ + (base_type, cpp_type), + (typing.Optional[base_type], f"{cpp_type}?"), + ] + + def derived_seq_types(typ): + return [ + typing.Sequence[typ], # type: ignore[valid-type] + typing.List[typ], # type: ignore[valid-type] + ] + + if list_base: + for seq_typ in derived_seq_types(base_type): + result.append((seq_typ, f"{cpp_type}[]")) # type: ignore[valid-type] + if optional_base_list: + for seq_typ in derived_seq_types(typing.Optional[base_type]): + result.append((seq_typ, f"{cpp_type}?[]")) # type: ignore[valid-type] + if optional_list_base: + for seq_typ in derived_seq_types(base_type): # type: ignore[valid-type] + result.append((typing.Optional[seq_typ], f"{cpp_type}[]?")) # type: ignore[valid-type] + return result + + +def get_supported_param_types(): + data = [ + # (python type, schema type, type[] variant, type?[] variant, type[]? variant + (Tensor, "Tensor", True, True, False), + (int, "SymInt", True, False, True), + (float, "float", True, False, True), + (bool, "bool", True, False, True), + (str, "str", False, False, False), + (types.Number, "Scalar", True, False, False), + (dtype, "ScalarType", False, False, False), + (device, "Device", False, False, False), + ] + result = [] + for line in data: + result.extend(derived_types(*line)) + return dict(result) + + +SUPPORTED_RETURN_TYPES = { + Tensor: "Tensor", + typing.List[Tensor]: "Tensor[]", + int: "SymInt", + float: "float", + bool: "bool", + types.Number: "Scalar", +} + + +def parse_return(annotation, error_fn): + if annotation is None: + return "()" + + if annotation is inspect.Parameter.empty: + error_fn("No return type annotation was provided. Please add one.") + + origin = typing.get_origin(annotation) + if origin is not tuple: + if annotation not in SUPPORTED_RETURN_TYPES.keys(): + error_fn( + f"Return has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_RETURN_TYPES}." + ) + return SUPPORTED_RETURN_TYPES[annotation] + + args = typing.get_args(annotation) + for arg in args: + if arg not in SUPPORTED_RETURN_TYPES: + error_fn( + f"Return has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_RETURN_TYPES}." + ) + + return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")" + + +SUPPORTED_PARAM_TYPES = get_supported_param_types() + + +def supported_param(param: inspect.Parameter) -> bool: + return param.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + + +def tuple_to_list(tuple_type: typing.Type[typing.Tuple]) -> typing.Type[typing.List]: + """ + Convert `tuple_type` into a list type with the same type arguments. Assumes that `tuple_type` is typing.Tuple type. + """ + type_args = getattr(tuple_type, "__args__", None) + # Account for different python versions, e.g. python 3.8 would give () + # but python 3.12 would give None. + if tuple_type is typing.Tuple or type_args == () or type_args is None: + # Handle the case of an empty tuple type + return typing.List + elif len(type_args) == 1: + # General case: create a List with the same type arguments + return typing.List[type_args[0]] # type: ignore[valid-type] + elif len(type_args) == 2 and type_args[1] is Ellipsis: # type: ignore[valid-type] + return typing.List[type_args[0]] # type: ignore[valid-type] + else: + return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc] diff --git a/lib/python3.10/site-packages/torch/_library/simple_registry.py b/lib/python3.10/site-packages/torch/_library/simple_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..cfef278679ea56f4d5d15589467ffd47c3edaef4 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_library/simple_registry.py @@ -0,0 +1,85 @@ +# mypy: allow-untyped-defs +from typing import Callable, Optional + +from .fake_impl import FakeImplHolder +from .utils import RegistrationHandle + + +__all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"] + + +class SimpleLibraryRegistry: + """Registry for the "simple" torch.library APIs + + The "simple" torch.library APIs are a higher-level API on top of the + raw PyTorch DispatchKey registration APIs that includes: + - fake impl + + Registrations for these APIs do not go into the PyTorch dispatcher's + table because they may not directly involve a DispatchKey. For example, + the fake impl is a Python function that gets invoked by FakeTensor. + Instead, we manage them here. + + SimpleLibraryRegistry is a mapping from a fully qualified operator name + (including the overload) to SimpleOperatorEntry. + """ + + def __init__(self): + self._data = {} + + def find(self, qualname: str) -> "SimpleOperatorEntry": + if qualname not in self._data: + self._data[qualname] = SimpleOperatorEntry(qualname) + return self._data[qualname] + + +singleton: SimpleLibraryRegistry = SimpleLibraryRegistry() + + +class SimpleOperatorEntry: + """This is 1:1 to an operator overload. + + The fields of SimpleOperatorEntry are Holders where kernels can be + registered to. + """ + + def __init__(self, qualname: str): + self.qualname: str = qualname + self.fake_impl: FakeImplHolder = FakeImplHolder(qualname) + self.torch_dispatch_rules: GenericTorchDispatchRuleHolder = ( + GenericTorchDispatchRuleHolder(qualname) + ) + + # For compatibility reasons. We can delete this soon. + @property + def abstract_impl(self): + return self.fake_impl + + +class GenericTorchDispatchRuleHolder: + def __init__(self, qualname): + self._data = {} + self.qualname = qualname + + def register( + self, torch_dispatch_class: type, func: Callable + ) -> RegistrationHandle: + if self.find(torch_dispatch_class): + raise RuntimeError( + f"{torch_dispatch_class} already has a `__torch_dispatch__` rule registered for {self.qualname}" + ) + self._data[torch_dispatch_class] = func + + def deregister(): + del self._data[torch_dispatch_class] + + return RegistrationHandle(deregister) + + def find(self, torch_dispatch_class): + return self._data.get(torch_dispatch_class, None) + + +def find_torch_dispatch_rule(op, torch_dispatch_class: type) -> Optional[Callable]: + return singleton.find(op.__qualname__).torch_dispatch_rules.find( + torch_dispatch_class + ) diff --git a/lib/python3.10/site-packages/torch/_library/triton.py b/lib/python3.10/site-packages/torch/_library/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..d2caa4924529ef67b3ed7729129fc38e6120c0f9 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_library/triton.py @@ -0,0 +1,233 @@ +import contextlib +import threading +from typing import Callable, Generator, Iterable, Optional, Union + +from .custom_ops import custom_op +from .infer_schema import infer_schema + + +def triton_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + schema: Optional[str] = None, +) -> Callable: + """Create a custom operator whose implementation is backed by 1+ triton kernels. + + Use this instead of :func:`torch.library.custom_op` when the implementation + consists of 1+ triton kernels. :func:`torch.library.custom_op` treats + custom operators as opaque (:func:`torch.compile` and + :func:`torch.export.export` will never trace into them), but ``triton_op`` + makes the implementation visible to these subsystems, allowing them + to optimize the triton kernel(s). + + Note that ``fn`` must only consist of calls to PyTorch-understood + operators and triton kernels. Any triton kernels called inside ``fn`` + must be wrapped in a call to :func:`torch._library.capture_triton``. + + Args: + name (str): A name for the custom op that looks like "{namespace}::{name}", + e.g. "mylib::my_linear". The name is used as the op's stable identifier + in PyTorch subsystems (e.g. torch.export, FX graphs). + To avoid name collisions, please use your project name as the namespace; + e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. + mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. + This MUST be accurate, otherwise, the behavior is undefined. If "unknown", + it pessimistically assumes that all inputs to the operator are being mutated. + schema (None | str): A schema string for the operator. If None + (recommended) we'll infer a schema for the operator from its type + annotations. We recommend letting us infer a schema unless you + have a specific reason not to. + Example: "(Tensor x, int y) -> (Tensor, Tensor)". + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch._library import triton_op, capture_triton + >>> + >>> import triton + >>> from triton import language as tl + >>> + >>> @triton.jit + >>> def add_kernel( + >>> in_ptr0, + >>> in_ptr1, + >>> out_ptr, + >>> n_elements, + >>> BLOCK_SIZE: "tl.constexpr", + >>> ): + >>> pid = tl.program_id(axis=0) + >>> block_start = pid * BLOCK_SIZE + >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) + >>> mask = offsets < n_elements + >>> x = tl.load(in_ptr0 + offsets, mask=mask) + >>> y = tl.load(in_ptr1 + offsets, mask=mask) + >>> output = x + y + >>> tl.store(out_ptr + offsets, output, mask=mask) + >>> + >>> @triton_op("mylib::add", mutates_args={}) + >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + >>> output = torch.empty_like(x) + >>> n_elements = output.numel() + >>> + >>> def grid(meta): + >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + >>> + >>> # NB: we need to wrap the triton kernel in a call to capture_triton + >>> capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) + >>> return output + >>> + >>> @torch.compile + >>> def f(x, y): + >>> return add(x, y) + >>> + >>> x = torch.randn(3, device="cuda") + >>> y = torch.randn(3, device="cuda") + >>> + >>> z = f(x, y) + >>> assert torch.allclose(z, x + y) + + """ + + def dec(fn: Callable) -> Callable: + def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def] + # Optimization: we're passing regular Tensors into the triton kernel, so + # no need to go through HOP dispatch + with set_capture_triton_enabled(False): + return fn(*args, **kwargs) + + result = custom_op( + name, + backend_fn, + mutates_args=mutates_args, + schema=infer_schema(fn, mutates_args=mutates_args), + ) + from .._subclasses.functional_tensor import FunctionalTensorMode + + # We require that the user pass us a function that is make_fx traceable, + # so we can just register it as the Fake/meta kernel. + result.register_fake(fn) + + # We decompose the operator when FunctionalTensorMode is active. + # The goal is to decompose the operator in AOTDispatcher. + # - With torch.compile, this means that the backend (usually Inductor) + # can see a call to the triton kernel(s) and so it can directly optimize + # them by inlining them into the lowering process. + # - With post-dispatch torch.export, this means that there will + # be a call(s) to the triton_kernel_wrapper_functional HOP in the + # graph (that we have yet to figure out how to serialize). + def functional_decomp( # type: ignore[no-untyped-def] + mode, _, types, args, kwargs + ): + with mode: + return fn(*args, **kwargs) + + result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) + return result + + if fn is None: + return dec + else: + return dec(fn) + + +capture_triton_enabled = threading.local() +capture_triton_enabled_default = True + + +@contextlib.contextmanager +def set_capture_triton_enabled(enabled: bool) -> Generator[None, None, None]: + """If triton kernels annotated with @capture_triton should dispatch via HOP + or go straight to the triton kernel execution. + + We have this switch because eager-mode performance of HOP dispatch is slow + enough to matter (~1ms) and we know that capture_triton isn't necessary in + some situations (eager-mode with regular Tensors) + """ + try: + prev = is_capture_triton_enabled() + capture_triton_enabled.value = enabled + yield + finally: + capture_triton_enabled.value = prev + + +def is_capture_triton_enabled() -> bool: + return getattr(capture_triton_enabled, "value", capture_triton_enabled_default) + + +def capture_triton(triton_kernel: Callable, /) -> Callable: + """Allows capture of a triton kernel into a graph via make_fx or + non-strict export (coming soon). + + These technologies perform Dispatcher-based tracing (via + ``__torch_dispatch__``) and cannot see calls to raw triton kernels. + The ``capture_triton`` API returns a new callable that can actually + be traced into a graph. + + Examples: + + >>> # xdoctest: +SKIP + >>> import torch + >>> import triton + >>> from triton import language as tl + >>> from torch.fx.experimental.proxy_tensor import make_fx + >>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton + >>> + >>> @triton.jit + >>> def add_kernel( + >>> in_ptr0, + >>> in_ptr1, + >>> out_ptr, + >>> n_elements, + >>> BLOCK_SIZE: "tl.constexpr", + >>> ): + >>> pid = tl.program_id(axis=0) + >>> block_start = pid * BLOCK_SIZE + >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) + >>> mask = offsets < n_elements + >>> x = tl.load(in_ptr0 + offsets, mask=mask) + >>> y = tl.load(in_ptr1 + offsets, mask=mask) + >>> output = x + y + >>> tl.store(out_ptr + offsets, output, mask=mask) + >>> + >>> def add(x, y): + >>> output = torch.empty_like(x) + >>> n_elements = output.numel() + >>> + >>> def grid_fn(meta): + >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + >>> + >>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16) + >>> return output + >>> + >>> x = torch.randn(3, device="cuda") + >>> y = torch.randn(3, device="cuda") + >>> gm = make_fx(add)(x, y) + >>> print(gm.code) + >>> # def forward(self, x_1, y_1): + >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False) + >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation( + >>> # kernel_idx = 0, constant_args_idx = 0, + >>> # grid = [(1, 1, 1)], kwargs = { + >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like, + >>> # 'n_elements': 3, 'BLOCK_SIZE': 16 + >>> # }) + >>> # return empty_like + + """ + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper + + if not isinstance(triton_kernel, (JITFunction, Autotuner)): + raise RuntimeError( + "capture_triton only works on functions annotated with triton.jit or triton.autotune" + ) + if not is_capture_triton_enabled(): + return triton_kernel + return TraceableTritonKernelWrapper(triton_kernel, None, None) diff --git a/lib/python3.10/site-packages/torch/_library/utils.py b/lib/python3.10/site-packages/torch/_library/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b8eec899d77e445428b278dc96c765d7b0f9ae6 --- /dev/null +++ b/lib/python3.10/site-packages/torch/_library/utils.py @@ -0,0 +1,318 @@ +# mypy: allow-untyped-defs +import dataclasses +import inspect +import sys +from typing import Any, Callable, Dict, Iterable, Tuple, Union + +import torch +from torch import _C, _utils_internal +from torch._ops import OpOverload + + +@dataclasses.dataclass +class Kernel: + """Models a (function, source location)""" + + func: Callable + source: str + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + +class RegistrationHandle: + """Does something when someone calls .destroy() on it""" + + def __init__(self, on_destroy: Callable): + self._on_destroy = on_destroy + + def destroy(self) -> None: + self._on_destroy() + + +def get_source(stacklevel: int) -> str: + """Get a string that represents the caller. + + Example: "/path/to/foo.py:42" + + Use stacklevel=1 to get the caller's source + Use stacklevel=2 to get the caller's caller's source + etc. + """ + frame = inspect.getframeinfo(sys._getframe(stacklevel)) + source = f"{frame.filename}:{frame.lineno}" + return source + + +def parse_namespace(qualname: str) -> Tuple[str, str]: + splits = qualname.split("::") + if len(splits) != 2: + raise ValueError( + f"Expected `qualname` to be of the form " + f'"namespace::name", but got {qualname}. ' + f"The qualname passed to the torch.library APIs must consist " + f"of a namespace and a name, e.g. aten::sin" + ) + return splits[0], splits[1] + + +def lookup_op(qualname: str) -> OpOverload: + namespace, name = parse_namespace(qualname) + if "." in name: + name, overload = name.split(".") + else: + overload = "default" + ns = getattr(torch.ops, namespace) + packet = getattr(ns, name) + return getattr(packet, overload) + + +def is_builtin(op: OpOverload) -> bool: + assert isinstance(op, OpOverload) + return op.namespace in {"aten", "prim", "prims"} + + +def is_functional_schema(schema: Any) -> bool: + """Check if the schema is functional. + + An operator is functional if: + - it does not mutate any of its inputs + - it does not return a view on any of its inputs + - it has at least one return + """ + + def is_functional(schema): + if schema.is_mutable: + return False + rets = schema.returns + is_non_mutating_view = len(rets) > 0 and any( + r.alias_info is not None and not r.alias_info.is_write for r in rets + ) + if is_non_mutating_view: + return False + if not schema.returns: + return False + return True + + if isinstance(schema, torch._C.FunctionSchema): + return is_functional(schema) + + # Lazy import because not all PyTorch builds have torchgen + from torchgen.model import FunctionSchema + + if isinstance(schema, str): + schema = FunctionSchema.parse(schema) + assert isinstance(schema, FunctionSchema) + return is_functional(schema) + + +# should be torch._C.JitType but that annotation is busted +def is_tensorlist_like_type(typ: Any) -> bool: + return ( + typ == _C.ListType(_C.TensorType.get()) + or typ == _C.ListType(_C.OptionalType(_C.TensorType.get())) + or typ == _C.OptionalType(_C.ListType(_C.TensorType.get())) + or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get()))) + ) + + +# should be torch._C.JitType but that annotation is busted +def is_tensor_like_type(typ: Any) -> bool: + return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get()) + + +def mutates_and_returns_first_arg(op: OpOverload): + """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg. + + TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this, + but not all PyTorch builds have torchgen (due to the yaml dependency being weird). + Figure this out. + + Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a) + """ + if op.namespace != "aten": + return False + schema = op._schema + if not len(schema.returns) == 1: + return False + if schema.returns[0].alias_info is None: + return False + alias_set = schema.returns[0].alias_info.after_set + if len(alias_set) != 1: + return False + loc = next(iter(alias_set)) + if len(schema.arguments) < 1: + return False + first_arg = schema.arguments[0] + if first_arg.alias_info is None: + return False + if not first_arg.alias_info.is_write: + return False + alias_set = first_arg.alias_info.after_set + if len(alias_set) != 1: + return False + if loc != next(iter(alias_set)): + return False + for arg in schema.arguments[1:]: + if arg.alias_info is not None: + return False + return True + + +def fill_defaults(schema, args, kwargs): + new_args = [] + new_kwargs = {} + for i in range(len(schema.arguments)): + info = schema.arguments[i] + if info.kwarg_only: + if info.name in kwargs: + new_kwargs[info.name] = kwargs[info.name] + else: + new_kwargs[info.name] = info.default_value + else: + if i < len(args): + new_args.append(args[i]) + else: + new_args.append(info.default_value) + return tuple(new_args), new_kwargs + + +def zip_schema( + schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any] +) -> Iterable[Tuple[_C.Argument, Any]]: + """zips schema.arguments and (args, kwargs) together. + + Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload: + that is, kwargs must be keyword-only arguments and default values may be omitted. + """ + assert len(schema.arguments) >= len(args) + len(kwargs) + for i in range(len(schema.arguments)): + info = schema.arguments[i] + if info.kwarg_only: + if info.name in kwargs: + yield info, kwargs[info.name] + continue + if i >= len(args): + # args that are equal to their default values are not populated + # if they are followed by args that are equal to their defaults. + # Skip these. + continue + yield info, args[i] + return + + +def hop_schema_from_fx_node(node): + from torchgen.gen_schema_utils import FunctionSchemaGen + + hop = node.target + if not isinstance(hop, torch._ops.HigherOrderOperator): + raise RuntimeError("fx_node's target must be a hop.") + + def _collect_example_val(node): + meta_val = node.meta.get("val", None) + if meta_val is None: + assert node.op == "get_attr" + meta_val = getattr(node.graph.owning_module, node.target) + return meta_val + + example_inputs = [] + for arg in node.args: + if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)): + example_inputs.append(_collect_example_val(arg)) + elif isinstance( + arg, (torch.fx.immutable_collections.immutable_list, list, tuple) + ): + example_inputs.append([_collect_example_val(x) for x in arg]) + else: + raise RuntimeError(f"Unsupported arg type {type(arg)}") + + # Bound the arguments to make sure number of inputs are correct + bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind( + *example_inputs + ) + + # We treat example_output as a single value in return. This is to differentiate 1. return a single val + # vs 2. return a tuple with one element. + example_output = _collect_example_val(node) + return FunctionSchemaGen.from_example( + hop._name, tuple(bound_args.arguments.items()), (list(example_output),) + ) + + +def can_generate_trivial_fake_impl(op: OpOverload) -> bool: + assert isinstance(op, OpOverload) + if is_builtin(op): + # We control the built-ins. These may (in rare cases) + # do input metadata mutation (which we have banned on custom ops) + return False + schema = op._schema + # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution + if not schema.is_mutable: + return False + if len(schema.returns) > 0: + return False + # If the op returns nothing, then it has a trivial fake impl. + return True + + +def requires_set_python_module() -> bool: + """If an op was defined in C++ and extended from Python using the + torch.library APIs, returns if we require that there have been a + m.set_python_module("mylib.ops") call from C++ that associates + the C++ op with a python module. + """ + return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True) + + +def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs): + assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode) + overload_types = [] + args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values())) + for a in args_flattened: + # TODO: need to double check the semantics of the "types" argument to torch_dispatch. + # It's generated in PyInterpreter.cpp, but seems to be generated in two places, + # where in one case we only include tensors with the python key, and in another + # we include **all** tensors. + if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has( + torch._C.DispatchKey.Python + ): + overload_types.append(type(a)) + # TODO: check that I got these args correct (in C++, we pass in "0000"??) + + return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs) + + +def has_kwarg_only_args(schema: _C.FunctionSchema): + return any(a.kwarg_only for a in schema.arguments) + + +def has_kwarg_only_tensors(schema: _C.FunctionSchema): + for a in schema.arguments: + if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)): + continue + if not a.kwarg_only: + continue + return True + return False + + +def has_tensor_arg(schema: _C.FunctionSchema) -> bool: + """ + Given a schema, returns True if the schema has a Tensor arg. + A Tensor arg is any arg with a type annotation that might involve Tensor. + """ + return any( + (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)) + for a in schema.arguments + ) + + +def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]: + """ + Given a schema, returns the id of the `device: torch.device` argument. + If it does not exist, returns None. + """ + for index, arg in enumerate(schema.arguments): + if arg.type is _C.DeviceObjType.get() and arg.name == "device": + return index + return None diff --git a/lib/python3.10/site-packages/torch/distributed/__init__.py b/lib/python3.10/site-packages/torch/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..181fab713febe834bd42bbd0881e61dec15938db --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/__init__.py @@ -0,0 +1,157 @@ +# mypy: allow-untyped-defs +import logging +import pdb +import sys +import traceback +import typing + +import torch + + +log = logging.getLogger(__name__) + + +def is_available() -> bool: + """ + Return ``True`` if the distributed package is available. + + Otherwise, + ``torch.distributed`` does not expose any other APIs. Currently, + ``torch.distributed`` is available on Linux, MacOS and Windows. Set + ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source. + Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows, + ``USE_DISTRIBUTED=0`` for MacOS. + """ + return hasattr(torch._C, "_c10d_init") + + +if is_available() and not torch._C._c10d_init(): + raise RuntimeError("Failed to initialize torch.distributed") + +# Custom Runtime Errors thrown from the distributed package +DistError = torch._C._DistError +DistBackendError = torch._C._DistBackendError +DistNetworkError = torch._C._DistNetworkError +DistStoreError = torch._C._DistStoreError + +if is_available(): + from torch._C._distributed_c10d import ( + _broadcast_coalesced, + _compute_bucket_assignment_by_size, + _ControlCollectives, + _DEFAULT_FIRST_BUCKET_BYTES, + _make_nccl_premul_sum, + _register_builtin_comm_hook, + _register_comm_hook, + _StoreCollectives, + _test_python_store, + _verify_params_across_processes, + Backend as _Backend, + BuiltinCommHookType, + DebugLevel, + FileStore, + get_debug_level, + GradBucket, + Logger, + PrefixStore, + ProcessGroup as ProcessGroup, + Reducer, + set_debug_level, + set_debug_level_from_env, + Store, + TCPStore, + Work as _Work, + ) + + class _DistributedPdb(pdb.Pdb): + """ + Supports using PDB from inside a multiprocessing child process. + + Usage: + _DistributedPdb().set_trace() + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open("/dev/stdin") + pdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + _breakpoint_cache: typing.Dict[int, typing.Any] = {} + + def breakpoint(rank: int = 0, skip: int = 0): + """ + Set a breakpoint, but only on a single rank. All other ranks will wait for you to be + done with the breakpoint before continuing. + + Args: + rank (int): Which rank to break on. Default: ``0`` + skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``. + """ + if skip > 0: + key = hash(str(traceback.format_exc())) + counter = _breakpoint_cache.get(key, 0) + 1 + _breakpoint_cache[key] = counter + if counter <= skip: + log.warning("Skip the breakpoint, counter=%d", counter) + return + + if get_rank() == rank: + pdb = _DistributedPdb() + pdb.message( + "\n!!! ATTENTION !!!\n\n" + f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n" + ) + pdb.set_trace() + # If Meta/Python keys are in the TLS, we want to make sure that we ignore them + # and hit the (default) CPU/CUDA implementation of barrier. + meta_in_tls = torch._C._meta_in_tls_dispatch_include() + guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] + torch._C._set_meta_in_tls_dispatch_include(False) + try: + barrier() + finally: + torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) + del guard + + if sys.platform != "win32": + from torch._C._distributed_c10d import HashStore + + from .device_mesh import DeviceMesh, init_device_mesh + + # Variables prefixed with underscore are not auto imported + # See the comment in `distributed_c10d.py` above `_backend` on why we expose + # this. + from .distributed_c10d import * # noqa: F403 + from .distributed_c10d import ( + _all_gather_base, + _coalescing_manager, + _CoalescingManager, + _create_process_group_wrapper, + _get_process_group_name, + _rank_not_in_group, + _reduce_scatter_base, + get_node_local_rank, + ) + from .remote_device import _remote_device + from .rendezvous import ( + _create_store_from_options, + register_rendezvous_handler, + rendezvous, + ) + + set_debug_level_from_env() + +else: + # This stub is sufficient to get + # python test/test_public_bindings.py -k test_correct_module_names + # working even when USE_DISTRIBUTED=0. Feel free to add more + # stubs as necessary. + # We cannot define stubs directly because they confuse pyre + + class _ProcessGroupStub: + pass + + sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined] diff --git a/lib/python3.10/site-packages/torch/distributed/_functional_collectives.py b/lib/python3.10/site-packages/torch/distributed/_functional_collectives.py new file mode 100644 index 0000000000000000000000000000000000000000..77b962bf3df47c996e87afec70c0efda5a9fabd3 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/_functional_collectives.py @@ -0,0 +1,1150 @@ +# mypy: allow-untyped-defs +import sys +import warnings +from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union + +import torch +import torch.distributed as dist +import torch.distributed.distributed_c10d as c10d +from torch.distributed.device_mesh import DeviceMesh +from torch.fx.experimental.proxy_tensor import get_proxy_mode + +from . import _functional_collectives_impl as fun_col_impl + + +try: + from torch.utils._cxx_pytree import tree_map_only +except ImportError: + from torch.utils._pytree import tree_map_only # type: ignore[no-redef] + + +if torch._running_with_deploy(): + + def is_torchdynamo_compiling(): + """Can't import torchdynamo in torchdeploy builds currently.""" + return False + +else: + try: + from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling + except Exception: + warnings.warn( + "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" + ) + + def is_torchdynamo_compiling(): + return False + + +""" +New traceable, functional collectives. +RFC: https://github.com/pytorch/pytorch/issues/93173 + + compiler: trace these ops with plain-old-data schemas, then choose how to lower them. + eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses, + automatically calling .wait() on underlying/hidden async 'work' obj only when fed to + a downstream op. + +Issues: +* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files +* Proper support for eager requires inplace ops. We should explore having it as an option for the API. +""" + +""" +Functional collectives are asynchronous only and we perform implicit stream synchronization +on behalf of the user. + +We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness +first usage of the tensor and insert cross stream sync at the right place. + +The above are the easy bits, the hard one is how we match the Work object returned by +c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective +op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the +dispatcher which might call other implementations that are allowed to change the returned +tensor - even return a tensor with a different shape (see ``torch.vmap``). + +This means the caller of our ops receives a Tensor that is not guaranteed to be the same +allocated by our implementations and that makes pairing The AsyncTensor to the original +tensor a lot harder. This pairing is needed so we can lookup the Work object to use. + +Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's +identity is not stable across dispatch, the op caller would end up with a different Tensor +instance that would not match any in the dictionary. + +With Tensor identity out of the question, we decided use the tensor data pointer, which +should be stable across all the Tensor changes done during dispatch. + +We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d. + +We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait() + +Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we +can clean up stale entries in the dictionary. + +To eliminate the possibility of races we have a global version counter that is used by the finalizer. + +As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo) + +""" + +""" +Functional collectives can accept any of these types to describe the ranks participating in collectives. + +The different types will be desugared to a canonical format +""" +RANK_TYPES = Union[ + List[int], + List[List[int]], + dist.ProcessGroup, + DeviceMesh, + Tuple["dist.tensor.DeviceMesh", int], + str, +] + + +""" +User facing APIs for functional collectives +------------------------------------------- + +These apis are called by user code and expected to work both in eager execution and compilation, +but there are significant differences to how the two modes are implemented underneath. + +Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op) +just before the tensor is first used. Compiled tracing currently relies on the compiler to perform this optimization, +and cannot yet correctly trace the AsyncTensor wrapper class. In the future, these paths may be unified +if sufficient subclass support is added in dynamo. + +Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern. + +Here's how it works under torch.compile/dynamo: +all_reduce(...) + |--> _expand_group(...) - desugars processgroup into canonical/traceable format + |--> c10d_functional.all_reduce(...) - dynamo captures this op call, doesn't trace deeper + |--> _maybe_wrap_tensor(...) - wait_tensor() op is immediately called, no AsyncTensor subclass needed + +And under eager execution: +all_reduce(...) + |--> _expand_group(...) - same as above, but less critical for eager + |--> c10d_functional.all_reduce(...) - dispatches to real kernel OR records op in trace + |--> _maybe_wrap_tensor(...) - AsyncTensor wrapper applied to returned tensor, + which issues wait_tensor() at the time of first use +""" + + +def wait_tensor(tensor): + """ + Wait on a tensor returned by the collectives ops. + + Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA. + """ + return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] + + +def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""): + """ + Broadcasts the tensor to all processes in the given process group. + + Args: + src (int): Source rank + group (ProcessGroup or List[int]): The process group to work on. + tag (str, optional): A unique identifier for the collective. Default: empty string + """ + group_name = _resolve_group_name(group, tag) + tensor = torch.ops._c10d_functional.broadcast(self, src, group_name) + return _maybe_wrap_tensor(tensor) + + +def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""): + """ + Reduces the tensor data across all machines in such a way that all get + the final result. + + The input tensor is left unmodified. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name) + return _maybe_wrap_tensor(tensor) + + +def all_gather_tensor( + self: torch.Tensor, + gather_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Gather tensor data across from all machines and concatenate over ``gather_dim``. + + Note that it currently only supports gather_dim = 0. + + The input tensor is left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + assert self.is_contiguous() + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + tensor = torch.ops._c10d_functional.all_gather_into_tensor( + self, group_size, group_name + ) + res = _maybe_wrap_tensor(tensor) + # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call + if gather_dim != 0: + # torch.cat access the data so we already need to wait here, first do wait + # and then chunk + cat avoid us going through ACT dispatching logic again + if isinstance(res, AsyncCollectiveTensor): + res = res.wait() # type: ignore[attr-defined] + res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim) + return res + + +def all_gather_tensor_autograd( + self: torch.Tensor, + gather_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Gather tensor data across from all machines and concatenate over ``gather_dim``. + + Note that it currently only supports gather_dim = 0. + + This function is the same as all_gather_tensor but will propagate the + backwards gradient across workers. + + See all_gather_tensor for more details on usage. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + tensor = torch.ops._c10d_functional_autograd.all_gather_into_tensor( + self, group_size, group_name + ) + res = _FromTorchTensor.apply(tensor) + # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call + if gather_dim != 0: + # torch.cat access the data so we already need to wait here, first do wait + # and then chunk + cat avoid us going through ACT dispatching logic again + if isinstance(res, AsyncCollectiveTensor): + res = res.wait() # type: ignore[attr-defined] + res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim) + return res + + +def reduce_scatter_tensor( + self: torch.Tensor, + reduceOp: str, + scatter_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Reduces the tensor data across all machines in such a way that all get + the final result, then scatter the results to corresponding ranks. + + + The input tensor is left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + assert ( + self.size(scatter_dim) % group_size == 0 + ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" + if scatter_dim != 0: + tensor_list = torch.chunk(self, group_size, dim=scatter_dim) + self = torch.cat(tensor_list) + + tensor = torch.ops._c10d_functional.reduce_scatter_tensor( + self, + reduceOp.lower(), + group_size, + group_name, # type: ignore[possibly-undefined] + ) + res = _maybe_wrap_tensor(tensor) + return res + + +def reduce_scatter_tensor_autograd( + self: torch.Tensor, + reduceOp: str, + scatter_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Reduces the tensor data across all machines in such a way that all get + the final result, then scatter the results to corresponding ranks. + + This function is the same as reduce_scatter_tensor but will propagate the + backwards gradient across workers. + + Currently only the "sum" reduceOp is supported. + + See reduce_scatter_tensor for more details on usage. + """ + + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + assert ( + self.size(scatter_dim) % group_size == 0 + ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" + if scatter_dim != 0: + tensor_list = torch.chunk(self, group_size, dim=scatter_dim) + self = torch.cat(tensor_list) + + tensor = torch.ops._c10d_functional_autograd.reduce_scatter_tensor( + self, + reduceOp.lower(), + group_size, + group_name, # type: ignore[possibly-undefined] + ) + res = _FromTorchTensor.apply(tensor) + return res + + +def all_reduce_coalesced( + self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = "" +) -> List[torch.Tensor]: + """ + Reduces a list of tensors across all machines in such a way that all get + the final result. + + The all tensors in the input list are left unmodified. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + tensor_list = torch.ops._c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined] + self, + reduceOp.lower(), + group_name, + ) + return list(map(_maybe_wrap_tensor, tensor_list)) + + +def all_gather_into_tensor_coalesced( + self: List[torch.Tensor], group: RANK_TYPES, tag: str = "" +) -> List[torch.Tensor]: + """ + Gather a list of tensors across from all machines. + + Note that it currently only supports gather_dim = 0. + + The input tensor is left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined] + self, + group_size, + group_name, + ) + return list(map(_maybe_wrap_tensor, tensor_list)) + + +def reduce_scatter_tensor_coalesced( + inputs: List[torch.Tensor], + reduceOp: str, + scatter_dim: List[int], + group: RANK_TYPES, + tag: str = "", +) -> List[torch.Tensor]: + """ + Reduces a list of tensors across all machines in such a way that all get + the final result, then scatter the results to corresponding ranks. + + The input tensors are left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + assert len(scatter_dim) == len(inputs) + for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)): + assert ( + tensor.size(dim) % group_size == 0 + ), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" + if dim != 0: + tensor_list = torch.chunk(tensor, group_size, dim=dim) + inputs[idx] = torch.cat(tensor_list) + + tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined] + inputs, + reduceOp.lower(), + group_size, + group_name, # type: ignore[possibly-undefined] + ) + + return list(map(_maybe_wrap_tensor, tensor_list)) + + +# This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias. +# Today, this maps 1:1 with "aten ops that are views". +def _is_view_op(tgt): + assert isinstance(tgt, torch._ops.OpOverload) + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + return first_arg.alias_info is not None and not first_arg.alias_info.is_write + + +def all_to_all_single( + self: torch.Tensor, + output_split_sizes: Optional[List[int]], + input_split_sizes: Optional[List[int]], + group: RANK_TYPES, + tag: str = "", +) -> torch.Tensor: + """ + Each process splits input tensor and then scatters the split list + to all processes in a group. Then concatenate the received tensors from all + the processes in the group and return single output tensor. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + if output_split_sizes is not None: + assert all( + isinstance(size, (int, torch.SymInt)) for size in output_split_sizes + ), output_split_sizes + if input_split_sizes is not None: + assert all( + isinstance(size, (int, torch.SymInt)) for size in input_split_sizes + ), input_split_sizes + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + if output_split_sizes is None or input_split_sizes is None: + assert output_split_sizes is None and input_split_sizes is None, ( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) + output_split_sizes = [self.shape[0] // group_size] * group_size + input_split_sizes = output_split_sizes + tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined] + self, + output_split_sizes, + input_split_sizes, + group_name, + ) + return _maybe_wrap_tensor(tensor) + + +def all_to_all_single_autograd( + self: torch.Tensor, + output_split_sizes: Optional[List[int]], + input_split_sizes: Optional[List[int]], + group: RANK_TYPES, + tag: str = "", +) -> torch.Tensor: + """ + Same as all_to_all_single but supports autograd. + """ + if output_split_sizes is not None: + assert all( + isinstance(size, (int, torch.SymInt)) for size in output_split_sizes + ), output_split_sizes + if input_split_sizes is not None: + assert all( + isinstance(size, (int, torch.SymInt)) for size in input_split_sizes + ), input_split_sizes + + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + if output_split_sizes is None or input_split_sizes is None: + assert output_split_sizes is None and input_split_sizes is None, ( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) + output_split_sizes = [self.shape[0] // group_size] * group_size + input_split_sizes = output_split_sizes + tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined] + self, + output_split_sizes, + input_split_sizes, + group_name, + ) + return _FromTorchTensor.apply(tensor) + + +def permute_tensor( + self: torch.Tensor, + src_dst: List[int], + group: RANK_TYPES, + tag: str = "", +) -> torch.Tensor: + """ + Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should + be defined such that src_dst[m] == n means m sends to n. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one + """ + t, rankset, group_size = _expand_group(group, tag) + local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size) + + output_split_sizes = [0] * group_size + input_split_sizes = [0] * group_size + for src, dst in enumerate(src_dst): + if src == dist.get_rank(local_pg): + input_split_sizes[dst] = self.numel() + if dst == dist.get_rank(local_pg): + output_split_sizes[src] = self.numel() + + return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag) + + +class AsyncCollectiveTensor(torch.Tensor): + r""" + A Tensor wrapper subclass that is used to trigger a call to wait + prior to first use of the underlying tensor. + Use it inside functional collective pytorch wrappers like the following: + def functional_collective(self, group, tag): + tag, rankset, group_size = _expand_group(group, tag) + tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size) + return _maybe_wrap_tensor(tensor) + """ + elem: torch.Tensor + completed: bool + + __slots__ = ["elem", "completed"] + + @staticmethod + def __new__(cls, elem: torch.Tensor): + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=elem.requires_grad, + ) + r.elem = elem + r.completed = False + return r + + def __tensor_flatten__(self): + return ["elem"], None + + def tolist(self): + return self.trigger_wait().tolist() + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is None + elem = inner_tensors["elem"] + return AsyncCollectiveTensor(elem) + + def __repr__(self): + return f"AsyncCollectiveTensor({self.trigger_wait()})" + + def trigger_wait(self): + if not self.completed: + out = wait_tensor(self.elem) + self.completed = True + return out + else: + return self.elem + + def wait(self) -> torch.Tensor: + return wait_tensor(self.elem) + + def _get_acs_underlying_tensor(self): + """This method enables _functional_collectives_impl to test if a tensor is an ACS""" + return self.elem + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if func == torch.ops.aten.view.default: + # Fast handle aten.view as a lot of view related op goes to aten.view + # eventually, this avoids pytree slowdown + res = func(args[0].elem, args[1]) + wrapper_res = AsyncCollectiveTensor(res) + return wrapper_res + + is_view_op = _is_view_op(func) + + def unwrap(e: AsyncCollectiveTensor): + # wait_tensor is idepotent and will do stream sync only once + if not is_view_op: + return e.trigger_wait() + return e.elem + + def wrap(e: torch.Tensor): + # wait_tensor is idepotent and will do stream sync only once + assert not isinstance(e, AsyncCollectiveTensor) + res = AsyncCollectiveTensor(e) + return res + + unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args) + unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs) + + # we don't wrap the result as it doesn't need to be waited on. + out = func(*unwrapped_args, **unwrapped_kwargs) + + # View ops dont require a sync, so we should re-wrap the outputs. + if is_view_op: + out = tree_map_only(torch.Tensor, wrap, out) + + return out + + def numpy(self): + return self.wait().numpy() + + +""" +Utils and infrastructure for tracing support +""" + + +def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]: + """ + _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable. + + By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside + torchdynamo and can still interoperate with processgroup objects or other untraceable forms. + """ + # had to define this hack _inside_ expand_group to avoid + # graph_break [('torch.* op returned non-Tensor int + # caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc) + if TYPE_CHECKING: + + def cast_listlistint(x): + return cast(List[List[int]], x) + + def cast_listint(x): + return cast(List[int], x) + + else: + # fake cast op for use at runtime since dynamo doesn't support real cast + # also, dynamo didn't like encountering 'typing' objects () + # NotImplementedError: argument of type: + def cast_listlistint(x): + return x + + def cast_listint(x): + return x + + rankset: List[int] + if isinstance(group, list): + if isinstance(group[0], list): + nested_list = cast_listlistint(group) + rankset = [] + group_size = -1 + for rs in nested_list: + rankset.extend(rs) + if group_size != -1 and group_size != len(rs): + raise ValueError( + f"group sizes must be identical found {group_size} and {len(rs)}" + ) + group_size = len(rs) + else: + rankset = cast_listint(group) + group_size = len(rankset) + elif isinstance(group, dist.ProcessGroup): + rankset = dist.get_process_group_ranks(group) + group_size = len(rankset) + tag = tag or c10d._get_group_tag(group) + elif isinstance(group, DeviceMesh): + assert ( + group.ndim == 1 + ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + # TODO: it should run collective in the whole mesh instead of dim 0 + tag, rankset, _ = group._dim_group_infos[0] + group_size = len(rankset) + elif isinstance(group, tuple): + if ( + len(group) == 2 + and isinstance(group[0], DeviceMesh) + and isinstance(group[1], int) + ): + dmesh = group[0] + dim = group[1] + tag, rankset, _ = dmesh._dim_group_infos[dim] + group_size = len(rankset) + else: + raise ValueError("Invalid tuple for group must be (DeviceMesh, int)") + else: + raise ValueError( + "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)." + ) + + return (tag, rankset, group_size) + + +def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str: + """ + Given group in RANK_TYPES, return the group name. + """ + # `tag` will be deprecated. See details in: + # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208 + if isinstance(group, dist.ProcessGroup): + return group.group_name + elif isinstance(group, str): + return group + elif isinstance(group, DeviceMesh): + assert ( + group.ndim == 1 + ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + return group._dim_group_infos[0][2] + elif isinstance(group, tuple): + if ( + len(group) == 2 + and isinstance(group[0], DeviceMesh) + and isinstance(group[1], int) + ): + dmesh = group[0] + dim = group[1] + return dmesh._dim_group_infos[dim][2] + else: + raise ValueError("Invalid tuple for group must be (DeviceMesh, int)") + elif isinstance(group, list): + if not is_torchdynamo_compiling(): + warnings.warn( + "The combination of ranks + tag as process group " + "identifier has been deprecated. Please switch to " + "using ProcessGroup, DeviceMesh, or group name instead.", + FutureWarning, + stacklevel=3, + ) + return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag) + else: + raise ValueError(f"Unsupported group type: {type(group)}, {group}") + + +class _FromTorchTensor(torch.autograd.Function): + """ + _FromTorchTensor allows autograd to propagate from a normal Tensor to an + AsyncCollectiveTensor. + """ + + @staticmethod + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + input: torch.Tensor, + ) -> torch.Tensor: + return _maybe_wrap_tensor(input) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # type: ignore[override] + return grad_output + + +def _are_we_tracing() -> bool: + if is_torchdynamo_compiling(): + return True + # If functionalization is turned on, we are almost definitely compiling/tracing. + # (In particular, AOTAutograd traces a model once with functionalization on + # but proxy tracing turned of, so this is how we detect it). + if ( + torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) + is not None + ): + return True + return get_proxy_mode() is not None + + +def _maybe_wrap_tensor(self) -> torch.Tensor: + if _are_we_tracing(): + return wait_tensor(self) + res = AsyncCollectiveTensor(self) + return cast(torch.Tensor, res) + + +def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size): + def mk_out_tensor(shard): + out_size = list(shard.size()) + out_size[0] *= group_size + out_tensor = shard.new_empty(out_size) + return out_tensor + + return [mk_out_tensor(t) for t in self] + + +# We now register meta kernels to deal with tracing +def _broadcast_meta(self, *args): + return torch.empty_like(self) + + +def _all_reduce_meta(self, *args): + return torch.empty_like(self) + + +def _wait_tensor_meta(self, *args): + return torch.empty_like(self) + + +def _all_gather_into_tensor_meta(shard, tag, rankset, group_size): + out_size = list(shard.size()) + out_size[0] *= group_size + return shard.new_empty(out_size) + + +def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size): + out_size = list(input.size()) + out_size[0] //= group_size + return input.new_empty(out_size) + + +def _all_reduce_coalesced_meta(self, *args): + return [torch.empty_like(t) for t in self] + + +def _all_reduce__meta(inp, *args): + return inp + + +def _broadcast__meta(inp, *args): + return inp + + +def _all_reduce_coalesced__meta(inputs, *args): + return inputs + + +def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size): + def mk_out_tensor(input): + out_size = list(input.size()) + out_size[0] //= group_size + out_tensor = input.new_empty(out_size) + return out_tensor + + return [mk_out_tensor(t) for t in inputs] + + +# NB: We often say all_to_all has dynamic output size, but this is not +# technically true: instead, what typically happens is you manually +# communicate the output_split_sizes ahead of time (which is dynamic), +# but then you pass those sizes explicitly, and the all to all itself +# isn't dynamic, it just follows the specified output splits +def _all_to_all_single_meta( + input, output_split_sizes, input_split_sizes, *args, **kwargs +): + if output_split_sizes is None: + return input.new_empty(input.size()) + else: + for s in output_split_sizes: + torch._check_is_size(s) + out_size = list(input.size()) + out_size[0] = sum(output_split_sizes) + return input.new_empty(out_size) + + +def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out): + shape = list(input.size()) + shape[0] *= group_size + return input.new_empty(shape) + + +def _all_gather_into_tensor_native_meta(input, group_size, group_name): + shape = list(input.size()) + shape[0] *= group_size + return input.new_empty(shape) + + +def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name): + return [ + _all_gather_into_tensor_native_meta(input, group_size, group_name) + for input in inputs + ] + + +def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name): + shape = list(inp.size()) + shape[0] //= group_size + return inp.new_empty(shape) + + +def _reduce_scatter_tensor_coalesced_native_meta( + inputs, reduce_op, group_size, group_name +): + return [ + _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name) + for inp in inputs + ] + + +if not torch._running_with_deploy(): + # Library MUST be defined at module scope or it doesn't work + # Creating a "DEF" Library always crashes torch::deploy so we create our + # Library instances here guarded against running inside it + lib_impl = torch.library.Library("_c10d_functional", "IMPL") + lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") + lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") + lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") + lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") + lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") + lib_impl.impl( + "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" + ) + lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") + lib_impl.impl( + "all_gather_into_tensor_coalesced", + _all_gather_into_tensor_coalesced_native_meta, + "Meta", + ) + lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") + lib_impl.impl( + "reduce_scatter_tensor_coalesced", + _reduce_scatter_tensor_coalesced_native_meta, + "Meta", + ) + lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") + lib_impl.impl("broadcast", _broadcast_meta, "Meta") + lib_impl.impl("broadcast_", _broadcast__meta, "Meta") + + # mark these ops has side effect so that they won't be removed by DCE + torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) + torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) + + # Register legacy ops for backward compatibility + # TODO(yifu): remove these in functional collective beta release + legacy_lib = torch.library.Library("c10d_functional", "DEF") + legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL") + ops_defs = [ + "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "wait_tensor(Tensor self) -> Tensor", + "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", + "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", + "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 + ] + + my_module = sys.modules[__name__] + for op_def in ops_defs: + op_name = op_def[0 : op_def.index("(")] + backend_impl = getattr(fun_col_impl, f"_{op_name}") + legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) + legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") + +else: + warnings.warn( + "PyTorch Distributed functional collectives do not work with torch::deploy." + ) + + +""" +Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into +functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph. + +We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via +the mapping dict below. + +These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from +""" + + +def all_gather_tensor_inplace( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group, # TODO add a type, + async_op: bool = False, + tag: str = "", + gather_dim: int = 0, +): + assert ( + not async_op + ), "Can't remap async version of inplace op to functional collective" + + group = group or dist.group.WORLD + assert group is not None + + return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) + + +def reduce_scatter_tensor_inplace( + output: torch.Tensor, + input: torch.Tensor, + op: str = "sum", # TODO type is actually c10d ReduceOp. is this ok? + group=None, # TODO add a type + async_op: bool = False, + scatter_dim: int = 0, + tag: str = "", +): + assert ( + not async_op + ), "Can't remap async version of inplace op to functional collective" + + group = group or dist.group.WORLD + assert group is not None + + return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag)) + + +REDUCE_OP_TO_STR = { + dist.ReduceOp.SUM: "sum", + dist.ReduceOp.AVG: "avg", + dist.ReduceOp.PRODUCT: "product", + dist.ReduceOp.MIN: "min", + dist.ReduceOp.MAX: "max", + dist.ReduceOp.BAND: "band", + dist.ReduceOp.BOR: "bor", + dist.ReduceOp.BXOR: "bxor", +} + + +def all_reduce_inplace( + tensor: torch.Tensor, + op: str = "sum", + group=None, + async_op: bool = False, + tag: str = "", +): + assert ( + not async_op + ), "Can't remap async version of inplace op to functional collective" + + group = group or dist.group.WORLD + assert group is not None + + return tensor.copy_(all_reduce(tensor, op, group, tag)) + + +def all_to_all_inplace( + output: torch.Tensor, + input: torch.Tensor, + output_split_sizes=None, + input_split_sizes=None, + group=None, + async_op=False, + tag: str = "", +): + assert ( + not async_op + ), "Can't remap async version of inplace op to functional collective" + + group = group or dist.group.WORLD + assert group is not None + + return output.copy_( + all_to_all_single( + input, + output_split_sizes, + input_split_sizes, + group, + tag, + ) + ) + + +def all_gather_inplace( + tensor_list: List[torch.Tensor], + tensor: torch.Tensor, + group=None, + async_op=False, + tag: str = "", +): + assert ( + not async_op + ), "Can't remap async version of inplace op to functional collective" + assert all( + t.size(0) == tensor.size(0) for t in tensor_list + ), "Remapping variable size all_gather is not yet supported" + + group = group or dist.group.WORLD + assert group is not None + + output = all_gather_tensor(tensor, 0, group, tag) + + # Use aten.slice instead of aten.split because the latter causes + # tensor.shape(0) to be unnecessarily baked in when it's a SymInt. + output_splits = [] + offset = 0 + for t in tensor_list: + output_splits.append(output[offset : offset + t.size(0)]) + offset += t.size(0) + for dst, src in zip(tensor_list, output_splits): + dst.copy_(src) + return tensor_list + + +from torch.distributed.distributed_c10d import ( + _all_gather_base as legacy_all_gather_base, + _reduce_scatter_base as legacy_reduce_scatter_base, + all_gather as legacy_all_gather, + all_gather_into_tensor as legacy_allgather, + all_reduce as legacy_allreduce, + all_to_all_single as legacy_all_to_all_single, + reduce_scatter_tensor as legacy_reducescatter, +) + + +# This dict should contain sets of functions that dynamo is allowed to remap. +# Functions in this set should accept the same args/kwargs 1:1 as their mapping. +traceable_collective_remaps = { + legacy_allgather: all_gather_tensor_inplace, + legacy_reducescatter: reduce_scatter_tensor_inplace, + legacy_allreduce: all_reduce_inplace, + legacy_all_to_all_single: all_to_all_inplace, + legacy_all_gather: all_gather_inplace, + legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, + legacy_all_gather_base: all_gather_tensor_inplace, +} diff --git a/lib/python3.10/site-packages/torch/distributed/_functional_collectives_impl.py b/lib/python3.10/site-packages/torch/distributed/_functional_collectives_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd193d662bd6de80e87707455d0d1916410bf84 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/_functional_collectives_impl.py @@ -0,0 +1,117 @@ +# mypy: allow-untyped-defs +from typing import List, Optional + +import torch +import torch.distributed.distributed_c10d as c10d + + +""" +This file contains the op impls for the legacy (c10d_functional) functional collectives. +These impls simply call into the native (_c10d_functional) functional collectives. +""" + + +def _broadcast(input, src, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.broadcast( + input, + src, + group_name, + ) + + +def _all_reduce(input, reduce_op, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_reduce( + input, + reduce_op, + group_name, + ) + + +def _all_reduce_coalesced(inputs, reduce_op, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_reduce_coalesced( + inputs, + reduce_op, + group_name, + ) + + +def _all_gather_into_tensor(input, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_gather_into_tensor( + input, + group_size, + group_name, + ) + + +def _all_gather_into_tensor_coalesced(input, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_gather_into_tensor_coalesced( + input, + group_size, + group_name, + ) + + +def _reduce_scatter_tensor( + input: torch.Tensor, + reduce_op: str, + tag: str, + ranks: List[int], + group_size: int, +): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.reduce_scatter_tensor( + input, + reduce_op, + group_size, + group_name, + ) + + +def _reduce_scatter_tensor_coalesced( + inputs: List[torch.Tensor], + reduce_op: str, + tag: str, + ranks: List[int], + group_size: int, +): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( + inputs, + reduce_op, + group_size, + group_name, + ) + + +def _all_to_all_single( + input: torch.Tensor, + output_split_sizes: Optional[List[int]], + input_split_sizes: Optional[List[int]], + tag: str, + ranks: List[int], + group_size: int, +): + if output_split_sizes is None or input_split_sizes is None: + assert output_split_sizes is None and input_split_sizes is None, ( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) + output_split_sizes = [input.shape[0] // group_size] * group_size + input_split_sizes = output_split_sizes + + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_to_all_single( + input, + output_split_sizes, + input_split_sizes, + group_name, + ) + + +def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor: + return torch.ops._c10d_functional.wait_tensor(tensor) diff --git a/lib/python3.10/site-packages/torch/distributed/_state_dict_utils.py b/lib/python3.10/site-packages/torch/distributed/_state_dict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3a021c0a318796b89cb7021f95b8371762e847 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/_state_dict_utils.py @@ -0,0 +1,753 @@ +# mypy: allow-untyped-defs +import copy +import io +import math +import weakref +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Mapping, + MutableMapping, + NamedTuple, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed._functional_collectives import AsyncCollectiveTensor + + +if dist.is_available() or TYPE_CHECKING: + from torch.distributed import distributed_c10d + from torch.distributed._shard.sharded_tensor import ShardedTensor + from torch.distributed.tensor import distribute_tensor, DTensor, Replicate + from torch.distributed.tensor._utils import compute_local_shape_and_global_offset + + +def _identity_func( + obj: torch.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + companion_obj: Any, +) -> torch.Tensor: + return obj + + +def _all_gather_sharded_tensor( + sharded_tensor: "ShardedTensor", + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, +) -> torch.Tensor: + if pg is None: + pg = distributed_c10d._get_default_group() + world_size = dist.get_world_size(pg) + shards = sharded_tensor.local_shards() + dim_0_size = sharded_tensor.size()[0] # type: ignore[index] + tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr] + chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size + pg_device = ( + distributed_c10d._get_pg_default_device(pg) if device is None else device + ) + if shards: + local_tensor = shards[0].tensor.flatten() + if local_tensor.device.type != pg_device.type: + local_tensor = local_tensor.to(pg_device) + num_padding = chunk_size - local_tensor.numel() + if num_padding > 0: + local_tensor = F.pad(local_tensor, [0, num_padding]) + else: + local_tensor = torch.zeros( + chunk_size, dtype=sharded_tensor.dtype, device=pg_device + ) + + tensor = torch.empty( + chunk_size * world_size, + dtype=local_tensor.dtype, + device=pg_device, + ) + dist.all_gather_into_tensor(tensor, local_tensor, group=pg) + + tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size()) + return tensor + + +class CompanionMismatch(Exception): + ... + + +def _iterate_state_dict( + iter_object: Any, + sharded_tensor_func: Callable, + dtensor_func: Callable, + tensor_func: Callable, + *, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + cpu_offload: bool = False, + companion_obj: Any = None, + ranks_only: Tuple[int, ...] = (), + type_check: bool = True, + non_blocking: bool = True, +) -> Dict[str, Any]: + """Iterate through the state dict, applying the given functions to each tensor type. + + Args: + iter_object (Any): the target state_dict. + sharded_tensor_func (Callable): the function to apply to ShardedTensor + dtensor_func (Callable): the function to apply to DTensor + tensor_func (Callable): the function to apply to Tensor + pg (Optional[dist.ProcessGroup]): process group passed to tensor functions + device (Optional[torch.device]): device passed to tensor functions + cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored + if a companion_obj is supplied. + companion_obj (Any): A companion object to the state dict. If this object + is supplied, we attempt to copy the tensor to the companion object. + ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + non_blocking (bool): whether to use non-blocking copy when copying to the companion object. + """ + # TODO: should we use pytree? + cpu_device = torch.device("cpu") + if isinstance(iter_object, ShardedTensor): + ret = sharded_tensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, DTensor): + ret = dtensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, torch.Tensor): + ret = tensor_func(iter_object, pg, device, companion_obj) + elif ( + isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) + or iter_object is None + ): + ret = iter_object + elif isinstance(iter_object, dict): + if companion_obj is not None and ( + not isinstance(companion_obj, dict) + or set(companion_obj.keys()) != set(iter_object.keys()) + ): + msg = ( + "" + if isinstance(companion_obj, dict) + else f"{set(companion_obj.keys())=} {set(iter_object.keys())=}" + ) + raise CompanionMismatch(msg) + + ret = { + key: _iterate_state_dict( + value, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + companion_obj=companion_obj[key] if companion_obj is not None else None, + ranks_only=ranks_only, + type_check=type_check, + non_blocking=non_blocking, + ) + for key, value in iter_object.items() + } + elif isinstance(iter_object, (list, tuple)): + if companion_obj is not None and ( + not isinstance(companion_obj, (list, tuple)) + or len(companion_obj) != len(iter_object) + ): + raise CompanionMismatch + + ret = [ + _iterate_state_dict( + v, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + companion_obj=companion_obj[idx] if companion_obj is not None else None, + ranks_only=ranks_only, + type_check=type_check, + non_blocking=non_blocking, + ) + for idx, v in enumerate(iter_object) + ] + if isinstance(iter_object, tuple): + ret = tuple(ret) + elif not type_check: + ret = copy.deepcopy(iter_object) + else: + raise ValueError(f"Unexpected value type {type(iter_object)}") + + if not ranks_only or dist.get_rank(pg) in ranks_only: + if isinstance(ret, torch.Tensor): + if cpu_offload and companion_obj is None: + ret = ret.to(cpu_device) + + if companion_obj is not None: + # TODO: support DTensor + companion_obj.copy_(ret, non_blocking=non_blocking) + ret = companion_obj + else: + ret = {} if isinstance(ret, dict) else None + + return ret + + +def _gather_state_dict( + state_dict: Dict[str, Any], + *, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + cpu_offload: bool = False, + ranks_only: Tuple[int, ...] = (), + type_check: bool = True, +) -> Dict[str, Any]: + """ + Given a state_dict, this API gathers all the ShardedTensors or DTensors in + the state_dict. + + + Args: + state_dict (Dict[str, Any]): the target sharded state_dict. + pg (Optional[dist.ProcessGroup]): the process group that is used to + gather ShardedTensor. Note that gathering a DTensor will use + the DeviceMesh. So this argument will be ignored when gathering a + DTensor. + device: (Optional[torch.device]): the device that is used to + perform allgather for ShardedTensor. Note that gathering a DTensor + will use the DeviceMesh. So this argument will be ignored when + gathering a DTensor. + cpu_offload (bool): whether to offload the tensors to CPU memory. The + default value is False. + ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check: (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + The gathered state dictionary. + """ + + def sharded_tensor_func(value, pg, device, companion_obj): + # ShardedTensor does not seem to record the original device type. + # So if the tensor is moved to CPU, we won't know the original type. + # As a result, we have to rely on the user to tell us the correct one. + cpu_device = torch.device("cpu") + output_tensor = _all_gather_sharded_tensor(value, pg, device) + local_shard_device = ( + value.local_shards()[0].tensor.device + if value.local_shards() + else cpu_device + ) + if output_tensor.device != local_shard_device: + value = output_tensor.to(local_shard_device) + else: + value = output_tensor + return value + + def dtensor_func(value, pg, device, companion_obj): + if value.device != value.device_mesh.device_type: + value = value.to(value.device_mesh.device_type) + # FSDP all_gather: [Shard(0)] -> [Replicate()] + # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] + # 2D FSDP + TP all_gather: + # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()] + # - [Shard(0), Replicate()] -> [Replicate(), Replicate()] + placements = [Replicate() for _ in value.placements] + value = value.redistribute( + device_mesh=value.device_mesh, + placements=placements, + ) + # Call `wait()` to force the tensor to be synchronous with respect + # to the main stream. + # See the discussion in https://github.com/pytorch/pytorch/pull/117799. + value = value.to_local() + if isinstance(value, AsyncCollectiveTensor): + value = value.wait() + return value + + return _iterate_state_dict( + state_dict, + sharded_tensor_func, + dtensor_func, + _identity_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + ranks_only=ranks_only, + type_check=type_check, + ) + + +def _offload_state_dict_to_cpu( + state_dict: Dict[str, Any], + *, + ranks_only: Tuple[int, ...] = (), + type_check: bool = True, +) -> Dict[str, Any]: + """ + Given a state_dict, this API offload all the tensors to CPU memory. + + Args: + state_dict (Dict[str, Any]): the target state_dict. + pg (Optional[dist.ProcessGroup]): the process group that is used to + gather ShardedTensor. Note that gathering a DTensor will use + the DeviceMesh. So this argument will be ignored when gathering a + DTensor. + ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check: (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + The gathered state dictionary. + """ + + ret = _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + _identity_func, + pg=None, + device=None, + cpu_offload=True, + ranks_only=ranks_only, + type_check=type_check, + ) + return ret + + +def _copy_state_dict( + state_dict: Dict[str, Any], + copy_state_dict: Dict[str, Any], + non_blocking: bool = False, + type_check: bool = True, +) -> Dict[str, Any]: + """ + Copies all tensors in a given state dict into a different state_dict with the + same structure. Additionally, a copied state dict with the same value references + is returned. Editing the keys on this state dict will not affect the + passed in copy_state_dict (but the value references are the same). + + .. warning:: + It is expected by this function that state_dict and copy_state_dict share + the same structure and data types. + + .. warning:: + The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Args: + state_dict (Dict[str, Any]): the target state_dict. + copy_state_dict (Dict[str, Any]): + The state dict we are copying into. This state_dict must have exactly + the same structure as the source `state_dict`. + non_blocking: (bool): Whether copy ops should be performed asynchronously + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + State Dict copy + """ + + return _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + _identity_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + companion_obj=copy_state_dict, + type_check=type_check, + non_blocking=non_blocking, + ) + + +def _create_cpu_state_dict( + state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False +) -> Dict[str, Any]: + """ + Given a state_dict, create another state_dict with the same structure and elements. + However, all tensors in the returned state_dict are new tensors on CPU. These + tensors can be placed on pin_memory or share_memory based on the provided arguments. + + .. warning:: + Setting both `pin_memory` and `share_memory` to True significantly increases the + latency of this method because of the nuances which require us to register memory + as pinned directly as opposed to relying on the pin_memory cache allocator. This + option should only be used for long lived tensors which are required to be shared. + This is not the case as long as at least one of `pin_memory` or `share_memory` is + set to False. + + """ + + def tensor_func( + obj: torch.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + _: Any, + ) -> torch.Tensor: + if len(obj.size()) == 0: + return torch.tensor(0, dtype=obj.dtype) + + if share_memory: + t = torch.empty(*tuple(obj.size()), dtype=obj.dtype) + t = t.share_memory_() + if pin_memory: + + def unpin_memory(t): + succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr())) + assert ( + succ == 0 + ), f"Unpinning shared memory failed with error-code: {succ}" + + weakref.finalize(t, unpin_memory, t) + succ = int( + torch.cuda.cudart().cudaHostRegister( + t.data_ptr(), + t.numel() * t.element_size(), + 1, # lines up with 'cudaHostRegisterPortable' + ) + ) + assert ( + succ == 0 + ), f"Pinning shared memory failed with error-code: {succ}" + return t + elif pin_memory: + return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory() + else: + return torch.empty(*tuple(obj.size()), dtype=obj.dtype) + + ret = _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + tensor_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + type_check=False, + ) + return ret + + +def _check_state_dict_similarity( + state_dict: Dict[str, Any], + compared_state_dict: Dict[str, Any], +) -> bool: + """ + Given two state_dicts, check if the structures are the same. And + if a [key, tensor] pair exist in one state_dict there must be + the a corresponding pait, [key, other_tensor], in the other state_dict, + where tensor and other_tensor have the same size and dtype. + + Return the check result. + """ + + def tensor_func( + obj: torch.Tensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + companion_obj: Any, + ) -> torch.Tensor: + if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size(): + raise CompanionMismatch + return obj + + try: + _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + tensor_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + companion_obj=compared_state_dict, + type_check=False, + ) + except CompanionMismatch: + return False + + return True + + +class _TensorInfo(NamedTuple): + size: torch.Size + dtype: torch.dtype + + +def _broadcast_tensors( + full_state_dict: Dict[str, Any], + local_state_dict: Dict[str, Any], + keys: List[str], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + tensors = [] + for key in keys: + if dist.get_rank() == 0: + full_state = full_state_dict[key] + assert isinstance(full_state, torch.Tensor) + full_tensor = full_state.detach().to(device) + else: + tensor_info = full_state_dict[key] + full_tensor = torch.empty( + size=tensor_info.size, + device=device, + dtype=tensor_info.dtype, + ) + + tensors.append(full_tensor) + local_state = local_state_dict.get(key, None) + if local_state is None: + continue + elif isinstance(local_state, DTensor): + local_state_dict[key] = (local_state, full_tensor) + else: + local_state_dict[key] = full_tensor + + if pg is None: + pg = dist.distributed_c10d._get_default_group() + + if len(tensors) > 1: + dist._broadcast_coalesced(pg, tensors, 500, 0) + else: + dist.broadcast(tensors[0], src=0, group=pg) + + _distribute_tensors(local_state_dict, keys, device, pg) + + +def _distribute_tensors( + local_state_dict: Dict[str, Any], + keys: List[str], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + if pg is None: + pg = dist.distributed_c10d._get_default_group() + for key in keys: + _local_state = local_state_dict.get(key, None) + if _local_state is None or torch.is_tensor(_local_state): + continue + + local_state = _local_state[0] + full_tensor = _local_state[1] + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, local_state.device_mesh, local_state.placements + ) + slices = [slice(offset[i], shape[i] + offset[i]) for i in range(len(shape))] + local_tensor = full_tensor[slices] + # TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example, + # one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)). + local_state_dict[key] = DTensor.from_local( + local_tensor, + local_state.device_mesh, + local_state.placements, + shape=local_state.shape, + stride=local_state.stride(), + ) + + +def _broadcast_state_dict( + full_state_dict: Dict[str, Any], + local_state_dict: Dict[str, Any], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, + strict: bool = False, +) -> None: + # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`. + # If strict is True, any keys in `local_state_dict` but not in `full_state_dict` + # will be removed from `local_state_dict`. + ret = {} + if dist.get_rank() == 0: + for key, value in full_state_dict.items(): + if not torch.is_tensor(value): + ret[key] = value + elif value.dim() == 0: + ret[key] = value.cpu() + else: + ret[key] = _TensorInfo(value.size(), value.dtype) + + broadcast_list = [ret] + dist.broadcast_object_list(broadcast_list, src=0, group=pg) + ret = broadcast_list[0] + + # Gather values + keys = [] + local_state_dict_keys = set(local_state_dict.keys()) + global_keys = set() + for key, value in ret.items(): + global_keys.add(key) + if not isinstance(value, _TensorInfo): + if key in local_state_dict: + local_state_dict[key] = value + continue + + if dist.get_rank() == 0: + ret[key] = full_state_dict[key] + + keys.append(key) + # Broadcast every tensor to avoid OOM for now. + if len(keys) >= 1: + _broadcast_tensors(ret, local_state_dict, keys, device, pg) + keys.clear() + + if strict: + if missing_keys := (local_state_dict_keys - global_keys): + for key in missing_keys: + local_state_dict.pop(key) + + if keys: + _broadcast_tensors(ret, local_state_dict, keys, device, pg) + + +def _distribute_state_dict( + full_state_dict: Dict[str, Any], + local_state_dict: Dict[str, Any], + device: torch.device, + pg: Optional[dist.ProcessGroup] = None, +) -> None: + # Full_state_dict = True, broadcast_from_rank0 = False here. Each rank has + # full_state_dict. Skip the broadcast in ``_broadcast_state_dict`` and + # distribute tensors in each rank + for key, value in full_state_dict.items(): + if key not in full_state_dict: + continue + if not torch.is_tensor(value): + local_state_dict[key] = value + elif value.dim() == 0: + local_state_dict[key] = value.cpu() + else: + assert isinstance(value, torch.Tensor) + local_state = local_state_dict.get(key, None) + if local_state is None: + continue + elif isinstance(local_state, DTensor): + local_state_dict[key] = distribute_tensor( + value.detach().to(device), + local_state.device_mesh, + local_state.placements, + ) + else: + local_state_dict[key] = value.detach().to(device) + + +# These APIs are from torch.distributed.checkpoint. +# TODO: We should consolidate the code here as some not all modules can depend on +# DCP. +PATH_ITEM = Union[str, int] +OBJ_PATH = Tuple[PATH_ITEM, ...] +FLATTEN_MAPPING = Dict[str, OBJ_PATH] +STATE_DICT_TYPE = Dict[str, Any] +CONTAINER_TYPE = MutableMapping[PATH_ITEM, Any] + + +def _traverse_state_dict( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, Any], None], +) -> None: + """ + Invoke ``visitor`` for each value recursively in ``state_dict``. + Mapping, list, and tuple will be flattened and other value types are treated + as the terminal values and will invoke ``visitor``. + """ + + def _traverse_obj(path: OBJ_PATH, value: Any) -> None: + if isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + else: + visitor(path, value) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def _flatten_state_dict( + state_dict: STATE_DICT_TYPE, +) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: + """ + Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. + + Use ``unflatten_state_dict`` to revert this process. + Returns: + A tuple with the flatten state_dict and a mapping from original to new state_dict. + N.B. The new keys are derived from the object paths, joined by dot. + For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. + """ + flattened: STATE_DICT_TYPE = {} + mappings: FLATTEN_MAPPING = {} + + def flat_copy(path: OBJ_PATH, value: Any) -> None: + new_fqn = ".".join(map(str, path)) + if new_fqn in flattened: + raise ValueError(f"duplicated flatten key {new_fqn}") + flattened[new_fqn] = value + mappings[new_fqn] = path + + _traverse_state_dict(state_dict, flat_copy) + return flattened, mappings + + +def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None: + """Set ``value`` in ``root_dict`` along the ``path`` object path.""" + cur_container = cast(CONTAINER_TYPE, root_dict) + + def extend_list(lst: List[Any], idx: int) -> None: + while len(lst) <= idx: + lst.append(None) + + for i in range(1, len(path)): + prev_key = path[i - 1] + key = path[i] + def_val: Union[CONTAINER_TYPE, List[Any]] = {} if type(key) == str else [] + + if isinstance(cur_container, Mapping): + cur_container = cast( + CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) + ) + else: + extend_list(cur_container, prev_key) + if cur_container[prev_key] is None: + cur_container[prev_key] = def_val + cur_container = cur_container[prev_key] + + key = path[-1] + if type(key) == int: + extend_list(cast(List[Any], cur_container), key) + + cur_container[key] = value + + +def _unflatten_state_dict( + state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING +) -> STATE_DICT_TYPE: + """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.""" + nested: STATE_DICT_TYPE = {} + for key, value in state_dict.items(): + _set_element(nested, mapping[key], value) + return nested diff --git a/lib/python3.10/site-packages/torch/distributed/argparse_util.py b/lib/python3.10/site-packages/torch/distributed/argparse_util.py new file mode 100644 index 0000000000000000000000000000000000000000..c475eebf21273abb53ab99e3edcbdef18e9f0c8f --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/argparse_util.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +from argparse import Action + + +class env(Action): + """ + Get argument values from ``PET_{dest}`` before defaulting to the given ``default`` value. + + For flags (e.g. ``--standalone``) + use ``check_env`` instead. + + .. note:: when multiple option strings are specified, ``dest`` is + the longest option string (e.g. for ``"-f", "--foo"`` + the env var to set is ``PET_FOO`` not ``PET_F``) + + Example: + :: + + parser.add_argument("-f", "--foo", action=env, default="bar") + + ./program -> args.foo="bar" + ./program -f baz -> args.foo="baz" + ./program --foo baz -> args.foo="baz" + PET_FOO="env_bar" ./program -f baz -> args.foo="baz" + PET_FOO="env_bar" ./program --foo baz -> args.foo="baz" + PET_FOO="env_bar" ./program -> args.foo="env_bar" + + parser.add_argument("-f", "--foo", action=env, required=True) + + ./program -> fails + ./program -f baz -> args.foo="baz" + PET_FOO="env_bar" ./program -> args.foo="env_bar" + PET_FOO="env_bar" ./program -f baz -> args.foo="baz" + """ + + def __init__(self, dest, default=None, required=False, **kwargs) -> None: + env_name = f"PET_{dest.upper()}" + default = os.environ.get(env_name, default) + + # ``required`` means that it NEEDS to be present in the command-line args + # rather than "this option requires a value (either set explicitly or default" + # so if we found default then we don't "require" it to be in the command-line + # so set it to False + if default: + required = False + + super().__init__(dest=dest, default=default, required=required, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + + +class check_env(Action): + """ + Check whether the env var ``PET_{dest}`` exists before defaulting to the given ``default`` value. + + Equivalent to + ``store_true`` argparse built-in action except that the argument can + be omitted from the commandline if the env var is present and has a + non-zero value. + + .. note:: it is redundant to pass ``default=True`` for arguments + that use this action because a flag should be ``True`` + when present and ``False`` otherwise. + + Example: + :: + + parser.add_argument("--verbose", action=check_env) + + ./program -> args.verbose=False + ./program --verbose -> args.verbose=True + PET_VERBOSE=1 ./program -> args.verbose=True + PET_VERBOSE=0 ./program -> args.verbose=False + PET_VERBOSE=0 ./program --verbose -> args.verbose=True + + Anti-pattern (don't do this): + + :: + + parser.add_argument("--verbose", action=check_env, default=True) + + ./program -> args.verbose=True + ./program --verbose -> args.verbose=True + PET_VERBOSE=1 ./program -> args.verbose=True + PET_VERBOSE=0 ./program -> args.verbose=False + + """ + + def __init__(self, dest, default=False, **kwargs) -> None: + env_name = f"PET_{dest.upper()}" + default = bool(int(os.environ.get(env_name, "1" if default else "0"))) + super().__init__(dest=dest, const=True, default=default, nargs=0, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, self.const) diff --git a/lib/python3.10/site-packages/torch/distributed/c10d_logger.py b/lib/python3.10/site-packages/torch/distributed/c10d_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..162cb62f992fdc14248e6d2b6e318e4701de4002 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/c10d_logger.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import logging +import time +from typing import Any, Callable, Dict, List, Tuple, TypeVar +from typing_extensions import ParamSpec + +import torch +import torch.distributed as dist +from torch.distributed.logging_handlers import _log_handlers + + +__all__: List[str] = [] + +_DEFAULT_DESTINATION = "default" + + +def _get_or_create_logger(destination: str = _DEFAULT_DESTINATION) -> logging.Logger: + logging_handler, log_handler_name = _get_logging_handler(destination) + logger = logging.getLogger(f"c10d-{log_handler_name}") + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + logging_handler.setFormatter(formatter) + logger.propagate = False + logger.addHandler(logging_handler) + return logger + + +def _get_logging_handler( + destination: str = _DEFAULT_DESTINATION, +) -> Tuple[logging.Handler, str]: + log_handler = _log_handlers[destination] + log_handler_name = f"{type(log_handler).__name__}-{destination}" + return (log_handler, log_handler_name) + + +global _c10d_logger +_c10d_logger = _get_or_create_logger() + + +def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]: + if dist.is_initialized(): + group = kwargs.get("group") or kwargs.get("process_group") + msg_dict = { + "func_name": f"{func_name}", + "args": f"{args}, {kwargs}", + "pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}", # type: ignore[arg-type] + "backend": f"{dist.get_backend(group)}", + "world_size": f"{dist.get_world_size()}", + "group_size": f"{dist.get_world_size(group)}", + "global_rank": f"{dist.get_rank()}", + "local_rank": f"{dist.get_rank(group)}", + } + if msg_dict["backend"] == "nccl": + nccl_version = torch.cuda.nccl.version() + msg_dict["nccl_version"] = ".".join(str(v) for v in nccl_version) + else: + msg_dict = { + "func_name": f"{func_name}", + "args": f"{args}, {kwargs}", + } + return msg_dict + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + try: + return func(*args, **kwargs) + except Exception as error: + msg_dict = _get_msg_dict(func.__name__, *args, **kwargs) + msg_dict["error"] = f"{error}" + _c10d_logger.debug(msg_dict) + raise + + return wrapper + + +def _time_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + t1 = time.time_ns() + func_return = func(*args, **kwargs) + time_spent = time.time_ns() - t1 + + msg_dict = _get_msg_dict(func.__name__, *args, **kwargs) + msg_dict["time_spent"] = f"{time_spent}ns" + _c10d_logger.debug(msg_dict) + + return func_return + + return wrapper diff --git a/lib/python3.10/site-packages/torch/distributed/collective_utils.py b/lib/python3.10/site-packages/torch/distributed/collective_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..78199e7a26f22cee28564fb66fd8bbdbc8a1defa --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/collective_utils.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 + + +""" +A set of primitive functions for performing collective ops. + +Each should also handle single rank scenario. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, cast, Generic, List, Optional, Tuple, TypeVar, Union + +import torch.distributed as dist + + +T = TypeVar("T") + + +@dataclass +class SyncPayload(Generic[T]): + stage_name: Optional[str] + success: bool + payload: T + exception: Optional[Exception] = None + + +def broadcast( + data_or_fn: Union[T, Callable[[], T]], + *, + success: bool = True, + stage_name: Optional[str] = None, + rank: int = 0, + pg: Optional[dist.ProcessGroup] = None, +) -> T: + """ + Broadcasts the data payload from rank 0 to all other ranks. + Or if a function is passed, execute it in rank 0 and broadcast result to all other ranks. + + Can be used to broadcast a failure signal to stop all ranks. + + If the function raises an exception, all ranks will raise. + + Args: + data_or_fn: the data to broadcast or function to execute and broadcast result. + success: False to stop all ranks. + stage_name: the name of the logical stage for synchronization and debugging + rank: rank to broadcast data or execute function and broadcast resutls. + pg: the process group for sync + Throws: + RuntimeError from original exception trace + Returns: + the value after synchronization + + Example usage: + >> id = broadcast(data_or_fn=allocate_id, rank=0, pg=ext_pg.my_pg) + """ + + if not success and data_or_fn is not None: + raise AssertionError( + "Data or Function is expected to be None if not successful" + ) + + payload: Optional[T] = None + exception: Optional[Exception] = None + # if no pg is passed then execute if rank is 0 + if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank): + # determine if it is an executable function or data payload only + if callable(data_or_fn): + try: + payload = data_or_fn() + except Exception as e: + success = False + exception = e + else: + payload = data_or_fn + + # broadcast the exception type if any to all ranks for failure categorization + sync_obj = SyncPayload( + stage_name=stage_name, + success=success, + payload=payload, + exception=exception, + ) + + if pg is not None: + broadcast_list = [sync_obj] + dist.broadcast_object_list(broadcast_list, src=rank, group=pg) + assert len(broadcast_list) == 1 + sync_obj = broadcast_list[0] + + # failure in any rank will trigger a throw in every rank. + if not sync_obj.success: + error_msg = f"Rank {rank} failed" + if stage_name is not None: + error_msg += f": stage {sync_obj.stage_name}" + if sync_obj.exception is not None: + error_msg += f": exception {sync_obj.exception}" + raise RuntimeError(error_msg) from sync_obj.exception + + return cast(T, sync_obj.payload) + + +def all_gather( + data_or_fn: Union[T, Callable[[], T]], + stage_name: Optional[str] = None, + pg: Optional[dist.ProcessGroup] = None, +) -> List[T]: + """ + A simple all_gather primitive with basic synchronization guard logic, + by checking payload from all ranks has the same stage name. + + Args: + data_or_fn: the data to be all gathered across ranks or function to be executed + stage_name: the sync stage name for out-of-sync protection + pg: the process group for sync + Throws: + RuntimeError from original exception trace + Returns: + a list of synced data from all ranks + + Example usage: + >> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg) + """ + payload: Optional[T] = None + exception: Optional[Exception] = None + success = True + # determine if it is an executable function or data payload only + if callable(data_or_fn): + try: + payload = data_or_fn() + except Exception as e: + success = False + exception = e + else: + payload = data_or_fn + + sync_obj = SyncPayload( + stage_name=stage_name, + success=success, + payload=payload, + exception=exception, + ) + + if pg is not None: + # List of success/failure across all ranks. + total_list = [None] * dist.get_world_size(pg) + all_gather_object_enforce_type(pg, total_list, sync_obj) + # Each rank will throw RuntimeError in case of failure on any rank. + stage_name = cast(SyncPayload[T], total_list[0]).stage_name + exception_list: List[Tuple[int, Exception]] = [] + ret_list: List[T] = [] + error_msg: str = "" + + for i, sp in enumerate(cast(List[SyncPayload[T]], total_list)): + if sp.stage_name != stage_name: + error_msg += ( + f"Unexpected stage name received from rank {i}: {sp.stage_name} " + ) + continue + if not sp.success and sp.exception is not None: + exception_list.append((i, sp.exception)) + continue + ret_list.append(sp.payload) + + if len(exception_list) > 0: + raise RuntimeError( # type: ignore[misc] + error_msg, exception_list + ) from exception_list[0] + return ret_list + else: + if not sync_obj.success: + raise RuntimeError( + f"all_gather failed with exception {sync_obj.exception}", + ) from sync_obj.exception + return [sync_obj.payload] # type: ignore[list-item] + + +# Note: use Any for typing for now so users can pass in +# either a list of None or target type placeholders +# otherwise pyre would complain +def all_gather_object_enforce_type( + pg: dist.ProcessGroup, + # pyre-fixme[2]: Parameter must have a type that does not contain `Any` + object_list: List[Any], + # pyre-fixme[2]: Parameter must have a type other than `Any` + obj: Any, + # pyre-fixme[2]: Parameter must have a type that does not contain `Any` + type_checker: Callable[[Any, Any], bool] = lambda x, y: type(x) == type(y), +) -> None: + """ + Similar to plain all_gather_object but with additional type checking + AFTER gather is done to ensure basic consistency. + If check does not pass, all ranks will fail with exception. + + This is generally to prevent conditional logic leading to + unexpected messages being received. This is considered fatal code error, + but due to logic stacks this might happen implicitly in practice. + + The default check does not check sub type (considered different) + or covariance (considered same) but users can pass in custom checker + if more complicated check is needed. + """ + dist.all_gather_object(object_list, obj, group=pg) + + # conservative check + list_len = len(object_list) + if list_len == 0: + return + first_obj = object_list[0] + for i in range(1, list_len): + if not type_checker(first_obj, object_list[i]): + raise TypeError( + f"Object type at index {i} is {type(object_list[i])}, " + f"while first object type is {type(first_obj)}" + ) diff --git a/lib/python3.10/site-packages/torch/distributed/constants.py b/lib/python3.10/site-packages/torch/distributed/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..b3754043644b8cc96feee06a4fc1cc9fa824b648 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/constants.py @@ -0,0 +1,26 @@ +from datetime import timedelta +from typing import Optional + +from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT + + +__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"] + +# Default process group wide timeout, if applicable. +# This only applies to the non-nccl backends +# To make an attempt at backwards compatibility with THD, we use an +# extraordinarily high default timeout, given that THD did not have timeouts. +default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT +# Separate timeout for PGNCCL mainly becuase it's always been that way in the C++ layer, but until recently +# there was one default that applied across all backends in the python layer. +# Later, we could consider merging them back together at the c++ layer if we can align on a same value. +# (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1). + +try: + from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT + + default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT +except ImportError: + # if C++ NCCL support is not compiled, we don't have access to the default nccl value. + # if anyone is actually trying to use nccl in this state, it should error. + default_pg_nccl_timeout = None diff --git a/lib/python3.10/site-packages/torch/distributed/device_mesh.py b/lib/python3.10/site-packages/torch/distributed/device_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d33ab7ebfbea58b56b9e86b92934f9c1ff9fa9 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/device_mesh.py @@ -0,0 +1,953 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import math +import threading +from functools import reduce +from itertools import chain +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union + +import torch +from torch.distributed import is_available +from torch.utils._typing_utils import not_none + + +__all__ = ["init_device_mesh", "DeviceMesh"] + + +if not is_available(): + import sys + + # We need to create the stubs when distributed is not available. + # Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```), + # since it would try to import ``torch.distributed.device_mesh`` or + # ``torch.distributed.init_device_mesh`` but cannot find them. + + class _DeviceMeshStub: + pass + + def _init_device_mesh_stub(): + pass + + sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined] + sys.modules[ + "torch.distributed.device_mesh" + ].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined] + + +else: + from torch.distributed.distributed_c10d import ( + _find_pg_by_ranks_and_tag, + _get_default_group, + _get_group_tag, + get_backend, + get_process_group_ranks, + get_rank, + get_world_size, + init_process_group, + is_initialized, + new_group, + ProcessGroup, + ) + + logger = logging.getLogger(__name__) + + # only import numpy typing when type checking + if TYPE_CHECKING: + try: + from numpy.typing import ArrayLike + except ImportError: + logger.warning( + "DeviceMesh requires numpy >= 1.21 to be installed for type checking" + ) + + class _MeshEnv(threading.local): + def __init__(self) -> None: + self.mesh_stack: List[DeviceMesh] = [] + self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {} + self.mesh_dim_group_options: Dict[ + int, Tuple[str, Optional[ProcessGroup.Options]] + ] = {} + self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {} + # Record flatten mesh name to its mesh dim index in root mesh. + self.flatten_name_to_root_dims: Dict[ + DeviceMesh, Dict[str, Tuple[int, ...]] + ] = {} + + def get_current_mesh(self) -> "DeviceMesh": + if len(self.mesh_stack) == 0: + raise RuntimeError("No device mesh is currently active!") + return self.mesh_stack[-1] + + def create_sub_mesh( + self, + device_mesh: "DeviceMesh", + submesh_dim_names: Tuple[str, ...], + submesh_dims: List[Tuple[int, ...]], + ) -> "DeviceMesh": + # Get the submesh dim size from the submesh_dims. + # For example, if we have a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp", "cp", "tp") and we want + # to slice out mesh["dp_cp"], then submesh_dims = [(0, 1), (2,)] and submesh_dim_size = [2 * 2, 2] = [4, 2]. + # If we want to slice out mesh["dp", "cp"], then submesh_dims = [(0,), (1,)] and submesh_dim_size = [2, 2]. + slice_dim_size = [ + reduce( + lambda x, y: device_mesh.mesh.size(x) * device_mesh.mesh.size(y), + mesh_dim, + ) + if len(mesh_dim) > 1 + else device_mesh.mesh.size(mesh_dim[0]) + for mesh_dim in submesh_dims + ] + + mesh_tensor = device_mesh.mesh + # slice_dim_idx could be differnt from submesh_dims, as we may need to flatten out some dims. + slice_dim_idx = [] + slice_dim_group_info = [] + # keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the + # flattened mesh tensor. + num_dims_flatten = 0 + for mesh_dim_indices, mesh_dim_name in zip(submesh_dims, submesh_dim_names): + # Currently, this only allows slicing out a contiguous flattened dim. + # TODO: we need to handle reconstructing a non-contiguous flattened dim. + if len(mesh_dim_indices) > 1: + # We need to move the start_dim and end_dim to the left if some dims are already flattened. + mesh_tensor = mesh_tensor.flatten( + start_dim=mesh_dim_indices[0] - num_dims_flatten, + end_dim=mesh_dim_indices[-1] - num_dims_flatten, + ) + # If some dims are already flattened, we need to adjust the slice_dim_idx accordingly. + # For example, if the submesh_dims = [(0, 1), (2,), (3, 4)] with 0-1 flattened and 3-4 flattened, + # then the final slice_dim_idx should be [0, 1, 2]. + slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) + num_dims_flatten += len(mesh_dim_indices) - 1 + slice_dim_group_info.append( + self.root_to_flatten_mapping[device_mesh][ + mesh_dim_name + ]._dim_group_infos[0] + ) + else: + slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) + slice_dim_group_info.append( + device_mesh._dim_group_infos[mesh_dim_indices[0]] + ) + + # mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now. + mesh_dims_remained_idx = list(range(mesh_tensor.ndim)) + for idx in slice_dim_idx: + mesh_dims_remained_idx.remove(idx) + + # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx] + # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with + # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank. + pg_ranks_by_dim = mesh_tensor.permute( + *mesh_dims_remained_idx, *slice_dim_idx + ).reshape(-1, *slice_dim_size) + + cur_rank = device_mesh.get_rank() + for mesh_nd in pg_ranks_by_dim: + submesh = DeviceMesh( + device_mesh.device_type, + mesh_nd, + mesh_dim_names=submesh_dim_names, + _init_backend=False, + ) + if cur_rank in mesh_nd: + res_submesh = submesh + + res_submesh._dim_group_infos = slice_dim_group_info # type: ignore[possibly-undefined] + self.child_to_root_mapping[res_submesh] = device_mesh + + return res_submesh + + def create_flatten_mesh( + self, device_mesh: "DeviceMesh", mesh_dim_name: Optional[str] = None + ) -> "DeviceMesh": + root_mesh = _mesh_resources.get_root_mesh(device_mesh) + + flatten_dims_in_root = [ + not_none(root_mesh.mesh_dim_names).index(flattened_mesh_dim_name) + for flattened_mesh_dim_name in not_none(device_mesh.mesh_dim_names) + ] + + if not mesh_dim_name: + mesh_dim_name = "_".join( + [ + not_none(root_mesh.mesh_dim_names)[dim] + for dim in flatten_dims_in_root + ] + ) + + # Check whether the mesh_dim_name for flattened mesh is valid. + self.flatten_name_to_root_dims.setdefault(root_mesh, {}) + invalid_dim_names = chain( + *list(not_none(root_mesh.mesh_dim_names)), + *self.flatten_name_to_root_dims[root_mesh].keys(), + ) + if mesh_dim_name in invalid_dim_names: + raise RuntimeError( + f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ", + f"The mesh_dim_names of submesh and flattened mesh are {invalid_dim_names}. " + f"Please specify another valid mesh_dim_name.", + ) + + # Quick return if the flatten mesh has been created before. + # TODO: If we decide to restrict flatten initialization once, we should remove + # this check and throw an error if the flatten mesh is already created before. + if ( + root_mesh in self.root_to_flatten_mapping + and mesh_dim_name in self.root_to_flatten_mapping[root_mesh] + ): + return self.root_to_flatten_mapping[root_mesh][mesh_dim_name] + + flattened_mesh_dim_size = math.prod(device_mesh.mesh.size()) + + remained_dims_in_root = list(range(root_mesh.mesh.ndim)) + for flatten_dim_in_root in flatten_dims_in_root: + remained_dims_in_root.remove(flatten_dim_in_root) + + pg_ranks_by_dim = root_mesh.mesh.permute( + *remained_dims_in_root, *flatten_dims_in_root + ).reshape(-1, flattened_mesh_dim_size) + + cur_rank = root_mesh.get_rank() + for mesh_nd in pg_ranks_by_dim: + # need to init backend here since the flattened pg doesn't exist in root mesh. + flattened_mesh = DeviceMesh( + root_mesh.device_type, + mesh_nd, + mesh_dim_names=(mesh_dim_name,), + ) + if cur_rank in mesh_nd: + res_flattened_mesh = flattened_mesh + self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined] + self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = res_flattened_mesh # type: ignore[possibly-undefined] + self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(flatten_dims_in_root) # type: ignore[possibly-undefined] + + return res_flattened_mesh + + def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh": + # If a mesh could not be found in the child_to_root_mapping, it is a root mesh itself. + # A root mesh is not created through slicing. + # We considers the root mesh of a root mesh is itself. + root_mesh = self.child_to_root_mapping.get(device_mesh, None) + return device_mesh if not root_mesh else root_mesh + + def get_root_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]: + """ + Returns the index of the mesh dim in the root mesh. + The device_mesh passed in needs to be sliced out from the root mesh + or submesh of the root mesh. + """ + root_mesh = self.get_root_mesh(device_mesh) + child_mesh_dim_names = device_mesh.mesh_dim_names + if root_mesh and child_mesh_dim_names: + assert ( + len(child_mesh_dim_names) == 1 + ), "The submesh can only be a 1D mesh." + child_mesh_dim_name = child_mesh_dim_names[0] + return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name) + return None + + @staticmethod + def num_devices_per_host(device_type: str) -> int: + return _get_device_handle(device_type).device_count() + + @staticmethod + def num_hosts(device_type: str) -> int: + # ProcessGroup can't tell us this info so we have to infer it, assume + # homogeneous hardware for now + return get_world_size() // _MeshEnv.num_devices_per_host(device_type) + + def get_mesh_dim_by_name( + self, device_mesh: "DeviceMesh", mesh_dim_name: str + ) -> int: + if ( + device_mesh.mesh_dim_names is None + or len(device_mesh.mesh_dim_names) == 0 + ): + raise KeyError( + "No `mesh_dim_names` found.", + ) + if mesh_dim_name not in device_mesh.mesh_dim_names: + raise KeyError( + f"Mesh dimension '{mesh_dim_name}' does not exist.", + f"Available mesh dimensions are: mesh_dim_names={device_mesh.mesh_dim_names}", + ) + return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name)) + + def _set_mesh_dim_group_options( + self, + dim: int, + backend: str, + pg_options: Optional[ProcessGroup.Options] = None, + ) -> None: + self.mesh_dim_group_options[dim] = (backend, pg_options) + + def _get_slice_mesh_dims( + self, device_mesh, mesh_dim_names + ) -> List[Tuple[int, ...]]: + """ + Validate whether the mesh_dim_names is valid for slicing the given device_mesh. + If valid, return dim indexes of the slice mesh in the device mesh. + """ + if device_mesh != self.get_root_mesh(device_mesh): + raise RuntimeError("Cannot create a submesh from a submesh.") + + # The slice mesh_dim_names should consist either the device_mesh's mesh_dim_names + # or its flattened mesh's mesh_dim_names. + self.flatten_name_to_root_dims.setdefault(device_mesh, {}) + flatten_name_to_root_dims = self.flatten_name_to_root_dims[device_mesh] + valid_mesh_dim_names = [ + *device_mesh.mesh_dim_names, + *flatten_name_to_root_dims, + ] + + if not all( + mesh_dim_name in valid_mesh_dim_names + for mesh_dim_name in mesh_dim_names + ): + raise KeyError( + f"Invalid mesh_dim_names {mesh_dim_names} specified. " + f"Valid mesh_dim_names are {valid_mesh_dim_names}." + ) + + # Validate the order of the slice mesh dim indices. + # This needs to be in ascending order. + curr_idx = -1 + slice_mesh_dims = [] + for mesh_dim_name in mesh_dim_names: + if mesh_dim_name in flatten_name_to_root_dims: + mesh_indices = flatten_name_to_root_dims[mesh_dim_name] + # TODO: this doesn't allow non-contiguous slicing with flatten dim yet. next_idx + # should be mesh_indices[0] once we support non-contiguous slicing with flatten dim. + next_idx = mesh_indices[-1] + slice_mesh_dims.append(mesh_indices) + else: + next_idx = device_mesh.mesh_dim_names.index(mesh_dim_name) + slice_mesh_dims.append((next_idx,)) + if next_idx <= curr_idx: + raise KeyError( + f"Invalid mesh_dim_names {mesh_dim_names} specified. ", + f"Found mesh dim indices to slice: {slice_mesh_dims}. ", + "Mesh dim indices should be in ascending order.", + ) + curr_idx = next_idx + + return slice_mesh_dims + + def _get_all_submeshes( + self, device_mesh: "DeviceMesh", mesh_dim_name: str + ) -> List["DeviceMesh"]: + """ + Return all the submeshes of a given mesh dimension of the device mesh. + """ + mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name) + pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( + -1, device_mesh.mesh.size(mesh_dim) + ) + + cur_rank = device_mesh.get_rank() + res_submeshes = [] + for mesh_1d in pg_ranks_by_dim: + submesh = DeviceMesh( + device_mesh.device_type, + mesh_1d, + mesh_dim_names=(mesh_dim_name,), + _init_backend=False, + ) + submesh._dim_group_infos = ( + [device_mesh._dim_group_infos[mesh_dim]] + if cur_rank in mesh_1d + else [] + ) + res_submeshes.append(submesh) + + return res_submeshes + + _mesh_resources: _MeshEnv = _MeshEnv() + + def _get_device_handle(device_type: str = "cuda"): + """ + Get the module corresponding to the device_type which is cuda or cuda-like device. + For example, when the device_type is cuda, the module `torch.cuda` is returned. + Return None when there is no corresponding module for device_type, otherwise + return the corresponding module. + """ + return getattr(torch, device_type, None) + + class DeviceMesh: + """ + DeviceMesh represents a mesh of devices, where layout of devices could be + represented as a n-d dimension array, and each value of the n-d dimensional + array is the global id of the default process group ranks. + + DeviceMesh could be used to describe the layout of devices across the cluster, + and serves as a proxy for communication among the device lists within the cluster. + + DeviceMesh can be used as a context manager. + + .. note:: + DeviceMesh follows SPMD programming model, which means the same PyTorch Python program + is running on all processes/ranks in the cluster. Therefore, users need to make sure the + `mesh` array (which describes the layout of devices) should be identical across all ranks. + Inconsistent `mesh` will lead to silent hang. + + Args: + device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". + mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout + of devices, where the IDs are global IDs of the default process group. + + Returns: + DeviceMesh: A :class:`DeviceMesh` object representing the device layout. + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + A reduction over the first dimension of mesh will reduce across + columns (0, 4), .. and (3, 7), a reduction over the second dimension + of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7). + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) + """ + + device_type: str + mesh: torch.Tensor + mesh_dim_names: Optional[Tuple[str, ...]] + + def __init__( + self, + device_type: str, + mesh: Union[torch.Tensor, "ArrayLike"], + *, + mesh_dim_names: Optional[Tuple[str, ...]] = None, + _init_backend: bool = True, + ) -> None: + self.device_type = device_type + if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu": + raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}") + self.mesh = ( + mesh.detach().to(dtype=torch.int) + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, device="cpu", dtype=torch.int) + ) + self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None + + # private field to pre-generate DeviceMesh's hash + self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) + self._thread_id = None + + # Skip process group initialization if xla device or init backend is False + # TODO(yeounoh) implement DeviceMesh backend and register XLA backend. + if device_type != "xla": + # always try to create default (world) pg, even if it is not initialized + # already. The world pg is used for device mesh identity (rank) on each + # process (we need to know if the current global rank is in the mesh or not). + if _init_backend: + self._get_or_create_default_group() + self._init_process_groups() + + if is_initialized() and get_backend() == "threaded": + self._thread_id = threading.get_ident() + + # calculate the coordinates of the current global rank on the mesh + rank_coords = (self.mesh == get_rank()).nonzero() + assert rank_coords.size(0) in (0, 1) + self._coordinate_on_dim: Optional[List[int]] = ( + rank_coords[0].tolist() if rank_coords.size(0) > 0 else None + ) + + def _get_or_create_default_group(self): + default_initialized = is_initialized() + if not default_initialized: + init_process_group() + + world_size = get_world_size() + if self.mesh.numel() > world_size: + raise RuntimeError( + f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!" + ) + + device_handle = _get_device_handle(self.device_type) + # TODO: if user want to pass pg_options, offer a way to do it + if not default_initialized and device_handle: + # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host + # NOTE: This device selection would only work for homogeneous hardware. + num_devices_per_host = device_handle.device_count() + if ( + world_size > num_devices_per_host + and world_size % num_devices_per_host != 0 + ): + raise RuntimeError( + f"DeviceMesh only support homogeneous hardware, but found " + f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" + ) + device_handle.set_device(get_rank() % num_devices_per_host) + + return _get_default_group() + + def _init_process_groups(self): + # tag/ranks/group_name associated with each mesh dimension, each + # mesh dimension should have one sub-group per rank + # + # TODO(yifu): remove tag and ranks once we fully migrate to native + # functional collectives. See details in: + # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208 + dim_group_infos: List[Tuple[str, List[int], str]] = [] + + if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size(): + # Append the default pg to the first dim groups only if the default pg is compatible with `self.device_type`. + # Otherwise, create new pg. + default_group = _get_default_group() + ranks = list(range(get_world_size())) + dim_group = ( + new_group(backend="cpu:gloo,cuda:nccl", ranks=ranks) + if torch.cuda.is_available() + and get_backend(default_group) == "gloo" + else default_group + ) + dim_group_infos.append( + ( + _get_group_tag(dim_group), + ranks, + dim_group.group_name, + ) + ) + else: + # create sub pgs base on the mesh argument specified + for dim in range(self.mesh.ndim): + # swap the current dim to the last dim + # then reshape to flatten out other dims + pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape( + -1, self.mesh.size(dim) + ) + # multi-dim mesh, create subgroups by looping over the pg_ranks + # for each dim and append the groups + for dim_mesh in pg_ranks_by_dim: + subgroup_ranks = dim_mesh.tolist() + + # Respect dim group options specified via _MeshEnv.set_dim_group_options(). + # Inherit from the parent group if no options are specified for the group. + if dim in _mesh_resources.mesh_dim_group_options: + ( + backend, + pg_options, + ) = _mesh_resources.mesh_dim_group_options[dim] + else: + backend, pg_options = None, None + + # We temporarily revert the re-use subgroup, since it breaks two internal tests. + # Temporarily reverting to resolve test timeout while root-causing. + # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists. + dim_group = new_group( + ranks=subgroup_ranks, + backend=backend, + pg_options=pg_options, + ) + + # only add to dim_groups if the current rank in the subgroup + if self.get_rank() in subgroup_ranks: + if len(dim_group_infos) > dim: + raise RuntimeError( + f"Each device mesh dimension should get only one process group, but got {self.get_rank()} " + f"in {subgroup_ranks}!" + ) + dim_group_infos.append( + ( + _get_group_tag(not_none(dim_group)), + subgroup_ranks, + dim_group.group_name, + ) + ) + self._dim_group_infos = dim_group_infos + + def __enter__(self) -> "DeviceMesh": + # set this mesh as the current mesh in mesh env + _mesh_resources.mesh_stack.append(self) + return self + + # pyre-fixme[2]: Parameter must be annotated. + def __exit__(self, exc_type, exc_value, exc_traceback) -> None: + # pop this mesh from mesh env + _mesh_resources.mesh_stack.pop() + + def __repr__(self) -> str: + device_mesh_repr = ( + f"DeviceMesh('{self.device_type}', {self.mesh.tolist()})" + if not self.mesh_dim_names + else f"DeviceMesh('{self.device_type}', {self.mesh.tolist()}, mesh_dim_names={self.mesh_dim_names})" + ) + return device_mesh_repr + + def __hash__(self): + # lazily compute hash + self._hash = getattr(self, "_hash", None) + if not self._hash: + self._hash = hash( + ( + self._flatten_mesh_list, + self.mesh.shape, + self.device_type, + self.mesh_dim_names, + self._thread_id, + ) + ) + return self._hash + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DeviceMesh): + return False + if id(self) == id(other): + return True + else: + return ( + self._flatten_mesh_list == other._flatten_mesh_list + and self.mesh.shape == other.mesh.shape + and self.device_type == other.device_type + and self.mesh_dim_names == other.mesh_dim_names + and self._thread_id == other._thread_id + ) + + def __getitem__( + self, mesh_dim_names: Union[str, Tuple[str, ...]] + ) -> "DeviceMesh": + """ + Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh. + The submesh created consists of the dimensions and the communicators indicated by + ``mesh_dim_names`` + + Args: + mesh_dim_names (Union[str, Tuple[str]]): the name or the tuple of names of the + mesh dimension of the DeviceMesh to create the submesh for. + Returns: + A :class:`DeviceMesh` object + + The following program runs on each process/rank in an SPMD manner in a world size of 8. + In the first example: + Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]). + Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of DeviceMesh:([4, 5, 6, 7]). + Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of DeviceMesh:([0, 4]). + Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of DeviceMesh:([1, 5]). + Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of DeviceMesh:([2, 6]). + Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of DeviceMesh:([3, 7]). + + In the second example: + Calling mesh_3d["dp", "cp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 1], [4, 5]]). + Calling mesh_3d["dp", "cp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 3], [6, 7]]). + Calling mesh_3d["cp", "dp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 4], [1, 5]]). + Calling mesh_3d["cp", "dp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 6], [3, 7]]). + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize a 2D device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh_2d = init_device_mesh(device_type="cuda", (2,4), mesh_dim_names=("dp", "tp")) + >>> tp_mesh = mesh_2d["tp"] + >>> dp_mesh = mesh_2d["dp"] + >>> + >>> # Initialize a 3D mesh. + >>> mesh_3d = init_device_mesh(device_type="cuda", (2,2,2), mesh_dim_names=("dp", "pp", "cp")) + >>> # The order of the mesh_dim_names provided deteremines the order of dimensions in the submesh. + >>> dp_cp_mesh = mesh_3d["dp", "cp"] + >>> cp_dp_mesh = mesh_3d["cp", "dp"] + """ + if not self.mesh_dim_names: + raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") + + mesh_dim_names = ( + (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names + ) + + if mesh_dim_names == self.mesh_dim_names: + return self + else: + slice_mesh_dims = _mesh_resources._get_slice_mesh_dims( + self, mesh_dim_names + ) + submesh = _mesh_resources.create_sub_mesh( + self, mesh_dim_names, slice_mesh_dims + ) + return submesh + + def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup: + """ + Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the + DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh. + + Args: + mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index + of the mesh dimension. Default is None. + + Returns: + A :class:`ProcessGroup` object. + """ + if not hasattr(self, "_dim_group_infos"): + raise RuntimeError("DeviceMesh process groups not initialized!") + + if self.mesh.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.mesh.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + "If you want to get the list of all the ProcessGroups in the DeviceMesh," + "please use `get_all_groups()` instead.", + ) + + # Quick return if the current device_mesh is a 1D mesh. + if self.mesh.ndim == 1 and mesh_dim is None: + return not_none( + _find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2]) # type: ignore[index] + ) + + root_mesh = _mesh_resources.get_root_mesh(self) + root_to_flatten_mapping = _mesh_resources.root_to_flatten_mapping.get( + root_mesh, None + ) + if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys(): + dim_group_infos = root_to_flatten_mapping[mesh_dim]._dim_group_infos[0][:2] # type: ignore[index] + return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos)) + else: + mesh_dim = ( + _mesh_resources.get_mesh_dim_by_name(self, mesh_dim) + if isinstance(mesh_dim, str) + else mesh_dim + ) + return not_none( + _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) # type: ignore[index] + ) + + def get_all_groups(self) -> List[ProcessGroup]: + """ + Returns a list of ProcessGroups for all mesh dimensions. + + Returns: + A list of :class:`ProcessGroup` object. + """ + return [self.get_group(i) for i in range(self.mesh.ndim)] + + @staticmethod + def from_group( + group: Union[ProcessGroup, List[ProcessGroup]], + device_type: str, + mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None, + *, + mesh_dim_names: Optional[Tuple[str, ...]] = None, + ) -> "DeviceMesh": + """ + Constructs a :class:`DeviceMesh` with ``device_type`` from an + existing :class:`ProcessGroup`. + + The constructed device mesh has number of dimensions equal to the + number of groups passed. If more than one group is passed, then the + ``mesh`` argument is required. + """ + if isinstance(group, ProcessGroup): + group_ranks = get_process_group_ranks(group) + if ( + isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks + ) or (mesh is not None and mesh != group_ranks): + raise ValueError( + f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}" + ) + mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int) + device_mesh = DeviceMesh( + device_type, + mesh, + mesh_dim_names=mesh_dim_names, + _init_backend=False, + ) + device_mesh._dim_group_infos = [ + (_get_group_tag(group), group_ranks, group.group_name) + ] + return device_mesh + groups = list(group) + if len(groups) == 0: + raise ValueError("Expects at least one ProcessGroup to be passed") + if mesh is None: + raise ValueError("Must pass mesh if passing multiple ProcessGroups") + mesh = ( + mesh.detach().to(dtype=torch.int, device="cpu") + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, device="cpu", dtype=torch.int) + ) + if mesh.ndim != len(groups): + raise ValueError( + "Expects mesh with ndim equal to number of ProcessGroups but got " + f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups" + ) + device_mesh = DeviceMesh( + device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False + ) + device_mesh._dim_group_infos = [ + ( + _get_group_tag(group), + get_process_group_ranks(group), + group.group_name, + ) + for group in groups + ] + return device_mesh + + def size(self, mesh_dim: Optional[int] = None) -> int: + return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim) + + @property + def ndim(self) -> int: + return self.mesh.ndim + + @property + def shape(self) -> Tuple[int, ...]: + return tuple(self.mesh.shape) + + def get_rank(self) -> int: + """ + Returns the current global rank. + """ + return get_rank() + + def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: + """ + Returns the local rank of the given mesh_dim of the DeviceMesh. + + Args: + mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index + of the mesh dimension. Default is None. + + Returns: + An integer denotes the local rank. + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. + Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3. + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) + """ + if self.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.mesh.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + ) + elif mesh_dim is None: + mesh_dim = 0 + + mesh_dim_group = not_none(self.get_group(mesh_dim)) + assert isinstance( + mesh_dim_group, ProcessGroup + ), "We expect ProcessGroup before calling `get_rank`!" + return not_none(get_rank(mesh_dim_group)) + + def get_coordinate(self) -> Optional[List[int]]: + """ + Return the relative indices of this rank relative to all + dimensions of the mesh. If this rank is not part of the mesh, return None. + """ + return self._coordinate_on_dim if self._coordinate_on_dim else None + + def _flatten(self, mesh_dim_name: Optional[str] = None) -> "DeviceMesh": + """ + Returns a 1D DeviceMesh by flattening the current DeviceMesh. + + If no mesh_dim_name is provided, the default is a string concatentaing the mesh_dim_names of the + given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh + DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling + mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 1, 2, 3], mesh_dim_names=("dp_cp",)) + on rank 0, 1, 2, 3 and a 1D submesh DeviceMesh([4, 5, 6, 7], mesh_dim_names=("dp_cp",)) on rank 4, 5, 6, 7. + + After the flattened dimension is created, to access the flattened dimesnion in mesh_3d, one can use the + existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"]. + """ + if not self.mesh_dim_names: + raise RuntimeError( + "Cannot flatten a DeviceMesh without mesh_dim_names!" + ) + + return _mesh_resources.create_flatten_mesh(self, mesh_dim_name) + + def init_device_mesh( + device_type: str, + mesh_shape: Tuple[int, ...], + *, + mesh_dim_names: Optional[Tuple[str, ...]] = None, + ) -> DeviceMesh: + """ + Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. + + This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`. + If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`. + + .. note:: + `init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program + runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array + describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging. + + .. note:: + If no process group is found, init_device_mesh will initialize distributed process group/groups + required for distributed communications behind the scene. + + Args: + device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". + Passing in a device type with a GPU index, such as "cuda:0", is not allowed. + mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array + describing the layout of devices. + mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension + of the multi-dimensional array describing the layout of devices. Its length must match the length + of `mesh_shape`. Each string in `mesh_dim_names` must be unique. + + Returns: + DeviceMesh: A :class:`DeviceMesh` object representing the device layout. + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import init_device_mesh + >>> + >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) + >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) + + """ + if mesh_dim_names is not None: + if len(set(mesh_dim_names)) != len(mesh_dim_names): + raise RuntimeError( + "Each mesh_dim_name must be unique.", + f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}", + ) + + if len(mesh_shape) != len(mesh_dim_names): + raise RuntimeError( + "mesh_shape and mesh_dim_names should have same length!", + f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", + ) + + # assume valid device types are all letters + if device_type and not device_type.isalpha(): + raise RuntimeError( + f"Device type with GPU index is not supported but got {device_type}. ", + "If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.", + ) + + # Always initialize the mesh's tensor on CPU, regardless of what the + # external device type has been set to be (e.g. meta) + with torch.device("cpu"): + mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape) + device_mesh = DeviceMesh( + device_type=device_type, + mesh=mesh, + mesh_dim_names=mesh_dim_names, + ) + + return device_mesh diff --git a/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py b/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py new file mode 100644 index 0000000000000000000000000000000000000000..45e096985143a372932f523a46cf869e7f862310 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py @@ -0,0 +1,4983 @@ +# mypy: allow-untyped-defs +"""Distributed Collective Communication (c10d).""" + +import collections.abc +import contextlib +import hashlib +import io +import itertools +import logging +import os +import pickle +import sys +import time +import warnings +from collections import namedtuple +from datetime import timedelta +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing_extensions import deprecated + +import torch +from torch._C import _DistStoreError as DistStoreError +from torch._C._distributed_c10d import ( + _DistributedBackendOptions, + _register_process_group, + _resolve_process_group, + _unregister_all_process_groups, + _unregister_process_group, + AllgatherOptions, + AllreduceCoalescedOptions, + AllreduceOptions, + AllToAllOptions, + BarrierOptions, + BroadcastOptions, + DebugLevel, + GatherOptions, + get_debug_level, + PrefixStore, + ProcessGroup, + ReduceOp, + ReduceOptions, + ReduceScatterOptions, + ScatterOptions, + Store, + Work, +) +from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs +from torch.utils._typing_utils import not_none + +from .c10d_logger import _exception_logger, _time_logger +from .constants import default_pg_nccl_timeout, default_pg_timeout +from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401 + + +__all__ = [ + "Backend", + "BackendConfig", + "GroupMember", + "P2POp", + "all_gather", + "all_gather_coalesced", + "all_gather_object", + "all_reduce", + "all_reduce_coalesced", + "all_to_all", + "all_to_all_single", + "barrier", + "batch_isend_irecv", + "broadcast", + "send_object_list", + "recv_object_list", + "broadcast_object_list", + "destroy_process_group", + "gather", + "gather_object", + "get_backend_config", + "get_backend", + "get_rank", + "get_world_size", + "get_pg_count", + "group", + "init_process_group", + "irecv", + "is_gloo_available", + "is_initialized", + "is_mpi_available", + "is_backend_available", + "is_nccl_available", + "is_torchelastic_launched", + "is_ucc_available", + "isend", + "monitored_barrier", + "new_group", + "new_subgroups", + "new_subgroups_by_enumeration", + "recv", + "reduce", + "reduce_scatter", + "scatter", + "scatter_object_list", + "send", + "supports_complex", + "AllreduceCoalescedOptions", + "AllreduceOptions", + "AllToAllOptions", + "BarrierOptions", + "BroadcastOptions", + "GatherOptions", + "PrefixStore", + "ProcessGroup", + "ReduceOp", + "ReduceOptions", + "ReduceScatterOptions", + "ScatterOptions", + "Store", + "DebugLevel", + "get_debug_level", + "Work", + "default_pg_timeout", + "get_group_rank", + "get_global_rank", + "get_process_group_ranks", + "reduce_op", + "all_gather_into_tensor", + "reduce_scatter_tensor", + "get_node_local_rank", + "split_group", +] + +_MPI_AVAILABLE = True +_NCCL_AVAILABLE = True +_GLOO_AVAILABLE = True +_UCC_AVAILABLE = True + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +# Change __module__ of all imported types from torch._C._distributed_c10d that are public +def _export_c_types() -> None: + _public_types_to_change_module = [ + AllreduceCoalescedOptions, + AllreduceOptions, + AllToAllOptions, + BarrierOptions, + BroadcastOptions, + GatherOptions, + PrefixStore, + ProcessGroup, + ReduceOp, + ReduceOptions, + ReduceScatterOptions, + ScatterOptions, + Store, + DebugLevel, + get_debug_level, + Work, + ] + for type in _public_types_to_change_module: + type.__module__ = "torch.distributed.distributed_c10d" + + +_export_c_types() + +try: + from torch._C._distributed_c10d import ProcessGroupMPI + + ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupMPI"] +except ImportError: + _MPI_AVAILABLE = False + +try: + from torch._C._distributed_c10d import ProcessGroupNCCL + + ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupNCCL"] +except ImportError: + _NCCL_AVAILABLE = False + +try: + from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo + + ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupGloo"] +except ImportError: + _GLOO_AVAILABLE = False + +try: + from torch._C._distributed_c10d import ProcessGroupUCC + + ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupUCC"] +except ImportError: + _UCC_AVAILABLE = False + +logger = logging.getLogger(__name__) + +PG_WRAPPER_STORE_PREFIX = "pg_wrapper" + + +# Some reduce ops are not supported by complex numbers and will result in an error. +# We currently provide complex support to the distributed API by viewing +# complex tensors as real (torch.view_as_real), meaning that calling +# these unsupported ops will return garbage values rather than error out. +# (e.g. max(2+3i, 3+2i) = 3+3i) +# We'd like calls to unsupported ops to error out accordingly, +# rather than returning garbage values. +def supports_complex(reduceOp: ReduceOp) -> bool: + """Return true if reduce ops is supported. False otherwise.""" + denyList = [ + ReduceOp.MAX, + ReduceOp.MIN, + ReduceOp.PRODUCT, + ReduceOp.BAND, + ReduceOp.BOR, + ReduceOp.BXOR, + ] + return reduceOp not in denyList + + +class Backend(str): + """ + An enum-like class for backends. + + Available backends: GLOO, NCCL, UCC, MPI, and other registered backends. + + The values of this class are lowercase strings, e.g., ``"gloo"``. They can + be accessed as attributes, e.g., ``Backend.NCCL``. + + This class can be directly called to parse the string, e.g., + ``Backend(backend_str)`` will check if ``backend_str`` is valid, and + return the parsed lowercase string if so. It also accepts uppercase strings, + e.g., ``Backend("GLOO")`` returns ``"gloo"``. + + .. note:: The entry ``Backend.UNDEFINED`` is present but only used as + initial value of some fields. Users should neither use it directly + nor assume its existence. + """ + + UNDEFINED = "undefined" + GLOO = "gloo" + NCCL = "nccl" + UCC = "ucc" + MPI = "mpi" + + _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"]) + + _plugins: Dict[str, _BackendPlugin] = {} + + backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] + + default_device_backend_map: Dict[str, str] = { + "cpu": GLOO, + "cuda": NCCL, + } + + backend_capability: Dict[str, List[str]] = { + GLOO: ["cpu", "cuda"], + NCCL: ["cuda"], + UCC: ["cpu", "cuda"], + MPI: ["cpu", "cuda"], + } + + backend_type_map: Dict[str, ProcessGroup.BackendType] = { + UNDEFINED: ProcessGroup.BackendType.UNDEFINED, + GLOO: ProcessGroup.BackendType.GLOO, + NCCL: ProcessGroup.BackendType.NCCL, + UCC: ProcessGroup.BackendType.UCC, + } + + def __new__(cls, name: str): + """Create and return a new instance of the class.""" + if not isinstance(name, str): + raise ValueError("Backend constructor parameter must be string-ish") + value = getattr(Backend, name.upper(), Backend.UNDEFINED) + + if value == Backend.UNDEFINED: + value = name.lower() + return value + + @classmethod + def register_backend( + cls, + name, + func, + extended_api=False, + devices: Optional[Union[str, List[str]]] = None, + ) -> None: + """ + Register a new backend with the given name and instantiating function. + + This class method is used by 3rd party ``ProcessGroup`` extension to + register new backends. + + Args: + name (str): Backend name of the ``ProcessGroup`` extension. It + should match the one in ``init_process_group()``. + func (function): Function handler that instantiates the backend. + The function should be implemented in the backend + extension and takes four arguments, including + ``store``, ``rank``, ``world_size``, and ``timeout``. + extended_api (bool, optional): Whether the backend supports extended argument structure. + Default: ``False``. If set to ``True``, the backend + will get an instance of ``c10d::DistributedBackendOptions``, and + a process group options object as defined by the backend implementation. + device (str or list of str, optional): device type this backend + supports, e.g. "cpu", "cuda", etc. If `None`, + assuming both "cpu" and "cuda" + + .. note:: This support of 3rd party backend is experimental and subject to change. + + """ + # Allow UCC plugin if Pytorch is not built with native support. + # TODO: remove this exception once UCC plugin is fully deprecated. + if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()): + assert not hasattr( + Backend, name.upper() + ), f"{name.upper()} c10d backend already exist" + assert ( + name.upper() not in Backend._plugins + ), f"{name.upper()} c10d backend creator function already exist" + + setattr(Backend, name.upper(), name.lower()) + Backend.backend_list.append(name.lower()) + if devices is not None: + for device in devices: + if device != "cpu" and device != "cuda": + Backend.default_device_backend_map[device] = name.lower() + Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM + + # Update device capability matrix in Backend class + if devices is None: + # This is more of a backward support for groups like `threaded`: + # assume default devices "cpu" and "cuda", but warn + warnings.warn( + f"Device capability of {name} unspecified, assuming `cpu` and " + "`cuda`. Please specify it via the `devices` argument of " + "`register_backend`." + ) + Backend.backend_capability[name.lower()] = ["cpu", "cuda"] + elif isinstance(devices, str): + # Single device string specified. Simply convert to list. + Backend.backend_capability[name.lower()] = [devices] + else: + Backend.backend_capability[name.lower()] = devices + + Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api) + + +class BackendConfig: + """Backend configuration class.""" + + def __init__(self, backend: Backend): + """Init.""" + self.device_backend_map: Dict[str, Backend] = {} + backend = str(backend) + + if backend == Backend.UNDEFINED: + # default config when backend is not specified + # supported since PyTorch 2.0 + for device, default_backend in Backend.default_device_backend_map.items(): + if is_backend_available(default_backend): + if ( + default_backend == Backend.NCCL + and not torch.cuda.is_available() + ): + continue + self.device_backend_map[device] = Backend(default_backend) + elif backend.lower() in Backend.backend_list: + # Cases for when backend is a single string (without device types) + # e.g. "nccl", "gloo", "ucc", "mpi" + supported_devices = Backend.backend_capability[backend.lower()] + backend_val = Backend(backend) + self.device_backend_map = dict.fromkeys(supported_devices, backend_val) + elif ":" in backend.lower(): + # Backend specified in "device:backend" format + # make sure the backend string is in the correct format + # "{device_type1}:{backend1},{device_type2}:{backend2}" + # e.g. "cpu:gloo,cuda:nccl" + backend_str_error_message = f"""The custom backend string argument is invalid: {backend}. + Custom backend string is an experimental feature where the backend string must be in the format: + ":,:...". e.g. 'cpu:gloo,cuda:nccl'""" + + # parse the backend string and populate the device_backend_map + for device_backend_pair_str in backend.lower().split(","): + device_backend_pair = device_backend_pair_str.split(":") + if len(device_backend_pair) != 2: + raise ValueError( + f"Invalid device:backend pairing: \ + {device_backend_pair_str}. {backend_str_error_message}" + ) + device, backend = device_backend_pair + if device in self.device_backend_map: + raise ValueError( + f"Duplicate device type {device} \ + in backend string: {backend}. {backend_str_error_message}" + ) + self.device_backend_map[device] = Backend(backend) + else: + # User specified a single backend name whose device capability is + # unknown, assuming it can support the default devices of PyTorch + # (cpu and cuda) + warnings.warn( + f"Device capability of {backend} unknown, assuming `cpu` and " + "`cuda`. You can specify it in `device:backend` format in " + "`init_process_group` call." + ) + backend_val = Backend(backend) + self.device_backend_map = { + "cpu": backend_val, + "cuda": backend_val, + "xpu": backend_val, + } + + logger.info("Using backend config: %s", self.device_backend_map) + + def __repr__(self): + """Return all the device:backend pairs separated by commas.""" + return ",".join( + f"{device}:{backend}" for device, backend in self.device_backend_map.items() + ) + + def get_device_backend_map(self) -> Dict[str, Backend]: + """Return backend map of the device.""" + return self.device_backend_map + + +class _reduce_op: + r""" + Deprecated enum-like class. + + For reduction operations: ``SUM``, ``PRODUCT``, ``MIN``, and ``MAX``. + + :class:`~torch.distributed.ReduceOp` is recommended to use instead. + """ + + def __init__(self) -> None: + # __members__ is a dict storing key-value pairs for enum classes + for k, v in ReduceOp.RedOpType.__members__.items(): + setattr(self, k, v) + self.__members__ = ReduceOp.RedOpType.__members__ + + @deprecated( + "`torch.distributed.reduce_op` is deprecated, " + "please use `torch.distributed.ReduceOp` instead", + category=FutureWarning, + ) + def __getattribute__(self, key): + return object.__getattribute__(self, key) + + +reduce_op = _reduce_op() + + +class P2POp: + """ + A class to build point-to-point operations for ``batch_isend_irecv``. + + This class builds the type of P2P operation, communication buffer, peer rank, + Process Group, and tag. Instances of this class will be passed to + ``batch_isend_irecv`` for point-to-point communications. + + Args: + op (Callable): A function to send data to or receive data from a peer process. + The type of ``op`` is either ``torch.distributed.isend`` or + ``torch.distributed.irecv``. + tensor (Tensor): Tensor to send or receive. + peer (int): Destination or source rank. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with recv. + """ + + def __init__( + self, + op: Callable, + tensor: torch.Tensor, + peer: int, + group: Optional[ProcessGroup] = None, + tag: int = 0, + ): + """Init.""" + self.op = op + self.tensor = tensor + self.peer = peer + self.group = group + self.tag = tag + + def __new__( + cls, + op: Callable, + tensor: torch.Tensor, + peer: int, + group: Optional[ProcessGroup] = None, + tag: int = 0, + ): + """Create and return a new instance of the class.""" + _check_op(op) + _check_single_tensor(tensor, "tensor") + return object.__new__(cls) + + def __repr__(self): + my_group_rank = get_rank(self.group) + peer_group_rank = ( + get_group_rank(self.group, self.peer) if self.group else self.peer + ) + op_name = self.op.__name__ + group_name = self.group.group_name if self.group else "default_pg" + if "send" in op_name: + s = my_group_rank + d = peer_group_rank + elif "recv" in op_name: + s = peer_group_rank + d = my_group_rank + else: + return super().__repr__() + + return f"P2POp({op_name} pg={group_name}, s={s}, d={d}, {self.tensor.shape}, {self.tensor.dtype})" + + +class _CollOp: + """ + A class to capture collective operations. + + Args: + op (Callable): A collective function, e.g. ``torch.distributed.all_reduce``. + tensor (Tensor): Tensor to operate on. + dst_tensor (Tensor, optional): Provided when source and destinaton tensors are not the same. + redop (ReduceOp, optional): reduce operation. + root (int, optional): root of broadcast or reduce. + """ + + def __init__( + self, + op: Callable, + tensor: torch.Tensor, + dst_tensor: Optional[torch.Tensor] = None, + redop: Optional[ReduceOp] = None, + root: Optional[int] = None, + ): + self.op = op + self.tensor = tensor + self.dst_tensor = dst_tensor + self.redop = redop + self.root = root + + +# DO NOT USE THESE FIELDS DIRECTLY. +# Use them through the _world object to make sure the _world override mechanism +_pg_map: Dict[ProcessGroup, Tuple[str, Store]] = {} +_pg_names: Dict[ProcessGroup, str] = {} +_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {} +# For a pg, it is a map from ProcessGroup to BackendConfig +_pg_backend_config: Dict[ProcessGroup, str] = {} +_group_count = 0 +_tags_to_pg: Dict[str, List[ProcessGroup]] = {} +_pg_to_tag: Dict[ProcessGroup, str] = {} +_backend: Optional[str] = None + + +class _World: + """ + Container class for c10d process group state. + + This is used during registration and lookup of PG state. + + .. warning:: This is an experimental API intended to expose the inner workings + of c10d and is subject to change.. + """ + + def __init__(self) -> None: + self._default_pg = None + self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {} + self._pg_default_device: Dict[ProcessGroup, torch.device] = {} + + @property + def default_pg(self) -> Optional[ProcessGroup]: + """ + Process group that includes all ranks of the cluster. + + This default ProcessGroup is used by c10d APIs when a ProcessGroup is needed + but None is provided. + """ + return self._default_pg + + @default_pg.setter + def default_pg(self, value) -> None: + self._default_pg = value + + @property + def pg_map(self) -> Dict[ProcessGroup, Tuple[str, Store]]: + """ + Provide Mapping from ProcessGroup to backend name and store. + + For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store) + For MPI pg, it is a map from ProcessGroup to (Backend, None) + + TODO don't expose the map, expose fine grained ops + """ + global _pg_map + return _pg_map + + @property + def pg_names(self) -> Dict[ProcessGroup, str]: + """ + Process group's names, map from ProcessGroup to str. + + TODO don't expose the map, expose fine grained ops + """ + global _pg_names + return _pg_names + + @property + def pg_group_ranks(self) -> Dict[ProcessGroup, Dict[int, int]]: + """ + Process group's global rank to local rank mapping. + + TODO don't expose the map, expose fine grained ops + """ + global _pg_group_ranks + return _pg_group_ranks + + @property + def pg_backend_config(self) -> Dict[ProcessGroup, str]: + """ + Process group's backend config. + + TODO don't expose the map, expose fine grained ops + """ + global _pg_backend_config + return _pg_backend_config + + @property + def group_count(self) -> int: + """ + Process group count for default naming. + + TODO don't expose group_count, use something else instead + """ + global _group_count + return _group_count + + @group_count.setter + def group_count(self, value: int) -> None: + """Use to compute the name of ProcessGroups when using global synchronization.""" + global _group_count + _group_count = value + + @property + def tags_to_pg(self) -> Dict[str, List[ProcessGroup]]: + global _tags_to_pg + return _tags_to_pg + + @property + def pg_to_tag(self) -> Dict[ProcessGroup, str]: + global _pg_to_tag + return _pg_to_tag + + @property + def pg_coalesce_state(self) -> Dict[ProcessGroup, List[_CollOp]]: + return self._pg_coalesce_state + + @property + def pg_default_device(self) -> Dict[ProcessGroup, torch.device]: + return self._pg_default_device + + @property + def pg_config_info(self) -> List[Dict[str, Any]]: + """ + Return a list of dict with process groups and backends. + + Along with their unique IDs and configurations (types and ranks). + """ + config_info: List[Dict[str, Any]] = [] + default_pg_size = _get_group_size(None) + for pg in self.pg_map.keys(): + ranks = self.pg_group_ranks[pg] + config_info.append( + { + "pg_name": self.pg_names[pg], + "pg_desc": pg.group_desc, + "backend_config": self.pg_backend_config[pg], + "ranks": list(ranks.keys()) + if len(ranks) != default_pg_size + else [], # 'ranks' is an empty list when all ranks are involved in a pg + "group_size": len(ranks), + "group_count": self.group_count, + } + ) + return config_info + + +_world = _World() +"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it""" + + +class _WorldMeta(type): + """ + Meta class of ``group`` and ``GroupMember``. + + Allows them to have the class property ``WORLD``. + """ + + # Points to the default PG once initialized. + @property + def WORLD(cls) -> Optional[ProcessGroup]: + return _world.default_pg + + @WORLD.setter + def WORLD(cls, pg: Optional[ProcessGroup]): + _world.default_pg = pg + + +class group(metaclass=_WorldMeta): + """Group class. Placeholder.""" + + +class GroupMember(metaclass=_WorldMeta): + """Group member class.""" + + NON_GROUP_MEMBER = -100 + + +def _get_default_timeout(backend: Backend) -> timedelta: + # see note on nccl vs other backend timeout (constants.py) + if backend == Backend.NCCL: + if not isinstance(default_pg_nccl_timeout, timedelta): + # TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was + # changed to be a warning. We should fix the moco model. + warnings.warn( + "Attempted to get default timeout for nccl backend, but NCCL support is not compiled" + ) + return default_pg_timeout + return default_pg_nccl_timeout + else: + return default_pg_timeout + + +def _check_valid_timeout(timeout: Any) -> None: + if not isinstance(timeout, timedelta): + raise TypeError( + f"Expected timeout argument to be of type datetime.timedelta, got {timeout}" + ) + + +# Default process group state +_default_pg_init_method: Optional[str] = None + +STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key" + + +def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device: + """ + Return the device to use with ``group`` for control flow usage (object collectives, barrier). + + There are selection rules: + 1. If user specifies exactly one backend in ``init_process_group`` call: + use that backend + 2. Else if user specifies multiple "device:backend" pairs in init_process_group: + If "cpu" is among those pairs, use "cpu" (because the object is in cpu memory); + Otherwise, use the first backend (sort of a random pick). + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + torch.device: The device to use with ``group``. + + """ + group = group or _get_default_group() + if group in _world.pg_default_device: + # Previously searched and cached; just return + return _world.pg_default_device[group] + + if not isinstance(group, ProcessGroup): + # Provide backward compatibility to cases where `group` passed in is + # actually a Backend (like `ProcessGroupGloo`) rather than a + # `ProcessGroup` in PT 2.0 sense + warnings.warn( + f"You are using a Backend {type(group)} as a ProcessGroup. " + "This usage is deprecated since PyTorch 2.0. Please use a public API " + "of PyTorch Distributed instead.", + FutureWarning, + stacklevel=3, + ) + # Most users create Gloo with private API for object collectives + _world.pg_default_device[group] = torch.device("cpu") + return _world.pg_default_device[group] + + """ + ``group._device_types`` is a property pybind that returns the devices + ("cpu", "cuda", etc) supported by ``group``. Can be multiple if the + ``group`` supports multiple devices. + """ + devices = group._device_types + + if len(devices) == 1: + # User fixed exactly one backend in `init_process_group` + _world.pg_default_device[group] = devices[0] + elif len(devices) == 0: + # No backend has been registered with this PG (maybe because no + # collective has been run?) We pick cpu as the default and hopefully + # this would lazily init Gloo or other available cpu backend. + _world.pg_default_device[group] = torch.device("cpu") + elif torch.device("cpu") in devices: + # There are multiple backends in this PG and cpu is among them. + # cpu is preferred as the object is in cpu memory. No need for device + # copy. + _world.pg_default_device[group] = torch.device("cpu") + else: + # No cpu in the backend list. Randomly pick the first backend + _world.pg_default_device[group] = devices[0] + + logger.info( + "Using device %s for object " "collectives.", _world.pg_default_device[group] + ) + return _world.pg_default_device[group] + + +@_time_logger +def _store_based_barrier( + rank, + store, + group_name, + rendezvous_count, + timeout, + logging_interval=timedelta(seconds=10), +) -> None: + """ + Store based barrier for synchronizing processes. + + Barrier based on store which is used for synchronizing processes after + ``init_process_group`` or ``new_group``. Intended to be used only with + those two methods and is not a generic alternative to ``barrier()``. + """ + store_key = f"{STORE_BASED_BARRIER_PREFIX}:{group_name}" + store.add(store_key, 1) + logger.debug("Added key: %s to store for rank: %s", store_key, rank) + + # Now wait for all workers to check in with the store. + world_size = rendezvous_count + worker_count = store.add(store_key, 0) + + last_worker_key = f"{store_key}:last_worker" + if worker_count == world_size: + store.set(last_worker_key, "1") + + # adjust the timeout to be at least 10secs + 1sec per thousand ranks to reduce the odds of timeout + # this value was empirically found while scale testing. + logging_interval = max(logging_interval, timedelta(seconds=10 + world_size / 1000)) + + start = time.time() + while True: + try: + # This will throw an exception after the logging_interval in which we print out + # the status of the group or time out officially, throwing runtime error + store.wait([last_worker_key], logging_interval) + break + except RuntimeError as e: + worker_count = store.add(store_key, 0) + # Print status periodically to keep track. + logger.debug( + "Waiting in store based barrier to initialize process group for " + "rank: %s, key: %s (world_size=%s, num_workers_joined=%s, timeout=%s error=%s)", + rank, + store_key, + world_size, + worker_count, + timeout, + e, + ) + + if timedelta(seconds=(time.time() - start)) > timeout: + raise DistStoreError( # noqa: B904 + "Timed out initializing process group in store based barrier on " + f"rank {rank}, for key: {store_key} (world_size={world_size}, " + f"num_workers_joined={worker_count}, timeout={timeout} error={e})" + ) + + logger.info( + "Rank %s: Completed store-based barrier for key:%s with %s nodes.", + rank, + store_key, + world_size, + ) + + +def _rank_not_in_group(group: Optional[ProcessGroup]) -> bool: + """Check if the current process's rank is not in a given group.""" + if group is None: + return False + return group == GroupMember.NON_GROUP_MEMBER + + +def _warn_not_in_group(op_name) -> None: + global_rank = -1 if GroupMember.WORLD is None else GroupMember.WORLD.rank() + warnings.warn( + f"Running {op_name} on global rank {global_rank} which does not " + "belong to the given group." + ) + + +def get_group_rank(group: ProcessGroup, global_rank: int) -> int: + """ + Translate a global rank into a group rank. + + ``global_rank`` must be part of ``group`` otherwise this raises RuntimeError. + + Args: + group (ProcessGroup): ProcessGroup to find the relative rank. + global_rank (int): Global rank to query. + + Returns: + Group rank of ``global_rank`` relative to ``group`` + + N.B. calling this function on the default process group returns identity + """ + if group is GroupMember.WORLD: + return global_rank + if group not in _world.pg_group_ranks: + raise ValueError( + f"Group {group} is not registered, please create group with torch.distributed.new_group API" + ) + group_ranks = _world.pg_group_ranks[group] + if global_rank not in group_ranks: + raise ValueError(f"Global rank {global_rank} is not part of group {group}") + + return group_ranks[global_rank] + + +def get_global_rank(group: ProcessGroup, group_rank: int) -> int: + """ + Translate a group rank into a global rank. + + ``group_rank`` must be part of `group` otherwise this raises RuntimeError. + + Args: + group (ProcessGroup): ProcessGroup to find the global rank from. + group_rank (int): Group rank to query. + + Returns: + Global rank of ``group_rank`` relative to ``group`` + + N.B. calling this function on the default process group returns identity + """ + if group is GroupMember.WORLD: + return group_rank + if group not in _world.pg_group_ranks: + raise ValueError( + f"Group {group} is not registered, please create group with torch.distributed.new_group API" + ) + for rank, grp_rank in _world.pg_group_ranks[group].items(): + if grp_rank == group_rank: + return rank + raise ValueError(f"Group rank {group_rank} is not part of group {group}") + + +# TODO: remove this once the ecosystem moves away from it. +@deprecated( + "`torch.distributed.distributed_c10d._get_global_rank` is deprecated, " + "please use `torch.distributed.distributed_c10d.get_global_rank` instead", + category=FutureWarning, +) +def _get_global_rank(group, rank) -> int: + """Use get_global_rank as this method is deprecated.""" + return get_global_rank(group, rank) + + +def get_process_group_ranks(group: ProcessGroup) -> List[int]: + """ + Get all ranks associated with ``group``. + + Args: + group (ProcessGroup): ProcessGroup to get all ranks from. + + Returns: + List of global ranks ordered by group rank. + """ + return list(_world.pg_group_ranks[group].keys()) + + +def _get_group_size(group) -> int: + """Get a given group's world size.""" + if group is GroupMember.WORLD or group is None: + default_pg = _get_default_group() + return default_pg.size() + return group.size() + + +def _get_group_size_by_name(group_name: str) -> int: + group = _resolve_process_group(group_name) + return group.size() + + +def _resolve_group_name_by_ranks_and_tag(ranks: List[int], tag: str) -> str: + # TODO(yifu): remove this function once ranks + tag is not a supported + # identifier for process group for functional collectives. + group = _find_pg_by_ranks_and_tag(tag, ranks) + if group is None: + raise ValueError("") + return group.group_name + + +def _check_single_tensor(param, param_name) -> None: + """Check that the parameter ``param_name`` is a single tensor.""" + if not isinstance(param, torch.Tensor): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type torch.Tensor + but got {type(param)} instead.""" + ) + + +def _check_tensor_list(param, param_name) -> None: + """Check that the parameter ``param_name`` is a list of tensors.""" + if not isinstance(param, list): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] + but got {type(param)} instead.""" + ) + elif not all(isinstance(p, torch.Tensor) for p in param): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] + but got {type(param)} with elements of type {[type(p) for p in param]}.""" + ) + + +def _as_iterable(obj) -> collections.abc.Iterable: + return obj if isinstance(obj, list) else (obj,) + + +def _ensure_all_tensors_same_dtype(*tensors) -> None: + last_dtype = None + for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): + tensor_dtype = tensor.dtype + # Mixing complex and its element type is allowed + if tensor_dtype.is_complex: + tensor_dtype = ( + torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 + ) + + if last_dtype is None: + last_dtype = tensor_dtype + else: + if last_dtype != tensor_dtype: + raise ValueError( + "Invalid usage of tensors with different dtypes" + f"Found {last_dtype} and {tensor.dtype}" + ) + + +def _check_op(op) -> None: + """Check that the ``op`` is either isend or irecv.""" + if op not in [isend, irecv]: + raise ValueError( + "Invalid ``op``. Expected ``op`` " + "to be of type ``torch.distributed.isend`` or " + "``torch.distributed.irecv``." + ) + + +def _check_p2p_op_list(p2p_op_list) -> None: + """ + Check that the ``p2p_op_list`` is a list of P2POp instances. + + Also, check that all ops use the same group. + """ + if not isinstance(p2p_op_list, list) or not all( + isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list + ): + raise ValueError( + "Invalid ``p2p_op_list``. Each op is expected to " + "to be of type ``torch.distributed.P2POp``." + ) + + group = p2p_op_list[0].group + if not all(group == p2p_op.group for p2p_op in p2p_op_list): + raise ValueError("All ops need to use the same group.") + + +def is_mpi_available() -> bool: + """Check if the MPI backend is available.""" + return _MPI_AVAILABLE + + +def is_nccl_available() -> bool: + """Check if the NCCL backend is available.""" + return _NCCL_AVAILABLE + + +def is_gloo_available() -> bool: + """Check if the Gloo backend is available.""" + return _GLOO_AVAILABLE + + +def is_ucc_available() -> bool: + """Check if the UCC backend is available.""" + return _UCC_AVAILABLE + + +def is_backend_available(backend: str) -> bool: + """ + Check backend availability. + + Checks if the given backend is available and supports the built-in backends or + third-party backends through function ``Backend.register_backend``. + + Args: + backend (str): Backend name. + Returns: + bool: Returns true if the backend is available otherwise false. + """ + # If the backend has an ``is_backend_available`` function, return the result of that function directly + available_func = getattr(torch.distributed, f"is_{backend.lower()}_available", None) + if available_func: + return available_func() + + return backend.lower() in Backend.backend_list + + +def is_initialized() -> bool: + """Check if the default process group has been initialized.""" + return GroupMember.WORLD is not None + + +def is_torchelastic_launched() -> bool: + """ + Check whether this process was launched with ``torch.distributed.elastic`` (aka torchelastic). + + The existence of ``TORCHELASTIC_RUN_ID`` environment + variable is used as a proxy to determine whether the current process + was launched with torchelastic. This is a reasonable proxy since + ``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a + non-null value indicating the job id for peer discovery purposes.. + """ + return os.getenv("TORCHELASTIC_RUN_ID") is not None + + +def _is_barrier_after_init() -> int: + # Environment variable to control whether process group should perform a + # barrier after its init. Default value is 0, i.e. no barrier. If you + # experience issue with this setting, you may set + # `TORCH_DIST_INIT_BARRIER=1` to add the barrier. + return int(os.getenv("TORCH_DIST_INIT_BARRIER", "0")) + + +def _get_default_group() -> ProcessGroup: + """Get the default process group created by init_process_group.""" + if not is_initialized(): + raise ValueError( + "Default process group has not been initialized, " + "please make sure to call init_process_group." + ) + if TYPE_CHECKING: + return not_none(GroupMember.WORLD) + else: + return GroupMember.WORLD + + +def _get_default_store() -> Store: + """Get the default store created by init_process_group.""" + if not is_initialized(): + raise ValueError( + "Default process group has not been initialized, " + "please make sure to call init_process_group." + ) + default_pg = _get_default_group() + _, default_store = _world.pg_map[default_pg] + return default_store + + +def _update_default_pg(pg) -> None: + _world.default_pg = pg + rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1 + torch._C._distributed_c10d._set_global_rank(rank) + + +def get_backend_config(group: Optional[ProcessGroup] = None) -> str: + """ + Return the backend configuration of the given process group. + + Args: + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific group + is specified, the calling process must be part of :attr:`group`. + + Returns: + The backend configuration of the given process group as a lower case string. + + """ + pg = group or _get_default_group() + if _rank_not_in_group(pg): + raise ValueError("Invalid process group specified") + backend_config = _world.pg_backend_config.get(pg) + return str(not_none(backend_config)) + + +def get_backend(group: Optional[ProcessGroup] = None) -> Backend: + """ + Return the backend of the given process group. + + Args: + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific group + is specified, the calling process must be part of :attr:`group`. + + Returns: + The backend of the given process group as a lower case string. + + """ + pg = group or _get_default_group() + if _rank_not_in_group(pg): + raise ValueError("Invalid process group specified") + pg_store = _world.pg_map[pg] if pg in _world.pg_map else None + return Backend(not_none(pg_store)[0]) + + +def _get_process_group_uid(pg: ProcessGroup) -> int: + backend = None + try: + backend = pg._get_backend(torch.device("cuda")) + except RuntimeError: + pass + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + return backend.uid + return -1 + + +def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]: + """ + Return the pg configuration of the given process group. + + """ + pg = group or _get_default_group() + return { + "pg_name": _get_process_group_name(pg), + "pg_desc": pg.group_desc, + "backend_config": get_backend_config(pg), + "pg_size": _get_group_size(pg), + "ranks": get_process_group_ranks(pg), + } + + +def _get_all_pg_configs() -> List[Dict[str, Any]]: + """ + Return the pg configuration of all the process groups. + + """ + config_info: List[Dict[str, Any]] = [] + for pg in _world.pg_map.keys(): + config_info.append(_get_pg_config(pg)) + return config_info + + +def get_pg_count() -> int: + """ + Return the number of process groups. + + """ + return _world.group_count + + +def get_node_local_rank(fallback_rank: Optional[int] = None) -> int: + """ + Return the local rank of the current process relative to the node. + + Semantically, this is a useful concept for mapping processes to devices. + For example, on a node with 8 accelerator you could use the node local rank to decide + which accelerator device to bind the process to. + + In practice, the actual assignment of node local ranks is handled by the process launcher outside of pytorch, + and communicated via the `LOCAL_RANK` environment variable. + + Torchrun will automatically populate `LOCAL_RANK`, but other launchers may not. If `LOCAL_RANK` is unspecified, + this API will fall back to the provided kwarg 'fallback_rank' if specified, otherwise it will raise an error. The + intent is to allow writing an application that runs either in single or multi device contexts without error. + + """ + if "LOCAL_RANK" in os.environ: + return int(os.environ["LOCAL_RANK"]) + elif fallback_rank is not None: + return int(fallback_rank) + raise RuntimeError( + "LOCAL_RANK is not in the environment. Consider passing fallback_rank to allow `get_node_local_rank` to work, " + "assuming you are not running in a multi-device context and want the code to run locally instead." + ) + + +def _add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None: + """ + This API adds an ephemeral timeout extension for all PGs locally + on one rank. The timeout gets reset when the first collective issued + after API called finished. + NOTE: We only support to set timeout for cuda backends for now. + NOTE: While this feature + provides flexibility in specific scenarios, it introduces statefulness + to timeout setting. Therefore, it is advisable to use this API sparingly + and consider alternative approaches, such as directly setting the timeout + or utilizing a barrier collective (one can set any timeout to the barrier), + whenever feasible. + + Args: + timeout (timedelta): The delta of timeout to extend. + + Returns: + None. + """ + for pg in _world.pg_map.keys(): + devices = pg._device_types + if torch.device("cuda") in devices: + backend = pg._get_backend(torch.device("cuda")) + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + backend._add_ephemeral_timeout(timeout) + + +def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> None: + """ + Set the timeout for the given process group when users want to use a different timeout instead of + default values. + + Args: + timeout (timedelta): Timeout for operations executed against the process group which + users want to set. Default value is 10 minutes for NCCL and 30 minutes for other backends. + This is the duration after which collectives will be aborted asynchronously and the process will crash. + This is done since CUDA execution is async and it is no longer safe to continue executing user code since + failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. + When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. + + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific group + is specified, the calling process must be part of :attr:`group`. + + Returns: + None + """ + if group is None: + group = _get_default_group() + if _rank_not_in_group(group): + raise ValueError("Invalid process group specified") + assert isinstance(group, ProcessGroup) + devices = group._device_types + backends = set() + if torch.device("cpu") in devices and is_gloo_available(): + backend = group._get_backend(torch.device("cpu")) + if isinstance(backend, ProcessGroupGloo): + backends.add(backend) + if torch.device("cuda") in devices: + backend = group._get_backend(torch.device("cuda")) + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + backends.add(backend) # type: ignore[arg-type] + elif is_gloo_available() and isinstance(backend, ProcessGroupGloo): + backends.add(backend) # type: ignore[arg-type] + if len(backends) == 0: + warnings.warn("Set timeout is now only supported for either nccl or gloo.") + for backend in backends: + backend._set_default_timeout(timeout) + + +@_exception_logger +@_time_logger +def init_process_group( + backend: Optional[str] = None, + init_method: Optional[str] = None, + timeout: Optional[timedelta] = None, + world_size: int = -1, + rank: int = -1, + store: Optional[Store] = None, + group_name: str = "", + pg_options: Optional[Any] = None, + device_id: Optional[torch.device] = None, +) -> None: + """ + Initialize the default distributed process group. + + This will also initialize the distributed package. + + There are 2 main ways to initialize a process group: + 1. Specify ``store``, ``rank``, and ``world_size`` explicitly. + 2. Specify ``init_method`` (a URL string) which indicates where/how + to discover peers. Optionally specify ``rank`` and ``world_size``, + or encode all required parameters in the URL and omit them. + + If neither is specified, ``init_method`` is assumed to be "env://". + + + Args: + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values include ``mpi``, ``gloo``, + ``nccl``, and ``ucc``. If the backend is not provided, then both a ``gloo`` + and ``nccl`` backend will be created, see notes below for how multiple + backends are managed. This field can be given as a lowercase string + (e.g., ``"gloo"``), which can also be accessed via + :class:`Backend` attributes (e.g., ``Backend.GLOO``). If using + multiple processes per machine with ``nccl`` backend, each process + must have exclusive access to every GPU it uses, as sharing GPUs + between processes can result in deadlocks. ``ucc`` backend is + experimental. + init_method (str, optional): URL specifying how to initialize the + process group. Default is "env://" if no + ``init_method`` or ``store`` is specified. + Mutually exclusive with ``store``. + world_size (int, optional): Number of processes participating in + the job. Required if ``store`` is specified. + rank (int, optional): Rank of the current process (it should be a + number between 0 and ``world_size``-1). + Required if ``store`` is specified. + store(Store, optional): Key/value store accessible to all workers, used + to exchange connection/address information. + Mutually exclusive with ``init_method``. + timeout (timedelta, optional): Timeout for operations executed against + the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends. + This is the duration after which collectives will be aborted asynchronously and the process will crash. + This is done since CUDA execution is async and it is no longer safe to continue executing user code since + failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. + When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. + + group_name (str, optional, deprecated): Group name. This argument is ignored + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. As of now, the only + options we support is ``ProcessGroupNCCL.Options`` for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + the nccl backend can pick up high priority cuda streams when + there're compute kernels waiting. For other availble options to config nccl, + See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t + device_id (torch.device, optional): a single, specific device + to "bind" this process to, allowing for backend-specific + optimizations. Currently this has two effects, only under + NCCL: the communicator is immediately formed (calling + ``ncclCommInit*`` immediately rather than the normal lazy + call) and sub-groups will use ``ncclCommSplit`` when + possible to avoid unnecessary overhead of group creation. If you + want to know NCCL initialization error early, you can also use this + field. + + .. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source + on a system that supports MPI. + + .. note:: Support for multiple backends is experimental. Currently when no backend is + specified, both ``gloo`` and ``nccl`` backends will be created. The ``gloo`` backend + will be used for collectives with CPU tensors and the ``nccl`` backend will be used + for collectives with CUDA tensors. A custom backend can be specified by passing in + a string with format ":,:", e.g. + "cpu:gloo,cuda:custom_backend". + + """ + + global _world + + global _backend + global _default_pg_init_method + + if GroupMember.WORLD is not None: + raise ValueError("trying to initialize the default process group twice!") + + set_pytorch_distributed_envs_from_justknobs() + + # Depending on the import order, some trace_rules functions may be evaluated + # during the import phase. In such a case, these functions may not correctly + # add the distributed related rules due to import circular dependency. + # We need to clear the lru_cache during the runtime to ensure the correctness + # of these trace_rules. + # + # Since this API must be called before all distributed code being compiled, + # clearing the cache here should be safe. + if "torch._dynamo" in sys.modules: + torch._dynamo.trace_rules.clear_lru_cache() + + assert (store is None) or ( + init_method is None + ), "Cannot specify both init_method and store." + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = _get_default_timeout(backend) + + _check_valid_timeout(timeout) + + """ + Group name is not visible to users unless they access + internals of c10d. This means we can ignore the value + they provide as it not exposed in a public way. + """ + group_name = _process_group_name([], use_hashed_name=False) + if backend == Backend.MPI: + if world_size != -1 or rank != -1: + warnings.warn( + f"For MPI backend, world_size ({world_size}) and rank ({rank}) " + "are ignored since they are assigned by the " + "MPI runtime." + ) + + default_pg, _ = _new_process_group_helper( + -1, + -1, + [], + backend, + None, + group_name, + timeout=timeout, + group_desc="default_pg", + ) + _update_default_pg(default_pg) + else: + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous( + not_none(init_method), rank, world_size, timeout=timeout + ) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore("default_pg", store) + + default_pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name, + pg_options=pg_options, + timeout=timeout, + device_id=device_id, + group_desc="default_pg", + ) + _update_default_pg(default_pg) + + _world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index] + _backend = _world.pg_map[not_none(GroupMember.WORLD)][0] + _default_pg_init_method = init_method + + old_hook = sys.excepthook + excepthook_prefix = f"[rank{get_rank()}]" + + def _distributed_excepthook(*args): + old_stderr = sys.stderr + sys.stderr = buf = io.StringIO() + try: + old_hook(*args) + finally: + sys.stderr = old_stderr + msg = buf.getvalue() + msg = "\n".join( + f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n") + ) + sys.stderr.write(msg) + sys.stderr.flush() + + sys.excepthook = _distributed_excepthook + + if _is_barrier_after_init() == 1: + # barrier at the end to ensure that once we return from this method, all + # process groups including global variables (if any) are updated + # correctly on all ranks. + # Update 04/2023: for large-scale runs, this barrier (esp. store-based + # barrier) may be costly and/or unscalable. Also, in a lot of cases, + # these barriers may be unnecessary, as proven by a green CI after + # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been + # added which enables this barrier only when set to 1. + logger.debug( + "Performing barrier after ProcessGroup initialization since " + "TORCH_DIST_INIT_BARRIER = 1" + ) + if backend == Backend.MPI: + # MPI backend doesn't use store. + barrier() + else: + # Use store based barrier here since barrier() used a bunch of + # default devices and messes up NCCL internal state. + _store_based_barrier(rank, store, group_name, world_size, timeout) + + +def _get_split_source(pg): + split_from = None + if pg.bound_device_id: + split_from = pg._get_backend(pg.bound_device_id) + elif pg is _world.default_pg: + try: + split_from = pg._get_backend(torch.device("cuda")) + except RuntimeError: + # no cuda device associated with this backend + pass + + if not split_from or not split_from.supports_splitting: + return None + + # If necessary, find a backend to split from by peeling process + # group wrappers from our potentially wrapped process group. + while _GLOO_AVAILABLE and isinstance(split_from, _ProcessGroupWrapper): + split_from = split_from.wrapped_pg + + return split_from + + +def _shutdown_backend(pg): + """ + Try to shut down the backend of a process group. + Currently, only ProcessGroupNCCL backend is supported. + No op for other backends. + """ + backend = None + try: + backend = pg._get_backend(torch.device("cuda")) + except RuntimeError: + pass + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + # explictly call shutdown to ensure that NCCL resources are released + backend._shutdown() + + +def _new_process_group_helper( + group_size, + group_rank, + global_ranks_in_group, + backend, + store, + group_name, + pg_options=None, + timeout=None, + pg_tag=None, + device_id=None, + group_desc=None, +): + """ + Create a new distributed process group. + + This function must be called by ALL processes in the global group, even if + the calling process is not part of the newly created group. In that case, + this function returns GroupMember.NON_GROUP_MEMBER. + + This function is called with ``global_ranks_in_group == []`` for the default group. + """ + global _world + + if group_name in _world.pg_names.values(): + raise ValueError( + "The specified group name has already been " + "created, please use a different group name" + ) + + if device_id is not None and (device_id.index is None or device_id.type != "cuda"): + raise ValueError( + "init_process_group device_id parameter must be a cuda device with an " + "id, e.g. cuda:0, not just cuda or cpu" + ) + + # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value + _check_valid_timeout(timeout) + + if pg_tag not in [None, ""]: + # creating with the same tag and rank set results in the same underlying PG + existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group) + if existing_group: + _, prefix_store = _world.pg_map[existing_group] + return existing_group, prefix_store + + group_desc = "undefined" if group_desc is None else group_desc + + # The list of group ranks is empty if we're creating the default group. + is_default_group = len(global_ranks_in_group) == 0 + + # nccl and potentially other backends allow creation of + # communicators based on pre-existing ones, which can save + # initialization time. Due to lazy initialization of + # communicators in some backends, we have to be careful and only + # split when we *know* the backends already are connected _on all + # ranks_. We can only know this if the group we are making is the + # entire world or if we have bound a device id to the world (which + # causes early connection initialization). + if is_initialized() and ( + len(global_ranks_in_group) == _get_default_group().size() + or _get_default_group().bound_device_id + ): + split_from = _get_split_source(_get_default_group()) + else: + split_from = None + + # If this is a subgroup (which means group_ranks is specified), + # we check if the current process is a member of the new group. + if not is_default_group: + global_rank = _get_default_group().rank() + if global_rank not in global_ranks_in_group: + # If we are using `ncclCommSplit` (or similar split from + # other APIs) to create the communicator, we will need to + # call `ncclCommSplit` on *all* ranks in this new group's + # parent group, even those not in the new group. This is + # a requirement of the NCCL API as otherwise we would get + # out of sync. + if split_from: + split_from.perform_nocolor_split(_get_default_group().bound_device_id) + return GroupMember.NON_GROUP_MEMBER, None + + prefix_store = PrefixStore(f"{group_name}/", store) + base_pg_options = ProcessGroup.Options(backend=str(backend)) + base_pg_options._timeout = timeout + pg: ProcessGroup = ProcessGroup( + prefix_store, group_rank, group_size, base_pg_options + ) + if device_id: + pg.bound_device_id = device_id + backend_config = BackendConfig(backend) + backend_class: torch._C._distributed_c10d.Backend + for device, backend_str in backend_config.get_device_backend_map().items(): + # Use the group name as prefix in the default store, such that + # a single store can be reused by multiple groups. + backend_prefix_store = PrefixStore(f"{device}/", prefix_store) + + if backend_str == Backend.MPI: + if not is_mpi_available(): + raise RuntimeError( + "Distributed package doesn't have MPI built in." + " MPI is only included if you build PyTorch from" + " source on a host that has MPI installed." + ) + backend_class = ProcessGroupMPI.create(global_ranks_in_group) + backend_type = ProcessGroup.BackendType.MPI + if not backend_class: + return GroupMember.NON_GROUP_MEMBER, None + # create new process group with accurate rank and size + if pg.rank() == -1 and pg.size() == -1: + pg = ProcessGroup( + backend_prefix_store, + backend_class.rank(), + backend_class.size(), + base_pg_options, + ) + elif backend_str == Backend.GLOO: + # TODO: remove this check after lazy initialization is supported + # if pg_options is not None: + # raise RuntimeError("GLOO options not supported") + backend_class = ProcessGroupGloo( + backend_prefix_store, group_rank, group_size, timeout=timeout + ) + backend_type = ProcessGroup.BackendType.GLOO + elif backend_str == Backend.NCCL: + if not is_nccl_available(): + raise RuntimeError("Distributed package doesn't have NCCL built in") + if pg_options is not None: + assert isinstance( + pg_options, ProcessGroupNCCL.Options + ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options" + if pg_options._timeout != timeout: + warnings.warn( + "pg_options._timeout was specified, " + "but timeout kwarg has a default value that will always override it. " + ) + else: + # default pg_options for NCCL + pg_options = ProcessGroupNCCL.Options() + pg_options.is_high_priority_stream = False + pg_options._timeout = timeout + + if split_from: + pg_options.split_from = split_from + pg_options.split_color = _process_group_color(global_ranks_in_group) + pg_options.global_ranks_in_group = global_ranks_in_group + pg_options.group_name = group_name + backend_class = ProcessGroupNCCL( + backend_prefix_store, group_rank, group_size, pg_options + ) + backend_type = ProcessGroup.BackendType.NCCL + elif backend_str == Backend.UCC and is_ucc_available(): + # TODO: once UCC plugin is fully deprecated, remove + # is_ucc_available() from above elif-condition and raise + # RuntimeError if is_ucc_available() returns false. + + backend_class = ProcessGroupUCC( + backend_prefix_store, group_rank, group_size, timeout=timeout + ) + backend_type = ProcessGroup.BackendType.UCC + else: + assert ( + backend_str.upper() in Backend._plugins + ), f"Unknown c10d backend type {backend_str.upper()}" + + backend_plugin = Backend._plugins[backend_str.upper()] + creator_fn = backend_plugin.creator_fn + extended_api = backend_plugin.extended_api + backend_type = ProcessGroup.BackendType.CUSTOM + + if not extended_api: + backend_class = creator_fn( + backend_prefix_store, group_rank, group_size, timeout + ) + else: + dist_backend_opts = _DistributedBackendOptions() + dist_backend_opts.store = backend_prefix_store + dist_backend_opts.group_rank = group_rank + dist_backend_opts.group_size = group_size + dist_backend_opts.timeout = timeout + dist_backend_opts.group_id = group_name + dist_backend_opts.global_ranks_in_group = global_ranks_in_group + + backend_class = creator_fn(dist_backend_opts, pg_options) + + # Set sequence numbers for gloo and nccl backends. + if backend_str == Backend.GLOO: + assert isinstance(backend_class, ProcessGroupGloo) + backend_class._set_sequence_number_for_group() + elif backend_str == Backend.NCCL: + assert isinstance(backend_class, ProcessGroupNCCL) + backend_class._set_sequence_number_for_group() + + # If the type is a subclass of ProcessGroup then return this process group immediately + # TODO: This defaults to the old behavior for PythonProcessGroups which overwrites the + # ProcessGroup instance + if issubclass(type(backend_class), ProcessGroup): + pg = backend_class # type: ignore[assignment] + break + + # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set + if ( + backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC] + or backend_str.upper() in Backend._plugins + ): + # In debug mode and if GLOO is available, wrap in a wrapper PG that + # enables enhanced collective checking for debuggability. + if get_debug_level() == DebugLevel.DETAIL: + if not _GLOO_AVAILABLE: + logger.info( + """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but + GLOO is not available. Build with Gloo to + create a wrapper process group in debug mode + to aid collective desynchronization debugging.""" + ) + else: + backend_class = _create_process_group_wrapper( + wrapped_pg=backend_class, + store_prefix=group_name, + store=backend_prefix_store, + rank=group_rank, + world_size=group_size, + timeout=timeout, + ) + + # register only a single backend when all get_device_backend_map values are the same + if len(set(backend_config.get_device_backend_map().values())) == 1: + for device in backend_config.get_device_backend_map().keys(): + pg._register_backend(torch.device(device), backend_type, backend_class) + + # break out of outer loop to not create any more backends + break + + pg._register_backend(torch.device(device), backend_type, backend_class) + + # set group_name and group_dsec to backend + assert group_name is not None + assert group_desc is not None + pg._set_group_name(group_name) + pg._set_group_desc(group_desc) + + if device_id and pg._get_backend(device_id).supports_splitting: + eager_backend = pg._get_backend(device_id) + eager_backend.eager_connect_single_device(device_id) + + # update global state + _world.pg_map[pg] = (backend, prefix_store) + _world.pg_names[pg] = group_name + _register_process_group(group_name, pg) + + _world.pg_backend_config[pg] = str(backend_config) + # "" is the default tag for user PGs + if pg_tag in [None, ""]: + pg_tag = f"ptd:{group_name}" + _world.tags_to_pg.setdefault("", []).append(pg) + else: + pg_tag = f"user:{pg_tag}" + + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag + return pg, prefix_store + + +def destroy_process_group(group: Optional[ProcessGroup] = None): + """ + Destroy a given process group, and deinitialize the distributed package. + + Args: + group (ProcessGroup, optional): The process group to be destroyed, if + group.WORLD is given, all process + groups including the default one will + be destroyed. + """ + global _world + + if group == GroupMember.NON_GROUP_MEMBER: + return + + if group is None: + pg = GroupMember.WORLD + else: + pg = group + + assert pg is not None + if _world.pg_map.get(pg, None) is None: + raise ValueError("Invalid process group specified") + + # When users register Python onCompletion hooks, those hooks will run on a + # different thread than the main thread. Today, the ProcessGroup dtor does + # wait for that thread. However, the dtor might finish after the Python + # Interpreter exits. After that grabbing the GIL for the Python hook will crash. + # We can either revive the interpreter when running hooks or keep the main one + # alive until all works and hooks are done. The current implementation does the + # latter. Therefore, we explicitly call _wait_for_pending_works() here to wait + # for the pending hooks to finish. + if pg.name().lower() == "nccl" and pg._has_hooks(): + pg._wait_for_pending_works() + + if group is None or group == GroupMember.WORLD: + # shutdown all backends in the order of pg names. shutting down in order because + # ncclCommAbort() was a 'collective' call in some versions of NCCL. + for pg_to_shutdown in sorted( + _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True + ): + _shutdown_backend(pg_to_shutdown) + + _update_default_pg(None) + _world.pg_map.clear() + _world.pg_names.clear() + _world.pg_group_ranks.clear() + _world.pg_backend_config.clear() + _world.pg_to_tag.clear() + _world.tags_to_pg.clear() + _world.pg_coalesce_state.clear() + _world.pg_default_device.clear() + _unregister_all_process_groups() + + # when process group doesn't have an explicit name (only WORLD (default) + # process group can have an explicit name), we use global _world.group_count + # to generate the name. We need to reset the counter on destruction to + # allow consistent value to be generated when we re-create process + # groups after some trainers recover from failure + # + # We only reset this when WORLD is being destroyed because if this + # process group is in good state, we aren't dealing with failures. + _world.group_count = 0 + else: + _shutdown_backend(pg) + del _world.pg_map[pg] + del _world.pg_names[pg] + del _world.pg_group_ranks[pg] + del _world.pg_backend_config[pg] + if pg in _world.pg_default_device: + del _world.pg_default_device[pg] + if pg in _world.pg_coalesce_state.keys(): + warnings.warn( + "Some coalesced collectives haven't been launched when " + "ProcessGroup is destroyed. They will be cleaned." + ) + del _world.pg_coalesce_state[pg] + + tag = _world.pg_to_tag.get(pg) + del _world.pg_to_tag[pg] + if tag is not None: + try: + _world.tags_to_pg[tag].remove(pg) + if tag.startswith("ptd:"): + _world.tags_to_pg[""].remove(pg) + except Exception: + pass + _unregister_process_group(pg.group_name) + + +def get_rank(group: Optional[ProcessGroup] = None) -> int: + """ + Return the rank of the current process in the provided ``group``, default otherwise. + + Rank is a unique identifier assigned to each process within a distributed + process group. They are always consecutive integers ranging from 0 to + ``world_size``. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + The rank of the process group + -1, if not part of the group + + """ + if _rank_not_in_group(group): + return -1 + + default_pg = _get_default_group() + if group is None or group is GroupMember.WORLD: + return default_pg.rank() + + return get_group_rank(group, default_pg.rank()) + + +def get_world_size(group: Optional[ProcessGroup] = None) -> int: + """ + Return the number of processes in the current process group. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + The world size of the process group + -1, if not part of the group + + """ + if _rank_not_in_group(group): + return -1 + + return _get_group_size(group) + + +def isend( + tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0 +) -> Optional[Work]: + """ + Send a tensor asynchronously. + + .. warning:: + Modifying ``tensor`` before the request completes causes undefined + behavior. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Args: + tensor (Tensor): Tensor to send. + dst (int): Destination rank on global process group (regardless of ``group`` argument) + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with remote recv + + Returns: + A distributed request object. + None, if not part of the group + + """ + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("isend") + return None + + if tensor.is_complex(): + tensor = torch.view_as_real(tensor) + + if group is None or group is GroupMember.WORLD: + pg = _get_default_group() + else: + pg = group + dst = get_group_rank(pg, dst) + + return pg.send([tensor], dst, tag) + + +def irecv( + tensor: torch.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, +) -> Optional[Work]: + """ + Receives a tensor asynchronously. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Args: + tensor (Tensor): Tensor to fill with received data. + src (int, optional): Source rank on global process group (regardless of ``group`` argument). + Will receive from any process if unspecified. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match recv with remote send + + Returns: + A distributed request object. + None, if not part of the group + + """ + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("irecv") + return None + + if tensor.is_complex(): + tensor = torch.view_as_real(tensor) + + if group is None or group is GroupMember.WORLD: + pg = _get_default_group() + else: + pg = group + + if src is None: + return pg.recv_anysource([tensor], tag) + else: + if pg is GroupMember.WORLD: + return pg.recv([tensor], src, tag) + else: + group_src_rank = get_group_rank(pg, src) + return pg.recv([tensor], group_src_rank, tag) + + +@_exception_logger +def send( + tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0 +) -> None: + """ + Send a tensor synchronously. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Args: + tensor (Tensor): Tensor to send. + dst (int): Destination rank on global process group (regardless of ``group`` argument). + Destination rank should not be the same as the rank of the current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with remote recv + + """ + if get_rank() == dst: + raise ValueError( + "Invalid destination rank: destination rank should not be the same as " + "the rank of the current process." + ) + + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("send") + return None + + if tensor.is_complex(): + tensor = torch.view_as_real(tensor) + + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + default_pg.send([tensor], dst, tag).wait() + else: + group_dst_rank = get_group_rank(group, dst) + group.send([tensor], group_dst_rank, tag).wait() + + +@_exception_logger +def recv( + tensor: torch.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, +) -> int: + """ + Receives a tensor synchronously. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Args: + tensor (Tensor): Tensor to fill with received data. + src (int, optional): Source rank on global process group (regardless of ``group`` argument). + Will receive from any process if unspecified. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match recv with remote send + + Returns: + Sender rank + -1, if not part of the group + + """ + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("recv") + return -1 + + if tensor.is_complex(): + tensor = torch.view_as_real(tensor) + + pg = group or _get_default_group() + + if src is None: + work = pg.recv_anysource([tensor], tag) + work.wait() + src_rank = work._source_rank() + if group is None or group is GroupMember.WORLD: + return src_rank + else: + return get_global_rank(pg, src_rank) + else: + if group is None or group is GroupMember.WORLD: + pg.recv([tensor], src, tag).wait() + else: + group_src_rank = get_group_rank(pg, src) + pg.recv([tensor], group_src_rank, tag).wait() + return src + + +class _IllegalWork(Work): + def __getattribute__(self, name): + if name in [ + "is_success", + "exception", + "wait", + "source_rank", + "_source_rank", + "result", + "synchronize", + ]: + raise ValueError(f"Illegal to call {name} on IllegalWork object") + + +class _CoalescingManager: + def __init__(self) -> None: + self.works: List[Work] = [] + + def append(self, work: Work): + if work: + self.works.append(work) + + def wait(self): + for work in self.works: + work.wait() + + +@contextlib.contextmanager +def _coalescing_manager( + group: Optional[ProcessGroup] = None, + device: Optional[torch.device] = None, + async_ops: Optional[bool] = False, +): + """ + Context manager used to coalesce collectives or P2P operations when possible. + + Args: + group (`ProcessGroup`, optional): The process group to work on. If None, + the default process group will be used. + device (`torch.device`, optional): Default is None, set to a device if + there isn't a `**_coalesced` implementation by the backend. + async_ops (`bool`, optional): whether the coalesced ops are async ops. + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> # Synchronous ops + >>> with _coalescing_manager(): + >>> for i in range(num_colls): + >>> dist.all_reduce(tensors[i]) + >>> # Asynchronous ops + >>> with _coalescing_manager(async_ops=True) as cm: + >>> for i in range(num_colls): + >>> dist.all_reduce(tensors[i]) + >>> cm.wait() + + .. warning:: + :func:`_coalescing_manager` currently do not support coalescing + all-reduces with different reduce operators, e.g. `ReduceOp.SUM` mixed + with `ReduceOp.PRODUCT`. + """ + group = group or _get_default_group() + op_list = _world.pg_coalesce_state.setdefault(group, []) + if op_list: + raise ValueError( + "ProcessGroup has non-empty op list at the start of coalescing" + ) + if device: + group._start_coalescing(device) + cm = _CoalescingManager() + yield cm + op_list = _world.pg_coalesce_state.pop(group) + if op_list: + # Collectives supporting "Fast Path" coalescing are captured. + # See implementation in corresponding collective APIs. + # Currently supported: + # - coalesced `all_reduce` + # - coalesced `all_gather_into_tensor` + # - coalesced `reduce_scatter_tensor` + op0 = op_list[0].op + if op0 == all_reduce: + tensors = [] + for op in op_list: + tensors.append(op.tensor) + all_reduce_opts = AllreduceCoalescedOptions() + all_reduce_opts.reduceOp = not_none(op_list[0].redop) + work = group.allreduce_coalesced(tensors, all_reduce_opts) + elif op0 == all_gather_into_tensor: + inputs = [] + outputs = [] + for op in op_list: + inputs.append(op.tensor) + outputs.append(not_none(op.dst_tensor)) + work = group.allgather_into_tensor_coalesced(outputs, inputs) + elif op0 == reduce_scatter_tensor: + inputs = [] + outputs = [] + for op in op_list: + inputs.append(op.tensor) + outputs.append(not_none(op.dst_tensor)) + reduce_opts = ReduceScatterOptions() + reduce_opts.reduceOp = not_none(op_list[0].redop) + work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts) + else: + raise AssertionError( + f"Coalescing manager does not support fast-path coalescing of {op0}, " + f"yet {op0} is still recorded in op list. This is an internal error of c10d." + ) + + if device: + # Old style of letting each coll inside the context manager to call into C++ counterpart via python binding + work = group._end_coalescing(device) + + if async_ops: + cm.append(work) # type: ignore[possibly-undefined] + else: + work.wait() # type: ignore[possibly-undefined] + + +def batch_isend_irecv(p2p_op_list): + """ + Send or Receive a batch of tensors asynchronously and return a list of requests. + + Process each of the operations in ``p2p_op_list`` and return the corresponding + requests. NCCL, Gloo, and UCC backend are currently supported. + + Args: + p2p_op_list: A list of point-to-point operations(type of each operator is + ``torch.distributed.P2POp``). The order of the isend/irecv in the list + matters and it needs to match with corresponding isend/irecv on the + remote end. + + Returns: + A list of distributed request objects returned by calling the corresponding + op in the op_list. + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank + >>> recv_tensor = torch.randn(2, dtype=torch.float32) + >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size) + >>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size) + >>> reqs = batch_isend_irecv([send_op, recv_op]) + >>> for req in reqs: + >>> req.wait() + >>> recv_tensor + tensor([2, 3]) # Rank 0 + tensor([0, 1]) # Rank 1 + + .. note:: Note that when this API is used with the NCCL PG backend, users must set + the current GPU device with `torch.cuda.set_device`, otherwise it will + lead to unexpected hang issues. + + In addition, if this API is the first collective call in the ``group`` + passed to ``dist.P2POp``, all ranks of the ``group`` must participate in + this API call; otherwise, the behavior is undefined. If this API call is + not the first collective call in the ``group``, batched P2P operations + involving only a subset of ranks of the ``group`` are allowed. + """ + _check_p2p_op_list(p2p_op_list) + group = p2p_op_list[0].group + device = p2p_op_list[0].tensor.device + if device.type == "cuda": + # NCCL style coalescing + with _coalescing_manager(group, device, async_ops=True) as cm: + for p2p_op in p2p_op_list: + p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag) + return cm.works + else: + # Backward support for Gloo + reqs = [] + for p2p_op in p2p_op_list: + work = p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag) + if work: + reqs.append(work) + return reqs + + +@_exception_logger +def broadcast(tensor, src, group=None, async_op=False): + """ + Broadcasts the tensor to the whole group. + + ``tensor`` must have the same number of elements in all processes + participating in the collective. + + Args: + tensor (Tensor): Data to be sent if ``src`` is the rank of current + process, and tensor to be used to save received data otherwise. + src (int): Source rank on global process group (regardless of ``group`` argument). + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + """ + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("broadcast") + return + + opts = BroadcastOptions() + opts.rootRank = src + opts.rootTensor = 0 + opts.asyncOp = async_op + + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + work = default_pg.broadcast([tensor], opts) + else: + group_src_rank = get_group_rank(group, src) + opts.rootRank = group_src_rank + work = group.broadcast([tensor], opts) + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces the tensor data across all machines in a way that all get the final result. + + After the call ``tensor`` is going to be bitwise identical in all processes. + + Complex tensors are supported. + + Args: + tensor (Tensor): Input and output of the collective. The function + operates in-place. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> # All tensors below are of torch.int64 type. + >>> # We have 2 process groups, 2 ranks. + >>> device = torch.device(f'cuda:{rank}') + >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank + >>> tensor + tensor([1, 2], device='cuda:0') # Rank 0 + tensor([3, 4], device='cuda:1') # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4, 6], device='cuda:0') # Rank 0 + tensor([4, 6], device='cuda:1') # Rank 1 + + >>> # All tensors below are of torch.cfloat type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j) + >>> tensor + tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 + tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0 + tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1 + + """ + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("all_reduce") + return + + if tensor.is_complex(): + if not supports_complex(op): + raise ValueError(f"all_reduce does not support {op} on complex tensors") + tensor = torch.view_as_real(tensor) + + opts = AllreduceOptions() + opts.reduceOp = op + if group is None: + group = _get_default_group() + + if group in _world.pg_coalesce_state.keys(): + # We are in coalescing context, do not issue single operation, just append a collective representation + coll = _CollOp(all_reduce, tensor, None, op, None) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group.allreduce([tensor], opts) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +@deprecated( + "`torch.distributed.all_reduce_coalesced` will be deprecated. If you must " + "use it, please revisit our documentation later at " + "https://pytorch.org/docs/main/distributed.html#collective-functions", + category=FutureWarning, +) +def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): + """ + WARNING: at this time individual shape checking is not implemented across nodes. + + For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the + rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the allreduce + operation will proceed without complaint and return erroneous outputs. This lack + of shape checking results in significant performance improvements but users of this + function should take extra care to ensure that each node passes in tensors whose + shapes match across nodes. + + Reduces each tensor in tensors (residing on the same device) across all machines + in such a way that all get the final result. + + After the call each tensor in tensors is going to bitwise identical + in all processes. + + Complex tensors are supported. + + Args: + tensors (Union[List[Tensor], Tensor]): Input and output of the collective. + The function operates in-place. + op (Optional[ReduceOp]): One of the values from + ``torch.distributed.ReduceOp`` enum. Specifies an operation used for + element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (Optional[bool]): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + """ + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + _check_tensor_list(tensors, "tensor") + _ensure_all_tensors_same_dtype(tensors) + if _rank_not_in_group(group): + _warn_not_in_group("all_reduce_coalesced") + return + + if any(t.is_complex() for t in tensors) and not supports_complex(op): + raise ValueError(f"all_reduce does not support {op} on complex tensors") + + tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] + + opts = AllreduceCoalescedOptions() + opts.reduceOp = op + group = group or _get_default_group() + work = group.allreduce_coalesced(tensors, opts) + + if async_op: + return work.get_future() + else: + work.wait() + + +@_exception_logger +def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces the tensor data across all machines. + + Only the process with rank ``dst`` is going to receive the final result. + + Args: + tensor (Tensor): Input and output of the collective. The function + operates in-place. + dst (int): Destination rank on global process group (regardless of ``group`` argument) + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + """ + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("reduce") + return + + opts = ReduceOptions() + opts.reduceOp = op + opts.rootRank = dst + + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + work = default_pg.reduce([tensor], opts) + else: + group_dst_rank = get_group_rank(group, dst) + opts.rootRank = group_dst_rank + work = group.reduce([tensor], opts) + + if async_op: + return work + else: + work.wait() + + +def _object_to_tensor(obj, device, group): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage).to(device) + if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): + backend = get_backend(group) + if backend == Backend.NCCL: + hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) + logger.warning( + "_object_to_tensor size: %s hash value: %s", byte_tensor.numel(), hash + ) + local_size = torch.LongTensor([byte_tensor.numel()]).to(device) + return byte_tensor, local_size + + +def _tensor_to_object(tensor, tensor_size, group): + if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): + backend = get_backend(group) + if backend == Backend.NCCL: + hash = torch._C._distributed_c10d._hash_tensors([tensor]) + logger.warning( + "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash + ) + tensor = tensor.cpu() + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +@_exception_logger +def all_gather_object(object_list, obj, group=None): + """ + Gathers picklable objects from the whole group into a list. + + Similar to :func:`all_gather`, but Python objects can be passed in. + Note that the object must be picklable in order to be gathered. + + Args: + object_list (list[Any]): Output list. It should be correctly sized as the + size of the group for this collective and will contain the output. + obj (Any): Pickable Python object to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + None. If the calling rank is part of this group, the output of the + collective will be populated into the input ``object_list``. If the + calling rank is not part of the group, the passed in ``object_list`` will + be unmodified. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + :func:`all_gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`all_gather_object` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`all_gather` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] + """ + if _rank_not_in_group(group): + _warn_not_in_group("all_gather_object") + return + + current_device = _get_pg_default_device(group) + input_tensor, local_size = _object_to_tensor(obj, current_device, group) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes + all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + all_gather(output_tensors, input_tensor, group=group) + # Deserialize outputs back to object. + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size, group) + + +@_exception_logger +def gather_object(obj, object_gather_list=None, dst=0, group=None): + """ + Gathers picklable objects from the whole group in a single process. + + Similar to :func:`gather`, but Python objects can be passed in. Note that the + object must be picklable in order to be gathered. + + Args: + obj (Any): Input object. Must be picklable. + object_gather_list (list[Any]): Output list. On the ``dst`` rank, it + should be correctly sized as the size of the group for this + collective and will contain the output. Must be ``None`` on non-dst + ranks. (default is ``None``) + dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + None. On the ``dst`` rank, ``object_gather_list`` will contain the + output of the collective. + + .. note:: Note that this API differs slightly from the gather collective + since it does not provide an async_op handle and thus will be a blocking + call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + :func:`gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`gather_object` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`gather` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.gather_object( + ... gather_objects[dist.get_rank()], + ... output if dist.get_rank() == 0 else None, + ... dst=0 + ... ) + >>> # On rank 0 + >>> output + ['foo', 12, {1: 2}] + """ + if _rank_not_in_group(group): + _warn_not_in_group("gather_object") + return + + # Ensure object_gather_list is specified appropriately. + my_rank = get_rank() + _validate_output_list_for_rank(my_rank, dst, object_gather_list) + current_device = _get_pg_default_device(group) + input_tensor, local_size = _object_to_tensor(obj, current_device, group) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. + all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + # Avoid populating output tensors if the result won't be gathered on this rank. + if my_rank == dst: + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + # All ranks call gather with equal-sized tensors. + gather( + input_tensor, + gather_list=output_tensors if my_rank == dst else None, # type: ignore[possibly-undefined] + dst=dst, + group=group, + ) + if my_rank != dst: + return + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group) + + +@_exception_logger +def send_object_list(object_list, dst, group=None, device=None): + """ + Sends picklable objects in ``object_list`` synchronously. + + Similar to :func:`send`, but Python objects can be passed in. + Note that all objects in ``object_list`` must be picklable in order to be + sent. + + Args: + object_list (List[Any]): List of input objects to sent. + Each object must be picklable. Receiver must provide lists of equal sizes. + dst (int): Destination rank to send ``object_list`` to. + Destination rank is based on global process group (regardless of ``group`` argument) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, the objects are + serialized and converted to tensors which are moved to the + ``device`` before sending. Default is ``None``. + + Returns: + ``None``. + + .. note:: For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + :func:`send_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`send_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`send` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes backend is not NCCL + >>> device = torch.device("cpu") + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> dist.send_object_list(objects, dst=1, device=device) + >>> else: + >>> objects = [None, None, None] + >>> dist.recv_object_list(objects, src=0, device=device) + >>> objects + ['foo', 12, {1: 2}] + """ + if get_rank() == dst: + raise ValueError( + "Invalid destination rank: destination rank should not be the same as " + "the rank of the current process." + ) + + if _rank_not_in_group(group): + _warn_not_in_group("send_object_list") + return + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # sent to this device. + current_device = device or _get_pg_default_device(group) + # Serialize object_list elements to tensors on src rank. + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) + object_sizes_tensor = torch.cat(size_list) + + # Send object sizes + send(object_sizes_tensor, dst=dst, group=group) + + # Concatenate and send serialized object tensors + # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list + # has only one element, we can skip the copy. + if len(tensor_list) == 1: # type: ignore[possibly-undefined] + object_tensor = tensor_list[0] + else: + object_tensor = torch.cat(tensor_list) + + send(object_tensor, dst=dst, group=group) + + +@_exception_logger +def recv_object_list(object_list, src=None, group=None, device=None): + """ + Receives picklable objects in ``object_list`` synchronously. + + Similar to :func:`recv`, but can receive Python objects. + + Args: + object_list (List[Any]): List of objects to receive into. + Must provide a list of sizes equal to the size of the list being sent. + src (int, optional): Source rank from which to recv ``object_list``. + Source rank is based on global process group (regardless of ``group`` argument) + Will receive from any rank if set to None. Default is ``None``. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, receives on this device. + Default is ``None``. + + Returns: + Sender rank. -1 if rank is not part of the group. If rank is part of the group, + ``object_list`` will contain the sent objects from ``src`` rank. + + .. note:: For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + :func:`recv_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`recv_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`recv` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes backend is not NCCL + >>> device = torch.device("cpu") + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> dist.send_object_list(objects, dst=1, device=device) + >>> else: + >>> objects = [None, None, None] + >>> dist.recv_object_list(objects, src=0, device=device) + >>> objects + ['foo', 12, {1: 2}] + """ + if _rank_not_in_group(group): + _warn_not_in_group("recv_object_list") + return -1 + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # received to this device. + current_device = device or _get_pg_default_device(group) + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) + + # Receive object sizes + rank_sizes = recv(object_sizes_tensor, src=src, group=group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=current_device, + ) + + rank_objects = recv(object_tensor, src=src, group=group) + assert ( + rank_sizes == rank_objects + ), "Mismatch in return ranks for object sizes and objects." + # Deserialize objects using their stored sizes. + offset = 0 + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size, group) + return rank_objects + + +@_exception_logger +def broadcast_object_list(object_list, src=0, group=None, device=None): + """ + Broadcasts picklable objects in ``object_list`` to the whole group. + + Similar to :func:`broadcast`, but Python objects can be passed in. + Note that all objects in ``object_list`` must be picklable in order to be + broadcasted. + + Args: + object_list (List[Any]): List of input objects to broadcast. + Each object must be picklable. Only objects on the ``src`` rank will + be broadcast, but each rank must provide lists of equal sizes. + src (int): Source rank from which to broadcast ``object_list``. + Source rank is based on global process group (regardless of ``group`` argument) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, the objects are + serialized and converted to tensors which are moved to the + ``device`` before broadcasting. Default is ``None``. + + Returns: + ``None``. If rank is part of the group, ``object_list`` will contain the + broadcasted objects from ``src`` rank. + + .. note:: For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. note:: Note that this API differs slightly from the :func:`broadcast` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. warning:: + :func:`broadcast_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`broadcast_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`broadcast` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> objects = [None, None, None] + >>> # Assumes backend is not NCCL + >>> device = torch.device("cpu") + >>> dist.broadcast_object_list(objects, src=0, device=device) + >>> objects + ['foo', 12, {1: 2}] + """ + if _rank_not_in_group(group): + _warn_not_in_group("broadcast_object_list") + return + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # broadcasted to this device. + current_device = device or _get_pg_default_device(group) + my_rank = get_rank() + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) + + # Broadcast object sizes + broadcast(object_sizes_tensor, src=src, group=group) + + # Concatenate and broadcast serialized object tensors + # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list + # has only one element, we can skip the copy. + if my_rank == src: + if len(tensor_list) == 1: # type: ignore[possibly-undefined] + object_tensor = tensor_list[0] + else: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=current_device, + ) + + broadcast(object_tensor, src=src, group=group) + # Deserialize objects using their stored sizes. + offset = 0 + if my_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size, group) + + +@_exception_logger +def scatter_object_list( + scatter_object_output_list, scatter_object_input_list, src=0, group=None +): + """ + Scatters picklable objects in ``scatter_object_input_list`` to the whole group. + + Similar to :func:`scatter`, but Python objects can be passed in. On + each rank, the scattered object will be stored as the first element of + ``scatter_object_output_list``. Note that all objects in + ``scatter_object_input_list`` must be picklable in order to be scattered. + + Args: + scatter_object_output_list (List[Any]): Non-empty list whose first + element will store the object scattered to this rank. + scatter_object_input_list (List[Any]): List of input objects to scatter. + Each object must be picklable. Only objects on the ``src`` rank will + be scattered, and the argument can be ``None`` for non-src ranks. + src (int): Source rank from which to scatter ``scatter_object_input_list``. + Source rank is based on global process group (regardless of ``group`` argument). + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + ``None``. If rank is part of the group, ``scatter_object_output_list`` + will have its first element set to the scattered object for this rank. + + .. note:: Note that this API differs slightly from the scatter collective + since it does not provide an ``async_op`` handle and thus will be a + blocking call. + + .. warning:: + :func:`scatter_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`scatter_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`scatter` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> # Can be any list on non-src ranks, elements are not used. + >>> objects = [None, None, None] + >>> output_list = [None] + >>> dist.scatter_object_list(output_list, objects, src=0) + >>> # Rank i gets objects[i]. For example, on rank 2: + >>> output_list + [{1: 2}] + """ + if _rank_not_in_group(group): + _warn_not_in_group("scatter_object_list") + return + + if ( + not isinstance(scatter_object_output_list, list) + or len(scatter_object_output_list) < 1 + ): + raise ValueError( + "Expected argument scatter_object_output_list to be a list of size at least 1." + ) + + my_rank = get_rank() + pg_device = _get_pg_default_device(group) + if my_rank == src: + tensor_list, tensor_sizes = zip( + *[ + _object_to_tensor(obj, pg_device, group) + for obj in scatter_object_input_list + ] + ) + tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) + + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. + if my_rank == src: + max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined] + for tensor in tensor_list: # type: ignore[possibly-undefined] + tensor.resize_(max_tensor_size) + else: + max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) + broadcast(max_tensor_size, src=src, group=group) + + # Scatter actual serialized objects + output_tensor = torch.empty( + max_tensor_size.item(), dtype=torch.uint8, device=pg_device + ) + scatter( + output_tensor, + scatter_list=None if my_rank != src else tensor_list, # type: ignore[possibly-undefined] + src=src, + group=group, + ) + + # Scatter per-object sizes to trim tensors when deserializing back to object + obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) + scatter( + obj_tensor_size, + scatter_list=None if my_rank != src else tensor_sizes, # type: ignore[possibly-undefined] + src=src, + group=group, + ) + + # Deserialize back to object + scatter_object_output_list[0] = _tensor_to_object( + output_tensor, obj_tensor_size, group + ) + + +@_exception_logger +def all_gather(tensor_list, tensor, group=None, async_op=False): + """ + Gathers tensors from the whole group in a list. + + Complex and uneven sized tensors are supported. + + Args: + tensor_list (list[Tensor]): Output list. It should contain + correctly-sized tensors to be used for output of the collective. + Uneven sized tensors are supported. + tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # xdoctest: +SKIP("need process group init") + >>> # All tensors below are of torch.int64 dtype. + >>> # We have 2 process groups, 2 ranks. + >>> device = torch.device(f'cuda:{rank}') + >>> tensor_list = [torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)] + >>> tensor_list + [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0 + [tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1 + >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank + >>> tensor + tensor([1, 2], device='cuda:0') # Rank 0 + tensor([3, 4], device='cuda:1') # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0 + [tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1 + + >>> # All tensors below are of torch.cfloat dtype. + >>> # We have 2 process groups, 2 ranks. + >>> tensor_list = [torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)] + >>> tensor_list + [tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0 + [tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1 + >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j) + >>> tensor + tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 + tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0 + [tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1 + + """ + _check_tensor_list(tensor_list, "tensor_list") + _check_single_tensor(tensor, "tensor") + _ensure_all_tensors_same_dtype(tensor_list, tensor) + if _rank_not_in_group(group): + _warn_not_in_group("all_gather") + return + + tensor_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list + ] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + + group = group or _get_default_group() + work = group.allgather([tensor_list], [tensor]) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False): + """ + Gather tensors from all ranks and put them in a single output tensor. + + This function requires all tensors to be the same size on each process. + + Args: + output_tensor (Tensor): Output tensor to accommodate tensor elements + from all ranks. It must be correctly sized to have one of the + following forms: + (i) a concatenation of all the input tensors along the primary + dimension; for definition of "concatenation", see ``torch.cat()``; + (ii) a stack of all the input tensors along the primary dimension; + for definition of "stack", see ``torch.stack()``. + Examples below may better explain the supported output forms. + input_tensor (Tensor): Tensor to be gathered from current rank. + Different from the ``all_gather`` API, the input tensors in this + API must have the same size across all ranks. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # xdoctest: +SKIP("need process group init") + >>> # All tensors below are of torch.int64 dtype and on CUDA devices. + >>> # We have two ranks. + >>> device = torch.device(f'cuda:{rank}') + >>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank + >>> tensor_in + tensor([1, 2], device='cuda:0') # Rank 0 + tensor([3, 4], device='cuda:1') # Rank 1 + >>> # Output in concatenation form + >>> tensor_out = torch.zeros(world_size * 2, dtype=torch.int64, device=device) + >>> dist.all_gather_into_tensor(tensor_out, tensor_in) + >>> tensor_out + tensor([1, 2, 3, 4], device='cuda:0') # Rank 0 + tensor([1, 2, 3, 4], device='cuda:1') # Rank 1 + >>> # Output in stack form + >>> tensor_out2 = torch.zeros(world_size, 2, dtype=torch.int64, device=device) + >>> dist.all_gather_into_tensor(tensor_out2, tensor_in) + >>> tensor_out2 + tensor([[1, 2], + [3, 4]], device='cuda:0') # Rank 0 + tensor([[1, 2], + [3, 4]], device='cuda:1') # Rank 1 + + .. warning:: + The Gloo backend does not support this API. + + """ + _check_single_tensor(input_tensor, "input_tensor") + _check_single_tensor(output_tensor, "output_tensor") + if _rank_not_in_group(group): + _warn_not_in_group("all_gather_into_tensor") + return + + output_tensor = ( + output_tensor + if not output_tensor.is_complex() + else torch.view_as_real(output_tensor) + ) + input_tensor = ( + input_tensor + if not input_tensor.is_complex() + else torch.view_as_real(input_tensor) + ) + + opts = AllgatherOptions() + opts.asyncOp = async_op + + group = group or _get_default_group() + + if group in _world.pg_coalesce_state.keys(): + # We are in coalescing context, do not issue single operation, just append a collective representation + coll = _CollOp(all_gather_into_tensor, input_tensor, output_tensor) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group._allgather_base(output_tensor, input_tensor, opts) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +@deprecated( + "`torch.distributed._all_gather_base` is a private function and will be deprecated. " + "Please use `torch.distributed.all_gather_into_tensor` instead.", + category=FutureWarning, +) +def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False): + """ + Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. + + Args: + output_tensor (Tensor): Output tensor. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. warning:: + `_all_gather_base` is a private function. Users should use + `all_gather_into_tensor` instead. + + """ + return all_gather_into_tensor(output_tensor, input_tensor, group, async_op) + + +@_exception_logger +@deprecated( + "`torch.distributed.all_gather_coalesced` will be deprecated. If you must use it, " + "please revisit our documentation later at " + "https://pytorch.org/docs/main/distributed.html#collective-functions", + category=FutureWarning, +) +def all_gather_coalesced( + output_tensor_lists, input_tensor_list, group=None, async_op=False +): + """ + Gathers input tensors from the whole group in a list in a coalesced manner. + + Complex tensors are supported. + + Args: + output_tensor_lists (list[list[Tensor]]): Output list. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor_list (list[Tensor]): Tensors to be broadcast from + current process. At least one tensor has to be non empty. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Example: + we have 2 process groups, 2 ranks. + rank 0 passes: + input_tensor_list = [[[1, 1], [1, 1]], [2], [3, 3]] + output_tensor_lists = + [[[[-1, -1], [-1, -1]], [-1], [-1, -1]], + [[[-1, -1], [-1, -1]], [-1], [-1, -1]]] + rank 1 passes: + input_tensor_list = [[[3, 3], [3, 3]], [5], [1, 1]] + output_tensor_lists = + [[[[-1, -1], [-1, -1]], [-1], [-1, -1]], + [[[-1, -1], [-1, -1]], [-1], [-1, -1]]] + both rank 0 and 1 get: + output_tensor_lists = + [[[1, 1], [1, 1]], [2], [3, 3]], + [[3, 3], [3, 3]], [5], [1, 1]]]. + + WARNING: at this time individual shape checking is not implemented across nodes. + For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the + rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the + all_gather_coalesced operation will proceed without complaint and return + erroneous outputs. This lack of shape checking results in significant + performance improvements but users of this function should take extra care + to ensure that each node passes in tensors whose shapes match across nodes. + """ + # We only check basic compatibility with C++ params here, C++ code will + # do shape and type checking. + if _rank_not_in_group(group): + _warn_not_in_group("all_gather_coalesced") + return + _check_tensor_list(input_tensor_list, "input_tensor_list") + _ensure_all_tensors_same_dtype(input_tensor_list) + if not isinstance(output_tensor_lists, list): + raise TypeError( + "Invalid function argument: output_tensor_lists should be a list" + ) + for output_tensor_list in output_tensor_lists: + _check_tensor_list(output_tensor_list, "output_tensor_lists") + _ensure_all_tensors_same_dtype(output_tensor_list) + + output_tensor_lists = [ + [t if not t.is_complex() else torch.view_as_real(t) for t in l] + for l in output_tensor_lists + ] + input_tensor_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list + ] + + group = group or _get_default_group() + work = group.allgather_coalesced(output_tensor_lists, input_tensor_list) + + if async_op: + return work.get_future() + else: + work.wait() + + +def _validate_output_list_for_rank(my_rank, dst, gather_list): + if dst == my_rank: + if not gather_list: + raise ValueError( + "Argument ``gather_list`` must be specified on destination rank." + ) + elif gather_list: + raise ValueError( + "Argument ``gather_list`` must NOT be specified " + "on non-destination ranks." + ) + + +@_exception_logger +def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): + """ + Gathers a list of tensors in a single process. + + This function requires all tensors to be the same size on each process. + + Args: + tensor (Tensor): Input tensor. + gather_list (list[Tensor], optional): List of appropriately, + same-sized tensors to use for gathered data + (default is None, must be specified on the destination rank) + dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0) + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + """ + _check_single_tensor(tensor, "tensor") + + # Parameter ``gather_list`` may be left unspecified on non-dst ranks. + if gather_list: + _check_tensor_list(gather_list, "gather_list") + else: + gather_list = [] + _ensure_all_tensors_same_dtype(tensor, gather_list) + + if _rank_not_in_group(group): + _warn_not_in_group("gather") + return + + my_rank = get_rank() + _validate_output_list_for_rank(my_rank, dst, gather_list) + output_tensors = [gather_list] if dst == my_rank else [] + input_tensors = [tensor] + + opts = GatherOptions() + opts.rootRank = dst + + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + work = default_pg.gather(output_tensors, input_tensors, opts) + else: + group_dst_rank = get_group_rank(group, dst) + opts.rootRank = group_dst_rank + work = group.gather(output_tensors, input_tensors, opts) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): + """ + Scatters a list of tensors to all processes in a group. + + Each process will receive exactly one tensor and store its data in the + ``tensor`` argument. + + Complex tensors are supported. + + Args: + tensor (Tensor): Output tensor. + scatter_list (list[Tensor]): List of tensors to scatter (default is + None, must be specified on the source rank) + src (int): Source rank on global process group (regardless of ``group`` argument). + Default is 0 + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: Note that all Tensors in scatter_list must have the same size. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> tensor_size = 2 + >>> t_ones = torch.ones(tensor_size) + >>> t_fives = torch.ones(tensor_size) * 5 + >>> output_tensor = torch.zeros(tensor_size) + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> # Only tensors, all of which must be the same size. + >>> scatter_list = [t_ones, t_fives] + >>> else: + >>> scatter_list = None + >>> dist.scatter(output_tensor, scatter_list, src=0) + >>> # Rank i gets scatter_list[i]. For example, on rank 1: + >>> output_tensor + tensor([5., 5.]) + + """ + _check_single_tensor(tensor, "tensor") + + # Parameter ``scatter_list`` may be left unspecified on non-src ranks. + if scatter_list: + _check_tensor_list(scatter_list, "scatter_list") + else: + scatter_list = [] + _ensure_all_tensors_same_dtype(tensor, scatter_list) + + if _rank_not_in_group(group): + _warn_not_in_group("scatter") + return + scatter_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list + ] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + + my_rank = get_rank() + if src == my_rank: + if not scatter_list: + raise ValueError( + "Argument ``scatter_list`` must be specified on source rank." + ) + input_tensors = [scatter_list] + output_tensors = [tensor] + else: + if scatter_list: + raise ValueError( + "Argument ``scatter_list`` must NOT be specified " + "on non-source ranks." + ) + input_tensors = [] + output_tensors = [tensor] + + opts = ScatterOptions() + opts.rootRank = src + opts.asyncOp = async_op + + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + work = default_pg.scatter(output_tensors, input_tensors, opts) + else: + group_src_rank = get_group_rank(group, src) + opts.rootRank = group_src_rank + work = group.scatter(output_tensors, input_tensors, opts) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces, then scatters a list of tensors to all processes in a group. + + Args: + output (Tensor): Output tensor. + input_list (list[Tensor]): List of tensors to reduce and scatter. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + """ + _check_single_tensor(output, "output") + _check_tensor_list(input_list, "input_list") + _ensure_all_tensors_same_dtype(output, input_list) + if _rank_not_in_group(group): + _warn_not_in_group("reduce_scatter") + return + + opts = ReduceScatterOptions() + opts.reduceOp = op + + group = group or _get_default_group() + work = group.reduce_scatter([output], [input_list], opts) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces, then scatters a tensor to all ranks in a group. + + Args: + output (Tensor): Output tensor. It should have the same size across all + ranks. + input (Tensor): Input tensor to be reduced and scattered. Its size + should be output tensor size times the world size. The input tensor + can have one of the following shapes: + (i) a concatenation of the output tensors along the primary + dimension, or + (ii) a stack of the output tensors along the primary dimension. + For definition of "concatenation", see ``torch.cat()``. + For definition of "stack", see ``torch.stack()``. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + Examples: + >>> # xdoctest: +SKIP("need process group init") + >>> # All tensors below are of torch.int64 dtype and on CUDA devices. + >>> # We have two ranks. + >>> device = torch.device(f'cuda:{rank}') + >>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device) + >>> # Input in concatenation form + >>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device) + >>> tensor_in + tensor([0, 1, 2, 3], device='cuda:0') # Rank 0 + tensor([0, 1, 2, 3], device='cuda:1') # Rank 1 + >>> dist.reduce_scatter_tensor(tensor_out, tensor_in) + >>> tensor_out + tensor([0, 2], device='cuda:0') # Rank 0 + tensor([4, 6], device='cuda:1') # Rank 1 + >>> # Input in stack form + >>> tensor_in = torch.reshape(tensor_in, (world_size, 2)) + >>> tensor_in + tensor([[0, 1], + [2, 3]], device='cuda:0') # Rank 0 + tensor([[0, 1], + [2, 3]], device='cuda:1') # Rank 1 + >>> dist.reduce_scatter_tensor(tensor_out, tensor_in) + >>> tensor_out + tensor([0, 2], device='cuda:0') # Rank 0 + tensor([4, 6], device='cuda:1') # Rank 1 + + .. warning:: + The Gloo backend does not support this API. + + """ + _check_single_tensor(output, "output") + _check_single_tensor(input, "input") + + if _rank_not_in_group(group): + _warn_not_in_group("reduce_scatter_tensor") + return + + opts = ReduceScatterOptions() + opts.reduceOp = op + opts.asyncOp = async_op + + group = group or _get_default_group() + + # Check if we are in coalescing context + # If we are, do not issue single operation, just append a collective representation + if group in _world.pg_coalesce_state.keys(): + coll = _CollOp(reduce_scatter_tensor, input, output, op, None) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group._reduce_scatter_base(output, input, opts) + + if async_op: + return work + else: + work.wait() + + +@deprecated( + "`torch.distributed._reduce_scatter_base` is a private function and will be deprecated. " + "Please use `torch.distributed.reduce_scatter_tensor` instead.", + category=FutureWarning, +) +def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces, then scatters a flattened tensor to all processes in a group. + + Args: + output (Tensor): Output tensor. + input (Tensor): Input tensor that is of size output tensor size times world size + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + .. warning:: + `_reduce_scatter_base` is a private function. Users should use + `reduce_scatter_tensor` instead. + + """ + return reduce_scatter_tensor(output, input, op, group, async_op) + + +@_exception_logger +def all_to_all_single( + output, + input, + output_split_sizes=None, + input_split_sizes=None, + group=None, + async_op=False, +): + """ + Split input tensor and then scatter the split list to all processes in a group. + + Later the received tensors are concatenated from all the processes in the group + and returned as a single output tensor. + + Complex tensors are supported. + + Args: + output (Tensor): Gathered concatenated output tensor. + input (Tensor): Input tensor to scatter. + output_split_sizes: (list[Int], optional): Output split sizes for dim 0 + if specified None or empty, dim 0 of ``output`` tensor must divide + equally by ``world_size``. + input_split_sizes: (list[Int], optional): Input split sizes for dim 0 + if specified None or empty, dim 0 of ``input`` tensor must divide + equally by ``world_size``. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + .. warning:: + `all_to_all_single` is experimental and subject to change. + + Examples: + >>> # xdoctest: +SKIP("Undefined rank") + >>> input = torch.arange(4) + rank * 4 + >>> input + tensor([0, 1, 2, 3]) # Rank 0 + tensor([4, 5, 6, 7]) # Rank 1 + tensor([8, 9, 10, 11]) # Rank 2 + tensor([12, 13, 14, 15]) # Rank 3 + >>> output = torch.empty([4], dtype=torch.int64) + >>> dist.all_to_all_single(output, input) + >>> output + tensor([0, 4, 8, 12]) # Rank 0 + tensor([1, 5, 9, 13]) # Rank 1 + tensor([2, 6, 10, 14]) # Rank 2 + tensor([3, 7, 11, 15]) # Rank 3 + + >>> # Essentially, it is similar to following operation: + >>> scatter_list = list(input.chunk(world_size)) + >>> gather_list = list(output.chunk(world_size)) + >>> for i in range(world_size): + >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i) + + >>> # Another example with uneven split + >>> input + tensor([0, 1, 2, 3, 4, 5]) # Rank 0 + tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 + tensor([20, 21, 22, 23, 24]) # Rank 2 + tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 + >>> input_splits + [2, 2, 1, 1] # Rank 0 + [3, 2, 2, 2] # Rank 1 + [2, 1, 1, 1] # Rank 2 + [2, 2, 2, 1] # Rank 3 + >>> output_splits + [2, 3, 2, 2] # Rank 0 + [2, 2, 1, 2] # Rank 1 + [1, 2, 1, 2] # Rank 2 + [1, 2, 1, 1] # Rank 3 + >>> output = ... + >>> dist.all_to_all_single(output, input, output_splits, input_splits) + >>> output + tensor([ 0, 1, 10, 11, 12, 20, 21, 30, 31]) # Rank 0 + tensor([ 2, 3, 13, 14, 22, 32, 33]) # Rank 1 + tensor([ 4, 15, 16, 23, 34, 35]) # Rank 2 + tensor([ 5, 17, 18, 24, 36]) # Rank 3 + + + >>> # Another example with tensors of torch.cfloat type. + >>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j) + >>> input + tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0 + tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1 + tensor([9+9j, 10+10j, 11+11j, 12+12j]) # Rank 2 + tensor([13+13j, 14+14j, 15+15j, 16+16j]) # Rank 3 + >>> output = torch.empty([4], dtype=torch.int64) + >>> dist.all_to_all_single(output, input) + >>> output + tensor([1+1j, 5+5j, 9+9j, 13+13j]) # Rank 0 + tensor([2+2j, 6+6j, 10+10j, 14+14j]) # Rank 1 + tensor([3+3j, 7+7j, 11+11j, 15+15j]) # Rank 2 + tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3 + """ + if _rank_not_in_group(group): + _warn_not_in_group("all_to_all_single") + return + + opts = AllToAllOptions() + _check_single_tensor(output, "output") + _check_single_tensor(input, "input") + _ensure_all_tensors_same_dtype(output, input) + + if input.is_complex(): + input = torch.view_as_real(input) + if output.is_complex(): + output = torch.view_as_real(output) + + output_split_sizes = [] if output_split_sizes is None else output_split_sizes + input_split_sizes = [] if input_split_sizes is None else input_split_sizes + + group = group or _get_default_group() + work = group.alltoall_base( + output, input, output_split_sizes, input_split_sizes, opts + ) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False): + """ + Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. + + Complex tensors are supported. + + Args: + output_tensor_list (list[Tensor]): List of tensors to be gathered one + per rank. + input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + .. warning:: + `all_to_all` is experimental and subject to change. + + Examples: + >>> # xdoctest: +SKIP("Undefined rank") + >>> input = torch.arange(4) + rank * 4 + >>> input = list(input.chunk(4)) + >>> input + [tensor([0]), tensor([1]), tensor([2]), tensor([3])] # Rank 0 + [tensor([4]), tensor([5]), tensor([6]), tensor([7])] # Rank 1 + [tensor([8]), tensor([9]), tensor([10]), tensor([11])] # Rank 2 + [tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3 + >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4)) + >>> dist.all_to_all(output, input) + >>> output + [tensor([0]), tensor([4]), tensor([8]), tensor([12])] # Rank 0 + [tensor([1]), tensor([5]), tensor([9]), tensor([13])] # Rank 1 + [tensor([2]), tensor([6]), tensor([10]), tensor([14])] # Rank 2 + [tensor([3]), tensor([7]), tensor([11]), tensor([15])] # Rank 3 + + >>> # Essentially, it is similar to following operation: + >>> scatter_list = input + >>> gather_list = output + >>> for i in range(world_size): + >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i) + + >>> input + tensor([0, 1, 2, 3, 4, 5]) # Rank 0 + tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 + tensor([20, 21, 22, 23, 24]) # Rank 2 + tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 + >>> input_splits + [2, 2, 1, 1] # Rank 0 + [3, 2, 2, 2] # Rank 1 + [2, 1, 1, 1] # Rank 2 + [2, 2, 2, 1] # Rank 3 + >>> output_splits + [2, 3, 2, 2] # Rank 0 + [2, 2, 1, 2] # Rank 1 + [1, 2, 1, 2] # Rank 2 + [1, 2, 1, 1] # Rank 3 + >>> input = list(input.split(input_splits)) + >>> input + [tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])] # Rank 0 + [tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1 + [tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])] # Rank 2 + [tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])] # Rank 3 + >>> output = ... + >>> dist.all_to_all(output, input) + >>> output + [tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])] # Rank 0 + [tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])] # Rank 1 + [tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])] # Rank 2 + [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3 + + >>> # Another example with tensors of torch.cfloat type. + >>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j) + >>> input = list(input.chunk(4)) + >>> input + [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0 + [tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])] # Rank 1 + [tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])] # Rank 2 + [tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])] # Rank 3 + >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4)) + >>> dist.all_to_all(output, input) + >>> output + [tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])] # Rank 0 + [tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])] # Rank 1 + [tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])] # Rank 2 + [tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])] # Rank 3 + + """ + if _rank_not_in_group(group): + _warn_not_in_group("all_to_all") + return + + opts = AllToAllOptions() + _check_tensor_list(output_tensor_list, "output_tensor_list") + _check_tensor_list(input_tensor_list, "input_tensor_list") + _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list) + + input_tensor_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list + ] + output_tensor_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in output_tensor_list + ] + + group = group or _get_default_group() + work = group.alltoall(output_tensor_list, input_tensor_list, opts) + + if async_op: + return work + else: + work.wait() + + +@_exception_logger +def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None): + """ + Synchronize all processes. + + This collective blocks processes until the whole group enters this function, + if async_op is False, or if async work handle is called on wait(). + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + device_ids ([int], optional): List of device/GPU ids. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: `ProcessGroupNCCL` now relies on stream synchronization instead of + device synchronization to block the CPU. Thus, please do not assume that + `barrier()` would perform a device synchronization. + """ + if _rank_not_in_group(group): + _warn_not_in_group("barrier") + return + + opts = BarrierOptions() + opts.device = _get_pg_default_device(group) + if device_ids is not None: + if isinstance(device_ids, list): + opts.device_ids = device_ids + else: + raise TypeError( + "Invalid function argument: device_ids type should be List[int]" + ) + + group = group or _get_default_group() + work = group.barrier(opts=opts) + + if async_op: + return work + else: + work.wait() + + +def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False): + """ + Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout. + + It is able to report ranks that did not pass this barrier within the provided timeout. + Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0. + Rank 0 will block until all send /recv from other ranks are processed, and will report + failures for ranks that failed to respond in time. Note that if one rank does not reach the + monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier. + + This collective will block all processes/ranks in the group, until the + whole group exits the function successfully, making it useful for debugging + and synchronizing. However, it can have a performance impact and should only + be used for debugging or scenarios that require full synchronization points + on the host-side. For debugging purposes, this barrier can be inserted + before the application's collective calls to check if any ranks are + desynchronized. + + .. note:: Note that this collective is only supported with the GLOO backend. + + Args: + group (ProcessGroup, optional): The process group to work on. If + ``None``, the default process group will be used. + timeout (datetime.timedelta, optional): Timeout for monitored_barrier. + If ``None``, the default process group timeout will be used. + wait_all_ranks (bool, optional): Whether to collect all failed ranks or + not. By default, this is ``False`` and ``monitored_barrier`` on rank 0 + will throw on the first failed rank it encounters in order to fail + fast. By setting ``wait_all_ranks=True`` ``monitored_barrier`` will + collect all failed ranks and throw an error containing information + about all failed ranks. + + Returns: + ``None``. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() != 1: + >>> dist.monitored_barrier() # Raises exception indicating that + >>> # rank 1 did not call into monitored_barrier. + >>> # Example with wait_all_ranks=True + >>> if dist.get_rank() == 0: + >>> dist.monitored_barrier(wait_all_ranks=True) # Raises exception + >>> # indicating that ranks 1, 2, ... world_size - 1 did not call into + >>> # monitored_barrier. + """ + # Need to call rank not in group before using the group, otherwise + # "Invalid process group" error is raised. + if _rank_not_in_group(group): + _warn_not_in_group("monitored_barrier") + return + + if get_backend(group) != Backend.GLOO: + raise ValueError("monitored_barrier is only implemented for GLOO backend.") + + if timeout is None: + timeout = _get_default_timeout(get_backend(group)) + elif isinstance(timeout, float): + # TODO(whc) aparently some existing test case for monitored_barrier passes in a timeout in float format? + warnings.warn( + "Please specify timeout arg as a timedelta. " + f"Converting current value of {timeout} assuming it represents seconds", + ) + timeout = timedelta(seconds=timeout) + + _check_valid_timeout(timeout) + + group_to_use = _get_default_group() if group is None else group + return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks) + + +def _create_process_group_wrapper( + wrapped_pg: torch._C._distributed_c10d.Backend, + store_prefix: str, + store: Store, + rank: int, + world_size: int, + timeout: timedelta = default_pg_timeout, +): + assert _GLOO_AVAILABLE, "ProcessGroupWrapper unsupported without GLOO backend." + + # (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate... + + # Create a separate prefix store for the helper process group. + prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}" + store = PrefixStore(prefix, store) + helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout) + # Wrap the underlying pg with ProcessGroupWrapper. + wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg) + return wrapped_pg + + +# helper function for deterministically hashing a list of ranks +def _hash_ranks(ranks: List[int]): + return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest() + + +# Takes a list of ranks and computes an integer color +def _process_group_color(ranks: List[int]) -> int: + # Convert our hash to an int, but avoid negative numbers by shifting a bit. + return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1) + + +def _process_group_name(ranks, use_hashed_name): + global _world + if use_hashed_name: + pg_name = _hash_ranks(ranks) + while pg_name in _world.pg_names.values(): + pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest() + else: + pg_name = str(_world.group_count) + _world.group_count += 1 + return pg_name + + +def _get_backend_from_str(backend: Optional[str] = None) -> Backend: + # Default to the same backend as the global process group + # if backend is not specified. + if not backend: + backend = get_backend(_get_default_group()) + return Backend(backend) + + +def _is_safe_to_split() -> bool: + """ + Checks if it is safe to split the any process group in the world. + This is only safe if the default pg has a bound device id, otherwise + users must be aware that a pg is only splittable after the first collective is + issued. + """ + return False if _get_default_group().bound_device_id is None else True + + +@_time_logger +def split_group( + parent_pg: Optional[ProcessGroup] = None, + split_ranks: Optional[list] = None, + timeout: Optional[timedelta] = None, + pg_options: Optional[Any] = None, + group_desc: Optional[str] = None, +) -> Optional[ProcessGroup]: + """ + Create a new process group splitted from the given parent process group. + + warning:: This is an experimental API and only the ``NCCL`` backend supports this API. + Other backends will raise an error. + Users of this API must gurantee that all ranks in the parent group enter this API call, + and the split of the sub groups is the same accross all ranks in the parent group. + + Args: + parent_pg (ProcessGroup, optional): The parent process group. If None, + the default process group will be used. Users need to gurantee that + the parent group is fully initialized (e.g, communicators are initialized) + split_ranks (list[list[int]]): the split ranks, which is a list of list of ranks. + Users need to make sure the validity of the split ranks such that one + split (represented by one inner list of ints) does not overlap with any other split. + Note that the ranks in each split is the group rank (instead of global rank) + in the parent pg. For example, if the parent group has 4 ranks, and split_ranks can be + [[0, 1], [2, 3]]. Note [[0,1]] is also a valid split, in which case ranks 2, 3 would + return a non-group member. + timeout (timedelta, optional): see `init_process_group` for details and default value. + pg_options (ProcessGroupOptions, optional): only ProcessGroupNCCLOptions is supported now. + specifying what additional options need to be passed in during + the construction of specific process groups. i.e.``is_high_priority_stream`` + can be specified so that process group can pick up high priority cuda streams. + For other availble options to config nccl, + See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t + group_desc (str, optional): a string to describe the process group. + + Returns: + ProcessGroup if the current rank is within one split/subgroup given by split_ranks, + or None if the current rank is not part of any split_ranks`. + + """ + # check inputs + if split_ranks is None: + raise ValueError("split_ranks cannot be None") + + global _world + default_pg = _get_default_group() + device_id = default_pg.bound_device_id + if not device_id: + raise RuntimeError( + "No device associated with the default pg, not safe to split any process groups" + ) + default_backend, default_store = _world.pg_map[default_pg] + global_rank = default_pg.rank() + global_world_size = default_pg.size() + + if not parent_pg: + parent_pg = default_pg + if parent_pg not in _world.pg_group_ranks: + raise ValueError(f"Group {parent_pg} is not registered") + + parent_global_to_group_ranks = _world.pg_group_ranks[parent_pg] + parent_group_to_global_ranks = { + group_rank: global_rank + for global_rank, group_rank in parent_global_to_group_ranks.items() + } + + if global_rank not in parent_global_to_group_ranks: + raise ValueError( + f"Global rank {global_rank} is not part of the parent group {parent_pg}" + ) + + parent_group_rank = parent_global_to_group_ranks[global_rank] + parent_backend = parent_pg._get_backend(torch.device("cuda")) + + # if the parent backend does not support splitting, raise error + # currently this API only support NCCL backend + if ( + not parent_backend + or not parent_backend.supports_splitting + or not isinstance(parent_backend, ProcessGroupNCCL) + ): + raise RuntimeError( + "No backend for the parent process group or its backend does not support splitting" + ) + + # set the group_desc before the color or no_cloor split + group_desc = ( + f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" + if group_desc is None + else group_desc + ) + + parent_backend_str, _ = _world.pg_map[parent_pg] + # same type of backend as the parent process group + backend = Backend(parent_backend_str) + backend_config = BackendConfig(backend) + + if pg_options is not None: + assert isinstance( + pg_options, ProcessGroupNCCL.Options + ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options" + else: + # default pg_options same as the parent process group + pg_options = parent_backend.options + + # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, + # which may just pass their timeout value (or None) + if timeout is None: + timeout = _get_default_timeout(backend) + _check_valid_timeout(timeout) + + # find my group of ranks and my group local rank in split_ranks + my_group = None + group_rank = -1 + + for split_group in split_ranks: + if len(split_group) == 0: + raise ValueError("the split group cannot be empty") + if len(split_group) > global_world_size: + raise ValueError( + "the split group's size should be less or equal to the world_size set by init_process_group" + ) + if len(split_group) != len(set(split_group)): + raise ValueError("the split group cannot have duplicate ranks") + split_group = sorted(split_group) + if parent_group_rank in split_group: + my_group = split_group + group_rank = split_group.index(parent_group_rank) + break + # if my rank does not belong to any sub group, + # no_color split should be called + if my_group is None or group_rank == -1: + parent_backend.perform_nocolor_split(device_id) + return None + + group_name = _process_group_name(my_group, use_hashed_name=False) + global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group] + + prefix_store = PrefixStore(f"{group_name}/", default_store) + base_pg_options = ProcessGroup.Options(backend=str(backend)) + base_pg_options._timeout = timeout + pg: ProcessGroup = ProcessGroup( + prefix_store, group_rank, len(my_group), base_pg_options + ) + pg.bound_device_id = device_id + + pg_options._timeout = timeout + pg_options.split_from = parent_backend + pg_options.split_color = _process_group_color(my_group) + pg_options.global_ranks_in_group = global_ranks_in_my_group + pg_options.group_name = group_name + backend_class = ProcessGroupNCCL( + prefix_store, group_rank, len(my_group), pg_options + ) + backend_type = ProcessGroup.BackendType.NCCL + backend_class._set_sequence_number_for_group() + + pg._register_backend(torch.device("cuda"), backend_type, backend_class) + + # set group_name and group_desc to backend + assert group_name is not None + assert group_desc is not None + pg._set_group_name(group_name) + pg._set_group_desc(group_desc) + + # always eagerly initialize the backend in split_group + eager_backend = pg._get_backend(device_id) + eager_backend.eager_connect_single_device(device_id) + + # update global state + _world.pg_map[pg] = (backend, prefix_store) + _world.pg_names[pg] = group_name + _register_process_group(group_name, pg) + _world.pg_backend_config[pg] = str(backend_config) + pg_tag = f"ptd:{group_name}" + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag + + # Create the global rank to group rank mapping + _world.pg_group_ranks[pg] = { + global_rank: group_rank + for group_rank, global_rank in enumerate(global_ranks_in_my_group) + } + + return pg + + +@_time_logger +def new_group( + ranks=None, + timeout=None, + backend=None, + pg_options=None, + use_local_synchronization=False, + group_desc=None, +): + """ + Create a new distributed group. + + This function requires that all processes in the main group (i.e. all + processes that are part of the distributed job) enter this function, even + if they are not going to be members of the group. Additionally, groups + should be created in the same order in all processes. + + .. warning:: + Safe concurrent usage: + When using multiple process groups with the ``NCCL`` backend, the user + must ensure a globally consistent execution order of collectives across + ranks. + + If multiple threads within a process issue collectives, explicit + synchronization is necessary to ensure consistent ordering. + + When using async variants of torch.distributed communication APIs, + a work object is returned and the communication kernel is + enqueued on a separate CUDA stream, allowing overlap of communication + and computation. Once one or more async ops have been issued on one process + group, they must be synchronized with other cuda streams by calling `work.wait()` + before using another process group. + + See `Using multiple NCCL communicators concurrently `_ for more details. + + Args: + ranks (list[int]): List of ranks of group members. If ``None``, will be + set to all ranks. Default is ``None``. + timeout (timedelta, optional): see `init_process_group` for details and default value. + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values are ``gloo`` and ``nccl``. + By default uses the same backend as the global group. This field + should be given as a lowercase string (e.g., ``"gloo"``), which can + also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). If ``None`` is passed in, the backend + corresponding to the default process group will be used. Default is + ``None``. + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. i.e. for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + process group can pick up high priority cuda streams. For other availble options to config nccl, + See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t + use_local_synchronization (bool, optional): perform a group-local + barrier at the end of the process group creation. This is different + in that non-member ranks don't need to call into API and don't + join the barrier. + group_desc (str, optional): a string to describe the process group. + + Returns: + A handle of distributed group that can be given to collective calls or + GroupMember.NON_GROUP_MEMBER if the rank is not part of ``ranks``. + + N.B. use_local_synchronization doesn't work with MPI. + + N.B. While use_local_synchronization=True can be significantly faster with larger + clusters and small process groups, care must be taken since it changes cluster behavior + as non-member ranks don't join the group barrier(). + + N.B. use_local_synchronization=True can lead to deadlocks when each rank creates + multiple overlaping process groups. To avoid that, make sure all ranks follow the + same global creation order. + """ + return _new_group_with_tag( + ranks, + timeout, + backend, + pg_options, + None, + use_local_synchronization=use_local_synchronization, + group_desc=group_desc, + ) + + +def _new_group_with_tag( + ranks=None, + timeout=None, + backend=None, + pg_options=None, + pg_tag=None, + use_local_synchronization=False, + group_desc=None, +): + """ + Variant of ``new_group`` that exposes tag creation. + + :: N.B. The mechanism is experimental and tied to the functional collectives effort, see + ``torch.distributed._functional_collectives`` for reference on how to use it. + """ + global _world + + default_pg = _get_default_group() + device_id = default_pg.bound_device_id + default_backend, default_store = _world.pg_map[default_pg] + global_rank = default_pg.rank() + global_world_size = default_pg.size() + + # Default to the same backend as the global process group + # if the backend is not specified. + if not backend: + backend = default_backend + backend = Backend(backend) + + # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, + # which may just pass their timeout value (or None) + if timeout is None: + timeout = _get_default_timeout(backend) + _check_valid_timeout(timeout) + + if use_local_synchronization: + # MPI backend doesn't have have a way for us to perform a partial sync + if backend == Backend.MPI: + raise ValueError( + "MPI backend doesn't support use_local_synchronization=True" + ) + if ranks is not None and get_rank() not in ranks: + return None + + # checks the input ranks + if ranks is not None: + ranks = sorted(ranks) + group_world_size = len(ranks) + if group_world_size > global_world_size: + raise ValueError( + "the new group's world size should be less or " + "equal to the world size set by " + "init_process_group" + ) + # check ranks' sanity + for rank in ranks: + if rank < 0 or rank >= global_world_size: + raise ValueError( + "The new group's rank should be within " + "the world_size set by init_process_group" + ) + if global_rank in ranks: + group_rank = ranks.index(global_rank) + else: + group_rank = None + else: + ranks = list(range(global_world_size)) + group_world_size = global_world_size + group_rank = global_rank + + group_name = _process_group_name(ranks, use_hashed_name=use_local_synchronization) + + pg, pg_store = _new_process_group_helper( + group_world_size, + group_rank, + ranks, + backend, + default_store, + group_name, + pg_options=pg_options, + timeout=timeout, + pg_tag=pg_tag, + device_id=device_id, + group_desc=group_desc, + ) + + # Create the global rank to group rank mapping + _world.pg_group_ranks[pg] = { + global_rank: group_rank for group_rank, global_rank in enumerate(ranks) + } + + if _is_barrier_after_init() == 1: + # barrier at the end to ensure that once we return from this method, all + # process groups including global variables (if any) are updated + # correctly on all ranks. + # Update 04/2023: for large-scale runs, this barrier (esp. store-based + # barrier) may be costly and/or unscalable. Also, in a lot of cases, + # these barriers may be unnecessary, as proven by a green CI after + # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been + # added which enables this barrier only when set to 1. + logger.info( + "Performing barrier after ProcessGroup initialization since " + "TORCH_DIST_INIT_BARRIER = 1" + ) + if backend == Backend.MPI: + # MPI doesn't have store. + barrier() + else: + barrier_store = pg_store if use_local_synchronization else default_store + world_size = len(ranks) if use_local_synchronization else get_world_size() + # Use store based barrier here since barrier() used a bunch of + # default devices and messes up NCCL internal state. + _store_based_barrier( + global_rank, barrier_store, group_name, world_size, timeout + ) + + return pg + + +def new_subgroups( + group_size=None, + group=None, + timeout=None, + backend=None, + pg_options=None, + group_desc=None, +): + """ + Create subgroups of equal size. + + By default, it creates intra-machine subgroups, + where each of which contains all the ranks of a machine, based on the assumption + that each machine has the same number of devices. + + This is a convenience API that calls ``new_group`` to generate multiple subgroups. + It requires that all processes in the main group (i.e. all + processes that are part of the distributed job) enter this function, even + if they are not going to be members of the group. + + .. warning:: + If ``group_size`` is passed in, the world size must be divisible by ``group_size``. + If no ``group_size`` is passed in, it believe that you are creating a group based + on CUDA and determining the group size by number of CUDA devices, and if not all + the machines have the same number of devices, the subgroup division will be + different across nodes and can cause unexpected behaviors. Therefore, if you are + creating a subgroup that does not depend on CUDA (such as Gloo on CPU), please + pass in ``group_size`` correctly. + + .. warning:: + See warning `Safe concurrent usage` for `new_group` API for important details about + using multiple process groups concurrently in a safe manner. + + Args: + group_size (int, optional): The size of each subgroup. If ``None``, + the default subgroup size is equal to the number of devices on each machine, + based on the assumption that each machine has exactly the same + number of devices. Default is ``None``. + timeout (timedelta, optional): see `init_process_group` for details and default value. + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values are ``gloo`` and ``nccl``. + By default uses the same backend as the global group. This field + should be given as a lowercase string (e.g., ``"gloo"``), which can + also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). If ``None`` is passed in, the backend + corresponding to the default process group will be used. Default is + ``None``. + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. i.e. for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + process group can pick up high priority cuda streams. + group_desc (str, optional): A string describing the group. Each subgroup will + inherit its group_desc + + Returns: + The subgroup containing the current rank, and all the subgroups used for cleanup. + + Examples: + >>> # Create intra-machine subgroups. + >>> # xdoctest: +SKIP("need process group init") + >>> cur_subgroup, subgroups = dist.new_subgroups() + >>> # Allreduce within the machine. + >>> rank = dist.get_rank() + >>> tensor = torch.ones(1, device=rank) * rank + >>> dist.all_reduce(tensor, group=cur_subgroup) + >>> tensor + tensor([28]) # Assume 8 CUDA devices per machine. 28 is sum(range(8)). + >>> # Cleanup. + >>> for subgroup in subgroups: + >>> dist.destroy_process_group(subgroup) + """ + if group_size is None: + if not torch.cuda.is_available(): + raise ValueError( + "Default group size only takes effect when CUDA is available." + "If your subgroup using a backend that does not depend on CUDA," + "please pass in 'group_size' correctly." + ) + group_size = torch.cuda.device_count() + if group_size <= 0: + raise ValueError(f"The arg 'group_size' ({group_size}) must be positive") + + world_size = get_world_size() + if world_size < group_size: + raise ValueError( + f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})" + ) + if world_size % group_size != 0: + raise ValueError("The world size must be divisible by 'group_size'") + + subgroups = [] + cur_subgroup = None + + for subgroup_id in range(world_size // group_size): + start_rank = subgroup_id * group_size + end_rank = start_rank + group_size + ranks_in_subgroup = list(range(start_rank, end_rank)) + subgroup = new_group( + ranks=ranks_in_subgroup, + timeout=timeout, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) + subgroups.append(subgroup) + + rank = get_rank() + if rank in ranks_in_subgroup: + cur_subgroup = subgroup + logger.info("Rank %s is assigned to subgroup %s", rank, ranks_in_subgroup) + + return cur_subgroup, subgroups + + +def new_subgroups_by_enumeration( + ranks_per_subgroup_list, + timeout=None, + backend=None, + pg_options=None, + group_desc=None, +): + """ + Create subgroups by dividing the global world. + + The division is specified by a nested list of ranks. The subgroups cannot have + overlap, and some ranks may not have to be in any subgroup. + + This is a convenience API that calls ``new_group`` to generate multiple subgroups. + It requires that all processes in the main group (i.e. all + processes that are part of the distributed job) enter this function, even + if they are not going to be members of the group. + + .. warning:: + See warning `Safe concurrent usage` for `new_group` API for important details about + using multiple process groups concurrently in a safe manner. + + Args: + ranks_per_subgroup_list (list[list[int]]): A nested list of ranks of + group members. + timeout (timedelta, optional): see `init_process_group` for details and default value. + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values are ``gloo`` and ``nccl``. + By default uses the same backend as the global group. This field + should be given as a lowercase string (e.g., ``"gloo"``), which can + also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). If ``None`` is passed in, the backend + corresponding to the default process group will be used. Default is + ``None``. + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. i.e. for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + process group can pick up high priority cuda streams. + group_desc (str, optional): A string describing the group. Each subgroup will + inherit its group_desc. + + Returns: + The subgroup containing the current rank, and all the subgroups used for cleanup. + + Examples: + >>> # Create two subgroups, where each has 2 processes. + >>> # xdoctest: +SKIP("need process group init") + >>> cur_subgroup, subgroups = dist.new_subgroups(ranks=[[0, 2], [1, 3]]) + >>> rank = dist.get_rank() + >>> tensor = torch.ones(1, device=rank) * rank + >>> dist.all_reduce(tensor, group=cur_subgroup) + >>> tensor + tensor([2]) # Subgroup 0: ranks 0 and 2 + tensor([4]) # Subgroup 1: ranks 1 and 3 + """ + if ranks_per_subgroup_list is None or len(ranks_per_subgroup_list) == 0: + raise ValueError("The arg 'ranks_per_subgroup_list' cannot be empty") + + subgroups = [] + cur_subgroup = None + # Create a mapping from rank to subgroup to check if there is any subgroup overlap. + rank_to_ranks_dict = {} # type: ignore[var-annotated] + for ranks in ranks_per_subgroup_list: + subgroup = new_group( + ranks=ranks, + timeout=timeout, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) + subgroups.append(subgroup) + my_rank = get_rank() + for rank in ranks: + if rank in rank_to_ranks_dict: + raise ValueError( + f"Rank {rank} has appeared in both subgroup {rank_to_ranks_dict[rank]} and {ranks}" + ) + rank_to_ranks_dict[rank] = ranks + if my_rank == rank: + cur_subgroup = subgroup + logger.info("Rank %s is assigned to subgroup %s", rank, ranks) + + return cur_subgroup, subgroups + + +def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGroup]: + if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"): + tag = f"user:{tag}" + + for group in _world.tags_to_pg.get(tag, []): + if group.size() != len(ranks): + continue + + group_ranks = get_process_group_ranks(group) + good = all(r in group_ranks for r in ranks) + if good: + return group + return None + + +def _find_or_create_pg_by_ranks_and_tag( + tag: str, ranks: List[int], stride: int +) -> ProcessGroup: + assert ( + len(ranks) % stride == 0 + ), f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" + + my_rank = get_rank() + my_ranks = None + + if stride == len(ranks): + my_ranks = ranks.copy() + assert my_rank in my_ranks, "rankset doesn't include the current node" + else: + for i in range(0, len(ranks), stride): + rank_set = ranks[i : i + stride] + if my_rank in rank_set: + my_ranks = rank_set + assert my_ranks is not None, "rankset doesn't include the current node" + + my_ranks = sorted(my_ranks) + + pg = _find_pg_by_ranks_and_tag(tag, my_ranks) + if pg is not None: + return pg + if tag == "": + raise ValueError("Cannot automatically create PG with empty tag") + # TODO copy settings and timeout from default PG + return _new_group_with_tag(my_ranks, pg_tag=tag) + + +def _get_group_tag(pg: ProcessGroup) -> str: + """Return the tag associated with ``pg``.""" + tag = _world.pg_to_tag[pg] + if tag.startswith("user:"): + tag = tag[5:] + return tag + + +def _get_process_group_name(pg: ProcessGroup) -> str: + return _world.pg_names.get(pg, "None") + + +def _get_process_group_store(pg: ProcessGroup) -> Store: + return _world.pg_map[pg][1] + + +# This ops are not friendly to TorchDynamo. So, we decide to disallow these ops +# in FX graph, allowing them to run them on eager, with torch.compile. +dynamo_unsupported_distributed_c10d_ops = [ + recv, + all_gather_object, + all_gather_coalesced, + all_to_all_single, + all_reduce, + gather_object, + all_to_all, + all_reduce_coalesced, + gather, + send_object_list, + recv_object_list, + broadcast_object_list, + barrier, + scatter, + scatter_object_list, + reduce, + all_gather, + reduce_scatter, + all_gather_into_tensor, + broadcast, + reduce_scatter_tensor, + send, +] diff --git a/lib/python3.10/site-packages/torch/distributed/launch.py b/lib/python3.10/site-packages/torch/distributed/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..a9e35c36db7fbc585351128b9bbd712924e809af --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/launch.py @@ -0,0 +1,208 @@ +# mypy: allow-untyped-defs +r""" +Module ``torch.distributed.launch``. + +``torch.distributed.launch`` is a module that spawns up multiple distributed +training processes on each of the training nodes. + +.. warning:: + + This module is going to be deprecated in favor of :ref:`torchrun `. + +The utility can be used for single-node distributed training, in which one or +more processes per node will be spawned. The utility can be used for either +CPU training or GPU training. If the utility is used for GPU training, +each distributed process will be operating on a single GPU. This can achieve +well-improved single-node training performance. It can also be used in +multi-node distributed training, by spawning up multiple processes on each node +for well-improved multi-node distributed training performance as well. +This will especially be beneficial for systems with multiple Infiniband +interfaces that have direct-GPU support, since all of them can be utilized for +aggregated communication bandwidth. + +In both cases of single-node distributed training or multi-node distributed +training, this utility will launch the given number of processes per node +(``--nproc-per-node``). If used for GPU training, this number needs to be less +or equal to the number of GPUs on the current system (``nproc_per_node``), +and each process will be operating on a single GPU from *GPU 0 to +GPU (nproc_per_node - 1)*. + +**How to use this module:** + +1. Single-Node multi-process distributed training + +:: + + python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE + YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other + arguments of your training script) + +2. Multi-Node multi-process distributed training: (e.g. two nodes) + + +Node 1: *(IP: 192.168.1.1, and has a free port: 1234)* + +:: + + python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node-rank=0 --master-addr="192.168.1.1" + --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) + +Node 2: + +:: + + python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node-rank=1 --master-addr="192.168.1.1" + --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) + +3. To look up what optional arguments this module offers: + +:: + + python -m torch.distributed.launch --help + + +**Important Notices:** + +1. This utility and multi-process distributed (single-node or +multi-node) GPU training currently only achieves the best performance using +the NCCL distributed backend. Thus NCCL backend is the recommended backend to +use for GPU training. + +2. In your training program, you must parse the command-line argument: +``--local-rank=LOCAL_PROCESS_RANK``, which will be provided by this module. +If your training program uses GPUs, you should ensure that your code only +runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: + +Parsing the local_rank argument + +:: + + >>> # xdoctest: +SKIP + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> parser.add_argument("--local-rank", "--local_rank", type=int) + >>> args = parser.parse_args() + +Set your device to local rank using either + +:: + + >>> torch.cuda.set_device(args.local_rank) # before your code runs + +or + +:: + + >>> with torch.cuda.device(args.local_rank): + >>> # your code to run + >>> ... + +.. versionchanged:: 2.0.0 + + The launcher will passes the ``--local-rank=`` argument to your script. + From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the + previously used underscored ``--local_rank``. + + For backward compatibility, it may be necessary for users to handle both + cases in their argument parsing code. This means including both ``"--local-rank"`` + and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is + provided, the launcher will trigger an error: "error: unrecognized arguments: + --local-rank=". For training code that only supports PyTorch 2.0.0+, + including ``"--local-rank"`` should be sufficient. + +3. In your training program, you are supposed to call the following function +at the beginning to start the distributed backend. It is strongly recommended +that ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work, +but ``env://`` is the one that is officially supported by this module. + +:: + + >>> torch.distributed.init_process_group(backend='YOUR BACKEND', + >>> init_method='env://') + +4. In your training program, you can either use regular distributed functions +or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your +training program uses GPUs for training and you would like to use +:func:`torch.nn.parallel.DistributedDataParallel` module, +here is how to configure it. + +:: + + >>> model = torch.nn.parallel.DistributedDataParallel(model, + >>> device_ids=[args.local_rank], + >>> output_device=args.local_rank) + +Please ensure that ``device_ids`` argument is set to be the only GPU device id +that your code will be operating on. This is generally the local rank of the +process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``, +and ``output_device`` needs to be ``args.local_rank`` in order to use this +utility + +5. Another way to pass ``local_rank`` to the subprocesses via environment variable +``LOCAL_RANK``. This behavior is enabled when you launch the script with +``--use-env=True``. You must adjust the subprocess example above to replace +``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher +will not pass ``--local-rank`` when you specify this flag. + +.. warning:: + + ``local_rank`` is NOT globally unique: it is only unique per process + on a machine. Thus, don't use it to decide if you should, e.g., + write to a networked filesystem. See + https://github.com/pytorch/pytorch/issues/12042 for an example of + how things can go wrong if you don't do this correctly. + + + +""" + +from typing_extensions import deprecated as _deprecated + +from torch.distributed.run import get_args_parser, run + + +def parse_args(args): + parser = get_args_parser() + parser.add_argument( + "--use-env", + "--use_env", + default=False, + action="store_true", + help="Use environment variable to pass " + "'local rank'. For legacy reasons, the default value is False. " + "If set to True, the script will not pass " + "--local-rank as argument, and will instead set LOCAL_RANK.", + ) + return parser.parse_args(args) + + +def launch(args): + if args.no_python and not args.use_env: + raise ValueError( + "When using the '--no-python' flag," + " you must also set the '--use-env' flag." + ) + run(args) + + +@_deprecated( + "The module torch.distributed.launch is deprecated\n" + "and will be removed in future. Use torchrun.\n" + "Note that --use-env is set by default in torchrun.\n" + "If your script expects `--local-rank` argument to be set, please\n" + "change it to read from `os.environ['LOCAL_RANK']` instead. See \n" + "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n" + "further instructions\n", + category=FutureWarning, +) +def main(args=None): + args = parse_args(args) + launch(args) + + +if __name__ == "__main__": + main() diff --git a/lib/python3.10/site-packages/torch/distributed/logging_handlers.py b/lib/python3.10/site-packages/torch/distributed/logging_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..021ad100f06a89fa944d7e6dba18f1a9558d88f5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/logging_handlers.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, List + + +__all__: List[str] = [] + +_log_handlers: Dict[str, logging.Handler] = { + "default": logging.NullHandler(), +} diff --git a/lib/python3.10/site-packages/torch/distributed/remote_device.py b/lib/python3.10/site-packages/torch/distributed/remote_device.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5215e2f83a71e56c6f638bd1453341da671d37 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributed/remote_device.py @@ -0,0 +1,120 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch + + +class _remote_device: + """ + Represents a device on a remote worker. + + Args: + remote_device (str or torch.device): Represents a device on a remote worker. + The string format should be one of the following: + + 1. "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". + 2. "rank:/", where is the rank of the + process and device can be parsed as torch.device type. + E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0" + 3. and are optional and formats like "cpu" + and "cuda:1", just represent local devices. + """ + + def __init__(self, remote_device: Union[str, torch.device]): + PARSE_ERROR = ( + f"Could not parse remote_device: {remote_device}. The valid format is " + "'/' or 'rank:/' or ''" + ) + self._worker_name = None + self._rank = None + self._device: Optional[Union[str, int, torch.device]] = None + + if isinstance(remote_device, torch.device): + self._device = remote_device + elif isinstance(remote_device, str): + fields = remote_device.split("/") + if len(fields) == 2: + self._worker_name, self._device = fields + elif len(fields) == 1: + # Check if this is a valid device. + if _remote_device._is_valid_local_device(fields[0]): + self._device = fields[0] + else: + self._worker_name = fields[0] + self._device = "cpu" + else: + raise ValueError(PARSE_ERROR) + else: + raise TypeError(f"Invalid type for remote_device: {type(remote_device)}") + + # Do some basic sanity check (no empty string) + if self._worker_name is not None and not self._worker_name: + raise ValueError(PARSE_ERROR) + + # Validate the device. + self._device = torch.device(self._device) + + # Check for rank based format. + if self._worker_name is not None: + fields = self._worker_name.split(":") + if len(fields) == 2: + # rank:/device format, extract rank + if fields[0] == "rank" and fields[1].isdigit(): + self._rank = int(fields[1]) # type: ignore[assignment] + self._worker_name = None + else: + raise ValueError(PARSE_ERROR) + elif len(fields) > 2: + raise ValueError(PARSE_ERROR) + + @staticmethod + def _is_valid_local_device(device): + # Check for torch.device + try: + torch.device(device) + return True + except Exception: + return False + + def worker_name(self) -> Optional[str]: + """Return the name of remote worker representing the remote device and ``None`` if no worker name is available.""" + return self._worker_name + + def rank(self) -> Optional[int]: + """ + Returns the rank of remote worker representing the remote device. + Returns ``None`` if no rank is available. + """ + return self._rank + + def device(self) -> torch.device: + """Return the local device on the remote worker.""" + return self._device # type: ignore[return-value] + + def __repr__(self): + if self._device is not None: + if self._worker_name is not None: + return f"{self._worker_name}/{self._device}" + elif self._rank is not None: + return f"rank:{self._rank}/{self._device}" + else: + return str(self._device) + else: + if self._worker_name is not None: + return f"{self._worker_name}" + elif self._rank is not None: + return f"{self._rank}" + else: + raise RuntimeError("Invalid state!") + + def __eq__(self, other): + return isinstance(other, _remote_device) and ( + self._worker_name == other._worker_name + and self._device == other._device + and self._rank == other._rank + ) + + def __hash__(self): + return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank) diff --git a/lib/python3.10/site-packages/torch/distributions/__init__.py b/lib/python3.10/site-packages/torch/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc2775468d540113beae24a21219778e55f4dad --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/__init__.py @@ -0,0 +1,172 @@ +r""" +The ``distributions`` package contains parameterizable probability distributions +and sampling functions. This allows the construction of stochastic computation +graphs and stochastic gradient estimators for optimization. This package +generally follows the design of the `TensorFlow Distributions`_ package. + +.. _`TensorFlow Distributions`: + https://arxiv.org/abs/1711.10604 + +It is not possible to directly backpropagate through random samples. However, +there are two main methods for creating surrogate functions that can be +backpropagated through. These are the score function estimator/likelihood ratio +estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly +seen as the basis for policy gradient methods in reinforcement learning, and the +pathwise derivative estimator is commonly seen in the reparameterization trick +in variational autoencoders. Whilst the score function only requires the value +of samples :math:`f(x)`, the pathwise derivative requires the derivative +:math:`f'(x)`. The next sections discuss these two in a reinforcement learning +example. For more details see +`Gradient Estimation Using Stochastic Computation Graphs`_ . + +.. _`Gradient Estimation Using Stochastic Computation Graphs`: + https://arxiv.org/abs/1506.05254 + +Score function +^^^^^^^^^^^^^^ + +When the probability density function is differentiable with respect to its +parameters, we only need :meth:`~torch.distributions.Distribution.sample` and +:meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE: + +.. math:: + + \Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta} + +where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate, +:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of +taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`. + +In practice we would sample an action from the output of a network, apply this +action in an environment, and then use ``log_prob`` to construct an equivalent +loss function. Note that we use a negative because optimizers use gradient +descent, whilst the rule above assumes gradient ascent. With a categorical +policy, the code for implementing REINFORCE would be as follows:: + + probs = policy_network(state) + # Note that this is equivalent to what used to be called multinomial + m = Categorical(probs) + action = m.sample() + next_state, reward = env.step(action) + loss = -m.log_prob(action) * reward + loss.backward() + +Pathwise derivative +^^^^^^^^^^^^^^^^^^^ + +The other way to implement these stochastic/policy gradients would be to use the +reparameterization trick from the +:meth:`~torch.distributions.Distribution.rsample` method, where the +parameterized random variable can be constructed via a parameterized +deterministic function of a parameter-free random variable. The reparameterized +sample therefore becomes differentiable. The code for implementing the pathwise +derivative would be as follows:: + + params = policy_network(state) + m = Normal(*params) + # Any distribution with .has_rsample == True could work based on the application + action = m.rsample() + next_state, reward = env.step(action) # Assuming that reward is differentiable + loss = -reward + loss.backward() +""" + +from . import transforms +from .bernoulli import Bernoulli +from .beta import Beta +from .binomial import Binomial +from .categorical import Categorical +from .cauchy import Cauchy +from .chi2 import Chi2 +from .constraint_registry import biject_to, transform_to +from .continuous_bernoulli import ContinuousBernoulli +from .dirichlet import Dirichlet +from .distribution import Distribution +from .exp_family import ExponentialFamily +from .exponential import Exponential +from .fishersnedecor import FisherSnedecor +from .gamma import Gamma +from .geometric import Geometric +from .gumbel import Gumbel +from .half_cauchy import HalfCauchy +from .half_normal import HalfNormal +from .independent import Independent +from .inverse_gamma import InverseGamma +from .kl import _add_kl_info, kl_divergence, register_kl +from .kumaraswamy import Kumaraswamy +from .laplace import Laplace +from .lkj_cholesky import LKJCholesky +from .log_normal import LogNormal +from .logistic_normal import LogisticNormal +from .lowrank_multivariate_normal import LowRankMultivariateNormal +from .mixture_same_family import MixtureSameFamily +from .multinomial import Multinomial +from .multivariate_normal import MultivariateNormal +from .negative_binomial import NegativeBinomial +from .normal import Normal +from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough +from .pareto import Pareto +from .poisson import Poisson +from .relaxed_bernoulli import RelaxedBernoulli +from .relaxed_categorical import RelaxedOneHotCategorical +from .studentT import StudentT +from .transformed_distribution import TransformedDistribution +from .transforms import * # noqa: F403 +from .uniform import Uniform +from .von_mises import VonMises +from .weibull import Weibull +from .wishart import Wishart + + +_add_kl_info() +del _add_kl_info + +__all__ = [ + "Bernoulli", + "Beta", + "Binomial", + "Categorical", + "Cauchy", + "Chi2", + "ContinuousBernoulli", + "Dirichlet", + "Distribution", + "Exponential", + "ExponentialFamily", + "FisherSnedecor", + "Gamma", + "Geometric", + "Gumbel", + "HalfCauchy", + "HalfNormal", + "Independent", + "InverseGamma", + "Kumaraswamy", + "LKJCholesky", + "Laplace", + "LogNormal", + "LogisticNormal", + "LowRankMultivariateNormal", + "MixtureSameFamily", + "Multinomial", + "MultivariateNormal", + "NegativeBinomial", + "Normal", + "OneHotCategorical", + "OneHotCategoricalStraightThrough", + "Pareto", + "RelaxedBernoulli", + "RelaxedOneHotCategorical", + "StudentT", + "Poisson", + "Uniform", + "VonMises", + "Weibull", + "Wishart", + "TransformedDistribution", + "biject_to", + "kl_divergence", + "register_kl", + "transform_to", +] +__all__.extend(transforms.__all__) diff --git a/lib/python3.10/site-packages/torch/distributions/bernoulli.py b/lib/python3.10/site-packages/torch/distributions/bernoulli.py new file mode 100644 index 0000000000000000000000000000000000000000..8bfcb500b4eb505ffcb43aad9fb124881935e71c --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/bernoulli.py @@ -0,0 +1,132 @@ +# mypy: allow-untyped-defs +from numbers import Number + +import torch +from torch import nan +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) +from torch.nn.functional import binary_cross_entropy_with_logits + + +__all__ = ["Bernoulli"] + + +class Bernoulli(ExponentialFamily): + r""" + Creates a Bernoulli distribution parameterized by :attr:`probs` + or :attr:`logits` (but not both). + + Samples are binary (0 or 1). They take the value `1` with probability `p` + and `0` with probability `1 - p`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Bernoulli(torch.tensor([0.3])) + >>> m.sample() # 30% chance 1; 70% chance 0 + tensor([ 0.]) + + Args: + probs (Number, Tensor): the probability of sampling `1` + logits (Number, Tensor): the log-odds of sampling `1` + """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.boolean + has_enumerate_support = True + _mean_carrier_measure = 0 + + def __init__(self, probs=None, logits=None, validate_args=None): + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + is_scalar = isinstance(probs, Number) + (self.probs,) = broadcast_all(probs) + else: + is_scalar = isinstance(logits, Number) + (self.logits,) = broadcast_all(logits) + self._param = self.probs if probs is not None else self.logits + if is_scalar: + batch_shape = torch.Size() + else: + batch_shape = self._param.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Bernoulli, _instance) + batch_shape = torch.Size(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(Bernoulli, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @property + def mean(self): + return self.probs + + @property + def mode(self): + mode = (self.probs >= 0.5).to(self.probs) + mode[self.probs == 0.5] = nan + return mode + + @property + def variance(self): + return self.probs * (1 - self.probs) + + @lazy_property + def logits(self): + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits, is_binary=True) + + @property + def param_shape(self): + return self._param.size() + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + with torch.no_grad(): + return torch.bernoulli(self.probs.expand(shape)) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + return -binary_cross_entropy_with_logits(logits, value, reduction="none") + + def entropy(self): + return binary_cross_entropy_with_logits( + self.logits, self.probs, reduction="none" + ) + + def enumerate_support(self, expand=True): + values = torch.arange(2, dtype=self._param.dtype, device=self._param.device) + values = values.view((-1,) + (1,) * len(self._batch_shape)) + if expand: + values = values.expand((-1,) + self._batch_shape) + return values + + @property + def _natural_params(self): + return (torch.logit(self.probs),) + + def _log_normalizer(self, x): + return torch.log1p(torch.exp(x)) diff --git a/lib/python3.10/site-packages/torch/distributions/beta.py b/lib/python3.10/site-packages/torch/distributions/beta.py new file mode 100644 index 0000000000000000000000000000000000000000..f660d80326e3a9f8e8fd648de5df21f57a9f398d --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/beta.py @@ -0,0 +1,110 @@ +# mypy: allow-untyped-defs +from numbers import Number, Real + +import torch +from torch.distributions import constraints +from torch.distributions.dirichlet import Dirichlet +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import broadcast_all +from torch.types import _size + + +__all__ = ["Beta"] + + +class Beta(ExponentialFamily): + r""" + Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5])) + >>> m.sample() # Beta distributed with concentration concentration1 and concentration0 + tensor([ 0.1046]) + + Args: + concentration1 (float or Tensor): 1st concentration parameter of the distribution + (often referred to as alpha) + concentration0 (float or Tensor): 2nd concentration parameter of the distribution + (often referred to as beta) + """ + arg_constraints = { + "concentration1": constraints.positive, + "concentration0": constraints.positive, + } + support = constraints.unit_interval + has_rsample = True + + def __init__(self, concentration1, concentration0, validate_args=None): + if isinstance(concentration1, Real) and isinstance(concentration0, Real): + concentration1_concentration0 = torch.tensor( + [float(concentration1), float(concentration0)] + ) + else: + concentration1, concentration0 = broadcast_all( + concentration1, concentration0 + ) + concentration1_concentration0 = torch.stack( + [concentration1, concentration0], -1 + ) + self._dirichlet = Dirichlet( + concentration1_concentration0, validate_args=validate_args + ) + super().__init__(self._dirichlet._batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Beta, _instance) + batch_shape = torch.Size(batch_shape) + new._dirichlet = self._dirichlet.expand(batch_shape) + super(Beta, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def mean(self): + return self.concentration1 / (self.concentration1 + self.concentration0) + + @property + def mode(self): + return self._dirichlet.mode[..., 0] + + @property + def variance(self): + total = self.concentration1 + self.concentration0 + return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1)) + + def rsample(self, sample_shape: _size = ()) -> torch.Tensor: + return self._dirichlet.rsample(sample_shape).select(-1, 0) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + heads_tails = torch.stack([value, 1.0 - value], -1) + return self._dirichlet.log_prob(heads_tails) + + def entropy(self): + return self._dirichlet.entropy() + + @property + def concentration1(self): + result = self._dirichlet.concentration[..., 0] + if isinstance(result, Number): + return torch.tensor([result]) + else: + return result + + @property + def concentration0(self): + result = self._dirichlet.concentration[..., 1] + if isinstance(result, Number): + return torch.tensor([result]) + else: + return result + + @property + def _natural_params(self): + return (self.concentration1, self.concentration0) + + def _log_normalizer(self, x, y): + return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y) diff --git a/lib/python3.10/site-packages/torch/distributions/binomial.py b/lib/python3.10/site-packages/torch/distributions/binomial.py new file mode 100644 index 0000000000000000000000000000000000000000..18b267ea27fd9b93dbda2b9a809c3a135ca3f1f0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/binomial.py @@ -0,0 +1,167 @@ +# mypy: allow-untyped-defs +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) + + +__all__ = ["Binomial"] + + +def _clamp_by_zero(x): + # works like clamp(x, min=0) but has grad at 0 is 0.5 + return (x.clamp(min=0) + x - x.clamp(max=0)) / 2 + + +class Binomial(Distribution): + r""" + Creates a Binomial distribution parameterized by :attr:`total_count` and + either :attr:`probs` or :attr:`logits` (but not both). :attr:`total_count` must be + broadcastable with :attr:`probs`/:attr:`logits`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Binomial(100, torch.tensor([0 , .2, .8, 1])) + >>> x = m.sample() + tensor([ 0., 22., 71., 100.]) + + >>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8])) + >>> x = m.sample() + tensor([[ 4., 5.], + [ 7., 6.]]) + + Args: + total_count (int or Tensor): number of Bernoulli trials + probs (Tensor): Event probabilities + logits (Tensor): Event log-odds + """ + arg_constraints = { + "total_count": constraints.nonnegative_integer, + "probs": constraints.unit_interval, + "logits": constraints.real, + } + has_enumerate_support = True + + def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + ( + self.total_count, + self.probs, + ) = broadcast_all(total_count, probs) + self.total_count = self.total_count.type_as(self.probs) + else: + ( + self.total_count, + self.logits, + ) = broadcast_all(total_count, logits) + self.total_count = self.total_count.type_as(self.logits) + + self._param = self.probs if probs is not None else self.logits + batch_shape = self._param.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Binomial, _instance) + batch_shape = torch.Size(batch_shape) + new.total_count = self.total_count.expand(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(Binomial, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @constraints.dependent_property(is_discrete=True, event_dim=0) + def support(self): + return constraints.integer_interval(0, self.total_count) + + @property + def mean(self): + return self.total_count * self.probs + + @property + def mode(self): + return ((self.total_count + 1) * self.probs).floor().clamp(max=self.total_count) + + @property + def variance(self): + return self.total_count * self.probs * (1 - self.probs) + + @lazy_property + def logits(self): + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits, is_binary=True) + + @property + def param_shape(self): + return self._param.size() + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + with torch.no_grad(): + return torch.binomial( + self.total_count.expand(shape), self.probs.expand(shape) + ) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + log_factorial_n = torch.lgamma(self.total_count + 1) + log_factorial_k = torch.lgamma(value + 1) + log_factorial_nmk = torch.lgamma(self.total_count - value + 1) + # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p) + # (case logit < 0) = k * logit - n * log1p(e^logit) + # (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p) + # = k * logit - n * logit - n * log1p(e^-logit) + # (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|) + normalize_term = ( + self.total_count * _clamp_by_zero(self.logits) + + self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits))) + - log_factorial_n + ) + return ( + value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term + ) + + def entropy(self): + total_count = int(self.total_count.max()) + if not self.total_count.min() == total_count: + raise NotImplementedError( + "Inhomogeneous total count not supported by `entropy`." + ) + + log_prob = self.log_prob(self.enumerate_support(False)) + return -(torch.exp(log_prob) * log_prob).sum(0) + + def enumerate_support(self, expand=True): + total_count = int(self.total_count.max()) + if not self.total_count.min() == total_count: + raise NotImplementedError( + "Inhomogeneous total count not supported by `enumerate_support`." + ) + values = torch.arange( + 1 + total_count, dtype=self._param.dtype, device=self._param.device + ) + values = values.view((-1,) + (1,) * len(self._batch_shape)) + if expand: + values = values.expand((-1,) + self._batch_shape) + return values diff --git a/lib/python3.10/site-packages/torch/distributions/categorical.py b/lib/python3.10/site-packages/torch/distributions/categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..717cf74ba7e56831a174ab4fc68bc8d23d554e84 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/categorical.py @@ -0,0 +1,157 @@ +# mypy: allow-untyped-defs +import torch +from torch import nan +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits + + +__all__ = ["Categorical"] + + +class Categorical(Distribution): + r""" + Creates a categorical distribution parameterized by either :attr:`probs` or + :attr:`logits` (but not both). + + .. note:: + It is equivalent to the distribution that :func:`torch.multinomial` + samples from. + + Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``. + + If `probs` is 1-dimensional with length-`K`, each element is the relative probability + of sampling the class at that index. + + If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of + relative probability vectors. + + .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, + and it will be normalized to sum to 1 along the last dimension. :attr:`probs` + will return this normalized value. + The `logits` argument will be interpreted as unnormalized log probabilities + and can therefore be any real number. It will likewise be normalized so that + the resulting probabilities sum to 1 along the last dimension. :attr:`logits` + will return this normalized value. + + See also: :func:`torch.multinomial` + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) + >>> m.sample() # equal probability of 0, 1, 2, 3 + tensor(3) + + Args: + probs (Tensor): event probabilities + logits (Tensor): event log probabilities (unnormalized) + """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + has_enumerate_support = True + + def __init__(self, probs=None, logits=None, validate_args=None): + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + if probs.dim() < 1: + raise ValueError("`probs` parameter must be at least one-dimensional.") + self.probs = probs / probs.sum(-1, keepdim=True) + else: + if logits.dim() < 1: + raise ValueError("`logits` parameter must be at least one-dimensional.") + # Normalize + self.logits = logits - logits.logsumexp(dim=-1, keepdim=True) + self._param = self.probs if probs is not None else self.logits + self._num_events = self._param.size()[-1] + batch_shape = ( + self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size() + ) + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Categorical, _instance) + batch_shape = torch.Size(batch_shape) + param_shape = batch_shape + torch.Size((self._num_events,)) + if "probs" in self.__dict__: + new.probs = self.probs.expand(param_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(param_shape) + new._param = new.logits + new._num_events = self._num_events + super(Categorical, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @constraints.dependent_property(is_discrete=True, event_dim=0) + def support(self): + return constraints.integer_interval(0, self._num_events - 1) + + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + + @property + def param_shape(self): + return self._param.size() + + @property + def mean(self): + return torch.full( + self._extended_shape(), + nan, + dtype=self.probs.dtype, + device=self.probs.device, + ) + + @property + def mode(self): + return self.probs.argmax(axis=-1) + + @property + def variance(self): + return torch.full( + self._extended_shape(), + nan, + dtype=self.probs.dtype, + device=self.probs.device, + ) + + def sample(self, sample_shape=torch.Size()): + if not isinstance(sample_shape, torch.Size): + sample_shape = torch.Size(sample_shape) + probs_2d = self.probs.reshape(-1, self._num_events) + samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T + return samples_2d.reshape(self._extended_shape(sample_shape)) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + value = value.long().unsqueeze(-1) + value, log_pmf = torch.broadcast_tensors(value, self.logits) + value = value[..., :1] + return log_pmf.gather(-1, value).squeeze(-1) + + def entropy(self): + min_real = torch.finfo(self.logits.dtype).min + logits = torch.clamp(self.logits, min=min_real) + p_log_p = logits * self.probs + return -p_log_p.sum(-1) + + def enumerate_support(self, expand=True): + num_events = self._num_events + values = torch.arange(num_events, dtype=torch.long, device=self._param.device) + values = values.view((-1,) + (1,) * len(self._batch_shape)) + if expand: + values = values.expand((-1,) + self._batch_shape) + return values diff --git a/lib/python3.10/site-packages/torch/distributions/cauchy.py b/lib/python3.10/site-packages/torch/distributions/cauchy.py new file mode 100644 index 0000000000000000000000000000000000000000..436cc727baa10b9f95d2eaf0647bce5032709afc --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/cauchy.py @@ -0,0 +1,93 @@ +# mypy: allow-untyped-defs +import math +from numbers import Number + +import torch +from torch import inf, nan +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all +from torch.types import _size + + +__all__ = ["Cauchy"] + + +class Cauchy(Distribution): + r""" + Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of + independent normally distributed random variables with means `0` follows a + Cauchy distribution. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0])) + >>> m.sample() # sample from a Cauchy distribution with loc=0 and scale=1 + tensor([ 2.3214]) + + Args: + loc (float or Tensor): mode or median of the distribution. + scale (float or Tensor): half width at half maximum. + """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + has_rsample = True + + def __init__(self, loc, scale, validate_args=None): + self.loc, self.scale = broadcast_all(loc, scale) + if isinstance(loc, Number) and isinstance(scale, Number): + batch_shape = torch.Size() + else: + batch_shape = self.loc.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Cauchy, _instance) + batch_shape = torch.Size(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + super(Cauchy, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def mean(self): + return torch.full( + self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device + ) + + @property + def mode(self): + return self.loc + + @property + def variance(self): + return torch.full( + self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device + ) + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + eps = self.loc.new(shape).cauchy_() + return self.loc + eps * self.scale + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return ( + -math.log(math.pi) + - self.scale.log() + - (((value - self.loc) / self.scale) ** 2).log1p() + ) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5 + + def icdf(self, value): + return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc + + def entropy(self): + return math.log(4 * math.pi) + self.scale.log() diff --git a/lib/python3.10/site-packages/torch/distributions/chi2.py b/lib/python3.10/site-packages/torch/distributions/chi2.py new file mode 100644 index 0000000000000000000000000000000000000000..a44035b897edc133a928c41945c823203cbb30f8 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/chi2.py @@ -0,0 +1,35 @@ +# mypy: allow-untyped-defs +from torch.distributions import constraints +from torch.distributions.gamma import Gamma + + +__all__ = ["Chi2"] + + +class Chi2(Gamma): + r""" + Creates a Chi-squared distribution parameterized by shape parameter :attr:`df`. + This is exactly equivalent to ``Gamma(alpha=0.5*df, beta=0.5)`` + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Chi2(torch.tensor([1.0])) + >>> m.sample() # Chi2 distributed with shape df=1 + tensor([ 0.1046]) + + Args: + df (float or Tensor): shape parameter of the distribution + """ + arg_constraints = {"df": constraints.positive} + + def __init__(self, df, validate_args=None): + super().__init__(0.5 * df, 0.5, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Chi2, _instance) + return super().expand(batch_shape, new) + + @property + def df(self): + return self.concentration * 2 diff --git a/lib/python3.10/site-packages/torch/distributions/constraint_registry.py b/lib/python3.10/site-packages/torch/distributions/constraint_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..ce73b1a4df2eb1fe1ebf38865a061cdb99f14efc --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/constraint_registry.py @@ -0,0 +1,294 @@ +# mypy: allow-untyped-defs +r""" +PyTorch provides two global :class:`ConstraintRegistry` objects that link +:class:`~torch.distributions.constraints.Constraint` objects to +:class:`~torch.distributions.transforms.Transform` objects. These objects both +input constraints and return transforms, but they have different guarantees on +bijectivity. + +1. ``biject_to(constraint)`` looks up a bijective + :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` + to the given ``constraint``. The returned transform is guaranteed to have + ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``. +2. ``transform_to(constraint)`` looks up a not-necessarily bijective + :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` + to the given ``constraint``. The returned transform is not guaranteed to + implement ``.log_abs_det_jacobian()``. + +The ``transform_to()`` registry is useful for performing unconstrained +optimization on constrained parameters of probability distributions, which are +indicated by each distribution's ``.arg_constraints`` dict. These transforms often +overparameterize a space in order to avoid rotation; they are thus more +suitable for coordinate-wise optimization algorithms like Adam:: + + loc = torch.zeros(100, requires_grad=True) + unconstrained = torch.zeros(100, requires_grad=True) + scale = transform_to(Normal.arg_constraints['scale'])(unconstrained) + loss = -Normal(loc, scale).log_prob(data).sum() + +The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where +samples from a probability distribution with constrained ``.support`` are +propagated in an unconstrained space, and algorithms are typically rotation +invariant.:: + + dist = Exponential(rate) + unconstrained = torch.zeros(100, requires_grad=True) + sample = biject_to(dist.support)(unconstrained) + potential_energy = -dist.log_prob(sample).sum() + +.. note:: + + An example where ``transform_to`` and ``biject_to`` differ is + ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a + :class:`~torch.distributions.transforms.SoftmaxTransform` that simply + exponentiates and normalizes its inputs; this is a cheap and mostly + coordinate-wise operation appropriate for algorithms like SVI. In + contrast, ``biject_to(constraints.simplex)`` returns a + :class:`~torch.distributions.transforms.StickBreakingTransform` that + bijects its input down to a one-fewer-dimensional space; this a more + expensive less numerically stable transform but is needed for algorithms + like HMC. + +The ``biject_to`` and ``transform_to`` objects can be extended by user-defined +constraints and transforms using their ``.register()`` method either as a +function on singleton constraints:: + + transform_to.register(my_constraint, my_transform) + +or as a decorator on parameterized constraints:: + + @transform_to.register(MyConstraintClass) + def my_factory(constraint): + assert isinstance(constraint, MyConstraintClass) + return MyTransform(constraint.param1, constraint.param2) + +You can create your own registry by creating a new :class:`ConstraintRegistry` +object. +""" + +import numbers + +from torch.distributions import constraints, transforms + + +__all__ = [ + "ConstraintRegistry", + "biject_to", + "transform_to", +] + + +class ConstraintRegistry: + """ + Registry to link constraints to transforms. + """ + + def __init__(self): + self._registry = {} + super().__init__() + + def register(self, constraint, factory=None): + """ + Registers a :class:`~torch.distributions.constraints.Constraint` + subclass in this registry. Usage:: + + @my_registry.register(MyConstraintClass) + def construct_transform(constraint): + assert isinstance(constraint, MyConstraint) + return MyTransform(constraint.arg_constraints) + + Args: + constraint (subclass of :class:`~torch.distributions.constraints.Constraint`): + A subclass of :class:`~torch.distributions.constraints.Constraint`, or + a singleton object of the desired class. + factory (Callable): A callable that inputs a constraint object and returns + a :class:`~torch.distributions.transforms.Transform` object. + """ + # Support use as decorator. + if factory is None: + return lambda factory: self.register(constraint, factory) + + # Support calling on singleton instances. + if isinstance(constraint, constraints.Constraint): + constraint = type(constraint) + + if not isinstance(constraint, type) or not issubclass( + constraint, constraints.Constraint + ): + raise TypeError( + f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}" + ) + + self._registry[constraint] = factory + return factory + + def __call__(self, constraint): + """ + Looks up a transform to constrained space, given a constraint object. + Usage:: + + constraint = Normal.arg_constraints['scale'] + scale = transform_to(constraint)(torch.zeros(1)) # constrained + u = transform_to(constraint).inv(scale) # unconstrained + + Args: + constraint (:class:`~torch.distributions.constraints.Constraint`): + A constraint object. + + Returns: + A :class:`~torch.distributions.transforms.Transform` object. + + Raises: + `NotImplementedError` if no transform has been registered. + """ + # Look up by Constraint subclass. + try: + factory = self._registry[type(constraint)] + except KeyError: + raise NotImplementedError( + f"Cannot transform {type(constraint).__name__} constraints" + ) from None + return factory(constraint) + + +biject_to = ConstraintRegistry() +transform_to = ConstraintRegistry() + + +################################################################################ +# Registration Table +################################################################################ + + +@biject_to.register(constraints.real) +@transform_to.register(constraints.real) +def _transform_to_real(constraint): + return transforms.identity_transform + + +@biject_to.register(constraints.independent) +def _biject_to_independent(constraint): + base_transform = biject_to(constraint.base_constraint) + return transforms.IndependentTransform( + base_transform, constraint.reinterpreted_batch_ndims + ) + + +@transform_to.register(constraints.independent) +def _transform_to_independent(constraint): + base_transform = transform_to(constraint.base_constraint) + return transforms.IndependentTransform( + base_transform, constraint.reinterpreted_batch_ndims + ) + + +@biject_to.register(constraints.positive) +@biject_to.register(constraints.nonnegative) +@transform_to.register(constraints.positive) +@transform_to.register(constraints.nonnegative) +def _transform_to_positive(constraint): + return transforms.ExpTransform() + + +@biject_to.register(constraints.greater_than) +@biject_to.register(constraints.greater_than_eq) +@transform_to.register(constraints.greater_than) +@transform_to.register(constraints.greater_than_eq) +def _transform_to_greater_than(constraint): + return transforms.ComposeTransform( + [ + transforms.ExpTransform(), + transforms.AffineTransform(constraint.lower_bound, 1), + ] + ) + + +@biject_to.register(constraints.less_than) +@transform_to.register(constraints.less_than) +def _transform_to_less_than(constraint): + return transforms.ComposeTransform( + [ + transforms.ExpTransform(), + transforms.AffineTransform(constraint.upper_bound, -1), + ] + ) + + +@biject_to.register(constraints.interval) +@biject_to.register(constraints.half_open_interval) +@transform_to.register(constraints.interval) +@transform_to.register(constraints.half_open_interval) +def _transform_to_interval(constraint): + # Handle the special case of the unit interval. + lower_is_0 = ( + isinstance(constraint.lower_bound, numbers.Number) + and constraint.lower_bound == 0 + ) + upper_is_1 = ( + isinstance(constraint.upper_bound, numbers.Number) + and constraint.upper_bound == 1 + ) + if lower_is_0 and upper_is_1: + return transforms.SigmoidTransform() + + loc = constraint.lower_bound + scale = constraint.upper_bound - constraint.lower_bound + return transforms.ComposeTransform( + [transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)] + ) + + +@biject_to.register(constraints.simplex) +def _biject_to_simplex(constraint): + return transforms.StickBreakingTransform() + + +@transform_to.register(constraints.simplex) +def _transform_to_simplex(constraint): + return transforms.SoftmaxTransform() + + +# TODO define a bijection for LowerCholeskyTransform +@transform_to.register(constraints.lower_cholesky) +def _transform_to_lower_cholesky(constraint): + return transforms.LowerCholeskyTransform() + + +@transform_to.register(constraints.positive_definite) +@transform_to.register(constraints.positive_semidefinite) +def _transform_to_positive_definite(constraint): + return transforms.PositiveDefiniteTransform() + + +@biject_to.register(constraints.corr_cholesky) +@transform_to.register(constraints.corr_cholesky) +def _transform_to_corr_cholesky(constraint): + return transforms.CorrCholeskyTransform() + + +@biject_to.register(constraints.cat) +def _biject_to_cat(constraint): + return transforms.CatTransform( + [biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths + ) + + +@transform_to.register(constraints.cat) +def _transform_to_cat(constraint): + return transforms.CatTransform( + [transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths + ) + + +@biject_to.register(constraints.stack) +def _biject_to_stack(constraint): + return transforms.StackTransform( + [biject_to(c) for c in constraint.cseq], constraint.dim + ) + + +@transform_to.register(constraints.stack) +def _transform_to_stack(constraint): + return transforms.StackTransform( + [transform_to(c) for c in constraint.cseq], constraint.dim + ) diff --git a/lib/python3.10/site-packages/torch/distributions/constraints.py b/lib/python3.10/site-packages/torch/distributions/constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..3c510bd32abc62cbe3be05eeeb591965c3db7861 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/constraints.py @@ -0,0 +1,681 @@ +# mypy: allow-untyped-defs +r""" +The following constraints are implemented: + +- ``constraints.boolean`` +- ``constraints.cat`` +- ``constraints.corr_cholesky`` +- ``constraints.dependent`` +- ``constraints.greater_than(lower_bound)`` +- ``constraints.greater_than_eq(lower_bound)`` +- ``constraints.independent(constraint, reinterpreted_batch_ndims)`` +- ``constraints.integer_interval(lower_bound, upper_bound)`` +- ``constraints.interval(lower_bound, upper_bound)`` +- ``constraints.less_than(upper_bound)`` +- ``constraints.lower_cholesky`` +- ``constraints.lower_triangular`` +- ``constraints.multinomial`` +- ``constraints.nonnegative`` +- ``constraints.nonnegative_integer`` +- ``constraints.one_hot`` +- ``constraints.positive_integer`` +- ``constraints.positive`` +- ``constraints.positive_semidefinite`` +- ``constraints.positive_definite`` +- ``constraints.real_vector`` +- ``constraints.real`` +- ``constraints.simplex`` +- ``constraints.symmetric`` +- ``constraints.stack`` +- ``constraints.square`` +- ``constraints.symmetric`` +- ``constraints.unit_interval`` +""" + +import torch + + +__all__ = [ + "Constraint", + "boolean", + "cat", + "corr_cholesky", + "dependent", + "dependent_property", + "greater_than", + "greater_than_eq", + "independent", + "integer_interval", + "interval", + "half_open_interval", + "is_dependent", + "less_than", + "lower_cholesky", + "lower_triangular", + "multinomial", + "nonnegative", + "nonnegative_integer", + "one_hot", + "positive", + "positive_semidefinite", + "positive_definite", + "positive_integer", + "real", + "real_vector", + "simplex", + "square", + "stack", + "symmetric", + "unit_interval", +] + + +class Constraint: + """ + Abstract base class for constraints. + + A constraint object represents a region over which a variable is valid, + e.g. within which a variable can be optimized. + + Attributes: + is_discrete (bool): Whether constrained space is discrete. + Defaults to False. + event_dim (int): Number of rightmost dimensions that together define + an event. The :meth:`check` method will remove this many dimensions + when computing validity. + """ + + is_discrete = False # Default to continuous. + event_dim = 0 # Default to univariate. + + def check(self, value): + """ + Returns a byte tensor of ``sample_shape + batch_shape`` indicating + whether each event in value satisfies this constraint. + """ + raise NotImplementedError + + def __repr__(self): + return self.__class__.__name__[1:] + "()" + + +class _Dependent(Constraint): + """ + Placeholder for variables whose support depends on other variables. + These variables obey no simple coordinate-wise constraints. + + Args: + is_discrete (bool): Optional value of ``.is_discrete`` in case this + can be computed statically. If not provided, access to the + ``.is_discrete`` attribute will raise a NotImplementedError. + event_dim (int): Optional value of ``.event_dim`` in case this + can be computed statically. If not provided, access to the + ``.event_dim`` attribute will raise a NotImplementedError. + """ + + def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): + self._is_discrete = is_discrete + self._event_dim = event_dim + super().__init__() + + @property + def is_discrete(self): + if self._is_discrete is NotImplemented: + raise NotImplementedError(".is_discrete cannot be determined statically") + return self._is_discrete + + @property + def event_dim(self): + if self._event_dim is NotImplemented: + raise NotImplementedError(".event_dim cannot be determined statically") + return self._event_dim + + def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): + """ + Support for syntax to customize static attributes:: + + constraints.dependent(is_discrete=True, event_dim=1) + """ + if is_discrete is NotImplemented: + is_discrete = self._is_discrete + if event_dim is NotImplemented: + event_dim = self._event_dim + return _Dependent(is_discrete=is_discrete, event_dim=event_dim) + + def check(self, x): + raise ValueError("Cannot determine validity of dependent constraint") + + +def is_dependent(constraint): + """ + Checks if ``constraint`` is a ``_Dependent`` object. + + Args: + constraint : A ``Constraint`` object. + + Returns: + ``bool``: True if ``constraint`` can be refined to the type ``_Dependent``, False otherwise. + + Examples: + >>> import torch + >>> from torch.distributions import Bernoulli + >>> from torch.distributions.constraints import is_dependent + + >>> dist = Bernoulli(probs = torch.tensor([0.6], requires_grad=True)) + >>> constraint1 = dist.arg_constraints["probs"] + >>> constraint2 = dist.arg_constraints["logits"] + + >>> for constraint in [constraint1, constraint2]: + >>> if is_dependent(constraint): + >>> continue + """ + return isinstance(constraint, _Dependent) + + +class _DependentProperty(property, _Dependent): + """ + Decorator that extends @property to act like a `Dependent` constraint when + called on a class and act like a property when called on an object. + + Example:: + + class Uniform(Distribution): + def __init__(self, low, high): + self.low = low + self.high = high + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return constraints.interval(self.low, self.high) + + Args: + fn (Callable): The function to be decorated. + is_discrete (bool): Optional value of ``.is_discrete`` in case this + can be computed statically. If not provided, access to the + ``.is_discrete`` attribute will raise a NotImplementedError. + event_dim (int): Optional value of ``.event_dim`` in case this + can be computed statically. If not provided, access to the + ``.event_dim`` attribute will raise a NotImplementedError. + """ + + def __init__( + self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented + ): + super().__init__(fn) + self._is_discrete = is_discrete + self._event_dim = event_dim + + def __call__(self, fn): + """ + Support for syntax to customize static attributes:: + + @constraints.dependent_property(is_discrete=True, event_dim=1) + def support(self): + ... + """ + return _DependentProperty( + fn, is_discrete=self._is_discrete, event_dim=self._event_dim + ) + + +class _IndependentConstraint(Constraint): + """ + Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many + dims in :meth:`check`, so that an event is valid only if all its + independent entries are valid. + """ + + def __init__(self, base_constraint, reinterpreted_batch_ndims): + assert isinstance(base_constraint, Constraint) + assert isinstance(reinterpreted_batch_ndims, int) + assert reinterpreted_batch_ndims >= 0 + self.base_constraint = base_constraint + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + super().__init__() + + @property + def is_discrete(self): + return self.base_constraint.is_discrete + + @property + def event_dim(self): + return self.base_constraint.event_dim + self.reinterpreted_batch_ndims + + def check(self, value): + result = self.base_constraint.check(value) + if result.dim() < self.reinterpreted_batch_ndims: + expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims + raise ValueError( + f"Expected value.dim() >= {expected} but got {value.dim()}" + ) + result = result.reshape( + result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,) + ) + result = result.all(-1) + return result + + def __repr__(self): + return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})" + + +class _Boolean(Constraint): + """ + Constrain to the two values `{0, 1}`. + """ + + is_discrete = True + + def check(self, value): + return (value == 0) | (value == 1) + + +class _OneHot(Constraint): + """ + Constrain to one-hot vectors. + """ + + is_discrete = True + event_dim = 1 + + def check(self, value): + is_boolean = (value == 0) | (value == 1) + is_normalized = value.sum(-1).eq(1) + return is_boolean.all(-1) & is_normalized + + +class _IntegerInterval(Constraint): + """ + Constrain to an integer interval `[lower_bound, upper_bound]`. + """ + + is_discrete = True + + def __init__(self, lower_bound, upper_bound): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return ( + (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) + ) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += ( + f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" + ) + return fmt_string + + +class _IntegerLessThan(Constraint): + """ + Constrain to an integer interval `(-inf, upper_bound]`. + """ + + is_discrete = True + + def __init__(self, upper_bound): + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return (value % 1 == 0) & (value <= self.upper_bound) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(upper_bound={self.upper_bound})" + return fmt_string + + +class _IntegerGreaterThan(Constraint): + """ + Constrain to an integer interval `[lower_bound, inf)`. + """ + + is_discrete = True + + def __init__(self, lower_bound): + self.lower_bound = lower_bound + super().__init__() + + def check(self, value): + return (value % 1 == 0) & (value >= self.lower_bound) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(lower_bound={self.lower_bound})" + return fmt_string + + +class _Real(Constraint): + """ + Trivially constrain to the extended real line `[-inf, inf]`. + """ + + def check(self, value): + return value == value # False for NANs. + + +class _GreaterThan(Constraint): + """ + Constrain to a real half line `(lower_bound, inf]`. + """ + + def __init__(self, lower_bound): + self.lower_bound = lower_bound + super().__init__() + + def check(self, value): + return self.lower_bound < value + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(lower_bound={self.lower_bound})" + return fmt_string + + +class _GreaterThanEq(Constraint): + """ + Constrain to a real half line `[lower_bound, inf)`. + """ + + def __init__(self, lower_bound): + self.lower_bound = lower_bound + super().__init__() + + def check(self, value): + return self.lower_bound <= value + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(lower_bound={self.lower_bound})" + return fmt_string + + +class _LessThan(Constraint): + """ + Constrain to a real half line `[-inf, upper_bound)`. + """ + + def __init__(self, upper_bound): + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return value < self.upper_bound + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(upper_bound={self.upper_bound})" + return fmt_string + + +class _Interval(Constraint): + """ + Constrain to a real interval `[lower_bound, upper_bound]`. + """ + + def __init__(self, lower_bound, upper_bound): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return (self.lower_bound <= value) & (value <= self.upper_bound) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += ( + f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" + ) + return fmt_string + + +class _HalfOpenInterval(Constraint): + """ + Constrain to a real interval `[lower_bound, upper_bound)`. + """ + + def __init__(self, lower_bound, upper_bound): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return (self.lower_bound <= value) & (value < self.upper_bound) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += ( + f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" + ) + return fmt_string + + +class _Simplex(Constraint): + """ + Constrain to the unit simplex in the innermost (rightmost) dimension. + Specifically: `x >= 0` and `x.sum(-1) == 1`. + """ + + event_dim = 1 + + def check(self, value): + return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6) + + +class _Multinomial(Constraint): + """ + Constrain to nonnegative integer values summing to at most an upper bound. + + Note due to limitations of the Multinomial distribution, this currently + checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future + this may be strengthened to ``value.sum(-1) == upper_bound``. + """ + + is_discrete = True + event_dim = 1 + + def __init__(self, upper_bound): + self.upper_bound = upper_bound + + def check(self, x): + return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound) + + +class _LowerTriangular(Constraint): + """ + Constrain to lower-triangular square matrices. + """ + + event_dim = 2 + + def check(self, value): + value_tril = value.tril() + return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + + +class _LowerCholesky(Constraint): + """ + Constrain to lower-triangular square matrices with positive diagonals. + """ + + event_dim = 2 + + def check(self, value): + value_tril = value.tril() + lower_triangular = ( + (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + ) + + positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0] + return lower_triangular & positive_diagonal + + +class _CorrCholesky(Constraint): + """ + Constrain to lower-triangular square matrices with positive diagonals and each + row vector being of unit length. + """ + + event_dim = 2 + + def check(self, value): + tol = ( + torch.finfo(value.dtype).eps * value.size(-1) * 10 + ) # 10 is an adjustable fudge factor + row_norm = torch.linalg.norm(value.detach(), dim=-1) + unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1) + return _LowerCholesky().check(value) & unit_row_norm + + +class _Square(Constraint): + """ + Constrain to square matrices. + """ + + event_dim = 2 + + def check(self, value): + return torch.full( + size=value.shape[:-2], + fill_value=(value.shape[-2] == value.shape[-1]), + dtype=torch.bool, + device=value.device, + ) + + +class _Symmetric(_Square): + """ + Constrain to Symmetric square matrices. + """ + + def check(self, value): + square_check = super().check(value) + if not square_check.all(): + return square_check + return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1) + + +class _PositiveSemidefinite(_Symmetric): + """ + Constrain to positive-semidefinite matrices. + """ + + def check(self, value): + sym_check = super().check(value) + if not sym_check.all(): + return sym_check + return torch.linalg.eigvalsh(value).ge(0).all(-1) + + +class _PositiveDefinite(_Symmetric): + """ + Constrain to positive-definite matrices. + """ + + def check(self, value): + sym_check = super().check(value) + if not sym_check.all(): + return sym_check + return torch.linalg.cholesky_ex(value).info.eq(0) + + +class _Cat(Constraint): + """ + Constraint functor that applies a sequence of constraints + `cseq` at the submatrices at dimension `dim`, + each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`. + """ + + def __init__(self, cseq, dim=0, lengths=None): + assert all(isinstance(c, Constraint) for c in cseq) + self.cseq = list(cseq) + if lengths is None: + lengths = [1] * len(self.cseq) + self.lengths = list(lengths) + assert len(self.lengths) == len(self.cseq) + self.dim = dim + super().__init__() + + @property + def is_discrete(self): + return any(c.is_discrete for c in self.cseq) + + @property + def event_dim(self): + return max(c.event_dim for c in self.cseq) + + def check(self, value): + assert -value.dim() <= self.dim < value.dim() + checks = [] + start = 0 + for constr, length in zip(self.cseq, self.lengths): + v = value.narrow(self.dim, start, length) + checks.append(constr.check(v)) + start = start + length # avoid += for jit compat + return torch.cat(checks, self.dim) + + +class _Stack(Constraint): + """ + Constraint functor that applies a sequence of constraints + `cseq` at the submatrices at dimension `dim`, + in a way compatible with :func:`torch.stack`. + """ + + def __init__(self, cseq, dim=0): + assert all(isinstance(c, Constraint) for c in cseq) + self.cseq = list(cseq) + self.dim = dim + super().__init__() + + @property + def is_discrete(self): + return any(c.is_discrete for c in self.cseq) + + @property + def event_dim(self): + dim = max(c.event_dim for c in self.cseq) + if self.dim + dim < 0: + dim += 1 + return dim + + def check(self, value): + assert -value.dim() <= self.dim < value.dim() + vs = [value.select(self.dim, i) for i in range(value.size(self.dim))] + return torch.stack( + [constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim + ) + + +# Public interface. +dependent = _Dependent() +dependent_property = _DependentProperty +independent = _IndependentConstraint +boolean = _Boolean() +one_hot = _OneHot() +nonnegative_integer = _IntegerGreaterThan(0) +positive_integer = _IntegerGreaterThan(1) +integer_interval = _IntegerInterval +real = _Real() +real_vector = independent(real, 1) +positive = _GreaterThan(0.0) +nonnegative = _GreaterThanEq(0.0) +greater_than = _GreaterThan +greater_than_eq = _GreaterThanEq +less_than = _LessThan +multinomial = _Multinomial +unit_interval = _Interval(0.0, 1.0) +interval = _Interval +half_open_interval = _HalfOpenInterval +simplex = _Simplex() +lower_triangular = _LowerTriangular() +lower_cholesky = _LowerCholesky() +corr_cholesky = _CorrCholesky() +square = _Square() +symmetric = _Symmetric() +positive_semidefinite = _PositiveSemidefinite() +positive_definite = _PositiveDefinite() +cat = _Cat +stack = _Stack diff --git a/lib/python3.10/site-packages/torch/distributions/continuous_bernoulli.py b/lib/python3.10/site-packages/torch/distributions/continuous_bernoulli.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb49f951f1c06b35438cdb4022a6b9003857df2 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/continuous_bernoulli.py @@ -0,0 +1,238 @@ +# mypy: allow-untyped-defs +import math +from numbers import Number + +import torch +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import ( + broadcast_all, + clamp_probs, + lazy_property, + logits_to_probs, + probs_to_logits, +) +from torch.nn.functional import binary_cross_entropy_with_logits +from torch.types import _size + + +__all__ = ["ContinuousBernoulli"] + + +class ContinuousBernoulli(ExponentialFamily): + r""" + Creates a continuous Bernoulli distribution parameterized by :attr:`probs` + or :attr:`logits` (but not both). + + The distribution is supported in [0, 1] and parameterized by 'probs' (in + (0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs' + does not correspond to a probability and 'logits' does not correspond to + log-odds, but the same names are used due to the similarity with the + Bernoulli. See [1] for more details. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = ContinuousBernoulli(torch.tensor([0.3])) + >>> m.sample() + tensor([ 0.2538]) + + Args: + probs (Number, Tensor): (0,1) valued parameters + logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs' + + [1] The continuous Bernoulli: fixing a pervasive error in variational + autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019. + https://arxiv.org/abs/1907.06845 + """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.unit_interval + _mean_carrier_measure = 0 + has_rsample = True + + def __init__( + self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None + ): + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + is_scalar = isinstance(probs, Number) + (self.probs,) = broadcast_all(probs) + # validate 'probs' here if necessary as it is later clamped for numerical stability + # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass + if validate_args is not None: + if not self.arg_constraints["probs"].check(self.probs).all(): + raise ValueError("The parameter probs has invalid values") + self.probs = clamp_probs(self.probs) + else: + is_scalar = isinstance(logits, Number) + (self.logits,) = broadcast_all(logits) + self._param = self.probs if probs is not None else self.logits + if is_scalar: + batch_shape = torch.Size() + else: + batch_shape = self._param.size() + self._lims = lims + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ContinuousBernoulli, _instance) + new._lims = self._lims + batch_shape = torch.Size(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + def _outside_unstable_region(self): + return torch.max( + torch.le(self.probs, self._lims[0]), torch.gt(self.probs, self._lims[1]) + ) + + def _cut_probs(self): + return torch.where( + self._outside_unstable_region(), + self.probs, + self._lims[0] * torch.ones_like(self.probs), + ) + + def _cont_bern_log_norm(self): + """computes the log normalizing constant as a function of the 'probs' parameter""" + cut_probs = self._cut_probs() + cut_probs_below_half = torch.where( + torch.le(cut_probs, 0.5), cut_probs, torch.zeros_like(cut_probs) + ) + cut_probs_above_half = torch.where( + torch.ge(cut_probs, 0.5), cut_probs, torch.ones_like(cut_probs) + ) + log_norm = torch.log( + torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs)) + ) - torch.where( + torch.le(cut_probs, 0.5), + torch.log1p(-2.0 * cut_probs_below_half), + torch.log(2.0 * cut_probs_above_half - 1.0), + ) + x = torch.pow(self.probs - 0.5, 2) + taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x + return torch.where(self._outside_unstable_region(), log_norm, taylor) + + @property + def mean(self): + cut_probs = self._cut_probs() + mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / ( + torch.log1p(-cut_probs) - torch.log(cut_probs) + ) + x = self.probs - 0.5 + taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x + return torch.where(self._outside_unstable_region(), mus, taylor) + + @property + def stddev(self): + return torch.sqrt(self.variance) + + @property + def variance(self): + cut_probs = self._cut_probs() + vars = cut_probs * (cut_probs - 1.0) / torch.pow( + 1.0 - 2.0 * cut_probs, 2 + ) + 1.0 / torch.pow(torch.log1p(-cut_probs) - torch.log(cut_probs), 2) + x = torch.pow(self.probs - 0.5, 2) + taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x + return torch.where(self._outside_unstable_region(), vars, taylor) + + @lazy_property + def logits(self): + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self): + return clamp_probs(logits_to_probs(self.logits, is_binary=True)) + + @property + def param_shape(self): + return self._param.size() + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device) + with torch.no_grad(): + return self.icdf(u) + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device) + return self.icdf(u) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + return ( + -binary_cross_entropy_with_logits(logits, value, reduction="none") + + self._cont_bern_log_norm() + ) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + cut_probs = self._cut_probs() + cdfs = ( + torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value) + + cut_probs + - 1.0 + ) / (2.0 * cut_probs - 1.0) + unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value) + return torch.where( + torch.le(value, 0.0), + torch.zeros_like(value), + torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs), + ) + + def icdf(self, value): + cut_probs = self._cut_probs() + return torch.where( + self._outside_unstable_region(), + ( + torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0)) + - torch.log1p(-cut_probs) + ) + / (torch.log(cut_probs) - torch.log1p(-cut_probs)), + value, + ) + + def entropy(self): + log_probs0 = torch.log1p(-self.probs) + log_probs1 = torch.log(self.probs) + return ( + self.mean * (log_probs0 - log_probs1) + - self._cont_bern_log_norm() + - log_probs0 + ) + + @property + def _natural_params(self): + return (self.logits,) + + def _log_normalizer(self, x): + """computes the log normalizing constant as a function of the natural parameter""" + out_unst_reg = torch.max( + torch.le(x, self._lims[0] - 0.5), torch.gt(x, self._lims[1] - 0.5) + ) + cut_nat_params = torch.where( + out_unst_reg, x, (self._lims[0] - 0.5) * torch.ones_like(x) + ) + log_norm = torch.log(torch.abs(torch.exp(cut_nat_params) - 1.0)) - torch.log( + torch.abs(cut_nat_params) + ) + taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0 + return torch.where(out_unst_reg, log_norm, taylor) diff --git a/lib/python3.10/site-packages/torch/distributions/dirichlet.py b/lib/python3.10/site-packages/torch/distributions/dirichlet.py new file mode 100644 index 0000000000000000000000000000000000000000..25e7bb9cd7c2383409de8ed10fd911e69f3b23c7 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/dirichlet.py @@ -0,0 +1,126 @@ +# mypy: allow-untyped-defs +import torch +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.types import _size + + +__all__ = ["Dirichlet"] + + +# This helper is exposed for testing. +def _Dirichlet_backward(x, concentration, grad_output): + total = concentration.sum(-1, True).expand_as(concentration) + grad = torch._dirichlet_grad(x, concentration, total) + return grad * (grad_output - (x * grad_output).sum(-1, True)) + + +class _Dirichlet(Function): + @staticmethod + def forward(ctx, concentration): + x = torch._sample_dirichlet(concentration) + ctx.save_for_backward(x, concentration) + return x + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + x, concentration = ctx.saved_tensors + return _Dirichlet_backward(x, concentration, grad_output) + + +class Dirichlet(ExponentialFamily): + r""" + Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Dirichlet(torch.tensor([0.5, 0.5])) + >>> m.sample() # Dirichlet distributed with concentration [0.5, 0.5] + tensor([ 0.1046, 0.8954]) + + Args: + concentration (Tensor): concentration parameter of the distribution + (often referred to as alpha) + """ + arg_constraints = { + "concentration": constraints.independent(constraints.positive, 1) + } + support = constraints.simplex + has_rsample = True + + def __init__(self, concentration, validate_args=None): + if concentration.dim() < 1: + raise ValueError( + "`concentration` parameter must be at least one-dimensional." + ) + self.concentration = concentration + batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Dirichlet, _instance) + batch_shape = torch.Size(batch_shape) + new.concentration = self.concentration.expand(batch_shape + self.event_shape) + super(Dirichlet, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape: _size = ()) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + concentration = self.concentration.expand(shape) + return _Dirichlet.apply(concentration) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return ( + torch.xlogy(self.concentration - 1.0, value).sum(-1) + + torch.lgamma(self.concentration.sum(-1)) + - torch.lgamma(self.concentration).sum(-1) + ) + + @property + def mean(self): + return self.concentration / self.concentration.sum(-1, True) + + @property + def mode(self): + concentrationm1 = (self.concentration - 1).clamp(min=0.0) + mode = concentrationm1 / concentrationm1.sum(-1, True) + mask = (self.concentration < 1).all(axis=-1) + mode[mask] = torch.nn.functional.one_hot( + mode[mask].argmax(axis=-1), concentrationm1.shape[-1] + ).to(mode) + return mode + + @property + def variance(self): + con0 = self.concentration.sum(-1, True) + return ( + self.concentration + * (con0 - self.concentration) + / (con0.pow(2) * (con0 + 1)) + ) + + def entropy(self): + k = self.concentration.size(-1) + a0 = self.concentration.sum(-1) + return ( + torch.lgamma(self.concentration).sum(-1) + - torch.lgamma(a0) + - (k - a0) * torch.digamma(a0) + - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1) + ) + + @property + def _natural_params(self): + return (self.concentration,) + + def _log_normalizer(self, x): + return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1)) diff --git a/lib/python3.10/site-packages/torch/distributions/distribution.py b/lib/python3.10/site-packages/torch/distributions/distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..1c3bdf9e85cd52aad69ef9342574604059e2e651 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/distribution.py @@ -0,0 +1,340 @@ +# mypy: allow-untyped-defs +import warnings +from typing import Any, Dict, Optional +from typing_extensions import deprecated + +import torch +from torch.distributions import constraints +from torch.distributions.utils import lazy_property +from torch.types import _size + + +__all__ = ["Distribution"] + + +class Distribution: + r""" + Distribution is the abstract base class for probability distributions. + """ + + has_rsample = False + has_enumerate_support = False + _validate_args = __debug__ + + @staticmethod + def set_default_validate_args(value: bool) -> None: + """ + Sets whether validation is enabled or disabled. + + The default behavior mimics Python's ``assert`` statement: validation + is on by default, but is disabled if Python is run in optimized mode + (via ``python -O``). Validation may be expensive, so you may want to + disable it once a model is working. + + Args: + value (bool): Whether to enable validation. + """ + if value not in [True, False]: + raise ValueError + Distribution._validate_args = value + + def __init__( + self, + batch_shape: torch.Size = torch.Size(), + event_shape: torch.Size = torch.Size(), + validate_args: Optional[bool] = None, + ): + self._batch_shape = batch_shape + self._event_shape = event_shape + if validate_args is not None: + self._validate_args = validate_args + if self._validate_args: + try: + arg_constraints = self.arg_constraints + except NotImplementedError: + arg_constraints = {} + warnings.warn( + f"{self.__class__} does not define `arg_constraints`. " + + "Please set `arg_constraints = {}` or initialize the distribution " + + "with `validate_args=False` to turn off validation." + ) + for param, constraint in arg_constraints.items(): + if constraints.is_dependent(constraint): + continue # skip constraints that cannot be checked + if param not in self.__dict__ and isinstance( + getattr(type(self), param), lazy_property + ): + continue # skip checking lazily-constructed args + value = getattr(self, param) + valid = constraint.check(value) + if not valid.all(): + raise ValueError( + f"Expected parameter {param} " + f"({type(value).__name__} of shape {tuple(value.shape)}) " + f"of distribution {repr(self)} " + f"to satisfy the constraint {repr(constraint)}, " + f"but found invalid values:\n{value}" + ) + super().__init__() + + def expand(self, batch_shape: _size, _instance=None): + """ + Returns a new distribution instance (or populates an existing instance + provided by a derived class) with batch dimensions expanded to + `batch_shape`. This method calls :class:`~torch.Tensor.expand` on + the distribution's parameters. As such, this does not allocate new + memory for the expanded distribution instance. Additionally, + this does not repeat any args checking or parameter broadcasting in + `__init__.py`, when an instance is first created. + + Args: + batch_shape (torch.Size): the desired expanded size. + _instance: new instance provided by subclasses that + need to override `.expand`. + + Returns: + New distribution instance with batch dimensions expanded to + `batch_size`. + """ + raise NotImplementedError + + @property + def batch_shape(self) -> torch.Size: + """ + Returns the shape over which parameters are batched. + """ + return self._batch_shape + + @property + def event_shape(self) -> torch.Size: + """ + Returns the shape of a single sample (without batching). + """ + return self._event_shape + + @property + def arg_constraints(self) -> Dict[str, constraints.Constraint]: + """ + Returns a dictionary from argument names to + :class:`~torch.distributions.constraints.Constraint` objects that + should be satisfied by each argument of this distribution. Args that + are not tensors need not appear in this dict. + """ + raise NotImplementedError + + @property + def support(self) -> Optional[Any]: + """ + Returns a :class:`~torch.distributions.constraints.Constraint` object + representing this distribution's support. + """ + raise NotImplementedError + + @property + def mean(self) -> torch.Tensor: + """ + Returns the mean of the distribution. + """ + raise NotImplementedError + + @property + def mode(self) -> torch.Tensor: + """ + Returns the mode of the distribution. + """ + raise NotImplementedError(f"{self.__class__} does not implement mode") + + @property + def variance(self) -> torch.Tensor: + """ + Returns the variance of the distribution. + """ + raise NotImplementedError + + @property + def stddev(self) -> torch.Tensor: + """ + Returns the standard deviation of the distribution. + """ + return self.variance.sqrt() + + def sample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + """ + Generates a sample_shape shaped sample or sample_shape shaped batch of + samples if the distribution parameters are batched. + """ + with torch.no_grad(): + return self.rsample(sample_shape) + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + """ + Generates a sample_shape shaped reparameterized sample or sample_shape + shaped batch of reparameterized samples if the distribution parameters + are batched. + """ + raise NotImplementedError + + @deprecated( + "`sample_n(n)` will be deprecated. Use `sample((n,))` instead.", + category=FutureWarning, + ) + def sample_n(self, n: int) -> torch.Tensor: + """ + Generates n samples or n batches of samples if the distribution + parameters are batched. + """ + return self.sample(torch.Size((n,))) + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + """ + Returns the log of the probability density/mass function evaluated at + `value`. + + Args: + value (Tensor): + """ + raise NotImplementedError + + def cdf(self, value: torch.Tensor) -> torch.Tensor: + """ + Returns the cumulative density/mass function evaluated at + `value`. + + Args: + value (Tensor): + """ + raise NotImplementedError + + def icdf(self, value: torch.Tensor) -> torch.Tensor: + """ + Returns the inverse cumulative density/mass function evaluated at + `value`. + + Args: + value (Tensor): + """ + raise NotImplementedError + + def enumerate_support(self, expand: bool = True) -> torch.Tensor: + """ + Returns tensor containing all values supported by a discrete + distribution. The result will enumerate over dimension 0, so the shape + of the result will be `(cardinality,) + batch_shape + event_shape` + (where `event_shape = ()` for univariate distributions). + + Note that this enumerates over all batched tensors in lock-step + `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens + along dim 0, but with the remaining batch dimensions being + singleton dimensions, `[[0], [1], ..`. + + To iterate over the full Cartesian product use + `itertools.product(m.enumerate_support())`. + + Args: + expand (bool): whether to expand the support over the + batch dims to match the distribution's `batch_shape`. + + Returns: + Tensor iterating over dimension 0. + """ + raise NotImplementedError + + def entropy(self) -> torch.Tensor: + """ + Returns entropy of distribution, batched over batch_shape. + + Returns: + Tensor of shape batch_shape. + """ + raise NotImplementedError + + def perplexity(self) -> torch.Tensor: + """ + Returns perplexity of distribution, batched over batch_shape. + + Returns: + Tensor of shape batch_shape. + """ + return torch.exp(self.entropy()) + + def _extended_shape(self, sample_shape: _size = torch.Size()) -> torch.Size: + """ + Returns the size of the sample returned by the distribution, given + a `sample_shape`. Note, that the batch and event shapes of a distribution + instance are fixed at the time of construction. If this is empty, the + returned shape is upcast to (1,). + + Args: + sample_shape (torch.Size): the size of the sample to be drawn. + """ + if not isinstance(sample_shape, torch.Size): + sample_shape = torch.Size(sample_shape) + return torch.Size(sample_shape + self._batch_shape + self._event_shape) + + def _validate_sample(self, value: torch.Tensor) -> None: + """ + Argument validation for distribution methods such as `log_prob`, + `cdf` and `icdf`. The rightmost dimensions of a value to be + scored via these methods must agree with the distribution's batch + and event shapes. + + Args: + value (Tensor): the tensor whose log probability is to be + computed by the `log_prob` method. + Raises + ValueError: when the rightmost dimensions of `value` do not match the + distribution's batch and event shapes. + """ + if not isinstance(value, torch.Tensor): + raise ValueError("The value argument to log_prob must be a Tensor") + + event_dim_start = len(value.size()) - len(self._event_shape) + if value.size()[event_dim_start:] != self._event_shape: + raise ValueError( + f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}." + ) + + actual_shape = value.size() + expected_shape = self._batch_shape + self._event_shape + for i, j in zip(reversed(actual_shape), reversed(expected_shape)): + if i != 1 and j != 1 and i != j: + raise ValueError( + f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}." + ) + try: + support = self.support + except NotImplementedError: + warnings.warn( + f"{self.__class__} does not define `support` to enable " + + "sample validation. Please initialize the distribution with " + + "`validate_args=False` to turn off validation." + ) + return + assert support is not None + valid = support.check(value) + if not valid.all(): + raise ValueError( + "Expected value argument " + f"({type(value).__name__} of shape {tuple(value.shape)}) " + f"to be within the support ({repr(support)}) " + f"of the distribution {repr(self)}, " + f"but found invalid values:\n{value}" + ) + + def _get_checked_instance(self, cls, _instance=None): + if _instance is None and type(self).__init__ != cls.__init__: + raise NotImplementedError( + f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method " + "must also define a custom .expand() method." + ) + return self.__new__(type(self)) if _instance is None else _instance + + def __repr__(self) -> str: + param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] + args_string = ", ".join( + [ + f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}" + for p in param_names + ] + ) + return self.__class__.__name__ + "(" + args_string + ")" diff --git a/lib/python3.10/site-packages/torch/distributions/exp_family.py b/lib/python3.10/site-packages/torch/distributions/exp_family.py new file mode 100644 index 0000000000000000000000000000000000000000..33234c47b10282856a94c0b29860db68ac715b18 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/exp_family.py @@ -0,0 +1,64 @@ +# mypy: allow-untyped-defs +import torch +from torch.distributions.distribution import Distribution + + +__all__ = ["ExponentialFamily"] + + +class ExponentialFamily(Distribution): + r""" + ExponentialFamily is the abstract base class for probability distributions belonging to an + exponential family, whose probability mass/density function has the form is defined below + + .. math:: + + p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x)) + + where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic, + :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier + measure. + + Note: + This class is an intermediary between the `Distribution` class and distributions which belong + to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL + divergence methods. We use this class to compute the entropy and KL divergence using the AD + framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and + Cross-entropies of Exponential Families). + """ + + @property + def _natural_params(self): + """ + Abstract method for natural parameters. Returns a tuple of Tensors based + on the distribution + """ + raise NotImplementedError + + def _log_normalizer(self, *natural_params): + """ + Abstract method for log normalizer function. Returns a log normalizer based on + the distribution and input + """ + raise NotImplementedError + + @property + def _mean_carrier_measure(self): + """ + Abstract method for expected carrier measure, which is required for computing + entropy. + """ + raise NotImplementedError + + def entropy(self): + """ + Method to compute the entropy using Bregman divergence of the log normalizer. + """ + result = -self._mean_carrier_measure + nparams = [p.detach().requires_grad_() for p in self._natural_params] + lg_normal = self._log_normalizer(*nparams) + gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True) + result += lg_normal + for np, g in zip(nparams, gradients): + result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1) + return result diff --git a/lib/python3.10/site-packages/torch/distributions/exponential.py b/lib/python3.10/site-packages/torch/distributions/exponential.py new file mode 100644 index 0000000000000000000000000000000000000000..02e349c11b9cb204279c3f40d80e2e8596b073f2 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/exponential.py @@ -0,0 +1,87 @@ +# mypy: allow-untyped-defs +from numbers import Number + +import torch +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import broadcast_all +from torch.types import _size + + +__all__ = ["Exponential"] + + +class Exponential(ExponentialFamily): + r""" + Creates a Exponential distribution parameterized by :attr:`rate`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Exponential(torch.tensor([1.0])) + >>> m.sample() # Exponential distributed with rate=1 + tensor([ 0.1046]) + + Args: + rate (float or Tensor): rate = 1 / scale of the distribution + """ + arg_constraints = {"rate": constraints.positive} + support = constraints.nonnegative + has_rsample = True + _mean_carrier_measure = 0 + + @property + def mean(self): + return self.rate.reciprocal() + + @property + def mode(self): + return torch.zeros_like(self.rate) + + @property + def stddev(self): + return self.rate.reciprocal() + + @property + def variance(self): + return self.rate.pow(-2) + + def __init__(self, rate, validate_args=None): + (self.rate,) = broadcast_all(rate) + batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Exponential, _instance) + batch_shape = torch.Size(batch_shape) + new.rate = self.rate.expand(batch_shape) + super(Exponential, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + return self.rate.new(shape).exponential_() / self.rate + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return self.rate.log() - self.rate * value + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 1 - torch.exp(-self.rate * value) + + def icdf(self, value): + return -torch.log1p(-value) / self.rate + + def entropy(self): + return 1.0 - torch.log(self.rate) + + @property + def _natural_params(self): + return (-self.rate,) + + def _log_normalizer(self, x): + return -torch.log(-x) diff --git a/lib/python3.10/site-packages/torch/distributions/fishersnedecor.py b/lib/python3.10/site-packages/torch/distributions/fishersnedecor.py new file mode 100644 index 0000000000000000000000000000000000000000..824a77ad7835fd536c16ba71359f50f0e4cc323c --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/fishersnedecor.py @@ -0,0 +1,101 @@ +# mypy: allow-untyped-defs +from numbers import Number + +import torch +from torch import nan +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.gamma import Gamma +from torch.distributions.utils import broadcast_all +from torch.types import _size + + +__all__ = ["FisherSnedecor"] + + +class FisherSnedecor(Distribution): + r""" + Creates a Fisher-Snedecor distribution parameterized by :attr:`df1` and :attr:`df2`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0])) + >>> m.sample() # Fisher-Snedecor-distributed with df1=1 and df2=2 + tensor([ 0.2453]) + + Args: + df1 (float or Tensor): degrees of freedom parameter 1 + df2 (float or Tensor): degrees of freedom parameter 2 + """ + arg_constraints = {"df1": constraints.positive, "df2": constraints.positive} + support = constraints.positive + has_rsample = True + + def __init__(self, df1, df2, validate_args=None): + self.df1, self.df2 = broadcast_all(df1, df2) + self._gamma1 = Gamma(self.df1 * 0.5, self.df1) + self._gamma2 = Gamma(self.df2 * 0.5, self.df2) + + if isinstance(df1, Number) and isinstance(df2, Number): + batch_shape = torch.Size() + else: + batch_shape = self.df1.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(FisherSnedecor, _instance) + batch_shape = torch.Size(batch_shape) + new.df1 = self.df1.expand(batch_shape) + new.df2 = self.df2.expand(batch_shape) + new._gamma1 = self._gamma1.expand(batch_shape) + new._gamma2 = self._gamma2.expand(batch_shape) + super(FisherSnedecor, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def mean(self): + df2 = self.df2.clone(memory_format=torch.contiguous_format) + df2[df2 <= 2] = nan + return df2 / (df2 - 2) + + @property + def mode(self): + mode = (self.df1 - 2) / self.df1 * self.df2 / (self.df2 + 2) + mode[self.df1 <= 2] = nan + return mode + + @property + def variance(self): + df2 = self.df2.clone(memory_format=torch.contiguous_format) + df2[df2 <= 4] = nan + return ( + 2 + * df2.pow(2) + * (self.df1 + df2 - 2) + / (self.df1 * (df2 - 2).pow(2) * (df2 - 4)) + ) + + def rsample(self, sample_shape: _size = torch.Size(())) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + # X1 ~ Gamma(df1 / 2, 1 / df1), X2 ~ Gamma(df2 / 2, 1 / df2) + # Y = df2 * df1 * X1 / (df1 * df2 * X2) = X1 / X2 ~ F(df1, df2) + X1 = self._gamma1.rsample(sample_shape).view(shape) + X2 = self._gamma2.rsample(sample_shape).view(shape) + tiny = torch.finfo(X2.dtype).tiny + X2.clamp_(min=tiny) + Y = X1 / X2 + Y.clamp_(min=tiny) + return Y + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + ct1 = self.df1 * 0.5 + ct2 = self.df2 * 0.5 + ct3 = self.df1 / self.df2 + t1 = (ct1 + ct2).lgamma() - ct1.lgamma() - ct2.lgamma() + t2 = ct1 * ct3.log() + (ct1 - 1) * torch.log(value) + t3 = (ct1 + ct2) * torch.log1p(ct3 * value) + return t1 + t2 - t3 diff --git a/lib/python3.10/site-packages/torch/distributions/gamma.py b/lib/python3.10/site-packages/torch/distributions/gamma.py new file mode 100644 index 0000000000000000000000000000000000000000..97631683d53b2432a974508dfd1daf3868171553 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/gamma.py @@ -0,0 +1,111 @@ +# mypy: allow-untyped-defs +from numbers import Number + +import torch +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import broadcast_all +from torch.types import _size + + +__all__ = ["Gamma"] + + +def _standard_gamma(concentration): + return torch._standard_gamma(concentration) + + +class Gamma(ExponentialFamily): + r""" + Creates a Gamma distribution parameterized by shape :attr:`concentration` and :attr:`rate`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0])) + >>> m.sample() # Gamma distributed with concentration=1 and rate=1 + tensor([ 0.1046]) + + Args: + concentration (float or Tensor): shape parameter of the distribution + (often referred to as alpha) + rate (float or Tensor): rate = 1 / scale of the distribution + (often referred to as beta) + """ + arg_constraints = { + "concentration": constraints.positive, + "rate": constraints.positive, + } + support = constraints.nonnegative + has_rsample = True + _mean_carrier_measure = 0 + + @property + def mean(self): + return self.concentration / self.rate + + @property + def mode(self): + return ((self.concentration - 1) / self.rate).clamp(min=0) + + @property + def variance(self): + return self.concentration / self.rate.pow(2) + + def __init__(self, concentration, rate, validate_args=None): + self.concentration, self.rate = broadcast_all(concentration, rate) + if isinstance(concentration, Number) and isinstance(rate, Number): + batch_shape = torch.Size() + else: + batch_shape = self.concentration.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Gamma, _instance) + batch_shape = torch.Size(batch_shape) + new.concentration = self.concentration.expand(batch_shape) + new.rate = self.rate.expand(batch_shape) + super(Gamma, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand( + shape + ) + value.detach().clamp_( + min=torch.finfo(value.dtype).tiny + ) # do not record in autograd graph + return value + + def log_prob(self, value): + value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device) + if self._validate_args: + self._validate_sample(value) + return ( + torch.xlogy(self.concentration, self.rate) + + torch.xlogy(self.concentration - 1, value) + - self.rate * value + - torch.lgamma(self.concentration) + ) + + def entropy(self): + return ( + self.concentration + - torch.log(self.rate) + + torch.lgamma(self.concentration) + + (1.0 - self.concentration) * torch.digamma(self.concentration) + ) + + @property + def _natural_params(self): + return (self.concentration - 1, -self.rate) + + def _log_normalizer(self, x, y): + return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal()) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return torch.special.gammainc(self.concentration, self.rate * value) diff --git a/lib/python3.10/site-packages/torch/distributions/geometric.py b/lib/python3.10/site-packages/torch/distributions/geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..c5f17802411082a55abe8ecc77c83746f11c9e47 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/geometric.py @@ -0,0 +1,130 @@ +# mypy: allow-untyped-defs +from numbers import Number + +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) +from torch.nn.functional import binary_cross_entropy_with_logits + + +__all__ = ["Geometric"] + + +class Geometric(Distribution): + r""" + Creates a Geometric distribution parameterized by :attr:`probs`, + where :attr:`probs` is the probability of success of Bernoulli trials. + + .. math:: + + P(X=k) = (1-p)^{k} p, k = 0, 1, ... + + .. note:: + :func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success + hence draws samples in :math:`\{0, 1, \ldots\}`, whereas + :func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Geometric(torch.tensor([0.3])) + >>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0 + tensor([ 2.]) + + Args: + probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1] + logits (Number, Tensor): the log-odds of sampling `1`. + """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.nonnegative_integer + + def __init__(self, probs=None, logits=None, validate_args=None): + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + (self.probs,) = broadcast_all(probs) + else: + (self.logits,) = broadcast_all(logits) + probs_or_logits = probs if probs is not None else logits + if isinstance(probs_or_logits, Number): + batch_shape = torch.Size() + else: + batch_shape = probs_or_logits.size() + super().__init__(batch_shape, validate_args=validate_args) + if self._validate_args and probs is not None: + # Add an extra check beyond unit_interval + value = self.probs + valid = value > 0 + if not valid.all(): + invalid_value = value.data[~valid] + raise ValueError( + "Expected parameter probs " + f"({type(value).__name__} of shape {tuple(value.shape)}) " + f"of distribution {repr(self)} " + f"to be positive but found invalid values:\n{invalid_value}" + ) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Geometric, _instance) + batch_shape = torch.Size(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + super(Geometric, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def mean(self): + return 1.0 / self.probs - 1.0 + + @property + def mode(self): + return torch.zeros_like(self.probs) + + @property + def variance(self): + return (1.0 / self.probs - 1.0) / self.probs + + @lazy_property + def logits(self): + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits, is_binary=True) + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + tiny = torch.finfo(self.probs.dtype).tiny + with torch.no_grad(): + if torch._C._get_tracing_state(): + # [JIT WORKAROUND] lack of support for .uniform_() + u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device) + u = u.clamp(min=tiny) + else: + u = self.probs.new(shape).uniform_(tiny, 1) + return (u.log() / (-self.probs).log1p()).floor() + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + value, probs = broadcast_all(value, self.probs) + probs = probs.clone(memory_format=torch.contiguous_format) + probs[(probs == 1) & (value == 0)] = 0 + return value * (-probs).log1p() + self.probs.log() + + def entropy(self): + return ( + binary_cross_entropy_with_logits(self.logits, self.probs, reduction="none") + / self.probs + ) diff --git a/lib/python3.10/site-packages/torch/distributions/gumbel.py b/lib/python3.10/site-packages/torch/distributions/gumbel.py new file mode 100644 index 0000000000000000000000000000000000000000..782aec9b350a52865f7b5fb0ac400640df0c5fe1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/gumbel.py @@ -0,0 +1,83 @@ +# mypy: allow-untyped-defs +import math +from numbers import Number + +import torch +from torch.distributions import constraints +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, ExpTransform +from torch.distributions.uniform import Uniform +from torch.distributions.utils import broadcast_all, euler_constant + + +__all__ = ["Gumbel"] + + +class Gumbel(TransformedDistribution): + r""" + Samples from a Gumbel Distribution. + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0])) + >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2 + tensor([ 1.0124]) + + Args: + loc (float or Tensor): Location parameter of the distribution + scale (float or Tensor): Scale parameter of the distribution + """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + + def __init__(self, loc, scale, validate_args=None): + self.loc, self.scale = broadcast_all(loc, scale) + finfo = torch.finfo(self.loc.dtype) + if isinstance(loc, Number) and isinstance(scale, Number): + base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args) + else: + base_dist = Uniform( + torch.full_like(self.loc, finfo.tiny), + torch.full_like(self.loc, 1 - finfo.eps), + validate_args=validate_args, + ) + transforms = [ + ExpTransform().inv, + AffineTransform(loc=0, scale=-torch.ones_like(self.scale)), + ExpTransform().inv, + AffineTransform(loc=loc, scale=-self.scale), + ] + super().__init__(base_dist, transforms, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Gumbel, _instance) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + return super().expand(batch_shape, _instance=new) + + # Explicitly defining the log probability function for Gumbel due to precision issues + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + y = (self.loc - value) / self.scale + return (y - y.exp()) - self.scale.log() + + @property + def mean(self): + return self.loc + self.scale * euler_constant + + @property + def mode(self): + return self.loc + + @property + def stddev(self): + return (math.pi / math.sqrt(6)) * self.scale + + @property + def variance(self): + return self.stddev.pow(2) + + def entropy(self): + return self.scale.log() + (1 + euler_constant) diff --git a/lib/python3.10/site-packages/torch/distributions/half_cauchy.py b/lib/python3.10/site-packages/torch/distributions/half_cauchy.py new file mode 100644 index 0000000000000000000000000000000000000000..17f48c45cf271ef2c7b4d42ccebf79c4397be09e --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/half_cauchy.py @@ -0,0 +1,84 @@ +# mypy: allow-untyped-defs +import math + +import torch +from torch import inf +from torch.distributions import constraints +from torch.distributions.cauchy import Cauchy +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AbsTransform + + +__all__ = ["HalfCauchy"] + + +class HalfCauchy(TransformedDistribution): + r""" + Creates a half-Cauchy distribution parameterized by `scale` where:: + + X ~ Cauchy(0, scale) + Y = |X| ~ HalfCauchy(scale) + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = HalfCauchy(torch.tensor([1.0])) + >>> m.sample() # half-cauchy distributed with scale=1 + tensor([ 2.3214]) + + Args: + scale (float or Tensor): scale of the full Cauchy distribution + """ + arg_constraints = {"scale": constraints.positive} + support = constraints.nonnegative + has_rsample = True + + def __init__(self, scale, validate_args=None): + base_dist = Cauchy(0, scale, validate_args=False) + super().__init__(base_dist, AbsTransform(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(HalfCauchy, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def scale(self): + return self.base_dist.scale + + @property + def mean(self): + return torch.full( + self._extended_shape(), + math.inf, + dtype=self.scale.dtype, + device=self.scale.device, + ) + + @property + def mode(self): + return torch.zeros_like(self.scale) + + @property + def variance(self): + return self.base_dist.variance + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + value = torch.as_tensor( + value, dtype=self.base_dist.scale.dtype, device=self.base_dist.scale.device + ) + log_prob = self.base_dist.log_prob(value) + math.log(2) + log_prob = torch.where(value >= 0, log_prob, -inf) + return log_prob + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 2 * self.base_dist.cdf(value) - 1 + + def icdf(self, prob): + return self.base_dist.icdf((prob + 1) / 2) + + def entropy(self): + return self.base_dist.entropy() - math.log(2) diff --git a/lib/python3.10/site-packages/torch/distributions/half_normal.py b/lib/python3.10/site-packages/torch/distributions/half_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..4031c34bbb578e633580167b7cdbaefdb6c7e2ac --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/half_normal.py @@ -0,0 +1,76 @@ +# mypy: allow-untyped-defs +import math + +import torch +from torch import inf +from torch.distributions import constraints +from torch.distributions.normal import Normal +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AbsTransform + + +__all__ = ["HalfNormal"] + + +class HalfNormal(TransformedDistribution): + r""" + Creates a half-normal distribution parameterized by `scale` where:: + + X ~ Normal(0, scale) + Y = |X| ~ HalfNormal(scale) + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = HalfNormal(torch.tensor([1.0])) + >>> m.sample() # half-normal distributed with scale=1 + tensor([ 0.1046]) + + Args: + scale (float or Tensor): scale of the full Normal distribution + """ + arg_constraints = {"scale": constraints.positive} + support = constraints.nonnegative + has_rsample = True + + def __init__(self, scale, validate_args=None): + base_dist = Normal(0, scale, validate_args=False) + super().__init__(base_dist, AbsTransform(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(HalfNormal, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def scale(self): + return self.base_dist.scale + + @property + def mean(self): + return self.scale * math.sqrt(2 / math.pi) + + @property + def mode(self): + return torch.zeros_like(self.scale) + + @property + def variance(self): + return self.scale.pow(2) * (1 - 2 / math.pi) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + log_prob = self.base_dist.log_prob(value) + math.log(2) + log_prob = torch.where(value >= 0, log_prob, -inf) + return log_prob + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 2 * self.base_dist.cdf(value) - 1 + + def icdf(self, prob): + return self.base_dist.icdf((prob + 1) / 2) + + def entropy(self): + return self.base_dist.entropy() - math.log(2) diff --git a/lib/python3.10/site-packages/torch/distributions/independent.py b/lib/python3.10/site-packages/torch/distributions/independent.py new file mode 100644 index 0000000000000000000000000000000000000000..edf740138ef8743ae4833930bfd703eb90197672 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/independent.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-defs +from typing import Dict + +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import _sum_rightmost +from torch.types import _size + + +__all__ = ["Independent"] + + +class Independent(Distribution): + r""" + Reinterprets some of the batch dims of a distribution as event dims. + + This is mainly useful for changing the shape of the result of + :meth:`log_prob`. For example to create a diagonal Normal distribution with + the same shape as a Multivariate Normal distribution (so they are + interchangeable), you can:: + + >>> from torch.distributions.multivariate_normal import MultivariateNormal + >>> from torch.distributions.normal import Normal + >>> loc = torch.zeros(3) + >>> scale = torch.ones(3) + >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale)) + >>> [mvn.batch_shape, mvn.event_shape] + [torch.Size([]), torch.Size([3])] + >>> normal = Normal(loc, scale) + >>> [normal.batch_shape, normal.event_shape] + [torch.Size([3]), torch.Size([])] + >>> diagn = Independent(normal, 1) + >>> [diagn.batch_shape, diagn.event_shape] + [torch.Size([]), torch.Size([3])] + + Args: + base_distribution (torch.distributions.distribution.Distribution): a + base distribution + reinterpreted_batch_ndims (int): the number of batch dims to + reinterpret as event dims + """ + arg_constraints: Dict[str, constraints.Constraint] = {} + + def __init__( + self, base_distribution, reinterpreted_batch_ndims, validate_args=None + ): + if reinterpreted_batch_ndims > len(base_distribution.batch_shape): + raise ValueError( + "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), " + f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}" + ) + shape = base_distribution.batch_shape + base_distribution.event_shape + event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape) + batch_shape = shape[: len(shape) - event_dim] + event_shape = shape[len(shape) - event_dim :] + self.base_dist = base_distribution + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Independent, _instance) + batch_shape = torch.Size(batch_shape) + new.base_dist = self.base_dist.expand( + batch_shape + self.event_shape[: self.reinterpreted_batch_ndims] + ) + new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims + super(Independent, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @property + def has_rsample(self): + return self.base_dist.has_rsample + + @property + def has_enumerate_support(self): + if self.reinterpreted_batch_ndims > 0: + return False + return self.base_dist.has_enumerate_support + + @constraints.dependent_property + def support(self): + result = self.base_dist.support + if self.reinterpreted_batch_ndims: + result = constraints.independent(result, self.reinterpreted_batch_ndims) + return result + + @property + def mean(self): + return self.base_dist.mean + + @property + def mode(self): + return self.base_dist.mode + + @property + def variance(self): + return self.base_dist.variance + + def sample(self, sample_shape=torch.Size()): + return self.base_dist.sample(sample_shape) + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + return self.base_dist.rsample(sample_shape) + + def log_prob(self, value): + log_prob = self.base_dist.log_prob(value) + return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims) + + def entropy(self): + entropy = self.base_dist.entropy() + return _sum_rightmost(entropy, self.reinterpreted_batch_ndims) + + def enumerate_support(self, expand=True): + if self.reinterpreted_batch_ndims > 0: + raise NotImplementedError( + "Enumeration over cartesian product is not implemented" + ) + return self.base_dist.enumerate_support(expand=expand) + + def __repr__(self): + return ( + self.__class__.__name__ + + f"({self.base_dist}, {self.reinterpreted_batch_ndims})" + ) diff --git a/lib/python3.10/site-packages/torch/distributions/inverse_gamma.py b/lib/python3.10/site-packages/torch/distributions/inverse_gamma.py new file mode 100644 index 0000000000000000000000000000000000000000..cff64d0a9e4956c6c2267dd10f921bb9e4fa6f4c --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/inverse_gamma.py @@ -0,0 +1,81 @@ +# mypy: allow-untyped-defs +import torch +from torch.distributions import constraints +from torch.distributions.gamma import Gamma +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import PowerTransform + + +__all__ = ["InverseGamma"] + + +class InverseGamma(TransformedDistribution): + r""" + Creates an inverse gamma distribution parameterized by :attr:`concentration` and :attr:`rate` + where:: + + X ~ Gamma(concentration, rate) + Y = 1 / X ~ InverseGamma(concentration, rate) + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterinistic") + >>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0])) + >>> m.sample() + tensor([ 1.2953]) + + Args: + concentration (float or Tensor): shape parameter of the distribution + (often referred to as alpha) + rate (float or Tensor): rate = 1 / scale of the distribution + (often referred to as beta) + """ + arg_constraints = { + "concentration": constraints.positive, + "rate": constraints.positive, + } + support = constraints.positive + has_rsample = True + + def __init__(self, concentration, rate, validate_args=None): + base_dist = Gamma(concentration, rate, validate_args=validate_args) + neg_one = -base_dist.rate.new_ones(()) + super().__init__( + base_dist, PowerTransform(neg_one), validate_args=validate_args + ) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(InverseGamma, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def concentration(self): + return self.base_dist.concentration + + @property + def rate(self): + return self.base_dist.rate + + @property + def mean(self): + result = self.rate / (self.concentration - 1) + return torch.where(self.concentration > 1, result, torch.inf) + + @property + def mode(self): + return self.rate / (self.concentration + 1) + + @property + def variance(self): + result = self.rate.square() / ( + (self.concentration - 1).square() * (self.concentration - 2) + ) + return torch.where(self.concentration > 2, result, torch.inf) + + def entropy(self): + return ( + self.concentration + + self.rate.log() + + self.concentration.lgamma() + - (1 + self.concentration) * self.concentration.digamma() + ) diff --git a/lib/python3.10/site-packages/torch/distributions/kl.py b/lib/python3.10/site-packages/torch/distributions/kl.py new file mode 100644 index 0000000000000000000000000000000000000000..c94c711cc8b364c68daf5d34830b4756aaced93c --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/kl.py @@ -0,0 +1,972 @@ +# mypy: allow-untyped-defs +import math +import warnings +from functools import total_ordering +from typing import Callable, Dict, Tuple, Type + +import torch +from torch import inf + +from .bernoulli import Bernoulli +from .beta import Beta +from .binomial import Binomial +from .categorical import Categorical +from .cauchy import Cauchy +from .continuous_bernoulli import ContinuousBernoulli +from .dirichlet import Dirichlet +from .distribution import Distribution +from .exp_family import ExponentialFamily +from .exponential import Exponential +from .gamma import Gamma +from .geometric import Geometric +from .gumbel import Gumbel +from .half_normal import HalfNormal +from .independent import Independent +from .laplace import Laplace +from .lowrank_multivariate_normal import ( + _batch_lowrank_logdet, + _batch_lowrank_mahalanobis, + LowRankMultivariateNormal, +) +from .multivariate_normal import _batch_mahalanobis, MultivariateNormal +from .normal import Normal +from .one_hot_categorical import OneHotCategorical +from .pareto import Pareto +from .poisson import Poisson +from .transformed_distribution import TransformedDistribution +from .uniform import Uniform +from .utils import _sum_rightmost, euler_constant as _euler_gamma + + +_KL_REGISTRY: Dict[ + Tuple[Type, Type], Callable +] = {} # Source of truth mapping a few general (type, type) pairs to functions. +_KL_MEMOIZE: Dict[ + Tuple[Type, Type], Callable +] = {} # Memoized version mapping many specific (type, type) pairs to functions. + +__all__ = ["register_kl", "kl_divergence"] + + +def register_kl(type_p, type_q): + """ + Decorator to register a pairwise function with :meth:`kl_divergence`. + Usage:: + + @register_kl(Normal, Normal) + def kl_normal_normal(p, q): + # insert implementation here + + Lookup returns the most specific (type,type) match ordered by subclass. If + the match is ambiguous, a `RuntimeWarning` is raised. For example to + resolve the ambiguous situation:: + + @register_kl(BaseP, DerivedQ) + def kl_version1(p, q): ... + @register_kl(DerivedP, BaseQ) + def kl_version2(p, q): ... + + you should register a third most-specific implementation, e.g.:: + + register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie. + + Args: + type_p (type): A subclass of :class:`~torch.distributions.Distribution`. + type_q (type): A subclass of :class:`~torch.distributions.Distribution`. + """ + if not isinstance(type_p, type) and issubclass(type_p, Distribution): + raise TypeError( + f"Expected type_p to be a Distribution subclass but got {type_p}" + ) + if not isinstance(type_q, type) and issubclass(type_q, Distribution): + raise TypeError( + f"Expected type_q to be a Distribution subclass but got {type_q}" + ) + + def decorator(fun): + _KL_REGISTRY[type_p, type_q] = fun + _KL_MEMOIZE.clear() # reset since lookup order may have changed + return fun + + return decorator + + +@total_ordering +class _Match: + __slots__ = ["types"] + + def __init__(self, *types): + self.types = types + + def __eq__(self, other): + return self.types == other.types + + def __le__(self, other): + for x, y in zip(self.types, other.types): + if not issubclass(x, y): + return False + if x is not y: + break + return True + + +def _dispatch_kl(type_p, type_q): + """ + Find the most specific approximate match, assuming single inheritance. + """ + matches = [ + (super_p, super_q) + for super_p, super_q in _KL_REGISTRY + if issubclass(type_p, super_p) and issubclass(type_q, super_q) + ] + if not matches: + return NotImplemented + # Check that the left- and right- lexicographic orders agree. + # mypy isn't smart enough to know that _Match implements __lt__ + # see: https://github.com/python/typing/issues/760#issuecomment-710670503 + left_p, left_q = min(_Match(*m) for m in matches).types # type: ignore[type-var] + right_q, right_p = min(_Match(*reversed(m)) for m in matches).types # type: ignore[type-var] + left_fun = _KL_REGISTRY[left_p, left_q] + right_fun = _KL_REGISTRY[right_p, right_q] + if left_fun is not right_fun: + warnings.warn( + f"Ambiguous kl_divergence({type_p.__name__}, {type_q.__name__}). " + f"Please register_kl({left_p.__name__}, {right_q.__name__})", + RuntimeWarning, + ) + return left_fun + + +def _infinite_like(tensor): + """ + Helper function for obtaining infinite KL Divergence throughout + """ + return torch.full_like(tensor, inf) + + +def _x_log_x(tensor): + """ + Utility function for calculating x log x + """ + return tensor * tensor.log() + + +def _batch_trace_XXT(bmat): + """ + Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions + """ + n = bmat.size(-1) + m = bmat.size(-2) + flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1) + return flat_trace.reshape(bmat.shape[:-2]) + + +def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor: + r""" + Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions. + + .. math:: + + KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx + + Args: + p (Distribution): A :class:`~torch.distributions.Distribution` object. + q (Distribution): A :class:`~torch.distributions.Distribution` object. + + Returns: + Tensor: A batch of KL divergences of shape `batch_shape`. + + Raises: + NotImplementedError: If the distribution types have not been registered via + :meth:`register_kl`. + """ + try: + fun = _KL_MEMOIZE[type(p), type(q)] + except KeyError: + fun = _dispatch_kl(type(p), type(q)) + _KL_MEMOIZE[type(p), type(q)] = fun + if fun is NotImplemented: + raise NotImplementedError( + f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}" + ) + return fun(p, q) + + +################################################################################ +# KL Divergence Implementations +################################################################################ + +# Same distributions + + +@register_kl(Bernoulli, Bernoulli) +def _kl_bernoulli_bernoulli(p, q): + t1 = p.probs * ( + torch.nn.functional.softplus(-q.logits) + - torch.nn.functional.softplus(-p.logits) + ) + t1[q.probs == 0] = inf + t1[p.probs == 0] = 0 + t2 = (1 - p.probs) * ( + torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits) + ) + t2[q.probs == 1] = inf + t2[p.probs == 1] = 0 + return t1 + t2 + + +@register_kl(Beta, Beta) +def _kl_beta_beta(p, q): + sum_params_p = p.concentration1 + p.concentration0 + sum_params_q = q.concentration1 + q.concentration0 + t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma() + t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma() + t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1) + t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0) + t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p) + return t1 - t2 + t3 + t4 + t5 + + +@register_kl(Binomial, Binomial) +def _kl_binomial_binomial(p, q): + # from https://math.stackexchange.com/questions/2214993/ + # kullback-leibler-divergence-for-binomial-distributions-p-and-q + if (p.total_count < q.total_count).any(): + raise NotImplementedError( + "KL between Binomials where q.total_count > p.total_count is not implemented" + ) + kl = p.total_count * ( + p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p() + ) + inf_idxs = p.total_count > q.total_count + kl[inf_idxs] = _infinite_like(kl[inf_idxs]) + return kl + + +@register_kl(Categorical, Categorical) +def _kl_categorical_categorical(p, q): + t = p.probs * (p.logits - q.logits) + t[(q.probs == 0).expand_as(t)] = inf + t[(p.probs == 0).expand_as(t)] = 0 + return t.sum(-1) + + +@register_kl(ContinuousBernoulli, ContinuousBernoulli) +def _kl_continuous_bernoulli_continuous_bernoulli(p, q): + t1 = p.mean * (p.logits - q.logits) + t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs) + t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs) + return t1 + t2 + t3 + + +@register_kl(Dirichlet, Dirichlet) +def _kl_dirichlet_dirichlet(p, q): + # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/ + sum_p_concentration = p.concentration.sum(-1) + sum_q_concentration = q.concentration.sum(-1) + t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma() + t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1) + t3 = p.concentration - q.concentration + t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1) + return t1 - t2 + (t3 * t4).sum(-1) + + +@register_kl(Exponential, Exponential) +def _kl_exponential_exponential(p, q): + rate_ratio = q.rate / p.rate + t1 = -rate_ratio.log() + return t1 + rate_ratio - 1 + + +@register_kl(ExponentialFamily, ExponentialFamily) +def _kl_expfamily_expfamily(p, q): + if not type(p) == type(q): + raise NotImplementedError( + "The cross KL-divergence between different exponential families cannot \ + be computed using Bregman divergences" + ) + p_nparams = [np.detach().requires_grad_() for np in p._natural_params] + q_nparams = q._natural_params + lg_normal = p._log_normalizer(*p_nparams) + gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True) + result = q._log_normalizer(*q_nparams) - lg_normal + for pnp, qnp, g in zip(p_nparams, q_nparams, gradients): + term = (qnp - pnp) * g + result -= _sum_rightmost(term, len(q.event_shape)) + return result + + +@register_kl(Gamma, Gamma) +def _kl_gamma_gamma(p, q): + t1 = q.concentration * (p.rate / q.rate).log() + t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration) + t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration) + t4 = (q.rate - p.rate) * (p.concentration / p.rate) + return t1 + t2 + t3 + t4 + + +@register_kl(Gumbel, Gumbel) +def _kl_gumbel_gumbel(p, q): + ct1 = p.scale / q.scale + ct2 = q.loc / q.scale + ct3 = p.loc / q.scale + t1 = -ct1.log() - ct2 + ct3 + t2 = ct1 * _euler_gamma + t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3) + return t1 + t2 + t3 - (1 + _euler_gamma) + + +@register_kl(Geometric, Geometric) +def _kl_geometric_geometric(p, q): + return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits + + +@register_kl(HalfNormal, HalfNormal) +def _kl_halfnormal_halfnormal(p, q): + return _kl_normal_normal(p.base_dist, q.base_dist) + + +@register_kl(Laplace, Laplace) +def _kl_laplace_laplace(p, q): + # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf + scale_ratio = p.scale / q.scale + loc_abs_diff = (p.loc - q.loc).abs() + t1 = -scale_ratio.log() + t2 = loc_abs_diff / q.scale + t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale) + return t1 + t2 + t3 - 1 + + +@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal) +def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q): + if p.event_shape != q.event_shape: + raise ValueError( + "KL-divergence between two Low Rank Multivariate Normals with\ + different event shapes cannot be computed" + ) + + term1 = _batch_lowrank_logdet( + q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril + ) - _batch_lowrank_logdet( + p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril + ) + term3 = _batch_lowrank_mahalanobis( + q._unbroadcasted_cov_factor, + q._unbroadcasted_cov_diag, + q.loc - p.loc, + q._capacitance_tril, + ) + # Expands term2 according to + # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD) + # = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T) + qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2) + A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False) + term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1) + term22 = _batch_trace_XXT( + p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1) + ) + term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2)) + term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor)) + term2 = term21 + term22 - term23 - term24 + return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) + + +@register_kl(MultivariateNormal, LowRankMultivariateNormal) +def _kl_multivariatenormal_lowrankmultivariatenormal(p, q): + if p.event_shape != q.event_shape: + raise ValueError( + "KL-divergence between two (Low Rank) Multivariate Normals with\ + different event shapes cannot be computed" + ) + + term1 = _batch_lowrank_logdet( + q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril + ) - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + term3 = _batch_lowrank_mahalanobis( + q._unbroadcasted_cov_factor, + q._unbroadcasted_cov_diag, + q.loc - p.loc, + q._capacitance_tril, + ) + # Expands term2 according to + # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T + # = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T + qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2) + A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False) + term21 = _batch_trace_XXT( + p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1) + ) + term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril)) + term2 = term21 - term22 + return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) + + +@register_kl(LowRankMultivariateNormal, MultivariateNormal) +def _kl_lowrankmultivariatenormal_multivariatenormal(p, q): + if p.event_shape != q.event_shape: + raise ValueError( + "KL-divergence between two (Low Rank) Multivariate Normals with\ + different event shapes cannot be computed" + ) + + term1 = 2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum( + -1 + ) - _batch_lowrank_logdet( + p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril + ) + term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) + # Expands term2 according to + # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD) + combined_batch_shape = torch._C._infer_size( + q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_cov_factor.shape[:-2] + ) + n = p.event_shape[0] + q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) + p_cov_factor = p._unbroadcasted_cov_factor.expand( + combined_batch_shape + (n, p.cov_factor.size(-1)) + ) + p_cov_diag = torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand( + combined_batch_shape + (n, n) + ) + term21 = _batch_trace_XXT( + torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False) + ) + term22 = _batch_trace_XXT( + torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False) + ) + term2 = term21 + term22 + return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) + + +@register_kl(MultivariateNormal, MultivariateNormal) +def _kl_multivariatenormal_multivariatenormal(p, q): + # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence + if p.event_shape != q.event_shape: + raise ValueError( + "KL-divergence between two Multivariate Normals with\ + different event shapes cannot be computed" + ) + + half_term1 = q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum( + -1 + ) - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + combined_batch_shape = torch._C._infer_size( + q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_scale_tril.shape[:-2] + ) + n = p.event_shape[0] + q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) + p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) + term2 = _batch_trace_XXT( + torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False) + ) + term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) + return half_term1 + 0.5 * (term2 + term3 - n) + + +@register_kl(Normal, Normal) +def _kl_normal_normal(p, q): + var_ratio = (p.scale / q.scale).pow(2) + t1 = ((p.loc - q.loc) / q.scale).pow(2) + return 0.5 * (var_ratio + t1 - 1 - var_ratio.log()) + + +@register_kl(OneHotCategorical, OneHotCategorical) +def _kl_onehotcategorical_onehotcategorical(p, q): + return _kl_categorical_categorical(p._categorical, q._categorical) + + +@register_kl(Pareto, Pareto) +def _kl_pareto_pareto(p, q): + # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf + scale_ratio = p.scale / q.scale + alpha_ratio = q.alpha / p.alpha + t1 = q.alpha * scale_ratio.log() + t2 = -alpha_ratio.log() + result = t1 + t2 + alpha_ratio - 1 + result[p.support.lower_bound < q.support.lower_bound] = inf + return result + + +@register_kl(Poisson, Poisson) +def _kl_poisson_poisson(p, q): + return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate) + + +@register_kl(TransformedDistribution, TransformedDistribution) +def _kl_transformed_transformed(p, q): + if p.transforms != q.transforms: + raise NotImplementedError + if p.event_shape != q.event_shape: + raise NotImplementedError + return kl_divergence(p.base_dist, q.base_dist) + + +@register_kl(Uniform, Uniform) +def _kl_uniform_uniform(p, q): + result = ((q.high - q.low) / (p.high - p.low)).log() + result[(q.low > p.low) | (q.high < p.high)] = inf + return result + + +# Different distributions +@register_kl(Bernoulli, Poisson) +def _kl_bernoulli_poisson(p, q): + return -p.entropy() - (p.probs * q.rate.log() - q.rate) + + +@register_kl(Beta, ContinuousBernoulli) +def _kl_beta_continuous_bernoulli(p, q): + return ( + -p.entropy() + - p.mean * q.logits + - torch.log1p(-q.probs) + - q._cont_bern_log_norm() + ) + + +@register_kl(Beta, Pareto) +def _kl_beta_infinity(p, q): + return _infinite_like(p.concentration1) + + +@register_kl(Beta, Exponential) +def _kl_beta_exponential(p, q): + return ( + -p.entropy() + - q.rate.log() + + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0)) + ) + + +@register_kl(Beta, Gamma) +def _kl_beta_gamma(p, q): + t1 = -p.entropy() + t2 = q.concentration.lgamma() - q.concentration * q.rate.log() + t3 = (q.concentration - 1) * ( + p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma() + ) + t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0) + return t1 + t2 - t3 + t4 + + +# TODO: Add Beta-Laplace KL Divergence + + +@register_kl(Beta, Normal) +def _kl_beta_normal(p, q): + E_beta = p.concentration1 / (p.concentration1 + p.concentration0) + var_normal = q.scale.pow(2) + t1 = -p.entropy() + t2 = 0.5 * (var_normal * 2 * math.pi).log() + t3 = ( + E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1) + + E_beta.pow(2) + ) * 0.5 + t4 = q.loc * E_beta + t5 = q.loc.pow(2) * 0.5 + return t1 + t2 + (t3 - t4 + t5) / var_normal + + +@register_kl(Beta, Uniform) +def _kl_beta_uniform(p, q): + result = -p.entropy() + (q.high - q.low).log() + result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf + return result + + +# Note that the KL between a ContinuousBernoulli and Beta has no closed form + + +@register_kl(ContinuousBernoulli, Pareto) +def _kl_continuous_bernoulli_infinity(p, q): + return _infinite_like(p.probs) + + +@register_kl(ContinuousBernoulli, Exponential) +def _kl_continuous_bernoulli_exponential(p, q): + return -p.entropy() - torch.log(q.rate) + q.rate * p.mean + + +# Note that the KL between a ContinuousBernoulli and Gamma has no closed form +# TODO: Add ContinuousBernoulli-Laplace KL Divergence + + +@register_kl(ContinuousBernoulli, Normal) +def _kl_continuous_bernoulli_normal(p, q): + t1 = -p.entropy() + t2 = 0.5 * (math.log(2.0 * math.pi) + torch.square(q.loc / q.scale)) + torch.log( + q.scale + ) + t3 = (p.variance + torch.square(p.mean) - 2.0 * q.loc * p.mean) / ( + 2.0 * torch.square(q.scale) + ) + return t1 + t2 + t3 + + +@register_kl(ContinuousBernoulli, Uniform) +def _kl_continuous_bernoulli_uniform(p, q): + result = -p.entropy() + (q.high - q.low).log() + return torch.where( + torch.max( + torch.ge(q.low, p.support.lower_bound), + torch.le(q.high, p.support.upper_bound), + ), + torch.ones_like(result) * inf, + result, + ) + + +@register_kl(Exponential, Beta) +@register_kl(Exponential, ContinuousBernoulli) +@register_kl(Exponential, Pareto) +@register_kl(Exponential, Uniform) +def _kl_exponential_infinity(p, q): + return _infinite_like(p.rate) + + +@register_kl(Exponential, Gamma) +def _kl_exponential_gamma(p, q): + ratio = q.rate / p.rate + t1 = -q.concentration * torch.log(ratio) + return ( + t1 + + ratio + + q.concentration.lgamma() + + q.concentration * _euler_gamma + - (1 + _euler_gamma) + ) + + +@register_kl(Exponential, Gumbel) +def _kl_exponential_gumbel(p, q): + scale_rate_prod = p.rate * q.scale + loc_scale_ratio = q.loc / q.scale + t1 = scale_rate_prod.log() - 1 + t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1) + t3 = scale_rate_prod.reciprocal() + return t1 - loc_scale_ratio + t2 + t3 + + +# TODO: Add Exponential-Laplace KL Divergence + + +@register_kl(Exponential, Normal) +def _kl_exponential_normal(p, q): + var_normal = q.scale.pow(2) + rate_sqr = p.rate.pow(2) + t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi) + t2 = rate_sqr.reciprocal() + t3 = q.loc / p.rate + t4 = q.loc.pow(2) * 0.5 + return t1 - 1 + (t2 - t3 + t4) / var_normal + + +@register_kl(Gamma, Beta) +@register_kl(Gamma, ContinuousBernoulli) +@register_kl(Gamma, Pareto) +@register_kl(Gamma, Uniform) +def _kl_gamma_infinity(p, q): + return _infinite_like(p.concentration) + + +@register_kl(Gamma, Exponential) +def _kl_gamma_exponential(p, q): + return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate + + +@register_kl(Gamma, Gumbel) +def _kl_gamma_gumbel(p, q): + beta_scale_prod = p.rate * q.scale + loc_scale_ratio = q.loc / q.scale + t1 = ( + (p.concentration - 1) * p.concentration.digamma() + - p.concentration.lgamma() + - p.concentration + ) + t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod + t3 = ( + torch.exp(loc_scale_ratio) + * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration) + - loc_scale_ratio + ) + return t1 + t2 + t3 + + +# TODO: Add Gamma-Laplace KL Divergence + + +@register_kl(Gamma, Normal) +def _kl_gamma_normal(p, q): + var_normal = q.scale.pow(2) + beta_sqr = p.rate.pow(2) + t1 = ( + 0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi) + - p.concentration + - p.concentration.lgamma() + ) + t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr + t3 = q.loc * p.concentration / p.rate + t4 = 0.5 * q.loc.pow(2) + return ( + t1 + + (p.concentration - 1) * p.concentration.digamma() + + (t2 - t3 + t4) / var_normal + ) + + +@register_kl(Gumbel, Beta) +@register_kl(Gumbel, ContinuousBernoulli) +@register_kl(Gumbel, Exponential) +@register_kl(Gumbel, Gamma) +@register_kl(Gumbel, Pareto) +@register_kl(Gumbel, Uniform) +def _kl_gumbel_infinity(p, q): + return _infinite_like(p.loc) + + +# TODO: Add Gumbel-Laplace KL Divergence + + +@register_kl(Gumbel, Normal) +def _kl_gumbel_normal(p, q): + param_ratio = p.scale / q.scale + t1 = (param_ratio / math.sqrt(2 * math.pi)).log() + t2 = (math.pi * param_ratio * 0.5).pow(2) / 3 + t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5 + return -t1 + t2 + t3 - (_euler_gamma + 1) + + +@register_kl(Laplace, Beta) +@register_kl(Laplace, ContinuousBernoulli) +@register_kl(Laplace, Exponential) +@register_kl(Laplace, Gamma) +@register_kl(Laplace, Pareto) +@register_kl(Laplace, Uniform) +def _kl_laplace_infinity(p, q): + return _infinite_like(p.loc) + + +@register_kl(Laplace, Normal) +def _kl_laplace_normal(p, q): + var_normal = q.scale.pow(2) + scale_sqr_var_ratio = p.scale.pow(2) / var_normal + t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi) + t2 = 0.5 * p.loc.pow(2) + t3 = p.loc * q.loc + t4 = 0.5 * q.loc.pow(2) + return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1 + + +@register_kl(Normal, Beta) +@register_kl(Normal, ContinuousBernoulli) +@register_kl(Normal, Exponential) +@register_kl(Normal, Gamma) +@register_kl(Normal, Pareto) +@register_kl(Normal, Uniform) +def _kl_normal_infinity(p, q): + return _infinite_like(p.loc) + + +@register_kl(Normal, Gumbel) +def _kl_normal_gumbel(p, q): + mean_scale_ratio = p.loc / q.scale + var_scale_sqr_ratio = (p.scale / q.scale).pow(2) + loc_scale_ratio = q.loc / q.scale + t1 = var_scale_sqr_ratio.log() * 0.5 + t2 = mean_scale_ratio - loc_scale_ratio + t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio) + return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi))) + + +@register_kl(Normal, Laplace) +def _kl_normal_laplace(p, q): + loc_diff = p.loc - q.loc + scale_ratio = p.scale / q.scale + loc_diff_scale_ratio = loc_diff / p.scale + t1 = torch.log(scale_ratio) + t2 = ( + math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2)) + ) + t3 = loc_diff * torch.erf(math.sqrt(0.5) * loc_diff_scale_ratio) + return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi))) + + +@register_kl(Pareto, Beta) +@register_kl(Pareto, ContinuousBernoulli) +@register_kl(Pareto, Uniform) +def _kl_pareto_infinity(p, q): + return _infinite_like(p.scale) + + +@register_kl(Pareto, Exponential) +def _kl_pareto_exponential(p, q): + scale_rate_prod = p.scale * q.rate + t1 = (p.alpha / scale_rate_prod).log() + t2 = p.alpha.reciprocal() + t3 = p.alpha * scale_rate_prod / (p.alpha - 1) + result = t1 - t2 + t3 - 1 + result[p.alpha <= 1] = inf + return result + + +@register_kl(Pareto, Gamma) +def _kl_pareto_gamma(p, q): + common_term = p.scale.log() + p.alpha.reciprocal() + t1 = p.alpha.log() - common_term + t2 = q.concentration.lgamma() - q.concentration * q.rate.log() + t3 = (1 - q.concentration) * common_term + t4 = q.rate * p.alpha * p.scale / (p.alpha - 1) + result = t1 + t2 + t3 + t4 - 1 + result[p.alpha <= 1] = inf + return result + + +# TODO: Add Pareto-Laplace KL Divergence + + +@register_kl(Pareto, Normal) +def _kl_pareto_normal(p, q): + var_normal = 2 * q.scale.pow(2) + common_term = p.scale / (p.alpha - 1) + t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log() + t2 = p.alpha.reciprocal() + t3 = p.alpha * common_term.pow(2) / (p.alpha - 2) + t4 = (p.alpha * common_term - q.loc).pow(2) + result = t1 - t2 + (t3 + t4) / var_normal - 1 + result[p.alpha <= 2] = inf + return result + + +@register_kl(Poisson, Bernoulli) +@register_kl(Poisson, Binomial) +def _kl_poisson_infinity(p, q): + return _infinite_like(p.rate) + + +@register_kl(Uniform, Beta) +def _kl_uniform_beta(p, q): + common_term = p.high - p.low + t1 = torch.log(common_term) + t2 = ( + (q.concentration1 - 1) + * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) + / common_term + ) + t3 = ( + (q.concentration0 - 1) + * (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term) + / common_term + ) + t4 = ( + q.concentration1.lgamma() + + q.concentration0.lgamma() + - (q.concentration1 + q.concentration0).lgamma() + ) + result = t3 + t4 - t1 - t2 + result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf + return result + + +@register_kl(Uniform, ContinuousBernoulli) +def _kl_uniform_continuous_bernoulli(p, q): + result = ( + -p.entropy() + - p.mean * q.logits + - torch.log1p(-q.probs) + - q._cont_bern_log_norm() + ) + return torch.where( + torch.max( + torch.ge(p.high, q.support.upper_bound), + torch.le(p.low, q.support.lower_bound), + ), + torch.ones_like(result) * inf, + result, + ) + + +@register_kl(Uniform, Exponential) +def _kl_uniform_exponetial(p, q): + result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log() + result[p.low < q.support.lower_bound] = inf + return result + + +@register_kl(Uniform, Gamma) +def _kl_uniform_gamma(p, q): + common_term = p.high - p.low + t1 = common_term.log() + t2 = q.concentration.lgamma() - q.concentration * q.rate.log() + t3 = ( + (1 - q.concentration) + * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) + / common_term + ) + t4 = q.rate * (p.high + p.low) / 2 + result = -t1 + t2 + t3 + t4 + result[p.low < q.support.lower_bound] = inf + return result + + +@register_kl(Uniform, Gumbel) +def _kl_uniform_gumbel(p, q): + common_term = q.scale / (p.high - p.low) + high_loc_diff = (p.high - q.loc) / q.scale + low_loc_diff = (p.low - q.loc) / q.scale + t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff) + t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff)) + return t1 - t2 + + +# TODO: Uniform-Laplace KL Divergence + + +@register_kl(Uniform, Normal) +def _kl_uniform_normal(p, q): + common_term = p.high - p.low + t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log() + t2 = (common_term).pow(2) / 12 + t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2) + return t1 + 0.5 * (t2 + t3) / q.scale.pow(2) + + +@register_kl(Uniform, Pareto) +def _kl_uniform_pareto(p, q): + support_uniform = p.high - p.low + t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log() + t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform + result = t2 * (q.alpha + 1) - t1 + result[p.low < q.support.lower_bound] = inf + return result + + +@register_kl(Independent, Independent) +def _kl_independent_independent(p, q): + if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims: + raise NotImplementedError + result = kl_divergence(p.base_dist, q.base_dist) + return _sum_rightmost(result, p.reinterpreted_batch_ndims) + + +@register_kl(Cauchy, Cauchy) +def _kl_cauchy_cauchy(p, q): + # From https://arxiv.org/abs/1905.10965 + t1 = ((p.scale + q.scale).pow(2) + (p.loc - q.loc).pow(2)).log() + t2 = (4 * p.scale * q.scale).log() + return t1 - t2 + + +def _add_kl_info(): + """Appends a list of implemented KL functions to the doc for kl_divergence.""" + rows = [ + "KL divergence is currently implemented for the following distribution pairs:" + ] + for p, q in sorted( + _KL_REGISTRY, key=lambda p_q: (p_q[0].__name__, p_q[1].__name__) + ): + rows.append( + f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`" + ) + kl_info = "\n\t".join(rows) + if kl_divergence.__doc__: + kl_divergence.__doc__ += kl_info # type: ignore[operator] diff --git a/lib/python3.10/site-packages/torch/distributions/kumaraswamy.py b/lib/python3.10/site-packages/torch/distributions/kumaraswamy.py new file mode 100644 index 0000000000000000000000000000000000000000..367e5d52e44a29cf00d1edb4317ac8248d1e9e7a --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/kumaraswamy.py @@ -0,0 +1,99 @@ +# mypy: allow-untyped-defs +import torch +from torch import nan +from torch.distributions import constraints +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, PowerTransform +from torch.distributions.uniform import Uniform +from torch.distributions.utils import broadcast_all, euler_constant + + +__all__ = ["Kumaraswamy"] + + +def _moments(a, b, n): + """ + Computes nth moment of Kumaraswamy using using torch.lgamma + """ + arg1 = 1 + n / a + log_value = torch.lgamma(arg1) + torch.lgamma(b) - torch.lgamma(arg1 + b) + return b * torch.exp(log_value) + + +class Kumaraswamy(TransformedDistribution): + r""" + Samples from a Kumaraswamy distribution. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0])) + >>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1 + tensor([ 0.1729]) + + Args: + concentration1 (float or Tensor): 1st concentration parameter of the distribution + (often referred to as alpha) + concentration0 (float or Tensor): 2nd concentration parameter of the distribution + (often referred to as beta) + """ + arg_constraints = { + "concentration1": constraints.positive, + "concentration0": constraints.positive, + } + support = constraints.unit_interval + has_rsample = True + + def __init__(self, concentration1, concentration0, validate_args=None): + self.concentration1, self.concentration0 = broadcast_all( + concentration1, concentration0 + ) + finfo = torch.finfo(self.concentration0.dtype) + base_dist = Uniform( + torch.full_like(self.concentration0, 0), + torch.full_like(self.concentration0, 1), + validate_args=validate_args, + ) + transforms = [ + PowerTransform(exponent=self.concentration0.reciprocal()), + AffineTransform(loc=1.0, scale=-1.0), + PowerTransform(exponent=self.concentration1.reciprocal()), + ] + super().__init__(base_dist, transforms, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Kumaraswamy, _instance) + new.concentration1 = self.concentration1.expand(batch_shape) + new.concentration0 = self.concentration0.expand(batch_shape) + return super().expand(batch_shape, _instance=new) + + @property + def mean(self): + return _moments(self.concentration1, self.concentration0, 1) + + @property + def mode(self): + # Evaluate in log-space for numerical stability. + log_mode = ( + self.concentration0.reciprocal() * (-self.concentration0).log1p() + - (-self.concentration0 * self.concentration1).log1p() + ) + log_mode[(self.concentration0 < 1) | (self.concentration1 < 1)] = nan + return log_mode.exp() + + @property + def variance(self): + return _moments(self.concentration1, self.concentration0, 2) - torch.pow( + self.mean, 2 + ) + + def entropy(self): + t1 = 1 - self.concentration1.reciprocal() + t0 = 1 - self.concentration0.reciprocal() + H0 = torch.digamma(self.concentration0 + 1) + euler_constant + return ( + t0 + + t1 * H0 + - torch.log(self.concentration1) + - torch.log(self.concentration0) + ) diff --git a/lib/python3.10/site-packages/torch/distributions/laplace.py b/lib/python3.10/site-packages/torch/distributions/laplace.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d33f2638289201a4dd2c580503790ac90256d4 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/laplace.py @@ -0,0 +1,97 @@ +# mypy: allow-untyped-defs +from numbers import Number + +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all +from torch.types import _size + + +__all__ = ["Laplace"] + + +class Laplace(Distribution): + r""" + Creates a Laplace distribution parameterized by :attr:`loc` and :attr:`scale`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0])) + >>> m.sample() # Laplace distributed with loc=0, scale=1 + tensor([ 0.1046]) + + Args: + loc (float or Tensor): mean of the distribution + scale (float or Tensor): scale of the distribution + """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + has_rsample = True + + @property + def mean(self): + return self.loc + + @property + def mode(self): + return self.loc + + @property + def variance(self): + return 2 * self.scale.pow(2) + + @property + def stddev(self): + return (2**0.5) * self.scale + + def __init__(self, loc, scale, validate_args=None): + self.loc, self.scale = broadcast_all(loc, scale) + if isinstance(loc, Number) and isinstance(scale, Number): + batch_shape = torch.Size() + else: + batch_shape = self.loc.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Laplace, _instance) + batch_shape = torch.Size(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + super(Laplace, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + finfo = torch.finfo(self.loc.dtype) + if torch._C._get_tracing_state(): + # [JIT WORKAROUND] lack of support for .uniform_() + u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) * 2 - 1 + return self.loc - self.scale * u.sign() * torch.log1p( + -u.abs().clamp(min=finfo.tiny) + ) + u = self.loc.new(shape).uniform_(finfo.eps - 1, 1) + # TODO: If we ever implement tensor.nextafter, below is what we want ideally. + # u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5) + return self.loc - self.scale * u.sign() * torch.log1p(-u.abs()) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1( + -(value - self.loc).abs() / self.scale + ) + + def icdf(self, value): + term = value - 0.5 + return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs()) + + def entropy(self): + return 1 + torch.log(2 * self.scale) diff --git a/lib/python3.10/site-packages/torch/distributions/lkj_cholesky.py b/lib/python3.10/site-packages/torch/distributions/lkj_cholesky.py new file mode 100644 index 0000000000000000000000000000000000000000..479568bdd428d6e36a78ee8a532f9cd6c35c7ef8 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/lkj_cholesky.py @@ -0,0 +1,144 @@ +# mypy: allow-untyped-defs +""" +This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro). + +Original copyright notice: + +# Copyright: Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 +""" + +import math + +import torch +from torch.distributions import Beta, constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all + + +__all__ = ["LKJCholesky"] + + +class LKJCholesky(Distribution): + r""" + LKJ distribution for lower Cholesky factor of correlation matrices. + The distribution is controlled by ``concentration`` parameter :math:`\eta` + to make the probability of the correlation matrix :math:`M` generated from + a Cholesky factor proportional to :math:`\det(M)^{\eta - 1}`. Because of that, + when ``concentration == 1``, we have a uniform distribution over Cholesky + factors of correlation matrices:: + + L ~ LKJCholesky(dim, concentration) + X = L @ L' ~ LKJCorr(dim, concentration) + + Note that this distribution samples the + Cholesky factor of correlation matrices and not the correlation matrices + themselves and thereby differs slightly from the derivations in [1] for + the `LKJCorr` distribution. For sampling, this uses the Onion method from + [1] Section 3. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> l = LKJCholesky(3, 0.5) + >>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix + tensor([[ 1.0000, 0.0000, 0.0000], + [ 0.3516, 0.9361, 0.0000], + [-0.1899, 0.4748, 0.8593]]) + + Args: + dimension (dim): dimension of the matrices + concentration (float or Tensor): concentration/shape parameter of the + distribution (often referred to as eta) + + **References** + + [1] `Generating random correlation matrices based on vines and extended onion method` (2009), + Daniel Lewandowski, Dorota Kurowicka, Harry Joe. + Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008 + """ + arg_constraints = {"concentration": constraints.positive} + support = constraints.corr_cholesky + + def __init__(self, dim, concentration=1.0, validate_args=None): + if dim < 2: + raise ValueError( + f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}." + ) + self.dim = dim + (self.concentration,) = broadcast_all(concentration) + batch_shape = self.concentration.size() + event_shape = torch.Size((dim, dim)) + # This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1]. + marginal_conc = self.concentration + 0.5 * (self.dim - 2) + offset = torch.arange( + self.dim - 1, + dtype=self.concentration.dtype, + device=self.concentration.device, + ) + offset = torch.cat([offset.new_zeros((1,)), offset]) + beta_conc1 = offset + 0.5 + beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset + self._beta = Beta(beta_conc1, beta_conc0) + super().__init__(batch_shape, event_shape, validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LKJCholesky, _instance) + batch_shape = torch.Size(batch_shape) + new.dim = self.dim + new.concentration = self.concentration.expand(batch_shape) + new._beta = self._beta.expand(batch_shape + (self.dim,)) + super(LKJCholesky, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + def sample(self, sample_shape=torch.Size()): + # This uses the Onion method, but there are a few differences from [1] Sec. 3.2: + # - This vectorizes the for loop and also works for heterogeneous eta. + # - Same algorithm generalizes to n=1. + # - The procedure is simplified since we are sampling the cholesky factor of + # the correlation matrix instead of the correlation matrix itself. As such, + # we only need to generate `w`. + y = self._beta.sample(sample_shape).unsqueeze(-1) + u_normal = torch.randn( + self._extended_shape(sample_shape), dtype=y.dtype, device=y.device + ).tril(-1) + u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True) + # Replace NaNs in first row + u_hypersphere[..., 0, :].fill_(0.0) + w = torch.sqrt(y) * u_hypersphere + # Fill diagonal elements; clamp for numerical stability + eps = torch.finfo(w.dtype).tiny + diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt() + w += torch.diag_embed(diag_elems) + return w + + def log_prob(self, value): + # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html + # The probability of a correlation matrix is proportional to + # determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1)) + # Additionally, the Jacobian of the transformation from Cholesky factor to + # correlation matrix is: + # prod(L_ii ^ (D - i)) + # So the probability of a Cholesky factor is propotional to + # prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i) + # with order_i = 2 * concentration - 2 + D - i + if self._validate_args: + self._validate_sample(value) + diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:] + order = torch.arange(2, self.dim + 1, device=self.concentration.device) + order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order + unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1) + # Compute normalization constant (page 1999 of [1]) + dm1 = self.dim - 1 + alpha = self.concentration + 0.5 * dm1 + denominator = torch.lgamma(alpha) * dm1 + numerator = torch.mvlgamma(alpha - 0.5, dm1) + # pi_constant in [1] is D * (D - 1) / 4 * log(pi) + # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi) + # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2 + pi_constant = 0.5 * dm1 * math.log(math.pi) + normalize_term = pi_constant + numerator - denominator + return unnormalized_log_pdf - normalize_term diff --git a/lib/python3.10/site-packages/torch/distributions/log_normal.py b/lib/python3.10/site-packages/torch/distributions/log_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..d40d21b9ef4a6a28fd89656c9d3cc66d370492fa --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/log_normal.py @@ -0,0 +1,64 @@ +# mypy: allow-untyped-defs +from torch.distributions import constraints +from torch.distributions.normal import Normal +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import ExpTransform + + +__all__ = ["LogNormal"] + + +class LogNormal(TransformedDistribution): + r""" + Creates a log-normal distribution parameterized by + :attr:`loc` and :attr:`scale` where:: + + X ~ Normal(loc, scale) + Y = exp(X) ~ LogNormal(loc, scale) + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0])) + >>> m.sample() # log-normal distributed with mean=0 and stddev=1 + tensor([ 0.1046]) + + Args: + loc (float or Tensor): mean of log of distribution + scale (float or Tensor): standard deviation of log of the distribution + """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.positive + has_rsample = True + + def __init__(self, loc, scale, validate_args=None): + base_dist = Normal(loc, scale, validate_args=validate_args) + super().__init__(base_dist, ExpTransform(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LogNormal, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def loc(self): + return self.base_dist.loc + + @property + def scale(self): + return self.base_dist.scale + + @property + def mean(self): + return (self.loc + self.scale.pow(2) / 2).exp() + + @property + def mode(self): + return (self.loc - self.scale.square()).exp() + + @property + def variance(self): + scale_sq = self.scale.pow(2) + return scale_sq.expm1() * (2 * self.loc + scale_sq).exp() + + def entropy(self): + return self.base_dist.entropy() + self.loc diff --git a/lib/python3.10/site-packages/torch/distributions/logistic_normal.py b/lib/python3.10/site-packages/torch/distributions/logistic_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..466afe50f48e0fafe4dda73550748c331e0d0adc --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/logistic_normal.py @@ -0,0 +1,56 @@ +# mypy: allow-untyped-defs +from torch.distributions import constraints +from torch.distributions.normal import Normal +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import StickBreakingTransform + + +__all__ = ["LogisticNormal"] + + +class LogisticNormal(TransformedDistribution): + r""" + Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale` + that define the base `Normal` distribution transformed with the + `StickBreakingTransform` such that:: + + X ~ LogisticNormal(loc, scale) + Y = log(X / (1 - X.cumsum(-1)))[..., :-1] ~ Normal(loc, scale) + + Args: + loc (float or Tensor): mean of the base distribution + scale (float or Tensor): standard deviation of the base distribution + + Example:: + + >>> # logistic-normal distributed with mean=(0, 0, 0) and stddev=(1, 1, 1) + >>> # of the base Normal distribution + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = LogisticNormal(torch.tensor([0.0] * 3), torch.tensor([1.0] * 3)) + >>> m.sample() + tensor([ 0.7653, 0.0341, 0.0579, 0.1427]) + + """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.simplex + has_rsample = True + + def __init__(self, loc, scale, validate_args=None): + base_dist = Normal(loc, scale, validate_args=validate_args) + if not base_dist.batch_shape: + base_dist = base_dist.expand([1]) + super().__init__( + base_dist, StickBreakingTransform(), validate_args=validate_args + ) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LogisticNormal, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def loc(self): + return self.base_dist.base_dist.loc + + @property + def scale(self): + return self.base_dist.base_dist.scale diff --git a/lib/python3.10/site-packages/torch/distributions/lowrank_multivariate_normal.py b/lib/python3.10/site-packages/torch/distributions/lowrank_multivariate_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..22dea3ca6bda217f8ad3de15ed771bd8cb1696a3 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/lowrank_multivariate_normal.py @@ -0,0 +1,240 @@ +# mypy: allow-untyped-defs +import math + +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv +from torch.distributions.utils import _standard_normal, lazy_property +from torch.types import _size + + +__all__ = ["LowRankMultivariateNormal"] + + +def _batch_capacitance_tril(W, D): + r""" + Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W` + and a batch of vectors :math:`D`. + """ + m = W.size(-1) + Wt_Dinv = W.mT / D.unsqueeze(-2) + K = torch.matmul(Wt_Dinv, W).contiguous() + K.view(-1, m * m)[:, :: m + 1] += 1 # add identity matrix to K + return torch.linalg.cholesky(K) + + +def _batch_lowrank_logdet(W, D, capacitance_tril): + r""" + Uses "matrix determinant lemma":: + log|W @ W.T + D| = log|C| + log|D|, + where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute + the log determinant. + """ + return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum( + -1 + ) + + +def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril): + r""" + Uses "Woodbury matrix identity":: + inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D), + where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared + Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`. + """ + Wt_Dinv = W.mT / D.unsqueeze(-2) + Wt_Dinv_x = _batch_mv(Wt_Dinv, x) + mahalanobis_term1 = (x.pow(2) / D).sum(-1) + mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x) + return mahalanobis_term1 - mahalanobis_term2 + + +class LowRankMultivariateNormal(Distribution): + r""" + Creates a multivariate normal distribution with covariance matrix having a low-rank form + parameterized by :attr:`cov_factor` and :attr:`cov_diag`:: + + covariance_matrix = cov_factor @ cov_factor.T + cov_diag + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2)) + >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]` + tensor([-0.2102, -0.5429]) + + Args: + loc (Tensor): mean of the distribution with shape `batch_shape + event_shape` + cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape + `batch_shape + event_shape + (rank,)` + cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape + `batch_shape + event_shape` + + Note: + The computation for determinant and inverse of covariance matrix is avoided when + `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity + `_ and + `matrix determinant lemma `_. + Thanks to these formulas, we just need to compute the determinant and inverse of + the small size "capacitance" matrix:: + + capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor + """ + arg_constraints = { + "loc": constraints.real_vector, + "cov_factor": constraints.independent(constraints.real, 2), + "cov_diag": constraints.independent(constraints.positive, 1), + } + support = constraints.real_vector + has_rsample = True + + def __init__(self, loc, cov_factor, cov_diag, validate_args=None): + if loc.dim() < 1: + raise ValueError("loc must be at least one-dimensional.") + event_shape = loc.shape[-1:] + if cov_factor.dim() < 2: + raise ValueError( + "cov_factor must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + if cov_factor.shape[-2:-1] != event_shape: + raise ValueError( + f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m" + ) + if cov_diag.shape[-1:] != event_shape: + raise ValueError( + f"cov_diag must be a batch of vectors with shape {event_shape}" + ) + + loc_ = loc.unsqueeze(-1) + cov_diag_ = cov_diag.unsqueeze(-1) + try: + loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors( + loc_, cov_factor, cov_diag_ + ) + except RuntimeError as e: + raise ValueError( + f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}" + ) from e + self.loc = loc_[..., 0] + self.cov_diag = cov_diag_[..., 0] + batch_shape = self.loc.shape[:-1] + + self._unbroadcasted_cov_factor = cov_factor + self._unbroadcasted_cov_diag = cov_diag + self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag) + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LowRankMultivariateNormal, _instance) + batch_shape = torch.Size(batch_shape) + loc_shape = batch_shape + self.event_shape + new.loc = self.loc.expand(loc_shape) + new.cov_diag = self.cov_diag.expand(loc_shape) + new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:]) + new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor + new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag + new._capacitance_tril = self._capacitance_tril + super(LowRankMultivariateNormal, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @property + def mean(self): + return self.loc + + @property + def mode(self): + return self.loc + + @lazy_property + def variance(self): + return ( + self._unbroadcasted_cov_factor.pow(2).sum(-1) + self._unbroadcasted_cov_diag + ).expand(self._batch_shape + self._event_shape) + + @lazy_property + def scale_tril(self): + # The following identity is used to increase the numerically computation stability + # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3): + # W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2 + # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1, + # hence it is well-conditioned and safe to take Cholesky decomposition. + n = self._event_shape[0] + cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1) + Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze + K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous() + K.view(-1, n * n)[:, :: n + 1] += 1 # add identity matrix to K + scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K) + return scale_tril.expand( + self._batch_shape + self._event_shape + self._event_shape + ) + + @lazy_property + def covariance_matrix(self): + covariance_matrix = torch.matmul( + self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.mT + ) + torch.diag_embed(self._unbroadcasted_cov_diag) + return covariance_matrix.expand( + self._batch_shape + self._event_shape + self._event_shape + ) + + @lazy_property + def precision_matrix(self): + # We use "Woodbury matrix identity" to take advantage of low rank form:: + # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D) + # where :math:`C` is the capacitance matrix. + Wt_Dinv = ( + self._unbroadcasted_cov_factor.mT + / self._unbroadcasted_cov_diag.unsqueeze(-2) + ) + A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False) + precision_matrix = ( + torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A + ) + return precision_matrix.expand( + self._batch_shape + self._event_shape + self._event_shape + ) + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + W_shape = shape[:-1] + self.cov_factor.shape[-1:] + eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device) + eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) + return ( + self.loc + + _batch_mv(self._unbroadcasted_cov_factor, eps_W) + + self._unbroadcasted_cov_diag.sqrt() * eps_D + ) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + diff = value - self.loc + M = _batch_lowrank_mahalanobis( + self._unbroadcasted_cov_factor, + self._unbroadcasted_cov_diag, + diff, + self._capacitance_tril, + ) + log_det = _batch_lowrank_logdet( + self._unbroadcasted_cov_factor, + self._unbroadcasted_cov_diag, + self._capacitance_tril, + ) + return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M) + + def entropy(self): + log_det = _batch_lowrank_logdet( + self._unbroadcasted_cov_factor, + self._unbroadcasted_cov_diag, + self._capacitance_tril, + ) + H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det) + if len(self._batch_shape) == 0: + return H + else: + return H.expand(self._batch_shape) diff --git a/lib/python3.10/site-packages/torch/distributions/mixture_same_family.py b/lib/python3.10/site-packages/torch/distributions/mixture_same_family.py new file mode 100644 index 0000000000000000000000000000000000000000..99e0362cead32b12a0145977cad9d46bd1fe588b --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/mixture_same_family.py @@ -0,0 +1,216 @@ +# mypy: allow-untyped-defs +from typing import Dict + +import torch +from torch.distributions import Categorical, constraints +from torch.distributions.distribution import Distribution + + +__all__ = ["MixtureSameFamily"] + + +class MixtureSameFamily(Distribution): + r""" + The `MixtureSameFamily` distribution implements a (batch of) mixture + distribution where all component are from different parameterizations of + the same distribution type. It is parameterized by a `Categorical` + "selecting distribution" (over `k` component) and a component + distribution, i.e., a `Distribution` with a rightmost batch shape + (equal to `[k]`) which indexes each (batch of) component. + + Examples:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally + >>> # weighted normal distributions + >>> mix = D.Categorical(torch.ones(5,)) + >>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) + >>> gmm = MixtureSameFamily(mix, comp) + + >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally + >>> # weighted bivariate normal distributions + >>> mix = D.Categorical(torch.ones(5,)) + >>> comp = D.Independent(D.Normal( + ... torch.randn(5,2), torch.rand(5,2)), 1) + >>> gmm = MixtureSameFamily(mix, comp) + + >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each + >>> # consisting of 5 random weighted bivariate normal distributions + >>> mix = D.Categorical(torch.rand(3,5)) + >>> comp = D.Independent(D.Normal( + ... torch.randn(3,5,2), torch.rand(3,5,2)), 1) + >>> gmm = MixtureSameFamily(mix, comp) + + Args: + mixture_distribution: `torch.distributions.Categorical`-like + instance. Manages the probability of selecting component. + The number of categories must match the rightmost batch + dimension of the `component_distribution`. Must have either + scalar `batch_shape` or `batch_shape` matching + `component_distribution.batch_shape[:-1]` + component_distribution: `torch.distributions.Distribution`-like + instance. Right-most batch dimension indexes component. + """ + arg_constraints: Dict[str, constraints.Constraint] = {} + has_rsample = False + + def __init__( + self, mixture_distribution, component_distribution, validate_args=None + ): + self._mixture_distribution = mixture_distribution + self._component_distribution = component_distribution + + if not isinstance(self._mixture_distribution, Categorical): + raise ValueError( + " The Mixture distribution needs to be an " + " instance of torch.distributions.Categorical" + ) + + if not isinstance(self._component_distribution, Distribution): + raise ValueError( + "The Component distribution need to be an " + "instance of torch.distributions.Distribution" + ) + + # Check that batch size matches + mdbs = self._mixture_distribution.batch_shape + cdbs = self._component_distribution.batch_shape[:-1] + for size1, size2 in zip(reversed(mdbs), reversed(cdbs)): + if size1 != 1 and size2 != 1 and size1 != size2: + raise ValueError( + f"`mixture_distribution.batch_shape` ({mdbs}) is not " + "compatible with `component_distribution." + f"batch_shape`({cdbs})" + ) + + # Check that the number of mixture component matches + km = self._mixture_distribution.logits.shape[-1] + kc = self._component_distribution.batch_shape[-1] + if km is not None and kc is not None and km != kc: + raise ValueError( + f"`mixture_distribution component` ({km}) does not" + " equal `component_distribution.batch_shape[-1]`" + f" ({kc})" + ) + self._num_component = km + + event_shape = self._component_distribution.event_shape + self._event_ndims = len(event_shape) + super().__init__( + batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args + ) + + def expand(self, batch_shape, _instance=None): + batch_shape = torch.Size(batch_shape) + batch_shape_comp = batch_shape + (self._num_component,) + new = self._get_checked_instance(MixtureSameFamily, _instance) + new._component_distribution = self._component_distribution.expand( + batch_shape_comp + ) + new._mixture_distribution = self._mixture_distribution.expand(batch_shape) + new._num_component = self._num_component + new._event_ndims = self._event_ndims + event_shape = new._component_distribution.event_shape + super(MixtureSameFamily, new).__init__( + batch_shape=batch_shape, event_shape=event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @constraints.dependent_property + def support(self): + # FIXME this may have the wrong shape when support contains batched + # parameters + return self._component_distribution.support + + @property + def mixture_distribution(self): + return self._mixture_distribution + + @property + def component_distribution(self): + return self._component_distribution + + @property + def mean(self): + probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) + return torch.sum( + probs * self.component_distribution.mean, dim=-1 - self._event_ndims + ) # [B, E] + + @property + def variance(self): + # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) + probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) + mean_cond_var = torch.sum( + probs * self.component_distribution.variance, dim=-1 - self._event_ndims + ) + var_cond_mean = torch.sum( + probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0), + dim=-1 - self._event_ndims, + ) + return mean_cond_var + var_cond_mean + + def cdf(self, x): + x = self._pad(x) + cdf_x = self.component_distribution.cdf(x) + mix_prob = self.mixture_distribution.probs + + return torch.sum(cdf_x * mix_prob, dim=-1) + + def log_prob(self, x): + if self._validate_args: + self._validate_sample(x) + x = self._pad(x) + log_prob_x = self.component_distribution.log_prob(x) # [S, B, k] + log_mix_prob = torch.log_softmax( + self.mixture_distribution.logits, dim=-1 + ) # [B, k] + return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B] + + def sample(self, sample_shape=torch.Size()): + with torch.no_grad(): + sample_len = len(sample_shape) + batch_len = len(self.batch_shape) + gather_dim = sample_len + batch_len + es = self.event_shape + + # mixture samples [n, B] + mix_sample = self.mixture_distribution.sample(sample_shape) + mix_shape = mix_sample.shape + + # component samples [n, B, k, E] + comp_samples = self.component_distribution.sample(sample_shape) + + # Gather along the k dimension + mix_sample_r = mix_sample.reshape( + mix_shape + torch.Size([1] * (len(es) + 1)) + ) + mix_sample_r = mix_sample_r.repeat( + torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es + ) + + samples = torch.gather(comp_samples, gather_dim, mix_sample_r) + return samples.squeeze(gather_dim) + + def _pad(self, x): + return x.unsqueeze(-1 - self._event_ndims) + + def _pad_mixture_dimensions(self, x): + dist_batch_ndims = len(self.batch_shape) + cat_batch_ndims = len(self.mixture_distribution.batch_shape) + pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims + xs = x.shape + x = x.reshape( + xs[:-1] + + torch.Size(pad_ndims * [1]) + + xs[-1:] + + torch.Size(self._event_ndims * [1]) + ) + return x + + def __repr__(self): + args_string = ( + f"\n {self.mixture_distribution},\n {self.component_distribution}" + ) + return "MixtureSameFamily" + "(" + args_string + ")" diff --git a/lib/python3.10/site-packages/torch/distributions/multinomial.py b/lib/python3.10/site-packages/torch/distributions/multinomial.py new file mode 100644 index 0000000000000000000000000000000000000000..12295a80e1855ceaf39aadd79f6df025ad7882ad --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/multinomial.py @@ -0,0 +1,137 @@ +# mypy: allow-untyped-defs +import torch +from torch import inf +from torch.distributions import Categorical, constraints +from torch.distributions.binomial import Binomial +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all + + +__all__ = ["Multinomial"] + + +class Multinomial(Distribution): + r""" + Creates a Multinomial distribution parameterized by :attr:`total_count` and + either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of + :attr:`probs` indexes over categories. All other dimensions index over batches. + + Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is + called (see example below) + + .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, + and it will be normalized to sum to 1 along the last dimension. :attr:`probs` + will return this normalized value. + The `logits` argument will be interpreted as unnormalized log probabilities + and can therefore be any real number. It will likewise be normalized so that + the resulting probabilities sum to 1 along the last dimension. :attr:`logits` + will return this normalized value. + + - :meth:`sample` requires a single shared `total_count` for all + parameters and samples. + - :meth:`log_prob` allows different `total_count` for each parameter and + sample. + + Example:: + + >>> # xdoctest: +SKIP("FIXME: found invalid values") + >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.])) + >>> x = m.sample() # equal probability of 0, 1, 2, 3 + tensor([ 21., 24., 30., 25.]) + + >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x) + tensor([-4.1338]) + + Args: + total_count (int): number of trials + probs (Tensor): event probabilities + logits (Tensor): event log probabilities (unnormalized) + """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + total_count: int + + @property + def mean(self): + return self.probs * self.total_count + + @property + def variance(self): + return self.total_count * self.probs * (1 - self.probs) + + def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): + if not isinstance(total_count, int): + raise NotImplementedError("inhomogeneous total_count is not supported") + self.total_count = total_count + self._categorical = Categorical(probs=probs, logits=logits) + self._binomial = Binomial(total_count=total_count, probs=self.probs) + batch_shape = self._categorical.batch_shape + event_shape = self._categorical.param_shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Multinomial, _instance) + batch_shape = torch.Size(batch_shape) + new.total_count = self.total_count + new._categorical = self._categorical.expand(batch_shape) + super(Multinomial, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._categorical._new(*args, **kwargs) + + @constraints.dependent_property(is_discrete=True, event_dim=1) + def support(self): + return constraints.multinomial(self.total_count) + + @property + def logits(self): + return self._categorical.logits + + @property + def probs(self): + return self._categorical.probs + + @property + def param_shape(self): + return self._categorical.param_shape + + def sample(self, sample_shape=torch.Size()): + sample_shape = torch.Size(sample_shape) + samples = self._categorical.sample( + torch.Size((self.total_count,)) + sample_shape + ) + # samples.shape is (total_count, sample_shape, batch_shape), need to change it to + # (sample_shape, batch_shape, total_count) + shifted_idx = list(range(samples.dim())) + shifted_idx.append(shifted_idx.pop(0)) + samples = samples.permute(*shifted_idx) + counts = samples.new(self._extended_shape(sample_shape)).zero_() + counts.scatter_add_(-1, samples, torch.ones_like(samples)) + return counts.type_as(self.probs) + + def entropy(self): + n = torch.tensor(self.total_count) + + cat_entropy = self._categorical.entropy() + term1 = n * cat_entropy - torch.lgamma(n + 1) + + support = self._binomial.enumerate_support(expand=False)[1:] + binomial_probs = torch.exp(self._binomial.log_prob(support)) + weights = torch.lgamma(support + 1) + term2 = (binomial_probs * weights).sum([0, -1]) + + return term1 + term2 + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + logits = logits.clone(memory_format=torch.contiguous_format) + log_factorial_n = torch.lgamma(value.sum(-1) + 1) + log_factorial_xs = torch.lgamma(value + 1).sum(-1) + logits[(value == 0) & (logits == -inf)] = 0 + log_powers = (logits * value).sum(-1) + return log_factorial_n - log_factorial_xs + log_powers diff --git a/lib/python3.10/site-packages/torch/distributions/multivariate_normal.py b/lib/python3.10/site-packages/torch/distributions/multivariate_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..bece6d0606a8c93af31d78a3be1e5623bf127a2f --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/multivariate_normal.py @@ -0,0 +1,265 @@ +# mypy: allow-untyped-defs +import math + +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import _standard_normal, lazy_property +from torch.types import _size + + +__all__ = ["MultivariateNormal"] + + +def _batch_mv(bmat, bvec): + r""" + Performs a batched matrix-vector product, with compatible but different batch shapes. + + This function takes as input `bmat`, containing :math:`n \times n` matrices, and + `bvec`, containing length :math:`n` vectors. + + Both `bmat` and `bvec` may have any number of leading dimensions, which correspond + to a batch shape. They are not necessarily assumed to have the same batch shape, + just ones which can be broadcasted. + """ + return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1) + + +def _batch_mahalanobis(bL, bx): + r""" + Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}` + for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`. + + Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch + shape, but `bL` one should be able to broadcasted to `bx` one. + """ + n = bx.size(-1) + bx_batch_shape = bx.shape[:-1] + + # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n), + # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve + bx_batch_dims = len(bx_batch_shape) + bL_batch_dims = bL.dim() - 2 + outer_batch_dims = bx_batch_dims - bL_batch_dims + old_batch_dims = outer_batch_dims + bL_batch_dims + new_batch_dims = outer_batch_dims + 2 * bL_batch_dims + # Reshape bx with the shape (..., 1, i, j, 1, n) + bx_new_shape = bx.shape[:outer_batch_dims] + for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]): + bx_new_shape += (sx // sL, sL) + bx_new_shape += (n,) + bx = bx.reshape(bx_new_shape) + # Permute bx to make it have shape (..., 1, j, i, 1, n) + permute_dims = ( + list(range(outer_batch_dims)) + + list(range(outer_batch_dims, new_batch_dims, 2)) + + list(range(outer_batch_dims + 1, new_batch_dims, 2)) + + [new_batch_dims] + ) + bx = bx.permute(permute_dims) + + flat_L = bL.reshape(-1, n, n) # shape = b x n x n + flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n + flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c + M_swap = ( + torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) + ) # shape = b x c + M = M_swap.t() # shape = c x b + + # Now we revert the above reshape and permute operators. + permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1) + permute_inv_dims = list(range(outer_batch_dims)) + for i in range(bL_batch_dims): + permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i] + reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1) + return reshaped_M.reshape(bx_batch_shape) + + +def _precision_to_scale_tril(P): + # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril + Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1))) + L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1) + Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device) + L = torch.linalg.solve_triangular(L_inv, Id, upper=False) + return L + + +class MultivariateNormal(Distribution): + r""" + Creates a multivariate normal (also called Gaussian) distribution + parameterized by a mean vector and a covariance matrix. + + The multivariate normal distribution can be parameterized either + in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}` + or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}` + or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued + diagonal entries, such that + :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix + can be obtained via e.g. Cholesky decomposition of the covariance. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2)) + >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I` + tensor([-0.2102, -0.5429]) + + Args: + loc (Tensor): mean of the distribution + covariance_matrix (Tensor): positive-definite covariance matrix + precision_matrix (Tensor): positive-definite precision matrix + scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal + + Note: + Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or + :attr:`scale_tril` can be specified. + + Using :attr:`scale_tril` will be more efficient: all computations internally + are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or + :attr:`precision_matrix` is passed instead, it is only used to compute + the corresponding lower triangular matrices using a Cholesky decomposition. + """ + arg_constraints = { + "loc": constraints.real_vector, + "covariance_matrix": constraints.positive_definite, + "precision_matrix": constraints.positive_definite, + "scale_tril": constraints.lower_cholesky, + } + support = constraints.real_vector + has_rsample = True + + def __init__( + self, + loc, + covariance_matrix=None, + precision_matrix=None, + scale_tril=None, + validate_args=None, + ): + if loc.dim() < 1: + raise ValueError("loc must be at least one-dimensional.") + if (covariance_matrix is not None) + (scale_tril is not None) + ( + precision_matrix is not None + ) != 1: + raise ValueError( + "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." + ) + + if scale_tril is not None: + if scale_tril.dim() < 2: + raise ValueError( + "scale_tril matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1]) + self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) + elif covariance_matrix is not None: + if covariance_matrix.dim() < 2: + raise ValueError( + "covariance_matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + batch_shape = torch.broadcast_shapes( + covariance_matrix.shape[:-2], loc.shape[:-1] + ) + self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) + else: + if precision_matrix.dim() < 2: + raise ValueError( + "precision_matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + batch_shape = torch.broadcast_shapes( + precision_matrix.shape[:-2], loc.shape[:-1] + ) + self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1)) + self.loc = loc.expand(batch_shape + (-1,)) + + event_shape = self.loc.shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + if scale_tril is not None: + self._unbroadcasted_scale_tril = scale_tril + elif covariance_matrix is not None: + self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix) + else: # precision_matrix is not None + self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(MultivariateNormal, _instance) + batch_shape = torch.Size(batch_shape) + loc_shape = batch_shape + self.event_shape + cov_shape = batch_shape + self.event_shape + self.event_shape + new.loc = self.loc.expand(loc_shape) + new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril + if "covariance_matrix" in self.__dict__: + new.covariance_matrix = self.covariance_matrix.expand(cov_shape) + if "scale_tril" in self.__dict__: + new.scale_tril = self.scale_tril.expand(cov_shape) + if "precision_matrix" in self.__dict__: + new.precision_matrix = self.precision_matrix.expand(cov_shape) + super(MultivariateNormal, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @lazy_property + def scale_tril(self): + return self._unbroadcasted_scale_tril.expand( + self._batch_shape + self._event_shape + self._event_shape + ) + + @lazy_property + def covariance_matrix(self): + return torch.matmul( + self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT + ).expand(self._batch_shape + self._event_shape + self._event_shape) + + @lazy_property + def precision_matrix(self): + return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand( + self._batch_shape + self._event_shape + self._event_shape + ) + + @property + def mean(self): + return self.loc + + @property + def mode(self): + return self.loc + + @property + def variance(self): + return ( + self._unbroadcasted_scale_tril.pow(2) + .sum(-1) + .expand(self._batch_shape + self._event_shape) + ) + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) + return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + diff = value - self.loc + M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff) + half_log_det = ( + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + ) + return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det + + def entropy(self): + half_log_det = ( + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + ) + H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det + if len(self._batch_shape) == 0: + return H + else: + return H.expand(self._batch_shape) diff --git a/lib/python3.10/site-packages/torch/distributions/negative_binomial.py b/lib/python3.10/site-packages/torch/distributions/negative_binomial.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e1a66639ba101a35c7deca67a7567e938a1608 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/negative_binomial.py @@ -0,0 +1,135 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn.functional as F +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) + + +__all__ = ["NegativeBinomial"] + + +class NegativeBinomial(Distribution): + r""" + Creates a Negative Binomial distribution, i.e. distribution + of the number of successful independent and identical Bernoulli trials + before :attr:`total_count` failures are achieved. The probability + of success of each Bernoulli trial is :attr:`probs`. + + Args: + total_count (float or Tensor): non-negative number of negative Bernoulli + trials to stop, although the distribution is still valid for real + valued count + probs (Tensor): Event probabilities of success in the half open interval [0, 1) + logits (Tensor): Event log-odds for probabilities of success + """ + arg_constraints = { + "total_count": constraints.greater_than_eq(0), + "probs": constraints.half_open_interval(0.0, 1.0), + "logits": constraints.real, + } + support = constraints.nonnegative_integer + + def __init__(self, total_count, probs=None, logits=None, validate_args=None): + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + ( + self.total_count, + self.probs, + ) = broadcast_all(total_count, probs) + self.total_count = self.total_count.type_as(self.probs) + else: + ( + self.total_count, + self.logits, + ) = broadcast_all(total_count, logits) + self.total_count = self.total_count.type_as(self.logits) + + self._param = self.probs if probs is not None else self.logits + batch_shape = self._param.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(NegativeBinomial, _instance) + batch_shape = torch.Size(batch_shape) + new.total_count = self.total_count.expand(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(NegativeBinomial, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @property + def mean(self): + return self.total_count * torch.exp(self.logits) + + @property + def mode(self): + return ((self.total_count - 1) * self.logits.exp()).floor().clamp(min=0.0) + + @property + def variance(self): + return self.mean / torch.sigmoid(-self.logits) + + @lazy_property + def logits(self): + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits, is_binary=True) + + @property + def param_shape(self): + return self._param.size() + + @lazy_property + def _gamma(self): + # Note we avoid validating because self.total_count can be zero. + return torch.distributions.Gamma( + concentration=self.total_count, + rate=torch.exp(-self.logits), + validate_args=False, + ) + + def sample(self, sample_shape=torch.Size()): + with torch.no_grad(): + rate = self._gamma.sample(sample_shape=sample_shape) + return torch.poisson(rate) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + + log_unnormalized_prob = self.total_count * F.logsigmoid( + -self.logits + ) + value * F.logsigmoid(self.logits) + + log_normalization = ( + -torch.lgamma(self.total_count + value) + + torch.lgamma(1.0 + value) + + torch.lgamma(self.total_count) + ) + # The case self.total_count == 0 and value == 0 has probability 1 but + # lgamma(0) is infinite. Handle this case separately using a function + # that does not modify tensors in place to allow Jit compilation. + log_normalization = log_normalization.masked_fill( + self.total_count + value == 0.0, 0.0 + ) + + return log_unnormalized_prob - log_normalization diff --git a/lib/python3.10/site-packages/torch/distributions/normal.py b/lib/python3.10/site-packages/torch/distributions/normal.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6f9b7170855d615eb2ee94902e5b40a6facefc --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/normal.py @@ -0,0 +1,112 @@ +# mypy: allow-untyped-defs +import math +from numbers import Number, Real + +import torch +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import _standard_normal, broadcast_all +from torch.types import _size + + +__all__ = ["Normal"] + + +class Normal(ExponentialFamily): + r""" + Creates a normal (also called Gaussian) distribution parameterized by + :attr:`loc` and :attr:`scale`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0])) + >>> m.sample() # normally distributed with loc=0 and scale=1 + tensor([ 0.1046]) + + Args: + loc (float or Tensor): mean of the distribution (often referred to as mu) + scale (float or Tensor): standard deviation of the distribution + (often referred to as sigma) + """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + has_rsample = True + _mean_carrier_measure = 0 + + @property + def mean(self): + return self.loc + + @property + def mode(self): + return self.loc + + @property + def stddev(self): + return self.scale + + @property + def variance(self): + return self.stddev.pow(2) + + def __init__(self, loc, scale, validate_args=None): + self.loc, self.scale = broadcast_all(loc, scale) + if isinstance(loc, Number) and isinstance(scale, Number): + batch_shape = torch.Size() + else: + batch_shape = self.loc.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Normal, _instance) + batch_shape = torch.Size(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + super(Normal, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + with torch.no_grad(): + return torch.normal(self.loc.expand(shape), self.scale.expand(shape)) + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) + return self.loc + eps * self.scale + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + # compute the variance + var = self.scale**2 + log_scale = ( + math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log() + ) + return ( + -((value - self.loc) ** 2) / (2 * var) + - log_scale + - math.log(math.sqrt(2 * math.pi)) + ) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 0.5 * ( + 1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)) + ) + + def icdf(self, value): + return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2) + + def entropy(self): + return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale) + + @property + def _natural_params(self): + return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal()) + + def _log_normalizer(self, x, y): + return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y) diff --git a/lib/python3.10/site-packages/torch/distributions/one_hot_categorical.py b/lib/python3.10/site-packages/torch/distributions/one_hot_categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..76cf6137b0c9df0e805b6e5e69992da09da77d1c --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/one_hot_categorical.py @@ -0,0 +1,132 @@ +# mypy: allow-untyped-defs +import torch +from torch.distributions import constraints +from torch.distributions.categorical import Categorical +from torch.distributions.distribution import Distribution +from torch.types import _size + + +__all__ = ["OneHotCategorical", "OneHotCategoricalStraightThrough"] + + +class OneHotCategorical(Distribution): + r""" + Creates a one-hot categorical distribution parameterized by :attr:`probs` or + :attr:`logits`. + + Samples are one-hot coded vectors of size ``probs.size(-1)``. + + .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, + and it will be normalized to sum to 1 along the last dimension. :attr:`probs` + will return this normalized value. + The `logits` argument will be interpreted as unnormalized log probabilities + and can therefore be any real number. It will likewise be normalized so that + the resulting probabilities sum to 1 along the last dimension. :attr:`logits` + will return this normalized value. + + See also: :func:`torch.distributions.Categorical` for specifications of + :attr:`probs` and :attr:`logits`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) + >>> m.sample() # equal probability of 0, 1, 2, 3 + tensor([ 0., 0., 0., 1.]) + + Args: + probs (Tensor): event probabilities + logits (Tensor): event log probabilities (unnormalized) + """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + support = constraints.one_hot + has_enumerate_support = True + + def __init__(self, probs=None, logits=None, validate_args=None): + self._categorical = Categorical(probs, logits) + batch_shape = self._categorical.batch_shape + event_shape = self._categorical.param_shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(OneHotCategorical, _instance) + batch_shape = torch.Size(batch_shape) + new._categorical = self._categorical.expand(batch_shape) + super(OneHotCategorical, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._categorical._new(*args, **kwargs) + + @property + def _param(self): + return self._categorical._param + + @property + def probs(self): + return self._categorical.probs + + @property + def logits(self): + return self._categorical.logits + + @property + def mean(self): + return self._categorical.probs + + @property + def mode(self): + probs = self._categorical.probs + mode = probs.argmax(axis=-1) + return torch.nn.functional.one_hot(mode, num_classes=probs.shape[-1]).to(probs) + + @property + def variance(self): + return self._categorical.probs * (1 - self._categorical.probs) + + @property + def param_shape(self): + return self._categorical.param_shape + + def sample(self, sample_shape=torch.Size()): + sample_shape = torch.Size(sample_shape) + probs = self._categorical.probs + num_events = self._categorical._num_events + indices = self._categorical.sample(sample_shape) + return torch.nn.functional.one_hot(indices, num_events).to(probs) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + indices = value.max(-1)[1] + return self._categorical.log_prob(indices) + + def entropy(self): + return self._categorical.entropy() + + def enumerate_support(self, expand=True): + n = self.event_shape[0] + values = torch.eye(n, dtype=self._param.dtype, device=self._param.device) + values = values.view((n,) + (1,) * len(self.batch_shape) + (n,)) + if expand: + values = values.expand((n,) + self.batch_shape + (n,)) + return values + + +class OneHotCategoricalStraightThrough(OneHotCategorical): + r""" + Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight- + through gradient estimator from [1]. + + [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation + (Bengio et al., 2013) + """ + has_rsample = True + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + samples = self.sample(sample_shape) + probs = self._categorical.probs # cached via @lazy_property + return samples + (probs - probs.detach()) diff --git a/lib/python3.10/site-packages/torch/distributions/pareto.py b/lib/python3.10/site-packages/torch/distributions/pareto.py new file mode 100644 index 0000000000000000000000000000000000000000..798330d7bca7a8e497a306606f3d722d3ffbd4f3 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/pareto.py @@ -0,0 +1,62 @@ +# mypy: allow-untyped-defs +from torch.distributions import constraints +from torch.distributions.exponential import Exponential +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, ExpTransform +from torch.distributions.utils import broadcast_all + + +__all__ = ["Pareto"] + + +class Pareto(TransformedDistribution): + r""" + Samples from a Pareto Type 1 distribution. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0])) + >>> m.sample() # sample from a Pareto distribution with scale=1 and alpha=1 + tensor([ 1.5623]) + + Args: + scale (float or Tensor): Scale parameter of the distribution + alpha (float or Tensor): Shape parameter of the distribution + """ + arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive} + + def __init__(self, scale, alpha, validate_args=None): + self.scale, self.alpha = broadcast_all(scale, alpha) + base_dist = Exponential(self.alpha, validate_args=validate_args) + transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)] + super().__init__(base_dist, transforms, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Pareto, _instance) + new.scale = self.scale.expand(batch_shape) + new.alpha = self.alpha.expand(batch_shape) + return super().expand(batch_shape, _instance=new) + + @property + def mean(self): + # mean is inf for alpha <= 1 + a = self.alpha.clamp(min=1) + return a * self.scale / (a - 1) + + @property + def mode(self): + return self.scale + + @property + def variance(self): + # var is inf for alpha <= 2 + a = self.alpha.clamp(min=2) + return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2)) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return constraints.greater_than_eq(self.scale) + + def entropy(self): + return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal()) diff --git a/lib/python3.10/site-packages/torch/distributions/poisson.py b/lib/python3.10/site-packages/torch/distributions/poisson.py new file mode 100644 index 0000000000000000000000000000000000000000..4f386e9361fdd9da310568b79238127fe2caa5e7 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/poisson.py @@ -0,0 +1,79 @@ +# mypy: allow-untyped-defs +from numbers import Number + +import torch +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import broadcast_all + + +__all__ = ["Poisson"] + + +class Poisson(ExponentialFamily): + r""" + Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter. + + Samples are nonnegative integers, with a pmf given by + + .. math:: + \mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!} + + Example:: + + >>> # xdoctest: +SKIP("poisson_cpu not implemented for 'Long'") + >>> m = Poisson(torch.tensor([4])) + >>> m.sample() + tensor([ 3.]) + + Args: + rate (Number, Tensor): the rate parameter + """ + arg_constraints = {"rate": constraints.nonnegative} + support = constraints.nonnegative_integer + + @property + def mean(self): + return self.rate + + @property + def mode(self): + return self.rate.floor() + + @property + def variance(self): + return self.rate + + def __init__(self, rate, validate_args=None): + (self.rate,) = broadcast_all(rate) + if isinstance(rate, Number): + batch_shape = torch.Size() + else: + batch_shape = self.rate.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Poisson, _instance) + batch_shape = torch.Size(batch_shape) + new.rate = self.rate.expand(batch_shape) + super(Poisson, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + with torch.no_grad(): + return torch.poisson(self.rate.expand(shape)) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + rate, value = broadcast_all(self.rate, value) + return value.xlogy(rate) - rate - (value + 1).lgamma() + + @property + def _natural_params(self): + return (torch.log(self.rate),) + + def _log_normalizer(self, x): + return torch.exp(x) diff --git a/lib/python3.10/site-packages/torch/distributions/relaxed_bernoulli.py b/lib/python3.10/site-packages/torch/distributions/relaxed_bernoulli.py new file mode 100644 index 0000000000000000000000000000000000000000..04f70519e805cdfc47f9546f8b318456a46c62c9 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/relaxed_bernoulli.py @@ -0,0 +1,152 @@ +# mypy: allow-untyped-defs +from numbers import Number + +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import SigmoidTransform +from torch.distributions.utils import ( + broadcast_all, + clamp_probs, + lazy_property, + logits_to_probs, + probs_to_logits, +) +from torch.types import _size + + +__all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"] + + +class LogitRelaxedBernoulli(Distribution): + r""" + Creates a LogitRelaxedBernoulli distribution parameterized by :attr:`probs` + or :attr:`logits` (but not both), which is the logit of a RelaxedBernoulli + distribution. + + Samples are logits of values in (0, 1). See [1] for more details. + + Args: + temperature (Tensor): relaxation temperature + probs (Number, Tensor): the probability of sampling `1` + logits (Number, Tensor): the log-odds of sampling `1` + + [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random + Variables (Maddison et al., 2017) + + [2] Categorical Reparametrization with Gumbel-Softmax + (Jang et al., 2017) + """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.real + + def __init__(self, temperature, probs=None, logits=None, validate_args=None): + self.temperature = temperature + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + is_scalar = isinstance(probs, Number) + (self.probs,) = broadcast_all(probs) + else: + is_scalar = isinstance(logits, Number) + (self.logits,) = broadcast_all(logits) + self._param = self.probs if probs is not None else self.logits + if is_scalar: + batch_shape = torch.Size() + else: + batch_shape = self._param.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LogitRelaxedBernoulli, _instance) + batch_shape = torch.Size(batch_shape) + new.temperature = self.temperature + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @lazy_property + def logits(self): + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits, is_binary=True) + + @property + def param_shape(self): + return self._param.size() + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + probs = clamp_probs(self.probs.expand(shape)) + uniforms = clamp_probs( + torch.rand(shape, dtype=probs.dtype, device=probs.device) + ) + return ( + uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p() + ) / self.temperature + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + diff = logits - value.mul(self.temperature) + return self.temperature.log() + diff - 2 * diff.exp().log1p() + + +class RelaxedBernoulli(TransformedDistribution): + r""" + Creates a RelaxedBernoulli distribution, parametrized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits` + (but not both). This is a relaxed version of the `Bernoulli` distribution, + so the values are in (0, 1), and has reparametrizable samples. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = RelaxedBernoulli(torch.tensor([2.2]), + ... torch.tensor([0.1, 0.2, 0.3, 0.99])) + >>> m.sample() + tensor([ 0.2951, 0.3442, 0.8918, 0.9021]) + + Args: + temperature (Tensor): relaxation temperature + probs (Number, Tensor): the probability of sampling `1` + logits (Number, Tensor): the log-odds of sampling `1` + """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.unit_interval + has_rsample = True + + def __init__(self, temperature, probs=None, logits=None, validate_args=None): + base_dist = LogitRelaxedBernoulli(temperature, probs, logits) + super().__init__(base_dist, SigmoidTransform(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(RelaxedBernoulli, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def temperature(self): + return self.base_dist.temperature + + @property + def logits(self): + return self.base_dist.logits + + @property + def probs(self): + return self.base_dist.probs diff --git a/lib/python3.10/site-packages/torch/distributions/relaxed_categorical.py b/lib/python3.10/site-packages/torch/distributions/relaxed_categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..0f9b027f1a500c8bd91bbf07e76fe746b9866d55 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/relaxed_categorical.py @@ -0,0 +1,142 @@ +# mypy: allow-untyped-defs +import torch +from torch.distributions import constraints +from torch.distributions.categorical import Categorical +from torch.distributions.distribution import Distribution +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import ExpTransform +from torch.distributions.utils import broadcast_all, clamp_probs +from torch.types import _size + + +__all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"] + + +class ExpRelaxedCategorical(Distribution): + r""" + Creates a ExpRelaxedCategorical parameterized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both). + Returns the log of a point in the simplex. Based on the interface to + :class:`OneHotCategorical`. + + Implementation based on [1]. + + See also: :func:`torch.distributions.OneHotCategorical` + + Args: + temperature (Tensor): relaxation temperature + probs (Tensor): event probabilities + logits (Tensor): unnormalized log probability for each event + + [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables + (Maddison et al., 2017) + + [2] Categorical Reparametrization with Gumbel-Softmax + (Jang et al., 2017) + """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + support = ( + constraints.real_vector + ) # The true support is actually a submanifold of this. + has_rsample = True + + def __init__(self, temperature, probs=None, logits=None, validate_args=None): + self._categorical = Categorical(probs, logits) + self.temperature = temperature + batch_shape = self._categorical.batch_shape + event_shape = self._categorical.param_shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ExpRelaxedCategorical, _instance) + batch_shape = torch.Size(batch_shape) + new.temperature = self.temperature + new._categorical = self._categorical.expand(batch_shape) + super(ExpRelaxedCategorical, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._categorical._new(*args, **kwargs) + + @property + def param_shape(self): + return self._categorical.param_shape + + @property + def logits(self): + return self._categorical.logits + + @property + def probs(self): + return self._categorical.probs + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + uniforms = clamp_probs( + torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) + ) + gumbels = -((-(uniforms.log())).log()) + scores = (self.logits + gumbels) / self.temperature + return scores - scores.logsumexp(dim=-1, keepdim=True) + + def log_prob(self, value): + K = self._categorical._num_events + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + log_scale = torch.full_like( + self.temperature, float(K) + ).lgamma() - self.temperature.log().mul(-(K - 1)) + score = logits - value.mul(self.temperature) + score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1) + return score + log_scale + + +class RelaxedOneHotCategorical(TransformedDistribution): + r""" + Creates a RelaxedOneHotCategorical distribution parametrized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits`. + This is a relaxed version of the :class:`OneHotCategorical` distribution, so + its samples are on simplex, and are reparametrizable. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), + ... torch.tensor([0.1, 0.2, 0.3, 0.4])) + >>> m.sample() + tensor([ 0.1294, 0.2324, 0.3859, 0.2523]) + + Args: + temperature (Tensor): relaxation temperature + probs (Tensor): event probabilities + logits (Tensor): unnormalized log probability for each event + """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + support = constraints.simplex + has_rsample = True + + def __init__(self, temperature, probs=None, logits=None, validate_args=None): + base_dist = ExpRelaxedCategorical( + temperature, probs, logits, validate_args=validate_args + ) + super().__init__(base_dist, ExpTransform(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(RelaxedOneHotCategorical, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def temperature(self): + return self.base_dist.temperature + + @property + def logits(self): + return self.base_dist.logits + + @property + def probs(self): + return self.base_dist.probs diff --git a/lib/python3.10/site-packages/torch/distributions/studentT.py b/lib/python3.10/site-packages/torch/distributions/studentT.py new file mode 100644 index 0000000000000000000000000000000000000000..50b2b995bd501dcf0c09d6b4f35aa1910f1a3182 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/studentT.py @@ -0,0 +1,119 @@ +# mypy: allow-untyped-defs +import math + +import torch +from torch import inf, nan +from torch.distributions import Chi2, constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import _standard_normal, broadcast_all +from torch.types import _size + + +__all__ = ["StudentT"] + + +class StudentT(Distribution): + r""" + Creates a Student's t-distribution parameterized by degree of + freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = StudentT(torch.tensor([2.0])) + >>> m.sample() # Student's t-distributed with degrees of freedom=2 + tensor([ 0.1046]) + + Args: + df (float or Tensor): degrees of freedom + loc (float or Tensor): mean of the distribution + scale (float or Tensor): scale of the distribution + """ + arg_constraints = { + "df": constraints.positive, + "loc": constraints.real, + "scale": constraints.positive, + } + support = constraints.real + has_rsample = True + + @property + def mean(self): + m = self.loc.clone(memory_format=torch.contiguous_format) + m[self.df <= 1] = nan + return m + + @property + def mode(self): + return self.loc + + @property + def variance(self): + m = self.df.clone(memory_format=torch.contiguous_format) + m[self.df > 2] = ( + self.scale[self.df > 2].pow(2) + * self.df[self.df > 2] + / (self.df[self.df > 2] - 2) + ) + m[(self.df <= 2) & (self.df > 1)] = inf + m[self.df <= 1] = nan + return m + + def __init__(self, df, loc=0.0, scale=1.0, validate_args=None): + self.df, self.loc, self.scale = broadcast_all(df, loc, scale) + self._chi2 = Chi2(self.df) + batch_shape = self.df.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(StudentT, _instance) + batch_shape = torch.Size(batch_shape) + new.df = self.df.expand(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + new._chi2 = self._chi2.expand(batch_shape) + super(StudentT, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + # NOTE: This does not agree with scipy implementation as much as other distributions. + # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor + # parameters seems to help. + + # X ~ Normal(0, 1) + # Z ~ Chi2(df) + # Y = X / sqrt(Z / df) ~ StudentT(df) + shape = self._extended_shape(sample_shape) + X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device) + Z = self._chi2.rsample(sample_shape) + Y = X * torch.rsqrt(Z / self.df) + return self.loc + self.scale * Y + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + y = (value - self.loc) / self.scale + Z = ( + self.scale.log() + + 0.5 * self.df.log() + + 0.5 * math.log(math.pi) + + torch.lgamma(0.5 * self.df) + - torch.lgamma(0.5 * (self.df + 1.0)) + ) + return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z + + def entropy(self): + lbeta = ( + torch.lgamma(0.5 * self.df) + + math.lgamma(0.5) + - torch.lgamma(0.5 * (self.df + 1)) + ) + return ( + self.scale.log() + + 0.5 + * (self.df + 1) + * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) + + 0.5 * self.df.log() + + lbeta + ) diff --git a/lib/python3.10/site-packages/torch/distributions/transformed_distribution.py b/lib/python3.10/site-packages/torch/distributions/transformed_distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..a9450accea23057e01aeb7bb9d33e468124f7499 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/transformed_distribution.py @@ -0,0 +1,216 @@ +# mypy: allow-untyped-defs +from typing import Dict + +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.independent import Independent +from torch.distributions.transforms import ComposeTransform, Transform +from torch.distributions.utils import _sum_rightmost +from torch.types import _size + + +__all__ = ["TransformedDistribution"] + + +class TransformedDistribution(Distribution): + r""" + Extension of the Distribution class, which applies a sequence of Transforms + to a base distribution. Let f be the composition of transforms applied:: + + X ~ BaseDistribution + Y = f(X) ~ TransformedDistribution(BaseDistribution, f) + log p(Y) = log p(X) + log |det (dX/dY)| + + Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the + maximum shape of its base distribution and its transforms, since transforms + can introduce correlations among events. + + An example for the usage of :class:`TransformedDistribution` would be:: + + # Building a Logistic Distribution + # X ~ Uniform(0, 1) + # f = a + b * logit(X) + # Y ~ f(X) ~ Logistic(a, b) + base_distribution = Uniform(0, 1) + transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] + logistic = TransformedDistribution(base_distribution, transforms) + + For more examples, please look at the implementations of + :class:`~torch.distributions.gumbel.Gumbel`, + :class:`~torch.distributions.half_cauchy.HalfCauchy`, + :class:`~torch.distributions.half_normal.HalfNormal`, + :class:`~torch.distributions.log_normal.LogNormal`, + :class:`~torch.distributions.pareto.Pareto`, + :class:`~torch.distributions.weibull.Weibull`, + :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and + :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` + """ + arg_constraints: Dict[str, constraints.Constraint] = {} + + def __init__(self, base_distribution, transforms, validate_args=None): + if isinstance(transforms, Transform): + self.transforms = [ + transforms, + ] + elif isinstance(transforms, list): + if not all(isinstance(t, Transform) for t in transforms): + raise ValueError( + "transforms must be a Transform or a list of Transforms" + ) + self.transforms = transforms + else: + raise ValueError( + f"transforms must be a Transform or list, but was {transforms}" + ) + + # Reshape base_distribution according to transforms. + base_shape = base_distribution.batch_shape + base_distribution.event_shape + base_event_dim = len(base_distribution.event_shape) + transform = ComposeTransform(self.transforms) + if len(base_shape) < transform.domain.event_dim: + raise ValueError( + f"base_distribution needs to have shape with size at least {transform.domain.event_dim}, but got {base_shape}." + ) + forward_shape = transform.forward_shape(base_shape) + expanded_base_shape = transform.inverse_shape(forward_shape) + if base_shape != expanded_base_shape: + base_batch_shape = expanded_base_shape[ + : len(expanded_base_shape) - base_event_dim + ] + base_distribution = base_distribution.expand(base_batch_shape) + reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim + if reinterpreted_batch_ndims > 0: + base_distribution = Independent( + base_distribution, reinterpreted_batch_ndims + ) + self.base_dist = base_distribution + + # Compute shapes. + transform_change_in_event_dim = ( + transform.codomain.event_dim - transform.domain.event_dim + ) + event_dim = max( + transform.codomain.event_dim, # the transform is coupled + base_event_dim + transform_change_in_event_dim, # the base dist is coupled + ) + assert len(forward_shape) >= event_dim + cut = len(forward_shape) - event_dim + batch_shape = forward_shape[:cut] + event_shape = forward_shape[cut:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(TransformedDistribution, _instance) + batch_shape = torch.Size(batch_shape) + shape = batch_shape + self.event_shape + for t in reversed(self.transforms): + shape = t.inverse_shape(shape) + base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)] + new.base_dist = self.base_dist.expand(base_batch_shape) + new.transforms = self.transforms + super(TransformedDistribution, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @constraints.dependent_property(is_discrete=False) + def support(self): + if not self.transforms: + return self.base_dist.support + support = self.transforms[-1].codomain + if len(self.event_shape) > support.event_dim: + support = constraints.independent( + support, len(self.event_shape) - support.event_dim + ) + return support + + @property + def has_rsample(self): + return self.base_dist.has_rsample + + def sample(self, sample_shape=torch.Size()): + """ + Generates a sample_shape shaped sample or sample_shape shaped batch of + samples if the distribution parameters are batched. Samples first from + base distribution and applies `transform()` for every transform in the + list. + """ + with torch.no_grad(): + x = self.base_dist.sample(sample_shape) + for transform in self.transforms: + x = transform(x) + return x + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + """ + Generates a sample_shape shaped reparameterized sample or sample_shape + shaped batch of reparameterized samples if the distribution parameters + are batched. Samples first from base distribution and applies + `transform()` for every transform in the list. + """ + x = self.base_dist.rsample(sample_shape) + for transform in self.transforms: + x = transform(x) + return x + + def log_prob(self, value): + """ + Scores the sample by inverting the transform(s) and computing the score + using the score of the base distribution and the log abs det jacobian. + """ + if self._validate_args: + self._validate_sample(value) + event_dim = len(self.event_shape) + log_prob = 0.0 + y = value + for transform in reversed(self.transforms): + x = transform.inv(y) + event_dim += transform.domain.event_dim - transform.codomain.event_dim + log_prob = log_prob - _sum_rightmost( + transform.log_abs_det_jacobian(x, y), + event_dim - transform.domain.event_dim, + ) + y = x + + log_prob = log_prob + _sum_rightmost( + self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape) + ) + return log_prob + + def _monotonize_cdf(self, value): + """ + This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is + monotone increasing. + """ + sign = 1 + for transform in self.transforms: + sign = sign * transform.sign + if isinstance(sign, int) and sign == 1: + return value + return sign * (value - 0.5) + 0.5 + + def cdf(self, value): + """ + Computes the cumulative distribution function by inverting the + transform(s) and computing the score of the base distribution. + """ + for transform in self.transforms[::-1]: + value = transform.inv(value) + if self._validate_args: + self.base_dist._validate_sample(value) + value = self.base_dist.cdf(value) + value = self._monotonize_cdf(value) + return value + + def icdf(self, value): + """ + Computes the inverse cumulative distribution function using + transform(s) and computing the score of the base distribution. + """ + value = self._monotonize_cdf(value) + value = self.base_dist.icdf(value) + for transform in self.transforms: + value = transform(value) + return value diff --git a/lib/python3.10/site-packages/torch/distributions/transforms.py b/lib/python3.10/site-packages/torch/distributions/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d42983b00db6a27054f54b1d16f9519a3c1ff0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/transforms.py @@ -0,0 +1,1247 @@ +# mypy: allow-untyped-defs +import functools +import math +import numbers +import operator +import weakref +from typing import List + +import torch +import torch.nn.functional as F +from torch.distributions import constraints +from torch.distributions.utils import ( + _sum_rightmost, + broadcast_all, + lazy_property, + tril_matrix_to_vec, + vec_to_tril_matrix, +) +from torch.nn.functional import pad, softplus + + +__all__ = [ + "AbsTransform", + "AffineTransform", + "CatTransform", + "ComposeTransform", + "CorrCholeskyTransform", + "CumulativeDistributionTransform", + "ExpTransform", + "IndependentTransform", + "LowerCholeskyTransform", + "PositiveDefiniteTransform", + "PowerTransform", + "ReshapeTransform", + "SigmoidTransform", + "SoftplusTransform", + "TanhTransform", + "SoftmaxTransform", + "StackTransform", + "StickBreakingTransform", + "Transform", + "identity_transform", +] + + +class Transform: + """ + Abstract class for invertable transformations with computable log + det jacobians. They are primarily used in + :class:`torch.distributions.TransformedDistribution`. + + Caching is useful for transforms whose inverses are either expensive or + numerically unstable. Note that care must be taken with memoized values + since the autograd graph may be reversed. For example while the following + works with or without caching:: + + y = t(x) + t.log_abs_det_jacobian(x, y).backward() # x will receive gradients. + + However the following will error when caching due to dependency reversal:: + + y = t(x) + z = t.inv(y) + grad(z.sum(), [y]) # error because z is x + + Derived classes should implement one or both of :meth:`_call` or + :meth:`_inverse`. Derived classes that set `bijective=True` should also + implement :meth:`log_abs_det_jacobian`. + + Args: + cache_size (int): Size of cache. If zero, no caching is done. If one, + the latest single value is cached. Only 0 and 1 are supported. + + Attributes: + domain (:class:`~torch.distributions.constraints.Constraint`): + The constraint representing valid inputs to this transform. + codomain (:class:`~torch.distributions.constraints.Constraint`): + The constraint representing valid outputs to this transform + which are inputs to the inverse transform. + bijective (bool): Whether this transform is bijective. A transform + ``t`` is bijective iff ``t.inv(t(x)) == x`` and + ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in + the codomain. Transforms that are not bijective should at least + maintain the weaker pseudoinverse properties + ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``. + sign (int or Tensor): For bijective univariate transforms, this + should be +1 or -1 depending on whether transform is monotone + increasing or decreasing. + """ + + bijective = False + domain: constraints.Constraint + codomain: constraints.Constraint + + def __init__(self, cache_size=0): + self._cache_size = cache_size + self._inv = None + if cache_size == 0: + pass # default behavior + elif cache_size == 1: + self._cached_x_y = None, None + else: + raise ValueError("cache_size must be 0 or 1") + super().__init__() + + def __getstate__(self): + state = self.__dict__.copy() + state["_inv"] = None + return state + + @property + def event_dim(self): + if self.domain.event_dim == self.codomain.event_dim: + return self.domain.event_dim + raise ValueError("Please use either .domain.event_dim or .codomain.event_dim") + + @property + def inv(self): + """ + Returns the inverse :class:`Transform` of this transform. + This should satisfy ``t.inv.inv is t``. + """ + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = _InverseTransform(self) + self._inv = weakref.ref(inv) + return inv + + @property + def sign(self): + """ + Returns the sign of the determinant of the Jacobian, if applicable. + In general this only makes sense for bijective transforms. + """ + raise NotImplementedError + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + if type(self).__init__ is Transform.__init__: + return type(self)(cache_size=cache_size) + raise NotImplementedError(f"{type(self)}.with_cache is not implemented") + + def __eq__(self, other): + return self is other + + def __ne__(self, other): + # Necessary for Python2 + return not self.__eq__(other) + + def __call__(self, x): + """ + Computes the transform `x => y`. + """ + if self._cache_size == 0: + return self._call(x) + x_old, y_old = self._cached_x_y + if x is x_old: + return y_old + y = self._call(x) + self._cached_x_y = x, y + return y + + def _inv_call(self, y): + """ + Inverts the transform `y => x`. + """ + if self._cache_size == 0: + return self._inverse(y) + x_old, y_old = self._cached_x_y + if y is y_old: + return x_old + x = self._inverse(y) + self._cached_x_y = x, y + return x + + def _call(self, x): + """ + Abstract method to compute forward transformation. + """ + raise NotImplementedError + + def _inverse(self, y): + """ + Abstract method to compute inverse transformation. + """ + raise NotImplementedError + + def log_abs_det_jacobian(self, x, y): + """ + Computes the log det jacobian `log |dy/dx|` given input and output. + """ + raise NotImplementedError + + def __repr__(self): + return self.__class__.__name__ + "()" + + def forward_shape(self, shape): + """ + Infers the shape of the forward computation, given the input shape. + Defaults to preserving shape. + """ + return shape + + def inverse_shape(self, shape): + """ + Infers the shapes of the inverse computation, given the output shape. + Defaults to preserving shape. + """ + return shape + + +class _InverseTransform(Transform): + """ + Inverts a single :class:`Transform`. + This class is private; please instead use the ``Transform.inv`` property. + """ + + def __init__(self, transform: Transform): + super().__init__(cache_size=transform._cache_size) + self._inv: Transform = transform + + @constraints.dependent_property(is_discrete=False) + def domain(self): + assert self._inv is not None + return self._inv.codomain + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + assert self._inv is not None + return self._inv.domain + + @property + def bijective(self): + assert self._inv is not None + return self._inv.bijective + + @property + def sign(self): + assert self._inv is not None + return self._inv.sign + + @property + def inv(self): + return self._inv + + def with_cache(self, cache_size=1): + assert self._inv is not None + return self.inv.with_cache(cache_size).inv + + def __eq__(self, other): + if not isinstance(other, _InverseTransform): + return False + assert self._inv is not None + return self._inv == other._inv + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self._inv)})" + + def __call__(self, x): + assert self._inv is not None + return self._inv._inv_call(x) + + def log_abs_det_jacobian(self, x, y): + assert self._inv is not None + return -self._inv.log_abs_det_jacobian(y, x) + + def forward_shape(self, shape): + return self._inv.inverse_shape(shape) + + def inverse_shape(self, shape): + return self._inv.forward_shape(shape) + + +class ComposeTransform(Transform): + """ + Composes multiple transforms in a chain. + The transforms being composed are responsible for caching. + + Args: + parts (list of :class:`Transform`): A list of transforms to compose. + cache_size (int): Size of cache. If zero, no caching is done. If one, + the latest single value is cached. Only 0 and 1 are supported. + """ + + def __init__(self, parts: List[Transform], cache_size=0): + if cache_size: + parts = [part.with_cache(cache_size) for part in parts] + super().__init__(cache_size=cache_size) + self.parts = parts + + def __eq__(self, other): + if not isinstance(other, ComposeTransform): + return False + return self.parts == other.parts + + @constraints.dependent_property(is_discrete=False) + def domain(self): + if not self.parts: + return constraints.real + domain = self.parts[0].domain + # Adjust event_dim to be maximum among all parts. + event_dim = self.parts[-1].codomain.event_dim + for part in reversed(self.parts): + event_dim += part.domain.event_dim - part.codomain.event_dim + event_dim = max(event_dim, part.domain.event_dim) + assert event_dim >= domain.event_dim + if event_dim > domain.event_dim: + domain = constraints.independent(domain, event_dim - domain.event_dim) + return domain + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + if not self.parts: + return constraints.real + codomain = self.parts[-1].codomain + # Adjust event_dim to be maximum among all parts. + event_dim = self.parts[0].domain.event_dim + for part in self.parts: + event_dim += part.codomain.event_dim - part.domain.event_dim + event_dim = max(event_dim, part.codomain.event_dim) + assert event_dim >= codomain.event_dim + if event_dim > codomain.event_dim: + codomain = constraints.independent(codomain, event_dim - codomain.event_dim) + return codomain + + @lazy_property + def bijective(self): + return all(p.bijective for p in self.parts) + + @lazy_property + def sign(self): + sign = 1 + for p in self.parts: + sign = sign * p.sign + return sign + + @property + def inv(self): + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = ComposeTransform([p.inv for p in reversed(self.parts)]) + self._inv = weakref.ref(inv) + inv._inv = weakref.ref(self) + return inv + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return ComposeTransform(self.parts, cache_size=cache_size) + + def __call__(self, x): + for part in self.parts: + x = part(x) + return x + + def log_abs_det_jacobian(self, x, y): + if not self.parts: + return torch.zeros_like(x) + + # Compute intermediates. This will be free if parts[:-1] are all cached. + xs = [x] + for part in self.parts[:-1]: + xs.append(part(xs[-1])) + xs.append(y) + + terms = [] + event_dim = self.domain.event_dim + for part, x, y in zip(self.parts, xs[:-1], xs[1:]): + terms.append( + _sum_rightmost( + part.log_abs_det_jacobian(x, y), event_dim - part.domain.event_dim + ) + ) + event_dim += part.codomain.event_dim - part.domain.event_dim + return functools.reduce(operator.add, terms) + + def forward_shape(self, shape): + for part in self.parts: + shape = part.forward_shape(shape) + return shape + + def inverse_shape(self, shape): + for part in reversed(self.parts): + shape = part.inverse_shape(shape) + return shape + + def __repr__(self): + fmt_string = self.__class__.__name__ + "(\n " + fmt_string += ",\n ".join([p.__repr__() for p in self.parts]) + fmt_string += "\n)" + return fmt_string + + +identity_transform = ComposeTransform([]) + + +class IndependentTransform(Transform): + """ + Wrapper around another transform to treat + ``reinterpreted_batch_ndims``-many extra of the right most dimensions as + dependent. This has no effect on the forward or backward transforms, but + does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions + in :meth:`log_abs_det_jacobian`. + + Args: + base_transform (:class:`Transform`): A base transform. + reinterpreted_batch_ndims (int): The number of extra rightmost + dimensions to treat as dependent. + """ + + def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0): + super().__init__(cache_size=cache_size) + self.base_transform = base_transform.with_cache(cache_size) + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return IndependentTransform( + self.base_transform, self.reinterpreted_batch_ndims, cache_size=cache_size + ) + + @constraints.dependent_property(is_discrete=False) + def domain(self): + return constraints.independent( + self.base_transform.domain, self.reinterpreted_batch_ndims + ) + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + return constraints.independent( + self.base_transform.codomain, self.reinterpreted_batch_ndims + ) + + @property + def bijective(self): + return self.base_transform.bijective + + @property + def sign(self): + return self.base_transform.sign + + def _call(self, x): + if x.dim() < self.domain.event_dim: + raise ValueError("Too few dimensions on input") + return self.base_transform(x) + + def _inverse(self, y): + if y.dim() < self.codomain.event_dim: + raise ValueError("Too few dimensions on input") + return self.base_transform.inv(y) + + def log_abs_det_jacobian(self, x, y): + result = self.base_transform.log_abs_det_jacobian(x, y) + result = _sum_rightmost(result, self.reinterpreted_batch_ndims) + return result + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})" + + def forward_shape(self, shape): + return self.base_transform.forward_shape(shape) + + def inverse_shape(self, shape): + return self.base_transform.inverse_shape(shape) + + +class ReshapeTransform(Transform): + """ + Unit Jacobian transform to reshape the rightmost part of a tensor. + + Note that ``in_shape`` and ``out_shape`` must have the same number of + elements, just as for :meth:`torch.Tensor.reshape`. + + Arguments: + in_shape (torch.Size): The input event shape. + out_shape (torch.Size): The output event shape. + """ + + bijective = True + + def __init__(self, in_shape, out_shape, cache_size=0): + self.in_shape = torch.Size(in_shape) + self.out_shape = torch.Size(out_shape) + if self.in_shape.numel() != self.out_shape.numel(): + raise ValueError("in_shape, out_shape have different numbers of elements") + super().__init__(cache_size=cache_size) + + @constraints.dependent_property + def domain(self): + return constraints.independent(constraints.real, len(self.in_shape)) + + @constraints.dependent_property + def codomain(self): + return constraints.independent(constraints.real, len(self.out_shape)) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size) + + def _call(self, x): + batch_shape = x.shape[: x.dim() - len(self.in_shape)] + return x.reshape(batch_shape + self.out_shape) + + def _inverse(self, y): + batch_shape = y.shape[: y.dim() - len(self.out_shape)] + return y.reshape(batch_shape + self.in_shape) + + def log_abs_det_jacobian(self, x, y): + batch_shape = x.shape[: x.dim() - len(self.in_shape)] + return x.new_zeros(batch_shape) + + def forward_shape(self, shape): + if len(shape) < len(self.in_shape): + raise ValueError("Too few dimensions on input") + cut = len(shape) - len(self.in_shape) + if shape[cut:] != self.in_shape: + raise ValueError( + f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}" + ) + return shape[:cut] + self.out_shape + + def inverse_shape(self, shape): + if len(shape) < len(self.out_shape): + raise ValueError("Too few dimensions on input") + cut = len(shape) - len(self.out_shape) + if shape[cut:] != self.out_shape: + raise ValueError( + f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}" + ) + return shape[:cut] + self.in_shape + + +class ExpTransform(Transform): + r""" + Transform via the mapping :math:`y = \exp(x)`. + """ + domain = constraints.real + codomain = constraints.positive + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, ExpTransform) + + def _call(self, x): + return x.exp() + + def _inverse(self, y): + return y.log() + + def log_abs_det_jacobian(self, x, y): + return x + + +class PowerTransform(Transform): + r""" + Transform via the mapping :math:`y = x^{\text{exponent}}`. + """ + domain = constraints.positive + codomain = constraints.positive + bijective = True + + def __init__(self, exponent, cache_size=0): + super().__init__(cache_size=cache_size) + (self.exponent,) = broadcast_all(exponent) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return PowerTransform(self.exponent, cache_size=cache_size) + + @lazy_property + def sign(self): + return self.exponent.sign() + + def __eq__(self, other): + if not isinstance(other, PowerTransform): + return False + return self.exponent.eq(other.exponent).all().item() + + def _call(self, x): + return x.pow(self.exponent) + + def _inverse(self, y): + return y.pow(1 / self.exponent) + + def log_abs_det_jacobian(self, x, y): + return (self.exponent * y / x).abs().log() + + def forward_shape(self, shape): + return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) + + def inverse_shape(self, shape): + return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) + + +def _clipped_sigmoid(x): + finfo = torch.finfo(x.dtype) + return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1.0 - finfo.eps) + + +class SigmoidTransform(Transform): + r""" + Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`. + """ + domain = constraints.real + codomain = constraints.unit_interval + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, SigmoidTransform) + + def _call(self, x): + return _clipped_sigmoid(x) + + def _inverse(self, y): + finfo = torch.finfo(y.dtype) + y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps) + return y.log() - (-y).log1p() + + def log_abs_det_jacobian(self, x, y): + return -F.softplus(-x) - F.softplus(x) + + +class SoftplusTransform(Transform): + r""" + Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`. + The implementation reverts to the linear function when :math:`x > 20`. + """ + domain = constraints.real + codomain = constraints.positive + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, SoftplusTransform) + + def _call(self, x): + return softplus(x) + + def _inverse(self, y): + return (-y).expm1().neg().log() + y + + def log_abs_det_jacobian(self, x, y): + return -softplus(-x) + + +class TanhTransform(Transform): + r""" + Transform via the mapping :math:`y = \tanh(x)`. + + It is equivalent to + ``` + ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)]) + ``` + However this might not be numerically stable, thus it is recommended to use `TanhTransform` + instead. + + Note that one should use `cache_size=1` when it comes to `NaN/Inf` values. + + """ + domain = constraints.real + codomain = constraints.interval(-1.0, 1.0) + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, TanhTransform) + + def _call(self, x): + return x.tanh() + + def _inverse(self, y): + # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. + # one should use `cache_size=1` instead + return torch.atanh(y) + + def log_abs_det_jacobian(self, x, y): + # We use a formula that is more numerically stable, see details in the following link + # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80 + return 2.0 * (math.log(2.0) - x - softplus(-2.0 * x)) + + +class AbsTransform(Transform): + r""" + Transform via the mapping :math:`y = |x|`. + """ + domain = constraints.real + codomain = constraints.positive + + def __eq__(self, other): + return isinstance(other, AbsTransform) + + def _call(self, x): + return x.abs() + + def _inverse(self, y): + return y + + +class AffineTransform(Transform): + r""" + Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`. + + Args: + loc (Tensor or float): Location parameter. + scale (Tensor or float): Scale parameter. + event_dim (int): Optional size of `event_shape`. This should be zero + for univariate random variables, 1 for distributions over vectors, + 2 for distributions over matrices, etc. + """ + bijective = True + + def __init__(self, loc, scale, event_dim=0, cache_size=0): + super().__init__(cache_size=cache_size) + self.loc = loc + self.scale = scale + self._event_dim = event_dim + + @property + def event_dim(self): + return self._event_dim + + @constraints.dependent_property(is_discrete=False) + def domain(self): + if self.event_dim == 0: + return constraints.real + return constraints.independent(constraints.real, self.event_dim) + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + if self.event_dim == 0: + return constraints.real + return constraints.independent(constraints.real, self.event_dim) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return AffineTransform( + self.loc, self.scale, self.event_dim, cache_size=cache_size + ) + + def __eq__(self, other): + if not isinstance(other, AffineTransform): + return False + + if isinstance(self.loc, numbers.Number) and isinstance( + other.loc, numbers.Number + ): + if self.loc != other.loc: + return False + else: + if not (self.loc == other.loc).all().item(): + return False + + if isinstance(self.scale, numbers.Number) and isinstance( + other.scale, numbers.Number + ): + if self.scale != other.scale: + return False + else: + if not (self.scale == other.scale).all().item(): + return False + + return True + + @property + def sign(self): + if isinstance(self.scale, numbers.Real): + return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0 + return self.scale.sign() + + def _call(self, x): + return self.loc + self.scale * x + + def _inverse(self, y): + return (y - self.loc) / self.scale + + def log_abs_det_jacobian(self, x, y): + shape = x.shape + scale = self.scale + if isinstance(scale, numbers.Real): + result = torch.full_like(x, math.log(abs(scale))) + else: + result = torch.abs(scale).log() + if self.event_dim: + result_size = result.size()[: -self.event_dim] + (-1,) + result = result.view(result_size).sum(-1) + shape = shape[: -self.event_dim] + return result.expand(shape) + + def forward_shape(self, shape): + return torch.broadcast_shapes( + shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) + ) + + def inverse_shape(self, shape): + return torch.broadcast_shapes( + shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) + ) + + +class CorrCholeskyTransform(Transform): + r""" + Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the + Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower + triangular matrix with positive diagonals and unit Euclidean norm for each row. + The transform is processed as follows: + + 1. First we convert x into a lower triangular matrix in row order. + 2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of + class :class:`StickBreakingTransform` to transform :math:`X_i` into a + unit Euclidean length vector using the following steps: + - Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`. + - Transforms into an unsigned domain: :math:`z_i = r_i^2`. + - Applies :math:`s_i = StickBreakingTransform(z_i)`. + - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`. + """ + domain = constraints.real_vector + codomain = constraints.corr_cholesky + bijective = True + + def _call(self, x): + x = torch.tanh(x) + eps = torch.finfo(x.dtype).eps + x = x.clamp(min=-1 + eps, max=1 - eps) + r = vec_to_tril_matrix(x, diag=-1) + # apply stick-breaking on the squared values + # Note that y = sign(r) * sqrt(z * z1m_cumprod) + # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) + z = r**2 + z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) + # Diagonal elements must be 1. + r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device) + y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1) + return y + + def _inverse(self, y): + # inverse stick-breaking + # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html + y_cumsum = 1 - torch.cumsum(y * y, dim=-1) + y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1) + y_vec = tril_matrix_to_vec(y, diag=-1) + y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1) + t = y_vec / (y_cumsum_vec).sqrt() + # inverse of tanh + x = (t.log1p() - t.neg().log1p()) / 2 + return x + + def log_abs_det_jacobian(self, x, y, intermediates=None): + # Because domain and codomain are two spaces with different dimensions, determinant of + # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the + # flattened lower triangular part of `y`. + + # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html + y1m_cumsum = 1 - (y * y).cumsum(dim=-1) + # by taking diagonal=-2, we don't need to shift z_cumprod to the right + # also works for 2 x 2 matrix + y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2) + stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1) + tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.0)).sum(dim=-1) + return stick_breaking_logdet + tanh_logdet + + def forward_shape(self, shape): + # Reshape from (..., N) to (..., D, D). + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + N = shape[-1] + D = round((0.25 + 2 * N) ** 0.5 + 0.5) + if D * (D - 1) // 2 != N: + raise ValueError("Input is not a flattend lower-diagonal number") + return shape[:-1] + (D, D) + + def inverse_shape(self, shape): + # Reshape from (..., D, D) to (..., N). + if len(shape) < 2: + raise ValueError("Too few dimensions on input") + if shape[-2] != shape[-1]: + raise ValueError("Input is not square") + D = shape[-1] + N = D * (D - 1) // 2 + return shape[:-2] + (N,) + + +class SoftmaxTransform(Transform): + r""" + Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then + normalizing. + + This is not bijective and cannot be used for HMC. However this acts mostly + coordinate-wise (except for the final normalization), and thus is + appropriate for coordinate-wise optimization algorithms. + """ + domain = constraints.real_vector + codomain = constraints.simplex + + def __eq__(self, other): + return isinstance(other, SoftmaxTransform) + + def _call(self, x): + logprobs = x + probs = (logprobs - logprobs.max(-1, True)[0]).exp() + return probs / probs.sum(-1, True) + + def _inverse(self, y): + probs = y + return probs.log() + + def forward_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape + + def inverse_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape + + +class StickBreakingTransform(Transform): + """ + Transform from unconstrained space to the simplex of one additional + dimension via a stick-breaking process. + + This transform arises as an iterated sigmoid transform in a stick-breaking + construction of the `Dirichlet` distribution: the first logit is + transformed via sigmoid to the first probability and the probability of + everything else, and then the process recurses. + + This is bijective and appropriate for use in HMC; however it mixes + coordinates together and is less appropriate for optimization. + """ + + domain = constraints.real_vector + codomain = constraints.simplex + bijective = True + + def __eq__(self, other): + return isinstance(other, StickBreakingTransform) + + def _call(self, x): + offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) + z = _clipped_sigmoid(x - offset.log()) + z_cumprod = (1 - z).cumprod(-1) + y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1) + return y + + def _inverse(self, y): + y_crop = y[..., :-1] + offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1) + sf = 1 - y_crop.cumsum(-1) + # we clamp to make sure that sf is positive which sometimes does not + # happen when y[-1] ~ 0 or y[:-1].sum() ~ 1 + sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny) + x = y_crop.log() - sf.log() + offset.log() + return x + + def log_abs_det_jacobian(self, x, y): + offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) + x = x - offset.log() + # use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x) + detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1) + return detJ + + def forward_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape[:-1] + (shape[-1] + 1,) + + def inverse_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape[:-1] + (shape[-1] - 1,) + + +class LowerCholeskyTransform(Transform): + """ + Transform from unconstrained matrices to lower-triangular matrices with + nonnegative diagonal entries. + + This is useful for parameterizing positive definite matrices in terms of + their Cholesky factorization. + """ + + domain = constraints.independent(constraints.real, 2) + codomain = constraints.lower_cholesky + + def __eq__(self, other): + return isinstance(other, LowerCholeskyTransform) + + def _call(self, x): + return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed() + + def _inverse(self, y): + return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed() + + +class PositiveDefiniteTransform(Transform): + """ + Transform from unconstrained matrices to positive-definite matrices. + """ + + domain = constraints.independent(constraints.real, 2) + codomain = constraints.positive_definite # type: ignore[assignment] + + def __eq__(self, other): + return isinstance(other, PositiveDefiniteTransform) + + def _call(self, x): + x = LowerCholeskyTransform()(x) + return x @ x.mT + + def _inverse(self, y): + y = torch.linalg.cholesky(y) + return LowerCholeskyTransform().inv(y) + + +class CatTransform(Transform): + """ + Transform functor that applies a sequence of transforms `tseq` + component-wise to each submatrix at `dim`, of length `lengths[dim]`, + in a way compatible with :func:`torch.cat`. + + Example:: + + x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0) + x = torch.cat([x0, x0], dim=0) + t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10]) + t = CatTransform([t0, t0], dim=0, lengths=[20, 20]) + y = t(x) + """ + + transforms: List[Transform] + + def __init__(self, tseq, dim=0, lengths=None, cache_size=0): + assert all(isinstance(t, Transform) for t in tseq) + if cache_size: + tseq = [t.with_cache(cache_size) for t in tseq] + super().__init__(cache_size=cache_size) + self.transforms = list(tseq) + if lengths is None: + lengths = [1] * len(self.transforms) + self.lengths = list(lengths) + assert len(self.lengths) == len(self.transforms) + self.dim = dim + + @lazy_property + def event_dim(self): + return max(t.event_dim for t in self.transforms) + + @lazy_property + def length(self): + return sum(self.lengths) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return CatTransform(self.transforms, self.dim, self.lengths, cache_size) + + def _call(self, x): + assert -x.dim() <= self.dim < x.dim() + assert x.size(self.dim) == self.length + yslices = [] + start = 0 + for trans, length in zip(self.transforms, self.lengths): + xslice = x.narrow(self.dim, start, length) + yslices.append(trans(xslice)) + start = start + length # avoid += for jit compat + return torch.cat(yslices, dim=self.dim) + + def _inverse(self, y): + assert -y.dim() <= self.dim < y.dim() + assert y.size(self.dim) == self.length + xslices = [] + start = 0 + for trans, length in zip(self.transforms, self.lengths): + yslice = y.narrow(self.dim, start, length) + xslices.append(trans.inv(yslice)) + start = start + length # avoid += for jit compat + return torch.cat(xslices, dim=self.dim) + + def log_abs_det_jacobian(self, x, y): + assert -x.dim() <= self.dim < x.dim() + assert x.size(self.dim) == self.length + assert -y.dim() <= self.dim < y.dim() + assert y.size(self.dim) == self.length + logdetjacs = [] + start = 0 + for trans, length in zip(self.transforms, self.lengths): + xslice = x.narrow(self.dim, start, length) + yslice = y.narrow(self.dim, start, length) + logdetjac = trans.log_abs_det_jacobian(xslice, yslice) + if trans.event_dim < self.event_dim: + logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim) + logdetjacs.append(logdetjac) + start = start + length # avoid += for jit compat + # Decide whether to concatenate or sum. + dim = self.dim + if dim >= 0: + dim = dim - x.dim() + dim = dim + self.event_dim + if dim < 0: + return torch.cat(logdetjacs, dim=dim) + else: + return sum(logdetjacs) + + @property + def bijective(self): + return all(t.bijective for t in self.transforms) + + @constraints.dependent_property + def domain(self): + return constraints.cat( + [t.domain for t in self.transforms], self.dim, self.lengths + ) + + @constraints.dependent_property + def codomain(self): + return constraints.cat( + [t.codomain for t in self.transforms], self.dim, self.lengths + ) + + +class StackTransform(Transform): + """ + Transform functor that applies a sequence of transforms `tseq` + component-wise to each submatrix at `dim` + in a way compatible with :func:`torch.stack`. + + Example:: + + x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1) + t = StackTransform([ExpTransform(), identity_transform], dim=1) + y = t(x) + """ + + transforms: List[Transform] + + def __init__(self, tseq, dim=0, cache_size=0): + assert all(isinstance(t, Transform) for t in tseq) + if cache_size: + tseq = [t.with_cache(cache_size) for t in tseq] + super().__init__(cache_size=cache_size) + self.transforms = list(tseq) + self.dim = dim + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return StackTransform(self.transforms, self.dim, cache_size) + + def _slice(self, z): + return [z.select(self.dim, i) for i in range(z.size(self.dim))] + + def _call(self, x): + assert -x.dim() <= self.dim < x.dim() + assert x.size(self.dim) == len(self.transforms) + yslices = [] + for xslice, trans in zip(self._slice(x), self.transforms): + yslices.append(trans(xslice)) + return torch.stack(yslices, dim=self.dim) + + def _inverse(self, y): + assert -y.dim() <= self.dim < y.dim() + assert y.size(self.dim) == len(self.transforms) + xslices = [] + for yslice, trans in zip(self._slice(y), self.transforms): + xslices.append(trans.inv(yslice)) + return torch.stack(xslices, dim=self.dim) + + def log_abs_det_jacobian(self, x, y): + assert -x.dim() <= self.dim < x.dim() + assert x.size(self.dim) == len(self.transforms) + assert -y.dim() <= self.dim < y.dim() + assert y.size(self.dim) == len(self.transforms) + logdetjacs = [] + yslices = self._slice(y) + xslices = self._slice(x) + for xslice, yslice, trans in zip(xslices, yslices, self.transforms): + logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice)) + return torch.stack(logdetjacs, dim=self.dim) + + @property + def bijective(self): + return all(t.bijective for t in self.transforms) + + @constraints.dependent_property + def domain(self): + return constraints.stack([t.domain for t in self.transforms], self.dim) + + @constraints.dependent_property + def codomain(self): + return constraints.stack([t.codomain for t in self.transforms], self.dim) + + +class CumulativeDistributionTransform(Transform): + """ + Transform via the cumulative distribution function of a probability distribution. + + Args: + distribution (Distribution): Distribution whose cumulative distribution function to use for + the transformation. + + Example:: + + # Construct a Gaussian copula from a multivariate normal. + base_dist = MultivariateNormal( + loc=torch.zeros(2), + scale_tril=LKJCholesky(2).sample(), + ) + transform = CumulativeDistributionTransform(Normal(0, 1)) + copula = TransformedDistribution(base_dist, [transform]) + """ + + bijective = True + codomain = constraints.unit_interval + sign = +1 + + def __init__(self, distribution, cache_size=0): + super().__init__(cache_size=cache_size) + self.distribution = distribution + + @property + def domain(self): + return self.distribution.support + + def _call(self, x): + return self.distribution.cdf(x) + + def _inverse(self, y): + return self.distribution.icdf(y) + + def log_abs_det_jacobian(self, x, y): + return self.distribution.log_prob(x) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return CumulativeDistributionTransform(self.distribution, cache_size=cache_size) diff --git a/lib/python3.10/site-packages/torch/distributions/uniform.py b/lib/python3.10/site-packages/torch/distributions/uniform.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe3678a319c406a1aabcacefa465391cd94e615 --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/uniform.py @@ -0,0 +1,102 @@ +# mypy: allow-untyped-defs +from numbers import Number + +import torch +from torch import nan +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all +from torch.types import _size + + +__all__ = ["Uniform"] + + +class Uniform(Distribution): + r""" + Generates uniformly distributed random samples from the half-open interval + ``[low, high)``. + + Example:: + + >>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0])) + >>> m.sample() # uniformly distributed in the range [0.0, 5.0) + >>> # xdoctest: +SKIP + tensor([ 2.3418]) + + Args: + low (float or Tensor): lower range (inclusive). + high (float or Tensor): upper range (exclusive). + """ + # TODO allow (loc,scale) parameterization to allow independent constraints. + arg_constraints = { + "low": constraints.dependent(is_discrete=False, event_dim=0), + "high": constraints.dependent(is_discrete=False, event_dim=0), + } + has_rsample = True + + @property + def mean(self): + return (self.high + self.low) / 2 + + @property + def mode(self): + return nan * self.high + + @property + def stddev(self): + return (self.high - self.low) / 12**0.5 + + @property + def variance(self): + return (self.high - self.low).pow(2) / 12 + + def __init__(self, low, high, validate_args=None): + self.low, self.high = broadcast_all(low, high) + + if isinstance(low, Number) and isinstance(high, Number): + batch_shape = torch.Size() + else: + batch_shape = self.low.size() + super().__init__(batch_shape, validate_args=validate_args) + + if self._validate_args and not torch.lt(self.low, self.high).all(): + raise ValueError("Uniform is not defined when low>= high") + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Uniform, _instance) + batch_shape = torch.Size(batch_shape) + new.low = self.low.expand(batch_shape) + new.high = self.high.expand(batch_shape) + super(Uniform, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return constraints.interval(self.low, self.high) + + def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + rand = torch.rand(shape, dtype=self.low.dtype, device=self.low.device) + return self.low + rand * (self.high - self.low) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + lb = self.low.le(value).type_as(self.low) + ub = self.high.gt(value).type_as(self.low) + return torch.log(lb.mul(ub)) - torch.log(self.high - self.low) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + result = (value - self.low) / (self.high - self.low) + return result.clamp(min=0, max=1) + + def icdf(self, value): + result = value * (self.high - self.low) + self.low + return result + + def entropy(self): + return torch.log(self.high - self.low) diff --git a/lib/python3.10/site-packages/torch/distributions/utils.py b/lib/python3.10/site-packages/torch/distributions/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..90c6d98a11ab1bb9c6571aa48293333b0cc3c28e --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/utils.py @@ -0,0 +1,200 @@ +# mypy: allow-untyped-defs +from functools import update_wrapper +from numbers import Number +from typing import Any, Dict + +import torch +import torch.nn.functional as F +from torch.overrides import is_tensor_like + + +euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant + +__all__ = [ + "broadcast_all", + "logits_to_probs", + "clamp_probs", + "probs_to_logits", + "lazy_property", + "tril_matrix_to_vec", + "vec_to_tril_matrix", +] + + +def broadcast_all(*values): + r""" + Given a list of values (possibly containing numbers), returns a list where each + value is broadcasted based on the following rules: + - `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`. + - numbers.Number instances (scalars) are upcast to tensors having + the same size and type as the first tensor passed to `values`. If all the + values are scalars, then they are upcasted to scalar Tensors. + + Args: + values (list of `numbers.Number`, `torch.*Tensor` or objects implementing __torch_function__) + + Raises: + ValueError: if any of the values is not a `numbers.Number` instance, + a `torch.*Tensor` instance, or an instance implementing __torch_function__ + """ + if not all(is_tensor_like(v) or isinstance(v, Number) for v in values): + raise ValueError( + "Input arguments must all be instances of numbers.Number, " + "torch.Tensor or objects implementing __torch_function__." + ) + if not all(is_tensor_like(v) for v in values): + options: Dict[str, Any] = dict(dtype=torch.get_default_dtype()) + for value in values: + if isinstance(value, torch.Tensor): + options = dict(dtype=value.dtype, device=value.device) + break + new_values = [ + v if is_tensor_like(v) else torch.tensor(v, **options) for v in values + ] + return torch.broadcast_tensors(*new_values) + return torch.broadcast_tensors(*values) + + +def _standard_normal(shape, dtype, device): + if torch._C._get_tracing_state(): + # [JIT WORKAROUND] lack of support for .normal_() + return torch.normal( + torch.zeros(shape, dtype=dtype, device=device), + torch.ones(shape, dtype=dtype, device=device), + ) + return torch.empty(shape, dtype=dtype, device=device).normal_() + + +def _sum_rightmost(value, dim): + r""" + Sum out ``dim`` many rightmost dimensions of a given tensor. + + Args: + value (Tensor): A tensor of ``.dim()`` at least ``dim``. + dim (int): The number of rightmost dims to sum out. + """ + if dim == 0: + return value + required_shape = value.shape[:-dim] + (-1,) + return value.reshape(required_shape).sum(-1) + + +def logits_to_probs(logits, is_binary=False): + r""" + Converts a tensor of logits into probabilities. Note that for the + binary case, each value denotes log odds, whereas for the + multi-dimensional case, the values along the last dimension denote + the log probabilities (possibly unnormalized) of the events. + """ + if is_binary: + return torch.sigmoid(logits) + return F.softmax(logits, dim=-1) + + +def clamp_probs(probs): + """Clamps the probabilities to be in the open interval `(0, 1)`. + + The probabilities would be clamped between `eps` and `1 - eps`, + and `eps` would be the smallest representable positive number for the input data type. + + Args: + probs (Tensor): A tensor of probabilities. + + Returns: + Tensor: The clamped probabilities. + + Examples: + >>> probs = torch.tensor([0.0, 0.5, 1.0]) + >>> clamp_probs(probs) + tensor([1.1921e-07, 5.0000e-01, 1.0000e+00]) + + >>> probs = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64) + >>> clamp_probs(probs) + tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=torch.float64) + + """ + eps = torch.finfo(probs.dtype).eps + return probs.clamp(min=eps, max=1 - eps) + + +def probs_to_logits(probs, is_binary=False): + r""" + Converts a tensor of probabilities into logits. For the binary case, + this denotes the probability of occurrence of the event indexed by `1`. + For the multi-dimensional case, the values along the last dimension + denote the probabilities of occurrence of each of the events. + """ + ps_clamped = clamp_probs(probs) + if is_binary: + return torch.log(ps_clamped) - torch.log1p(-ps_clamped) + return torch.log(ps_clamped) + + +class lazy_property: + r""" + Used as a decorator for lazy loading of class attributes. This uses a + non-data descriptor that calls the wrapped method to compute the property on + first call; thereafter replacing the wrapped method into an instance + attribute. + """ + + def __init__(self, wrapped): + self.wrapped = wrapped + update_wrapper(self, wrapped) # type:ignore[arg-type] + + def __get__(self, instance, obj_type=None): + if instance is None: + return _lazy_property_and_property(self.wrapped) + with torch.enable_grad(): + value = self.wrapped(instance) + setattr(instance, self.wrapped.__name__, value) + return value + + +class _lazy_property_and_property(lazy_property, property): + """We want lazy properties to look like multiple things. + + * property when Sphinx autodoc looks + * lazy_property when Distribution validate_args looks + """ + + def __init__(self, wrapped): + property.__init__(self, wrapped) + + +def tril_matrix_to_vec(mat: torch.Tensor, diag: int = 0) -> torch.Tensor: + r""" + Convert a `D x D` matrix or a batch of matrices into a (batched) vector + which comprises of lower triangular elements from the matrix in row order. + """ + n = mat.shape[-1] + if not torch._C._get_tracing_state() and (diag < -n or diag >= n): + raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n-1}].") + arange = torch.arange(n, device=mat.device) + tril_mask = arange < arange.view(-1, 1) + (diag + 1) + vec = mat[..., tril_mask] + return vec + + +def vec_to_tril_matrix(vec: torch.Tensor, diag: int = 0) -> torch.Tensor: + r""" + Convert a vector or a batch of vectors into a batched `D x D` + lower triangular matrix containing elements from the vector in row order. + """ + # +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0 + n = ( + -(1 + 2 * diag) + + ((1 + 2 * diag) ** 2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1)) ** 0.5 + ) / 2 + eps = torch.finfo(vec.dtype).eps + if not torch._C._get_tracing_state() and (round(n) - n > eps): + raise ValueError( + f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as " + + "the lower triangular part of a square D x D matrix." + ) + n = round(n.item()) if isinstance(n, torch.Tensor) else round(n) + mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n))) + arange = torch.arange(n, device=vec.device) + tril_mask = arange < arange.view(-1, 1) + (diag + 1) + mat[..., tril_mask] = vec + return mat diff --git a/lib/python3.10/site-packages/torch/distributions/von_mises.py b/lib/python3.10/site-packages/torch/distributions/von_mises.py new file mode 100644 index 0000000000000000000000000000000000000000..bd8fa87f2619a476d85cab44accac728d7b8fb2a --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/von_mises.py @@ -0,0 +1,211 @@ +# mypy: allow-untyped-defs +import math + +import torch +import torch.jit +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all, lazy_property + + +__all__ = ["VonMises"] + + +def _eval_poly(y, coef): + coef = list(coef) + result = coef.pop() + while coef: + result = coef.pop() + y * result + return result + + +_I0_COEF_SMALL = [ + 1.0, + 3.5156229, + 3.0899424, + 1.2067492, + 0.2659732, + 0.360768e-1, + 0.45813e-2, +] +_I0_COEF_LARGE = [ + 0.39894228, + 0.1328592e-1, + 0.225319e-2, + -0.157565e-2, + 0.916281e-2, + -0.2057706e-1, + 0.2635537e-1, + -0.1647633e-1, + 0.392377e-2, +] +_I1_COEF_SMALL = [ + 0.5, + 0.87890594, + 0.51498869, + 0.15084934, + 0.2658733e-1, + 0.301532e-2, + 0.32411e-3, +] +_I1_COEF_LARGE = [ + 0.39894228, + -0.3988024e-1, + -0.362018e-2, + 0.163801e-2, + -0.1031555e-1, + 0.2282967e-1, + -0.2895312e-1, + 0.1787654e-1, + -0.420059e-2, +] + +_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL] +_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE] + + +def _log_modified_bessel_fn(x, order=0): + """ + Returns ``log(I_order(x))`` for ``x > 0``, + where `order` is either 0 or 1. + """ + assert order == 0 or order == 1 + + # compute small solution + y = x / 3.75 + y = y * y + small = _eval_poly(y, _COEF_SMALL[order]) + if order == 1: + small = x.abs() * small + small = small.log() + + # compute large solution + y = 3.75 / x + large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log() + + result = torch.where(x < 3.75, small, large) + return result + + +@torch.jit.script_if_tracing +def _rejection_sample(loc, concentration, proposal_r, x): + done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device) + while not done.all(): + u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device) + u1, u2, u3 = u.unbind() + z = torch.cos(math.pi * u1) + f = (1 + proposal_r * z) / (proposal_r + z) + c = concentration * (proposal_r - f) + accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0) + if accept.any(): + x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x) + done = done | accept + return (x + math.pi + loc) % (2 * math.pi) - math.pi + + +class VonMises(Distribution): + """ + A circular von Mises distribution. + + This implementation uses polar coordinates. The ``loc`` and ``value`` args + can be any real number (to facilitate unconstrained optimization), but are + interpreted as angles modulo 2 pi. + + Example:: + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0])) + >>> m.sample() # von Mises distributed with loc=1 and concentration=1 + tensor([1.9777]) + + :param torch.Tensor loc: an angle in radians. + :param torch.Tensor concentration: concentration parameter + """ + + arg_constraints = {"loc": constraints.real, "concentration": constraints.positive} + support = constraints.real + has_rsample = False + + def __init__(self, loc, concentration, validate_args=None): + self.loc, self.concentration = broadcast_all(loc, concentration) + batch_shape = self.loc.shape + event_shape = torch.Size() + super().__init__(batch_shape, event_shape, validate_args) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + log_prob = self.concentration * torch.cos(value - self.loc) + log_prob = ( + log_prob + - math.log(2 * math.pi) + - _log_modified_bessel_fn(self.concentration, order=0) + ) + return log_prob + + @lazy_property + def _loc(self): + return self.loc.to(torch.double) + + @lazy_property + def _concentration(self): + return self.concentration.to(torch.double) + + @lazy_property + def _proposal_r(self): + kappa = self._concentration + tau = 1 + (1 + 4 * kappa**2).sqrt() + rho = (tau - (2 * tau).sqrt()) / (2 * kappa) + _proposal_r = (1 + rho**2) / (2 * rho) + # second order Taylor expansion around 0 for small kappa + _proposal_r_taylor = 1 / kappa + kappa + return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r) + + @torch.no_grad() + def sample(self, sample_shape=torch.Size()): + """ + The sampling algorithm for the von Mises distribution is based on the + following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the + von Mises distribution." Applied Statistics (1979): 152-157. + + Sampling is always done in double precision internally to avoid a hang + in _rejection_sample() for small values of the concentration, which + starts to happen for single precision around 1e-4 (see issue #88443). + """ + shape = self._extended_shape(sample_shape) + x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device) + return _rejection_sample( + self._loc, self._concentration, self._proposal_r, x + ).to(self.loc.dtype) + + def expand(self, batch_shape): + try: + return super().expand(batch_shape) + except NotImplementedError: + validate_args = self.__dict__.get("_validate_args") + loc = self.loc.expand(batch_shape) + concentration = self.concentration.expand(batch_shape) + return type(self)(loc, concentration, validate_args=validate_args) + + @property + def mean(self): + """ + The provided mean is the circular one. + """ + return self.loc + + @property + def mode(self): + return self.loc + + @lazy_property + def variance(self): + """ + The provided variance is the circular one. + """ + return ( + 1 + - ( + _log_modified_bessel_fn(self.concentration, order=1) + - _log_modified_bessel_fn(self.concentration, order=0) + ).exp() + ) diff --git a/lib/python3.10/site-packages/torch/distributions/weibull.py b/lib/python3.10/site-packages/torch/distributions/weibull.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a5af16968413ebd16bed2d5f94f827294ff93a --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/weibull.py @@ -0,0 +1,85 @@ +# mypy: allow-untyped-defs +import torch +from torch.distributions import constraints +from torch.distributions.exponential import Exponential +from torch.distributions.gumbel import euler_constant +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, PowerTransform +from torch.distributions.utils import broadcast_all + + +__all__ = ["Weibull"] + + +class Weibull(TransformedDistribution): + r""" + Samples from a two-parameter Weibull distribution. + + Example: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0])) + >>> m.sample() # sample from a Weibull distribution with scale=1, concentration=1 + tensor([ 0.4784]) + + Args: + scale (float or Tensor): Scale parameter of distribution (lambda). + concentration (float or Tensor): Concentration parameter of distribution (k/shape). + """ + arg_constraints = { + "scale": constraints.positive, + "concentration": constraints.positive, + } + support = constraints.positive + + def __init__(self, scale, concentration, validate_args=None): + self.scale, self.concentration = broadcast_all(scale, concentration) + self.concentration_reciprocal = self.concentration.reciprocal() + base_dist = Exponential( + torch.ones_like(self.scale), validate_args=validate_args + ) + transforms = [ + PowerTransform(exponent=self.concentration_reciprocal), + AffineTransform(loc=0, scale=self.scale), + ] + super().__init__(base_dist, transforms, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Weibull, _instance) + new.scale = self.scale.expand(batch_shape) + new.concentration = self.concentration.expand(batch_shape) + new.concentration_reciprocal = new.concentration.reciprocal() + base_dist = self.base_dist.expand(batch_shape) + transforms = [ + PowerTransform(exponent=new.concentration_reciprocal), + AffineTransform(loc=0, scale=new.scale), + ] + super(Weibull, new).__init__(base_dist, transforms, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def mean(self): + return self.scale * torch.exp(torch.lgamma(1 + self.concentration_reciprocal)) + + @property + def mode(self): + return ( + self.scale + * ((self.concentration - 1) / self.concentration) + ** self.concentration.reciprocal() + ) + + @property + def variance(self): + return self.scale.pow(2) * ( + torch.exp(torch.lgamma(1 + 2 * self.concentration_reciprocal)) + - torch.exp(2 * torch.lgamma(1 + self.concentration_reciprocal)) + ) + + def entropy(self): + return ( + euler_constant * (1 - self.concentration_reciprocal) + + torch.log(self.scale * self.concentration_reciprocal) + + 1 + ) diff --git a/lib/python3.10/site-packages/torch/distributions/wishart.py b/lib/python3.10/site-packages/torch/distributions/wishart.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a00f0cc6db06412b13d3d9157eb9d8bade1eaf --- /dev/null +++ b/lib/python3.10/site-packages/torch/distributions/wishart.py @@ -0,0 +1,339 @@ +# mypy: allow-untyped-defs +import math +import warnings +from numbers import Number +from typing import Optional, Union + +import torch +from torch import nan +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.multivariate_normal import _precision_to_scale_tril +from torch.distributions.utils import lazy_property +from torch.types import _size + + +__all__ = ["Wishart"] + +_log_2 = math.log(2) + + +def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor: + assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function." + return torch.digamma( + x.unsqueeze(-1) + - torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,)) + ).sum(-1) + + +def _clamp_above_eps(x: torch.Tensor) -> torch.Tensor: + # We assume positive input for this function + return x.clamp(min=torch.finfo(x.dtype).eps) + + +class Wishart(ExponentialFamily): + r""" + Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`, + or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top` + + Example: + >>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional") + >>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2)) + >>> m.sample() # Wishart distributed with mean=`df * I` and + >>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j + + Args: + df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1 + covariance_matrix (Tensor): positive-definite covariance matrix + precision_matrix (Tensor): positive-definite precision matrix + scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal + Note: + Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or + :attr:`scale_tril` can be specified. + Using :attr:`scale_tril` will be more efficient: all computations internally + are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or + :attr:`precision_matrix` is passed instead, it is only used to compute + the corresponding lower triangular matrices using a Cholesky decomposition. + 'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1] + + **References** + + [1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`. + [2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`. + [3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`. + [4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203. + [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`. + """ + arg_constraints = { + "covariance_matrix": constraints.positive_definite, + "precision_matrix": constraints.positive_definite, + "scale_tril": constraints.lower_cholesky, + "df": constraints.greater_than(0), + } + support = constraints.positive_definite + has_rsample = True + _mean_carrier_measure = 0 + + def __init__( + self, + df: Union[torch.Tensor, Number], + covariance_matrix: Optional[torch.Tensor] = None, + precision_matrix: Optional[torch.Tensor] = None, + scale_tril: Optional[torch.Tensor] = None, + validate_args=None, + ): + assert (covariance_matrix is not None) + (scale_tril is not None) + ( + precision_matrix is not None + ) == 1, "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." + + param = next( + p + for p in (covariance_matrix, precision_matrix, scale_tril) + if p is not None + ) + + if param.dim() < 2: + raise ValueError( + "scale_tril must be at least two-dimensional, with optional leading batch dimensions" + ) + + if isinstance(df, Number): + batch_shape = torch.Size(param.shape[:-2]) + self.df = torch.tensor(df, dtype=param.dtype, device=param.device) + else: + batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape) + self.df = df.expand(batch_shape) + event_shape = param.shape[-2:] + + if self.df.le(event_shape[-1] - 1).any(): + raise ValueError( + f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}." + ) + + if scale_tril is not None: + self.scale_tril = param.expand(batch_shape + (-1, -1)) + elif covariance_matrix is not None: + self.covariance_matrix = param.expand(batch_shape + (-1, -1)) + elif precision_matrix is not None: + self.precision_matrix = param.expand(batch_shape + (-1, -1)) + + self.arg_constraints["df"] = constraints.greater_than(event_shape[-1] - 1) + if self.df.lt(event_shape[-1]).any(): + warnings.warn( + "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim." + ) + + super().__init__(batch_shape, event_shape, validate_args=validate_args) + self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))] + + if scale_tril is not None: + self._unbroadcasted_scale_tril = scale_tril + elif covariance_matrix is not None: + self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix) + else: # precision_matrix is not None + self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix) + + # Chi2 distribution is needed for Bartlett decomposition sampling + self._dist_chi2 = torch.distributions.chi2.Chi2( + df=( + self.df.unsqueeze(-1) + - torch.arange( + self._event_shape[-1], + dtype=self._unbroadcasted_scale_tril.dtype, + device=self._unbroadcasted_scale_tril.device, + ).expand(batch_shape + (-1,)) + ) + ) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Wishart, _instance) + batch_shape = torch.Size(batch_shape) + cov_shape = batch_shape + self.event_shape + new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape) + new.df = self.df.expand(batch_shape) + + new._batch_dims = [-(x + 1) for x in range(len(batch_shape))] + + if "covariance_matrix" in self.__dict__: + new.covariance_matrix = self.covariance_matrix.expand(cov_shape) + if "scale_tril" in self.__dict__: + new.scale_tril = self.scale_tril.expand(cov_shape) + if "precision_matrix" in self.__dict__: + new.precision_matrix = self.precision_matrix.expand(cov_shape) + + # Chi2 distribution is needed for Bartlett decomposition sampling + new._dist_chi2 = torch.distributions.chi2.Chi2( + df=( + new.df.unsqueeze(-1) + - torch.arange( + self.event_shape[-1], + dtype=new._unbroadcasted_scale_tril.dtype, + device=new._unbroadcasted_scale_tril.device, + ).expand(batch_shape + (-1,)) + ) + ) + + super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @lazy_property + def scale_tril(self): + return self._unbroadcasted_scale_tril.expand( + self._batch_shape + self._event_shape + ) + + @lazy_property + def covariance_matrix(self): + return ( + self._unbroadcasted_scale_tril + @ self._unbroadcasted_scale_tril.transpose(-2, -1) + ).expand(self._batch_shape + self._event_shape) + + @lazy_property + def precision_matrix(self): + identity = torch.eye( + self._event_shape[-1], + device=self._unbroadcasted_scale_tril.device, + dtype=self._unbroadcasted_scale_tril.dtype, + ) + return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand( + self._batch_shape + self._event_shape + ) + + @property + def mean(self): + return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix + + @property + def mode(self): + factor = self.df - self.covariance_matrix.shape[-1] - 1 + factor[factor <= 0] = nan + return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix + + @property + def variance(self): + V = self.covariance_matrix # has shape (batch_shape x event_shape) + diag_V = V.diagonal(dim1=-2, dim2=-1) + return self.df.view(self._batch_shape + (1, 1)) * ( + V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V) + ) + + def _bartlett_sampling(self, sample_shape=torch.Size()): + p = self._event_shape[-1] # has singleton shape + + # Implemented Sampling using Bartlett decomposition + noise = _clamp_above_eps( + self._dist_chi2.rsample(sample_shape).sqrt() + ).diag_embed(dim1=-2, dim2=-1) + + i, j = torch.tril_indices(p, p, offset=-1) + noise[..., i, j] = torch.randn( + torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),), + dtype=noise.dtype, + device=noise.device, + ) + chol = self._unbroadcasted_scale_tril @ noise + return chol @ chol.transpose(-2, -1) + + def rsample( + self, sample_shape: _size = torch.Size(), max_try_correction=None + ) -> torch.Tensor: + r""" + .. warning:: + In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples. + Several tries to correct singular samples are performed by default, but it may end up returning + singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`. + In those cases, the user should validate the samples and either fix the value of `df` + or adjust `max_try_correction` value for argument in `.rsample` accordingly. + """ + + if max_try_correction is None: + max_try_correction = 3 if torch._C._get_tracing_state() else 10 + + sample_shape = torch.Size(sample_shape) + sample = self._bartlett_sampling(sample_shape) + + # Below part is to improve numerical stability temporally and should be removed in the future + is_singular = self.support.check(sample) + if self._batch_shape: + is_singular = is_singular.amax(self._batch_dims) + + if torch._C._get_tracing_state(): + # Less optimized version for JIT + for _ in range(max_try_correction): + sample_new = self._bartlett_sampling(sample_shape) + sample = torch.where(is_singular, sample_new, sample) + + is_singular = ~self.support.check(sample) + if self._batch_shape: + is_singular = is_singular.amax(self._batch_dims) + + else: + # More optimized version with data-dependent control flow. + if is_singular.any(): + warnings.warn("Singular sample detected.") + + for _ in range(max_try_correction): + sample_new = self._bartlett_sampling(is_singular[is_singular].shape) + sample[is_singular] = sample_new + + is_singular_new = ~self.support.check(sample_new) + if self._batch_shape: + is_singular_new = is_singular_new.amax(self._batch_dims) + is_singular[is_singular.clone()] = is_singular_new + + if not is_singular.any(): + break + + return sample + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + nu = self.df # has shape (batch_shape) + p = self._event_shape[-1] # has singleton shape + return ( + -nu + * ( + p * _log_2 / 2 + + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1) + .log() + .sum(-1) + ) + - torch.mvlgamma(nu / 2, p=p) + + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet + - torch.cholesky_solve(value, self._unbroadcasted_scale_tril) + .diagonal(dim1=-2, dim2=-1) + .sum(dim=-1) + / 2 + ) + + def entropy(self): + nu = self.df # has shape (batch_shape) + p = self._event_shape[-1] # has singleton shape + V = self.covariance_matrix # has shape (batch_shape x event_shape) + return ( + (p + 1) + * ( + p * _log_2 / 2 + + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1) + .log() + .sum(-1) + ) + + torch.mvlgamma(nu / 2, p=p) + - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p) + + nu * p / 2 + ) + + @property + def _natural_params(self): + nu = self.df # has shape (batch_shape) + p = self._event_shape[-1] # has singleton shape + return -self.precision_matrix / 2, (nu - p - 1) / 2 + + def _log_normalizer(self, x, y): + p = self._event_shape[-1] + return (y + (p + 1) / 2) * ( + -torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p + ) + torch.mvlgamma(y + (p + 1) / 2, p=p) diff --git a/lib/python3.10/site-packages/torch/export/__init__.py b/lib/python3.10/site-packages/torch/export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..367b3f127a16bba26b6e4191abc00220918c0093 --- /dev/null +++ b/lib/python3.10/site-packages/torch/export/__init__.py @@ -0,0 +1,518 @@ +import builtins +import copy +import dataclasses +import inspect +import io +import os +import sys +import typing +import warnings +import zipfile +from enum import auto, Enum +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + +import torch +import torch.utils._pytree as pytree +from torch.fx._compatibility import compatibility +from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.infra.pass_manager import PassManager +from torch.utils._pytree import ( + FlattenFunc, + FromDumpableContextFn, + ToDumpableContextFn, + UnflattenFunc, +) + + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # Do not import unconditionally, as they import sympy and importing sympy is very slow + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + + +__all__ = [ + "Constraint", + "Dim", + "ExportBackwardSignature", + "ExportGraphSignature", + "ExportedProgram", + "ModuleCallEntry", + "ModuleCallSignature", + "dims", + "export", + "export_for_training", + "load", + "register_dataclass", + "save", + "unflatten", + "FlatArgsAdapter", + "UnflattenedModule", +] + + +from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection +from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature +from .graph_signature import ExportBackwardSignature, ExportGraphSignature +from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule + + +PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] + + +def export_for_training( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + strict: bool = True, + preserve_module_call_signature: Tuple[str, ...] = (), +) -> ExportedProgram: + """ + :func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing + only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, + which can subsequently be executed with different inputs or serialized. The + traced graph (1) produces normalized operators in the all ATen operator set + (as well as any user-specified custom operators), (2) has eliminated all Python control + flow and data structures (with certain exceptions), and (3) records the set of + shape constraints needed to show that this normalization and control-flow elimination + is sound for future inputs. This API is intended for PT2 quantization training use cases + and will soon be the default IR of torch.export.export in the near future. + + **Soundness Guarantee** + + See :func:`export()` docstring for more details. + + Args: + mod: We will trace the forward method of this module. + + args: Example positional inputs. + + kwargs: Optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + strict: When enabled (default), the export function will trace the program through + TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the + exported program will not validate the implicit assumptions baked into the graph and + may cause behavior divergence between the original model and the exported one. This is + useful when users need to workaround bugs in the tracer, or simply want incrementally + enable safety in their models. Note that this does not affect the resulting IR spec + to be different and the model will be serialized in the same way regardless of what value + is passed here. + WARNING: This option is experimental and use this at your own risk. + + Returns: + An :class:`ExportedProgram` containing the traced callable. + + **Acceptable input/output types** + + Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: + + - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. + - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. + - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and + ``OrderedDict`` containing all above types. + + """ + from ._trace import _export_for_training + + if not isinstance(mod, torch.nn.Module): + raise ValueError( + f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." + ) + if isinstance(mod, torch.jit.ScriptModule): + raise ValueError( + "Exporting a ScriptModule is not supported. " + "Maybe try converting your ScriptModule to an ExportedProgram " + "using `TS2EPConverter(mod, args, kwargs).convert()` instead." + ) + return _export_for_training( + mod, + args, + kwargs, + dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + ) + + +def export( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + strict: bool = True, + preserve_module_call_signature: Tuple[str, ...] = (), +) -> ExportedProgram: + """ + :func:`export` takes an arbitrary Python callable (an nn.Module, a function or + a method) along with example inputs, and produces a traced graph representing + only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, + which can subsequently be executed with different inputs or serialized. The + traced graph (1) produces normalized operators in the functional ATen operator set + (as well as any user-specified custom operators), (2) has eliminated all Python control + flow and data structures (with certain exceptions), and (3) records the set of + shape constraints needed to show that this normalization and control-flow elimination + is sound for future inputs. + + **Soundness Guarantee** + + While tracing, :func:`export()` takes note of shape-related assumptions + made by the user program and the underlying PyTorch operator kernels. + The output :class:`ExportedProgram` is considered valid only when these + assumptions hold true. + + Tracing makes assumptions on the shapes (not values) of input tensors. + Such assumptions must be validated at graph capture time for :func:`export` + to succeed. Specifically: + + - Assumptions on static shapes of input tensors are automatically validated without additional effort. + - Assumptions on dynamic shape of input tensors require explicit specification + by using the :func:`Dim` API to construct dynamic dimensions and by associating + them with example inputs through the ``dynamic_shapes`` argument. + + If any assumption can not be validated, a fatal error will be raised. When that happens, + the error message will include suggested fixes to the specification that are needed + to validate the assumptions. For example :func:`export` might suggest the + following fix to the definition of a dynamic dimension ``dim0_x``, say appearing in the + shape associated with input ``x``, that was previously defined as ``Dim("dim0_x")``:: + + dim = Dim("dim0_x", max=5) + + This example means the generated code requires dimension 0 of input ``x`` to be less + than or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimension + definitions and then copy them verbatim into your code without needing to change the + ``dynamic_shapes`` argument to your :func:`export` call. + + Args: + mod: We will trace the forward method of this module. + + args: Example positional inputs. + + kwargs: Optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + strict: When enabled (default), the export function will trace the program through + TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the + exported program will not validate the implicit assumptions baked into the graph and + may cause behavior divergence between the original model and the exported one. This is + useful when users need to workaround bugs in the tracer, or simply want incrementally + enable safety in their models. Note that this does not affect the resulting IR spec + to be different and the model will be serialized in the same way regardless of what value + is passed here. + WARNING: This option is experimental and use this at your own risk. + + Returns: + An :class:`ExportedProgram` containing the traced callable. + + **Acceptable input/output types** + + Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: + + - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. + - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. + - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and + ``OrderedDict`` containing all above types. + + """ + from ._trace import _export + + if not isinstance(mod, torch.nn.Module): + raise ValueError( + f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." + ) + if isinstance(mod, torch.jit.ScriptModule): + raise ValueError( + "Exporting a ScriptModule is not supported. " + "Maybe try converting your ScriptModule to an ExportedProgram " + "using `TS2EPConverter(mod, args, kwargs).convert()` instead." + ) + return _export( + mod, + args, + kwargs, + dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + pre_dispatch=True, + ) + + +def save( + ep: ExportedProgram, + f: Union[str, os.PathLike, io.BytesIO], + *, + extra_files: Optional[Dict[str, Any]] = None, + opset_version: Optional[Dict[str, int]] = None, +) -> None: + """ + + .. warning:: + Under active development, saved files may not be usable in newer versions + of PyTorch. + + Saves an :class:`ExportedProgram` to a file-like object. It can then be + loaded using the Python API :func:`torch.export.load `. + + Args: + ep (ExportedProgram): The exported program to save. + + f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to + implement write and flush) or a string containing a file name. + + extra_files (Optional[Dict[str, Any]]): Map from filename to contents + which will be stored as part of f. + + opset_version (Optional[Dict[str, int]]): A map of opset names + to the version of this opset + + + Example:: + + import torch + import io + + class MyModule(torch.nn.Module): + def forward(self, x): + return x + 10 + + ep = torch.export.export(MyModule(), (torch.randn(5),)) + + # Save to file + torch.export.save(ep, 'exported_program.pt2') + + # Save to io.BytesIO buffer + buffer = io.BytesIO() + torch.export.save(ep, buffer) + + # Save with extra files + extra_files = {'foo.txt': b'bar'.decode('utf-8')} + torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files) + + """ + if not isinstance(ep, ExportedProgram): + raise TypeError( + f"The 'ep' parameter must be an instance of 'ExportedProgram', got '{type(ep).__name__}' instead." + ) + + from torch._export.serde.schema import SCHEMA_VERSION + from torch._export.serde.serialize import serialize, SerializedArtifact + + artifact: SerializedArtifact = serialize(ep, opset_version) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + with zipfile.ZipFile(f, "w") as zipf: + # Save every field in the SerializedArtifact to a file. + assert isinstance(artifact.exported_program, bytes) + zipf.writestr("serialized_exported_program.json", artifact.exported_program) + zipf.writestr("serialized_state_dict.pt", artifact.state_dict) + zipf.writestr("serialized_constants.pt", artifact.constants) + zipf.writestr("serialized_example_inputs.pt", artifact.example_inputs) + + zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION))) + + # Add extra files if provided + if extra_files: + for extra_file_name, content in extra_files.items(): + encoded_content = content.encode("utf-8") + zipf.writestr(f"extra_files/{extra_file_name}", encoded_content) + + +def load( + f: Union[str, os.PathLike, io.BytesIO], + *, + extra_files: Optional[Dict[str, Any]] = None, + expected_opset_version: Optional[Dict[str, int]] = None, +) -> ExportedProgram: + """ + + .. warning:: + Under active development, saved files may not be usable in newer versions + of PyTorch. + + Loads an :class:`ExportedProgram` previously saved with + :func:`torch.export.save `. + + Args: + ep (ExportedProgram): The exported program to save. + + f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to + implement write and flush) or a string containing a file name. + + extra_files (Optional[Dict[str, Any]]): The extra filenames given in + this map would be loaded and their content would be stored in the + provided map. + + expected_opset_version (Optional[Dict[str, int]]): A map of opset names + to expected opset versions + + Returns: + An :class:`ExportedProgram` object + + Example:: + + import torch + import io + + # Load ExportedProgram from file + ep = torch.export.load('exported_program.pt2') + + # Load ExportedProgram from io.BytesIO object + with open('exported_program.pt2', 'rb') as f: + buffer = io.BytesIO(f.read()) + buffer.seek(0) + ep = torch.export.load(buffer) + + # Load with extra files. + extra_files = {'foo.txt': ''} # values will be replaced with data + ep = torch.export.load('exported_program.pt2', extra_files=extra_files) + print(extra_files['foo.txt']) + print(ep(torch.randn(5))) + """ + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + extra_files = extra_files or {} + + with zipfile.ZipFile(f, "r") as zipf: + # Check the version + version = zipf.read("version").decode().split(".") + from torch._export.serde.schema import SCHEMA_VERSION + + assert len(version) == len(SCHEMA_VERSION) + if version[0] != str(SCHEMA_VERSION[0]): + raise RuntimeError( + f"Serialized version {version} does not match our current " + f"schema version {SCHEMA_VERSION}." + ) + + from torch._export.serde.serialize import deserialize, SerializedArtifact + + # Load serialized_ep and serialized_state_dict from the zip file + + serialized_exported_program: Optional[bytes] = None + serialized_state_dict: Optional[bytes] = None + serialized_constants: Optional[bytes] = None + serialized_example_inputs: Optional[bytes] = None + + for file_info in zipf.infolist(): + file_content = zipf.read(file_info.filename) + + if file_info.filename == "serialized_exported_program.json": + serialized_exported_program = file_content + elif file_info.filename == "serialized_state_dict.json": + warnings.warn("This version of file is deprecated") + serialized_state_dict = file_content + elif file_info.filename == "serialized_constants.json": + warnings.warn("This version of file is deprecated") + serialized_constants = file_content + elif file_info.filename == "serialized_state_dict.pt": + serialized_state_dict = file_content + elif file_info.filename == "serialized_constants.pt": + serialized_constants = file_content + elif file_info.filename == "serialized_example_inputs.pt": + serialized_example_inputs = file_content + elif file_info.filename.startswith("extra_files"): + filename = file_info.filename.split("/", 1)[1] + extra_files[filename] = file_content.decode("utf-8") + + assert serialized_exported_program is not None + assert serialized_state_dict is not None + assert serialized_constants is not None + assert serialized_example_inputs is not None + artifact: SerializedArtifact = SerializedArtifact( + serialized_exported_program, + serialized_state_dict, + serialized_constants, + serialized_example_inputs, + ) + + # Deserialize ExportedProgram + ep = deserialize(artifact, expected_opset_version) + + return ep + + +def register_dataclass( + cls: Type[Any], + *, + serialized_type_name: Optional[str] = None, +) -> None: + """ + Registers a dataclass as a valid input/output type for :func:`torch.export.export`. + + Args: + cls: the dataclass type to register + serialized_type_name: The serialized name for the dataclass. This is + required if you want to serialize the pytree TreeSpec containing this + dataclass. + + Example:: + + @dataclass + class InputDataClass: + feature: torch.Tensor + bias: int + + class OutputDataClass: + res: torch.Tensor + + torch.export.register_dataclass(InputDataClass) + torch.export.register_dataclass(OutputDataClass) + + def fn(o: InputDataClass) -> torch.Tensor: + res = res=o.feature + o.bias + return OutputDataClass(res=res) + + ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), )) + print(ep) + + """ + + from torch._export.utils import register_dataclass_as_pytree_node + + return register_dataclass_as_pytree_node( + cls, serialized_type_name=serialized_type_name + ) diff --git a/lib/python3.10/site-packages/torch/export/_remove_auto_functionalized_pass.py b/lib/python3.10/site-packages/torch/export/_remove_auto_functionalized_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..683e89c3d14910b3f753a9eba7093ba53bc631a4 --- /dev/null +++ b/lib/python3.10/site-packages/torch/export/_remove_auto_functionalized_pass.py @@ -0,0 +1,52 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized, + auto_functionalized_v2, +) +from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized +from torch.export import ExportedProgram + + +def remove_self_clone(graph: torch.fx.Graph): + for node in graph.nodes: + if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]: + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + + +def unsafe_remove_auto_functionalized_pass( + ep: ExportedProgram, +) -> ExportedProgram: + """ + This pass removes an instances of the higher order op 'auto_functionalized', + and modifies the calling EP inplace to have the original mutator op. + This pass doesn't perform safety checks to make sure that this inplace mutation is safe. + """ + + with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): + for module in ep.graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in ep.graph.nodes: + if ( + node.op == "call_function" and node.target is auto_functionalized + ) or ( + node.op == "call_function" and node.target is auto_functionalized_v2 + ): + func = node.args[0] + assert isinstance(func, torch._ops.OpOverload) + # re-inplace everything + node.meta["only_clone_these_tensors"] = [] + decompose_auto_functionalized(ep.graph) + remove_self_clone(ep.graph) + ep.graph.eliminate_dead_code() + + return ep diff --git a/lib/python3.10/site-packages/torch/export/_remove_effect_tokens_pass.py b/lib/python3.10/site-packages/torch/export/_remove_effect_tokens_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..84adbf3663781124d3112a575895a0b4c051dc25 --- /dev/null +++ b/lib/python3.10/site-packages/torch/export/_remove_effect_tokens_pass.py @@ -0,0 +1,161 @@ +# mypy: allow-untyped-defs +import operator +from typing import List + +import torch +from torch._higher_order_ops.effects import _get_schema, with_effects + +from .exported_program import ExportedProgram +from .graph_signature import ( + CustomObjArgument, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + TokenArgument, +) + + +def _remove_effect_tokens_from_graph_helper( + ep, num_tokens, input_token_names, output_token_names +): + inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs + + output_node = None + with_effect_nodes: List[torch.fx.Node] = [] + + # Output node need to check its args agianst output_token_names (collected from output_spec) + # Therefore, we only need to find the top-levele output node + output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output"))) + for module in ep.graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + for node in module.graph.nodes: + if not (node.op == "call_function" and node.target is with_effects): + continue + + with_effect_nodes.append(node) + + # Remove tokens from outputs + assert output_node is not None + output_args = output_node.args[0] + assert len(output_args) >= num_tokens + out_token_nodes = output_args[:num_tokens] + output_node.args = (tuple(output_args[num_tokens:]),) + for out_token in out_token_nodes: + assert out_token.name in output_token_names + out_token.users.clear() + ep.graph.erase_node(out_token) + + # Replace with_effects(token, func, args) with just func(args) + for node in reversed(with_effect_nodes): + func = node.args[1] + assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) + + if func == torch.ops.higher_order.call_torchbind: + custom_obj_meta = node.args[2].meta["val"] + assert isinstance(custom_obj_meta, CustomObjArgument) + if custom_obj_meta.fake_val: + custom_obj = custom_obj_meta.fake_val + elif node.args[2].name in inputs_to_lifted_custom_objs: + custom_obj = ep.constants[ + inputs_to_lifted_custom_objs[node.args[2].name] + ] + else: + raise RuntimeError(f"Unable to find custom obj for node {node}") + schema = _get_schema(func, (custom_obj,) + node.args[3:]) + else: + schema = _get_schema(func, node.args[2:]) + + with ep.graph.inserting_before(node): + new_node = ep.graph.call_function(func, node.args[2:], node.kwargs) + for k, v in node.meta.items(): + new_node.meta[k] = v + + node.replace_all_uses_with(new_node) + + # Update user getitem nodes + for user in list(new_node.users.keys()): + assert user.target == operator.getitem + # getitem(with_effects, 0) == token + if user.args[1] == 0: + ep.graph.erase_node(user) + + if len(schema.returns) == 1: + # If the function has 1 return then it will just directly return the + # result -- we don't need a getitem. So we can replace all the + # getitem(with_effects, 1) with just the note itself. + for user in list(new_node.users.keys()): + assert user.args[1] == 1 + user.replace_all_uses_with(new_node) + + new_node.meta["val"] = node.meta["val"][1] + elif len(schema.returns) > 1: + # If the function has more than 1 return then since we got rid of + # the 1st return value (the token), we need to bump all the other + # getitem calls by 1 down + for user in list(new_node.users.keys()): + assert user.args[1] >= 1 + user.args = (user.args[0], user.args[1] - 1) + + new_node.meta["val"] = node.meta["val"][1:] + else: + assert len(schema.returns) == 0 + assert len(new_node.users) == 0 + new_node.meta["val"] = None + + ep.graph.erase_node(node) + + # Remove tokens from inputs + placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"] + assert len(placeholders) >= num_tokens + inp_token_nodes = placeholders[:num_tokens] + for inp_token in inp_token_nodes: + assert inp_token.name in input_token_names + ep.graph.erase_node(inp_token) + + ep.graph.eliminate_dead_code() + + +def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: + """ + Removes the existance of tokens from the exported program, including: + - Removes the input and output tokens + - Replaces with_effects(token, func, args) with just func(args) + + This function does an inplace modification on the given ExportedProgram. + """ + num_tokens: int = 0 + input_token_names: List[str] = [] + new_input_specs: List[InputSpec] = [] + for inp in ep.graph_signature.input_specs: + if inp.kind == InputKind.TOKEN: + num_tokens += 1 + assert isinstance(inp.arg, TokenArgument) + input_token_names.append(inp.arg.name) + else: + new_input_specs.append(inp) + + num_out_tokens: int = 0 + new_output_specs: List[OutputSpec] = [] + output_token_names: List[OutputSpec] = [] + for out in ep.graph_signature.output_specs: + if out.kind == OutputKind.TOKEN: + num_out_tokens += 1 + output_token_names.append(out.arg.name) + else: + new_output_specs.append(out) + + # Update graph signature + ep.graph_signature.input_specs = new_input_specs + ep.graph_signature.output_specs = new_output_specs + + assert num_tokens == num_out_tokens + + with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): + _remove_effect_tokens_from_graph_helper( + ep, num_tokens, input_token_names, output_token_names + ) + + return ep diff --git a/lib/python3.10/site-packages/torch/export/_safeguard.py b/lib/python3.10/site-packages/torch/export/_safeguard.py new file mode 100644 index 0000000000000000000000000000000000000000..76f22f369c566a97062fc60696ad7972dc2b260c --- /dev/null +++ b/lib/python3.10/site-packages/torch/export/_safeguard.py @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +import torch +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode +from torch.overrides import TorchFunctionMode + + +class AutogradStateOpsFailSafeguard(TorchFunctionMode): + """ + Detect grad state ops during exporting the graph and fail the process by + raising an error, to avoid unexpected behavior. Those grad mode ops could be: + `torch.no_grad` + `torch.enable_grad` + `torch.set_grad_enabled` + + Export with predispatch mode is exempted. + """ + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + unsupported_grad_mode_ops = [ + torch._C._set_grad_enabled, + ] + # It's only enabled while tracing, by confirming the torch dispatch mode is + # any active PROXY. This is to allow the autograd ops out of tracing. + current_state = torch._C.is_grad_enabled() + if func in unsupported_grad_mode_ops: + assert len(args) == 1 + changed_state = args[0] + mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) + # Intend to check if it's not the pre_dispatch mode. It's allowed to use + # autograd ops in pre_dispatch mode, e.g. `torch.no_grad` + if ( + mode + and isinstance(mode, ProxyTorchDispatchMode) + and not mode.pre_dispatch + and changed_state != current_state + ): + raise RuntimeError( + f"Encountered autograd state manager op {func} trying to change global autograd state " + "while exporting. This is unsafe because we don't capture this op in torch.export " + "today, hence we can't reflect the user intention soundly. You can fix this by " + "adding a torch.no_grad() context around the export call." + ) + return func(*args, **kwargs) diff --git a/lib/python3.10/site-packages/torch/export/_trace.py b/lib/python3.10/site-packages/torch/export/_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..abd2b5405fb937371501a06261c5a2fabef29cfb --- /dev/null +++ b/lib/python3.10/site-packages/torch/export/_trace.py @@ -0,0 +1,1943 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import dataclasses +import functools +import inspect +import logging +import re +import time +import warnings +from contextlib import contextmanager, nullcontext +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +import torch._dynamo +import torch.fx +import torch.utils._pytree as pytree +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.exc import UserError, UserErrorType +from torch._export.db.logging import ( + exportdb_error_message, + get_class_if_classified_error, +) +from torch._export.non_strict_utils import ( + _fakify_script_objects, + _gather_constant_attrs, + _NonStrictTorchFunctionHandler, + make_constraints, + make_fake_inputs, + produce_guards_and_solve_constraints, +) +from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, +) +from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass +from torch._export.passes.lift_constants_pass import ( + ConstantAttrMap, + lift_constants_pass, + rewrite_script_object_meta, +) +from torch._export.utils import ( + _collect_param_buffer_metadata, + _get_shape_env_from_gm, + _populate_param_buffer_metadata_to_new_gm, + placeholder_naming_pass, + placeholder_prefixes, +) +from torch._export.verifier import SpecViolationError +from torch._export.wrappers import _wrap_submodules +from torch._functorch._aot_autograd.input_output_analysis import ( + _graph_input_names, + _graph_output_names, +) +from torch._functorch._aot_autograd.traced_function_transforms import ( + create_functional_call, +) +from torch._functorch._aot_autograd.utils import create_tree_flattened_fn +from torch._functorch.aot_autograd import aot_export_module +from torch._guards import detect_fake_mode +from torch._library.fake_class_registry import FakeScriptObject +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch._utils_internal import log_export_usage +from torch.export.dynamic_shapes import ( + _check_dynamic_shapes, + _combine_args, + _transform_shapes_for_default_dynamic, +) +from torch.export.exported_program import OutputKind +from torch.fx._utils import first_call_function_nn_module_stack +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + free_unbacked_symbols, + GuardOnDataDependentSymNode, + ShapeEnv, +) +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts +from torch.utils._pytree import TreeSpec +from torch.utils._sympy.value_ranges import ValueRangeError + +from ._safeguard import AutogradStateOpsFailSafeguard +from .exported_program import ( + _disable_prexisiting_fake_mode, + ExportedProgram, + InputKind, + ModuleCallEntry, + ModuleCallSignature, +) +from .graph_signature import _convert_to_export_graph_signature, ExportGraphSignature + + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ExportDynamoConfig: + """ + Manage Export-specific configurations of Dynamo. + """ + + allow_rnn: bool = True + reorderable_logging_functions: Set[Callable] = dataclasses.field( + default_factory=set + ) + # Emit runtime asserts after AOTAutograd instead. + # This isn't really necessary, and isn't much more efficient since the runtime asserts pass does CSE, + # but if we want to reason more about what guards/runtime asserts to emit, + # this makes it a bit cleaner to do from the export side. Also no real point in running this twice. + do_not_emit_runtime_asserts = True + + +@dataclasses.dataclass +class ATenExportArtifact: + gm: torch.fx.GraphModule + sig: ExportGraphSignature + constants: Dict[ + str, + Union[ + torch.Tensor, + FakeScriptObject, + torch.ScriptObject, + ], + ] + + +@dataclasses.dataclass(frozen=True) +class ExportArtifact: + aten: ATenExportArtifact + out_spec: TreeSpec + fake_mode: FakeTensorMode + module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] + + +DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig() +DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = { + logging.critical, + logging.debug, + logging.error, + logging.exception, + logging.info, + logging.log, + logging.warning, + print, + warnings.warn, +} + + +@contextmanager +def _ignore_backend_decomps(): + orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False) + orig_nnpack_flag = torch.backends.nnpack.set_flags(False) + try: + yield + finally: + torch.backends.mkldnn.set_flags(*orig_mkldnn_flag) + torch.backends.nnpack.set_flags(*orig_nnpack_flag) + + +def _fixup_key(x): + return "L__self__" + _strip_root(x) + + +def _strip_root(x): + if isinstance(x, str) and x.startswith("_export_root"): + stripped = x[len("_export_root") :] + return stripped[1:] if stripped.startswith(".") else stripped + return x + + +def _rewrite_tracepoint_node(gm: torch.fx.GraphModule): + """ + In-place modifiy input graph module by replacing the export tracepoint with a new node + that has the same target and args, but with the _export_root stripped from path. + """ + for node in gm.graph.nodes: + if node.target == torch.ops.higher_order._export_tracepoint: + if "path" in node.kwargs: + path = _strip_root(node.kwargs["path"]) + with gm.graph.inserting_before(node): + new_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order._export_tracepoint, + args=node.args, + kwargs={ + "path": path, + "kind": node.kwargs["kind"], + }, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + gm.graph.erase_node(node) + + +def _extract_fake_inputs(gm, args, kwargs): + """ + Given a graph module, extract fakified input tensors from the metadata of + its placeholders, and map them to the structure of given args and kwargs. + Also return the fake mode used to fakify those inputs. + """ + + fake_inps: List[torch.Tensor] = [] + fake_vals: List[torch.Tensor] = [] + for node in gm.graph.nodes: + if node.op == "placeholder" and "val" in node.meta: + fake_val = node.meta["val"] + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_inps.append(fake_val) + elif "example_value" in node.meta: + fake_val = node.meta["example_value"] + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_vals.append(fake_val) + + if detected_fake_mode := detect_fake_mode(fake_inps + fake_vals): + fake_mode = detected_fake_mode + else: + fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True) + + count = 0 + + def lookup_fake(x): + nonlocal count + val = fake_inps[count] + count += 1 + return val + + fake_args = pytree.tree_map_only(torch.Tensor, lookup_fake, args) + fake_kwargs = pytree.tree_map_only(torch.Tensor, lookup_fake, kwargs) + + return fake_args, fake_kwargs, fake_mode + + +def _replace_param_buffer_names(param_buffer_table, sig): + for spec in sig.input_specs: + if spec.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, + ): + spec.target = param_buffer_table[spec.target] + for spec in sig.output_specs: + if spec.kind in ( + OutputKind.BUFFER_MUTATION, + OutputKind.GRADIENT_TO_PARAMETER, + ): + spec.target = param_buffer_table[spec.target] + + +def _convert_to_positional_args(orig_arg_names, args, kwargs): + assert len(orig_arg_names) == len(args) + len(kwargs), ( + f"Total number of arg names is expected to be {len(orig_arg_names)} " + f"but got {len(args)} positional args, {len(kwargs)} kwargs." + ) + reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]] + return ( + *args, + *reordered_kwargs, + ) + + +def _normalize_nn_module_stack(gm_torch_level, root_cls): + # Append a root module to every nn_module_stack. + root = "L['self']" + root_key = re.sub(r"[^a-zA-Z0-9]", "_", root) + for gm in gm_torch_level.modules(): + if not isinstance(gm, torch.fx.GraphModule): + continue + for node in gm.graph.nodes: + if node.op in ["placeholder", "output"]: + continue + add_root = True + if nn_module_stack := node.meta.get("nn_module_stack", {}): + path, ty = next(iter(nn_module_stack.values())) + # After deserializing the class `ty` might not exist anymore so + # it could be a string + if inspect.isclass(ty) and issubclass(ty, torch.nn.Module): + # TODO Figure out why sometimes we have root sometimes we don't. + if path == root and ty is root_cls: + add_root = False + else: + assert isinstance(ty, str) + if add_root: + + def normalize_path(path): + try: + parts = [] + + class Path: + def __getattr__(self, name): + parts.append(name) + return self + + def __getitem__(self, idx): + parts.append(str(idx)) + return self + + eval(path, {"L": {"self": Path()}}) + return ".".join(parts) + except Exception: # TODO(zhxchen17) Remove this. + return path + + nn_module_stack = { + root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__), + **nn_module_stack, + } + node.meta["nn_module_stack"] = { + key: (normalize_path(path), ty) + for key, (path, ty) in nn_module_stack.items() + } + + +def _get_param_buffer_mapping( + original_module: torch.nn.Module, + traced_module: torch.nn.Module, +) -> Dict[str, str]: + """ + Returns a mapping of parameter/buffer names from the new module to the + original model. This is to help with restoring the FQN for parameter/buffers + of a traced module to what the original module contains. + """ + + param_lookup: Dict[int, str] = {} + buffer_lookup: Dict[int, str] = {} + for name, param in original_module.named_parameters(remove_duplicate=False): + param_lookup[id(param)] = name + for name, buffer in original_module.named_buffers(remove_duplicate=False): + buffer_lookup[id(buffer)] = name + + param_buffer_table: Dict[str, str] = {} + for dynamo_name, dynamo_param in traced_module.named_parameters( + remove_duplicate=False + ): + assert dynamo_name not in param_buffer_table + if id(dynamo_param) in param_lookup: + param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)] + + for dynamo_name, dynamo_buffer in traced_module.named_buffers( + remove_duplicate=False + ): + assert dynamo_name not in param_buffer_table + if id(dynamo_buffer) in buffer_lookup: + param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)] + + return param_buffer_table + + +def _preserve_requires_grad_pass( + gm: torch.fx.GraphModule, + sig: ExportGraphSignature, + fake_params_buffers: Dict[str, torch.Tensor], + constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], + flat_fake_args: List[Any], +): + placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + assert len(sig.input_specs) == len(placeholders) + i = 0 + for node, spec in zip(placeholders, sig.input_specs): + if spec.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, + ): + assert spec.target is not None + node.meta["val"].requires_grad = fake_params_buffers[ + spec.target + ].requires_grad + elif spec.kind == InputKind.USER_INPUT: + fake_arg = flat_fake_args[i] + if isinstance(fake_arg, torch.Tensor): + node.meta["val"].requires_grad = fake_arg.requires_grad + i += 1 + elif spec.kind == InputKind.CONSTANT_TENSOR: + assert spec.target is not None + constant = constants[spec.target] + if isinstance(constant, torch.Tensor): + # If the tensor is not leaf, it should already have a correct requires grad field + if node.meta["val"].is_leaf: + node.meta["val"].requires_grad = constant.requires_grad + else: + assert node.meta["val"].requires_grad == constant.requires_grad + elif spec.kind in (InputKind.CUSTOM_OBJ, InputKind.TOKEN): + continue + else: + raise AssertionError(spec.kind) + + +def _remap_constants( + orig_constant_attrs: ConstantAttrMap, + graph_signature: ExportGraphSignature, + constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], +) -> None: + """Rewrite the graph signature and constants table to use the FQN from the original module.""" + remap_table: Dict[str, List[str]] = {} + for name, value in constants.items(): + if value in orig_constant_attrs: + remap_table[name] = orig_constant_attrs[value] + + for spec in graph_signature.input_specs: + if spec.kind in ( + InputKind.CONSTANT_TENSOR, + InputKind.CUSTOM_OBJ, + ): + orig_target = spec.target + assert orig_target is not None + targets = remap_table.get(orig_target, [orig_target]) + spec.target = targets[0] + + constant = constants[orig_target] + del constants[orig_target] + for target in targets: + constants[target] = constant + + +def _rename_constants_nodes( + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, +) -> None: + """ + For strict mode, rename constants nodes that were previously annotated as buffers. + """ + # handle name collisions with existing constants + node_names = {node.name for node in gm.graph.nodes} + + def rename_constant(name): + if name in node_names: + n = 1 + while (dup_name := f"{name}_{n}") in node_names: + n += 1 + name = dup_name + node_names.add(name) + return name + + # use input specs to map names from buffers to constants + buffer_prefix = placeholder_prefixes[InputKind.BUFFER] + const_prefix = placeholder_prefixes[InputKind.CONSTANT_TENSOR] + buffer_to_constant = {} + for spec in graph_signature.input_specs: + if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith( + const_prefix + ): + if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants + c_name = rename_constant( + const_prefix + spec.arg.name[len(buffer_prefix) :] + ) + else: # lifted constant + c_name = rename_constant(const_prefix + spec.arg.name) + buffer_to_constant[spec.arg.name] = c_name + spec.arg.name = c_name + for spec in graph_signature.output_specs: + if spec.arg.name in buffer_to_constant: + spec.arg.name = buffer_to_constant[spec.arg.name] + + # Rename constants nodes for all modules + for mod in gm.modules(): + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in mod.graph.nodes: + if node.name in buffer_to_constant: + node.name = node.target = buffer_to_constant[node.name] + mod.recompile() + + +def _restore_state_dict( + original_module: torch.nn.Module, traced_module: torch.fx.GraphModule +) -> None: + """ + Restores the state dict of the traced module to that of the original module. + """ + param_buffer_table = _get_param_buffer_mapping(original_module, traced_module) + # Since the graph module is flattened (no module heirarchy), we + # need to noramlize the module by replacing "." with "_". If we + # don't, it will try to save the weight to a submodule which no + # longer exists. + for name, fqn in param_buffer_table.items(): + param_buffer_table[name] = fqn.replace(".", "_") + + # Replace state dict attr names with the fqn + for name, fqn in param_buffer_table.items(): + if not hasattr(traced_module, name): + continue + + attr = getattr(traced_module, name) + if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter): + traced_module.register_buffer(fqn, attr) + else: + setattr(traced_module, fqn, attr) + delattr(traced_module, name) + + # Replace graph getattr nodes with the correct name + for node in traced_module.graph.nodes: + if node.op == "get_attr": + attr_name = node.target + if attr_name in param_buffer_table: + node.target = param_buffer_table[attr_name] + + traced_module.recompile() + + +def _get_module_hierarchy(mod: torch.nn.Module) -> Dict[str, str]: + return { + name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False) + } + + +def _make_module_call_graph( + module_hierarchy: Dict[str, str], + in_spec: TreeSpec, + out_spec: TreeSpec, + module_call_signatures: Dict[str, ModuleCallSignature], +) -> List[ModuleCallEntry]: + ret = [ + ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn)) + for fqn in module_hierarchy + ] + assert ret[0].fqn == "" + ret[0].signature = ModuleCallSignature( + inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec + ) + return ret + + +def _export_to_torch_ir( + f: Callable, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + *, + preserve_module_call_signature: Tuple[str, ...] = (), + disable_constraint_solver: bool = False, + allow_complex_guards_as_runtime_asserts: bool = False, + restore_fqn: bool = True, + _log_export_usage: bool = True, + same_signature: bool = True, +) -> torch.fx.GraphModule: + """ + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside and produce a torch.fx.GraphModule in torch IR. + """ + + if _log_export_usage: + log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"}) + + if not isinstance(args, tuple): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}", + ) + + kwargs = kwargs or {} + combined_args = _combine_args(f, args, kwargs) + _check_dynamic_shapes(combined_args, dynamic_shapes) + transformed_dynamic_shapes = _transform_shapes_for_default_dynamic( + combined_args, dynamic_shapes + ) + + with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): + try: + module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {} + with _wrap_submodules( + f, preserve_module_call_signature, module_call_specs + ), _ignore_backend_decomps(): + gm_torch_level, _ = torch._dynamo.export( + f, + dynamic_shapes=transformed_dynamic_shapes, # type: ignore[arg-type] + tracing_mode="symbolic", + disable_constraint_solver=disable_constraint_solver, + # currently the following 2 flags are tied together for export purposes, + # but untangle for sake of dynamo export api + prefer_deferred_runtime_asserts_over_guards=True, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + _log_export_usage=_log_export_usage, + same_signature=same_signature, + )( + *args, + **kwargs, + ) + except (ConstraintViolationError, ValueRangeError) as e: + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 + except GuardOnDataDependentSymNode as e: + raise UserError( # noqa: B904 + UserErrorType.ANTI_PATTERN, + f"Consider annotating your code using torch._check*(). {str(e)}", + case_name="constrain_as_size_example", + ) + + gm_torch_level.meta["module_call_specs"] = module_call_specs + + if isinstance(f, torch.nn.Module) and restore_fqn: + _restore_state_dict(f, gm_torch_level) + + return gm_torch_level + + +def _export_to_aten_ir( + mod: torch.nn.Module, + fake_args, + fake_kwargs, + fake_params_buffers, + constant_attrs: ConstantAttrMap, + produce_guards_callback=None, + *, + transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. + pre_dispatch=False, + decomp_table=None, + _check_autograd_state=True, + _is_torch_jit_trace=False, +) -> ATenExportArtifact: + # [NOTE] If the user is exporting under training mode, we want to detect if there is any + # state change in the autograd global state and error. If the user is exporting under inference + # mode, we don't care. At predispatch level, we don't care about the state change. + is_grad_enabled = torch._C.is_grad_enabled() + grad_safe_guard = nullcontext() + # export_to_aten_ir is called when we decompose the ep into inference IR + # In that setting, we actually shouldn't check the state change as at this point, + # because the intention is specalizing to inference. + if _check_autograd_state: + if not pre_dispatch and is_grad_enabled: + grad_safe_guard = AutogradStateOpsFailSafeguard() # type: ignore[assignment] + + @contextmanager + def _compiling_state_context(): + old_value = torch.compiler._is_compiling_flag + try: + torch.compiler._is_compiling_flag = True + yield + finally: + torch.compiler._is_compiling_flag = old_value + + # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, + # otherwise aot_export_module will error out because it sees a mix of fake_modes. + # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. + with torch.nn.utils.stateless._reparametrize_module( + mod, + fake_params_buffers, + tie_weights=True, + strict=True, + stack_weights=True, + ), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined] + gm, graph_signature = transform(aot_export_module)( + mod, + fake_args, + trace_joint=False, + pre_dispatch=pre_dispatch, + decompositions=decomp_table, + kwargs=fake_kwargs, + ) + + def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm): + if isinstance(old_gm, torch.fx.GraphModule): + if hasattr(old_gm, "meta"): + new_gm.meta.update(old_gm.meta) + old_output_node = list(old_gm.graph.nodes)[-1] + new_output_node = list(new_gm.graph.nodes)[-1] + assert old_output_node.op == "output" and new_output_node.op == "output" + # make sure we don't override any meta + assert len(new_output_node.meta) == 0 + new_output_node.meta.update(old_output_node.meta) + + # TODO unfortunately preserving graph-level metadata and output node's meta + # is not working well with aot_export. So we manually copy it. + # (The node-level meta is addressed above.) + _maybe_fixup_gm_and_output_node_meta(mod, gm) + + # Run produce guards before we handle runtime asserts. + # This means we run the export solver before the runtime asserts pass. + # Right now this doesn't mean much - the export solver is only there for suggested fixes, + # and we won't even get to constraint solving if that's needed. + # But if in future we want to control what runtime asserts are emitted for export, + # or rely on produce_guards + solver for some simplification on runtime asserts, this probably makes sense. + if produce_guards_callback: + try: + produce_guards_callback(gm) + except (ConstraintViolationError, ValueRangeError) as e: + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 + + # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature. + # Overwrite output specs afterwards. + flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs)) + if not torch._dynamo.config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" + ) + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + shape_env = _get_shape_env_from_gm(gm) + if shape_env: + insert_deferred_runtime_asserts( + gm, + shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) + + # update output specs + gm.recompile() + graph_signature.user_outputs = _graph_output_names(gm) + + # NOTE: aot_export adds symint metadata for placeholders with int values; + # since these become specialized, we replace such metadata with the original values + index = 0 + total_non_user_inputs = ( + len(graph_signature.parameters) + + len(graph_signature.buffers) + + len(graph_signature.input_tokens) + ) + for node in gm.graph.nodes: + if node.op == "placeholder": + if index >= total_non_user_inputs: + user_arg = flat_fake_args[index - total_non_user_inputs] + if not isinstance(user_arg, torch.Tensor): + node.meta["val"] = user_arg + index += 1 + + export_graph_signature = _convert_to_export_graph_signature( + graph_signature, gm, _get_non_persistent_buffers(mod) + ) + + constants = rewrite_script_object_meta(gm) + constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) + + if pre_dispatch: + from torch._export.passes.replace_autocast_with_hop_pass import ( + replace_autocast_with_hop_pass, + ) + from torch._export.passes.replace_set_grad_with_hop_pass import ( + replace_set_grad_with_hop_pass, + ) + + # Note: replace_set_grad_with_hop_pass need to be after lift_constant_pass because + # a getattr of a constant tensor doesn't have meta["val"] until after lift_constant_pass. + # If replace_set_grad_with_hop_pass is before lift_constant_pass, + # and the constant_tensor is passed as input of the set grad hop, the placeholder's + # meta["val"] will be None and fails our verifier for placeholder. + gm, export_graph_signature = replace_set_grad_with_hop_pass( + gm, export_graph_signature + ) + + gm, export_graph_signature = replace_autocast_with_hop_pass( + gm, export_graph_signature + ) + + # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes. + for _mod in gm.modules(): + if not isinstance(_mod, torch.fx.GraphModule): + continue + for node in _mod.graph.nodes: + if node.op in ["placeholder", "output"]: + node.meta.pop("nn_module_stack", None) + node.meta.pop("stack_trace", None) + + # Prettify names for placeholder nodes. + placeholder_naming_pass( + gm, + export_graph_signature, + mod, + fake_args, + fake_kwargs, + fake_params_buffers, + constants, + ) + + _preserve_requires_grad_pass( + gm, export_graph_signature, fake_params_buffers, constants, flat_fake_args + ) + + return ATenExportArtifact( + gm, + export_graph_signature, + constants, + ) + + +def _fakify_params_buffers( + fake_mode: FakeTensorMode, + mod: torch.nn.Module, +) -> Dict[str, Union[torch.Tensor, torch.nn.Parameter]]: + params_buffers = { + **dict(mod.named_parameters(remove_duplicate=False)), + **dict(mod.named_buffers(remove_duplicate=False)), + } + + faked_params_buffers = {} + memo: Dict[int, FakeTensor] = {} + for key, value in params_buffers.items(): + if id(value) in memo: + fake_tensor = memo[id(value)] + else: + fake_tensor = fake_mode.from_tensor(value, static_shapes=True) + memo[id(value)] = fake_tensor + faked_params_buffers[key] = fake_tensor + return faked_params_buffers # type: ignore[return-value] + + +def _get_forward_arg_names( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, +) -> List[str]: + """ + Gets the argument names to forward that are used, for restoring the + original signature when unlifting the exported program module. + - Positional args: retain the original argument names, and enumerate + *args as args_0, args_1, ... + - Keyword args: retain the original kwarg names in the order specified + by the user. This order seems to matter for the current state of + export lifted modules. + """ + sig = inspect.signature(mod.forward) + _args = sig.bind_partial(*args).arguments + + names: List[str] = [] + for name, value in _args.items(): + # handle variable number of positional args + if sig.parameters[name].kind == inspect._ParameterKind.VAR_POSITIONAL: + names.extend([f"{name}_{i}" for i, _ in enumerate(value)]) + else: + names.append(name) + # order of kwargs matters for input spec + if kwargs: + names.extend([kwarg for kwarg, _ in kwargs.items()]) + + return names + + +def _get_non_persistent_buffers(mod: torch.nn.Module) -> Set[str]: + """ + Returns set of non-persistent buffers in a module and its submodules. + """ + result = set() + for name, m in mod.named_modules(): + for b in m._non_persistent_buffers_set: + result.add(f"{name}.{b}" if name else b) + return result + + +def _rewrite_dynamo_tensor_constants( + orig_mod_buffers: Set[torch.Tensor], + traced_mod_buffers: Dict[str, torch.Tensor], + graph_signature: ExportGraphSignature, + constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], +): + """ + Dynamo erroneously marks tensor attributes on modules as buffers. + Rewrite them to be tensor constants. + """ + for spec in graph_signature.input_specs: + if spec.kind == InputKind.BUFFER: + assert spec.target is not None + value = traced_mod_buffers[spec.target] + if value not in orig_mod_buffers: + # This was a tensor constant erroneously marked as a buffer. + # Convert it into a constant in the graph signature, and add its + # value to the constants table. + spec.kind = InputKind.CONSTANT_TENSOR + constants[spec.target] = value # type: ignore[arg-type] + + +def _move_non_persistent_buffers_to_tensor_constants( + orig_mod: torch.nn.Module, + graph_signature: ExportGraphSignature, + constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], +): + """ + Moves non-persistent buffers to tensor constants. + """ + for spec in graph_signature.input_specs: + if spec.kind == InputKind.BUFFER and not spec.persistent: + assert spec.target is not None + assert spec.target not in constants + constants[spec.target] = orig_mod.get_buffer(spec.target) # type: ignore[arg-type] + + +def _verify_nn_module_stack(graph_module: torch.fx.GraphModule) -> None: + """ + Perform nn_module_stack checks on the graph. + Current constraints: + For the top level graph: + - populated for 'call_function', 'get_attr' + - None for 'placeholder', 'output' + For submodule graphs: + - None for 'placeholder', output' + + TODO(pianpwk): make this a consistent node-level check once nn_module_stack is populated for cond submodules. + """ + # Check top-level graph for all nodes, all graphs for placeholder & output nodes + for i, mod in enumerate([graph_module] + list(graph_module.modules())): + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in mod.graph.nodes: + if node.op in ["call_function", "get_attr"]: + if i == 0: + if ( + nn_module_stack := node.meta.get("nn_module_stack", None) + ) is None: + raise SpecViolationError( + f"Node {node} of type {node.op} is missing nn_module_stack metadata" + ) + if not all( + isinstance(k, str) + and isinstance(v, tuple) + and len(v) == 2 + and all(isinstance(x, str) for x in v) + for k, v in nn_module_stack.items() + ): + raise SpecViolationError( + f"Node {node} of type {node.op} has incorrect nn_module_stack metadata format" + f"expected Dict[str, Tuple[str, str]], but got {nn_module_stack}" + ) + elif node.op in ["placeholder", "output"]: + if node.meta.get("nn_module_stack", None): + raise SpecViolationError( + f"Node {node} of type {node.op} contains nn_module_stack metadata, this should be None" + ) + + +def _verify_stack_trace(graph_module: torch.fx.GraphModule) -> None: + """ + Perform stack trace checks on the graph. + Constraints: + - None or non-empty str for 'call_function', 'get_attr' + - None for 'placeholder', 'output' + """ + for i, mod in enumerate([graph_module] + list(graph_module.modules())): + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in graph_module.graph.nodes: + stack_trace = node.meta.get("stack_trace", None) + if node.op in ["call_function", "get_attr"]: + if not (stack_trace is None or isinstance(stack_trace, str)): + raise SpecViolationError( + f"Node {node} of type {node.op} has invalid stack_trace metadata, " + f"expected a string or None but instead found: {stack_trace}" + ) + elif node.op in ["placeholder", "output"]: + if stack_trace: + raise SpecViolationError( + f"Node {node} of type {node.op} contains stack_trace metadata, " + f"expected None but instead found: {stack_trace}" + ) + + +def _verify_placeholder_names(gm: torch.fx.GraphModule, sig: ExportGraphSignature): + """ + Performs a sanity check on the placeholder node names. + - User input nodes: no restrictions, should match the original forward() signature + - Params/buffers/constants/custom_obj/token nodes: should start with prefixes defined in + """ + name_to_kind = {spec.arg.name: spec.kind for spec in sig.input_specs} + for mod in gm.modules(): + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in mod.graph.nodes: + if node.op == "placeholder": + if node.name not in name_to_kind: + continue + node_kind = name_to_kind[node.name] + prefix = placeholder_prefixes[node_kind] + if not node.name.startswith(prefix): + raise SpecViolationError( + f"Placeholder node name {node.name} does not follow spec for {node_kind}, name should have prefix: {prefix}" + ) + + +def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]: + op_count = 0 + op_set = set() + for m in ep.graph_module.modules(): + if not isinstance(m, torch.fx.GraphModule): + continue + for node in m.graph.nodes: + if node.op != "call_function": + continue + op_count += 1 + assert hasattr(node.target, "__module__") + assert hasattr(node.target, "__name__") + op_set.add(f"{node.target.__module__}.{node.target.__name__}") + return {"op_count": op_count, "op_set": op_set} + + +_EXPORT_FLAGS: Optional[Set[str]] = None +_EXPORT_MODULE_HIERARCHY: Optional[Dict[str, str]] = None + + +def _log_export_wrapper(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY + try: + start = time.time() + ep = fn(*args, **kwargs) + end = time.time() + log_export_usage( + event="export.time", + metrics=end - start, + flags=_EXPORT_FLAGS, + **get_ep_stats(ep), + ) + except Exception as e: + t = type(e) + error_type = t.__module__ + "." + t.__qualname__ + case_name = get_class_if_classified_error(e) + if case_name is not None: + log.error(exportdb_error_message(case_name)) + log_export_usage( + event="export.error.classified", + type=error_type, + message=str(e), + flags=_EXPORT_FLAGS, + ) + else: + log_export_usage( + event="export.error.unclassified", + type=error_type, + message=str(e), + flags=_EXPORT_FLAGS, + ) + raise e + finally: + _EXPORT_FLAGS = None + _EXPORT_MODULE_HIERARCHY = None + + return ep + + return wrapper + + +def _process_jit_trace_inputs_for_export(example_inputs, example_kwarg_inputs): + if not isinstance(example_inputs, (tuple, list, dict)): + example_inputs = (example_inputs,) + + elif isinstance(example_inputs, list): + example_inputs = tuple(example_inputs) + + elif ( + isinstance(example_inputs, (torch.Tensor, dict)) + and example_kwarg_inputs is None + ): + example_inputs = (example_inputs,) + + if example_kwarg_inputs is None: + example_kwarg_inputs = {} + return example_inputs, example_kwarg_inputs + + +def _process_export_inputs(mod, args, kwargs, dynamic_shapes): + original_state_dict = mod.state_dict(keep_vars=True) + + if not isinstance(args, tuple): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}", + ) + kwargs = kwargs if kwargs is not None else {} + _, original_in_spec = pytree.tree_flatten((args, kwargs)) + + if isinstance(dynamic_shapes, torch.export.ShapesCollection): + dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) + + return args, kwargs, original_in_spec, original_state_dict, dynamic_shapes + + +def _get_module_call_graph( + export_artifact: ExportArtifact, + original_in_spec: TreeSpec, + preserve_module_call_signature: Tuple[str, ...], + strict_mode_export: bool, +): + """ + In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and + return module_call_graph. + """ + gm: torch.fx.GraphModule = export_artifact.aten.gm + export_graph_signature: ExportGraphSignature = export_artifact.aten.sig + module_call_specs: Dict[ + str, Dict[str, TreeSpec] + ] = export_artifact.module_call_specs + out_spec: TreeSpec = export_artifact.out_spec + + # Make module signatures. + module_call_signatures = {} + for fqn, specs in module_call_specs.items(): + mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn + module_call_signatures[mod_fqn] = ModuleCallSignature( + inputs=[], outputs=[], **specs + ) + + if len(preserve_module_call_signature) > 0: + if not strict_mode_export: + _rewrite_tracepoint_node(gm) + res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm) + assert res is not None + gm = res.graph_module + + assert _EXPORT_MODULE_HIERARCHY is not None + module_call_graph = _make_module_call_graph( + _EXPORT_MODULE_HIERARCHY, + original_in_spec, + out_spec, + module_call_signatures, + ) + return gm, module_call_graph + + +def _get_range_constraints( + export_artifact: ExportArtifact, combined_args: Dict[str, Any], dynamic_shapes +): + gm: torch.fx.GraphModule = export_artifact.aten.gm + export_graph_signature: ExportGraphSignature = export_artifact.aten.sig + fake_mode: FakeTensorMode = export_artifact.fake_mode + num_lifted = next( + ( + i + for i, s in enumerate(export_graph_signature.input_specs) + if s.kind == InputKind.USER_INPUT + ), + len(export_graph_signature.input_specs), + ) + range_constraints = make_constraints( + fake_mode, + gm, + combined_args, + dynamic_shapes, + num_lifted, + ) + return range_constraints + + +def _get_inline_constraints(fake_mode: FakeTensorMode): + assert fake_mode.shape_env is not None + return { + k: v + for k, v in fake_mode.shape_env.var_to_range.items() + if free_unbacked_symbols(k) + } + + +@contextmanager +def patch_forward(obj: torch.nn.Module, new_method): + """Helper method to make it easier to cleanly torch.export() a method on a + module that is not `forward`. + """ + # Save the original method + original_method = obj.forward + + # Patch the method + obj.forward = new_method.__get__(obj, obj.__class__) + + try: + yield + finally: + # Restore the original method + obj.forward = original_method + + +@contextmanager +def _temp_disable_texpr_fuser(): + original_state = torch._C._jit_texpr_fuser_enabled() + torch._C._jit_set_texpr_fuser_enabled(False) + try: + yield + finally: + torch._C._jit_set_texpr_fuser_enabled(original_state) + + +class _WrapperModule(torch.nn.Module): + def __init__(self, f): + super().__init__() + self.f = f + + def forward(self, *args, **kwargs): + return self.f(*args, **kwargs) + + +def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): + with _temp_disable_texpr_fuser(): + from torch.jit._trace import TopLevelTracedModule + + export_args, export_kwargs = _process_jit_trace_inputs_for_export(args, kwargs) + + if isinstance(traced_callable, (TopLevelTracedModule, torch._C.ScriptModule)): # type: ignore[operator] + return _export( + traced_callable, + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + elif isinstance(traced_callable, torch.ScriptMethod) and isinstance( + traced_callable.owner(), (torch._C.ScriptModule, torch.nn.Module) # type: ignore[operator] + ): + with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator] + return _export( + traced_callable.owner(), # type: ignore[operator] + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + else: + return _export( + _WrapperModule(traced_callable), + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + +def _strict_export( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]], + preserve_module_call_signature: Tuple[str, ...], + pre_dispatch: bool, + original_state_dict: Dict[str, Any], + orig_in_spec: TreeSpec, + allow_complex_guards_as_runtime_asserts: bool, + _is_torch_jit_trace: bool, +) -> ExportArtifact: + lower_to_aten = functools.partial(_export_to_aten_ir, pre_dispatch=pre_dispatch) + return _strict_export_lower_to_aten_ir( + mod=mod, + args=args, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + preserve_module_call_signature=preserve_module_call_signature, + pre_dispatch=pre_dispatch, + original_state_dict=original_state_dict, + orig_in_spec=orig_in_spec, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + _is_torch_jit_trace=_is_torch_jit_trace, + lower_to_aten_callback=lower_to_aten, + ) + + +def _strict_export_lower_to_aten_ir( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]], + preserve_module_call_signature: Tuple[str, ...], + pre_dispatch: bool, + original_state_dict: Dict[str, Any], + orig_in_spec: TreeSpec, + allow_complex_guards_as_runtime_asserts: bool, + _is_torch_jit_trace: bool, + lower_to_aten_callback: Callable, +) -> ExportArtifact: + gm_torch_level = _export_to_torch_ir( + mod, + args, + kwargs, + dynamic_shapes, + preserve_module_call_signature=preserve_module_call_signature, + restore_fqn=False, # don't need to restore because we will do it later + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + _log_export_usage=False, + ) + + # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo. + ( + fake_args, + fake_kwargs, + dynamo_fake_mode, + ) = _extract_fake_inputs(gm_torch_level, args, kwargs) + + fake_params_buffers = _fakify_params_buffers(dynamo_fake_mode, gm_torch_level) + + # First, we want to pass through the graph to try populating + # val field for getattr if there is anything missing. + # This can happen when quantization adds extra params and forgets + # to update "val" + for node in gm_torch_level.graph.nodes: + if node.op == "get_attr" and "val" not in node.meta: + attr = getattr(gm_torch_level, node.target) + # Checks if it is not a HigherOrderOp branch or a module + if not isinstance(attr, torch.nn.Module): + assert ( + dynamo_fake_mode is not None + ), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders." + node.meta["val"] = dynamo_fake_mode.from_tensor( + attr, static_shapes=True + ) + + # Fix the graph output signature to be tuple if scalar + out_spec = orig_out_spec = gm_torch_level._out_spec + + # Used to get rid of lint type error. + assert out_spec is not None + assert orig_out_spec is not None + + # aot_export expect the return type to always be a tuple. + if out_spec.type not in (list, tuple): + out_spec = pytree.TreeSpec(tuple, None, [out_spec]) + + orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] + + gm_torch_level.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo( + orig_arg_names, + gm_torch_level._in_spec, + out_spec, + ) + ) + gm_torch_level.recompile() + + _normalize_nn_module_stack(gm_torch_level, type(mod)) + + params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level) + + # When aot_export lifts the params, we lose metadata (e.g. source_fn_stack, stack_trace) + # from the param nodes as they are treated as fresh inputs + # Therefore, we manually extract them before calling into aot_export + # params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level) + + constant_attrs = _gather_constant_attrs(mod) + param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level) + + # Dynamo does not track which buffers were registered as non-persistent. This info + # is available in the original module, so we transfer it to the traced module. Also, + # since we didn't restore original param/buffer names yet, we must use traced names. + non_persistent_buffers = _get_non_persistent_buffers(mod) + reverse_name_lookup = {orig: traced for traced, orig in param_buffer_table.items()} + gm_torch_level._non_persistent_buffers_set = { + reverse_name_lookup[name] + for name in non_persistent_buffers + if name in reverse_name_lookup + } + with dynamo_fake_mode: + aten_export_artifact = lower_to_aten_callback( + gm_torch_level, + # NOTE: graph module expects only positional args + _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs), + {}, + fake_params_buffers, + constant_attrs, + ) + + # Decompose for readability. + gm = aten_export_artifact.gm + export_graph_signature = aten_export_artifact.sig + constants = aten_export_artifact.constants + + _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta, gm, export_graph_signature + ) + + # Do some cleanups on the graph module to restore the state dict to the + # expected form. Each of these steps should probably get fixed upstream. + # 1. Remove tensor constants that were added as buffers. + _rewrite_dynamo_tensor_constants( + orig_mod_buffers=set(mod.buffers()), + traced_mod_buffers=dict(gm_torch_level.named_buffers()), + graph_signature=export_graph_signature, + constants=constants, + ) + # 2. Restore FQN of param/buffers + _replace_param_buffer_names(param_buffer_table, export_graph_signature) + + # 3. Move non-persistent buffers to tensor constants + _move_non_persistent_buffers_to_tensor_constants( + mod, export_graph_signature, constants + ) + + # 4. Rewrite constants to have the same FQN as the original module. + _remap_constants(constant_attrs, export_graph_signature, constants) + + # 5. Rename constants nodes in graph module from buffers to constants + _rename_constants_nodes(gm, export_graph_signature) + + return ExportArtifact( + aten=aten_export_artifact, + out_spec=orig_out_spec, + fake_mode=dynamo_fake_mode, + module_call_specs=gm_torch_level.meta["module_call_specs"], + ) + + +def _export_to_aten_ir_make_fx( + mod: torch.nn.Module, + fake_args, + fake_kwargs, + fake_params_buffers, + constant_attrs: ConstantAttrMap, + produce_guards_callback=None, + transform=lambda x: x, +) -> ATenExportArtifact: + @contextmanager + def _compiling_state_context(): + old_value = torch.compiler._is_compiling_flag + try: + torch.compiler._is_compiling_flag = True + yield + finally: + torch.compiler._is_compiling_flag = old_value + + def _make_fx_helper(mod, args, kwargs, **flags): + from torch._functorch._aot_autograd.schemas import GraphSignature + + kwargs = kwargs or {} + + named_parameters = dict(mod.named_parameters(remove_duplicate=False)) + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) + + params_and_buffers = {**named_parameters, **named_buffers} + params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers) + params_and_buffers_flat = tuple(params_and_buffers_flat) + + param_len = len(named_parameters) + buffer_len = len(named_buffers) + params_len = len(params_and_buffers) + + functional_call = create_functional_call( + mod, params_spec, params_len, store_orig_mod=True + ) + + params_buffers_args: List[Any] = [] + params_buffers_args.extend(params_and_buffers_flat) + params_buffers_args.extend(args) + + flat_fn, out_spec = create_tree_flattened_fn( + functional_call, params_buffers_args, kwargs + ) + flat_args, in_spec = pytree.tree_flatten((params_buffers_args, kwargs)) + + @functools.wraps(flat_fn) + def wrapped_fn(*args): + return tuple(flat_fn(*args)) + + with enable_python_dispatcher(): + gm = make_fx( + wrapped_fn, + record_module_stack=True, + pre_dispatch=True, + )(*flat_args) + gm.graph.eliminate_dead_code() + + # create graph signature + input_names = _graph_input_names(gm) + output_names = _graph_output_names(gm) + sig = GraphSignature( + parameters=list(named_parameters), + buffers=list(named_buffers), + user_inputs=input_names[params_len:], + user_outputs=output_names, + inputs_to_parameters=dict(zip(input_names[0:param_len], named_parameters)), + inputs_to_buffers=dict( + zip(input_names[param_len : param_len + buffer_len], named_buffers) + ), + buffers_to_mutate={}, + user_inputs_to_mutate={}, + in_spec=in_spec, + out_spec=out_spec, # type: ignore[arg-type] + backward_signature=None, + input_tokens=[], + output_tokens=[], + ) + return gm, sig + + # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, + # otherwise aot_export_module will error out because it sees a mix of fake_modes. + # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. + with torch.nn.utils.stateless._reparametrize_module( + mod, + fake_params_buffers, + tie_weights=True, + strict=True, + stack_weights=True, + ), _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined] + param_len = len(dict(mod.named_parameters(remove_duplicate=False))) + buffer_len = len(dict(mod.named_buffers(remove_duplicate=False))) + params_len = param_len + buffer_len + + gm, graph_signature = transform(_make_fx_helper)( + mod, + fake_args, + trace_joint=False, + kwargs=fake_kwargs, + ) + + if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"): + gm.meta.update(mod.meta) + + flat_args = pytree.tree_leaves((fake_args, fake_kwargs)) + index = 0 + for node in gm.graph.nodes: + if node.op == "placeholder": + if index >= params_len: + user_arg = flat_args[index - params_len] + if not isinstance(user_arg, torch.Tensor): + node.meta["val"] = user_arg + index += 1 + + export_graph_signature = _convert_to_export_graph_signature( + graph_signature, gm, _get_non_persistent_buffers(mod) + ) + + # See comment in _export_to_aten_ir() + if produce_guards_callback: + try: + produce_guards_callback(gm) + except (ConstraintViolationError, ValueRangeError) as e: + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 + + fake_mode = detect_fake_mode(flat_args) + + if not torch._dynamo.config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" + ) + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + insert_deferred_runtime_asserts( + gm, + fake_mode.shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) + + # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes. + for _mod in gm.modules(): + if not isinstance(_mod, torch.fx.GraphModule): + continue + for node in _mod.graph.nodes: + if node.op in ["placeholder", "output"]: + node.meta.pop("nn_module_stack", None) + node.meta.pop("stack_trace", None) + + constants = rewrite_script_object_meta(gm) + constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) + + _preserve_requires_grad_pass( + gm, export_graph_signature, fake_params_buffers, constants, flat_args + ) + + # Prettify names for placeholder nodes. + placeholder_naming_pass( + gm, + export_graph_signature, + mod, + fake_args, + fake_kwargs, + fake_params_buffers, + constants, + ) + + return ATenExportArtifact( + gm, + export_graph_signature, + constants, + ) + + +def _non_strict_export( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]], + preserve_module_call_signature: Tuple[str, ...], + pre_dispatch: bool, + original_state_dict: Dict[str, Any], + orig_in_spec: TreeSpec, + allow_complex_guards_as_runtime_asserts: bool, + _is_torch_jit_trace: bool, + dispatch_tracing_mode: str = "aot_export", +) -> ExportArtifact: + """ + ``dispatch_tracing_mode`` can be either "make_fx” or “aot_export”, corresponding to + _export_to_aten_ir_make_fx and _export_to_aten_ir, respectively. + """ + assert dispatch_tracing_mode in ["make_fx", "aot_export"] + out_spec: Optional[TreeSpec] = None + + module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {} + + def _tuplify_outputs(aot_export): + def _aot_export_non_strict(mod, args, kwargs=None, **flags): + kwargs = kwargs or {} + + class Wrapper(torch.nn.Module): + def __init__(self, mod): + super().__init__() + self._export_root = mod + + def forward(self, *args, **kwargs): + nonlocal out_spec + if isinstance(self._export_root, torch.fx.GraphModule): + with torch.fx.traceback.preserve_node_meta(): + tree_out = torch.fx.Interpreter(self._export_root).run( + *args, **kwargs + ) + else: + tree_out = self._export_root(*args, **kwargs) + flat_outs, out_spec = pytree.tree_flatten(tree_out) + return tuple(flat_outs) + + wrapped_mod = Wrapper(mod) + # Patch export_root to the signatures so that wrapper module correctly populates the + # in/out spec + new_preserved_call_signatures = [ + "_export_root." + i for i in preserve_module_call_signature + ] + with _wrap_submodules( + wrapped_mod, new_preserved_call_signatures, module_call_specs + ): + gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags) + log.debug("Exported program from AOTAutograd:\n%s", gm) + + sig.parameters = pytree.tree_map(_strip_root, sig.parameters) + sig.buffers = pytree.tree_map(_strip_root, sig.buffers) + sig.inputs_to_buffers = pytree.tree_map(_strip_root, sig.inputs_to_buffers) + sig.inputs_to_parameters = pytree.tree_map( + _strip_root, sig.inputs_to_parameters + ) + sig.buffers_to_mutate = pytree.tree_map(_strip_root, sig.buffers_to_mutate) + + for node in gm.graph.nodes: + if "nn_module_stack" in node.meta: + nn_module_stack = node.meta["nn_module_stack"] + node.meta["nn_module_stack"] = { + _fixup_key(key): val + for key, val in pytree.tree_map( + _strip_root, nn_module_stack + ).items() + } + + return gm, sig + + return _aot_export_non_strict + + ( + fake_mode, + fake_args, + fake_kwargs, + equalities_inputs, + original_signature, + transformed_dynamic_shapes, + ) = make_fake_inputs( + mod, + args, + kwargs, + dynamic_shapes, + _is_torch_jit_trace=_is_torch_jit_trace, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, # for shape env initialization + ) + + fake_params_buffers = _fakify_params_buffers(fake_mode, mod) + + def _produce_guards_callback(gm): + return produce_guards_and_solve_constraints( + fake_mode=fake_mode, + gm=gm, + dynamic_shapes=transformed_dynamic_shapes, + equalities_inputs=equalities_inputs, + original_signature=original_signature, + _is_torch_jit_trace=_is_torch_jit_trace, + ) + + with fake_mode, _NonStrictTorchFunctionHandler(), torch._dynamo.config.patch( + assume_static_by_default=False + ): + with _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as ( + patched_mod, + new_fake_args, + new_fake_kwargs, + new_fake_constant_attrs, + map_fake_to_real, + ): + _to_aten_func = ( + _export_to_aten_ir_make_fx + if dispatch_tracing_mode == "make_fx" + else functools.partial( + _export_to_aten_ir, + pre_dispatch=pre_dispatch, + _is_torch_jit_trace=_is_torch_jit_trace, + ) + ) + aten_export_artifact = _to_aten_func( # type: ignore[operator] + patched_mod, + new_fake_args, + new_fake_kwargs, + fake_params_buffers, + new_fake_constant_attrs, + produce_guards_callback=_produce_guards_callback, + transform=_tuplify_outputs, + ) + # aten_export_artifact.constants contains only fake script objects, we need to map them back + aten_export_artifact.constants = { + fqn: map_fake_to_real[obj] if isinstance(obj, FakeScriptObject) else obj + for fqn, obj in aten_export_artifact.constants.items() + } + + _move_non_persistent_buffers_to_tensor_constants( + mod, aten_export_artifact.sig, aten_export_artifact.constants + ) + + assert out_spec is not None + + return ExportArtifact( + aten=aten_export_artifact, + out_spec=out_spec, + fake_mode=fake_mode, + module_call_specs=module_call_specs, + ) + + +# TODO (tmanlaibaatar) We need to preserve aten.to here somehow +@_log_export_wrapper +@_disable_prexisiting_fake_mode +def _export_for_training( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + *, + strict: bool = True, + preserve_module_call_signature: Tuple[str, ...] = (), +) -> ExportedProgram: + global _EXPORT_MODULE_HIERARCHY + _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod) + + ( + args, + kwargs, + orig_in_spec, + original_state_dict, + dynamic_shapes, + ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes) + + export_func = ( + functools.partial( + _strict_export_lower_to_aten_ir, + lower_to_aten_callback=_export_to_aten_ir_make_fx, + ) + if strict + else functools.partial( + _non_strict_export, + dispatch_tracing_mode="make_fx", + ) + ) + export_artifact = export_func( # type: ignore[operator] + mod=mod, + args=args, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + preserve_module_call_signature=preserve_module_call_signature, + pre_dispatch=False, + original_state_dict=original_state_dict, + orig_in_spec=orig_in_spec, + allow_complex_guards_as_runtime_asserts=False, + _is_torch_jit_trace=False, + ) + + export_graph_signature = export_artifact.aten.sig + + forward_arg_names = _get_forward_arg_names(mod, args, kwargs) + inline_constraints = _get_inline_constraints(export_artifact.fake_mode) + # The unbacked symint symbols are updated in aot_export + # so we serialize them here instead of inside dynamo. + # Note: _get_range_constraints depends on "inline_constraints" to be set. + export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints + range_constraints = _get_range_constraints( + export_artifact, + _combine_args(mod, args, kwargs, _is_torch_jit_trace=False), + dynamic_shapes, + ) + # The returned the gm is in-place modified + gm, module_call_graph = _get_module_call_graph( + export_artifact, orig_in_spec, preserve_module_call_signature, strict + ) + + # Add forward args metadata. + gm.meta["forward_arg_names"] = forward_arg_names + + _verify_nn_module_stack(gm) + _verify_stack_trace(gm) + _verify_placeholder_names(gm, export_graph_signature) + + from torch._export.verifier import TrainingIRVerifier + + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=export_graph_signature, + state_dict=original_state_dict, + range_constraints=range_constraints, + module_call_graph=module_call_graph, + example_inputs=(args, kwargs), + constants=export_artifact.aten.constants, + verifiers=[TrainingIRVerifier], + ) + + return exported_program + + +@_log_export_wrapper +@_disable_prexisiting_fake_mode +def _export( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + *, + strict: bool = True, + preserve_module_call_signature: Tuple[str, ...] = (), + pre_dispatch: bool = False, + allow_complex_guards_as_runtime_asserts: bool = False, + _is_torch_jit_trace: bool = False, +) -> ExportedProgram: + """ + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside and produce a ExportedProgram. + + Args: + f: the `nn.Module` to trace. + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + preserve_module_call_signature: A list of submodule paths for which the original + calling conventions are preserved as metadata. + + allow_complex_guards_as_runtime_asserts: + With the current dynamic shapes language for dims and derived dims, we can run into constraints + that are not expressible with the language. For example, flattening a matrix and adding to a vector, + both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible. + By default, we either raise a constraint violation error or specialize to static values. + If this flag is set to True, we avoid erroring out and instead allow complex constraints to exist as runtime + assertions in the graph. The sympy interpreter (torch/utils/_sympy/interp.py) will produce the math ops + required to compute and assert the value of the guard (e.g. sym_size_int, eq, _assert_scalar). + Additionally, if TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 is specified, we will allow complex constraints + while not emitting runtime asserts, returning a cleaner graph with lesser guarantees around dynamic shapes. + + Returns: + An ExportedProgram containing the traced method. + """ + + global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY + _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod) + + flags = set() + flags.add("strict" if strict else "non_strict") + flags.add("pre_dispatch" if pre_dispatch else "aot_dispatch") + _EXPORT_FLAGS = flags + + log_export_usage(event="export.enter", flags=_EXPORT_FLAGS) + + ( + args, + kwargs, + original_in_spec, + original_state_dict, + dynamic_shapes, + ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes) + + # Call the appropriate export function based on the strictness of tracing. + export_func = _strict_export if strict else _non_strict_export + + export_artifact = export_func( # type: ignore[operator] + mod, + args, + kwargs, + dynamic_shapes, + preserve_module_call_signature, + pre_dispatch, + original_state_dict, + original_in_spec, + allow_complex_guards_as_runtime_asserts, + _is_torch_jit_trace, + ) + export_graph_signature: ExportGraphSignature = export_artifact.aten.sig + + forward_arg_names = ( + _get_forward_arg_names(mod, args, kwargs) if not _is_torch_jit_trace else None + ) + inline_constraints = _get_inline_constraints(export_artifact.fake_mode) + # The unbacked symint symbols are updated in aot_export + # so we serialize them here instead of inside dynamo. + # Note: this step must be before _get_range_constraints. + export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints + range_constraints = _get_range_constraints( + export_artifact, + _combine_args(mod, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace), + dynamic_shapes, + ) + gm, module_call_graph = _get_module_call_graph( + export_artifact, original_in_spec, preserve_module_call_signature, strict + ) + + # Add forward args metadata. + gm.meta["forward_arg_names"] = forward_arg_names + + _verify_nn_module_stack(gm) + _verify_stack_trace(gm) + if not _is_torch_jit_trace: + _verify_placeholder_names(gm, export_graph_signature) + + # Remove Proxy because they cannot be deepcopied or pickled. + torch._export.utils.remove_proxy_from_state_dict(original_state_dict, in_place=True) + + from torch._export.verifier import Verifier + + if ( + isinstance(mod, torch.fx.GraphModule) + and hasattr(mod, "meta") + and "custom" in mod.meta + ): + gm.meta.update({"custom": mod.meta["custom"]}) + + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=export_graph_signature, + state_dict=original_state_dict, + range_constraints=range_constraints, + module_call_graph=module_call_graph, + example_inputs=(args, kwargs), + constants=export_artifact.aten.constants, + verifiers=[Verifier], + ) + + return exported_program diff --git a/lib/python3.10/site-packages/torch/export/_tree_utils.py b/lib/python3.10/site-packages/torch/export/_tree_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a1615ebd5f586ce5216c65d40162a74ffb7bc5d1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/export/_tree_utils.py @@ -0,0 +1,64 @@ +from typing import Any, Callable, Dict, Optional + +from torch.utils._pytree import Context, TreeSpec + + +def reorder_kwargs(user_kwargs: Dict[str, Any], spec: TreeSpec) -> Dict[str, Any]: + """Reorder user-provided kwargs to match the order in `spec`. `spec` is + expected to be the in_spec of an exported program, i.e. the spec that + results from flattening `(args, kwargs)`. + + We need this to provide consistent input ordering, such so that users can + pass in foo(a=a, b=b) OR foo(b=b, a=a) and receive the same result. + """ + # Make sure that the spec is actually shaped like (args, kwargs) + assert spec.type is tuple + assert spec.num_children == 2 + kwargs_spec = spec.children_specs[1] + assert kwargs_spec.type is dict + + if set(user_kwargs) != set(kwargs_spec.context): + raise ValueError( + f"kwarg key mismatch: " + f"Got {list(user_kwargs)} but expected {kwargs_spec.context}" + ) + + reordered_kwargs = {} + for kw in kwargs_spec.context: + reordered_kwargs[kw] = user_kwargs[kw] + + return reordered_kwargs + + +def is_equivalent( + spec1: TreeSpec, + spec2: TreeSpec, + equivalence_fn: Callable[[Optional[type], Context, Optional[type], Context], bool], +) -> bool: + """Customizable equivalence check for two TreeSpecs. + + Arguments: + spec1: The first TreeSpec to compare + spec2: The second TreeSpec to compare + equivalence_fn: A function to determine the equivalence of two + TreeSpecs by examining their types and contexts. It will be called like: + + equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context) + + This function will be applied recursively to all children. + + Returns: + True if the two TreeSpecs are equivalent, False otherwise. + """ + if not equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context): + return False + + # Recurse on children + if len(spec1.children_specs) != len(spec2.children_specs): + return False + + for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs): + if not is_equivalent(child_spec1, child_spec2, equivalence_fn): + return False + + return True diff --git a/lib/python3.10/site-packages/torch/export/_unlift.py b/lib/python3.10/site-packages/torch/export/_unlift.py new file mode 100644 index 0000000000000000000000000000000000000000..ad48372e2d9aef93e3e079135f1bd4af73904201 --- /dev/null +++ b/lib/python3.10/site-packages/torch/export/_unlift.py @@ -0,0 +1,361 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from itertools import chain +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from torch._export.utils import _check_input_constraints_for_graph +from torch.export.unflatten import _assign_attr, _AttrKind, _recursive_getattr +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo + +from ._remove_effect_tokens_pass import _remove_effect_tokens +from .exported_program import ( + ExportedProgram, + ExportGraphSignature, + InputKind, + OutputKind, +) + + +@torch._dynamo.disable +def _check_input_constraints_pre_hook(self, *args, **kwargs): + flat_args_with_path, received_spec = pytree.tree_flatten_with_path(args) + + if received_spec != self._in_spec: + raise ValueError( # noqa: B904 + "Trying to flatten user inputs with exported input tree spec: \n" + f"{self._in_spec}\n" + "but actually got inputs with tree spec of: \n" + f"{received_spec}" + ) + + return _check_input_constraints_for_graph( + [node for node in self.graph.nodes if node.op == "placeholder"], + flat_args_with_path, + self.range_constraints, + ) + + +def _unlift_inputs_as_getattr( + gm: torch.fx.GraphModule, + lifted_inputs: List[Optional[str]], +) -> Tuple[Dict[str, torch.fx.Node], Dict[str, torch.fx.Node]]: + """ + Unlift inputs referring to params/buffers/constants as getattr nodes in the + graph + """ + unlifted_name_to_node = {} + input_name_to_node = {} + + placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + assert len(lifted_inputs) == len(placeholder_nodes) + for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs): + if lifted_node is None: + input_name_to_node[input_node.name] = input_node + + else: + with gm.graph.inserting_after(input_node): + getattr_node = gm.graph.get_attr(lifted_node) + input_node.replace_all_uses_with(getattr_node) + metadata = input_node.meta + gm.graph.erase_node(input_node) + getattr_node.meta = metadata + unlifted_name_to_node[lifted_node] = getattr_node + + return unlifted_name_to_node, input_name_to_node + + +def _insert_copy_for_mutations( + gm: torch.fx.GraphModule, + mutated_outputs: List[Optional[str]], + unlifted_name_to_node: Dict[str, torch.fx.Node], + input_name_to_node: Dict[str, torch.fx.Node], +) -> None: + """ + Find the all the buffers and inputs that were mutated and insert copy_ + operators to reflect mutations. + """ + output_node = None + for node in gm.graph.nodes: + if node.op == "output": + output_node = node + break + assert output_node is not None + outputs = pytree.tree_flatten(output_node.args)[0] + assert len(outputs) == len(mutated_outputs) + + user_output_nodes = [] + return_nodes_to_copy = {} + for return_node, mutated_node_name in zip(outputs, mutated_outputs): + if mutated_node_name is None: + user_output_nodes.append(return_node) + continue + + if mutated_node_name in unlifted_name_to_node: + mutated_node = unlifted_name_to_node[mutated_node_name] + elif mutated_node_name in input_name_to_node: + mutated_node = input_name_to_node[mutated_node_name] + else: + raise RuntimeError( + f"Could not find {mutated_node_name} in either buffer or input nodes" + ) + + with gm.graph.inserting_before(output_node): + copy_node = gm.graph.call_function( + torch.ops.aten.copy_.default, (mutated_node, return_node) + ) + return_nodes_to_copy[return_node] = copy_node + + output_args = [ + return_nodes_to_copy[node] if node in return_nodes_to_copy else node + for node in user_output_nodes + ] + with gm.graph.inserting_before(output_node): + # Only return user outputs + new_output = gm.graph.output(tuple(output_args)) + new_output.meta.update(output_node.meta) + output_node.replace_all_uses_with(new_output) + gm.graph.erase_node(output_node) + + +def _get_codegen( + in_spec: pytree.TreeSpec, + out_spec: Optional[pytree.TreeSpec], + forward_arg_names: Optional[List[str]] = None, +) -> _PyTreeCodeGen: + """ + Create the codegen for the graph module based on the in/out specs + """ + if forward_arg_names: + names = forward_arg_names + else: + if ( + in_spec.type == tuple + and in_spec.num_children == 2 + and in_spec.children_specs[0].type == tuple + and in_spec.children_specs[1].type == dict + ): + # if in_spec contains the args (tuple) and kwargs (dict) + names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] + # add kwarg names + names.extend(in_spec.children_specs[1].context) + else: + names = [f"arg_{i}" for i in range(in_spec.num_children)] + + return _PyTreeCodeGen( + _PyTreeInfo( + names, + in_spec, + out_spec, + ) + ) + + +def _unlift( + gm: torch.fx.GraphModule, + lifted_inputs: List[Optional[str]], + mutated_outputs: List[Optional[str]], + in_spec: pytree.TreeSpec, + out_spec: Optional[pytree.TreeSpec], + state_dict: Dict[str, Any], + constants: Dict[str, Any], + forward_arg_names: Optional[List[str]] = None, +): + """ + Args: + lifted_inputs: A list matching the graph module's input nodes. For + an input node that is referring to a lifted parameter/buffer, this + list will contain the fqn the corresponding attribute. Otherwise, this + list will contain None. This is used to unlift the lifted parameters as + get_attr nodes. + + mutated_outputs: A list matching the graph module's output nodes. For + an output node that is referring to a mutated buffer or user input, this + list will contain the name of the corresponding buffer or user input + that needs to be mutated. Otherwise, this list will contain None. This + is used to re-insert an inplace copy_ operator to copy the mutated + values back to the original node. + """ + unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr( + gm, lifted_inputs + ) + _insert_copy_for_mutations( + gm, mutated_outputs, unlifted_name_to_node, input_name_to_node + ) + gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names) + gm.graph.lint() + gm.recompile() + return gm + + +def _register_attrs_to_new_gm( + new_gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + state_dict: Dict[str, Any], + constants: Dict[str, Any], +) -> None: + non_persistent_buffers = set(graph_signature.non_persistent_buffers) + for name in graph_signature.buffers: + if name in non_persistent_buffers: + persistent = False + value = constants[name] + else: + persistent = True + value = state_dict[name] + _assign_attr( + value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent + ) + for name in graph_signature.parameters: + value = state_dict[name] + _assign_attr( + value, + new_gm, + name, + attr_kind=_AttrKind.PARAMETER, + ) + + for name in chain( + graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants + ): + value = constants[name] + _assign_attr( + value, + new_gm, + name, + attr_kind=_AttrKind.CONSTANT, + ) + + +class _StatefulGraphModuleFactory(type): + """ + Metaclass that ensures a private constructor for _StatefulGraphModule + """ + + def __call__(cls, *args, **kwargs): + raise TypeError( + f"{cls.__module__}.{cls.__qualname__} has no public constructor. " + ) + + def _create(cls, root, graph, range_constraints=None): + return super().__call__( + root, + graph, + range_constraints=range_constraints, + ) + + +class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory): + def __init__(self, root, graph, range_constraints=None): + super().__init__(root, graph) + # Need to fix up non-persistent buffers. + self.range_constraints = range_constraints or [] + + +def _create_stateful_graph_module( + plain_graph_module: torch.fx.GraphModule, + range_constraints, + # TODO(suo) this should not be optional, but is since we still ahve + # capture_pre_autograd_graph grr + graph_signature: Optional[ExportGraphSignature] = None, +): + stateful_gm = _StatefulGraphModule._create( + plain_graph_module, + plain_graph_module.graph, + range_constraints=range_constraints, + ) + + stateful_gm.register_forward_pre_hook( + _check_input_constraints_pre_hook, with_kwargs=True + ) + + if graph_signature is None: + return stateful_gm + + # Fix up lifted tensor constants. + # fx.GraphModule() constructor silently turns a constant attribute of plain_graph_module + # into a buffer in stateful_gm and creates an inconsistency with graph_signature. + # We fix this by de-registering these buffers in lifted_tensor_constants + # and call _assign_attr(attr_kind=CONSTANT) to register them as constants. + for constant_fqn in graph_signature.lifted_tensor_constants: + # Sometimes, the constant can require gradient, this is probably a bug in user code, + # e.g. `self.const = torch.randn(2, 2, requires_grad=True)`. + # We call detach on the constant_val since they're tensor contants and we don't need to + # compute their gradients anyway. + # Users should properly register it as parameter if they want it to require gradient. + buffer = stateful_gm.get_buffer(constant_fqn) + if buffer.requires_grad: + warnings.warn( + f"A model attribute `{constant_fqn}` requires gradient. " + f"but it's not properly registered as a parameter. " + f"torch.export will detach it and treat it as a constant tensor " + f"but please register it as parameter instead." + ) + buffer = buffer.detach() + *prefix, field = constant_fqn.rsplit(".") + submod = _recursive_getattr(stateful_gm, prefix) + delattr(submod, field) + _assign_attr(buffer, stateful_gm, constant_fqn, attr_kind=_AttrKind.CONSTANT) + + # Fix up non-persistent buffers. torch.fx does not distinguish between + # persistent and non-persistent buffers, so we must restore that distinction + # here. + for buffer in graph_signature.non_persistent_buffers: + _assign_attr( + plain_graph_module.get_buffer(buffer), + stateful_gm, + buffer, + attr_kind=_AttrKind.BUFFER, + persistent=False, + ) + + return stateful_gm + + +def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module: + ep = _remove_effect_tokens(ep) + new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) + _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants) + forward_arg_names = ep.graph_module.meta.get("forward_arg_names") + + lifted_inputs: List[Optional[str]] = [ + ( + in_spec.target + if in_spec.kind + in ( + InputKind.BUFFER, + InputKind.CONSTANT_TENSOR, + InputKind.PARAMETER, + InputKind.CUSTOM_OBJ, + ) + else None + ) + for in_spec in ep.graph_signature.input_specs + ] + + mutated_outputs: List[Optional[str]] = [ + ( + out_spec.target + if out_spec.kind + in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION) + else None + ) + for out_spec in ep.graph_signature.output_specs + ] + + new_gm = _unlift( + new_gm, + lifted_inputs, + mutated_outputs, + ep.call_spec.in_spec, + ep.call_spec.out_spec, + ep.state_dict, + ep.constants, + forward_arg_names=forward_arg_names, + ) + unlift_gm = _create_stateful_graph_module( + new_gm, ep.range_constraints, ep.graph_signature + ) + unlift_gm.meta.update(ep.graph_module.meta) + return unlift_gm diff --git a/lib/python3.10/site-packages/torch/export/custom_obj.py b/lib/python3.10/site-packages/torch/export/custom_obj.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7f2080a4ee705a2621386c9b69a089d507544a --- /dev/null +++ b/lib/python3.10/site-packages/torch/export/custom_obj.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + + +__all__ = ["ScriptObjectMeta"] + + +@dataclass +class ScriptObjectMeta: + """ + Metadata which is stored on nodes representing ScriptObjects. + """ + + # Key into constants table to retrieve the real ScriptObject. + constant_name: str + + class_fqn: str diff --git a/lib/python3.10/site-packages/torch/export/dynamic_shapes.py b/lib/python3.10/site-packages/torch/export/dynamic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..18fc1e73f0cf05fd4be077834ed4702d78097136 --- /dev/null +++ b/lib/python3.10/site-packages/torch/export/dynamic_shapes.py @@ -0,0 +1,1220 @@ +# mypy: allow-untyped-defs +import dataclasses +import inspect +import logging +import sys +from collections import defaultdict +from enum import auto, Enum +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union + +import torch +from torch.utils._pytree import ( + _get_node_type, + BUILTIN_TYPES, + keystr, + LeafSpec, + MappingKey, + SequenceKey, + SUPPORTED_NODES, + tree_flatten, + tree_map_with_path, +) + +from .exported_program import ExportedProgram + + +if TYPE_CHECKING: + from sympy import Symbol + + from torch._guards import Source + from torch.fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint + +__all__ = [ + "Constraint", + "Dim", + "dims", + "refine_dynamic_shapes_from_suggested_fixes", +] + + +log = logging.getLogger(__name__) + + +class _DimHint(Enum): + """ + Enum for dynamic shape hints. + - AUTO means automatic inference of shape (static or dynamic). + - STATIC means static shape (always specialized). + """ + + AUTO = auto() + STATIC = auto() + + +class _Dim(type): + """ + Metaclass for :func:`Dim` types. + """ + + @staticmethod + def readable(name, min_, max_): + from torch.utils._sympy.numbers import int_oo + + if min_ == 2: + min_ = None + if max_ == int_oo: + max_ = None + if min_ is None and max_ is None: + return f"Dim('{name}')" + if min_ is None: + return f"Dim('{name}', max={max_})" + if max_ is None: + return f"Dim('{name}', min={min_})" + return f"Dim('{name}', min={min_}, max={max_})" + + def __add__(cls, other): + # e.g., dim + 1 + if type(other) is not int: + raise NotImplementedError( + f"Attempted to add {other} to {cls.__name__}, where an integer was expected. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + return cls._derive(lambda x: x + other) + + def __radd__(cls, other): + return cls + other + + def __sub__(cls, other): + # e.g., dim - 1 + if type(other) is not int: + raise NotImplementedError( + f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + return cls._derive(lambda x: x - other) + + def __rsub__(cls, other): + raise NotImplementedError( + f"Attempted to negate {cls.__name__}. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + + def __mul__(cls, other): + # e.g., dim * 2 + if type(other) is not int or other <= 0: + raise NotImplementedError( + f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + return cls._derive(lambda x: x * other) + + def __rmul__(cls, other): + return cls * other + + def _derived_name(cls, fn): + from sympy import sympify + + return str(fn(sympify(cls.__name__))) + + def _derive(cls, fn): + return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn}) + + +class _StaticDim(_Dim): + """ + Meta class for static :func:`Dim` types. + + This class is only for setting and checking static dim constraints, + and the user should never interact with it. + """ + + @property + def min(self): + return self.value # type: ignore[attr-defined] + + @property + def max(self): + return self.value # type: ignore[attr-defined] + + +class _DerivedDim(_Dim): + """ + Metaclass for derived :func:`Dim` types. + + Currently we only support increasing linear expressions with integer coefficients. + In other words, a derived Dim can always be written in the form Ax + B, where + x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive. + (In particular, the latter ensures that x < y => Ax + B < Ay + B.) + These restrictions on the form of derived Dims makes the metatheory simpler: e.g., + it simplifies computing ranges for derived Dims, solving for underlying regular Dims, + deciding equalities between derived Dims, and so on. + + The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`. + The range of a derived Dim is computed by mapping `fn` over the range of its `root`. + """ + + @property + def min(self): + # assume that self.fn is an increasing function + # TODO(avik): use sympy value range analysis instead? + from sympy import Integer + + from torch.utils._sympy.numbers import int_oo + + if self.root.min is -int_oo: # type: ignore[attr-defined] + return -int_oo # fn not needed cuz increasing + + _min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined] + root = self.root # type: ignore[attr-defined] + assert _min_symint >= 0, ( + f"Expected derived min value of {self.__name__} to be >= 0. " + f"Please specify an appropriate min value for {root.__name__} " + f"(currently {root.min})." + ) + return int(_min_symint) + + @property + def max(self): + # assume that self.fn is an increasing function + # TODO(avik): use sympy value range analysis instead? + from sympy import Integer + + from torch.utils._sympy.numbers import int_oo + + if self.root.max is int_oo: # type: ignore[attr-defined] + return int_oo # fn not needed cuz increasing + + _max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined] + root = self.root # type: ignore[attr-defined] + assert _max_symint <= sys.maxsize - 1, ( + f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. " + f"Please specify an appropriate max value for {root.__name__} " + f"(currently {root.max})." + ) + return int(_max_symint) + + def _derive(self, fn): + # We support nesting, e.g., 2*dim + 1. + # This is implemented by composing operations on the same root. + # As a consequence, roots are always regular Dims (i.e., not derived Dims). + return _DerivedDim( + self._derived_name(fn), + (int,), + {"root": self.root, "fn": lambda x: fn(self.fn(x))}, # type: ignore[attr-defined] + ) + + +def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): + """ + :func:`Dim` constructs a type analogous to a named symbolic integer with a range. + It can be used to describe multiple possible values of a dynamic tensor dimension. + Note that different dynamic dimensions of the same tensor, or of different tensors, + can be described by the same type. + + Args: + name (str): Human-readable name for debugging. + min (Optional[int]): Minimum possible value of given symbol (inclusive) + max (Optional[int]): Maximum possible value of given symbol (inclusive) + + Returns: + A type that can be used in dynamic shape specifications for tensors. + """ + + from torch.utils._sympy.numbers import int_oo + + _min = 0 if min is None else min + _max = int_oo if max is None else max + assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" + assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}" + dim = _Dim(name, (int,), {"min": _min, "max": _max}) + dim.__module__ = getattr( + inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__" + ) + return dim + + +Dim.AUTO = _DimHint.AUTO # type: ignore[attr-defined] +Dim.STATIC = _DimHint.STATIC # type: ignore[attr-defined] + + +def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None): + """ + Util to create multiple :func:`Dim` types. + """ + return tuple(Dim(name, min=min, max=max) for name in names) + + +@dataclasses.dataclass +class _ConstraintTarget: + """ + This represents input tensor dimensions. + """ + + t_id: int + dim: int + + +@dataclasses.dataclass +class _Constraint(_ConstraintTarget): + """ + This represents a Dim describing a constraint target. + + `name` is the name of the Dim. + `constraint_range` contains the min/max bounds of the Dim. + """ + + name: str + constraint_range: "StrictMinMaxConstraint" + + def _clone_with_range(self, lower=0, upper=None): + # Import sympy locally + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.numbers import int_oo + from torch.utils._sympy.value_ranges import ValueRanges + + if upper is None: + upper = int_oo + + constraint_range = StrictMinMaxConstraint( + vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), + warn_only=False, + ) + return _Constraint( + self.t_id, + self.dim, + self.name, + constraint_range, + ) + + def __ge__(self, lower): + return self._clone_with_range(lower=lower) + + def __gt__(self, lower): + return self._clone_with_range(lower=lower + 1) + + def __le__(self, upper): + return self._clone_with_range(upper=upper) + + def __lt__(self, upper): + return self._clone_with_range(upper=upper - 1) + + def __bool__(self): + # NOTE(avik): We do not support compound expressions like a <= x <= b. + # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b), + # and moreover, enforces that any overload of __bool__ must return True or False. + # FWIW, sympy also raises TypeError in this case. + raise TypeError( + "Cannot determine truth value of _Constraint. " + "If you are trying to combine _Constraint's with logical connectives, " + "you can specify them separately instead." + ) + + @property + def serializable_spec(self): + # We need a serialization compatible format of the constraint so that it + # can be savedin the graph module w/o breaking the module serialization. + # The saved constraints will be used directly for the post-exporting pass + # that converts constraints to runtime assertion. The saved constraints + # will not be saved in the serialized module. + # TODO: A better way is needed. Currently we use 't_id' to map the constraint, + # which is not reliable + return { + "t_id": self.t_id, + "dim": self.dim, + "min": self.constraint_range.vr.lower, + "max": self.constraint_range.vr.upper, + } + + +@dataclasses.dataclass +class _PhantomRoot: + """ + This represents the root of a derived Dim where the root does not directly + specify the shape of any input dimension, but the derived Dim does. + + e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim. + + The fields `name`, `constraint_range`, and `val` carried by a phantom root + help create a symbol for it. Any derived dims with this phantom root are + backed by expressions over this symbol. + """ + + name: str + constraint_range: "StrictMinMaxConstraint" + val: int + + +@dataclasses.dataclass +class _DerivedConstraint(_ConstraintTarget): + """ + This represents a derived Dim, whose root is either a regular constraint target + (which directly specifies the shape of some input dimension) or a phantom root + (which does so indirectly). + + It can be thought of as a subclass of `_Constraint`, except that it does not + support <, <=, >, >= operations. + """ + + name: str + constraint_range: "StrictMinMaxConstraint" + root: Union[_ConstraintTarget, _PhantomRoot] + fn: Callable + + @property + def serializable_spec(self): + # same as _Constraint.serializable_spec + return { + "t_id": self.t_id, + "dim": self.dim, + "min": self.constraint_range.vr.lower, + "max": self.constraint_range.vr.upper, + } + + +Constraint = Union[_Constraint, _DerivedConstraint] + + +def _process_equalities( + constraint: Constraint, + get_sources: Callable[[int, int], List["Source"]], + shape_env: "ShapeEnv", + names: Dict[str, Tuple[int, int]], + source_pairs: List[Tuple["Source", "Source"]], + derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]], + phantom_symbols: Dict[str, "Symbol"], +): + """ + Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become + fields of `EqualityConstraint`) based on a given input `constraint`. + """ + + sources = get_sources(constraint.t_id, constraint.dim) + if not sources: # empty sources due to unused shapes + return + + source, *other_sources = sources + # When t.size()[dim] maps to src0, src1, ..., srcN, we add + # constraints that make src0 "equal" to src1, ..., srcN. + source_pairs.extend((source, other_source) for other_source in other_sources) + if not isinstance(constraint, _DerivedConstraint): + if constraint.name in names: + shared_t_id, shared_dim = names[constraint.name] + other_sources = get_sources(shared_t_id, shared_dim) + source_pairs.extend( + (source, other_source) for other_source in other_sources + ) + else: + names[constraint.name] = (constraint.t_id, constraint.dim) + else: + # branch based on the root of the _DerivedConstraint + if not isinstance(constraint.root, _PhantomRoot): + # either root points to an input source + root = get_sources(constraint.root.t_id, constraint.root.dim)[0] # type: ignore[assignment] + else: + # or root points to a phantom symbol + if constraint.root.name in phantom_symbols: + root = phantom_symbols[constraint.root.name] # type: ignore[assignment] + else: + # create a phantom symbol in the shape env based on the _PhantomRoot + root = shape_env.create_symbol( + val=constraint.root.val, + source=torch._dynamo.source.ConstantSource(constraint.root.name), + dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC, + constraint_dim=constraint.root.constraint_range, + ) + phantom_symbols[constraint.root.name] = root # type: ignore[assignment] + + fn = constraint.fn + # A derived equality (source, root, fn) informally corresponds to source = fn(root). + # Here source describes an input and root might describe another input or a phantom symbol. + derived_equalities.append((source, root, fn)) + + +def _tree_map_with_path( + func: Callable[..., Any], + tree: Any, + *dynamic_shapes: Any, + tree_name: Optional[str] = None, +) -> Any: + """ + Customized tree_map for mapping pytrees to dynamic_shapes. + + For built-in types (e.g., standard collections) this behaves exactly like tree_map. + + OTOH for a user-defined class C registered with pytree, we cannot assume that a C + containing tensors can be mapped to a C containing dynamic shapes (i.e., C may not + be a polymorphic container). In that case we use the flattened form of C instead. + Thus a C(**tensors) that flattens to (**tensors) will map to (**dynamic_shapes). + + Args: + func: function to apply to each (int, float, str, bool, None, torch.Tensor) + tree: input pytree + dynamic_shapes: zero or more (typically one) dynamic_shapes to match + + Returns: + output pytree mapping func to each (int, float, str, bool, None, torch.Tensor) + """ + + def is_leaf(t): + # BUILTIN_TYPES is a subset of SUPPORTED_NODES, the latter being all types + # registered with pytree. Types *not* in BUILTIN_TYPES include primitive types + # (int, float, str, bool, None, torch.Tensor), which are not in SUPPORTED_NODES, + # as well as user-defined classes registered with pytree, which are. + return _get_node_type(t) not in BUILTIN_TYPES + + def f(path, t, *dynamic_shapes): + typ = _get_node_type(t) + # typ is not in BUILTIN_TYPES + if typ in SUPPORTED_NODES: + # thus typ is a user-defined class registered with pytree, + # in which case flatten and recurse + return tree_map_with_path( + f, + SUPPORTED_NODES[typ].flatten_fn(t)[0], + *dynamic_shapes, + is_leaf=is_leaf, + ) + else: + return func(path, t, *dynamic_shapes) + + try: + return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf) + except ValueError as e: + if "mismatch" in e.args[0]: + # When PyTree finds a structural mismatch between tree and dynamic_shapes, + # the error message is unfortunately quite horrible. Let's fix that. + assert dynamic_shapes, "Cannot be a mismatch if there is no dynamic_shapes" + assert tree_name, "Must provide a tree_name when there might be a mismatch" + + def _key(type_, context, i): + # derive a PyTree key given the type, context, and child # of a TreeSpec + if type_ is dict: + return MappingKey(context[i]) + if type_ in (list, tuple): + assert context is None + return SequenceKey(i) + raise AssertionError(f"Did not expect type {type_}") + + def raise_mismatch_error(msg): + from torch._dynamo.exc import UserError, UserErrorType + + raise UserError( + UserErrorType.INVALID_INPUT, + f"Detected mismatch between the structure of `{tree_name}` and `dynamic_shapes`: {msg}", + case_name="dynamic_shapes_validation", + ) + + def _compare(tree, dynamic_shapes, path): + # raise an error at the point where tree and dynamic_shapes differ, + # including the path to that point and the reason for the difference + rendered_path = keystr(path) + if isinstance(tree, LeafSpec): + return + if isinstance(dynamic_shapes, LeafSpec): + raise_mismatch_error( + f"`{tree_name}{rendered_path}` is a {tree.type}, " + f"but `dynamic_shapes{rendered_path}` is not" + ) + if tree.type != dynamic_shapes.type: + raise_mismatch_error( + f"`{tree_name}{rendered_path}` is a {tree.type}, " + f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}" + ) + if len(tree.children_specs) != len(dynamic_shapes.children_specs): + raise_mismatch_error( + f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, " + f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements" + ) + if tree.type is dict: + # context, children could be out of order + if sorted(tree.context) != sorted(dynamic_shapes.context): + raise_mismatch_error( + f"`{tree_name}{rendered_path}` has keys {tree.context}, " + f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}" + ) + _remap = dict( + zip(dynamic_shapes.context, dynamic_shapes.children_specs) + ) + dynamic_shapes_children_specs = [_remap[k] for k in tree.context] + else: + dynamic_shapes_children_specs = dynamic_shapes.children_specs + for i, (tree_, dynamic_shapes_) in enumerate( + zip(tree.children_specs, dynamic_shapes_children_specs) + ): + _compare( + tree_, + dynamic_shapes_, + path + [_key(tree.type, tree.context, i)], + ) + + _, tree_spec = tree_flatten(tree, is_leaf=is_leaf) + for other_tree in dynamic_shapes: + _, other_tree_spec = tree_flatten(other_tree, is_leaf) + _compare(tree_spec, other_tree_spec, []) + raise + + +def _combine_args(f, args, kwargs, _is_torch_jit_trace=False) -> Dict[str, Any]: + # combine args and kwargs following the signature of f, as it happens + # in the body of f when called with *args, **kwargs + if isinstance(f, ExportedProgram): + f = f.module() + if not _is_torch_jit_trace: + signature = ( + inspect.signature(f.forward) + if isinstance(f, torch.nn.Module) + else inspect.signature(f) + ) + kwargs = kwargs if kwargs is not None else {} + return signature.bind(*args, **kwargs).arguments + return args + + +class ShapesCollection: + """ + Builder for dynamic_shapes. + Used to assign dynamic shape specifications to tensors that appear in inputs. + + Example:: + args = ({"x": tensor_x, "others": [tensor_y, tensor_z]}) + + dim = torch.export.Dim(...) + dynamic_shapes = torch.export.ShapesCollection() + dynamic_shapes[tensor_x] = (dim, dim + 1, 8) + dynamic_shapes[tensor_y] = {0: dim * 2} + # This is equivalent to the following (now auto-generated): + # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]} + + torch.export(..., args, dynamic_shapes=dynamic_shapes) + """ + + def __init__(self): + self._shapes = {} + + def __setitem__(self, t, shape): + assert isinstance( + t, torch.Tensor + ), f"Cannot assign shape to non-tensor type {type(t)}" + # TODO(avik): check that shape is indeed a Shape + t_id = id(t) + if t_id in self._shapes: + _shape = self._shapes[t_id] + assert ( + shape == _shape + ), f"Shapes assigned to tensor do not match: expected {_shape}, got {shape}" + else: + self._shapes[id(t)] = shape + + def __getitem__(self, t): + t_id = id(t) + if t_id in self._shapes: + return self._shapes[t_id] + else: + return None + + def __len__(self): + return len(self._shapes) + + def dynamic_shapes(self, m, args, kwargs=None): + """ + Generate dynamic_shapes. + """ + + t_ids = set() + + def find_shape(path, t): + t_id = id(t) + if t_id in self._shapes: + t_ids.add(t_id) + return self._shapes[t_id] + else: + return None + + combined_args = _combine_args(m, args, kwargs) + dynamic_shapes = _tree_map_with_path(find_shape, combined_args) + if any(t_id not in t_ids for t_id in self._shapes): + raise ValueError( + "Some tensors that were assigned shapes were not found in args. " + "Maybe such tensors were copied when passing them as args? " + "Maybe such tensors are contained in classes that were not registered with pytree?" + ) + return dynamic_shapes + + +def _warn_on_None_dynamic_shape_dimension(): + msg = ( + "Using None as a dynamic shape dimension is deprecated. " + "Please use Dim.STATIC instead" + ) + # TODO(avik): raise an error in the future + log.warning(msg) + + +def _check_dynamic_shapes( + combined_args: Dict[str, Any], + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], +): + """ + Checks the dynamic_shapes specification for correctness, + using combined args + kwargs as reference for inputs structure. + """ + from torch._dynamo.exc import UserError, UserErrorType + from torch._export.non_strict_utils import _flatten_dynamic_shapes + + if dynamic_shapes is None or len(dynamic_shapes) == 0: + return + if isinstance(dynamic_shapes, (tuple, list)): + combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] + + bounds: Dict[str, Tuple[int, int]] = {} + + def check_same_bounds(dim): + if dim.__name__ in bounds: + min_, max_ = bounds[dim.__name__] + if dim.min != min_ or dim.max != max_: + this_ = _Dim.readable(dim.__name__, min_, max_) + that_ = _Dim.readable(dim.__name__, dim.min, dim.max) + raise UserError( + UserErrorType.INVALID_INPUT, + f"Found different definitions {this_} and {that_} " + f"for the same symbolic dimension {dim}!", + ) + else: + bounds[dim.__name__] = (dim.min, dim.max) + + def check_symbols(path, tensor, shape): + if isinstance(shape, dict): + for i, dim in shape.items(): + if isinstance(dim, _Dim): + check_same_bounds(dim) + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + elif not (isinstance(dim, (int, _DimHint))): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Unexpected dimension mapped to index {i} in input tensor shape {shape} " + f"specified at `dynamic_shapes{keystr(path)}` " + f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)", + case_name="dynamic_shapes_validation", + ) + elif isinstance(shape, (tuple, list)): + for i, dim in enumerate(shape): + if isinstance(dim, _Dim): + check_same_bounds(dim) + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + elif not (isinstance(dim, (int, _DimHint))): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Unexpected dimension #{i} in input tensor shape {shape} " + f"specified at `dynamic_shapes{keystr(path)}` " + f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)", + case_name="dynamic_shapes_validation", + ) + elif shape is not None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` " + f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," + f" where each dimension is an int, a Dim, Dim.AUTO, or Dim.STATIC)", + case_name="dynamic_shapes_validation", + ) + + assert isinstance(dynamic_shapes, (dict, tuple, list)) + if isinstance(dynamic_shapes, dict): + got_keys = list(dynamic_shapes.keys()) + expected_arg_names = list(combined_args.keys()) + if sorted(got_keys) != sorted(expected_arg_names): + msg = ( + f"When `dynamic_shapes` is specified as a dict, its top-level keys " + f"must be the arg names {expected_arg_names} of `inputs`, but " + f"here they are {got_keys}. " + ) + if ( + len(combined_args) == 1 + and expected_arg_names[0] not in got_keys + and isinstance(combined_args[expected_arg_names[0]], dict) + ): + msg += ( + "Since here `inputs` is a list/tuple enclosing a single dict, " + "maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?" + ) + else: + msg += ( + "Alternatively, you could also ignore arg names entirely " + "and specify `dynamic_shapes` as a list/tuple matching `inputs`." + ) + raise UserError( + UserErrorType.INVALID_INPUT, msg, case_name="dynamic_shapes_validation" + ) + + def check_shape(path, t, dynamic_shape): + if isinstance(t, torch.Tensor): + check_symbols(path, t, dynamic_shape) + else: + if dynamic_shape is not None: + rendered_path = keystr(path) + raise UserError( + UserErrorType.INVALID_INPUT, + f"Cannot associate shape {dynamic_shape} specified at `dynamic_shapes{rendered_path}` " + f"to non-tensor type {type(t)} at `inputs{rendered_path}` (expected None)", + case_name="dynamic_shapes_validation", + ) + + _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs") + + # raise user warning if both Dim.AUTO & Dims are specified in dynamic_shapes + flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes) + flatter_dynamic_shapes, _ = tree_flatten(flat_dynamic_shapes) + if any(isinstance(s, _Dim) for s in flatter_dynamic_shapes) and any( + s == _DimHint.AUTO for s in flatter_dynamic_shapes + ): + raise UserError( + UserErrorType.INVALID_INPUT, + "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " + "and can easily lead to constraint violation errors or obscure errors in torch.export. Dim/DerivedDims " + "expect all equal or related dimensions to be specified, and does not yet compose well with `Dim.AUTO`. " + "We suggest using `Dim.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), " + "torch._check(dim <= max) calls in your program to specify min/max ranges, or `Dim`/`DerivedDim` mixed with `None` " + "if you want to assert on the exact specification of your program's dynamic shapes behavior.", + case_name="dynamic_shapes_validation", + ) + + +def _transform_shapes_for_default_dynamic( + combined_args: Dict[str, Any], + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], +) -> Union[Dict[str, Any], Tuple[Any], List[Any], None]: + """ + In the long run this might not be needed, but this exists because export.export() and _dynamo.export() + historically have different semantics for how dynamic_shapes are specified, but go through the same + process of producing constraints, and now both use assume_static_by_default=False. + + For _dynamo.export(), the semantics for dynamic_shapes are: + - None: dynamic, allocated a symbol + - Dim/DerivedDim: a strict assertion on the min/max range for this symbol, and require a specification + for all dims governed by this symbol (i.e. relations, equality, linear relations, etc.) + + For export.export(), historically dynamism for unspecified dims has been undesirable, so the semantics are: + - Dim.AUTO: dynamic, allocated a symbol + - None/unspecified/Dim.STATIC: static + - Dim/DerivedDims: also a strict assertion + + To allow both APIs to follow the same process for producing constraints, this function converts dynamic_shapes + for export.export() to be compatible with _process_dynamic_shapes() and assume_static_by_default=False, turning them + into essentially what they'd look like for _dynamo.export(). + + An example conversion might look like, for a 3-d input tensor: + + input spec: { + 0: Dim.AUTO, + 1: None, # or Dim.STATIC + 2: Dim("dx"), + } + output spec: { + 0: None, # None: dynamic by default + 1: 32, # explicitly provide static shape + 2: Dim("dx"), # remains the same + } + """ + + def _tree_map_helper(tree, val): + """ + If the user generally specifies dynamic_shapes=None for a pytree input, + we'd like to convert this into a tree of Nones following the input spec, + so we can explicitly specify static dims for all tensor dimensions. + Non-builtin types for pytree (e.g. custom dataclasses) creates some difficulty, + in which case the correct format is a list containing specs for each child attribute. + """ + if (node_type := _get_node_type(tree)) not in SUPPORTED_NODES: # is_leaf + return val + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) # flatten from whatever original type + unflatten_fn = SUPPORTED_NODES[ + node_type if node_type in BUILTIN_TYPES else list + ].unflatten_fn + children = [_tree_map_helper(child, val) for child in child_pytrees] + return unflatten_fn( + children, context + ) # unflatten into original type, or list if not built-in type + + if ( + dynamic_shapes is None or len(dynamic_shapes) == 0 + ): # create pytree structure of static dim + dynamic_shapes = _tree_map_helper(combined_args, None) + if isinstance(dynamic_shapes, (tuple, list)): + combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] + + def transform_shapes(path, tensor, shape): + def _marked_dynamic(tensor, i): + # TODO(pianpwk): deprecate mark_dynamic() usage for export + return i in getattr(tensor, "_dynamo_dynamic_indices", set()) + + out: Union[None, List[Any], Dict[int, Any]] = None + if isinstance(shape, dict): + out = {} + for i, val in enumerate(tensor.shape): + dim = shape.get(i, _DimHint.STATIC) + if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: + # don't have to specify anything if dynamic + # None also works, since assume_static_by_default=False + if dim == _DimHint.AUTO: + torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing + continue + elif isinstance(dim, _Dim): + out[i] = dim + elif isinstance(dim, int): + # important that this is dim and not val, + # so we can raise error if user-specified dim != val + out[i] = dim + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + out[i] = val + else: + # make explicitly static + assert dim == _DimHint.STATIC + out[i] = val + elif isinstance(shape, (tuple, list)): + out = [] + for i, val in enumerate(tensor.shape): + dim = shape[i] + if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: + if dim == _DimHint.AUTO: + torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing + out.append(None) + elif isinstance(dim, _Dim): + out.append(dim) + elif isinstance(dim, int): + out.append(dim) + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + out.append(val) + else: + assert dim == _DimHint.STATIC + out.append(val) + out = type(shape)(out) # type: ignore[assignment] + else: + assert shape is None + if isinstance(tensor, torch.Tensor): + out = [] + for i, val in enumerate(tensor.shape): + out.append(None if _marked_dynamic(tensor, i) else val) + out = out or None + else: + out = None + return out + + def transform_shape(path, t, dynamic_shape): + if isinstance(t, torch.Tensor): + return transform_shapes(path, t, dynamic_shape) + + result = _tree_map_with_path( + transform_shape, combined_args, dynamic_shapes, tree_name="inputs" + ) + return result + + +def _process_dynamic_shapes( + combined_args: Dict[str, Any], + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], +) -> List[Constraint]: + """ + Reads the dynamic_shapes specification and produces a list of constraints. + """ + from torch._dynamo.exc import UserError, UserErrorType + + if dynamic_shapes is None or len(dynamic_shapes) == 0: + # we run with dynamic by default, so no need to produce constraints + return [] + if isinstance(dynamic_shapes, (tuple, list)): + combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] + + # map of Dim names representing input shape dimensions to constraints on them + symbols: Dict[str, List[Constraint]] = defaultdict(list) + # track roots that do not directly represent input shape dimensions + phantom_roots: Dict[str, _PhantomRoot] = {} + derived_constraints_with_phantom_root: List[_DerivedConstraint] = [] + + def to_constraint(dim, tensor, i): + import sympy + + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.solve import try_solve + from torch.utils._sympy.value_ranges import ValueRanges + + def root_value(): + # given tensor.shape[i] is the value of dim = fn(root), + # find the value of root + symbol = sympy.Symbol(dim.root.__name__, integer=True) + expr = dim.fn(symbol) + solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol) + if solution is not None: + return int(solution[1]) # type: ignore[call-overload] + else: + raise UserError( # noqa: B904 + UserErrorType.CONSTRAINT_VIOLATION, + f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " + f"of the form {expr}, where {symbol} is an integer", + ) + + if isinstance(dim, _DerivedDim): + # generate a _DerivedConstraint where the root is: + # - either a _ConstraintTarget (if dim.root directly describes an input shape) + # - or a _PhantomRoot (otherwise) + dim_root = dim.root # type: ignore[attr-defined] + if dim_root.__name__ in symbols: + # root represents an input shape dimension + root_constraint = symbols[dim_root.__name__][0] + root = _ConstraintTarget( + root_constraint.t_id, + root_constraint.dim, + ) + elif dim_root.__name__ not in phantom_roots: + # create a phantom root + root = _PhantomRoot( # type: ignore[assignment] + name=dim_root.__name__, + constraint_range=StrictMinMaxConstraint( + vr=ValueRanges(lower=dim_root.min, upper=dim_root.max), + warn_only=False, + ), + val=root_value(), + ) + phantom_roots[dim_root.__name__] = root # type: ignore[assignment] + else: + root = phantom_roots[dim_root.__name__] # type: ignore[assignment] + constraint = _DerivedConstraint( + id(tensor), + i, + dim.__name__, + StrictMinMaxConstraint( + vr=ValueRanges(lower=dim.min, upper=dim.max), + warn_only=False, + ), + root, + dim.fn, # type: ignore[attr-defined] + ) + if isinstance(root, _PhantomRoot): + # NOTE(avik): since we have not processed all inputs yet, we may replace this + # with a root that does represent an input shape dimension later (see below) + derived_constraints_with_phantom_root.append(constraint) + elif isinstance(dim, _StaticDim): + constraint = _Constraint( # type: ignore[assignment] + id(tensor), + i, + dim.__name__, + StrictMinMaxConstraint( + vr=ValueRanges(lower=dim.value, upper=dim.value), warn_only=False # type: ignore[attr-defined] + ), + ) + else: + constraint = _Constraint( # type: ignore[assignment] + id(tensor), + i, + dim.__name__, + StrictMinMaxConstraint( + vr=ValueRanges(lower=dim.min, upper=dim.max), warn_only=False # type: ignore[attr-defined] + ), + ) + return constraint + + def update_symbols(path, tensor, shape): + def _create_static_dim(tensor, i, value): + return _StaticDim(str(value), (int,), {"value": value}) + + if isinstance(shape, dict): + for i, dim in shape.items(): + if isinstance(dim, (int, _Dim)): + if isinstance(dim, int): + dim = _create_static_dim(tensor, i, dim) + constraint = to_constraint(dim, tensor, i) + symbols[dim.__name__].append(constraint) + elif isinstance(shape, (tuple, list)): + for i, dim in enumerate(shape): + if isinstance(dim, (int, _Dim)): + if isinstance(dim, int): + dim = _create_static_dim(tensor, i, dim) + constraint = to_constraint(dim, tensor, i) + symbols[dim.__name__].append(constraint) + + def assoc_shape(path, t, dynamic_shape): + if isinstance(t, torch.Tensor): + update_symbols(path, t, dynamic_shape) + + _tree_map_with_path(assoc_shape, combined_args, dynamic_shapes, tree_name="inputs") + + constraints = [] + for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root: + phantom_root_name = derived_constraint_with_phantom_root.root.name # type: ignore[union-attr] + if phantom_root_name in symbols: + # We found an input shape dimension corresponding to this name, so we + # do not need a phantom symbol for it after all. + # NOTE(avik): Overall we want to maintain the invariant that roots that + # are phantom symbols are really "phantom," i.e., they cannot be represented + # by any input source. This is important when we are deciding derived equalities, + # since we can focus our attention exclusively on input sources: deciding + # derived equalities involving phantom symbols are, in comparison, trivial. + derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0] + + for dynamic_dims in symbols.values(): + constraints.extend(dynamic_dims) + + return constraints # type: ignore[return-value] + + +def _get_dim_name_mapping( + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None] +): + name_to_dim = {} + for dim in tree_flatten( + dynamic_shapes, + is_leaf=lambda x: isinstance(x, _Dim), + )[0]: + if dim is None: + # NOTE: this must denote a non-Tensor or automatic at this point. + continue + if isinstance(dim, int): + continue + assert isinstance(dim, _Dim) # dim hints should have boiled away + name_to_dim[dim.__name__] = dim + if isinstance(dim, _DerivedDim): + name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined] + return name_to_dim + + +def refine_dynamic_shapes_from_suggested_fixes( + msg: str, + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]], +) -> Union[Dict[str, Any], Tuple[Any], List[Any]]: + """ + For working with export's dynamic shapes suggested fixes, and/or automatic dynamic shapes. + Refines the given dynamic shapes spec, given a ConstraintViolation error message and the original dynamic shapes. + + For most cases behavior is straightforward - i.e. for suggested fixes that specialize or refine a Dim's range, + or fixes that suggest a derived relation, the new dynamic shapes spec will be updated as such. + + e.g. + Suggested fixes: + + dim = Dim('dim', min=3, max=6) -> this just refines the dim's range + dim = 4 -> this specializes to a constant + dy = dx + 1 -> dy was specified as an independent dim, but is actually tied to dx with this relation + + However, suggested fixes associated with derived dims can be more complicated. + For example, if a suggested fix is provided for a root dim, the new derived dim value is evaluated based on the root. + + e.g. + dx = Dim('dx') + dy = dx + 2 + dynamic_shapes = {"x": (dx,), "y": (dy,)} + + Suggested fixes: + + dx = 4 # specialization will lead to dy also specializing = 6 + dx = Dim('dx', max=6) # dy now has max = 8 + + Derived dims suggested fixes can also be used to express divisibility constraints. + This involves creating new root dims that aren't tied to a particular input shape. + In this case the root dims won't appear directly in the new spec, but as a root of + one of the dims. + + e.g. + Suggested fixes: + + _dx = Dim('_dx', max=1024) # this won't appear in the return result, but dx will + dx = 4*_dx # dx is now divisible by 4, with a max value of 4096 + """ + + import re + + import sympy + + from torch._dynamo.exc import UserError, UserErrorType + from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence + + try: + shape_fixes_msg = msg.split("Suggested fixes:")[1].strip() + except Exception as exc: + raise UserError( + UserErrorType.INVALID_INPUT, + "Suggested fixes not found in error message given to refine_dynamic_shapes_from_suggested_fixes()", + ) from exc + + # build shape_fixes dictionary + shape_fixes = {} + for fix in shape_fixes_msg.split("\n"): + fix = fix.strip() + if match := re.match(r"(.*) = Dim\('(.*)'.*\)", fix): + name = match.group(1) + _min, _max = None, None + if match_min := re.match(r".* = Dim\('.*', min\=([0-9]+).*\)", fix): + _min = int(match_min.group(1)) + if match_max := re.match(r".* = Dim\('.*'.*max\=([0-9]+)\)", fix): + _max = int(match_max.group(1)) + shape_fixes[name] = Dim(name, min=_min, max=_max) + else: + name, expr = fix.split(" = ") + expr = sympy.sympify(expr) + if isinstance(expr, sympy.Number): + # static, integer + shape_fixes[name] = int(expr) # type: ignore[assignment] + else: + # relation or derived dim + shape_fixes[name] = expr + + name_to_dim = _get_dim_name_mapping(dynamic_shapes) + + # track derived dim roots + roots: Set[str] = set() + for k, c in shape_fixes.items(): + assert isinstance(c, (int, _Dim, _DerivedDim, sympy.Expr)) + if isinstance(c, sympy.Expr): # check dim/derived dim expression + assert _is_supported_equivalence(c) + shape_fixes[k] = c + roots.add(str(next(iter(c.free_symbols)))) + if isinstance(c, _DerivedDim): + roots.add(c.root.__name__) # type: ignore[attr-defined] + + # check keys are existing dims or new roots + for k, c in shape_fixes.items(): + assert k in name_to_dim or k in roots + + # cache so we don't produce multiple derived dim objects + derived_dim_cache: Dict[str, _DerivedDim] = {} + + def apply_fixes(path, dim, dummy): + if dim is None or isinstance(dim, int): # not dynamic + return dim + elif dim.__name__ in shape_fixes: # directly fix + fix = shape_fixes[dim.__name__] + if isinstance(fix, sympy.Expr): # now derived or related + if str(fix) in derived_dim_cache: + return derived_dim_cache[str(fix)] + else: + symbol = next(iter(fix.free_symbols)) + # try to locate symbol + if symbol.name in shape_fixes: # type: ignore[attr-defined] + root = shape_fixes[symbol.name] # type: ignore[attr-defined] + else: + assert symbol.name in name_to_dim # type: ignore[attr-defined] + root = name_to_dim[symbol.name] # type: ignore[attr-defined] + # figure out value of fix + modulus, remainder = sympy.polys.polytools.div(fix, symbol) + dim = root + if modulus != 1: + dim = int(modulus) * dim + if remainder != 0: + dim = dim + int(remainder) + derived_dim_cache[str(fix)] = dim + return dim + else: + return fix + elif isinstance(dim, _DerivedDim) and dim.root.__name__ in shape_fixes: # type: ignore[attr-defined] + if dim.__name__ in derived_dim_cache: + return derived_dim_cache[dim.__name__] + else: # evaluate new derived value based on root + _dim = dim.fn(shape_fixes[dim.root.__name__]) # type: ignore[attr-defined] + derived_dim_cache[dim.__name__] = _dim + return _dim + return dim # unchanged dim + + return _tree_map_with_path(apply_fixes, dynamic_shapes, dynamic_shapes) diff --git a/lib/python3.10/site-packages/torch/export/exported_program.py b/lib/python3.10/site-packages/torch/export/exported_program.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c71bdc91deca8a34516d12929f97bc89b9efa1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/export/exported_program.py @@ -0,0 +1,1202 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import contextlib +import copy +import dataclasses +import functools +import operator +import types +import warnings +from collections import namedtuple +from contextlib import contextmanager +from typing import ( + Any, + Callable, + Dict, + final, + Iterator, + List, + Optional, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._library.fake_class_registry import FakeScriptObject +from torch.fx._utils import first_call_function_nn_module_stack +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts + + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # such as auto-completion in tools like pylance, even when these modules are not explicitly + # imported in user code. + + import sympy + + from torch.utils._sympy.value_ranges import ValueRanges + +import torch +import torch.utils._pytree as pytree +from torch._export.utils import ( + _collect_and_set_constant_attrs, + _collect_param_buffer_metadata, + _detect_fake_mode_from_gm, + _name_hoo_subgraph_placeholders, + _overwrite_signature_for_non_persistent_buffers, + _populate_param_buffer_metadata_to_new_gm, + _rename_without_collisions, +) +from torch._export.verifier import Verifier +from torch._guards import detect_fake_mode +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.export._tree_utils import is_equivalent, reorder_kwargs +from torch.fx._compatibility import compatibility +from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.infra.pass_manager import PassManager + +from .graph_signature import ( # noqa: F401 + ArgumentSpec, + ConstantArgument, + CustomObjArgument, + ExportGraphSignature, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + SymIntArgument, + TensorArgument, + TokenArgument, +) + + +__all__ = [ + "ExportedProgram", + "ModuleCallEntry", + "ModuleCallSignature", +] + + +PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] + + +@dataclasses.dataclass +class ModuleCallSignature: + inputs: List[ArgumentSpec] + outputs: List[ArgumentSpec] + in_spec: pytree.TreeSpec + out_spec: pytree.TreeSpec + + def replace_all_uses_with(self, original_node, new_node): + for i in self.inputs: + if i.name == original_node.name: + i.name = new_node.name + for o in self.outputs: + if o.name == original_node.name: + o.name = new_node.name + + +@dataclasses.dataclass +class ModuleCallEntry: + fqn: str + signature: Optional[ModuleCallSignature] = None + + +def _disable_prexisiting_fake_mode(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + with unset_fake_temporarily(): + return fn(*args, **kwargs) + + return wrapper + + +def _fx_collection_equivalence_fn( + spec1_type: Optional[type], + spec1_context: pytree.Context, + spec2_type: Optional[type], + spec2_context: pytree.Context, +) -> bool: + """Treat containers and their immutable variants as the same type. Otherwise + compare as normal. + """ + if spec1_type is None or spec2_type is None: + return spec1_type is spec2_type and spec1_context == spec2_context + + if issubclass(spec1_type, (dict, immutable_dict)) and issubclass( + spec2_type, (dict, immutable_dict) + ): + return spec1_context == spec2_context + + if issubclass(spec1_type, (list, immutable_list)) and issubclass( + spec2_type, (list, immutable_list) + ): + return spec1_context == spec2_context + + return spec1_type is spec2_type and spec1_context == spec2_context + + +def _register_cia_to_meta(*args, **kwargs): + kernel = kwargs["kernel"] + del kwargs["kernel"] + + assert torch._C._dispatch_has_kernel_for_dispatch_key( + kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ) + + return kernel._op_dk( + torch._C.DispatchKey.CompositeImplicitAutograd, *args, **kwargs + ) + + +# This list is compiled from DispatchKey.cpp. +# The idea is that we use these keys to override +# CIA decomp in export +_AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE = [ + torch._C.DispatchKey.AutogradCPU, + torch._C.DispatchKey.AutogradCUDA, + torch._C.DispatchKey.AutogradMeta, + torch._C.DispatchKey.AutogradXLA, + torch._C.DispatchKey.AutogradLazy, + torch._C.DispatchKey.AutogradIPU, + torch._C.DispatchKey.AutogradXPU, + torch._C.DispatchKey.AutogradMPS, + torch._C.DispatchKey.AutogradHPU, + torch._C.DispatchKey.AutogradPrivateUse1, + torch._C.DispatchKey.AutogradPrivateUse2, + torch._C.DispatchKey.AutogradPrivateUse3, +] + + +@contextmanager +def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True): + # This function overrides CompositeImplicitAutograd decomp for + # functional composite ops that user specified. Ideally we want to not-decompose + # ALL composite ops but today's C++ functinalization relies on + # the fact that it is working with the opset after decomp is run. + # Hence we can only do it for functional ops. One caveat is that + # there are some composite ops that lie about their schema (claimed to be + # functional but not really aka dropout), for these cases, we just decompose. + + # When safe=False, we will assume that ops_to_preserve can be mutating/aliasing + # and their usual decompositions need to be shadowed rather than overridden. + # Thus we will avoid asserting that they are valid to preserve, and will not + # replace their CompositeImplicitAutograd kernels with NotImplemented. + # The only current users of this mode are variants of aten::to that we will + # replace with aten::_to_copy in FunctionalTensorMode.__torch_dispatch__. + + saved_tables = {} + patched_ops = set() + removed_decomps = {} + for op_overload in ops_to_preserve: + # Our strategy for deciding if we can preserve CIA is following: + # 1. The op should be known statically that it is functional + # 2. If it is maybe aliasing, we decompose because we must know if an op + # is mutating or aliasing. + # TODO (tmanlaibaatar) make this utility function and share it with functional_tensor + # decomp part. (https://github.com/pytorch/pytorch/issues/129431) + def assert_valid_to_preserve(op_overload): + if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops: + raise RuntimeError( + f"We can't detect {op_overload} as a functional op statically, so we can't preserve it" + ) + if op_overload in FunctionalTensor.metadata_fns: + raise RuntimeError( + f"{op_overload} is a metadata query function, " + "it will be preserved implicitly in our tracing system. " + "Please file an issue on github if you see otherwise" + ) + + alias_info = len( + [i for i in op_overload._schema.arguments if i.alias_info is not None] + ) + + is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable + + if is_mutating_or_aliasing: + raise RuntimeError( + f"{op_overload} is a mutating/aliasing op, we can't preserve it as is" + ) + + if not torch._C._dispatch_has_kernel(op_overload.name()): + raise RuntimeError( + f"{op_overload} is a TorchScript op, we can't preserve it as is" + ) + + return True + + if safe: + # If we didn't error, it means we can go ahead + assert_valid_to_preserve(op_overload) + + saved_tables[op_overload] = op_overload.py_kernels.copy() + patched_ops.add(op_overload) + + for override_dispatch_key in _AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE: + if override_dispatch_key not in op_overload.py_kernels: + # TODO (tmanlaibaatar)https://github.com/pytorch/pytorch/issues/129430 + op_overload.py_impl(override_dispatch_key)( + autograd_not_implemented(op_overload, deferred_error=True) + ) + if torch._C.DispatchKey.CompositeImplicitAutograd in op_overload.py_kernels: + del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] + + if safe: + + def _(*args, **kwargs): + return NotImplemented + + op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(_) + + # For fake tensor prop, we do want to register meta kernel directly + if torch._C.DispatchKey.Meta not in op_overload.py_kernels: + op_overload.py_impl(torch._C.DispatchKey.Meta)( + functools.partial(_register_cia_to_meta, kernel=op_overload) + ) + + if op_overload in decomp_table: + removed_decomps[op_overload] = decomp_table[op_overload] + del decomp_table[op_overload] + + try: + yield + finally: + for op in patched_ops: + op.py_kernels.clear() + op.py_kernels.update(saved_tables[op]) + op._dispatch_cache.clear() + + for op, decomp in removed_decomps.items(): + decomp_table[op] = decomp + + +@contextmanager +def _override_decomp_aten_to_variants(): + # Preserve variants of aten::to understanding that they are mutating/aliasing + # and their CompositeImplicitAutograd kernels will not become NotImplemented. + # We will later replace them with aten._to_copy when functionalizing. + with _override_composite_implicit_decomp( + (torch.ops.aten.to.dtype_layout, torch.ops.aten.to.dtype), + {}, + safe=False, + ): + yield + + +def _decompose_and_get_gm_with_new_signature_constants( + ep, + *, + decomp_table: Dict[torch._ops.OperatorBase, Callable], + _preserve_ops: Tuple[torch._ops.OpOverload], + joint_loss_index: Optional[int], +): + from torch._functorch.aot_autograd import aot_export_module + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.export._trace import ( + _export_to_aten_ir, + _fakify_params_buffers, + _ignore_backend_decomps, + _verify_nn_module_stack, + _verify_placeholder_names, + _verify_stack_trace, + ) + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + # TODO Merge this path with inference IR decomp, but it will require some additional work + # so I will leave it for now. T200307782 + if ep.verifier.dialect == "TRAINING": + mod = ep.module() + + fake_args = [] + for node in mod.graph.nodes: + if node.op == "placeholder": + fake_args.append(node.meta["val"]) + + fake_args_unwrapped = pytree.tree_unflatten(fake_args, mod._in_spec) + fake_mode = _detect_fake_mode_from_gm(mod) + if fake_mode is None: + fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True) + + # Fix the graph output signature to be tuple if scalar + out_spec = mod._out_spec + + orig_arg_names = mod.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] + + # aot_export expect the return type to always be a tuple. + if out_spec.type not in (list, tuple): + out_spec = pytree.TreeSpec(tuple, None, [out_spec]) + + mod.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo( + orig_arg_names, + mod._in_spec, + out_spec, + ) + ) + + mod.recompile() + + # the exported module will store constants & non-persistent buffers such that + # retracing treats them as persistent buffers, so we inform the constants lifting pass + # and overwrite the new graph signature using the previous program. + constant_attrs = _collect_and_set_constant_attrs( + ep.graph_signature, ep.constants, mod + ) + + # get params & buffers after excluding constants + fake_params_buffers = _fakify_params_buffers(fake_mode, mod) + + params_buffers_to_node_meta = _collect_param_buffer_metadata(mod) + + with _ignore_backend_decomps(), ( + fake_mode + ), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp( + _preserve_ops, + decomp_table, + ): + aten_export_artifact = _export_to_aten_ir( + mod, + # this requires empty kwargs, but not in pytree.flattened format + ( + *fake_args_unwrapped[0], + *fake_args_unwrapped[1].values(), + ), + {}, + fake_params_buffers, + constant_attrs, + decomp_table=decomp_table, + _check_autograd_state=False, + ) + + gm = aten_export_artifact.gm + new_graph_signature = aten_export_artifact.sig + + _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta, gm, new_graph_signature + ) + + # overwrite signature for non-persistent buffers + new_graph_signature = _overwrite_signature_for_non_persistent_buffers( + ep.graph_signature, new_graph_signature + ) + + _verify_nn_module_stack(gm) + _verify_stack_trace(gm) + _verify_placeholder_names(gm, new_graph_signature) + + return _remove_unneccessary_copy_op_pass(gm, new_graph_signature) + + old_placeholders = [ + node for node in ep.graph_module.graph.nodes if node.op == "placeholder" + ] + fake_args = [node.meta["val"] for node in old_placeholders] + + buffers_to_remove = [name for name, _ in ep.graph_module.named_buffers()] + for name in buffers_to_remove: + delattr(ep.graph_module, name) + + # TODO(zhxhchen17) Return the new graph_signature directly. + fake_mode = detect_fake_mode(fake_args) + fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode + with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp( + _preserve_ops, + decomp_table, + ): + gm, graph_signature = aot_export_module( + ep.graph_module, + fake_args, + decompositions=decomp_table, + trace_joint=True if joint_loss_index is not None else False, + output_loss_index=joint_loss_index + if joint_loss_index is not None + else None, + ) + + # Update the signatures with the new placeholder names in case they + # changed when calling aot_export + def update_arg(old_arg, new_ph): + if isinstance(old_arg, ConstantArgument): + return old_arg + elif isinstance(old_arg, TensorArgument): + return TensorArgument(name=new_ph.name) + elif isinstance(old_arg, SymIntArgument): + return SymIntArgument(name=new_ph.name) + raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") + + new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + new_outputs = list(gm.graph.nodes)[-1].args[0] + + # rename the placeholders + assert len(new_placeholders) == len(old_placeholders) + for old_ph, new_ph in zip(old_placeholders, new_placeholders): + new_ph.name = new_ph.target = old_ph.name + + # handle name collisions with newly decomposed graph nodes + name_map = {ph.name: ph.name for ph in new_placeholders} + for node in gm.graph.nodes: + if node.op == "placeholder": + continue + node.name = _rename_without_collisions(name_map, node.name, node.name) + + # propagate names to higher order op subgraphs + _name_hoo_subgraph_placeholders(gm) + + # Run this pass before creating input/output specs, since size-related CSE/DCE might affect output signature. + # Overwrite output specs afterwards. + from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, + ) + from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names + + if not torch._dynamo.config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" + ) + shape_env = _get_shape_env(gm) + if shape_env is not None: + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + insert_deferred_runtime_asserts( + gm, + shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) + + # update output specs + gm.recompile() + for i, name in enumerate(_graph_output_names(gm)): + if isinstance(new_outputs[i], torch.fx.Node): + new_outputs[i].name = name + + # To match the output target with correct input for input mutations + # need to find the old to new placeholder map + old_new_placeholder_map = { + spec.arg.name: new_placeholders[i].name + for i, spec in enumerate(ep.graph_signature.input_specs) + if not isinstance(spec.arg, ConstantArgument) + } + + input_specs = [ + InputSpec( + spec.kind, + update_arg(spec.arg, new_placeholders[i]), + spec.target, + spec.persistent, + ) + for i, spec in enumerate(ep.graph_signature.input_specs) + ] + output_specs = [ + OutputSpec( + spec.kind, + update_arg(spec.arg, new_outputs[i]), + old_new_placeholder_map.get(spec.target, spec.target), + ) + for i, spec in enumerate(ep.graph_signature.output_specs) + ] + + if joint_loss_index is not None: + assert graph_signature.backward_signature is not None + gradients = graph_signature.backward_signature.gradients_to_user_inputs + assert len(graph_signature.user_inputs) == len(ep.graph_signature.input_specs) + specs = { + graph_signature.user_inputs[i]: spec + for i, spec in enumerate(ep.graph_signature.input_specs) + if isinstance(spec.arg, TensorArgument) + } + for i, node in enumerate(new_outputs[len(output_specs) :]): + source = gradients[node.name] + spec = specs[source] # type: ignore[index] + if spec.kind == InputKind.PARAMETER: + kind = OutputKind.GRADIENT_TO_PARAMETER + target = spec.target + elif spec.kind == InputKind.USER_INPUT: + kind = OutputKind.GRADIENT_TO_USER_INPUT + target = source + else: + raise AssertionError(f"Unknown input kind: {spec.kind}") + output_specs.append( + OutputSpec( + kind, + TensorArgument(name=node.name), + target, + ) + ) + + assert len(new_placeholders) == len(old_placeholders) + + new_graph_signature = ExportGraphSignature( + input_specs=input_specs, output_specs=output_specs + ) + # NOTE: aot_export adds symint metadata for placeholders with int + # values; since these become specialized, we replace such metadata with + # the original values. + # Also, set the param/buffer metadata back to the placeholders. + for old_node, new_node in zip(old_placeholders, new_placeholders): + if not isinstance(old_node.meta["val"], torch.Tensor): + new_node.meta["val"] = old_node.meta["val"] + + if ( + new_node.target in new_graph_signature.inputs_to_parameters + or new_node.target in new_graph_signature.inputs_to_buffers + ): + for k, v in old_node.meta.items(): + new_node.meta[k] = v + return gm, new_graph_signature + + +def _remove_unneccessary_copy_op_pass( + gm: torch.fx.GraphModule, new_graph_signature: ExportGraphSignature +) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]: + """ + Removes redundant copy_ node that was introduced due to mutated buffer. + """ + with gm._set_replace_hook(new_graph_signature.get_replace_hook()): + for node in gm.graph.nodes: + if node.op == "output": + args, _ = pytree.tree_flatten(node.args) + for out in args: + if ( + isinstance(out, torch.fx.Node) + and out.name in new_graph_signature.buffers_to_mutate + ): + if ( + out.op == "call_function" + and out.target == torch.ops.aten.copy.default + ): + out.replace_all_uses_with(out.args[1]) # type: ignore[arg-type] + gm.graph.erase_node(out) + gm.recompile() + return gm, new_graph_signature + + +def _common_getitem_elimination_pass( + gm: torch.fx.GraphModule, graph_signature, module_call_graph +): + with gm._set_replace_hook(graph_signature.get_replace_hook()): + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + node_id: Dict[torch.fx.Node, str] = {} + getitems: Dict[str, torch.fx.Node] = {} + for node in list(module.graph.nodes): + if node.op == "call_function" and node.target == operator.getitem: + source, idx = node.args + new_id = f"{node_id[source]}.{idx}" + if new_id in getitems: + node.replace_all_uses_with(getitems[new_id]) + for entry in module_call_graph: + if entry.signature is not None: + entry.signature.replace_all_uses_with( + node, getitems[new_id] + ) + module.graph.erase_node(node) + else: + getitems[new_id] = node + node_id[node] = new_id + else: + node_id[node] = node.name + + +def _decompose_exported_program( + ep, + *, + decomp_table: Dict[torch._ops.OperatorBase, Callable], + _preserve_ops: Tuple[torch._ops.OpOverload], + joint_loss_index: Optional[int], +): + gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants( + ep, + decomp_table=decomp_table, + _preserve_ops=_preserve_ops, + joint_loss_index=joint_loss_index, + ) + + # TODO unfortunately preserving graph-level metadata is not + # working well with aot_export. So we manually copy it. + # (The node-level meta is addressed above.) + gm.meta.update(ep.graph_module.meta) + + new_range_constraints = _get_updated_range_constraints( + gm, + ep.range_constraints, + ) + + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=new_graph_signature, + state_dict=ep.state_dict, + range_constraints=new_range_constraints, + module_call_graph=copy.deepcopy(ep.module_call_graph), + example_inputs=ep.example_inputs, + constants=ep.constants, + ) + return exported_program + + +class ExportedProgram: + """ + Package of a program from :func:`export`. It contains + an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing + tensor values of all lifted parameters and buffers, and various metadata. + + You can call an ExportedProgram like the original callable traced by + :func:`export` with the same calling convention. + + To perform transformations on the graph, use ``.module`` property to access + an :class:`torch.fx.GraphModule`. You can then use + `FX transformation `_ + to rewrite the graph. Afterwards, you can simply use :func:`export` + again to construct a correct ExportedProgram. + """ + + def __init__( + self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: torch.fx.Graph, + graph_signature: ExportGraphSignature, + state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]], + range_constraints: "Dict[sympy.Symbol, Any]", + module_call_graph: List[ModuleCallEntry], + example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None, + constants: Optional[ + Dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]] + ] = None, + *, + verifiers: Optional[List[Type[Verifier]]] = None, + ): + # Remove codegen related things from the graph. It should just be a flat graph. + graph._codegen = torch.fx.graph.CodeGen() + self._graph_module = _create_graph_module_for_export(root, graph) + if isinstance(root, torch.fx.GraphModule): + self._graph_module.meta.update(root.meta) + + _common_getitem_elimination_pass( + self._graph_module, graph_signature, module_call_graph + ) + self._graph_signature: ExportGraphSignature = graph_signature + self._state_dict: Dict[str, Any] = state_dict + self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints + assert module_call_graph is not None + self._module_call_graph: List[ModuleCallEntry] = module_call_graph + self._example_inputs = example_inputs + + self._constants = constants or {} + + verifiers = verifiers or [Verifier] + assert all(issubclass(v, Verifier) for v in verifiers) + self._verifiers = verifiers + # Validate should be always the last step of the constructor. + self.validate() + + @property + @compatibility(is_backward_compatible=False) + def graph_module(self): + return self._graph_module + + @property + @compatibility(is_backward_compatible=False) + def graph(self): + return self.graph_module.graph + + @property + @compatibility(is_backward_compatible=False) + def graph_signature(self): + return self._graph_signature + + @property + @compatibility(is_backward_compatible=False) + def state_dict(self): + return self._state_dict + + @compatibility(is_backward_compatible=False) + def parameters(self) -> Iterator[torch.nn.Parameter]: + """ + Returns an iterator over original module's parameters. + """ + for _, param in self.named_parameters(): + yield param + + @compatibility(is_backward_compatible=False) + def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]: + """ + Returns an iterator over original module parameters, yielding + both the name of the parameter as well as the parameter itself. + """ + for param_name in self.graph_signature.parameters: + yield param_name, self.state_dict[param_name] + + @compatibility(is_backward_compatible=False) + def buffers(self) -> Iterator[torch.Tensor]: + """ + Returns an iterator over original module buffers. + """ + for _, buf in self.named_buffers(): + yield buf + + @compatibility(is_backward_compatible=False) + def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]: + """ + Returns an iterator over original module buffers, yielding + both the name of the buffer as well as the buffer itself. + """ + non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) + for buffer_name in self.graph_signature.buffers: + if buffer_name in non_persistent_buffers: + yield buffer_name, self.constants[buffer_name] + else: + yield buffer_name, self.state_dict[buffer_name] + + @property + @compatibility(is_backward_compatible=False) + def range_constraints(self): + return self._range_constraints + + @property + @compatibility(is_backward_compatible=False) + def module_call_graph(self): + return self._module_call_graph + + @property + @compatibility(is_backward_compatible=False) + def example_inputs(self): + return self._example_inputs + + @property + @compatibility(is_backward_compatible=False) + def call_spec(self): + CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"]) + + if len(self.module_call_graph) == 0: + return CallSpec(in_spec=None, out_spec=None) + assert self.module_call_graph[0].fqn == "" + return CallSpec( + in_spec=self.module_call_graph[0].signature.in_spec, + out_spec=self.module_call_graph[0].signature.out_spec, + ) + + @property + @compatibility(is_backward_compatible=False) + def verifier(self) -> Any: + return self._verifiers[0] + + @property + @compatibility(is_backward_compatible=False) + def dialect(self) -> str: + assert self._verifiers is not None + return self._verifiers[0].dialect + + @property + @compatibility(is_backward_compatible=False) + def verifiers(self): + return self._verifiers + + @property + @compatibility(is_backward_compatible=False) + def tensor_constants(self): + return self._constants + + @property + @compatibility(is_backward_compatible=False) + def constants(self): + return self._constants + + def _get_flat_args_with_check(self, args, kwargs): + """Flatten args, kwargs using pytree, then, check specs. + + Args: + args: List[Any] original args passed to __call__ + kwargs: Dict[str, Any] original kwargs passed to __call + + Returns: + A tuple of (flat_args, received_spec) + flat_args is flattend args / kwargs + received_spec is the pytree spec produced while flattening the + tuple (args, kwargs) + """ + in_spec = self.call_spec.in_spec + if in_spec is not None: + kwargs = reorder_kwargs(kwargs, in_spec) + flat_args_with_path, received_spec = pytree.tree_flatten_with_path( + (args, kwargs) + ) # type: ignore[possibly-undefined] + self._check_input_constraints(flat_args_with_path) + flat_args = tuple(x[1] for x in flat_args_with_path) + return flat_args, received_spec + + def _graph_module_flat_inputs(self, args: Any, kwargs: Any) -> Any: + """Transform args, kwargs of __call__ to args for graph_module. + + self.graph_module takes stuff from state dict as inputs. + The invariant is for ep: ExportedProgram is + ep(args, kwargs) == + ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs))) + """ + + in_spec = self.call_spec.in_spec + flat_args, received_spec = self._get_flat_args_with_check(args, kwargs) + if in_spec is not None and not is_equivalent( + received_spec, in_spec, _fx_collection_equivalence_fn + ): + raise ValueError( + "Trying to flatten user inputs with exported input tree spec: \n" + f"{in_spec}\n" + "but actually got inputs with tree spec of: \n" + f"{received_spec}" + ) + + additional_inputs = [] + for input_ in self.graph_signature.input_specs: + if input_.kind == InputKind.USER_INPUT: + continue + elif input_.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, + ): + if input_.persistent is False: + # This is a non-persistent buffer, grab it from our + # constants instead of the state dict. + additional_inputs.append(self.constants[input_.target]) + else: + additional_inputs.append(self.state_dict[input_.target]) + elif input_.kind in ( + InputKind.CONSTANT_TENSOR, + InputKind.CUSTOM_OBJ, + ): + additional_inputs.append(self.constants[input_.target]) + additional_inputs = tuple(additional_inputs) + + # NOTE: calling convention is first params, then buffers, then args as user supplied them. + # See: torch/_functorch/aot_autograd.py#L1034 + return additional_inputs + flat_args + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + raise RuntimeError( + "Unable to call ExportedProgram directly. " + "You should use `exported_program.module()` instead." + ) + + def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs): + """Process potential mutations to the input. + + Because self.graph_module is functional, so mutations has to be written + back after execution of graph_module. + """ + import torch._export.error as error + + flat_args, _ = self._get_flat_args_with_check(orig_args, orig_kwargs) + if self.call_spec.out_spec is not None: + buffer_mutation = self.graph_signature.buffers_to_mutate + user_input_mutation = self.graph_signature.user_inputs_to_mutate + num_mutated = len(buffer_mutation) + len(user_input_mutation) + mutated_values = res[:num_mutated] + + # Exclude dependency token from final result. + assertion_dep_token = self.graph_signature.assertion_dep_token + if assertion_dep_token is not None: + assertion_dep_token_index = next(iter(assertion_dep_token.keys())) + res = res[:assertion_dep_token_index] + + res = res[num_mutated:] + try: + res = pytree.tree_unflatten(res, self.call_spec.out_spec) + except Exception: + _, received_spec = pytree.tree_flatten(res) + raise error.InternalError( # noqa: B904 + "Trying to flatten user outputs with exported output tree spec: \n" + f"{self.call_spec.out_spec}\n" + "but actually got outputs with tree spec of: \n" + f"{received_spec}" + ) + finally: + user_inputs = [ + spec + for spec in self.graph_signature.input_specs + if spec.kind == InputKind.USER_INPUT + ] + for i, value in enumerate(mutated_values): + output_spec = self.graph_signature.output_specs[i] + if output_spec.kind == OutputKind.BUFFER_MUTATION: + assert output_spec.target is not None + self.state_dict[output_spec.target] = value + elif output_spec.kind == OutputKind.USER_INPUT_MUTATION: + assert output_spec.target is not None + index = next( + i + for i, spec in enumerate(user_inputs) + if spec.arg.name == output_spec.target + ) + flat_args[index].copy_(value) + else: + raise AssertionError(f"Unexpected kind: {output_spec.kind}") + return res + + def __str__(self) -> str: + graph_module = self.graph_module.print_readable( + print_output=False, colored=False + ).replace("\n", "\n ") + string = ( + "ExportedProgram:\n" + f" {graph_module}\n" + f"Graph signature: {self.graph_signature}\n" + f"Range constraints: {self.range_constraints}\n" + ) + return string + + def module(self) -> torch.nn.Module: + """ + Returns a self contained GraphModule with all the parameters/buffers inlined. + """ + from ._unlift import _unlift_exported_program_lifted_states + + module = _unlift_exported_program_lifted_states(self) + + def _train(self, mode: bool = True): + raise NotImplementedError("Calling train() is not supported yet.") + + def _eval(self, mode: bool = True): + raise NotImplementedError("Calling eval() is not supported yet.") + + module.train = types.MethodType(_train, module) # type: ignore[method-assign] + module.eval = types.MethodType(_eval, module) # type: ignore[method-assign] + return module + + def _num_lifted_params_buffers(self): + return next( + ( + i + for i, s in enumerate(self._graph_signature.input_specs) + if s.kind == InputKind.USER_INPUT + ), + len(self._graph_signature.input_specs), + ) + + @_disable_prexisiting_fake_mode + def run_decompositions( + self, + decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, + _preserve_ops: Tuple[torch._ops.OpOverload, ...] = (), + ) -> "ExportedProgram": + """ + Run a set of decompositions on the exported program and returns a new + exported program. By default we will run the Core ATen decompositions to + get operators in the + `Core ATen Operator Set `_. + + For now, we do not decompose joint graphs. + """ + from torch._decomp import core_aten_decompositions + + if decomp_table is None: + decomp_table = core_aten_decompositions() + + return _decompose_exported_program( + self, + decomp_table=decomp_table, + _preserve_ops=_preserve_ops, # type: ignore[arg-type] + joint_loss_index=None, + ) + + def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram": + pm = PassManager(list(passes)) + # Since we abstractly run the passes, we need to disable backend decomp here + # again. + from torch.export._trace import _ignore_backend_decomps + + with _ignore_backend_decomps(): + res = pm(self.graph_module) + transformed_gm = res.graph_module if res is not None else self.graph_module + assert transformed_gm is not None + + if transformed_gm is self.graph_module and not res.modified: + return self + + # TODO(zhxchen17) Remove this. + def _get_updated_graph_signature( + old_signature: ExportGraphSignature, + new_gm: torch.fx.GraphModule, + ) -> ExportGraphSignature: + """ + Update the graph signature's user_input/user_outputs. + """ + new_input_specs = [] + for i, node in enumerate(new_gm.graph.nodes): + if node.op != "placeholder": + break + + assert i < len( + old_signature.input_specs + ), "Number of inputs changed after transformation" + old_input_spec = old_signature.input_specs[i] + arg = ( + old_input_spec.arg + if isinstance( + old_input_spec.arg, (ConstantArgument, CustomObjArgument) + ) + else type(old_input_spec.arg)(node.name) + ) + new_input_specs.append( + InputSpec( + old_input_spec.kind, + arg, + old_input_spec.target, + old_input_spec.persistent, + ) + ) + + output_node = list(new_gm.graph.nodes)[-1] + assert output_node.op == "output" + + new_output_specs = [] + for i, node in enumerate(output_node.args[0]): + assert i < len( + old_signature.output_specs + ), "Number of outputs changed after transformation" + old_output_spec = old_signature.output_specs[i] + arg = ( + old_output_spec.arg + if isinstance( + old_output_spec.arg, (ConstantArgument, CustomObjArgument) + ) + else type(old_output_spec.arg)(node.name) + ) + new_output_specs.append( + OutputSpec(old_output_spec.kind, arg, old_output_spec.target) + ) + + new_signature = ExportGraphSignature( + input_specs=new_input_specs, output_specs=new_output_specs + ) + return new_signature + + transformed_ep = ExportedProgram( + root=transformed_gm, + graph=transformed_gm.graph, + graph_signature=_get_updated_graph_signature( + self.graph_signature, transformed_gm + ), + state_dict=self.state_dict, + range_constraints=_get_updated_range_constraints( + transformed_gm, + self.range_constraints, + ), + module_call_graph=copy.deepcopy(self._module_call_graph), + example_inputs=self.example_inputs, + constants=self.constants, + verifiers=self.verifiers, + ) + transformed_ep.graph_module.meta.update(self.graph_module.meta) + transformed_ep.graph_module.meta.update(res.graph_module.meta) + return transformed_ep + + def _check_input_constraints(self, flat_args_with_path): + from torch._export.utils import _check_input_constraints_for_graph + + placeholders = [p for p in self.graph.nodes if p.op == "placeholder"] + input_placeholders = [ + p + for p, s in zip(placeholders, self.graph_signature.input_specs) + if s.kind == InputKind.USER_INPUT + ] + _check_input_constraints_for_graph( + input_placeholders, flat_args_with_path, self.range_constraints + ) + + @compatibility(is_backward_compatible=False) + def validate(self): + self._validate() + + # TODO: remove this + @final + def _validate(self): + assert ( + len(self.verifiers) > 0 + ), "ExportedProgram must have at least one verifier." + for v in self.verifiers: + v().check(self) + + # TODO(zhxchen17) Formalize this. + def _update( + self, graph_module, graph_signature, *, state_dict=None, verifiers=None + ) -> "ExportedProgram": + return ExportedProgram( + root=graph_module, + graph=graph_module.graph, + graph_signature=graph_signature, + state_dict=state_dict if state_dict is not None else self.state_dict, + range_constraints=copy.deepcopy(self.range_constraints), + module_call_graph=copy.deepcopy(self._module_call_graph), + example_inputs=self.example_inputs, + constants=self.constants, + verifiers=verifiers if verifiers is not None else self.verifiers, + ) + + +def _get_shape_env(gm): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(vals) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + +def _get_updated_range_constraints( + gm: torch.fx.GraphModule, + old_range_constraints: "Optional[Dict[sympy.Symbol, Any]]" = None, +) -> "Dict[sympy.Symbol, Any]": + assert old_range_constraints is not None + + shape_env = _get_shape_env(gm) + if shape_env is None: + return {} + + range_constraints = copy.copy(old_range_constraints) + range_constraints = { + k: v for k, v in range_constraints.items() if k not in shape_env.replacements + } + # Only when we have an unbacked symint, and it's used as constructor inputs, + # runtime_var_to_range will make a difference compated to var_to_range. + # e.g. [2, oo) -> [0, oo) + for k, v in shape_env.var_to_range.items(): + if k not in shape_env.replacements and k not in range_constraints: + range_constraints[k] = v + return range_constraints + + +def _create_graph_module_for_export(root, graph): + try: + gm = torch.fx.GraphModule(root, graph) + except SyntaxError: + # If custom objects stored in memory are being used in the graph, + # the generated python code will result in a syntax error on the custom + # object, since it is unable to parse the in-memory object. However + # we can still run the graph eagerly through torch.fx.Interpreter, + # so we will bypass this error. + warnings.warn( + "Unable to execute the generated python source code from " + "the graph. The graph module will no longer be directly callable, " + "but you can still run the ExportedProgram, and if needed, you can " + "run the graph module eagerly using torch.fx.Interpreter." + ) + gm = torch.fx.GraphModule(root, torch.fx.Graph()) + gm._graph = graph + + return gm diff --git a/lib/python3.10/site-packages/torch/export/graph_signature.py b/lib/python3.10/site-packages/torch/export/graph_signature.py new file mode 100644 index 0000000000000000000000000000000000000000..4730cf6febcddfdfbfa0407ec0cd8895507d36a0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/export/graph_signature.py @@ -0,0 +1,593 @@ +# mypy: allow-untyped-defs +import dataclasses +from enum import auto, Enum +from typing import Collection, Dict, List, Mapping, Optional, Set, TYPE_CHECKING, Union + +from torch._library.fake_class_registry import FakeScriptObject + + +if TYPE_CHECKING: + import torch + from torch._functorch._aot_autograd.schemas import GraphSignature + +__all__ = [ + "ConstantArgument", + "CustomObjArgument", + "ExportBackwardSignature", + "ExportGraphSignature", + "InputKind", + "InputSpec", + "OutputKind", + "OutputSpec", + "SymIntArgument", + "TensorArgument", +] + + +@dataclasses.dataclass +class TensorArgument: + name: str + + +@dataclasses.dataclass +class TokenArgument: + name: str + + +@dataclasses.dataclass +class SymIntArgument: + name: str + + +@dataclasses.dataclass +class CustomObjArgument: + name: str + class_fqn: str + fake_val: Optional[FakeScriptObject] = None + + +@dataclasses.dataclass +class ConstantArgument: + name: str + value: Union[int, float, bool, str, None] + + +ArgumentSpec = Union[ + TensorArgument, + SymIntArgument, + ConstantArgument, + CustomObjArgument, + TokenArgument, +] + + +class InputKind(Enum): + USER_INPUT = auto() + PARAMETER = auto() + BUFFER = auto() + CONSTANT_TENSOR = auto() + CUSTOM_OBJ = auto() + TOKEN = auto() + + +@dataclasses.dataclass +class InputSpec: + kind: InputKind + arg: ArgumentSpec + target: Optional[str] + persistent: Optional[bool] = None + + def __post_init__(self): + if self.kind == InputKind.BUFFER: + assert ( + self.persistent is not None + ), "Failed to specify persistent flag on BUFFER." + assert isinstance( + self.arg, + ( + TensorArgument, + SymIntArgument, + ConstantArgument, + CustomObjArgument, + TokenArgument, + ), + ), f"got {type(self.arg)}" + + +class OutputKind(Enum): + USER_OUTPUT = auto() + LOSS_OUTPUT = auto() + BUFFER_MUTATION = auto() + GRADIENT_TO_PARAMETER = auto() + GRADIENT_TO_USER_INPUT = auto() + USER_INPUT_MUTATION = auto() + TOKEN = auto() + + +@dataclasses.dataclass +class OutputSpec: + kind: OutputKind + arg: ArgumentSpec + target: Optional[str] + + def __post_init__(self): + assert isinstance( + self.arg, + ( + TensorArgument, + SymIntArgument, + ConstantArgument, + TokenArgument, + CustomObjArgument, + ), + ), self.arg + + +@dataclasses.dataclass +class ExportBackwardSignature: + gradients_to_parameters: Dict[str, str] + gradients_to_user_inputs: Dict[str, str] + loss_output: str + + +@dataclasses.dataclass +class ExportGraphSignature: + """ + :class:`ExportGraphSignature` models the input/output signature of Export Graph, + which is a fx.Graph with stronger invariants gurantees. + + Export Graph is functional and does not access "states" like parameters + or buffers within the graph via ``getattr`` nodes. Instead, :func:`export` + gurantees that parameters, buffers, and constant tensors are lifted out of + the graph as inputs. Similarly, any mutations to buffers are not included + in the graph either, instead the updated values of mutated buffers are + modeled as additional outputs of Export Graph. + + The ordering of all inputs and outputs are:: + + Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] + Outputs = [*mutated_inputs, *flattened_user_outputs] + + e.g. If following module is exported:: + + class CustomModule(nn.Module): + def __init__(self) -> None: + super(CustomModule, self).__init__() + + # Define a parameter + self.my_parameter = nn.Parameter(torch.tensor(2.0)) + + # Define two buffers + self.register_buffer('my_buffer1', torch.tensor(3.0)) + self.register_buffer('my_buffer2', torch.tensor(4.0)) + + def forward(self, x1, x2): + # Use the parameter, buffers, and both inputs in the forward method + output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 + + # Mutate one of the buffers (e.g., increment it by 1) + self.my_buffer2.add_(1.0) # In-place addition + + return output + + Resulting Graph would be:: + + graph(): + %arg0_1 := placeholder[target=arg0_1] + %arg1_1 := placeholder[target=arg1_1] + %arg2_1 := placeholder[target=arg2_1] + %arg3_1 := placeholder[target=arg3_1] + %arg4_1 := placeholder[target=arg4_1] + %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) + %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) + %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) + %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) + %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) + return (add_tensor_2, add_tensor_1) + + Resulting ExportGraphSignature would be:: + + ExportGraphSignature( + input_specs=[ + InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target='my_parameter'), + InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), + InputSpec(kind=, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), + InputSpec(kind=, arg=TensorArgument(name='arg3_1'), target=None), + InputSpec(kind=, arg=TensorArgument(name='arg4_1'), target=None) + ], + output_specs=[ + OutputSpec(kind=, arg=TensorArgument(name='add_2'), target='my_buffer2'), + OutputSpec(kind=, arg=TensorArgument(name='add_1'), target=None) + ] + ) + """ + + input_specs: List[InputSpec] + output_specs: List[OutputSpec] + + # A list of parameters uniquely identified by mangled fully qualified name + @property + def parameters(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.PARAMETER + if isinstance(s.target, str) + ) + + # A list of buffers uniquely identified by mangled fully qualified name + @property + def buffers(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.BUFFER + if isinstance(s.target, str) + ) + + @property + def non_persistent_buffers(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.BUFFER + if s.persistent is False + if isinstance(s.target, str) + ) + + # A list of lifted constant tensors + @property + def lifted_tensor_constants(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.CONSTANT_TENSOR + if isinstance(s.target, str) + ) + + @property + def lifted_custom_objs(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.CUSTOM_OBJ + if isinstance(s.target, str) + ) + + # Graph node names of pytree-flattened inputs of original program + @property + def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]: + user_inputs: List[Union[int, float, bool, None, str]] = [] + for s in self.input_specs: + if s.kind != InputKind.USER_INPUT: + continue + + if isinstance(s.arg, (TensorArgument, SymIntArgument, CustomObjArgument)): + user_inputs.append(s.arg.name) + elif isinstance(s.arg, ConstantArgument): + user_inputs.append(s.arg.value) + else: + raise RuntimeError(f"{s.arg} is not a valid user inputs") + return tuple(user_inputs) + + # Graph node names of pytree-flattened outputs of original program + @property + def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]: + user_outputs: List[Union[int, float, bool, None, str]] = [] + for s in self.output_specs: + if s.kind != OutputKind.USER_OUTPUT: + continue + + if isinstance(s.arg, (TensorArgument, SymIntArgument)): + user_outputs.append(s.arg.name) + elif isinstance(s.arg, ConstantArgument): + user_outputs.append(s.arg.value) + elif isinstance(s.arg, CustomObjArgument): + user_outputs.append(s.arg.name) + else: + raise RuntimeError(f"{s.arg} is not a valid user output") + return tuple(user_outputs) + + # A dictionary mapping graph input node names to parameters. If a graph input + # name is found in this dictionary, it is guranteed to be a lifted parameter. + @property + def inputs_to_parameters(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.input_specs + if s.kind == InputKind.PARAMETER + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + # A dictionary mapping graph input node names to buffers. If a graph input + # name is found in this dictionary, it is guranteed to be a lifted buffer. + @property + def inputs_to_buffers(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) # type: ignore[union-attr, misc] + for s in self.input_specs + if s.kind == InputKind.BUFFER + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + # A dictionary mapping graph output node names to buffers that are mutated in the + # original program. Buffers that are not mutated will not be found in this dictionary. + @property + def buffers_to_mutate(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.output_specs + if s.kind == OutputKind.BUFFER_MUTATION + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + @property + def user_inputs_to_mutate(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.output_specs + if s.kind == OutputKind.USER_INPUT_MUTATION + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + # A dictionary mapping graph input node names to lifted tensor constants. + @property + def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.input_specs + if s.kind == InputKind.CONSTANT_TENSOR + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + @property + def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.input_specs + if s.kind == InputKind.CUSTOM_OBJ + and isinstance(s.arg, CustomObjArgument) + and isinstance(s.target, str) + ) + + @property + def backward_signature(self) -> Optional[ExportBackwardSignature]: + loss_output = None + gradients_to_parameters: Dict[str, str] = {} + gradients_to_user_inputs: Dict[str, str] = {} + for spec in self.output_specs: + if spec.kind == OutputKind.LOSS_OUTPUT: + assert loss_output is None + assert isinstance(spec.arg, TensorArgument) + loss_output = spec.arg.name + elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER: + assert isinstance(spec.target, str) + assert isinstance(spec.arg, TensorArgument) + gradients_to_parameters[spec.arg.name] = spec.target + elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT: + assert isinstance(spec.target, str) + assert isinstance(spec.arg, TensorArgument) + gradients_to_user_inputs[spec.arg.name] = spec.target + + if loss_output is None: + return None + + return ExportBackwardSignature( + loss_output=loss_output, + gradients_to_parameters=gradients_to_parameters, + gradients_to_user_inputs=gradients_to_user_inputs, + ) + + # Map from assertion dependency token index to assertion dep token output + # name in output. The shape of output after aot_autograd will be like: + # (updated_inputs, user_outputs, dep_token). + @property + def assertion_dep_token(self) -> Optional[Mapping[int, str]]: + return None + + @property + def input_tokens(self) -> Collection[str]: + input_tokens = [] + for s in self.input_specs: + if s.kind == InputKind.TOKEN: + assert isinstance(s.arg, TokenArgument) + input_tokens.append(s.arg.name) + return tuple(input_tokens) + + @property + def output_tokens(self) -> Collection[str]: + output_tokens = [] + for s in self.output_specs: + if s.kind == OutputKind.TOKEN: + assert isinstance(s.arg, TokenArgument) + output_tokens.append(s.arg.name) + return tuple(output_tokens) + + def __post_init__(self) -> None: + assertion_dep_token = self.assertion_dep_token + if assertion_dep_token is None: + return + assert len(assertion_dep_token) == 1 + assertion_dep_token_index = next(iter(assertion_dep_token.keys())) + assert ( + len(self.user_outputs) + len(self.buffers_to_mutate) + == assertion_dep_token_index + ) + + def replace_all_uses(self, old: str, new: str): + """ + Replace all uses of the old name with new name in the signature. + """ + assert isinstance(old, str) + assert isinstance(new, str) + arg_types = (TensorArgument, SymIntArgument, CustomObjArgument, TokenArgument) + for o in self.output_specs: + if isinstance(o.arg, arg_types): + if o.arg.name == old: + o.arg.name = new + for i in self.input_specs: + if isinstance(i.arg, arg_types): + if i.arg.name == old: + i.arg.name = new + + def get_replace_hook(self): + def _(old, new, user): + if user.op in ("output", "input"): + self.replace_all_uses(old.name, new) + + return _ + + +def _immutable_dict(items): + """ + Creates a mapping where items cannot be added, deleted, or updated. + NOTE: The immutability is shallow (like tuple is an immutable collection). + """ + from types import MappingProxyType + + return MappingProxyType(dict(items)) + + +def _make_argument_spec(node, token_names) -> ArgumentSpec: + from torch import ScriptObject, SymInt + from torch._library.fake_class_registry import FakeScriptObject + from torch._subclasses.fake_tensor import FakeTensor + + if isinstance(node, (int, bool, float, type(None), str)): + # For const outputs we just directly return this + return ConstantArgument(name="", value=node) + + assert ( + "val" in node.meta + ), f"{node} is not a constant or a node with a 'val' metadata field" + val = node.meta["val"] + if node.name in token_names: + return TokenArgument(name=node.name) + elif isinstance(val, FakeTensor): + return TensorArgument(name=node.name) + elif isinstance(val, SymInt): + return SymIntArgument(name=node.name) + elif isinstance(val, ScriptObject): + return CustomObjArgument(name=node.name, class_fqn=val._type().qualified_name()) # type: ignore[attr-defined] + elif isinstance(val, FakeScriptObject): + return CustomObjArgument( + name=node.name, class_fqn=val.script_class_name, fake_val=val + ) + elif isinstance(val, (int, bool, str, float, type(None))): + return ConstantArgument(name=node.name, value=val) + else: + raise AssertionError( + f"Encountered an unsupported object of type {type(val)} " + f"while writing the metadata for exported program" + ) + + +def _convert_to_export_graph_signature( + graph_signature: "GraphSignature", + gm: "torch.fx.GraphModule", + non_persistent_buffers: Set[str], +) -> "ExportGraphSignature": + from torch.utils import _pytree as pytree + + is_joint = graph_signature.backward_signature is not None + + # unpack objects + user_inputs = set(graph_signature.user_inputs) + inputs_to_parameters = graph_signature.inputs_to_parameters + inputs_to_buffers = graph_signature.inputs_to_buffers + user_outputs = set(graph_signature.user_outputs) + buffer_mutations = graph_signature.buffers_to_mutate + user_input_mutations = graph_signature.user_inputs_to_mutate + grad_params = graph_signature.backward_signature.gradients_to_parameter if is_joint else {} # type: ignore[union-attr] + grad_user_inputs = graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {} # type: ignore[union-attr] + loss_output = graph_signature.backward_signature.loss_output if is_joint else None # type: ignore[union-attr] + input_tokens = graph_signature.input_tokens + output_tokens = graph_signature.output_tokens + + inputs = [ + _make_argument_spec(node, input_tokens) + for node in gm.graph.nodes + if node.op == "placeholder" + ] + outputs = [ + _make_argument_spec(node, output_tokens) + for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args) + ] + + def to_input_spec(inp: ArgumentSpec) -> InputSpec: + if isinstance(inp, TokenArgument): + return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None) + + if not isinstance(inp, TensorArgument): + return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) + name = inp.name + if name in user_inputs: + return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) + elif name in inputs_to_parameters: + return InputSpec( + kind=InputKind.PARAMETER, + arg=inp, + target=inputs_to_parameters[name], # type: ignore[index] + ) + elif name in inputs_to_buffers: + return InputSpec( + kind=InputKind.BUFFER, + arg=inp, + target=inputs_to_buffers[name], # type: ignore[index] + persistent=(inputs_to_buffers[name] not in non_persistent_buffers), # type: ignore[index] + ) + else: + raise AssertionError(f"Unknown tensor input kind: {name}") + + def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec: + if isinstance(o, TokenArgument): + return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None) + + if not isinstance(o, TensorArgument): + return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) + name = o.name + if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): + if name in buffer_mutations: + return OutputSpec( + kind=OutputKind.BUFFER_MUTATION, + arg=o, + target=buffer_mutations[name], # type: ignore[index] + ) + elif name in user_input_mutations: + return OutputSpec( + kind=OutputKind.USER_INPUT_MUTATION, + arg=o, + target=user_input_mutations[name], # type: ignore[index] + ) + else: + raise AssertionError(f"Unknown tensor mutation kind: {name}") + else: + if name in user_outputs: + return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) + + elif name in grad_params: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_PARAMETER, + arg=o, + target=grad_params[name], + ) + elif name in grad_user_inputs: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_USER_INPUT, + arg=o, + target=grad_user_inputs[name], + ) + elif name == loss_output: + return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) + + else: + raise AssertionError(f"Unknown tensor output kind: {name}") + + input_specs = [to_input_spec(inp) for inp in inputs] + output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)] + return ExportGraphSignature(input_specs=input_specs, output_specs=output_specs) diff --git a/lib/python3.10/site-packages/torch/export/unflatten.py b/lib/python3.10/site-packages/torch/export/unflatten.py new file mode 100644 index 0000000000000000000000000000000000000000..992818f1758048ae7f3c0de41971ff5317c0b55d --- /dev/null +++ b/lib/python3.10/site-packages/torch/export/unflatten.py @@ -0,0 +1,1258 @@ +# mypy: allow-untyped-defs +import abc +import copy +import operator +from collections import defaultdict +from contextlib import contextmanager +from copy import deepcopy +from enum import Enum +from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.fx._pytree as fx_pytree +import torch.utils._pytree as pytree +from torch._library.fake_class_registry import FakeScriptObject +from torch.export._tree_utils import reorder_kwargs +from torch.export.exported_program import ( + ConstantArgument, + ExportedProgram, + InputKind, + ModuleCallSignature, + SymIntArgument, + TensorArgument, +) +from torch.fx._symbolic_trace import is_fx_tracing +from torch.fx.graph_module import _print_readable +from torch.utils._pytree import GetAttrKey, SequenceKey + +from ._remove_effect_tokens_pass import _remove_effect_tokens + + +__all__ = ["InterpreterModule", "UnflattenedModule", "unflatten", "FlatArgsAdapter"] + + +class _AttrKind(Enum): + PARAMETER = "parameter" + BUFFER = "buffer" + CONSTANT = "constant" + + +RUN_WITH_INTERPRETER = True + + +@contextmanager +def _disable_interpreter(): + global RUN_WITH_INTERPRETER + old_flag = RUN_WITH_INTERPRETER + RUN_WITH_INTERPRETER = False + try: + yield + finally: + RUN_WITH_INTERPRETER = old_flag + + +# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module +# This installs empty Modules where none exist yet if they are subpaths of target +def _assign_attr( + from_obj: Union[torch.Tensor, torch.ScriptObject], + to_module: torch.nn.Module, + target: str, + attr_kind: _AttrKind, + persistent: bool = True, +): + *prefix, field = target.split(".") + for item in prefix: + t = getattr(to_module, item, None) + + if t is None: + t = torch.nn.Module() + setattr(to_module, item, t) + to_module = t + + if attr_kind == _AttrKind.PARAMETER: + assert isinstance(from_obj, torch.nn.Parameter) + to_module.register_parameter(field, from_obj) + elif attr_kind == _AttrKind.BUFFER: + assert isinstance(from_obj, torch.Tensor) + to_module.register_buffer(field, from_obj, persistent=persistent) + elif attr_kind == _AttrKind.CONSTANT: + assert not isinstance( + from_obj, FakeScriptObject + ), "FakeScriptObject should only exist during tracing." + assert isinstance( + from_obj, + ( + torch.Tensor, + torch.ScriptObject, + ), + ) + setattr(to_module, field, from_obj) + + +class InterpreterModule(torch.nn.Module): + """A module that uses torch.fx.Interpreter to execute instead of the usual + codegen that GraphModule uses. This provides better stack trace information + and makes it easier to debug execution. + """ + + def __init__( + self, + graph: torch.fx.Graph, + ): + super().__init__() + self.graph = graph + self.graph.owning_module = self + self._run_with_interpeter = RUN_WITH_INTERPRETER + + def forward(self, *args, **kwargs): + assert self.graph_module is not None, "Didn't finalize this InterpreterModule" + if not is_fx_tracing() and ( + torch.compiler.is_dynamo_compiling() or not self._run_with_interpeter + ): + # Dynamo cannot trace through torch.fx.Interpreter, so fall back to + # GraphModule codegen in this instance. + # Patch the codegened forward to run with this InterpreterModule, + # so attribute accesses, etc. are on this module instead. + return type(self.graph_module).forward(self, *args, **kwargs) + else: + if kwargs: + # Handle **kwargs. FX only natively supports positional + # arguments (through placeholders). So in order to pass in + # kwargs, we must correspond the names of the placeholders with + # the keys in the kwarg dict. + arg_list = list(args) + kwarg_names = self.arg_names[len(arg_list) :] + for kwarg_name in kwarg_names: + if kwarg_name in kwargs: + arg_list.append(kwargs[kwarg_name]) + + # Assert that the kwargs passed in exactly match the positional + # arguments specified by the GraphModule. This should be + # guaranteed by the unflattening process. + assert len(kwarg_names) == len(kwargs) + assert len(arg_list) == len(self.arg_names) + args = tuple(arg_list) + + return torch.fx.Interpreter(self, graph=self.graph).run( + *args, enable_io_processing=False + ) + + def finalize(self): + # We need to "finalize" because GraphModule populates its own state_dict + # based on the get_attrs observed in the graph. So we need to fully + # construct the graph and call _sink_params before generating this + # GraphModule. + + # need to set `graph_module` directly on the dict to avoid it getting + # registered as a submodule. + self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph) + self.graph.lint() + + # Cache arg names for kwarg handling (see forward()) + self.arg_names = [] + for node in self.graph.nodes: + if node.op == "placeholder": + self.arg_names.append(node.target) + + def print_readable( + self, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + ): + return _print_readable( + self, + "InterpreterModule", + print_output, + include_stride, + include_device, + colored, + ) + + +class FlatArgsAdapter(abc.ABC): + """ + Adapts input arguments with ``input_spec`` to align ``target_spec``. + """ + + @abc.abstractmethod + def adapt( + self, + target_spec: pytree.TreeSpec, + input_spec: pytree.TreeSpec, + input_args: List[Any], + ) -> List[Any]: + """NOTE: This adapter may mutate given ``input_args_with_path``.""" + ... + + +class UnflattenedModule(torch.nn.Module): + def __init__( + self, + export_module: ExportedProgram, + flat_args_adapter: Optional[FlatArgsAdapter] = None, + ): + super().__init__() + if export_module.graph_signature.backward_signature is not None: + raise ValueError("Unflattening on JointExportModule NYI") + + fqn_list = [entry.fqn for entry in export_module.module_call_graph] + assert fqn_list[0] == "" + export_graph = deepcopy(export_module.graph) + self.graph_signature = deepcopy(export_module.graph_signature) + self.graph = torch.fx.Graph() + self.module_call_graph = deepcopy(export_module.module_call_graph) + self.flat_args_adapter = flat_args_adapter + # Flag to indicate whether args have been adapted. + self.adapted = False + self._run_with_interpeter = RUN_WITH_INTERPRETER + + _inplace_buffer_mutations(export_graph, self.graph_signature) + _outline_submodules(export_graph, self) + + self.range_constraints = export_module.range_constraints + self.equality_constraints: List = [] + + # aliasing/unused param or buffer issues: + # in strict-mode export, dynamo export will deduplicate aliased tensors, + # and ignore unused tensors. For aliasing, this causes issues when some aliases + # are unused, and we're unable to match the placeholder node to the correct FQN. + # This leads to the graph signature potentially having the wrong target FQN, + # and downstream issues where parameters are assigned to the wrong target attribute, + # mismatching the relevant placeholder node in the unflattened module. + # To resolve this we restore (_assign_attr) all aliased/unused tensors in + # the state_dict as module attributes, but only keep the used tensors in the + # graph's forward pass (_sink_params). + state_dict = export_module.state_dict + assigned_params: Set[str] = set() # tracking unused params + id_to_param: Dict[int, torch.nn.Parameter] = {} # handling weight-sharing + for name in self.graph_signature.parameters: # this loop adds used params + param = state_dict[name] + if id(param) not in id_to_param: + id_to_param[id(param)] = torch.nn.Parameter( + param.clone(), requires_grad=param.requires_grad + ) + + _assign_attr( + id_to_param[id(param)], + self, + name, + attr_kind=_AttrKind.PARAMETER, + ) + assigned_params.add(name) + + non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) + assigned_buffers: Set[str] = set() # tracking unused buffers + id_to_buffer: Dict[ + int, Tuple[torch.nn.Parameter, bool] + ] = {} # handle weight-sharing + for name in self.graph_signature.buffers: # this loop adds used buffers + if name in non_persistent_buffers: + persistent = False + buffer = export_module.constants[name] + else: + persistent = True + buffer = state_dict[name] + + if id(buffer) not in id_to_buffer: + id_to_buffer[id(buffer)] = (buffer.clone(), persistent) + + _assign_attr( + id_to_buffer[id(buffer)][0], + self, + name, + attr_kind=_AttrKind.BUFFER, + persistent=persistent, + ) + assigned_buffers.add(name) + + # restore aliased/unused params and buffers + # these appear in state dict but not graph signature + for name, tensor in state_dict.items(): + if name in assigned_params or name in assigned_buffers: # already assigned + continue + + is_buffer = False + if id(tensor) in id_to_buffer or not isinstance( + tensor, torch.nn.Parameter + ): # aliased buffer + is_buffer = True + + if is_buffer: + if ( + id(tensor) not in id_to_buffer + ): # this is completely unused (not weight-sharing) + id_to_buffer[id(tensor)] = ( + tensor, + True, + ) # assign to respect original model + _assign_attr( + id_to_buffer[id(tensor)][0], + self, + name, + attr_kind=_AttrKind.BUFFER, + persistent=True, + ) + else: + if id(tensor) not in id_to_param: # this is unused + id_to_param[id(tensor)] = tensor + _assign_attr( + id_to_param[id(tensor)], + self, + name, + attr_kind=_AttrKind.PARAMETER, + ) + + # use id map so we don't double-clone aliased constants + id_to_const: Dict[int, Union[torch.Tensor, torch._C.ScriptObject]] = {} + for fqn, constant in export_module.constants.items(): + if id(constant) not in id_to_const: + if isinstance(constant, torch.Tensor): + constant = constant.clone() + id_to_const[id(constant)] = constant + _constant = id_to_const[id(constant)] + _assign_attr( + _constant, + self, + fqn, + attr_kind=_AttrKind.CONSTANT, + ) + + # This is to handle parameters/buffers that point to the same tensor + # object id -> list of (node_name, target_name) + consts_map: Dict[int, List[Tuple[str, str]]] = defaultdict(list) + consts_targets: Set[str] = set() + + def add_to_consts_map(obj_id, node_name, target_name): + name_list = consts_map[obj_id] + name_list.append((node_name, target_name)) + + added_params_buffers: Set[str] = set() # track aliased/unused params, buffers + for s in self.graph_signature.input_specs: + if s.kind == InputKind.PARAMETER or ( + s.kind == InputKind.BUFFER and s.persistent + ): + assert hasattr(s.arg, "name") + assert isinstance(s.target, str) + add_to_consts_map( + id(export_module.state_dict[s.target]), s.arg.name, s.target + ) + consts_targets.add(s.target) + added_params_buffers.add(s.target) + elif ( + (s.kind == InputKind.BUFFER and not s.persistent) + or s.kind == InputKind.CONSTANT_TENSOR + or s.kind == InputKind.CUSTOM_OBJ + ): + assert hasattr(s.arg, "name") + assert isinstance(s.target, str) + add_to_consts_map( + id(export_module.constants[s.target]), s.arg.name, s.target + ) + consts_targets.add(s.target) + + # add constants that are aliased and don't appear in graph signature + for const_name, const in export_module.constants.items(): + if const_name not in consts_targets: + assert ( + id(const) in consts_map + ), "Constants should be either aliased or appear in graph signature" + ph_name, _ = consts_map[id(const)][0] + add_to_consts_map(id(const), ph_name, const_name) + added_params_buffers.add(s.target) + + # add aliased/unused params and buffers that don't appear in graph signature + for fqn, tensor in export_module.state_dict.items(): + if fqn not in added_params_buffers: + if id(tensor) not in consts_map: + # completely unused (no weight-sharing), ignore. + # this weight doesn't appear in graph module, + # so won't cause FQN assignment issues + continue + ph_name, _ = consts_map[id(tensor)][0] + add_to_consts_map(id(tensor), ph_name, fqn) + + # node name -> list of possible targets + inputs_to_state: Dict[str, List[str]] = {} + for node_target in consts_map.values(): + targets = [t[1] for t in node_target] + for n, _ in node_target: + inputs_to_state[n] = targets + + _sink_params(self, inputs_to_state, []) + + # Helper function to check input nodes of `module` has been processed. + def check_module_inputs(module, scope): + if hasattr(module, "graph"): + for node in module.graph.nodes: + # sink_params() should turn placeholders into get_attr nodes + # for attributes that are within scope of the current + # module. We allow attributes to remain as placeholders if + # they are inputs in the original module signature, meaning + # they are a parent module's attribute, and therefore out of + # scope of the current module. + if ( + node.op == "placeholder" + and node.name in inputs_to_state + and any( + fqn.split(".")[: len(scope)] == scope + for fqn in inputs_to_state[node.name] + ) # matching scope to avoid wrong assert + ): + raise AssertionError( + f"{node.name} was not sunk into the module {scope} which has the graph: {module.graph}" + ) + # Recursively check the submodules. + for name, submod in module.named_children(): + scope.append(name) + check_module_inputs(submod, scope) + + # Recurively check all input nodes have been processed. + check_module_inputs(self, []) + + # Cache so we don't have to compute this every time. + # NOTE: this needs to be kept in sync with the placeholders in + # self.graph, but currently we have no way to guarantee that. + self.input_placeholders = [ + node for node in self.graph.nodes if node.op == "placeholder" + ] + self.check_input_constraints = True + # TODO(zhxchen17) We can register modules ahead of time instead of reorder later. + fqn_order = {fqn: i for i, fqn in enumerate(fqn_list)} + # In the case of legacy IR, we might be missing some modules from metadata. + for name, _ in self.named_modules(remove_duplicate=False): + if name not in fqn_order: + fqn_order[name] = len(fqn_order) + _reorder_submodules(self, fqn_order) + assert [fqn for fqn, _ in self.named_modules(remove_duplicate=False)] == list( + fqn_order.keys() + ) + self.graph.lint() + + def _print_graph(self): + for fqn, mod in self.named_modules(): + print(fqn + ":") + if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph): + print(mod.graph) + + def forward(self, *args, **kwargs): + signature = self.module_call_graph[0].signature + + reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec) + + flat_args_with_path, in_spec = pytree.tree_flatten_with_path( + (args, reordered_kwargs) + ) + flat_args = [x[1] for x in flat_args_with_path] + if is_fx_tracing(): + return_val = torch.fx.Interpreter(self, graph=self.graph).run( + *flat_args, enable_io_processing=False + ) + # For scalar return value, fx.Graph wraps in a tuple + if isinstance(return_val, tuple) and len(return_val) == 1: + return return_val[0] + return return_val + + if in_spec != signature.in_spec: + if not self.adapted: + print( + "Input treespec does not match with exported module's: \n" + f"Input treespec: {in_spec}. ", + f"Exported module treespec: {signature.in_spec}", + ) + if self.flat_args_adapter is None: + raise TypeError( + "There is no flat args adapter sepcified. " + "Are you sure you are calling this with the right arguments? " + ) + else: + if not self.adapted: + print("Adapting flat arg to match exported module's treespec") + flat_args = self.flat_args_adapter.adapt( + target_spec=signature.in_spec, + input_spec=in_spec, + input_args=flat_args, + ) + self.adapted = True + if len(flat_args) != signature.in_spec.num_leaves: + raise TypeError( + f"Flat args adaption failed, number of args mismatch " + f"Adatped: {len(flat_args)} \n" + f"Exported module: {signature.in_spec.num_leaves}" + ) + + if self.check_input_constraints: + # Import here to avoid an unfortunate circular dependency. + # TODO(suo): untangle this. + from torch._export.utils import _check_input_constraints_for_graph + + if self.adapted is True: + # TODO(suo): The FlatArgsAdapter returns a list of flat args, + # which we don't have keypaths for. For now, just create a dummy + # keypath to associate with the arg. + new_flat_args_with_path = [ # type: ignore[var-annotated] + ((SequenceKey(idx=0), GetAttrKey(name="")), arg) + for arg in flat_args + ] + else: + new_flat_args_with_path = flat_args_with_path # type: ignore[assignment] + + _check_input_constraints_for_graph( + self.input_placeholders, new_flat_args_with_path, self.range_constraints + ) + if torch.compiler.is_dynamo_compiling() and not self._run_with_interpreter: + tree_out = torch.fx.GraphModule(self, self.graph)(*flat_args) + else: + tree_out = torch.fx.Interpreter(self, graph=self.graph).run( + *flat_args, enable_io_processing=False + ) + return pytree.tree_unflatten(tree_out, signature.out_spec) + + def print_readable( + self, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + ): + return _print_readable( + self, + "UnflattenedModule", + print_output, + include_stride, + include_device, + colored, + ) + + +def unflatten( + module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None +) -> UnflattenedModule: + """Unflatten an ExportedProgram, producing a module with the same module + hierarchy as the original eager module. This can be useful if you are trying + to use :mod:`torch.export` with another system that expects a module + hierachy instead of the flat graph that :mod:`torch.export` usually produces. + + .. note:: The args/kwargs of unflattened modules will not necessarily match + the eager module, so doing a module swap (e.g. :code:`self.submod = + new_mod`) will not necessarily work. If you need to swap a module out, you + need to set the :code:`preserve_module_call_signature` parameter of + :func:`torch.export.export`. + + Args: + module (ExportedProgram): The ExportedProgram to unflatten. + flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's. + + Returns: + An instance of :class:`UnflattenedModule`, which has the same module + hierarchy as the original eager module pre-export. + """ + module = _remove_effect_tokens(module) + return UnflattenedModule(module, flat_args_adapter) + + +def _inplace_buffer_mutations(graph: torch.fx.Graph, graph_signature) -> None: + """Transform buffer mutations from their functionalized form into a copy_ + node in the graph. + + Functionalization represents buffer mutation by passing the buffer as an input and output. So for example, the eager code: + def forward(self, x): + self.buffer += x + return x * x + + Will become a graph that looks like: + def forward(self, buffer, x): + mutated_buffer = aten.add(buffer, x) + mul = aten.mul(x, x) + return (mutated_buffer, mul) + + We want to inplace this into something that looks like the original eager code: + def forward(self, buffer, x): + mutated_buffer = aten.add(buffer, x) + buffer.copy_(mutated_buffer) + mul = aten.mul(x, x) + return (mul,) + """ + output_node = next(iter(reversed(graph.nodes))) + assert output_node.op == "output" and len(output_node.args) == 1 + return_args = output_node.args[0] + + mutation_node_to_buffer = graph_signature.buffers_to_mutate + mutations = return_args[: len(mutation_node_to_buffer)] + buffers_to_inputs = {v: k for k, v in graph_signature.inputs_to_buffers.items()} + input_name_to_node = { + node.name: node for node in graph.nodes if node.op == "placeholder" + } + + for mutation in mutations: + buffer_name = mutation_node_to_buffer[mutation.name] + input_name = buffers_to_inputs[buffer_name] + input_node = input_name_to_node[input_name] + + with graph.inserting_after(mutation): + new_node = graph.create_node( + "call_function", torch.ops.aten.copy_, (input_node, mutation) + ) + for k, v in mutation.meta.items(): + new_node.meta[k] = v + # Replace all uses of the previously functional mutation with our copy_ output. + mutation.replace_all_uses_with(new_node, lambda x: x is not new_node) + + # Remove the mutated buffer from the graph outputs, since we don't need to + # thread it through anymore. We don't need to handle the inputs, which will + # be handled by _sink_params. + user_outputs = tuple( + return_args[len(mutation_node_to_buffer) :], + ) + output_node.args = ((user_outputs),) + + +def _is_prefix(candidate, target): + """Check whether `candidate` is a prefix of `target`.""" + return len(candidate) < len(target) and target[: len(candidate)] == candidate + + +def _compute_accessor(parent_fqn: str, child_fqn: str) -> str: + if parent_fqn == "": + # Handle the root module correctly. + return child_fqn + + parent_split = parent_fqn.split(".") + child_split = child_fqn.split(".") + + # TODO: support skip connection by inlining the child module. + if child_split[: len(parent_split)] != parent_split: + raise RuntimeError( + f"Child module '{child_fqn}' is not a descendant of parent mldule '{parent_fqn}'." + "This is currently unsupported." + "Please try to make child module attach to parent module direclty." + ) + return ".".join(child_split[len(parent_split) :]) + + +def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): + def graph_dump(graph: torch.fx.Graph) -> str: + ret = [] + nodes_idx: Dict[int, int] = {} + + def arg_dump(arg) -> str: + if isinstance(arg, torch.fx.Node): + return "%" + str(nodes_idx[id(arg)]) + return str(arg) + + for i, node in enumerate(graph.nodes): + args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)] + args_dump += [ + f"{key}={value}" + for key, value in pytree.tree_map(arg_dump, node.kwargs).items() + ] + target = node.target if node.op == "call_function" else "" + ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})") + nodes_idx[id(node)] = i + return "\n".join(ret) + + assert graph_dump(x.graph) == graph_dump(y.graph) + + +def _add_spec(gm: torch.nn.Module, spec) -> str: + i = 0 + while hasattr(gm, f"_spec_{i}"): + i += 1 + name = f"_spec_{i}" + setattr(gm, name, spec) + return name + + +def _generate_flatten(gm: torch.nn.Module, node, spec) -> torch.fx.Node: + name = _add_spec(gm, spec) + spec_node = gm.graph.get_attr(name) + return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node)) + + +def _generate_unflatten(gm: torch.nn.Module, nodes, spec) -> torch.fx.Node: + name = _add_spec(gm, spec) + spec_node = gm.graph.get_attr(name) + return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node)) + + +def _get_submodule(mod: torch.nn.Module, target: str): + *prefix, field = target.split(".") + + for item in prefix: + submod = getattr(mod, item, None) + + if submod is None: + return None + + if not isinstance(submod, torch.nn.Module): + return None + + mod = submod + + return getattr(mod, field, None) + + +def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Module): + *prefix, field = target.split(".") + + for item in prefix: + submod = getattr(mod, item, None) + + if submod is None: + submod = torch.nn.Module() + setattr(mod, item, submod) + + if not isinstance(submod, torch.nn.Module): + return False + + mod = submod + + mod.add_module(field, module_to_add) + + +class _ModuleFrame: + def __init__( + self, + flat_graph: torch.fx.Graph, + nodes: Tuple[torch.fx.Node, ...], + seen_nodes, + seen_modules, + parent, + module_stack: List[str], + module_id, + module_call_graph: Dict[str, ModuleCallSignature], + module: Optional[torch.nn.Module] = None, + ): + self.flat_graph = flat_graph + self.nodes = nodes + self.seen_nodes = seen_nodes + self.seen_modules = seen_modules + self.parent = parent + self.module_stack = module_stack + self.module_id = module_id + + self.module_call_graph = module_call_graph + self.verbose = False + + self.fqn = self.module_stack[-1] + if module is not None: + self.module = module + else: + self.module = InterpreterModule(torch.fx.Graph()) + if self.module_id in self.seen_modules: + self.cached_graph_module = self.seen_modules[self.module_id] + else: + self.cached_graph_module = None + self.seen_modules[self.module_id] = self.module + + self.graph = self.module.graph + + # Mapping of nodes in the flat graph to nodes in this graph. + self.node_map: Dict[torch.fx.Node, torch.fx.Node] = {} + self.node_to_placeholder = {} + + self.parent_call_module: Optional[torch.fx.Node] = None + if parent is not None: + accessor = _compute_accessor(parent.fqn, self.fqn) + _add_submodule( + parent.module, + accessor, + ( + self.module + if self.cached_graph_module is None + else self.cached_graph_module + ), + ) + self.parent_call_module = parent.graph.call_module(accessor) + + signature = module_call_graph.get(self.fqn) + if signature is not None and self.parent is not None: + assert signature.in_spec.num_children == 2 + args_spec = signature.in_spec.children_specs[0] + kwargs_spec = signature.in_spec.children_specs[1] + assert args_spec.context is None + assert kwargs_spec.context is not None + + with self.graph.inserting_after(None): + arg_nodes = [] + for idx in range(args_spec.num_children): + arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}")) + kwarg_nodes = {} + for name in kwargs_spec.context: + kwarg_nodes[name] = self.graph.placeholder(name) + flat_args = _generate_flatten( + self.module, + (tuple(arg_nodes), kwarg_nodes), + signature.in_spec, + ) + for idx, arg in enumerate(signature.inputs): + flat_arg_node = self.graph.create_node( + op="call_function", + target=operator.getitem, + args=(flat_args, idx), + name=( + arg.name + if not isinstance(arg, ConstantArgument) + else f"_constant_{idx}" + ), + ) + if isinstance(arg, ConstantArgument): + continue + + if arg.name in self.seen_nodes: + flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) + self.node_to_placeholder[ + self.seen_nodes[arg.name] + ] = flat_arg_node + + with self.parent.graph.inserting_before(self.parent_call_module): + input_nodes: List[Optional[torch.fx.Node]] = [] + for input in signature.inputs: + if isinstance(input, ConstantArgument) and input.value is None: + input_nodes.append(None) + elif input.name not in self.seen_nodes: + input_nodes.append(None) + else: + assert isinstance(input, (TensorArgument, SymIntArgument)) + input_nodes.append( + self.parent.remap_input(self.seen_nodes[input.name]) + ) + + inputs_node = _generate_unflatten( + self.parent.module, + input_nodes, + signature.in_spec, + ) + + args_node = self.parent.graph.call_function( + operator.getitem, (inputs_node, 0) + ) + kwargs_node = self.parent.graph.call_function( + operator.getitem, (inputs_node, 1) + ) + arg_nodes = [ + self.parent.graph.call_function(operator.getitem, (args_node, i)) + for i in range(args_spec.num_children) + ] + kwarg_nodes = { + k: self.parent.graph.call_function( + operator.getitem, (kwargs_node, k) + ) + for k in kwargs_spec.context + } + assert self.parent_call_module is not None + self.parent_call_module.args = tuple(arg_nodes) + self.parent_call_module.kwargs = kwarg_nodes + + def add_placeholder(self, x): + assert self.fqn != "", f"Cannot add placeholder {x} to root module" + assert x.graph is self.flat_graph + # x is not in subgraph, create a new placeholder for subgraph + with self.graph.inserting_before(None): + placeholder_node = self.graph.placeholder(x.name, type_expr=x.type) + # copy all meta fields, even if some fields might be irrelvant for + # the placeholder node + placeholder_node.meta = copy.copy(x.meta) + self.node_to_placeholder[x] = placeholder_node + + def copy_sym_call_function(self, x): + # This only exists because we deduplicate sym_size nodes in the flat export graph, + # and if preserve_module_call_signature is set, we may not be able to pass sym_size + # nodes, or their downstream users, as inputs to submodule calls. + # To avoid this we copy these call_function nodes with sym_type results. + # This should however only be done for sym_type nodes - call_function nodes on tensors + # should not be deduplicated in the first place. + args = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.args) + kwargs = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.kwargs) + node = self.graph.call_function(x.target, args, kwargs) + node.meta = copy.copy(x.meta) + self.node_map[x] = node + return node + + def remap_input(self, x): + assert x.graph is self.flat_graph + if x in self.node_map: + return self.node_map[x] + self.print(f"remap_input({x})") + if x in self.node_to_placeholder: + return self.node_to_placeholder[x] + elif ( + x.op == "placeholder" + or self.module_call_graph.get(self.fqn) is None + # allow placeholder creation if we are not preserving module call signature + ): + self.add_placeholder(x) + if self.parent_call_module is not None: + # Important to *prepend* the output to match how we are + # inserting placeholder nodes. + with self.parent.graph.inserting_before(self.parent_call_module): + self.parent_call_module.insert_arg(0, self.parent.remap_input(x)) + return self.node_to_placeholder[x] + elif x.op == "call_function" and ( + x.target + in ( + torch.ops.aten.sym_size.int, + torch.ops.aten.item.default, + torch.ops.aten.unbind.int, + torch.ops.aten.sum.dim_IntList, + torch.ops.aten.view.default, + torch.ops.aten.diff.default, + ) + or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator") + ): + # export deduplicates sym_size nodes, and may need to re-copy them + # if module call signature needs to be preserved + self.copy_sym_call_function(x) + return self.node_map[x] + else: + raise RuntimeError( + f"Could not run remap_input() on op type: {x.op} for node {x}" + ) + + def finalize_outputs(self): + orig_outputs = [] + + signature = self.module_call_graph.get(self.fqn) + if signature is not None and self.parent is not None: + for output in signature.outputs: + if isinstance(output, (TensorArgument, SymIntArgument)): + if output.name in self.seen_nodes: + orig_outputs.append(self.seen_nodes[output.name]) + else: + orig_outputs.append(None) + else: + raise RuntimeError( + f"Unsupported data type for output node: {output}" + ) + + def get_actual_output_node(output): + if output is None: + return None + + seen_node = self.seen_nodes[output.name] + if seen_node in self.node_map: + return self.node_map[seen_node] + elif seen_node in self.node_to_placeholder: + return self.node_to_placeholder[seen_node] + else: + raise RuntimeError( + f"Could not find output node {output}. Graph: {self.graph}" + ) + + tree_out_node = _generate_unflatten( + self.module, + tuple(get_actual_output_node(output) for output in orig_outputs), + signature.out_spec, + ) + parent_out: Optional[torch.fx.Node] = _generate_flatten( + self.parent.module, self.parent_call_module, signature.out_spec + ) + graph_outputs: Union[torch.fx.Node, List[torch.fx.Node]] = tree_out_node + else: + graph_outputs = [] + # Iterate through nodes we have copied into self.graph. + for orig_node in self.node_map.keys(): + for user_node in orig_node.users: + if user_node.name not in self.seen_nodes: + # external user node, need to expose as an output + orig_outputs.append(orig_node) + graph_outputs.append(self.node_map[orig_node]) + break + + parent_out = self.parent_call_module + if len(graph_outputs) == 1: + graph_outputs = graph_outputs[0] + + assert isinstance(graph_outputs, (list, torch.fx.Node)) + + self.graph.output(graph_outputs) + + # Rewrite outputs in parent module + if parent_out is None: + return + + parent_out.meta["val"] = ( + graph_outputs.meta.get("val") + if isinstance(graph_outputs, torch.fx.Node) + else [o.meta.get("val") for o in graph_outputs] + ) + + if len(orig_outputs) == 1 and signature is None: + self.parent.node_map[orig_outputs[0]] = parent_out + else: + for i, orig_output in enumerate(orig_outputs): + if orig_output is None: + continue + # Use Proxy to record getitem access. + proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index] + proxy_out.meta["val"] = orig_output.meta.get("val") + self.parent.node_map[orig_output] = proxy_out + + if self.cached_graph_module is not None: + _verify_graph_equivalence(self.cached_graph_module, self.module) + + def copy_node(self, node): + self.print("copying", node.format_node()) + self.node_map[node] = self.graph.node_copy(node, self.remap_input) + self.seen_nodes[node.name] = node + + def run_outer(self): + i = 0 + for node in self.flat_graph.nodes: + self.print(i, node.meta.get("nn_module_stack"), node.format_node()) + i += 1 + + # Copy all graph inputs + node_idx: int = 0 + node = self.nodes[node_idx] + while node.op == "placeholder": + self.copy_node(node) + node_idx += 1 + node = self.nodes[node_idx] + + self.run_from(node_idx) + + # Copy graph outputs + for node in self.flat_graph.nodes: + if node.op == "output": + self.copy_node(node) + + def print(self, *args, **kwargs): + if self.verbose: + print(*args, **kwargs) + + def run_from(self, node_idx): + module_idx = 0 + # Walk through the graph, building up a new graph with the right submodules + while node_idx < len(self.nodes): + node = self.nodes[node_idx] + assert node.op != "placeholder" + + self.print() + self.print("STEP", node_idx, node.format_node()) + self.print(self.module_stack) + if node.op == "output": + if len(self.module_stack) == 1: + # We want the output node of the original graph to be handled + # specially by the outermost stack frame (in run_outer). So + # skip finalization here. + return node_idx + + # We've reached the end of the graph. Wrap up all the existing stack frames. + self.finalize_outputs() + return node_idx + + if len(node.meta.get("nn_module_stack", {})) == 0: + raise RuntimeError(f"Unable to find nn_module_stack for node {node}") + + nn_module_stack = node.meta["nn_module_stack"] + from torch._export.passes._node_metadata_hook import ( + _EMPTY_NN_MODULE_STACK_KEY, + ) + + if ( + len(nn_module_stack) == 1 + and _EMPTY_NN_MODULE_STACK_KEY in nn_module_stack + ): + # Empty case from the node_metadata_hook + node_module_stack = self.module_stack + else: + node_module_stack = [ + path for path, ty in node.meta["nn_module_stack"].values() + ] + + if node_module_stack[: len(self.module_stack)] != self.module_stack: + # This means that the current module is done executing and the + # current node is the beginning of a new module. + # + # In this case, we should finalize this module and return without + # incrementing the node counter. + self.finalize_outputs() + self.print("outlining", self.fqn) + self.print(self.graph) + return node_idx + + assert node_module_stack is not None + + if _is_prefix(self.module_stack, node_module_stack): + # This means that the current node represents the execution of a new + # module. + next_module = node_module_stack[len(self.module_stack)] + self.print("Creating new stack frame for", next_module) + # Run a nested version of module outliner from the current node + # counter. Once it is complete, continue from that point. + node_idx = _ModuleFrame( + self.flat_graph, + self.nodes, + self.seen_nodes, + self.seen_modules, + self, + self.module_stack + [next_module], + list(node.meta["nn_module_stack"].keys())[len(self.module_stack)], + self.module_call_graph, + ).run_from(node_idx) + module_idx += 1 + continue + + # The only remaining possibility is that we are in the right stack + # frame. Copy the node into this frame's graph and increment the node counter. + assert node_module_stack == self.module_stack + self.copy_node(node) + node_idx += 1 + + +def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule): + seen_nodes: Dict[str, torch.fx.Node] = {} + seen_modules: Dict[int, torch.nn.Module] = {} + _ModuleFrame( + orig_graph, + tuple(orig_graph.nodes), + seen_nodes, + seen_modules, + None, + [""], + "", + { + entry.fqn: entry.signature + for entry in root_module.module_call_graph + if entry.signature + }, + module=root_module, + ).run_outer() + + +def _reorder_submodules( + parent: torch.nn.Module, fqn_order: Dict[str, int], prefix: str = "" +): + # TODO Can be optimized by adding submodules ahead of time. + if prefix == "": + for fqn in list(fqn_order.keys())[1:]: + if _get_submodule(parent, fqn) is None: + _add_submodule(parent, fqn, torch.nn.Module()) + + children = [] + for name, child in list(parent._modules.items()): + if child is None: + continue + fqn = prefix + name + _reorder_submodules(child, fqn_order, prefix=fqn + ".") + delattr(parent, name) + children.append((fqn_order[fqn], name, child)) + children.sort(key=operator.itemgetter(0)) + for _, name, child in children: + parent.register_module(name, child) + + +def _sink_params( + module: torch.nn.Module, + inputs_to_state: Dict[str, List[str]], + scope: List[str], +): + """Sink params, buffers, and constants from graph inputs into get_attr nodes. + + Exported modules are purely functional, so they pass their parameters and + buffers in as inputs to the graph. + + To replicate eager's semantics, we need to get them from the module state + via get_attr instead. + + module: GraphModule, potentially containining nested submodules. + inputs_to_state: mapping graph input names to the corresponding key in the state_dict. + scope: tracks where we are in the module hierarchy, so that we can emit the + right `getattr(self, "foo.bar")` calls, etc. + """ + # This dict records inputs removed by child modules. + # Maps the module object id to the list of placeholder node names + # in the child module that were removed. + module_id_to_inputs_removed: Dict[int, List[str]] = defaultdict(list) + + # We need to use _modules here instead of named_children(), because we + # explicitly want duplicate modules to show up in the traversal. + for name, submodule in module._modules.items(): + submod_id_to_inputs_removed = _sink_params( + cast(torch.nn.Module, submodule), inputs_to_state, scope + [name] + ) + for k, v in submod_id_to_inputs_removed.items(): + module_id_to_inputs_removed[k].extend(v) + + if not hasattr(module, "graph"): + # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList) + return module_id_to_inputs_removed + + graph = module.graph + inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes)) + the_last_input = inputs[-1] + + # Also remove from call_module nodes + call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes) + for node in call_module_nodes: + submodule = _recursive_getattr(module, node.target.split(".")) + # remove placeholder from call_module node arguments, only if we've + # erased the placeholder node in the corresponding _sink_params() call + if submodule is not None and id(submodule) in module_id_to_inputs_removed: + node.args = tuple( + filter( + lambda n: n.name not in module_id_to_inputs_removed[id(submodule)], + node.args, + ) + ) + + # Filter out inputs_to_state corresponding to current scope. + inputs_to_state_of_scope: Dict[torch.fx.Node, list[str]] = {} + for node in inputs: + if node.name not in inputs_to_state: + continue + + state_name = None + for sn in inputs_to_state[node.name]: + sn_split = sn.split(".") + if sn_split[: len(scope)] == scope: + state_name = sn_split + break + + # If there's a mismatch beteewn scope name and state name, then + # there must be multuple scopes pointing to the same state name, + # meaning some modules are shared. In such case, we can simply skip + # updating the current node because another later iteration will + # take care of this input node when the unique match between scope + # and state name occurs. To make sure this always happen, we should + # enforce the invariant that no placeholder node in the unflattened + # graph appears in inputs_to_state dict, which means all the extra + # input nodes have been handled. + if state_name is None: + continue + + inputs_to_state_of_scope[node] = state_name + + # Record name of remove inputs for return purpose. + inputs_removed: List[str] = [] + + for node, state_name in inputs_to_state_of_scope.items(): + if len(node.users) > 0: + attr_path = state_name[len(scope) :] + state_attr = _recursive_getattr(module, attr_path) + assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject)) + + # Make sure the newly created get_attr node is placed after the last placeholder node + with graph.inserting_after(the_last_input): + new_node = graph.create_node("get_attr", ".".join(attr_path)) + + node.replace_all_uses_with(new_node, propagate_meta=True) + + graph.erase_node(node) + inputs_removed.append(node.name) + + if isinstance(module, InterpreterModule): + module.finalize() + + return {id(module): inputs_removed} + + +def _recursive_getattr(obj, attr_path): + for attr in attr_path: + if not hasattr(obj, attr): + return None + obj = getattr(obj, attr) + + return obj diff --git a/lib/python3.10/site-packages/torch/fft/__init__.py b/lib/python3.10/site-packages/torch/fft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bc5191c7b57de89817e5401d0db24aac1c6df5e --- /dev/null +++ b/lib/python3.10/site-packages/torch/fft/__init__.py @@ -0,0 +1,1360 @@ +import sys + +import torch +from torch._C import _add_docstr, _fft # type: ignore[attr-defined] +from torch._torch_docs import factory_common_args, common_args + +__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn', + 'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn', + 'hfft', 'ihfft', 'fftfreq', 'rfftfreq', 'fftshift', 'ifftshift', + 'Tensor'] + +Tensor = torch.Tensor + +# Note: This not only adds the doc strings for the spectral ops, but +# connects the torch.fft Python namespace to the torch._C._fft builtins. + +fft = _add_docstr(_fft.fft_fft, r""" +fft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + +Computes the one dimensional discrete Fourier transform of :attr:`input`. + +Note: + The Fourier domain representation of any real signal satisfies the + Hermitian property: `X[i] = conj(X[-i])`. This function always returns both + the positive and negative frequency terms even though, for real inputs, the + negative frequencies are redundant. :func:`~torch.fft.rfft` returns the + more compact one-sided representation where only the positive frequencies + are returned. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + +Args: + input (Tensor): the input tensor + n (int, optional): Signal length. If given, the input will either be zero-padded + or trimmed to this length before computing the FFT. + dim (int, optional): The dimension along which to take the one dimensional FFT. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.fft`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Calling the backward transform (:func:`~torch.fft.ifft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ifft` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + >>> t = torch.arange(4) + >>> t + tensor([0, 1, 2, 3]) + >>> torch.fft.fft(t) + tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) + + >>> t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j]) + >>> torch.fft.fft(t) + tensor([12.+16.j, -8.+0.j, -4.-4.j, 0.-8.j]) +""".format(**common_args)) + +ifft = _add_docstr(_fft.fft_ifft, r""" +ifft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + +Computes the one dimensional inverse discrete Fourier transform of :attr:`input`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + +Args: + input (Tensor): the input tensor + n (int, optional): Signal length. If given, the input will either be zero-padded + or trimmed to this length before computing the IFFT. + dim (int, optional): The dimension along which to take the one dimensional IFFT. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ifft`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Calling the forward transform (:func:`~torch.fft.fft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ifft` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> t = torch.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) + >>> torch.fft.ifft(t) + tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j]) +""".format(**common_args)) + +fft2 = _add_docstr(_fft.fft_fft2, r""" +fft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor + +Computes the 2 dimensional discrete Fourier transform of :attr:`input`. +Equivalent to :func:`~torch.fft.fftn` but FFTs only the last two dimensions by default. + +Note: + The Fourier domain representation of any real signal satisfies the + Hermitian property: ``X[i, j] = conj(X[-i, -j])``. This + function always returns all positive and negative frequency terms even + though, for real inputs, half of these values are redundant. + :func:`~torch.fft.rfft2` returns the more compact one-sided representation + where only the positive frequencies of the last dimension are returned. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.fft2`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.ifft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` + between the two transforms. This is required to make + :func:`~torch.fft.ifft2` the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + >>> x = torch.rand(10, 10, dtype=torch.complex64) + >>> fft2 = torch.fft.fft2(x) + + The discrete Fourier transform is separable, so :func:`~torch.fft.fft2` + here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls: + + >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1) + >>> torch.testing.assert_close(fft2, two_ffts, check_stride=False) + +""".format(**common_args)) + +ifft2 = _add_docstr(_fft.fft_ifft2, r""" +ifft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor + +Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`. +Equivalent to :func:`~torch.fft.ifftn` but IFFTs only the last two dimensions by default. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ifft2`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.fft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ifft2` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> x = torch.rand(10, 10, dtype=torch.complex64) + >>> ifft2 = torch.fft.ifft2(x) + + The discrete Fourier transform is separable, so :func:`~torch.fft.ifft2` + here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls: + + >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1) + >>> torch.testing.assert_close(ifft2, two_iffts, check_stride=False) + +""".format(**common_args)) + +fftn = _add_docstr(_fft.fft_fftn, r""" +fftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor + +Computes the N dimensional discrete Fourier transform of :attr:`input`. + +Note: + The Fourier domain representation of any real signal satisfies the + Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This + function always returns all positive and negative frequency terms even + though, for real inputs, half of these values are redundant. + :func:`~torch.fft.rfftn` returns the more compact one-sided representation + where only the positive frequencies of the last dimension are returned. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.fftn`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.ifftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` + between the two transforms. This is required to make + :func:`~torch.fft.ifftn` the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + >>> x = torch.rand(10, 10, dtype=torch.complex64) + >>> fftn = torch.fft.fftn(x) + + The discrete Fourier transform is separable, so :func:`~torch.fft.fftn` + here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls: + + >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1) + >>> torch.testing.assert_close(fftn, two_ffts, check_stride=False) + +""".format(**common_args)) + +ifftn = _add_docstr(_fft.fft_ifftn, r""" +ifftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor + +Computes the N dimensional inverse discrete Fourier transform of :attr:`input`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ifftn`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.fftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ifftn` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> x = torch.rand(10, 10, dtype=torch.complex64) + >>> ifftn = torch.fft.ifftn(x) + + The discrete Fourier transform is separable, so :func:`~torch.fft.ifftn` + here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls: + + >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1) + >>> torch.testing.assert_close(ifftn, two_iffts, check_stride=False) + +""".format(**common_args)) + +rfft = _add_docstr(_fft.fft_rfft, r""" +rfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + +Computes the one dimensional Fourier transform of real-valued :attr:`input`. + +The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])`` so +the output contains only the positive frequencies below the Nyquist frequency. +To compute the full output, use :func:`~torch.fft.fft` + +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + +Args: + input (Tensor): the real input tensor + n (int, optional): Signal length. If given, the input will either be zero-padded + or trimmed to this length before computing the real FFT. + dim (int, optional): The dimension along which to take the one dimensional real FFT. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.rfft`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Calling the backward transform (:func:`~torch.fft.irfft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfft` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + >>> t = torch.arange(4) + >>> t + tensor([0, 1, 2, 3]) + >>> torch.fft.rfft(t) + tensor([ 6.+0.j, -2.+2.j, -2.+0.j]) + + Compare against the full output from :func:`~torch.fft.fft`: + + >>> torch.fft.fft(t) + tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) + + Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted. + At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair, + and therefore must always be real-valued. +""".format(**common_args)) + +irfft = _add_docstr(_fft.fft_irfft, r""" +irfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + +Computes the inverse of :func:`~torch.fft.rfft`. + +:attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier +domain, as produced by :func:`~torch.fft.rfft`. By the Hermitian property, the +output will be real-valued. + +Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + +Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`n`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal length :attr:`n`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + With default arguments, size of the transformed dimension should be (2^n + 1) as argument + `n` defaults to even output size = 2 * (transformed_dim_size - 1) + +Args: + input (Tensor): the input tensor representing a half-Hermitian signal + n (int, optional): Output signal length. This determines the length of the + output signal. If given, the input will either be zero-padded or trimmed to this + length before computing the real IFFT. + Defaults to even output: ``n=2*(input.size(dim) - 1)``. + dim (int, optional): The dimension along which to take the one dimensional real IFFT. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.irfft`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) + + Calling the forward transform (:func:`~torch.fft.rfft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfft` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> t = torch.linspace(0, 1, 5) + >>> t + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + >>> T = torch.fft.rfft(t) + >>> T + tensor([ 2.5000+0.0000j, -0.6250+0.8602j, -0.6250+0.2031j]) + + Without specifying the output length to :func:`~torch.fft.irfft`, the output + will not round-trip properly because the input is odd-length: + + >>> torch.fft.irfft(T) + tensor([0.1562, 0.3511, 0.7812, 1.2114]) + + So, it is recommended to always pass the signal length :attr:`n`: + + >>> roundtrip = torch.fft.irfft(T, t.numel()) + >>> torch.testing.assert_close(roundtrip, t, check_stride=False) + +""".format(**common_args)) + +rfft2 = _add_docstr(_fft.fft_rfft2, r""" +rfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor + +Computes the 2-dimensional discrete Fourier transform of real :attr:`input`. +Equivalent to :func:`~torch.fft.rfftn` but FFTs only the last two dimensions by default. + +The FFT of a real signal is Hermitian-symmetric, ``X[i, j] = conj(X[-i, -j])``, +so the full :func:`~torch.fft.fft2` output contains redundant information. +:func:`~torch.fft.rfft2` instead omits the negative frequencies in the last +dimension. + +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.rfft2`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.irfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfft2` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + >>> t = torch.rand(10, 10) + >>> rfft2 = torch.fft.rfft2(t) + >>> rfft2.size() + torch.Size([10, 6]) + + Compared against the full output from :func:`~torch.fft.fft2`, we have all + elements up to the Nyquist frequency. + + >>> fft2 = torch.fft.fft2(t) + >>> torch.testing.assert_close(fft2[..., :6], rfft2, check_stride=False) + + The discrete Fourier transform is separable, so :func:`~torch.fft.rfft2` + here is equivalent to a combination of :func:`~torch.fft.fft` and + :func:`~torch.fft.rfft`: + + >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0) + >>> torch.testing.assert_close(rfft2, two_ffts, check_stride=False) + +""".format(**common_args)) + +irfft2 = _add_docstr(_fft.fft_irfft2, r""" +irfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor + +Computes the inverse of :func:`~torch.fft.rfft2`. +Equivalent to :func:`~torch.fft.irfftn` but IFFTs only the last two dimensions by default. + +:attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier +domain, as produced by :func:`~torch.fft.rfft2`. By the Hermitian property, the +output will be real-valued. + +Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + +Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`s`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal shape :attr:`s`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + With default arguments, the size of last dimension should be (2^n + 1) as argument + `s` defaults to even output size = 2 * (last_dim_size - 1) + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.irfft2`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.rfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfft2` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> t = torch.rand(10, 9) + >>> T = torch.fft.rfft2(t) + + Without specifying the output length to :func:`~torch.fft.irfft2`, the output + will not round-trip properly because the input is odd-length in the last + dimension: + + >>> torch.fft.irfft2(T).size() + torch.Size([10, 8]) + + So, it is recommended to always pass the signal shape :attr:`s`. + + >>> roundtrip = torch.fft.irfft2(T, t.size()) + >>> roundtrip.size() + torch.Size([10, 9]) + >>> torch.testing.assert_close(roundtrip, t, check_stride=False) + +""".format(**common_args)) + +rfftn = _add_docstr(_fft.fft_rfftn, r""" +rfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor + +Computes the N-dimensional discrete Fourier transform of real :attr:`input`. + +The FFT of a real signal is Hermitian-symmetric, +``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])`` so the full +:func:`~torch.fft.fftn` output contains redundant information. +:func:`~torch.fft.rfftn` instead omits the negative frequencies in the +last dimension. + +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.rfftn`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.irfftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfftn` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + >>> t = torch.rand(10, 10) + >>> rfftn = torch.fft.rfftn(t) + >>> rfftn.size() + torch.Size([10, 6]) + + Compared against the full output from :func:`~torch.fft.fftn`, we have all + elements up to the Nyquist frequency. + + >>> fftn = torch.fft.fftn(t) + >>> torch.testing.assert_close(fftn[..., :6], rfftn, check_stride=False) + + The discrete Fourier transform is separable, so :func:`~torch.fft.rfftn` + here is equivalent to a combination of :func:`~torch.fft.fft` and + :func:`~torch.fft.rfft`: + + >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0) + >>> torch.testing.assert_close(rfftn, two_ffts, check_stride=False) + +""".format(**common_args)) + +irfftn = _add_docstr(_fft.fft_irfftn, r""" +irfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor + +Computes the inverse of :func:`~torch.fft.rfftn`. + +:attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier +domain, as produced by :func:`~torch.fft.rfftn`. By the Hermitian property, the +output will be real-valued. + +Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + +Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`s`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal shape :attr:`s`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + With default arguments, the size of last dimension should be (2^n + 1) as argument + `s` defaults to even output size = 2 * (last_dim_size - 1) + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.irfftn`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.rfftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfftn` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> t = torch.rand(10, 9) + >>> T = torch.fft.rfftn(t) + + Without specifying the output length to :func:`~torch.fft.irfft`, the output + will not round-trip properly because the input is odd-length in the last + dimension: + + >>> torch.fft.irfftn(T).size() + torch.Size([10, 8]) + + So, it is recommended to always pass the signal shape :attr:`s`. + + >>> roundtrip = torch.fft.irfftn(T, t.size()) + >>> roundtrip.size() + torch.Size([10, 9]) + >>> torch.testing.assert_close(roundtrip, t, check_stride=False) + +""".format(**common_args)) + +hfft = _add_docstr(_fft.fft_hfft, r""" +hfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + +Computes the one dimensional discrete Fourier transform of a Hermitian +symmetric :attr:`input` signal. + +Note: + + :func:`~torch.fft.hfft`/:func:`~torch.fft.ihfft` are analogous to + :func:`~torch.fft.rfft`/:func:`~torch.fft.irfft`. The real FFT expects + a real signal in the time-domain and gives a Hermitian symmetry in the + frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in + the time-domain and real-valued in the frequency-domain. For this reason, + special care needs to be taken with the length argument :attr:`n`, in the + same way as with :func:`~torch.fft.irfft`. + +Note: + Because the signal is Hermitian in the time-domain, the result will be + real in the frequency domain. Note that some input frequencies must be + real-valued to satisfy the Hermitian property. In these cases the imaginary + component will be ignored. For example, any imaginary component in + ``input[0]`` would result in one or more complex frequency terms which + cannot be represented in a real output and so will always be ignored. + +Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`n`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal length :attr:`n`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + With default arguments, size of the transformed dimension should be (2^n + 1) as argument + `n` defaults to even output size = 2 * (transformed_dim_size - 1) + +Args: + input (Tensor): the input tensor representing a half-Hermitian signal + n (int, optional): Output signal length. This determines the length of the + real output. If given, the input will either be zero-padded or trimmed to this + length before computing the Hermitian FFT. + Defaults to even output: ``n=2*(input.size(dim) - 1)``. + dim (int, optional): The dimension along which to take the one dimensional Hermitian FFT. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.hfft`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal) + + Calling the backward transform (:func:`~torch.fft.ihfft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ihfft` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + Taking a real-valued frequency signal and bringing it into the time domain + gives Hermitian symmetric output: + + >>> t = torch.linspace(0, 1, 5) + >>> t + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + >>> T = torch.fft.ifft(t) + >>> T + tensor([ 0.5000-0.0000j, -0.1250-0.1720j, -0.1250-0.0406j, -0.1250+0.0406j, + -0.1250+0.1720j]) + + Note that ``T[1] == T[-1].conj()`` and ``T[2] == T[-2].conj()`` is + redundant. We can thus compute the forward transform without considering + negative frequencies: + + >>> torch.fft.hfft(T[:3], n=5) + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + + Like with :func:`~torch.fft.irfft`, the output length must be given in order + to recover an even length output: + + >>> torch.fft.hfft(T[:3]) + tensor([0.1250, 0.2809, 0.6250, 0.9691]) +""".format(**common_args)) + +ihfft = _add_docstr(_fft.fft_ihfft, r""" +ihfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + +Computes the inverse of :func:`~torch.fft.hfft`. + +:attr:`input` must be a real-valued signal, interpreted in the Fourier domain. +The IFFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``. +:func:`~torch.fft.ihfft` represents this in the one-sided form where only the +positive frequencies below the Nyquist frequency are included. To compute the +full output, use :func:`~torch.fft.ifft`. + +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + +Args: + input (Tensor): the real input tensor + n (int, optional): Signal length. If given, the input will either be zero-padded + or trimmed to this length before computing the Hermitian IFFT. + dim (int, optional): The dimension along which to take the one dimensional Hermitian IFFT. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ihfft`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Calling the forward transform (:func:`~torch.fft.hfft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ihfft` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> t = torch.arange(5) + >>> t + tensor([0, 1, 2, 3, 4]) + >>> torch.fft.ihfft(t) + tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j]) + + Compare against the full output from :func:`~torch.fft.ifft`: + + >>> torch.fft.ifft(t) + tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j, + -0.5000+0.6882j]) +""".format(**common_args)) + +hfft2 = _add_docstr(_fft.fft_hfft2, r""" +hfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor + +Computes the 2-dimensional discrete Fourier transform of a Hermitian symmetric +:attr:`input` signal. Equivalent to :func:`~torch.fft.hfftn` but only +transforms the last two dimensions by default. + +:attr:`input` is interpreted as a one-sided Hermitian signal in the time +domain. By the Hermitian property, the Fourier transform will be real-valued. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + With default arguments, the size of last dimension should be (2^n + 1) as argument + `s` defaults to even output size = 2 * (last_dim_size - 1) + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the Hermitian FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.hfft2`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.ihfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ihfft2` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + Starting from a real frequency-space signal, we can generate a + Hermitian-symmetric time-domain signal: + >>> T = torch.rand(10, 9) + >>> t = torch.fft.ihfft2(T) + + Without specifying the output length to :func:`~torch.fft.hfftn`, the + output will not round-trip properly because the input is odd-length in the + last dimension: + + >>> torch.fft.hfft2(t).size() + torch.Size([10, 10]) + + So, it is recommended to always pass the signal shape :attr:`s`. + + >>> roundtrip = torch.fft.hfft2(t, T.size()) + >>> roundtrip.size() + torch.Size([10, 9]) + >>> torch.allclose(roundtrip, T) + True + +""".format(**common_args)) + +ihfft2 = _add_docstr(_fft.fft_ihfft2, r""" +ihfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor + +Computes the 2-dimensional inverse discrete Fourier transform of real +:attr:`input`. Equivalent to :func:`~torch.fft.ihfftn` but transforms only the +two last dimensions by default. + +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the Hermitian IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ihfft2`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.hfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ihfft2` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> T = torch.rand(10, 10) + >>> t = torch.fft.ihfft2(t) + >>> t.size() + torch.Size([10, 6]) + + Compared against the full output from :func:`~torch.fft.ifft2`, the + Hermitian time-space signal takes up only half the space. + + >>> fftn = torch.fft.ifft2(t) + >>> torch.allclose(fftn[..., :6], rfftn) + True + + The discrete Fourier transform is separable, so :func:`~torch.fft.ihfft2` + here is equivalent to a combination of :func:`~torch.fft.ifft` and + :func:`~torch.fft.ihfft`: + + >>> two_ffts = torch.fft.ifft(torch.fft.ihfft(t, dim=1), dim=0) + >>> torch.allclose(t, two_ffts) + True + +""".format(**common_args)) + +hfftn = _add_docstr(_fft.fft_hfftn, r""" +hfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor + +Computes the n-dimensional discrete Fourier transform of a Hermitian symmetric +:attr:`input` signal. + +:attr:`input` is interpreted as a one-sided Hermitian signal in the time +domain. By the Hermitian property, the Fourier transform will be real-valued. + +Note: + :func:`~torch.fft.hfftn`/:func:`~torch.fft.ihfftn` are analogous to + :func:`~torch.fft.rfftn`/:func:`~torch.fft.irfftn`. The real FFT expects + a real signal in the time-domain and gives Hermitian symmetry in the + frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in + the time-domain and real-valued in the frequency-domain. For this reason, + special care needs to be taken with the shape argument :attr:`s`, in the + same way as with :func:`~torch.fft.irfftn`. + +Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + +Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`s`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. It is recommended to always pass the signal shape :attr:`s`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + With default arguments, the size of last dimension should be (2^n + 1) as argument + `s` defaults to even output size = 2 * (last_dim_size - 1) + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.hfftn`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.ihfftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ihfftn` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + Starting from a real frequency-space signal, we can generate a + Hermitian-symmetric time-domain signal: + >>> T = torch.rand(10, 9) + >>> t = torch.fft.ihfftn(T) + + Without specifying the output length to :func:`~torch.fft.hfftn`, the + output will not round-trip properly because the input is odd-length in the + last dimension: + + >>> torch.fft.hfftn(t).size() + torch.Size([10, 10]) + + So, it is recommended to always pass the signal shape :attr:`s`. + + >>> roundtrip = torch.fft.hfftn(t, T.size()) + >>> roundtrip.size() + torch.Size([10, 9]) + >>> torch.allclose(roundtrip, T) + True + +""".format(**common_args)) + +ihfftn = _add_docstr(_fft.fft_ihfftn, r""" +ihfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor + +Computes the N-dimensional inverse discrete Fourier transform of real :attr:`input`. + +:attr:`input` must be a real-valued signal, interpreted in the Fourier domain. +The n-dimensional IFFT of a real signal is Hermitian-symmetric, +``X[i, j, ...] = conj(X[-i, -j, ...])``. :func:`~torch.fft.ihfftn` represents +this in the one-sided form where only the positive frequencies below the +Nyquist frequency are included in the last signal dimension. To compute the +full output, use :func:`~torch.fft.ifftn`. + +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the Hermitian IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ihfftn`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.hfftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ihfftn` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> T = torch.rand(10, 10) + >>> ihfftn = torch.fft.ihfftn(T) + >>> ihfftn.size() + torch.Size([10, 6]) + + Compared against the full output from :func:`~torch.fft.ifftn`, we have all + elements up to the Nyquist frequency. + + >>> ifftn = torch.fft.ifftn(t) + >>> torch.allclose(ifftn[..., :6], ihfftn) + True + + The discrete Fourier transform is separable, so :func:`~torch.fft.ihfftn` + here is equivalent to a combination of :func:`~torch.fft.ihfft` and + :func:`~torch.fft.ifft`: + + >>> two_iffts = torch.fft.ifft(torch.fft.ihfft(t, dim=1), dim=0) + >>> torch.allclose(ihfftn, two_iffts) + True + +""".format(**common_args)) + +fftfreq = _add_docstr(_fft.fft_fftfreq, r""" +fftfreq(n, d=1.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Computes the discrete Fourier Transform sample frequencies for a signal of size :attr:`n`. + +Note: + By convention, :func:`~torch.fft.fft` returns positive frequency terms + first, followed by the negative frequencies in reverse order, so that + ``f[-i]`` for all :math:`0 < i \leq n/2`` in Python gives the negative + frequency terms. For an FFT of length :attr:`n` and with inputs spaced in + length unit :attr:`d`, the frequencies are:: + + f = [0, 1, ..., (n - 1) // 2, -(n // 2), ..., -1] / (d * n) + +Note: + For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as + either negative or positive. :func:`~torch.fft.fftfreq` follows NumPy's + convention of taking it to be negative. + +Args: + n (int): the FFT length + d (float, optional): The sampling length scale. + The spacing between individual samples of the FFT input. + The default assumes unit spacing, dividing that result by the actual + spacing gives the result in physical frequency units. + +Keyword Args: + {out} + {dtype} + {layout} + {device} + {requires_grad} + +Example: + + >>> torch.fft.fftfreq(5) + tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000]) + + For even input, we can see the Nyquist frequency at ``f[2]`` is given as + negative: + + >>> torch.fft.fftfreq(4) + tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) + +""".format(**factory_common_args)) + +rfftfreq = _add_docstr(_fft.fft_rfftfreq, r""" +rfftfreq(n, d=1.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Computes the sample frequencies for :func:`~torch.fft.rfft` with a signal of size :attr:`n`. + +Note: + :func:`~torch.fft.rfft` returns Hermitian one-sided output, so only the + positive frequency terms are returned. For a real FFT of length :attr:`n` + and with inputs spaced in length unit :attr:`d`, the frequencies are:: + + f = torch.arange((n + 1) // 2) / (d * n) + +Note: + For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as + either negative or positive. Unlike :func:`~torch.fft.fftfreq`, + :func:`~torch.fft.rfftfreq` always returns it as positive. + +Args: + n (int): the real FFT length + d (float, optional): The sampling length scale. + The spacing between individual samples of the FFT input. + The default assumes unit spacing, dividing that result by the actual + spacing gives the result in physical frequency units. + +Keyword Args: + {out} + {dtype} + {layout} + {device} + {requires_grad} + +Example: + + >>> torch.fft.rfftfreq(5) + tensor([0.0000, 0.2000, 0.4000]) + + >>> torch.fft.rfftfreq(4) + tensor([0.0000, 0.2500, 0.5000]) + + Compared to the output from :func:`~torch.fft.fftfreq`, we see that the + Nyquist frequency at ``f[2]`` has changed sign: + >>> torch.fft.fftfreq(4) + tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) + +""".format(**factory_common_args)) + +fftshift = _add_docstr(_fft.fft_fftshift, r""" +fftshift(input, dim=None) -> Tensor + +Reorders n-dimensional FFT data, as provided by :func:`~torch.fft.fftn`, to have +negative frequency terms first. + +This performs a periodic shift of n-dimensional data such that the origin +``(0, ..., 0)`` is moved to the center of the tensor. Specifically, to +``input.shape[dim] // 2`` in each selected dimension. + +Note: + By convention, the FFT returns positive frequency terms first, followed by + the negative frequencies in reverse order, so that ``f[-i]`` for all + :math:`0 < i \leq n/2` in Python gives the negative frequency terms. + :func:`~torch.fft.fftshift` rearranges all frequencies into ascending order + from negative to positive with the zero-frequency term in the center. + +Note: + For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as + either negative or positive. :func:`~torch.fft.fftshift` always puts the + Nyquist term at the 0-index. This is the same convention used by + :func:`~torch.fft.fftfreq`. + +Args: + input (Tensor): the tensor in FFT order + dim (int, Tuple[int], optional): The dimensions to rearrange. + Only dimensions specified here will be rearranged, any other dimensions + will be left in their original order. + Default: All dimensions of :attr:`input`. + +Example: + + >>> f = torch.fft.fftfreq(4) + >>> f + tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) + + >>> torch.fft.fftshift(f) + tensor([-0.5000, -0.2500, 0.0000, 0.2500]) + + Also notice that the Nyquist frequency term at ``f[2]`` was moved to the + beginning of the tensor. + + This also works for multi-dimensional transforms: + + >>> x = torch.fft.fftfreq(5, d=1/5) + 0.1 * torch.fft.fftfreq(5, d=1/5).unsqueeze(1) + >>> x + tensor([[ 0.0000, 1.0000, 2.0000, -2.0000, -1.0000], + [ 0.1000, 1.1000, 2.1000, -1.9000, -0.9000], + [ 0.2000, 1.2000, 2.2000, -1.8000, -0.8000], + [-0.2000, 0.8000, 1.8000, -2.2000, -1.2000], + [-0.1000, 0.9000, 1.9000, -2.1000, -1.1000]]) + + >>> torch.fft.fftshift(x) + tensor([[-2.2000, -1.2000, -0.2000, 0.8000, 1.8000], + [-2.1000, -1.1000, -0.1000, 0.9000, 1.9000], + [-2.0000, -1.0000, 0.0000, 1.0000, 2.0000], + [-1.9000, -0.9000, 0.1000, 1.1000, 2.1000], + [-1.8000, -0.8000, 0.2000, 1.2000, 2.2000]]) + + :func:`~torch.fft.fftshift` can also be useful for spatial data. If our + data is defined on a centered grid (``[-(N//2), (N-1)//2]``) then we can + use the standard FFT defined on an uncentered grid (``[0, N)``) by first + applying an :func:`~torch.fft.ifftshift`. + + >>> x_centered = torch.arange(-5, 5) + >>> x_uncentered = torch.fft.ifftshift(x_centered) + >>> fft_uncentered = torch.fft.fft(x_uncentered) + + Similarly, we can convert the frequency domain components to centered + convention by applying :func:`~torch.fft.fftshift`. + + >>> fft_centered = torch.fft.fftshift(fft_uncentered) + + The inverse transform, from centered Fourier space back to centered spatial + data, can be performed by applying the inverse shifts in reverse order: + + >>> x_centered_2 = torch.fft.fftshift(torch.fft.ifft(torch.fft.ifftshift(fft_centered))) + >>> torch.testing.assert_close(x_centered.to(torch.complex64), x_centered_2, check_stride=False) + + +""") + +ifftshift = _add_docstr(_fft.fft_ifftshift, r""" +ifftshift(input, dim=None) -> Tensor + +Inverse of :func:`~torch.fft.fftshift`. + +Args: + input (Tensor): the tensor in FFT order + dim (int, Tuple[int], optional): The dimensions to rearrange. + Only dimensions specified here will be rearranged, any other dimensions + will be left in their original order. + Default: All dimensions of :attr:`input`. + +Example: + + >>> f = torch.fft.fftfreq(5) + >>> f + tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000]) + + A round-trip through :func:`~torch.fft.fftshift` and + :func:`~torch.fft.ifftshift` gives the same result: + + >>> shifted = torch.fft.fftshift(f) + >>> torch.fft.ifftshift(shifted) + tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000]) + +""") diff --git a/lib/python3.10/site-packages/torch/func/__init__.py b/lib/python3.10/site-packages/torch/func/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd0786456dec0f2346e56e76fbfa18b12acd49d5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/func/__init__.py @@ -0,0 +1,13 @@ +from torch._functorch.eager_transforms import ( + vjp, + jvp, + jacrev, + jacfwd, + hessian, + functionalize, + linearize +) +from torch._functorch.apis import grad, grad_and_value +from torch._functorch.functional_call import functional_call, stack_module_state +from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ +from torch._functorch.apis import vmap diff --git a/lib/python3.10/site-packages/torch/futures/__init__.py b/lib/python3.10/site-packages/torch/futures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1623c44f193d74de2a8d099699cd3439a1f1227 --- /dev/null +++ b/lib/python3.10/site-packages/torch/futures/__init__.py @@ -0,0 +1,319 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union + +import torch + +__all__ = ['Future', 'collect_all', 'wait_all'] + +T = TypeVar("T") +S = TypeVar("S") + + +class _PyFutureMeta(type(torch._C.Future), type(Generic)): # type: ignore[misc, no-redef] + pass + + +class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta): + r""" + Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous + execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It + also exposes a set of APIs to add callback functions and set results. + + .. warning:: GPU support is a beta feature, subject to changes. + """ + + def __init__(self, *, devices: Optional[List[Union[int, str, torch.device]]] = None): + r""" + Create an empty unset ``Future``. If the future is intended to hold + values containing CUDA tensors, (a superset of) their CUDA devices must + be specified at construction. (This is only supported if + ``torch.cuda.is_available()`` returns ``True``). This is needed to + ensure proper CUDA stream synchronization. The child futures, returned + by the ``then`` method, will inherit these devices. + + Args: + devices(``List[Union[int, str, torch.device]]``, optional): the set + of devices on which tensors contained in this future's value are + allowed to reside and on which callbacks are allowed to operate. + """ + if devices is None: + devices = [] + super().__init__([torch.device(d) for d in devices]) + + def done(self) -> bool: + r""" + Return ``True`` if this ``Future`` is done. A ``Future`` is done if it + has a result or an exception. + + If the value contains tensors that reside on GPUs, ``Future.done()`` + will return ``True`` even if the asynchronous kernels that are + populating those tensors haven't yet completed running on the device, + because at such stage the result is already usable, provided one + performs the appropriate synchronizations (see :meth:`wait`). + """ + return super().done() + + def wait(self) -> T: + r""" + Block until the value of this ``Future`` is ready. + + If the value contains tensors that reside on GPUs, then an additional + synchronization is performed with the kernels (executing on the device) + which may be asynchronously populating those tensors. Such sync is + non-blocking, which means that ``wait()`` will insert the necessary + instructions in the current streams to ensure that further operations + enqueued on those streams will be properly scheduled after the async + kernels but, once that is done, ``wait()`` will return, even if those + kernels are still running. No further synchronization is required when + accessing and using the values, as long as one doesn't change streams. + + Returns: + The value held by this ``Future``. If the function (callback or RPC) + creating the value has thrown an error, this ``wait`` method will + also throw an error. + """ + return super().wait() + + def value(self) -> T: + r""" + Obtain the value of an already-completed future. + + This method should only be called after a call to :meth:`wait` has + completed, or inside a callback function passed to :meth:`then`. In + other cases this ``Future`` may not yet hold a value and calling + ``value()`` could fail. + + If the value contains tensors that reside on GPUs, then this method will + *not* perform any additional synchronization. This should be done + beforehand, separately, through a call to :meth:`wait` (except within + callbacks, for which it's already being taken care of by :meth:`then`). + + Returns: + The value held by this ``Future``. If the function (callback or RPC) + creating the value has thrown an error, this ``value()`` method will + also throw an error. + """ + return super().value() + + def then(self, callback: Callable[[Future[T]], S]) -> Future[S]: + r""" + Append the given callback function to this ``Future``, which will be run + when the ``Future`` is completed. Multiple callbacks can be added to + the same ``Future``, but the order in which they will be executed cannot + be guaranteed (to enforce a certain order consider chaining: + ``fut.then(cb1).then(cb2)``). The callback must take one argument, which + is the reference to this ``Future``. The callback function can use the + :meth:`value` method to get the value. Note that if this ``Future`` is + already completed, the given callback will be run immediately inline. + + If the ``Future``'s value contains tensors that reside on GPUs, the + callback might be invoked while the async kernels that are populating + those tensors haven't yet finished executing on the device. However, the + callback will be invoked with some dedicated streams set as current + (fetched from a global pool) which will be synchronized with those + kernels. Hence any operation performed by the callback on these tensors + will be scheduled on the device after the kernels complete. In other + words, as long as the callback doesn't switch streams, it can safely + manipulate the result without any additional synchronization. This is + similar to the non-blocking behavior of :meth:`wait`. + + Similarly, if the callback returns a value that contains tensors that + reside on a GPU, it can do so even if the kernels that are producing + these tensors are still running on the device, as long as the callback + didn't change streams during its execution. If one wants to change + streams, one must be careful to re-synchronize them with the original + streams, that is, those that were current when the callback was invoked. + + Args: + callback(``Callable``): a ``Callable`` that takes this ``Future`` as + the only argument. + + Returns: + A new ``Future`` object that holds the return value of the + ``callback`` and will be marked as completed when the given + ``callback`` finishes. + + .. note:: Note that if the callback function throws, either + through the original future being completed with an exception and + calling ``fut.wait()``, or through other code in the callback, the + future returned by ``then`` will be marked appropriately with the + encountered error. However, if this callback later completes + additional futures, those futures are not marked as completed with + an error and the user is responsible for handling completion/waiting + on those futures independently. + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) + >>> def callback(fut): + ... print(f"RPC return value is {fut.wait()}.") + >>> fut = torch.futures.Future() + >>> # The inserted callback will print the return value when + >>> # receiving the response from "worker1" + >>> cb_fut = fut.then(callback) + >>> chain_cb_fut = cb_fut.then( + ... lambda x : print(f"Chained cb done. {x.wait()}") + ... ) + >>> fut.set_result(5) + RPC return value is 5. + Chained cb done. None + """ + return cast(Future[S], super().then(callback)) + + def add_done_callback(self, callback: Callable[[Future[T]], None]) -> None: + r""" + Append the given callback function to this ``Future``, which will be run + when the ``Future`` is completed. Multiple callbacks can be added to + the same ``Future``, but the order in which they will be executed cannot + be guaranteed. The callback must take one argument, which is the + reference to this ``Future``. The callback function can use the + :meth:`value` method to get the value. Note that if this ``Future`` is + already completed, the given callback will be run inline. + + We recommend that you use the :meth:`then` method as it provides a way + to synchronize after your callback has completed. ``add_done_callback`` + can be cheaper if your callback does not return anything. But both + :meth:`then` and ``add_done_callback`` use the same callback + registration API under the hood. + + With respect to GPU tensors, this method behaves in the same way as + :meth:`then`. + + Args: + callback(``Future``): a ``Callable`` that takes in one argument, + which is the reference to this ``Future``. + + .. note:: Note that if the callback function throws, either + through the original future being completed with an exception and + calling ``fut.wait()``, or through other code in the callback, + error handling must be carefully taken care of. For example, if + this callback later completes additional futures, those futures are + not marked as completed with an error and the user is responsible + for handling completion/waiting on those futures independently. + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) + >>> def callback(fut): + ... print("This will run after the future has finished.") + ... print(fut.wait()) + >>> fut = torch.futures.Future() + >>> fut.add_done_callback(callback) + >>> fut.set_result(5) + This will run after the future has finished. + 5 + """ + super().add_done_callback(callback) + + def set_result(self, result: T) -> None: + r""" + Set the result for this ``Future``, which will mark this ``Future`` as + completed and trigger all attached callbacks. Note that a ``Future`` + cannot be marked completed twice. + + If the result contains tensors that reside on GPUs, this method can be + called even if the asynchronous kernels that are populating those + tensors haven't yet completed running on the device, provided that the + streams on which those kernels were enqueued are set as the current ones + when this method is called. Put simply, it's safe to call this method + immediately after launching those kernels, without any additional + synchronization, as long as one doesn't change streams in between. This + method will record events on all the relevant current streams and will + use them to ensure proper scheduling for all the consumers of this + ``Future``. + + Args: + result (object): the result object of this ``Future``. + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) + >>> import threading + >>> import time + >>> def slow_set_future(fut, value): + ... time.sleep(0.5) + ... fut.set_result(value) + >>> fut = torch.futures.Future() + >>> t = threading.Thread( + ... target=slow_set_future, + ... args=(fut, torch.ones(2) * 3) + ... ) + >>> t.start() + >>> print(fut.wait()) + tensor([3., 3.]) + >>> t.join() + """ + super().set_result(result) + + def set_exception(self, result: T) -> None: + r""" + Set an exception for this ``Future``, which will mark this ``Future`` as + completed with an error and trigger all attached callbacks. Note that + when calling wait()/value() on this ``Future``, the exception set here + will be raised inline. + + Args: + result (BaseException): the exception for this ``Future``. + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) + >>> fut = torch.futures.Future() + >>> fut.set_exception(ValueError("foo")) + >>> fut.wait() + Traceback (most recent call last): + ... + ValueError: foo + """ + assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception." + + def raise_error(fut_result): + raise fut_result + + super()._set_unwrap_func(raise_error) + self.set_result(result) # type: ignore[arg-type] + + +def collect_all(futures: List[Future]) -> Future[List[Future]]: + r""" + Collects the provided :class:`~torch.futures.Future` objects into a single + combined :class:`~torch.futures.Future` that is completed when all of the + sub-futures are completed. + + Args: + futures (list): a list of :class:`~torch.futures.Future` objects. + + Returns: + Returns a :class:`~torch.futures.Future` object to a list of the passed + in Futures. + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) + >>> fut0 = torch.futures.Future() + >>> fut1 = torch.futures.Future() + >>> fut = torch.futures.collect_all([fut0, fut1]) + >>> fut0.set_result(0) + >>> fut1.set_result(1) + >>> fut_list = fut.wait() + >>> print(f"fut0 result = {fut_list[0].wait()}") + fut0 result = 0 + >>> print(f"fut1 result = {fut_list[1].wait()}") + fut1 result = 1 + """ + return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures))) + + +def wait_all(futures: List[Future]) -> List: + r""" + Waits for all provided futures to be complete, and returns + the list of completed values. If any of the futures encounters an error, + the method will exit early and report the error not waiting for other + futures to complete. + + Args: + futures (list): a list of :class:`~torch.futures.Future` object. + + Returns: + A list of the completed :class:`~torch.futures.Future` results. This + method will throw an error if ``wait`` on any + :class:`~torch.futures.Future` throws. + """ + return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()] diff --git a/lib/python3.10/site-packages/torch/fx/__init__.py b/lib/python3.10/site-packages/torch/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd04cdd09d7fa1a877e64d114db6e037849d7ba1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/__init__.py @@ -0,0 +1,89 @@ +r''' +FX is a toolkit for developers to use to transform ``nn.Module`` +instances. FX consists of three main components: a **symbolic tracer,** +an **intermediate representation**, and **Python code generation**. A +demonstration of these components in action: + +:: + + import torch + # Simple module for demonstration + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return self.linear(x + self.param).clamp(min=0.0, max=1.0) + + module = MyModule() + + from torch.fx import symbolic_trace + # Symbolic tracing frontend - captures the semantics of the module + symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) + + # High-level intermediate representation (IR) - Graph representation + print(symbolic_traced.graph) + """ + graph(): + %x : [num_users=1] = placeholder[target=x] + %param : [num_users=1] = get_attr[target=param] + %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) + %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {}) + %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) + return clamp + """ + + # Code generation - valid Python code + print(symbolic_traced.code) + """ + def forward(self, x): + param = self.param + add = x + param; x = param = None + linear = self.linear(add); add = None + clamp = linear.clamp(min = 0.0, max = 1.0); linear = None + return clamp + """ + +The **symbolic tracer** performs "symbolic execution" of the Python +code. It feeds fake values, called Proxies, through the code. Operations +on theses Proxies are recorded. More information about symbolic tracing +can be found in the :func:`symbolic_trace` and :class:`Tracer` +documentation. + +The **intermediate representation** is the container for the operations +that were recorded during symbolic tracing. It consists of a list of +Nodes that represent function inputs, callsites (to functions, methods, +or :class:`torch.nn.Module` instances), and return values. More information +about the IR can be found in the documentation for :class:`Graph`. The +IR is the format on which transformations are applied. + +**Python code generation** is what makes FX a Python-to-Python (or +Module-to-Module) transformation toolkit. For each Graph IR, we can +create valid Python code matching the Graph's semantics. This +functionality is wrapped up in :class:`GraphModule`, which is a +:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a +``forward`` method generated from the Graph. + +Taken together, this pipeline of components (symbolic tracing -> +intermediate representation -> transforms -> Python code generation) +constitutes the Python-to-Python transformation pipeline of FX. In +addition, these components can be used separately. For example, +symbolic tracing can be used in isolation to capture a form of +the code for analysis (and not transformation) purposes. Code +generation can be used for programmatically generating models, for +example from a config file. There are many uses for FX! + +Several example transformations can be found at the +`examples `__ +repository. +''' + +from .graph_module import GraphModule +from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta +from .graph import Graph, CodeGen +from .node import Node, map_arg, has_side_effect +from .proxy import Proxy +from .interpreter import Interpreter as Interpreter, Transformer as Transformer +from .subgraph_rewriter import replace_pattern diff --git a/lib/python3.10/site-packages/torch/fx/__init__.pyi b/lib/python3.10/site-packages/torch/fx/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..0a263dfc5071ddee675ec517f2cbac13b51ce9e8 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/__init__.pyi @@ -0,0 +1,15 @@ +from torch.fx._symbolic_trace import ( + symbolic_trace as symbolic_trace, + Tracer as Tracer, + wrap as wrap, +) +from torch.fx.graph import Graph as Graph +from torch.fx.graph_module import GraphModule as GraphModule +from torch.fx.interpreter import Interpreter as Interpreter, Transformer as Transformer +from torch.fx.node import ( + has_side_effect as has_side_effect, + map_arg as map_arg, + Node as Node, +) +from torch.fx.proxy import Proxy as Proxy +from torch.fx.subgraph_rewriter import replace_pattern as replace_pattern diff --git a/lib/python3.10/site-packages/torch/fx/_compatibility.py b/lib/python3.10/site-packages/torch/fx/_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..27c1e600036df46ced25e722ebf493c2131beda0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/_compatibility.py @@ -0,0 +1,36 @@ +from typing import Any, Dict, Callable, TypeVar +import textwrap + +_BACK_COMPAT_OBJECTS : Dict[Any, None] = {} +_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {} + +_T = TypeVar("_T") + +def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]: + if is_backward_compatible: + + def mark_back_compat(fn: _T) -> _T: + docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') + docstring += """ +.. note:: + Backwards-compatibility for this API is guaranteed. +""" + fn.__doc__ = docstring + _BACK_COMPAT_OBJECTS.setdefault(fn) + _MARKED_WITH_COMPATIBILITY.setdefault(fn) + return fn + + return mark_back_compat + else: + + def mark_not_back_compat(fn: _T) -> _T: + docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') + docstring += """ +.. warning:: + This API is experimental and is *NOT* backward-compatible. +""" + fn.__doc__ = docstring + _MARKED_WITH_COMPATIBILITY.setdefault(fn) + return fn + + return mark_not_back_compat diff --git a/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py b/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..2a14fce3782e9a33a0d2396f2d57fe24d7730ef2 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py @@ -0,0 +1,185 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +from torch.fx import GraphModule +from torch.fx.graph_module import ( + _format_import_block, + reduce_graph_module, + reduce_package_graph_module, +) +from torch.package import PackageExporter, sys_importer + +from ._compatibility import compatibility + + +_use_lazy_graph_module_flag = False +_force_skip_lazy_graph_module_flag = False + + +@compatibility(is_backward_compatible=False) +@contextmanager +def _force_skip_lazy_graph_module(): + """ + Skip using lazy graph module disregarding the setting of _use_lazy_graph_module. + Use to skip _LazyGraphModule when testing inductor torchscript related backend. + + torch.jit.script a _LazyGraphModule results in following error: + https://gist.github.com/shunting314/5143654c8084aed84ecd19b818258a69 + """ + try: + global _force_skip_lazy_graph_module_flag + prior = _force_skip_lazy_graph_module_flag + _force_skip_lazy_graph_module_flag = True + yield + finally: + _force_skip_lazy_graph_module_flag = prior + + +@compatibility(is_backward_compatible=False) +@contextmanager +def _use_lazy_graph_module(should_use: bool): + try: + global _use_lazy_graph_module_flag + prior = _use_lazy_graph_module_flag + _use_lazy_graph_module_flag = ( + should_use and not _force_skip_lazy_graph_module_flag + ) + yield + finally: + _use_lazy_graph_module_flag = prior + + +@compatibility(is_backward_compatible=False) +def _get_graph_module_cls(): + return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule + + +def _make_graph_module(*args, graph_module_cls=None, **kwargs): + if graph_module_cls is None: + graph_module_cls = _get_graph_module_cls() + + return graph_module_cls(*args, **kwargs) + + +@compatibility(is_backward_compatible=False) +class _LazyGraphModule(GraphModule): + """ + The main difference between _LazyGraphModule and GraphModule is how recompile happens. + GraphModule will do a 'recompile' call to generate python code and the forward method when it's + constructed. Later on if the graph get updated, recompile method can be called again to refresh + the saved python code and forward method. + + However in some cases especially in inductor, the recompilation can be a waste since we never + check the python code for the graph module or call its forward method. A few more concreate + examples regarding pattern matching fx passes in inductor: + 1. some passes will update the graph to be compiled and then call recompile on the GraphModule. + 2. some passes will trace small pattern function to search it in the graph being compiled and + replace the match with the traced graph of a replacement function. The pattern graph and + replacement graph are quite small but there are large amount of them. Doing GraphModule.recompile + for them in GraphModule.__init__ is also a waste of time. + + However simply skip calling GraphModule.recompile in these scenarios is also dangeruous. + People may want to check the python code or call the GraphModule's forward method for debugging purposes. + + The way _LazyGraphModule solves it is, we override the recompile method to just mark the + need for recompilation but does not do the actual recompilation. Later on if people really + access the compiled python code or call the GraphModule's forward method, we do the real + recompilation. + """ + + @classmethod + def from_graphmodule(cls, gm: GraphModule): + if isinstance(gm, _LazyGraphModule): + return gm + else: + return _LazyGraphModule(gm, gm.graph) + + @staticmethod + def force_recompile(gm): + """ + Sometimes we need force a recompile as a workaround + - we want to do the real recompilation before symbolic_trace to avoid error: + https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 + """ + if isinstance(gm, _LazyGraphModule): + gm.real_recompile() + + def real_recompile(self): + if self._needs_recompile(): + self._real_recompile() + + @classmethod + def _needs_recompile(cls): + return cls.forward is cls._lazy_forward + + def _lazy_forward(self, *args, **kwargs): + # Call self.real_recompile() rather than self._real_recompile() here. + # The _lazy_forward method may be saved and call repeatedly. + # Calling self.real_recompile can make sure we skip recompilation if + # we have already done so. + self.real_recompile() + assert not self._needs_recompile() + + # call `__call__` rather than 'forward' since recompilation may + # install a wrapper for `__call__` to provide a customized error + # message. + return self(*args, **kwargs) + + forward = _lazy_forward + + # TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__, + # or __reduce__ by calling _real_recompile. But I don't find a good way + # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule + # will be used in torch::deploy. So it's skipped for now. + + def __reduce_package__(self, exporter: PackageExporter): + """ + Follow GraphModule.__reduce__ but call 'self._real_recompile' rather + than 'self.recompile' since for a _LazyGraphModule, self.recompile just + mark the need of recompilation and does not return the PythonCode object. + """ + python_code = self._real_recompile() + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + generated_module_name = f"fx-generated._{exporter.get_unique_id()}" + import_block = _format_import_block(python_code.globals, exporter.importer) + module_code = import_block + self.code + exporter.save_source_string(generated_module_name, module_code) + return ( + reduce_package_graph_module, + (dict_without_graph, generated_module_name), + ) + + def __reduce__(self): + """ + Follow GraphModule.__reduce__ but call 'self._real_recompile' rather + than 'self.recompile' since for a _LazyGraphModule, self.recompile just + mark the need of recompilation and does not return the PythonCode object. + """ + python_code = self._real_recompile() + dict_without_graph = self.__dict__.copy() + import_block = _format_import_block(python_code.globals, sys_importer) + del dict_without_graph["_graph"] + return (reduce_graph_module, (dict_without_graph, import_block)) + + def _real_recompile(self): + return super().recompile() + + @classmethod + def recompile(cls): + cls.forward = cls._lazy_forward + + @property + def code(self) -> str: + self.real_recompile() + return super().code + + def __str__(self) -> str: + """ + str(GraphModule) will access the _code attribute. Make sure recompile + happens so _code attribute is available. + """ + self.real_recompile() + return super().__str__() diff --git a/lib/python3.10/site-packages/torch/fx/_pytree.py b/lib/python3.10/site-packages/torch/fx/_pytree.py new file mode 100644 index 0000000000000000000000000000000000000000..2ccbfdf048c943e811cd7d9fc092a81d2c30bcea --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/_pytree.py @@ -0,0 +1,103 @@ +# mypy: allow-untyped-defs +from collections import namedtuple +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type + +import torch.return_types +from torch.utils._pytree import PyTree, TreeSpec + + +FlattenFuncSpec = Callable[[PyTree, TreeSpec], List] +FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool] + +SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {} +SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {} + + +def register_pytree_flatten_spec( + cls: Type[Any], + flatten_fn_spec: FlattenFuncSpec, + flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None, +) -> None: + SUPPORTED_NODES[cls] = flatten_fn_spec + SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec + + +def tree_flatten_spec( + pytree: PyTree, + spec: TreeSpec, + exact_structural_match=False, +) -> List[Any]: + if spec.is_leaf(): + return [pytree] + if spec.type not in SUPPORTED_NODES: + raise RuntimeError( + f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with " + "torch.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make " + "sure that any custom pytrees have been registered before loading it.", + ) + flatten_fn_spec = SUPPORTED_NODES[spec.type] + child_pytrees = flatten_fn_spec(pytree, spec) + if exact_structural_match: + flatten_fn_exact_match_spec = SUPPORTED_NODES_EXACT_MATCH[spec.type] + if flatten_fn_exact_match_spec and not flatten_fn_exact_match_spec( + pytree, + spec, + ): + raise RuntimeError(f"Cannot flatten pytree {pytree}, given spec: {spec}") + result = [] + for child, child_spec in zip(child_pytrees, spec.children_specs): + flat = tree_flatten_spec(child, child_spec, exact_structural_match) + result += flat + return result + + +def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]: + return [d[k] for k in spec.context] + + +def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]: + return [d[i] for i in range(spec.num_children)] + + +def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]: + return [d[i] for i in range(spec.num_children)] + + +def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]: + return [d[i] for i in range(spec.num_children)] + + +def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match) +register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match) +register_pytree_flatten_spec( + tuple, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, +) +for return_type in torch.return_types.all_return_types: + register_pytree_flatten_spec( + return_type, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, + ) +register_pytree_flatten_spec( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten_spec, + _namedtuple_flatten_spec_exact_match, +) diff --git a/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py b/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..6693863386513c8fd7bce53b1b3f2cbf8b0bb815 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py @@ -0,0 +1,1290 @@ +# mypy: allow-untyped-defs +import builtins +import copy +import contextlib +import functools +import inspect +import math +import os +import warnings +import collections +from itertools import chain +from types import CodeType, FunctionType, ModuleType +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, + Union, +) + +import torch +import torch.utils._pytree as pytree +from torch._C import ScriptObject # type: ignore[attr-defined] +from torch._library.fake_class_registry import FakeScriptObject + +from ._compatibility import compatibility +from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph +from .graph_module import GraphModule +from ._lazy_graph_module import _make_graph_module +from .node import Argument, base_types, map_aggregate +from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager + +HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS + +# These need to run in global scope to handle nested calls correctly +_orig_module_call: Callable = torch.nn.Module.__call__ +_orig_module_getattr: Callable = torch.nn.Module.__getattr__ + +_proxyable_classes: Dict[Type, None] = {} + +_is_fx_tracing_flag = False + + +def is_fx_tracing(): + return _is_fx_tracing_flag + +@compatibility(is_backward_compatible=True) +class ProxyableClassMeta(type): + """ + ProxyableClassMeta allows you to make construction of a given Python class + symbolically traceable. For example:: + + import torch + import torch.fx + + class TensorPair(metaclass=torch.fx.ProxyableClassMeta): + def __init__(self, left, right): + self.left, self.right = left, right + + def add(self, other): + l = self.left + other.left + r = self.right + other.right + return TensorPair(l, r) + + def mul(self, other): + l = self.left * other.left + r = self.right * other.right + return TensorPair(l, r) + + def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): + s = x.add(TensorPair(y, y)) + return s.mul(x) + + x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) + y = torch.randn(5, 3) + ref_out = use_tensor_pair_ctor(x, y) + + traced = torch.fx.symbolic_trace(use_tensor_pair_ctor) + print(traced.code) + ''' + def forward(self, x : __main___TensorPair, y : torch.Tensor): + tensor_pair = __main___TensorPair(y, y); y = None + add = x.add(tensor_pair); tensor_pair = None + mul = add.mul(x); add = x = None + return mul + ''' + + From this example, we can see that construction of a class (``TensorPair``) + defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic + tracing. + """ + + def __init__(cls, name, bases, attrs): + _proxyable_classes.setdefault(cls) + super().__init__(name, bases, attrs) + + def __call__(cls, *args, **kwargs): + instance = cls.__new__(cls) # type: ignore[call-overload] + + if not is_fx_tracing(): + cls.__init__(instance, *args, **kwargs) # type: ignore[misc] + return instance + + found_proxies = [] + + def check_proxy(a): + if isinstance(a, Proxy): + found_proxies.append(a) + + map_aggregate(args, check_proxy) + map_aggregate(kwargs, check_proxy) + + if len(found_proxies) != 0: + tracer = found_proxies[0].tracer + return tracer.create_proxy("call_function", cls, args, kwargs) + else: + cls.__init__(instance, *args, **kwargs) # type: ignore[misc] + return instance + + +def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: + co = fn.__code__ + co_flags = co.co_flags & ~HAS_VARSTUFF + co_args: tuple + if hasattr(co, "co_qualname"): + # Python-3.11+ code signature + co_args = ( + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_qualname, # type: ignore[attr-defined] + co.co_firstlineno, + co.co_lnotab, + co.co_exceptiontable, # type: ignore[attr-defined] + co.co_freevars, + co.co_cellvars, + ) + elif hasattr(co, "co_posonlyargcount"): + co_args = ( + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + else: + co_args = ( + nargs, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + new_code = CodeType(*co_args) # type: ignore[arg-type] + return FunctionType( + new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ + ) + + # we need to insert placeholder nodes for *args and **kwargs + # we can't call this function normally, otherwise it would try to unpack them + # instead, let's make python think that args and kwargs are normal variables + + +@compatibility(is_backward_compatible=False) +class PHBase: + """ + Object representing an input placeholder to `concrete_args` + """ + + def __repr__(self): + return "PH" + + +PH = PHBase() + + +@compatibility(is_backward_compatible=False) +class PHWithMeta(PHBase): + """ + Object representing an input placeholder to `concrete_args` + """ + def __init__(self, ph_key: Optional[str] = None): + super().__init__() + + # Provide a hey for user to identify placeholder node during analysis + self.ph_key = ph_key + + +def _transfer_attrs(fr, to): + for attr_name in dir(fr): + attr_val = getattr(fr, attr_name) + if ( + not callable(attr_val) + and not attr_name.startswith("__") + and not hasattr(to, attr_name) + ): + setattr(to, attr_name, attr_val) + + +@compatibility(is_backward_compatible=True) +class Tracer(TracerBase): + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `math`s path from the + # build environment (e.g. ` None: + # This method's signature is overridden by the first line of this class' + # docstring. If this method's signature is modified, the signature that + # overrides it also should be modified accordingly. + + """ + Construct a Tracer object. + + Args: + + autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, + Python modules whose functions should be wrapped automatically + without needing to use fx.wrap(). Backward-compatibility for + this parameter is guaranteed. + + autowrap_functions (Tuple[Callable, ...]): defaults to `()`, + Python functions that should be wrapped automatically without + needing to use fx.wrap(). Backward compatibility for this + parameter is guaranteed. + + param_shapes_constant (bool): When this flag is set, calls to shape, + size and a few other shape like attributes of a module's parameter + will be evaluated directly, rather than returning a new Proxy value + for an attribute access. Backward compatibility for this parameter + is guaranteed. + """ + + super().__init__() + + # Functions we will eagerly wrap when we see them while tracing + # this captures both `math.sqrt()` and `from math import sqrt` automatically + self._autowrap_function_ids: Set[int] = { + id(value) + for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) + if not name.startswith("_") and callable(value) + } + self._autowrap_function_ids.update({id(f) for f in autowrap_functions}) + + # Python modules to apply autowrap to at the start, in addition to + # modules we see while tracing + self._autowrap_search: List[ModuleType] = list(autowrap_modules) + self.param_shapes_constant = param_shapes_constant + + self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None + self.root_module_name: str = "" + # Maps the containing module's name to the operator name + self.scope = Scope("", None) + # Records the module call stack + self.module_stack = collections.OrderedDict() + # Mapping of node name to module scope + self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} + + _qualname_counter: Dict[str, int] = collections.defaultdict(int) + + @compatibility(is_backward_compatible=True) + def get_fresh_qualname(self, prefix: str) -> str: + """ + Gets a fresh name for a prefix and returns it. This function ensures + that it will not clash with an existing attribute on the graph. + """ + # The idea here is that if the module doesn't have this prefix at all we + # should reset the counter to start from the beginning + # It's a ... little bit hacky (doesn't cover all cases) but the precise + # naming of the prefixes isn't a correctness issue, just a niceness + # issue + qualname = f"{prefix}0" + if not hasattr(self.root, qualname): + self._qualname_counter[prefix] = 0 + return qualname + + i = self._qualname_counter[prefix] + while True: + qualname = f"{prefix}{i}" + i += 1 + if not hasattr(self.root, qualname): + break + self._qualname_counter[prefix] = i + + return qualname + + @compatibility(is_backward_compatible=True) + def create_arg(self, a: Any) -> "Argument": + """ + A method to specify the behavior of tracing when preparing values to + be used as arguments to nodes in the ``Graph``. + + By default, the behavior includes: + + #. Iterate through collection types (e.g. tuple, list, dict) and recursively + call ``create_args`` on the elements. + #. Given a Proxy object, return a reference to the underlying IR ``Node`` + #. Given a non-Proxy Tensor object, emit IR for various cases: + + * For a Parameter, emit a ``get_attr`` node referring to that Parameter + * For a non-Parameter Tensor, store the Tensor away in a special + attribute referring to that attribute. + + This method can be overridden to support more types. + + Args: + + a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. + + + Returns: + + The value ``a`` converted into the appropriate ``Argument`` + """ + # The base tracer is used to construct Graphs when there is no associated + # module hierarchy, so it can never create parameter references. + # The default tracer adds the ability to refer to parameters when + # tracing modules. + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node("get_attr", n, (), {}) + raise NameError("parameter is not a member of this module") + elif isinstance(a, torch.Tensor): + for n_, p_ in self.root.named_buffers(): + if a is p_: + return self.create_node("get_attr", n_, (), {}) + elif isinstance(a, torch.nn.Module): + for n_, p_ in self.root.named_modules(): + if a is p_: + return self.create_node("get_attr", n_, (), {}) + # For NamedTuple instances that appear literally as args, we emit + # a node to construct the NamedTuple and use that Node as the argument. + if isinstance(a, tuple) and hasattr(a, "_fields"): + args = tuple(self.create_arg(elem) for elem in a) + return self.create_node("call_function", a.__class__, args, {}) + + # Tensors do not have a reliable string repr() from which they can be + # constructed (and we probably don't want to rely on that, either), so + # for any constant Tensor values we encounter, first search for if they + # are an attribute of some module in the module hierarchy. If so, emit + # a get_attr to retrieve that tensor. Otherwise, we'll store away the + # tensor value into a special attribute on the Module s.t. we can + # retrieve it with a get_attr. + if isinstance(a, (torch.Tensor, ScriptObject, FakeScriptObject)): + qualname: Optional[str] = self.tensor_attrs.get(a) + + # Tensor was not found in the Module hierarchy, stow it away in a + # special attribute and set the qualname to refer to that + if not qualname: + base_name = "_tensor_constant" if isinstance(a, torch.Tensor) else "_torchbind_obj" + qualname = self.get_fresh_qualname(base_name) + assert isinstance(qualname, str) + self.tensor_attrs[a] = qualname + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + + if type(a) in _proxyable_classes: + # This is an instance of a proxyable class for which we did not + # witness its construction. Intern this as a constant attribute + + # TODO: binary search + qualname = self.get_fresh_qualname(f"_{a.__class__.__name__}_constant_") + assert isinstance(qualname, str) + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + + return super().create_arg(a) + + @compatibility(is_backward_compatible=True) + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + A method to specify whether a given ``nn.Module`` is a "leaf" module. + + Leaf modules are the atomic units that appear in + the IR, referenced by ``call_module`` calls. By default, + Modules in the PyTorch standard library namespace (torch.nn) + are leaf modules. All other modules are traced through and + their constituent ops are recorded, unless specified otherwise + via this parameter. + + Args: + + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. For example, + if you have a module hierarchy where submodule ``foo`` contains + submodule ``bar``, which contains submodule ``baz``, that module will + appear with the qualified name ``foo.bar.baz`` here. + """ + return ( + (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) + and not isinstance(m, torch.nn.Sequential) + ) + + @compatibility(is_backward_compatible=True) + def path_of_module(self, mod: torch.nn.Module) -> str: + """ + Helper method to find the qualified name of ``mod`` in the Module hierarchy + of ``root``. For example, if ``root`` has a submodule named ``foo``, which has + a submodule named ``bar``, passing ``bar`` into this function will return + the string "foo.bar". + + Args: + + mod (str): The ``Module`` to retrieve the qualified name for. + """ + # Prefer the O(1) algorithm + if self.submodule_paths: + path = self.submodule_paths.get(mod) + if path is None: + raise NameError("module is not installed as a submodule") + assert isinstance(path, str) + return path + # O(N^2) fallback in the case that we didn't store the submodule + # paths. + else: + for n, p in self.root.named_modules(): + if mod is p: + return n + raise NameError("module is not installed as a submodule") + + @compatibility(is_backward_compatible=True) + def call_module( + self, + m: torch.nn.Module, + forward: Callable[..., Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> Any: + """ + Method that specifies the behavior of this ``Tracer`` when it encounters + a call to an ``nn.Module`` instance. + + By default, the behavior is to check if the called module is a leaf module + via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to + ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through + the operations in its ``forward`` function. + + This method can be overridden to--for example--create nested traced + GraphModules, or any other behavior you would want while tracing across + ``Module`` boundaries. + + Args: + + m (Module): The module for which a call is being emitted + forward (Callable): The forward() method of the ``Module`` to be invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a ``call_module`` + node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever + value was returned from the ``Module`` invocation. + """ + module_qualified_name = self.path_of_module(m) + with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: + # module_stack is an ordered dict so writing then deleting the + # entry is equivalent to push/pop on a list + self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type) + if not self.is_leaf_module(m, module_qualified_name): + ret_val = forward(*args, **kwargs) + else: + ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs) + key, _ = self.module_stack.popitem(last=True) + assert key == _scope.module_path, f" Unexpected key {key}" + + return ret_val + + @compatibility(is_backward_compatible=False) + def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): + """ + Method that specifies the behavior of this ``Tracer`` when we call getattr + on a call to an ``nn.Module`` instance. + + By default, the behavior is to return a proxy value for the attribute. It + also stores the proxy value in the ``parameter_proxy_cache``, so that future + calls will reuse the proxy rather than creating a new one. + + This method can be overridden to --for example-- not return proxies when + querying parameters. + + Args: + + attr (str): The name of the attribute being queried + attr_val (Any): The value of the attribute + parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies + + Return: + + The return value from the getattr call. + """ + def maybe_get_proxy_for_attr( + attr_val, collection_to_search, parameter_proxy_cache + ): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if ( + "proxy_factory_fn" + in inspect.signature(self.create_proxy).parameters + ): + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ParameterProxy( + self, node, n, attr_val + ) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_buffers(), parameter_proxy_cache + ) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + return attr_val + + # This method will be refactored + @compatibility(is_backward_compatible=False) + def create_args_for_root(self, root_fn, is_module, concrete_args=None): + """ + Create ``placeholder`` nodes corresponding to the signature of the ``root`` + Module. This method introspects root's signature and emits those + nodes accordingly, also supporting ``*args`` and ``**kwargs``. + """ + # In some cases, a function or method has been decorated with a wrapper + # defined via ``functools.wraps``. In this case, the outer code object + # will likely not contain the actual parameters we care about, so unwrap + # the function to get to the innermost callable. + fn_for_analysis = inspect.unwrap(root_fn) + co = fn_for_analysis.__code__ + total_args = co.co_argcount + co.co_kwonlyargcount + orig_args = list(co.co_varnames) + names_iter = iter(co.co_varnames) + args: List[Any] = [] + skip_arg_idx = 0 + if is_module: + if total_args == 0: + raise RuntimeError( + "``self`` argument cannot be part of *args expansion!" + ) + skip_arg_idx = 1 + next(names_iter) # skip self + args.append(self.root) + + sig = inspect.signature(fn_for_analysis) + + + # This covers the very specific case where we are passing in flat + # concrete_args as a tuple, but our traced fn takes (*args, **kwargs). + # In this case, just take the concrete_args and pass them through. + name_idx = 0 + if isinstance(concrete_args, tuple) and \ + len(concrete_args) > 0 and \ + (co.co_flags & HAS_VARSTUFF) and \ + total_args == 1: + for concrete_arg in concrete_args: + out = self.create_proxy("placeholder", f"input_{name_idx}", (), {}) + if isinstance(concrete_arg, PHBase): + if concrete_arg != PH: + # Transfer attrs in the case where you're using a placeholder other + # than the singleton PH (PH has no attributes to transfer). + # Proxies were created out of the placeholders. + # Transfer any metadata (put on the placeholders in the form of + # attributes set by the user) from the placeholder to the + # underlying nodes (the proxy is unwrapped by the user, but + # the metadata should hold). + _transfer_attrs(fr=concrete_arg, to=out.node) + args.append(out) + name_idx += 1 + return root_fn, args + + arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] + if isinstance(concrete_args, tuple): + if len(arg_names) != len(concrete_args): + raise RuntimeError( + f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" + ) + concrete_args = dict(zip(arg_names, concrete_args)) + + def proxy_placeholder(name): + return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis) + + args.extend(proxy_placeholder(names) for names in arg_names) + + if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: + # TODO: type annotations for *args and **kwargs + if co.co_flags & inspect.CO_VARARGS: + args.append(proxy_placeholder("*" + next(names_iter))) + if co.co_flags & inspect.CO_VARKEYWORDS: + args.append(proxy_placeholder("**" + next(names_iter))) + root_fn = _patch_function(root_fn, len(args)) + + flat_args, in_spec = pytree.tree_flatten(tuple(args)) + if not all(child.is_leaf() for child in in_spec.children_specs): + # In the case that we have pytree-flattened inputs in + # `concrete_args`, generate a flattening wrapper around the + # original root function and return that. + self.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo(orig_args[:total_args], in_spec, None) + ) + + def flatten_fn(*args): + tree_args = pytree.tree_unflatten(list(args), in_spec) + tree_out = root_fn(*tree_args) + out_args, out_spec = pytree.tree_flatten(tree_out) + assert isinstance(self.graph._codegen, _PyTreeCodeGen) + self.graph._codegen.pytree_info = ( + self.graph._codegen.pytree_info._replace(out_spec=out_spec) + ) + return out_args + + return flatten_fn, flat_args + return root_fn, args + + @compatibility(is_backward_compatible=True) + def trace( + self, + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + ) -> Graph: + """ + Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` + can either be an ``nn.Module`` instance or a Python callable. + + Note that after this call, ``self.root`` may be different from the ``root`` passed + in here. For example, when a free function is passed to ``trace()``, we will + create an ``nn.Module`` instance to use as the root and add embedded constants + to. + + + Args: + + root (Union[Module, Callable]): Either a ``Module`` or a function to be + traced through. Backwards-compatibility for this parameter is + guaranteed. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that should + not be treated as Proxies. This parameter is experimental and + its backwards-compatibility is *NOT* guaranteed. + + Returns: + + A ``Graph`` representing the semantics of the passed-in ``root``. + """ + global _is_fx_tracing_flag + old_is_fx_tracing_flag = _is_fx_tracing_flag + _is_fx_tracing_flag = True + try: + if isinstance(root, torch.nn.Module): + + # do real recompilation for _LazyGraphModule before retracing since the trace + # method can not trace the _lazy_forward method. Got error: + # https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 + # without this. + from torch.fx._lazy_graph_module import _LazyGraphModule + _LazyGraphModule.force_recompile(root) + + self.root = root + + assert hasattr( + type(root), self.traced_func_name + ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" + + fn = getattr(type(root), self.traced_func_name) + self.root_module_name = root._get_name() + self.submodule_paths = {mod: name for name, mod in root.named_modules()} + else: + self.root = torch.nn.Module() + fn = root + + tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None) + self.graph = Graph(tracer_cls=tracer_cls) + if hasattr(fn, '__code__'): + code = fn.__code__ + self.graph._co_fields = { + 'co_name': code.co_name, + 'co_filename': code.co_filename, + 'co_firstlineno': code.co_firstlineno, + } + + # When we encounter a Tensor value that's not a parameter, we look if it + # is some other attribute on the model. Construct a dict mapping Tensor + # values to the qualified name here for efficiency. This is used downstream + # in create_arg + self.tensor_attrs: Dict[ + Union[ + torch.Tensor, + ScriptObject, + FakeScriptObject + ], str + ] = {} + + def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + for k, v in m.__dict__.items(): + if isinstance(v, (torch.Tensor, ScriptObject, FakeScriptObject)): + self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + assert isinstance(fn, FunctionType) + + fn_globals = fn.__globals__ # run before it gets patched + fn, args = self.create_args_for_root( + fn, isinstance(root, torch.nn.Module), concrete_args + ) + + parameter_proxy_cache: Dict[ + str, Proxy + ] = {} # Reduce number of get_attr calls + + # Method dispatch on parameters is not recorded unless it's directly used. + # Thus, we need to insert a proxy when __getattr__ requests a parameter. + @functools.wraps(_orig_module_getattr) + def module_getattr_wrapper(mod, attr): + attr_val = _orig_module_getattr(mod, attr) + return self.getattr(attr, attr_val, parameter_proxy_cache) + + @functools.wraps(_orig_module_call) + def module_call_wrapper(mod, *args, **kwargs): + def forward(*args, **kwargs): + return _orig_module_call(mod, *args, **kwargs) + + _autowrap_check( + patcher, # type: ignore[has-type] + getattr(getattr(mod, "forward", mod), "__globals__", {}), + self._autowrap_function_ids, + ) + return self.call_module(mod, forward, args, kwargs) + + with _new_patcher() as patcher: + # allow duplicate patches to support the case of nested calls + patcher.patch_method( + torch.nn.Module, + "__getattr__", + module_getattr_wrapper, + deduplicate=False, + ) + patcher.patch_method( + torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False + ) + _patch_wrapped_functions(patcher) + _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) + for module in self._autowrap_search: + _autowrap_check( + patcher, module.__dict__, self._autowrap_function_ids + ) + self.create_node( + "output", + "output", + (self.create_arg(fn(*args)),), + {}, + type_expr=fn.__annotations__.get("return", None), + ) + + self.submodule_paths = None + finally: + _is_fx_tracing_flag = old_is_fx_tracing_flag + return self.graph + + def __deepcopy__(self, memo): + # _autowrap_search contains modules, which cannot be deepcopied. + new_tracer = Tracer.__new__(Tracer) + + for k, v in self.__dict__.items(): + if k in {'_autowrap_search'}: + new_obj = copy.copy(v) + else: + new_obj = copy.deepcopy(v, memo) + + new_tracer.__dict__[k] = new_obj + + return new_tracer + + def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis): + if concrete_args is not None and name in concrete_args: + cnt = 0 + + def replace_ph(x): + nonlocal cnt + cnt += 1 + param = sig.parameters[name] + default = ( + () + if param.default is inspect.Parameter.empty + else (param.default,) + ) + out = self.create_proxy( + "placeholder", f"{name}_{str(cnt)}", default, {} + ) + if isinstance(x, PHBase): + if x != PH: + # Transfer attrs in the case where you're using a placeholder other + # than the singleton PH (PH has no attributes to transfer). + # Proxies were created out of the placeholders. + # Transfer any metadata (put on the placeholders in the form of + # attributes set by the user) from the placeholder to the + # underlying nodes (the proxy is unwrapped by the user, but + # the metadata should hold). + _transfer_attrs(fr=x, to=out.node) + + return out + # Union[int, bool] == bool in Python <= 3.6 + if ( + type(x) == bool + or type(x) in base_types + and type(x) != torch.Tensor + ): + torch._assert( + out == x, + f"{name} has been specialized to have value {x} but got another value", + ) + elif x is None: + args = ( + out, + f"{name} has been specialized to have value None but got another value", + ) + self.create_proxy("call_function", _assert_is_none, args, {}) + else: + warnings.warn( + f"Was not able to add assertion to guarantee correct input {name} to " + f"specialized function. It is up to the user to make sure that your inputs match the " + f"inputs you specialized the function with." + ) + + return x + + return pytree.tree_map(replace_ph, concrete_args[name]) + if name[0] == "*": + default = () + else: + param = sig.parameters[name] + default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] + return self.create_proxy( + "placeholder", + name, + default, + {}, + type_expr=fn_for_analysis.__annotations__.get(name, None) + ) + + +# Dictionary of (id(globals dict), function name) => globals_dict to patch for +# the purposes of the wrap() API. +# We key by the globals dict id and function name to ensure we're wrapping a given +# function only once. +_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {} + +# List of methods on classes to wrap (class type, function name) +# this currently only works for Tensor.* methods that aren't traced properly +_wrapped_methods_to_patch: List[Tuple[type, str]] = [] + +if os.environ.get("FX_PATCH_GETITEM") == "1": + # This change is needed to trace models like PositionalEmbedding from BERT: + # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py + # but causes issues in quantization documented here: + # https://github.com/pytorch/pytorch/issues/50710 + # once that is fixed we can make this the default behavior. + _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) + + +def _find_proxy(*objects_to_search): + """ + Recursively search a data structure for a Proxy() and return it, + return None if not found. + """ + proxy = None + + def find_proxy(x): + nonlocal proxy + if isinstance(x, Proxy): + proxy = x + + map_aggregate(objects_to_search, find_proxy) + return proxy + + +def _create_wrapped_func(orig_fn): + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Given an closed-over ``orig_function`` to invoke, search the args and kwargs for + a Proxy object. If there is one, emit a ``call_function`` node to preserve the + call to this leaf function directly. Otherwise, just return the results of + this function call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return_proxy = proxy.tracer.create_proxy( + "call_function", orig_fn, args, kwargs + ) + return_proxy.node.meta["is_wrapped"] = True + return return_proxy + return orig_fn(*args, **kwargs) + + return wrapped + + +def _create_wrapped_method(cls, name): + orig_fn = getattr(cls, name) + + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Search the args and kwargs for a Proxy object. If there is one, + emit a ``call_method`` node to preserve the call to this method + directly. Otherwise, just return the results of this function + call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return proxy.tracer.create_proxy("call_method", name, args, kwargs) + return orig_fn(*args, **kwargs) + + return wrapped + + +class _PatchedFn(NamedTuple): + frame_dict: Any + fn_name: str + orig_fn: Any + new_fn: Any + + def revert(self): + raise NotImplementedError + + def patch(self): + raise NotImplementedError + + +class _PatchedFnSetItem(_PatchedFn): + def revert(self): + self.frame_dict[self.fn_name] = self.orig_fn + + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn + +class _PatchedFnDel(_PatchedFn): + def revert(self): + del self.frame_dict[self.fn_name] + + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn + + +class _PatchedFnSetAttr(_PatchedFn): + def revert(self): + setattr(self.frame_dict, self.fn_name, self.orig_fn) + + def patch(self): + setattr(self.frame_dict, self.fn_name, self.new_fn) + +class _Patcher: + def __init__(self) -> None: + super().__init__() + self.patches_made: List[_PatchedFn] = [] + self.visited: Set[int] = set() + + def patch( + self, + frame_dict: Dict[str, Any], + name: str, + new_fn: Callable, + deduplicate: bool = True, + ): + """ + Replace frame_dict[name] with new_fn until we exit the context manager. + """ + new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] + if name not in frame_dict and hasattr(builtins, name): + self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn)) + self.patches_made[-1].patch() + elif getattr(frame_dict[name], "__fx_already_patched", False): + return # already patched, no need to do it again + else: + self.patches_made.append( + _PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn) + ) + self.patches_made[-1].patch() + + def patch_method( + self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True + ): + """ + Replace object_or_dict.name with new_fn until we exit the context manager. + """ + new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] + orig_fn = getattr(cls, name) + if getattr(orig_fn, "__fx_already_patched", False): + return # already patched, no need to do it again + self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn)) + self.patches_made[-1].patch() + + def visit_once(self, thing: Any): + """Return True on the first call to with thing, otherwise false""" + idx = id(thing) + if idx in self.visited: + return False + self.visited.add(idx) + return True + + def revert_all_patches(self): + """ + Remove all the stored patcheds. It doesn't modify patches_made. + """ + for patch in self.patches_made: + patch.revert() + return self.patches_made + + def reapply_all_patches(self): + """ + Patch all the stored patcheds. It doesn't modify patches_made. + """ + for patch in self.patches_made: + patch.patch() + return self.patches_made + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Undo all the changes made via self.patch() and self.patch_method() + """ + while self.patches_made: + # unpatch in reverse order to handle duplicates correctly + self.patches_made.pop().revert() + self.visited.clear() + + +CURRENT_PATCHER: Optional[_Patcher] = None + +@contextlib.contextmanager +def _new_patcher(): + global CURRENT_PATCHER + prior_patcher = CURRENT_PATCHER + try: + CURRENT_PATCHER = _Patcher() + yield CURRENT_PATCHER + finally: + # Clear all the patches made by when using current patcher. + assert CURRENT_PATCHER is not None + CURRENT_PATCHER.revert_all_patches() + CURRENT_PATCHER = prior_patcher + + +@contextlib.contextmanager +def _maybe_revert_all_patches(): + current_patcher = CURRENT_PATCHER + patches_made = None + patches_removed = None + try: + if current_patcher is not None: + patches_removed = current_patcher.revert_all_patches() + yield + finally: + if current_patcher is not None: + patches_made = current_patcher.reapply_all_patches() + assert patches_made == patches_removed, "CURRENT_PATCHER was changed during a revert_all_patches" + +def _patch_wrapped_functions(patcher: _Patcher): + """ + Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap + the listed global functions in the `_create_wrapped_func` wrapper. + """ + for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items(): + if name not in frame_dict and hasattr(builtins, name): + orig_fn = getattr(builtins, name) + else: + orig_fn = frame_dict[name] + patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) + + for cls, name in _wrapped_methods_to_patch: + patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) + + +def _autowrap_check( + patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] +): + """ + Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. + This method searches a scope for them and patches them if found. + """ + if patcher.visit_once(frame_dict): + for name, value in frame_dict.items(): + if ( + not name.startswith("_") + and callable(value) + and id(value) in function_ids + ): + patcher.patch(frame_dict, name, _create_wrapped_func(value)) + + +@compatibility(is_backward_compatible=True) +def wrap(fn_or_name: Union[str, Callable]): + """ + This function can be called at module-level scope to register fn_or_name as a "leaf function". + A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being + traced through:: + + # foo/bar/baz.py + def my_custom_function(x, y): + return x * x + y * y + + torch.fx.wrap('my_custom_function') + + def fn_to_be_traced(x, y): + # When symbolic tracing, the below call to my_custom_function will be inserted into + # the graph rather than tracing it. + return my_custom_function(x, y) + + This function can also equivalently be used as a decorator:: + + # foo/bar/baz.py + @torch.fx.wrap + def my_custom_function(x, y): + return x * x + y * y + + A wrapped function can be thought of a "leaf function", analogous to the concept of + "leaf modules", that is, they are functions that are left as calls in the FX trace + rather than traced through. + + Args: + + fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the + graph when it's called + """ + if not callable(fn_or_name) and not isinstance(fn_or_name, str): + raise RuntimeError( + "Unsupported type for global function! Must be either a callable or " + "string name" + ) + + if callable(fn_or_name): + assert not isinstance(fn_or_name, str) # to make mypy happy + fn_name = fn_or_name.__name__ + else: + assert isinstance( + fn_or_name, str + ), "fn_or_name must be a global function or string name" + fn_name = fn_or_name + + currentframe = inspect.currentframe() + assert currentframe is not None + f = currentframe.f_back + assert f is not None + if f.f_code.co_name != "": + raise NotImplementedError("wrap must be called at the top level of a module") + + # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search + # semantics would be slightly different, but would add support `from x import wrapped_function` + _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals + return fn_or_name + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, +) -> GraphModule: + """ + Symbolic tracing API + + Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` + constructed by recording operations seen while tracing through ``root``. + + ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures. + + For example:: + + def f(a, b): + if b == True: + return a + else: + return a*2 + + FX can typically not trace through this due to the presence of control + flow. However, we can use `concrete_args` to specialize on the value of + `b` to trace through this:: + + f = fx.symbolic_trace(f, concrete_args={'b': False}) + assert f(3, False) == 6 + + Note that although you can still pass in different values of `b`, they will be ignored. + + We can also use `concrete_args` to eliminate data-structure handling from + our function. This will use pytrees to flatten your input. To avoid + overspecializing, pass in `fx.PH` for values that shouldn't be + specialized. For example:: + + def f(x): + out = 0 + for v in x.values(): + out += v + return out + f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) + assert f({'a': 1, 'b': 2, 'c': 4}) == 7 + + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted + into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized + + Returns: + GraphModule: a Module created from the recorded operations from ``root``. + """ + tracer = Tracer() + graph = tracer.trace(root, concrete_args) + name = ( + root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + ) + return _make_graph_module(tracer.root, graph, name) + + +@wrap +def _assert_is_none(value, msg): + assert value is None, msg diff --git a/lib/python3.10/site-packages/torch/fx/_utils.py b/lib/python3.10/site-packages/torch/fx/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd3780fe0bb36cc47f2669a72a2b9cdc0bdcbf6 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/_utils.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs +import sys +from typing import Dict, Optional + +import torch +from torch._logging import LazyString + + +def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): + """ + Returns a LazyString that formats the graph code. + """ + + def format_name(): + if maybe_id is not None: + return f"{name} {maybe_id}" + else: + return name + + if "print_output" not in kwargs: + kwargs["print_output"] = False + + if "colored" in kwargs and not sys.stdout.isatty(): + kwargs["colored"] = False + + return LazyString( + lambda: _format_graph_code( + f"===== {format_name()} =====\n", + gm.forward.__code__.co_filename, + gm.print_readable(**kwargs), + ) + ) + + +def _format_graph_code(name, filename, graph_str): + """ + Returns a string that formats the graph code. + """ + return f"TRACED GRAPH\n {name} {filename} {graph_str}\n" + + +def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]: + """ + Returns the nn_module_stack of the first call_function node. + """ + for node in graph.nodes: + if node.op == "call_function" and "nn_module_stack" in node.meta: + return node.meta["nn_module_stack"] + return None + + +def get_node_context(node, num_nodes=2) -> str: + """ + Returns a string of the last num_nodes nodes in the graph. + """ + node_contexts = [] + cur = node + for i in range(num_nodes): + node_contexts.append(cur.format_node()) + if cur.op == "root": + break + cur = cur.prev + return "\n".join(node_contexts[::-1]) diff --git a/lib/python3.10/site-packages/torch/fx/annotate.py b/lib/python3.10/site-packages/torch/fx/annotate.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b5b5f2d376115fe25542b3b1260dcd6aaf1aaf --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/annotate.py @@ -0,0 +1,32 @@ +# mypy: allow-untyped-defs +from torch.fx.proxy import Proxy +from ._compatibility import compatibility + +@compatibility(is_backward_compatible=False) +def annotate(val, type): + """ + Annotates a Proxy object with a given type. + + This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object + Args: + val (object): An object to be annotated if its type is torch.fx.Proxy. + type (object): A type to be assigned to a given proxy object as val. + Returns: + The given val. + Raises: + RuntimeError: If a val already has a type in its node. + """ + if isinstance(val, Proxy): + if val.node.type: + raise RuntimeError(f"Tried to annotate a value that already had a type on it!" + f" Existing type is {val.node.type} " + f"and new type is {type}. " + f"This could happen if you tried to annotate a function parameter " + f"value (in which case you should use the type slot " + f"on the function signature) or you called " + f"annotate on the same value twice") + else: + val.node.type = type + return val + else: + return val diff --git a/lib/python3.10/site-packages/torch/fx/config.py b/lib/python3.10/site-packages/torch/fx/config.py new file mode 100644 index 0000000000000000000000000000000000000000..da5120d6edf180f7fbbe88ac342b4d0e4b383e50 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/config.py @@ -0,0 +1,6 @@ +# Whether to disable showing progress on compilation passes +# Need to add a new config otherwise wil get a circular import if dynamo config is imported here +disable_progress = True + +# If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy +verbose_progress = False diff --git a/lib/python3.10/site-packages/torch/fx/graph.py b/lib/python3.10/site-packages/torch/fx/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..e4fdd79fcbb280236c0dd139fa2baffa5c53a061 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/graph.py @@ -0,0 +1,1796 @@ +# mypy: allow-untyped-defs +from collections import defaultdict +from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name +import torch.utils._pytree as pytree +from . import _pytree as fx_pytree +from ._compatibility import compatibility +from torch._C import _NodeIter + +import os +import contextlib +from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type, Iterable +from dataclasses import dataclass +from contextlib import contextmanager +import copy +import enum +import torch +import keyword +import re +import builtins +import math +import warnings +import inspect + +__all__ = ["PythonCode", "CodeGen", "Graph"] + +if TYPE_CHECKING: + from .graph_module import GraphModule # noqa: F401 + from ._symbolic_trace import Tracer # noqa: F401 + + +# Mapping of builtins to their `typing` equivalent. +_origin_type_map = { + list: List, + dict: Dict, + set: Set, + frozenset: FrozenSet, + tuple: Tuple, +} + + +# Signature for functions thattransforms the body (`list[str]`) of the +# generated code +TransformCodeFunc = Callable[[List[str]], List[str]] + + +class _CustomBuiltin(NamedTuple): + """Additional objs that we add to every graph's globals. + + The repr() for some standard library objects is not valid Python code without + an import. For common objects of this sort, we bundle them in the globals of + every FX graph. + """ + # How to import this object from the standard library. + import_str: str + # The actual object, produced from that import string. + obj: Any + +_custom_builtins: Dict[str, _CustomBuiltin] = {} + + +def _register_custom_builtin(name: str, import_str: str, obj: Any): + _custom_builtins[name] = _CustomBuiltin(import_str, obj) + + +_register_custom_builtin('inf', 'from math import inf', math.inf) +_register_custom_builtin('nan', 'from math import nan', math.nan) +_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None)) +_register_custom_builtin('torch', 'import torch', torch) +_register_custom_builtin('device', 'from torch import device', torch.device) +_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree) +_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree) + + +def _is_magic(x: str) -> bool: + return x.startswith('__') and x.endswith('__') + + +def _snake_case(s: str) -> str: + """ + Transforms the given string ``s`` to a Python-style variable name + + Examples: + ``mod.snake_case`` -> ``mod.snake_case`` + ``mod.pascalCase``-> ``mod.pascal_case`` + ``mod.ALL_CAPS`` -> ``mod.all_caps`` + """ + chars = [] + prev_lower = False + for c in s: + if prev_lower and c.isupper(): + chars.append('_') + chars.append(c.lower()) + prev_lower = c.islower() + return ''.join(chars) + + +def _is_from_torch(obj: Any) -> bool: + module_name = getattr(obj, '__module__', None) + if module_name is not None: + base_module = module_name.partition('.')[0] + return ( + base_module == 'torch' and + not module_name.startswith("torch._dynamo.") and + not module_name.startswith("torch._inductor.") + ) + + name = getattr(obj, '__name__', None) + # exclude torch because torch.torch.torch.torch works. idk mang + if name is not None and name != 'torch': + for guess in [torch, torch.nn.functional]: + if getattr(guess, name, None) is obj: + return True + + return False + + +class _Namespace: + """A context for associating names uniquely with objects. + + The following invariants are enforced: + - Each object gets a single name. + - Each name is unique within a given namespace. + - Names generated do not shadow builtins, unless the object is indeed that builtin. + """ + def __init__(self): + self._obj_to_name: Dict[Any, str] = {} + self._unassociated_names = set() + self._used_names: Set[str] = set() + self._base_count: Dict[str, int] = defaultdict(int) + + self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+') + self._name_suffix_regex = re.compile(r"(.*)_(\d+)$") + + def create_name(self, candidate: str, obj: Optional[Any]) -> str: + """Create a unique name. + + Arguments: + candidate: used as the basis for the unique name, relevant to the user. + obj: If not None, an object that will be associated with the unique name. + """ + if obj is not None and obj in self._obj_to_name: + return self._obj_to_name[obj] + + # delete all characters that are illegal in a Python identifier + candidate = self._illegal_char_regex.sub('_', candidate) + + if not candidate: + candidate = '_unnamed' + + if candidate[0].isdigit(): + candidate = f'_{candidate}' + + match = self._name_suffix_regex.match(candidate) + if match is None: + base = candidate + num = None + else: + base, num_str = match.group(1, 2) + num = int(num_str) + + candidate = base if num is None else f'{base}_{num}' + if not num: + num = self._base_count[base] + + while candidate in self._used_names or self._is_illegal_name(candidate, obj): + num += 1 + candidate = f'{base}_{num}' + + self._used_names.add(candidate) + self._base_count[base] = num + if obj is None: + self._unassociated_names.add(candidate) + else: + self._obj_to_name[obj] = candidate + return candidate + + def associate_name_with_obj(self, name: str, obj: Any): + """Associate a unique name with an object. + + Neither `name` nor `obj` should be associated already. + """ + assert obj not in self._obj_to_name + assert name in self._unassociated_names + self._obj_to_name[obj] = name + self._unassociated_names.remove(name) + + def _is_illegal_name(self, name: str, obj: Any) -> bool: + # 1. keywords are never allowed as names. + if name in keyword.kwlist: + return True + + # 2. Can't shadow a builtin name, unless you *are* that builtin. + if name in builtins.__dict__: + return obj is not builtins.__dict__[name] + + # 3. Can't shadow our custom builtins either + if name in _custom_builtins: + return obj is not _custom_builtins[name].obj + + return False + + def _rename_object(self, obj: Any, name: str): + assert obj in self._obj_to_name + self._obj_to_name[obj] = name + self._used_names.add(name) + +dtype_abbrs = { + torch.bfloat16: 'bf16', + torch.float64: 'f64', + torch.float32: 'f32', + torch.float16: 'f16', + torch.float8_e4m3fn: 'f8e4m3fn', + torch.float8_e5m2: 'f8e5m2', + torch.float8_e4m3fnuz: 'f8e4m3fnuz', + torch.float8_e5m2fnuz: 'f8e5m2fnuz', + torch.complex32: 'c32', + torch.complex64: 'c64', + torch.complex128: 'c128', + torch.int8: 'i8', + torch.int16: 'i16', + torch.int32: 'i32', + torch.int64: 'i64', + torch.bool: 'b8', + torch.uint8: 'u8', + torch.uint16: 'u16', + torch.uint32: 'u32', + torch.uint64: 'u64', + torch.bits16: 'b16', +} + +@compatibility(is_backward_compatible=True) +@dataclass +class PythonCode: + """ + Represents all the information necessary to exec or save a graph as Python code. + """ + # Python source code for the forward function definition. + src: str + # Values in global scope during execution of `src_def`. + globals: Dict[str, Any] + # Optional mapping from the forward function's line number to + # node index. + _lineno_map: Optional[Dict[int, Optional[int]]] + + +def _format_target(base: str, target: str) -> str: + elems = target.split('.') + r = base + for e in elems: + if not e.isidentifier(): + r = f'getattr({r}, "{e}")' + else: + r = f'{r}.{e}' + return r + +class _InsertPoint: + def __init__(self, graph, new_insert): + self.graph = graph + self.orig_insert, graph._insert = graph._insert, new_insert + + def __enter__(self): + pass + + def __exit__(self, type, value, tb): + self.graph._insert = self.orig_insert + +class _node_list: + def __init__(self, graph: 'Graph', direction: str = '_next'): + assert direction in ['_next', '_prev'] + self.graph = graph + self.direction = direction + + def __len__(self): + return self.graph._len + + def __iter__(self): + assert self.direction == "_prev" or self.direction == "_next" + yield from _NodeIter(self.graph._root, self.direction == "_prev") + + def __reversed__(self): + return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') + +class _PyTreeInfo(NamedTuple): + """ + Contains extra info stored when we're using Pytrees + """ + orig_args: List[str] + in_spec: pytree.TreeSpec + out_spec: Optional[pytree.TreeSpec] + +@dataclass(frozen=True) +class _ParsedStackTrace: + """ + Represents the top-most frame of a parsed stack trace + """ + file: str + lineno: str + name: str + code: str + + def get_summary_str(self): + return f'File: {self.file}:{self.lineno} in {self.name}, code: {self.code}' + +# get File:lineno code from stack_trace +def _parse_stack_trace(stack_trace: str): + if stack_trace is None: + return None + pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") + lines = stack_trace.strip().split('\n') + # stacktrace should have innermost frame last, so we + # iterate backwards to find the first line that starts + # with 'File ' + summary_str = "" + for idx in range(len(lines) - 2, -1, -1): + line = lines[idx].strip() + matches = pattern.match(line) + if matches: + file = matches.group(1) + lineno = matches.group(2) + name = matches.group(3) + # next line should be the code + code = lines[idx + 1].strip() + return _ParsedStackTrace(file, lineno, name, code) + return None + +@compatibility(is_backward_compatible=False) +class CodeGen: + def __init__(self): + self._body_transformer: Optional[TransformCodeFunc] = None + self._func_name: str = "forward" + + def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str: + """ + Given the free variables and a return annotation, generates the beginning of the FX function. + By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'` + """ + # If the original function didn't have self as its first argument, we + # would have added it. + if len(free_vars) == 0 or free_vars[0] != 'self': + free_vars.insert(0, 'self') + return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" + + def generate_output(self, output_args: Argument) -> str: + """ + Given the output arguments, generates the return statement of the FX function. + Note: The returned statement should not be indented. + """ + return f'return {repr(output_args)}' + + def process_inputs(self, *args: Any) -> Any: + """ + Transforms the inputs so that the graph can take them as arguments, as + non-default codegen may result in the inputs to the function being + different from the inputs to the graph. + + If the graph was directly runnable, this invariant should hold true + `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)` + """ + return args + + def process_outputs(self, outputs: Any) -> Any: + """ + Transforms the outputs of the graph to be identical to the codegen. + + See ``process_inputs`` for more details. + """ + return outputs + + def additional_globals(self) -> List[Tuple[str, Any]]: + """ + If your codegen uses extra global values, add tuples of (identifier,reference to the value) here. + For example, return ['List', typing.List] if you need ``List`` in the global context. + """ + return [] + + def _gen_python_code( + self, nodes, root_module: str, namespace: _Namespace, *, + verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False + ) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation : List[str] = [''] + include_stride = include_stride or (os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1") + include_device = include_device or (os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1") + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o : Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + + typename = _type_repr(o) + + if hasattr(o, '__origin__'): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, '__args__'): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename + + return f'{origin_typename}[{",".join(args)}]' + else: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + codes = { + "yellow": "\033[33m", + "cyan": "\033[36m", + "green": "\033[32m", + "blue": "\033[34m", + "red": "\033[31m", + "dim": "\033[2m", + "dim_blue": "\033[2m\033[34m", + "dim_green": "\033[2m\033[32m", + "reset": "\033[0m", + } + + def make_wrapper_func(name): + def f(s): + if colored: + return f"{codes[name]}{s}{codes['reset']}" + return s + return f + + yellow = make_wrapper_func("yellow") + cyan = make_wrapper_func("cyan") + red = make_wrapper_func("red") + green = make_wrapper_func("green") + dim_green = make_wrapper_func("dim_green") + dim = make_wrapper_func("dim") + dim_blue = make_wrapper_func("dim_blue") + blue = make_wrapper_func("blue") + + def _get_repr(arg: Any) -> str: + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, '_fields'): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + qualified_name = _get_qualified_name(arg) + global_name = add_global(qualified_name, arg) + return f"{global_name}" + elif isinstance(arg, enum.Enum): + cls = arg.__class__ + clsname = add_global(cls.__name__, cls) + return f"{clsname}.{arg.name}" + elif isinstance(arg, Node): + return repr(arg) + elif isinstance(arg, torch.Tensor): + size = list(arg.size()) + dtype = str(arg.dtype).split(".")[-1] + return f"torch.Tensor(size={size}, dtype={dtype})" + else: + return blue(repr(arg)) + + + def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: + args_s = ', '.join(_get_repr(a) for a in args) + kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + if args_s and kwargs_s: + return f'{args_s}, {kwargs_s}' + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use : Dict[Node, Node] = {} + user_to_last_uses : Dict[Node, List[Node]] = {} + + def register_last_uses(n : Node, user : Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + def delete_unused_values(user : Node): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + + if len(user.users.keys()) == 0: + # This node is not used by any others. however it's also not + # removed by DCE since side-effect. We want to free it's outputs + # right after its execution done to save memory. + nodes_to_delete.append(user) + + if len(nodes_to_delete): + to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) + body.append(f'; {dim(to_delete_str)}\n') + else: + body.append('\n') + + prev_stacktrace = None + + def append_stacktrace_summary(node : Node): + """ + Append a summary of the stacktrace to the generated code. This is + useful for debugging. + """ + nonlocal prev_stacktrace + + if node.op not in {'placeholder', 'output'}: + if node.stack_trace: + if node.stack_trace != prev_stacktrace: + prev_stacktrace = node.stack_trace + summary_str = "" + + if parsed_stack_trace := _parse_stack_trace(node.stack_trace): + summary_str = parsed_stack_trace.get_summary_str() + + body.append(f'\n {dim("# " + summary_str)}\n') + elif prev_stacktrace != "": + prev_stacktrace = "" + no_stacktrace_msg = "# No stacktrace found for following nodes" + body.append(f'\n{dim(no_stacktrace_msg)}\n') + + def stringify_shape(shape : Iterable) -> str: + return f"[{', '.join(str(x) for x in shape)}]" + + def emit_node(node : Node): + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + + if verbose: + # override annotation with more detailed information + from torch.fx.experimental.proxy_tensor import py_sym_types + from torch.fx.passes.shape_prop import TensorMetadata + + meta_val = node.meta.get('val', node.meta.get('tensor_meta', node.meta.get('example_value', None))) + # use string as annotation, to make it valid python code + + if isinstance(meta_val, torch.Tensor): + stride_annotation = f"{stringify_shape(meta_val.stride())}" if include_stride else "" + device_annotation = f"{meta_val.device}" if include_device else "" + maybe_type_annotation = \ + f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}' \ + f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"' + elif isinstance(meta_val, py_sym_types): + maybe_type_annotation = f': "Sym({meta_val})"' + elif isinstance(meta_val, TensorMetadata): + maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"' + + if node.op == 'placeholder': + assert isinstance(node.target, str) + maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') + raw_name = node.target.replace('*', '') + if raw_name != repr(node): + body.append(f'{repr(node)} = {raw_name}\n') + return + elif node.op == 'call_method': + assert isinstance(node.target, str) + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}' + f'({_format_args(node.args[1:], node.kwargs)})') + return + elif node.op == 'call_function': + assert callable(node.target) + # pretty print operators + if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods: + assert isinstance(node.args, tuple) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}') + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods: + body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; ' + f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}') + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if global_name == 'getattr' and \ + isinstance(node.args, tuple) and \ + isinstance(node.args[1], str) and \ + node.args[1].isidentifier() and \ + len(node.args) == 2: + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}') + return + body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') + if node.meta.get('is_wrapped', False): + wrapped_fns.setdefault(global_name) + return + elif node.op == 'call_module': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return + elif node.op == 'get_attr': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f'node: {node.op} {node.target}') + + for i, node in enumerate(nodes): + # NOTE: emit_node does not emit a string with newline. It depends + # on delete_unused_values to append one + if verbose: + append_stacktrace_summary(node) + # emit a counter comment to keep track of + # node index, which will be deleted later + # after going through _body_transformer + body.append(f"# COUNTER: {i}\n") + emit_node(node) + delete_unused_values(node) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + + + + if len(wrapped_fns) > 0: + wrap_name = add_global('wrap', torch.fx.wrap) + wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = '' + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + + # remove counter and generate lineno to node index mapping + lineno_map: Dict[int, Optional[int]] = {} + prologue_len = prologue.count('\n') + 1 + new_lines: List[str] = [] + cur_idx = None + for line in ''.join(body).split('\n'): + counter = re.search(r"# COUNTER: (\d+)", line) + if counter and counter.group(1) is not None: + cur_idx = int(counter.group(1)) + else: + lineno_map[len(new_lines) + prologue_len] = cur_idx + new_lines.append(line) + + code = "\n".join(new_lines).lstrip('\n') + code = '\n'.join(' ' + line for line in code.split('\n')) + + fn_code = f""" +{wrap_stmts} + +{prologue} +{code}""" + return PythonCode(fn_code, globals_, _lineno_map=lineno_map) + + +# Ideally, we'd like to refactor all of the pytree logic into this codegen +# class. Unfortunately, there are 3 areas we currently need extra logic in FX. +# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`. +# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec. +# Since we can't access .graph within the FX forward, we need to copy the attribute to the module. +# 3. We currently can't register the pytree imports with `add_global` - not sure why. +class _PyTreeCodeGen(CodeGen): + def __init__(self, pytree_info: _PyTreeInfo): + super().__init__() + self.pytree_info: _PyTreeInfo = pytree_info + + def process_inputs(self, *inputs: Any) -> Any: + flat_args = pytree.arg_tree_leaves(*inputs) + return flat_args + + def process_outputs(self, out: Any) -> Any: + if self.pytree_info is None or self.pytree_info.out_spec is None: + return out + if not isinstance(out, (list, tuple)): + out = [out] + assert self.pytree_info.out_spec is not None + return pytree.tree_unflatten(out, self.pytree_info.out_spec) + + def gen_fn_def(self, free_vars, maybe_return_annotation): + # Given a user function/model: + # myargs = [myargs0, myargs1] + # mykwargs = {'mykwargs0': ..., 'mykwargs1': ...} + # def forward(self, mypos, *myargs, mykey=None, **mykwargs): + # + # The generated code flattens all keywords into positional arguments for `forward()` + # e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1): + # + # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately + # e.g. tree_flatten_spec(([mypos, myargs0, myargs1], + # {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}), + # self._in_spec) + # + # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec + # e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec) + if self.pytree_info is None: + return super().gen_fn_def(free_vars, maybe_return_annotation) + + fn_args = self.pytree_info.orig_args + has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False + if has_orig_self: + free_vars.insert(0, 'self') + fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation) + + if len(free_vars) > 0: # pytree has placeholders in it + # when kwargs is present, in_spec is tuple(args, kwargs) + has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \ + self.pytree_info.in_spec.num_children == 2 and \ + self.pytree_info.in_spec.children_specs[0].type == tuple and \ + self.pytree_info.in_spec.children_specs[1].type == dict + fn_kwargs = '{}' + fn_signature = f"[{', '.join(fn_args)}], self._in_spec" + if has_args_kwargs_tuple: + count_args = self.pytree_info.in_spec.children_specs[0].num_children + fn_args = self.pytree_info.orig_args[:count_args] + fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip( + self.pytree_info.in_spec.children_specs[1].context, + self.pytree_info.orig_args[count_args:])) + '}' + fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec" + + # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid. + # we need to split it to two lines: + # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon) + # one for code: `var1, var2, = function_call()` + without_annotation = [x.split(":")[0] for x in free_vars] + has_annotation = [x + "; " for x in free_vars if ":" in x] + if len(has_annotation) > 0: + fn_definition += "\n " + "".join(has_annotation) + "\n" + fn_definition += f""" + {', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" + return fn_definition + + def generate_output(self, output_args): + if self.pytree_info and self.pytree_info.out_spec: + return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)' + else: + return super().generate_output(output_args) + +class _FindNodesLookupTable: + """ + Side table for the graph for the purpose of doing fast queries + """ + def __init__(self): + self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict(dict) + + def _key(self, node) -> Tuple[str, Optional[Target]]: + return (node.op, node.target if node.op == "call_function" else None) + + def __contains__(self, node) -> bool: + return node in self.table[self._key(node)] + + def insert(self, node: Node) -> None: + self.table[self._key(node)][node] = None + + def remove(self, node: Node) -> None: + self.table[self._key(node)].pop(node) + + def find_nodes(self, *, op: str, target: Optional['Target'] = None): + if op == "call_function": + assert target is not None + return dict(self.table[(op, target)]).keys() + + if target is None: + return dict(self.table[(op, None)]).keys() + + # op is call_method, get_attr, call_module + return [node for node in self.table[(op, None)].keys() if node.target == target] + +@compatibility(is_backward_compatible=True) +class Graph: + """ + ``Graph`` is the main data structure used in the FX Intermediate Representation. + It consists of a series of ``Node`` s, each representing callsites (or other + syntactic constructs). The list of ``Node`` s, taken together, constitute a + valid Python function. + + For example, the following code + + .. code-block:: python + + import torch + import torch.fx + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + + Will produce the following Graph:: + + print(gm.graph) + + .. code-block:: text + + graph(x): + %linear_weight : [num_users=1] = self.linear.weight + %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) + %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) + %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) + %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) + %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) + return topk_1 + + For the semantics of operations represented in the ``Graph``, please see :class:`Node`. + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None, + tracer_extras: Optional[Dict[str, Any]] = None): + """ + Construct an empty Graph. + """ + self._root : Node = Node(self, '', 'root', '', (), {}) + self._used_names : Dict[str, int] = {} # base name -> number + self._insert = self._root.prepend + self._len = 0 + self._graph_namespace = _Namespace() + self._owning_module = owning_module + self._tracer_cls = tracer_cls + self._tracer_extras = tracer_extras + self._codegen = CodeGen() + self._co_fields : Dict[str, Any] = {} + self._find_nodes_lookup_table = _FindNodesLookupTable() + + @property + def owning_module(self): + return self._owning_module + + @owning_module.setter + def owning_module(self, mod: Optional["GraphModule"]): + self._owning_module = mod + + @property + def nodes(self) -> _node_list: + """ + Get the list of Nodes that constitute this Graph. + + Note that this ``Node`` list representation is a doubly-linked list. Mutations + during iteration (e.g. delete a Node, add a Node) are safe. + + Returns: + + A doubly-linked list of Nodes. Note that ``reversed`` can be called on + this list to switch iteration order. + """ + return _node_list(self) + + @compatibility(is_backward_compatible=False) + def find_nodes(self, *, op: str, target: Optional['Target'] = None, sort: bool = True): + """ + Allows for fast query of nodes + + Args: + + op (str): the name of the operation + + target (Optional[Target]): the target of the node. For call_function, + the target is required. For other ops, the target is optional. + + sort (bool): whether to return nodes in the order they appear on + on the graph. + + Returns: + + Iteratable of nodes with the requested op and target. + """ + node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target) + if sort: + return sorted(node_list) + return node_list + + @compatibility(is_backward_compatible=True) + def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]': + """ + Copy all nodes from a given graph into ``self``. + + Args: + + g (Graph): The source graph from which to copy Nodes. + + val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping + from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed + in with values in it already to override copying of certain values. + + Returns: + + The value in ``self`` that is now equivalent to the output value in ``g``, + if ``g`` had an ``output`` node. ``None`` otherwise. + """ + for node in g.nodes: + if node in val_map: + continue + if node.op == 'output': + rv = map_arg(node.args[0], lambda n: val_map[n]) + return rv if not return_output_node else (rv, node) + val_map[node] = self.node_copy(node, lambda n : val_map[n]) + return None + + def __deepcopy__(self, memo=None) -> 'Graph': + """ + Explicitly implement __deepcopy__ to prevent excessive recursion depth + from the default implementation. This uses graph_copy to copy the nodes + in an iterative way, rather than recursive. It also populates the + memoization table to prevent unnecessary copies (e.g. references to + nodes or other parts of the Graph from a custom GraphModule implementation. + """ + memo = memo if memo else {} + g = Graph(tracer_cls=self._tracer_cls) + output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) + g._codegen = copy.deepcopy(self._codegen) + assert isinstance(output_vals, tuple) + output_val, old_output_node = output_vals + new_output_node = g.output(output_val, type_expr=getattr(old_output_node, 'type', None)) + new_output_node.meta = copy.copy(old_output_node.meta) + return g + + @compatibility(is_backward_compatible=True) + def create_node(self, op: str, target: 'Target', + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + name: Optional[str] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Create a ``Node`` and add it to the ``Graph`` at the current insert-point. + Note that the current insert-point can be set via :meth:`Graph.inserting_before` + and :meth:`Graph.inserting_after`. + + Args: + op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', + 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are + described in the ``Graph`` docstring. + + args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node. + + kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node + + name (Optional[str]): an optional string name for the ``Node``. + This will influence the name of the value assigned to in the + Python generated code. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly-created and inserted node. + """ + assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') + args = () if args is None else args + kwargs = {} if kwargs is None else kwargs + assert isinstance(args, tuple), "args must be a tuple" + assert isinstance(kwargs, dict), "kwargs must be a dict" + + candidate = name if name is not None else self._target_to_str(target) + name = self._graph_namespace.create_name(candidate, None) + n = Node(self, name, op, target, args, kwargs, type_expr) + + if self.owning_module is not None and getattr(self.owning_module, "_create_node_hooks", None) is not None: + for f in self.owning_module._create_node_hooks: + f(n) + + self._graph_namespace.associate_name_with_obj(name, n) + + self._insert(n) + self._find_nodes_lookup_table.insert(n) + self._len += 1 + return n + + @compatibility(is_backward_compatible=False) + def process_inputs(self, *args): + """ + Processes args so that they can be passed to the FX graph. + """ + return self._codegen.process_inputs(*args) + + @compatibility(is_backward_compatible=False) + def process_outputs(self, out): + return self._codegen.process_outputs(out) + + + @compatibility(is_backward_compatible=True) + def erase_node(self, to_erase : Node) -> None: + """ + Erases a ``Node`` from the ``Graph``. Throws an exception if + there are still users of that node in the ``Graph``. + + Args: + + to_erase (Node): The ``Node`` to erase from the ``Graph``. + """ + if len(to_erase.users) > 0: + raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} ' + f'users in the graph: {to_erase.users}!') + if to_erase.graph != self: + raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!") + if to_erase._erased: + warnings.warn(f"erase_node({to_erase}) on an already erased node") + return + + if self.owning_module is not None and getattr(self.owning_module, "_erase_node_hooks", None) is not None: + for f in self.owning_module._erase_node_hooks: + f(to_erase) + + self._find_nodes_lookup_table.remove(to_erase) + to_erase._remove_from_list() + to_erase._erased = True # iterators may retain handles to erased nodes + self._len -= 1 + + # Null out this Node's argument nodes so that the Nodes referred to + # can update their ``users`` accordingly + new_args = map_arg(to_erase.args, lambda n: None) + assert isinstance(new_args, tuple) + to_erase.args = new_args + new_kwargs = map_arg(to_erase.kwargs, lambda n: None) + assert isinstance(new_kwargs, dict) + to_erase.kwargs = new_kwargs + + @compatibility(is_backward_compatible=True) + def inserting_before(self, n: Optional[Node] = None): + """Set the point at which create_node and companion methods will insert into the graph. + When used within a 'with' statement, this will temporary set the insert point and + then restore it when the with statement exits:: + + with g.inserting_before(n): + ... # inserting before node n + ... # insert point restored to what it was previously + g.inserting_before(n) # set the insert point permanently + + Args: + + n (Optional[Node]): The node before which to insert. If None this will insert before + the beginning of the entire graph. + + Returns: + A resource manager that will restore the insert point on ``__exit__``. + """ + if n is None: + return self.inserting_after(self._root) + assert n.graph == self, "Node to insert before is not in graph." + return _InsertPoint(self, n.prepend) + + @compatibility(is_backward_compatible=True) + def inserting_after(self, n: Optional[Node] = None): + """Set the point at which create_node and companion methods will insert into the graph. + When used within a 'with' statement, this will temporary set the insert point and + then restore it when the with statement exits:: + + with g.inserting_after(n): + ... # inserting after node n + ... # insert point restored to what it was previously + g.inserting_after(n) # set the insert point permanently + + Args: + + n (Optional[Node]): The node before which to insert. If None this will insert after + the beginning of the entire graph. + + Returns: + A resource manager that will restore the insert point on ``__exit__``. + """ + if n is None: + return self.inserting_before(self._root) + assert n.graph == self, "Node to insert after is not in graph." + return _InsertPoint(self, n.append) + + @compatibility(is_backward_compatible=True) + def placeholder(self, name: str, type_expr: Optional[Any] = None, + default_value : Any = inspect.Signature.empty) -> Node: + """ + Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents + a function input. + + Args: + + name (str): A name for the input value. This corresponds to the name + of the positional argument to the function this ``Graph`` represents. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. This is needed in some + cases for proper code generation (e.g. when the function is used + subsequently in TorchScript compilation). + + default_value (Any): The default value this function argument should take + on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty` + should be passed as this argument to specify that the parameter does _not_ + have a default value. + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + args = () if default_value is inspect.Signature.empty else (default_value,) + return self.create_node('placeholder', name, args=args, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the + fetch of an attribute from the ``Module`` hierarchy. + + Args: + + qualified_name (str): the fully-qualified name of the attribute to be retrieved. + For example, if the traced Module has a submodule named ``foo``, which has a + submodule named ``bar``, which has an attribute named ``baz``, the qualified + name ``foo.bar.baz`` should be passed as ``qualified_name``. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + + Returns: + + The newly-created and inserted ``get_attr`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool: + module_path, _, name = qualified_name.rpartition(".") + + try: + submod: torch.nn.Module = mod.get_submodule(module_path) + except AttributeError: + warnings.warn(f"Failed to fetch module {module_path}!") + return False + + if not hasattr(submod, name): + return False + + res = getattr(submod, name) + + if (not isinstance(res, torch.nn.Module) + and not isinstance(res, torch.nn.Parameter) + and name not in submod._buffers): + return False + + return True + + if (self.owning_module and + not _get_attr_reference_exists(self.owning_module, qualified_name)): + warnings.warn("Attempted to insert a get_attr Node with no " + "underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule, " + "GraphModule.add_parameter to add the " + "necessary Parameter, or " + "nn.Module.register_buffer to add the " + "necessary buffer", stacklevel=2) + return self.create_node('get_attr', qualified_name, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def call_module(self, + module_name: str, + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node + represents a call to the forward() function of a ``Module`` in the ``Module`` + hierarchy. + + Args: + + module_name (str): The qualified name of the ``Module`` in the ``Module`` + hierarchy to be called. For example, if the traced ``Module`` has a + submodule named ``foo``, which has a submodule named ``bar``, the + qualified name ``foo.bar`` should be passed as ``module_name`` to + call that module. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this should *not* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly-created and inserted ``call_module`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + if (self.owning_module and + self.owning_module.get_submodule(module_name) is None): + warnings.warn("Attempted to insert a call_module Node with " + "no underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule") + return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def call_method(self, + method_name: str, + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node + represents a call to a given method on the 0th element of ``args``. + + Args: + + method_name (str): The name of the method to apply to the self argument. + For example, if args[0] is a ``Node`` representing a ``Tensor``, + then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this *should* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly created and inserted ``call_method`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def call_function(self, + the_function: Callable[..., Any], + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node + represents a call to a Python callable, specified by ``the_function``. + + Args: + + the_function (Callable[..., Any]): The function to be called. Can be any PyTorch + operator, Python function, or member of the ``builtins`` or ``operator`` + namespaces. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called function. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called function + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly created and inserted ``call_function`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node: + """ + Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from + the graph of node to the graph of self. Example:: + + # Copying all the nodes in `g` into `new_graph` + g : torch.fx.Graph = ... + new_graph = torch.fx.graph() + value_remap = {} + for node in g.nodes: + value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n]) + + Args: + + node (Node): The node to copy into ``self``. + + arg_transform (Callable[[Node], Argument]): A function that transforms + ``Node`` arguments in node's ``args`` and ``kwargs`` into the + equivalent argument in ``self``. In the simplest case, this should + retrieve a value out of a table mapping Nodes in the original + graph to ``self``. + """ + args = map_arg(node.args, arg_transform) + kwargs = map_arg(node.kwargs, arg_transform) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type) + result_node.meta = copy.copy(node.meta) + return result_node + + @compatibility(is_backward_compatible=True) + def output(self, result: 'Argument', type_expr: Optional[Any] = None): + """ + Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents + a ``return`` statement in Python code. ``result`` is the value that should + be returned. + + Args: + + result (Argument): The value to be returned. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + .. note:: + + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) + + def _target_to_str(self, target : Target) -> str: + if callable(target): + op = target.__name__ + else: + assert isinstance(target, str) + op = target + if _is_magic(op): + op = op[2:-2] + op = _snake_case(op) + return op + + @compatibility(is_backward_compatible=True) + def python_code( + self, root_module: str, *, + verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False + ) -> PythonCode: + """ + Turn this ``Graph`` into valid Python code. + + Args: + + root_module (str): The name of the root module on which to look-up + qualified name targets. This is usually 'self'. + + Returns: + + A PythonCode object, consisting of two fields: + src: the Python source code representing the object + globals: a dictionary of global names in `src` -> the objects that they reference. + """ + # NOTE: [Graph Namespaces] + # + # There are two types of symbols in generated Python source code: + # locals and globals. + # Locals are locally defined by the output of a node in the Graph. + # Globals are references to external objects, like functions or types. + # + # When generating Python code, we need to make sure to name things + # appropriately. In particular: + # - All names should be unique, to avoid weird shadowing bugs. + # - These names need to be consistent, e.g. a object should always be + # referenced by the same name. + # + # To do this, we create a new namespace just for this source. All names + # that get printed must come from this namespace. + # + # Why can't we re-use node.name? Because it was generated within the + # namespace `self._graph_namespace`. In order to provide uniqueness + # over both locals (node.name) *and* globals, we create a completely + # new namespace to put all identifiers in. + namespace = _Namespace() + + # Override Node's repr to generate a valid name within our namespace. + # Since repr() is designed to produce a valid Python expression, it + # makes sense to re-use it. This way, it's easy to print something like + # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is + # implemented cooperatively to allow this. + def node_repr(n: Node): + return namespace.create_name(n.name, n) + + @contextmanager + def override_node_repr(graph: Graph): + orig_repr_fns = {} + for node in graph.nodes: + orig_repr_fns[node] = node._repr_fn + node._repr_fn = node_repr + try: + yield None + finally: + # restore the original repr functions + for node in graph.nodes: + node._repr_fn = orig_repr_fns[node] + + with override_node_repr(self): + return self._python_code( + root_module, namespace, + verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored + ) + + def _python_code( + self, root_module: str, namespace: _Namespace, *, + verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, + ) -> PythonCode: + return self._codegen._gen_python_code( + self.nodes, root_module, namespace, + verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored + ) + + + def __str__(self) -> str: + """ + Return a human-readable (not machine-readable) string representation + of this Graph + """ + placeholder_names : List[str] = [] + # This is a one-element array just so ``format_node`` can modify the closed + # over value + maybe_return_typename : List[str] = [''] + + node_strs = [node.format_node(placeholder_names) for node in self.nodes] + param_str = ', '.join(placeholder_names) + s = f'graph({param_str}){maybe_return_typename[0]}:' + for node_str in node_strs: + if node_str: + s += '\n ' + node_str + return s + + @compatibility(is_backward_compatible=True) + def print_tabular(self): + """ + Prints the intermediate representation of the graph in tabular + format. Note that this API requires the ``tabulate`` module to be + installed. + """ + try: + from tabulate import tabulate + except ImportError: + print("`print_tabular` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library.") + raise + + node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] + for n in self.nodes] + print(tabulate(node_specs, + headers=['opcode', 'name', 'target', 'args', 'kwargs'])) + + @compatibility(is_backward_compatible=True) + def lint(self): + """ + Runs various checks on this Graph to make sure it is well-formed. In + particular: + - Checks Nodes have correct ownership (owned by this graph) + - Checks Nodes appear in topological order + - If this Graph has an owning GraphModule, checks that targets + exist in that GraphModule + """ + + # Check topo order + def check_arg(arg : Node, n : Optional[Node] = None) -> None: + context_str = f' of Node \'{n}\' ' if n else ' ' + if arg.graph is not self: + raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, ' + f'but was used as an argument! If you are copying nodes from another graph, make ' + f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}') + if arg not in seen_values: + raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been ' + f'defined! Please check that Nodes in the graph are topologically ordered\n{self}') + + seen_names : Set[str] = set() + seen_values : Set[Node] = set() + for node in self.nodes: + if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']: + raise RuntimeError(f'Node {node} had unknown opcode {node.op}!') + if node.graph is not self: + raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!') + if node not in self._find_nodes_lookup_table: + raise RuntimeError(f"Node '{node}' is not added to the side table") + map_arg(node.args, lambda arg: check_arg(arg, node)) + map_arg(node.kwargs, lambda arg: check_arg(arg, node)) + seen_values.add(node) + + if node.name in seen_names: + raise RuntimeError(f'Node redefined name {node.name}!') + seen_names.add(node.name) + + # Check targets are legit + if self.owning_module: + num_warnings = 0 + MAX_WARNINGS = 5 + for node in self.nodes: + if node.op == 'call_function': + if not callable(node.target): + raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' + 'a Callable is expected') + else: + if not isinstance(node.target, str): + raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' + 'a str is expected') + if node.op in ['get_attr', 'call_module']: + target_atoms = node.target.split('.') + m_itr = self.owning_module + for i, atom in enumerate(target_atoms): + new_m_itr = getattr(m_itr, atom, None) + seen_qualname = '.'.join(target_atoms[:i]) + if new_m_itr is None: + raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute ' + f'{atom} of {seen_qualname}') + if (node.op == "call_module" + and not isinstance(new_m_itr, torch.nn.Module)): + raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' + 'not reference an nn.Module') + elif (node.op == "get_attr" + and not isinstance(new_m_itr, torch.nn.Module) + and not isinstance(new_m_itr, torch.nn.Parameter) + and atom not in m_itr._buffers): + if num_warnings < MAX_WARNINGS: + # Don't emit this warning too frequently, + # for very large graphs this can become very expensive + # from a performance perspective. + warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' + 'not reference an nn.Module, nn.Parameter, or buffer, which is ' + 'what \'get_attr\' Nodes typically target') + num_warnings += 1 + else: + m_itr = new_m_itr + if num_warnings > MAX_WARNINGS: + warnings.warn( + f'Additional {num_warnings - MAX_WARNINGS} warnings ' + 'suppressed about get_attr references' + ) + + @compatibility(is_backward_compatible=True) + def eliminate_dead_code(self, is_impure_node: Optional[Callable[[Node], bool]] = None): + """ + Remove all dead code from the graph, based on each node's number of + users, and whether the nodes have any side effects. The graph must be + topologically sorted before calling. + + Args: + is_impure_node (Optional[Callable[[Node], bool]]): A function that returns + whether a node is impure. If this is None, then the default behavior is to + use Node.is_impure. + + Returns: + bool: Whether the graph was changed as a result of the pass. + + Example: + + Before dead code is eliminated, `a` from `a = x + 1` below has no users + and thus can be eliminated from the graph without having an effect. + + .. code-block:: python + + def forward(self, x): + a = x + 1 + return x + self.attr_1 + + After dead code is eliminated, `a = x + 1` has been removed, and the rest + of `forward` remains. + + .. code-block:: python + + def forward(self, x): + return x + self.attr_1 + + .. warning:: + + Dead code elimination has some heuristics to avoid removing + side-effectful nodes (see Node.is_impure) but in general coverage + is very bad, so you should assume that this method is not sound + to call unless you know that your FX graph consists entirely + of functional operations or you supply your own custom + function for detecting side-effectful nodes. + """ + # Lint the graph first to make sure its topologically sorted, otherwise + # DCE below will not behave as expected. + self.lint() + + def has_side_effect(node): + if is_impure_node is not None: + return is_impure_node(node) + return node.is_impure() + + # Reverse iterate so that when we remove a node, any nodes used as an + # input to that node have an updated user count that no longer reflects + # the removed node. + changed = False + for node in reversed(self.nodes): + if not has_side_effect(node) and len(node.users) == 0: + self.erase_node(node) + changed = True + + return changed + + @compatibility(is_backward_compatible=False) + def set_codegen(self, codegen: CodeGen): + self._codegen = codegen + + @compatibility(is_backward_compatible=False) + def on_generate_code( + self, + make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc] + ): + """Register a transformer function when python code is generated + + Args: + make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]): + a function that returns a code transformer to be registered. + This function is called by `on_generate_code` to obtain the + code transformer. + + This function is also given as its input the currently + registered code transformer (or None if nothing is registered), + in case it is not desirable to overwrite it. This is useful to + chain code transformers together. + + Returns: + a context manager that when used in a `with` statement, to automatically + restore the previously registered code transformer. + + Example: + + .. code-block:: python + + + gm: fx.GraphModule = ... + + # This is a code transformer we want to register. This code + # transformer prepends a pdb import and trace statement at the very + # beginning of the generated torch.fx code to allow for manual + # debugging with the PDB library. + def insert_pdb(body): + return ["import pdb; pdb.set_trace()\\n", *body] + + # Registers `insert_pdb`, and overwrites the current registered + # code transformer (given by `_` to the lambda): + gm.graph.on_generate_code( + lambda _: insert_pdb + ) + + # Or alternatively, registers a code transformer which first + # runs `body` through existing registered transformer, then + # through `insert_pdb`: + gm.graph.on_generate_code( + lambda current_trans: ( + lambda body: insert_pdb( + current_trans(body) if current_trans + else body + ) + ) + ) + + gm.recompile() + gm(*inputs) # drops into pdb + + + This function can also be used as a context manager, with the benefit to + automatically restores the previously registered code transformer: + + .. code-block:: python + + # ... continue from previous example + + with gm.graph.on_generate_code(lambda _: insert_pdb): + # do more stuff with `gm`... + gm.recompile() + gm(*inputs) # drops into pdb + + # now previous code transformer is restored (but `gm`'s code with pdb + # remains - that means you can run `gm` with pdb here too, until you + # run next `recompile()`). + """ + on_gen_code_old = self._codegen._body_transformer + self._codegen._body_transformer = make_transformer(on_gen_code_old) + + @contextlib.contextmanager + def on_generate_code_context_manager(): + try: + yield + finally: + self._codegen._body_transformer = on_gen_code_old + + return on_generate_code_context_manager() + + +reflectable_magic_methods = { + 'add': '{} + {}', + 'sub': '{} - {}', + 'mul': '{} * {}', + 'floordiv': '{} // {}', + 'truediv': '{} / {}', + 'div': '{} / {}', + 'mod': '{} % {}', + 'pow': '{} ** {}', + 'lshift': '{} << {}', + 'rshift': '{} >> {}', + 'and_': '{} & {}', + 'or_': '{} | {}', + 'xor': '{} ^ {}', + 'getitem': '{}[{}]', + 'matmul': '{} @ {}', +} + +magic_methods = dict({ + 'eq': '{} == {}', + 'ne': '{} != {}', + 'lt': '{} < {}', + 'gt': '{} > {}', + 'le': '{} <= {}', + 'ge': '{} >= {}', + 'pos': '+{}', + 'neg': '-{}', + 'invert': '~{}'}, **reflectable_magic_methods) + +inplace_methods = { + 'iadd': '{} += {}', + 'iand': '{} &= {}', + 'ifloordiv': '{} //= {}', + 'ilshift': '{} <<= {}', + 'imod': '{} %= {}', + 'imul': '{} *= {}', + 'imatmul': '{} @= {}', + 'ior': '{} |= {}', + 'ipow': '{} **= {}', + 'irshift': '{} >>= {}', + 'isub': '{} -= {}', + 'itruediv': '{} /= {}', + 'ixor': '{} ^= {}', + 'setitem': '{}[{}] = {}', +} diff --git a/lib/python3.10/site-packages/torch/fx/graph_module.py b/lib/python3.10/site-packages/torch/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..76dac29512bdccd7cb93a5afe8107e124f8961b7 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/graph_module.py @@ -0,0 +1,955 @@ +# mypy: allow-untyped-defs +import contextlib +import copy +import itertools +import linecache +import os +import sys +import traceback +import warnings +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set, Type, Union + +import torch +import torch.nn as nn +import torch.overrides +from torch.nn.modules.module import _addindent +from torch.package import Importer, PackageExporter, PackageImporter, sys_importer + +from ._compatibility import compatibility +from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode + +__all__ = [ + "reduce_graph_module", + "reduce_package_graph_module", + "reduce_deploy_graph_module", + "GraphModule", +] + +_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes" + +# Normal exec loses the source code, however we can work with +# the linecache module to recover it. +# Using _exec_with_source will add it to our local cache +# and then tools like TorchScript will be able to get source info. +class _EvalCacheLoader: + def __init__(self): + self.eval_cache = {} + self.next_id = 0 + + def cache(self, src: str, globals: Dict[str, Any], co_fields=None): + """Store the source in a private cache, and add a lazy entry in linecache + that allows the source to be retrieved by 'filename'. + + Args: + src (str): The module source to cache + globals (dict): The module globals + + Returns: + str: The cache key (and dummy filename) generated for src. + """ + + key = self._get_key() + if co_fields: + key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" + self.eval_cache[key] = src + + # Don't mutate globals so that this loader is only used + # to populate linecache, and doesn't interact with other modules + # that might check `__loader__` + globals_copy = globals.copy() + globals_copy["__file__"] = key + globals_copy["__name__"] = key + globals_copy["__loader__"] = self + linecache.lazycache(key, globals_copy) + + return key + + # Part of the loader protocol (PEP 302) + # linecache will use this method when trying to find source code + def get_source(self, module_name) -> Optional[str]: + if module_name in self.eval_cache: + return self.eval_cache[module_name] + return None + + def _get_key(self): + key = f".{self.next_id}" + self.next_id += 1 + return key + + +_loader = _EvalCacheLoader() + + +def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None): + key = _loader.cache(src, globals, co_fields) + exec(compile(src, key, "exec"), globals) + + +def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None): + return _method_from_src( + method_name="forward", src=src, globals=globals, co_fields=co_fields + ) + + +def _method_from_src( + method_name: str, src: str, globals: Dict[str, Any], co_fields=None +) -> Callable: + # avoid mutating the passed in dict + globals_copy = globals.copy() + _exec_with_source(src, globals_copy, co_fields) + fn = globals_copy[method_name] + del globals_copy[method_name] + return fn + + +def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: + if name in _custom_builtins: + return _custom_builtins[name].import_str + if _is_from_torch(name): + return "import torch" + module_name, attr_name = importer.get_name(obj) + return f"from {module_name} import {attr_name} as {name}" + + +def _format_import_block(globals: Dict[str, Any], importer: Importer): + import_strs: Set[str] = {_format_import_statement(name, obj, importer) for name, obj in globals.items()} + # Sort the imports so we have a stable import block that allows us to + # hash the graph module and get a consistent key for use in a cache. + return "\n".join(sorted(import_strs)) + + +@compatibility(is_backward_compatible=True) +def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: + # BC: attribute name was changed from `code` to `_code` to facilitate + # making `code` into a property and adding a docstring to it + fn_src = body.get("_code") or body["code"] + forward = _forward_from_src(import_block + fn_src, {}) + return _deserialize_graph_module(forward, body) + + +@compatibility(is_backward_compatible=True) +def reduce_package_graph_module( + importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str +) -> torch.nn.Module: + forward = importer.import_module(generated_module_name).forward + return _deserialize_graph_module(forward, body) + + +@compatibility(is_backward_compatible=True) +def reduce_deploy_graph_module( + importer: PackageImporter, body: Dict[Any, Any], import_block: str +) -> torch.nn.Module: + ns = {} + ns["__builtins__"] = importer.patched_builtins + fn_src = body.get("_code") + assert fn_src is not None + forward = _forward_from_src(import_block + fn_src, ns) + return _deserialize_graph_module(forward, body) + + +# We create a dummy class here because symbolic_trace pulls the forward() +# function off of the class, rather than the instance. This class is used +# in _deserialize_graph_module() below. +class _CodeOnlyModule(torch.nn.Module): + def __init__(self, body): + super().__init__() + self.__dict__ = body + + +def _deserialize_graph_module(forward, body: Dict[Any, Any], graph_module_cls=None) -> torch.nn.Module: + """ + Deserialize a GraphModule given the dictionary of the original module, + using the code to reconstruct the graph. We delete the actual graph before + saving the dictionary so that changes to the in-memory graph format do not + get serialized. + """ + + # Try to retrieve the forward source in a backward-compatible way + _CodeOnlyModule.forward = forward + + tracer_cls = body.get("_tracer_cls") + if tracer_cls is None: + from ._symbolic_trace import Tracer + + tracer_cls = Tracer + + graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule") + + # This is a workaround for a mypy linter issue related to + # passing base class as an argument - https://github.com/python/mypy/issues/5865. + cls_tracer: Any = tracer_cls + + class KeepModules(cls_tracer): + # we shouldn't trace into any of the submodules, + # because they were not traced in the original GraphModule + def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: + return True + + com = _CodeOnlyModule(body) + + tracer_extras = body.get("_tracer_extras", {}) + graph = KeepModules().trace(com, **tracer_extras) + + # Manually set Tracer class on the reconstructed Graph, to avoid + # referencing the private local subclass KeepModules. + graph._tracer_cls = tracer_cls + from ._lazy_graph_module import _make_graph_module + gm = _make_graph_module(com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls) + + # The GraphModule constructor only retains attributes referenced by the graph. + # In this case, our goal is return a GraphModule as close to identical as the one + # put into the package. If any additional attributes were present in body, + # we should keep them. + for k, v in body.items(): + if not hasattr(gm, k): + setattr(gm, k, v) + return gm + + +# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' +# This installs empty Modules where none exist yet if they are subpaths of target +def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): + *prefix, field = target.split(".") + for item in prefix: + f = getattr(from_module, item) + t = getattr(to_module, item, None) + if f is t: + # we have already installed one of its parents + # (e.g. target = root.linear.weight, but we have already installed root.linear) + # once we install a parent, we no longer need to copy the children + # since all the needed properties will already be present + return + + if t is None: + t = torch.nn.Module() + setattr(to_module, item, t) + from_module, to_module = f, t + + orig = getattr(from_module, field) + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): + to_module.register_buffer(field, orig) + else: + setattr(to_module, field, orig) + + +# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module +# This installs empty Modules where none exist yet if they are subpaths of target +def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): + *prefix, field = target.split(".") + for item in prefix: + t = getattr(to_module, item, None) + + if t is None: + t = torch.nn.Module() + setattr(to_module, item, t) + to_module = t + + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(from_obj, torch.Tensor) and not isinstance( + from_obj, torch.nn.Parameter + ): + to_module.register_buffer(field, from_obj) + else: + setattr(to_module, field, from_obj) + + +def _print_readable( + module, + module_name, + print_output=True, + include_stride=False, + include_device=False, + colored=False, +): + graph = module.graph + assert graph is not None and isinstance(graph, torch.fx.Graph), "print_readable must be used on a module with a graph" + + verbose_python_code = graph.python_code( + root_module="self", + verbose=True, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + module_code = verbose_python_code.src + module_code = module_code.lstrip("\n") + module_code = f"class {module_name}(torch.nn.Module):\n" + module_code + module_code = _addindent(module_code, 4) + + submodule_code_list = [""] + for submodule_name, submodule in module.named_children(): + if hasattr(submodule, "graph"): + submodule_code_list.append( + _print_readable( + submodule, + submodule_name, + print_output=False, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + ) + submodule_code = "\n".join(submodule_code_list) + submodule_code = _addindent(submodule_code, 4) + + output = module_code + submodule_code + if print_output: + print(module_code + submodule_code) + return output + + +class _WrappedCall: + def __init__(self, cls, cls_call): + self.cls = cls + self.cls_call = cls_call + + # Previously, if an error occurred when valid + # symbolically-traced code was run with an invalid input, the + # user would see the source of the error as coming from + # `File "`, where N is some number. We use + # this function to generate a more informative error message. We + # return the traceback itself, a message explaining that the + # error occurred in a traced Module's generated forward + # function, and five lines of context surrounding the faulty + # line + @staticmethod + def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: + # auxiliary variables (for readability) + err_lineno = frame_summary.lineno + assert err_lineno is not None + line = frame_summary.line + assert line is not None + err_line_len = len(line) + all_src_lines = linecache.getlines(frame_summary.filename) + + # constituent substrings of the error message + tb_repr = torch._dynamo.disable(traceback.format_exc)() + custom_msg = ( + "Call using an FX-traced Module, " + f"line {err_lineno} of the traced Module's " + "generated forward function:" + ) + before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) + marker = "~" * err_line_len + "~~~ <--- HERE" + err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) + + # joined message + return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) + + def __call__(self, obj, *args, **kwargs): + try: + if self.cls_call is not None: + return self.cls_call(obj, *args, **kwargs) + else: + return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] + except Exception as e: + assert e.__traceback__ + topmost_framesummary: traceback.FrameSummary = ( + traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] + ) # type: ignore[arg-type] + if "eval_with_key" in topmost_framesummary.filename: + print( + _WrappedCall._generate_error_message(topmost_framesummary), + file=sys.stderr, + ) + raise e.with_traceback(None) # noqa: B904 + else: + raise e + +@compatibility(is_backward_compatible=True) +class GraphModule(torch.nn.Module): + """ + GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a + ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated + from that ``graph``. + + .. warning:: + + When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically + regenerated. However, if you edit the contents of the ``graph`` without reassigning + the ``graph`` attribute itself, you must call ``recompile()`` to update the generated + code. + """ + + def __new__(cls: "Type[GraphModule]", *args, **kwargs): + # each instance of a graph module needs its own forward method + # so create a new singleton class for each instance. + # it is a subclass of the user-defined class, the only difference + # is an extra layer to install the forward method + + # address issue described at https://github.com/pytorch/pytorch/issues/63883 + # in other words, traverse class hierarchy to fix the redundant class definition problem + for t in cls.__mro__: + c = t.__qualname__.split(".")[-1] + if c != "GraphModuleImpl": + cls = t + break + + class GraphModuleImpl(cls): # type: ignore[misc, valid-type] + pass + + return super().__new__(GraphModuleImpl) + + @compatibility(is_backward_compatible=True) + def __init__( + self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: Graph, + class_name: str = "GraphModule", + ): + """ + Construct a GraphModule. + + Args: + + root (Union[torch.nn.Module, Dict[str, Any]): + ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. + In the case that ``root`` is a Module, any references to Module-based objects (via qualified + name) in the Graph's Nodes' ``target`` field will be copied over from the respective place + within ``root``'s Module hierarchy into the GraphModule's module hierarchy. + In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be + looked up directly in the dict's keys. The object mapped to by the Dict will be copied + over into the appropriate place within the GraphModule's module hierarchy. + + graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation + + class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all + error messages will report as originating from ``GraphModule``. It may be helpful to set this + to ``root``'s original name or a name that makes sense within the context of your transform. + """ + super().__init__() + self.__class__.__name__ = class_name + if isinstance(root, torch.nn.Module): + if hasattr(root, "training"): + self.training = root.training + + # When we pickle/unpickle graph module, we don't want to drop any module or attributes. + if isinstance(root, _CodeOnlyModule): + for k, _ in root.named_children(): + _copy_attr(root, self, k) + + for k, _ in root.named_buffers(): + _copy_attr(root, self, k) + + for k, _ in root.named_parameters(): + _copy_attr(root, self, k) + + for node in graph.nodes: + if node.op in ["get_attr", "call_module"]: + assert isinstance(node.target, str) + _copy_attr(root, self, node.target) + elif isinstance(root, dict): + targets_to_copy = [] + for node in graph.nodes: + if node.op in ["get_attr", "call_module"]: + assert isinstance(node.target, str) + if node.target not in root: + raise RuntimeError( + "Node " + + str(node) + + " referenced target " + + node.target + + " but that target was not provided in ``root``!" + ) + targets_to_copy.append(node.target) + # Sort targets in ascending order of the # of atoms. + # This will ensure that less deeply nested attributes are assigned + # before more deeply nested attributes. For example, foo.bar + # will be assigned before foo.bar.baz. Otherwise, we might assign + # the user-provided ``foo.bar`` and wipe out the previously-assigned + # ``foo.bar.baz`` + targets_to_copy.sort(key=lambda t: t.count(".")) + for target_to_copy in targets_to_copy: + _assign_attr(root[target_to_copy], self, target_to_copy) + else: + raise RuntimeError("Unsupported type " + str(root) + " passed for root!") + + self.graph = graph + + # Store the Tracer class responsible for creating a Graph separately as part of the + # GraphModule state, except when the Tracer is defined in a local namespace. + # Locally defined Tracers are not pickleable. This is needed because torch.package will + # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer + # to re-create the Graph during deserialization. + self._tracer_cls = None + if ( + self.graph._tracer_cls + and "" not in self.graph._tracer_cls.__qualname__ + ): + self._tracer_cls = self.graph._tracer_cls + + self._tracer_extras = {} + if self.graph._tracer_extras: + self._tracer_extras = self.graph._tracer_extras + + # Dictionary to store metadata + self.meta: Dict[str, Any] = {} + self._replace_hook = None + self._create_node_hooks: List[Callable] = [] + self._erase_node_hooks: List[Callable] = [] + + # TorchScript breaks trying to compile the graph setter because of the + # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 + # + # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway + __jit_unused_properties__ = ["graph"] + + @property + def graph(self) -> Graph: + """ + Return the ``Graph`` underlying this ``GraphModule`` + """ + return self._graph + + @graph.setter + def graph(self, g: Graph) -> None: + """ + Set the underlying ``Graph`` for this ``GraphModule``. This will internally + recompile the ``GraphModule`` so that the generated ``forward()`` function + corresponds to ``g`` + """ + assert isinstance(g, Graph), f"Expected a Graph instance, but got {type(g)}" + self._graph = g + g.owning_module = self + self.recompile() + + @compatibility(is_backward_compatible=False) + def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): + """Dumps out module to ``folder`` with ``module_name`` so that it can be + imported with ``from import `` + + Args: + + folder (Union[str, os.PathLike]): The folder to write the code out to + + module_name (str): Top-level name to use for the ``Module`` while + writing out the code + """ + folder = Path(folder) + Path(folder).mkdir(exist_ok=True) + torch.save(self.state_dict(), folder / "state_dict.pt") + tab = " " * 4 + custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()]) + model_str = f""" +import torch +{custom_builtins} + +from torch.nn import * +class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__() +""" + + def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: + safe_reprs = [ + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + ] + if type(module) in safe_reprs: + return f"{module.__repr__()}" + else: + return None + + blobified_modules = [] + for module_name, module in self.named_children(): + module_str = _gen_model_repr(module_name, module) + if module_str is None: + module_file = folder / f"{module_name}.pt" + torch.save(module, module_file) + blobified_modules.append(module_name) + module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") + # weights_only=False as this is legacy code that saves the model + module_str = f"torch.load(r'{module_file}', weights_only=False) # {module_repr}" + model_str += f"{tab*2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in self._buffers.items(): + if buffer is None: + continue + model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" + + for param_name, param in self._parameters.items(): + if param is None: + continue + model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" + + model_str += ( + f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + ) + model_str += f"{_addindent(self.code, 4)}\n" + + module_file = folder / "module.py" + module_file.write_text(model_str) + + init_file = folder / "__init__.py" + init_file.write_text("from .module import *") + + if len(blobified_modules) > 0: + warnings.warn( + "Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}" + ) + + @compatibility(is_backward_compatible=True) + def add_submodule(self, target: str, m: torch.nn.Module) -> bool: + """ + Adds the given submodule to ``self``. + + This installs empty Modules where none exist yet if they are + subpaths of ``target``. + + Args: + target: The fully-qualified string name of the new submodule + (See example in ``nn.Module.get_submodule`` for how to + specify a fully-qualified string.) + m: The submodule itself; the actual object we want to + install in the current Module + + Return: + bool: Whether or not the submodule could be inserted. For + this method to return True, each object in the chain + denoted by ``target`` must either a) not exist yet, + or b) reference an ``nn.Module`` (not a parameter or + other attribute) + """ + *prefix, field = target.split(".") + mod: torch.nn.Module = self + + for item in prefix: + + submod = getattr(mod, item, None) + + if submod is None: + submod = torch.nn.Module() + setattr(mod, item, submod) + + if not isinstance(submod, torch.nn.Module): + return False + + mod = submod + + mod.add_module(field, m) + return True + + @compatibility(is_backward_compatible=True) + def delete_submodule(self, target: str) -> bool: + """ + Deletes the given submodule from ``self``. + + The module will not be deleted if ``target`` is not a valid + target. + + Args: + target: The fully-qualified string name of the new submodule + (See example in ``nn.Module.get_submodule`` for how to + specify a fully-qualified string.) + + Returns: + bool: Whether or not the target string referenced a + submodule we want to delete. A return value of ``False`` + means that the ``target`` was not a valid reference to + a submodule. + """ + atoms = target.split(".") + path, target_submod = atoms[:-1], atoms[-1] + mod: torch.nn.Module = self + + # Get the parent module + for item in path: + + if not hasattr(mod, item): + return False + + mod = getattr(mod, item) + + if not isinstance(mod, torch.nn.Module): + return False + + if not hasattr(mod, target_submod): + return False + + if not isinstance(getattr(mod, target_submod), torch.nn.Module): + return False + + delattr(mod, target_submod) + return True + + @compatibility(is_backward_compatible=True) + def delete_all_unused_submodules(self) -> None: + """ + Deletes all unused submodules from ``self``. + + A Module is considered "used" if any one of the following is + true: + 1. It has children that are used + 2. Its forward is called directly via a ``call_module`` node + 3. It has a non-Module attribute that is used from a + ``get_attr`` node + + This method can be called to clean up an ``nn.Module`` without + manually calling ``delete_submodule`` on each unused submodule. + """ + used: List[str] = [] + + for node in self.graph.nodes: + + if node.op == "call_module" or node.op == "get_attr": + + # A list of strings representing the different parts + # of the path. For example, `foo.bar.baz` gives us + # ["foo", "bar", "baz"] + fullpath = node.target.split(".") + + # If we're looking at multiple parts of a path, join + # join them with a dot. Otherwise, return that single + # element without doing anything to it. + def join_fn(x: str, y: str) -> str: + return ".".join([x, y] if y else [x]) + + # Progressively collect all the names of intermediate + # modules. For example, if we have the target + # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and + # `foo.bar.baz` to the list. + used.extend(itertools.accumulate(fullpath, join_fn)) + + # For a `call_module` node, also register all recursive submodules + # as used + if node.op == "call_module": + try: + submod = self.get_submodule(node.target) + + for submod_name, _ in submod.named_modules(): + if submod_name != "": + used.append(".".join([node.target, submod_name])) + except AttributeError: + # Node referenced nonexistent submodule, don't need to + # worry about GCing anything + pass + + to_delete = [name for name, _ in self.named_modules() if name not in used] + + for name in to_delete: + self.delete_submodule(name) + + @property + def code(self) -> str: + """ + Return the Python code generated from the ``Graph`` underlying this + ``GraphModule``. + """ + if not hasattr(self, "_code"): + raise RuntimeError( + "Code has not been generated! Please report a bug to PyTorch" + ) + return self._code + + @compatibility(is_backward_compatible=True) + def recompile(self) -> PythonCode: + """ + Recompile this GraphModule from its ``graph`` attribute. This should be + called after editing the contained ``graph``, otherwise the generated + code of this ``GraphModule`` will be out of date. + """ + if isinstance(self._graph._codegen, _PyTreeCodeGen): + self._in_spec = self._graph._codegen.pytree_info.in_spec + self._out_spec = self._graph._codegen.pytree_info.out_spec + python_code = self._graph.python_code(root_module="self") + self._code = python_code.src + self._lineno_map = python_code._lineno_map + + cls = type(self) + co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) + + # Determine whether this class explicitly defines a __call__ implementation + # to wrap. If it does, save it in order to have wrapped_call invoke it. + # If it does not, wrapped_call can use a dynamic call to super() instead. + # In most cases, super().__call__ should be torch.nn.Module.__call__. + # We do not want to hold a reference to Module.__call__ here; doing so will + # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. + cls_call = cls.__call__ if "__call__" in vars(cls) else None + + if "_wrapped_call" not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + + def call_wrapped(self, *args, **kwargs): + return self._wrapped_call(self, *args, **kwargs) + + cls.__call__ = call_wrapped # type: ignore[method-assign] + + return python_code + + # Passing Tracer as argument allows subclasses extending fx.GraphModule + # define their own Tracer (extending fx.Tracer). + def __reduce_deploy__(self, importer: Importer): + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, importer) + return (reduce_deploy_graph_module, (dict_without_graph, import_block)) + + def __reduce_package__(self, exporter: PackageExporter): + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + generated_module_name = f"fx-generated._{exporter.get_unique_id()}" + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, exporter.importer) + module_code = import_block + self.code + exporter.save_source_string(generated_module_name, module_code) + return ( + reduce_package_graph_module, + (dict_without_graph, generated_module_name), + ) + + def __reduce__(self): + """ + Serialization of GraphModule. We serialize only the generated code, not + the underlying ``Graph``. This is because ``Graph`` does not have on-disk + backward-compatibility guarantees, whereas Python source code does. + On the deserialization side, we symbolically trace through the generated + code to regenerate the underlying ``Graph`` + """ + dict_without_graph = self.__dict__.copy() + + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, sys_importer) + del dict_without_graph["_graph"] + return (reduce_graph_module, (dict_without_graph, import_block)) + + def _deepcopy_init(self): + return GraphModule.__init__ + + # because __reduce__ is defined for serialization, + # we need to define deepcopy otherwise it will call __reduce__ + # and cause symbolic tracing to occur every time we try to copy the object + def __deepcopy__(self, memo): + res = type(self).__new__(type(self)) + memo[id(self)] = res + fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo)) + self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"]) + # hooks are lost during `GraphModule.__init__`, so we need to copy over + # them explicitly, note right now we are only copying state_dict related + # hooks, to reduce bc-related issues, we can copy forward/backward related + # hooks in the future as well if needed + extra_preserved_attrs = [ + "_state_dict_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", + "_replace_hook", + "_create_node_hooks", + "_erase_node_hooks" + ] + for attr in extra_preserved_attrs: + if attr in self.__dict__: + setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo)) + res.meta = copy.deepcopy(getattr(self, "meta", {}), memo) + if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta: + for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): + setattr(res, attr_name, attr) + return res + + def __copy__(self): + from ._lazy_graph_module import _make_graph_module + res = _make_graph_module(self, self.graph) + res.meta = getattr(self, "meta", {}) + return res + + @compatibility(is_backward_compatible=False) + def print_readable(self, print_output=True, include_stride=False, include_device=False, colored=False): + """ + Return the Python code generated for current GraphModule and its children GraphModules + """ + return _print_readable( + self, + self._get_name(), + print_output, + include_stride, + include_device, + colored, + ) + + def __str__(self) -> str: + orig_str = super().__str__() + print_readable_reminder = ( + "# To see more debug info, please use `graph_module.print_readable()`" + ) + return "\n".join([orig_str, self._code, print_readable_reminder]) + + def _replicate_for_data_parallel(self): + new_gm = self.__copy__() + new_gm._is_replica = True + return new_gm + + @contextlib.contextmanager + def _set_replace_hook(self, f): + """ + Takes a callable which will be called everytime when we replace a node + to a new node, or change the node's name. Callable takes three arguments: + the old node we're changing, and NAME of the new node, followed by the + user node which consumes the old node to be replaced. + """ + assert callable(f), "Replace hook must be a callable." + prev, self._replace_hook = self._replace_hook, f + try: + yield + finally: + self._replace_hook = prev + + def _register_create_node_hook(self, f): + """ + Takes a callable which will be called after we create a new node. The + callable takes the newly created node as input and returns None. + """ + assert callable(f), "create_node hook must be a callable." + self._create_node_hooks.append(f) + + def _unregister_create_node_hook(self, f): + """ + Takes a callable which was previously registered to be called after we create a node. + This function will unregister that callable so it is no longer invoked on node creation. + """ + assert callable(f), "create_node hook must be a callable." + self._create_node_hooks.remove(f) + + def _register_erase_node_hook(self, f): + """ + Takes a callable which will be called after we erase a node. The + callable takes the node that is being erased as input and returns None. + """ + assert callable(f), "erase_node hook must be a callable." + self._erase_node_hooks.append(f) + + def _unregister_erase_node_hook(self, f): + """ + Takes a callable which was previously registered to be called after we erase a node. + This function will unregister that callable so it is no longer invoked on node erasure. + """ + assert callable(f), "erase_node hook must be a callable." + self._erase_node_hooks.remove(f) + +# workarounds for issues in __torch_function__ + +# WAR for __torch_function__ not handling tensor lists, +# fix is in https://github.com/pytorch/pytorch/pull/34725 +# orig_cat = torch.cat +# def patched_cat(*args, **kwargs): +# tensors = args[0] +# for t in tensors: +# if isinstance(t, Proxy): +# return t.__torch_function__(patched_cat, (), args, kwargs) +# return orig_cat(*args, **kwargs) +# patched_cat.__module__ = 'torch' +# patched_cat.__name__ = 'cat' +# torch.cat = patched_cat diff --git a/lib/python3.10/site-packages/torch/fx/immutable_collections.py b/lib/python3.10/site-packages/torch/fx/immutable_collections.py new file mode 100644 index 0000000000000000000000000000000000000000..2ff29cba474dcb613cd529b881193e20e1325856 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/immutable_collections.py @@ -0,0 +1,117 @@ +# mypy: allow-untyped-defs +from typing import Any, Dict, Iterable, List, Tuple + +from torch.utils._pytree import ( + _dict_flatten, + _dict_flatten_with_keys, + _dict_unflatten, + _list_flatten, + _list_flatten_with_keys, + _list_unflatten, + Context, + register_pytree_node, +) + +from ._compatibility import compatibility + + +__all__ = ["immutable_list", "immutable_dict"] + +_help_mutation = """\ +If you are attempting to modify the kwargs or args of a torch.fx.Node object, +instead create a new copy of it and assign the copy to the node: + new_args = ... # copy and mutate args + node.args = new_args +""" + + +def _no_mutation(self, *args, **kwargs): + raise NotImplementedError( + f"'{type(self).__name__}' object does not support mutation. {_help_mutation}", + ) + + +def _create_immutable_container(base, mutable_functions): + container = type("immutable_" + base.__name__, (base,), {}) + for attr in mutable_functions: + setattr(container, attr, _no_mutation) + return container + + +immutable_list = _create_immutable_container( + list, + ( + "__delitem__", + "__iadd__", + "__imul__", + "__setitem__", + "append", + "clear", + "extend", + "insert", + "pop", + "remove", + "reverse", + "sort", + ), +) +immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),)) +immutable_list.__hash__ = lambda self: hash(tuple(self)) + +compatibility(is_backward_compatible=True)(immutable_list) + +immutable_dict = _create_immutable_container( + dict, + ( + "__delitem__", + "__ior__", + "__setitem__", + "clear", + "pop", + "popitem", + "setdefault", + "update", + ), +) +immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),)) +immutable_dict.__hash__ = lambda self: hash(tuple(self.items())) +compatibility(is_backward_compatible=True)(immutable_dict) + + +# Register immutable collections for PyTree operations +def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: + return _dict_flatten(d) + + +def _immutable_dict_unflatten( + values: Iterable[Any], + context: Context, +) -> Dict[Any, Any]: + return immutable_dict(_dict_unflatten(values, context)) + + +def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: + return _list_flatten(d) + + +def _immutable_list_unflatten( + values: Iterable[Any], + context: Context, +) -> List[Any]: + return immutable_list(_list_unflatten(values, context)) + + +register_pytree_node( + immutable_dict, + _immutable_dict_flatten, + _immutable_dict_unflatten, + serialized_type_name="torch.fx.immutable_collections.immutable_dict", + flatten_with_keys_fn=_dict_flatten_with_keys, +) +register_pytree_node( + immutable_list, + _immutable_list_flatten, + _immutable_list_unflatten, + serialized_type_name="torch.fx.immutable_collections.immutable_list", + flatten_with_keys_fn=_list_flatten_with_keys, +) diff --git a/lib/python3.10/site-packages/torch/fx/interpreter.py b/lib/python3.10/site-packages/torch/fx/interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..c75407583137d30e377a6b4e8c734858b7b5063b --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/interpreter.py @@ -0,0 +1,520 @@ +# mypy: allow-untyped-defs +from .graph_module import GraphModule +from ._lazy_graph_module import _make_graph_module +from .graph import Graph +from .node import Argument, Node, Target, map_arg, map_aggregate +from .proxy import Proxy +from ._symbolic_trace import Tracer +from ._compatibility import compatibility +from . import config +import torch.fx.traceback as fx_traceback +import torch +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +import inspect +from contextlib import contextmanager +from torch.hub import tqdm + +__all__ = ['Interpreter', 'Transformer'] + +@compatibility(is_backward_compatible=True) +class Interpreter: + """ + An Interpreter executes an FX graph Node-by-Node. This pattern + can be useful for many things, including writing code + transformations as well as analysis passes. + + Methods in the Interpreter class can be overridden to customize + the behavior of execution. The map of overrideable methods + in terms of call hierarchy:: + + run() + +-- run_node + +-- placeholder() + +-- get_attr() + +-- call_function() + +-- call_method() + +-- call_module() + +-- output() + + Example: + + Suppose we want to swap all instances of ``torch.neg`` with + ``torch.sigmoid`` and vice versa (including their ``Tensor`` + method equivalents). We could subclass Interpreter like so:: + + class NegSigmSwapInterpreter(Interpreter): + def call_function(self, target : Target, + args : Tuple, kwargs : Dict) -> Any: + if target == torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(n) + + def call_method(self, target : Target, + args : Tuple, kwargs : Dict) -> Any: + if target == 'neg': + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(n) + + def fn(x): + return torch.sigmoid(x).neg() + + gm = torch.fx.symbolic_trace(fn) + input = torch.randn(3, 4) + result = NegSigmSwapInterpreter(gm).run(input) + torch.testing.assert_close(result, torch.neg(input).sigmoid()) + + Args: + module (torch.nn.Module): The module to be executed + garbage_collect_values (bool): Whether to delete values after their last + use within the Module's execution. This ensures optimal memory usage during + execution. This can be disabled to, for example, examine all of the intermediate + values in the execution by looking at the ``Interpreter.env`` attribute. + graph (Optional[Graph]): If passed, the interpreter will execute this + graph instead of `module.graph`, using the provided `module` + argument to satisfy any requests for state. + """ + @compatibility(is_backward_compatible=True) + def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None): + self.module = module + self.submodules = dict(self.module.named_modules()) + if graph is not None: + self.graph = graph + else: + self.graph = self.module.graph + self.env : Dict[Node, Any] = {} + self.name = "Interpreter" + self.garbage_collect_values = garbage_collect_values + self.extra_traceback = True + + if self.garbage_collect_values: + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use : Dict[Node, Node] = {} + self.user_to_last_uses : Dict[Node, List[Node]] = {} + + def register_last_uses(n : Node, user : Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + self.user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(self.graph.nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + @compatibility(is_backward_compatible=True) + def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any: + """ + Run `module` via interpretation and return the result. + + Args: + *args: The arguments to the Module to run, in positional order + initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. + This is a dict mapping `Node` to any value. This can be used, for example, to + pre-populate results for certain `Nodes` so as to do only partial evaluation within + the interpreter. + enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and + process_outputs function first before using them. + + Returns: + Any: The value returned from executing the Module + """ + self.env = initial_env if initial_env is not None else {} + + # Positional function args are consumed left-to-right by + # `placeholder` nodes. Use an iterator to keep track of + # position and extract those values. + if enable_io_processing: + args = self.graph.process_inputs(*args) + self.args_iter : Iterator[Any] = iter(args) + pbar = tqdm(total=len(self.graph.nodes), + desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}", + initial=0, position=0, leave=True, disable=config.disable_progress, delay=0) + + for node in self.graph.nodes: + pbar.update(1) + if node in self.env: + # Short circuit if we have this value. This could + # be used, for example, for partial evaluation + # where the caller has pre-populated `env` with + # values for a subset of the program. + continue + + try: + self.env[node] = self.run_node(node) + except Exception as e: + if self.extra_traceback: + msg = f"While executing {node.format_node()}" + msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg) + msg += f"\nOriginal traceback:\n{node.stack_trace}" + e.args = (msg,) + e.args[1:] + if isinstance(e, KeyError): + raise RuntimeError(*e.args) from e + raise + + if self.garbage_collect_values: + for to_delete in self.user_to_last_uses.get(node, []): + del self.env[to_delete] + + if node.op == 'output': + output_val = self.env[node] + return self.graph.process_outputs(output_val) if enable_io_processing else output_val + + @compatibility(is_backward_compatible=True) + def boxed_run(self, args_list): + """ + Run `module` via interpretation and return the result. This uses the "boxed" + calling convention, where you pass a list of arguments, which will be cleared + by the interpreter. This ensures that input tensors are promptly deallocated. + """ + args_iter = iter(args_list) + env = {} + for n in self.graph.nodes: + if n.op == "placeholder": + env[n] = next(args_iter) + args_list.clear() + return self.run(initial_env=env) + + @contextmanager + def _set_current_node(self, node): + with fx_traceback.set_current_meta(node): + yield + + @compatibility(is_backward_compatible=True) + def run_node(self, n : Node) -> Any: + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + with self._set_current_node(n): + args, kwargs = self.fetch_args_kwargs_from_env(n) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + return getattr(self, n.op)(n.target, args, kwargs) + + # Main Node running APIs + @compatibility(is_backward_compatible=True) + def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``placeholder`` node. Note that this is stateful: + ``Interpreter`` maintains an internal iterator over + arguments passed to ``run`` and this method returns + next() on that iterator. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Returns: + Any: The argument value that was retrieved. + """ + assert isinstance(target, str) + if target.startswith('*'): + # For a starred parameter e.g. `*args`, retrieve all + # remaining values from the args list. + return list(self.args_iter) + else: + try: + return next(self.args_iter) + except StopIteration as si: + if len(args) > 0: + return args[0] + else: + raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si + + @compatibility(is_backward_compatible=True) + def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``get_attr`` node. Will retrieve an attribute + value from the ``Module`` hierarchy of ``self.module``. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The value of the attribute that was retrieved + """ + assert isinstance(target, str) + return self.fetch_attr(target) + + @compatibility(is_backward_compatible=True) + def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``call_function`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the function invocation + """ + assert not isinstance(target, str) + + # Execute the function and return the result + return target(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``call_method`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the method invocation + """ + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + # Execute the method and return the result + assert isinstance(target, str) + return getattr(self_obj, target)(*args_tail, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``call_module`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the module invocation + """ + # Retrieve executed args and kwargs values from the environment + + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + + return submod(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute an ``output`` node. This really just retrieves + the value referenced by the ``output`` node and returns it. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The return value referenced by the output node + """ + return args[0] + + # Helper methods + @compatibility(is_backward_compatible=True) + def fetch_attr(self, target : str): + """ + Fetch an attribute from the ``Module`` hierarchy of ``self.module``. + + Args: + target (str): The fully-qualified name of the attribute to fetch + + Return: + Any: The value of the attribute. + """ + target_atoms = target.split('.') + attr_itr = self.module + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i+1])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + + @compatibility(is_backward_compatible=True) + def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: + """ + Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` + from the current execution environment. + + Args: + n (Node): The node for which ``args`` and ``kwargs`` should be fetched. + + Return: + Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. + """ + args = self.map_nodes_to_values(n.args, n) + assert isinstance(args, tuple) + kwargs = self.map_nodes_to_values(n.kwargs, n) + assert isinstance(kwargs, dict) + return args, kwargs + + @compatibility(is_backward_compatible=True) + def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: + """ + Recursively descend through ``args`` and look up the concrete value + for each ``Node`` in the current execution environment. + + Args: + args (Argument): Data structure within which to look up concrete values + + n (Node): Node to which ``args`` belongs. This is only used for error reporting. + """ + def load_arg(n_arg : Node) -> Any: + if n_arg not in self.env: + raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' + f'to diagnose such issues') + return self.env[n_arg] + return map_arg(args, load_arg) + +@compatibility(is_backward_compatible=True) +class Transformer(Interpreter): + """ + ``Transformer`` is a special type of interpreter that produces a + new ``Module``. It exposes a ``transform()`` method that returns + the transformed ``Module``. ``Transformer`` does not require + arguments to run, as ``Interpreter`` does. ``Transformer`` works + entirely symbolically. + + Example: + + Suppose we want to swap all instances of ``torch.neg`` with + ``torch.sigmoid`` and vice versa (including their ``Tensor`` + method equivalents). We could subclass ``Transformer`` like so:: + + class NegSigmSwapXformer(Transformer): + def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + if target == torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(n) + + def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + if target == 'neg': + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(n) + + def fn(x): + return torch.sigmoid(x).neg() + + gm = torch.fx.symbolic_trace(fn) + + transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() + input = torch.randn(3, 4) + torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) + + Args: + module (GraphModule): The ``Module`` to be transformed. + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, module): + super().__init__(module) + self.new_graph = Graph() + self.new_graph.set_codegen(module.graph._codegen) + + class TransformerTracer(Tracer): + def __init__(self, graph: Graph): + super().__init__() + self.graph = graph + self.tensor_attrs: Dict[torch.Tensor, str] = {} # type: ignore[assignment] + + def is_leaf_module(self, _, __) -> bool: + return True + + self.tracer = TransformerTracer(self.new_graph) + self.tracer.root = module + + @compatibility(is_backward_compatible=True) + def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: + """ + Execute a ``placeholder`` node. In ``Transformer``, this is + overridden to insert a new ``placeholder`` into the output + graph. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + """ + assert isinstance(target, str) + default_value = next(iter(args)) if args else inspect.Signature.empty + return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer) + + @compatibility(is_backward_compatible=True) + def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: + """ + Execute a ``get_attr`` node. In ``Transformer``, this is + overridden to insert a new ``get_attr`` node into the output + graph. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + """ + assert isinstance(target, str) + return self.tracer.create_proxy("get_attr", target, args, kwargs) + + @compatibility(is_backward_compatible=True) + def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + # Override so that the leaf module policy from `self.tracer` is respected. + assert isinstance(target, str) + submod = self.fetch_attr(target) + return self.tracer.call_module(submod, submod.forward, args, kwargs) + + @compatibility(is_backward_compatible=True) + def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + # Override so that functions that were wrapped are still wrapped. + return self.tracer.create_proxy('call_function', target, args, kwargs) + + @compatibility(is_backward_compatible=True) + def transform(self) -> GraphModule: + """ + Transform ``self.module`` and return the transformed + ``GraphModule``. + """ + with fx_traceback.preserve_node_meta(): + result = super().run(enable_io_processing=False) + if result is not None: + def strip_proxy(a : Union[Argument, Proxy]) -> Any: + return a.node if isinstance(a, Proxy) else a + new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy)) + # also preserve the metadata from the old output node, if it exists + old_output_node = list(self.graph.nodes)[-1] + assert old_output_node.op == "output" + for k, v in old_output_node.meta.items(): + new_output_node.meta[k] = v + + + return _make_graph_module(self.module, self.new_graph) diff --git a/lib/python3.10/site-packages/torch/fx/node.py b/lib/python3.10/site-packages/torch/fx/node.py new file mode 100644 index 0000000000000000000000000000000000000000..f84b23e29ddf855f799b3de6ce1bd4ccf3da3dd4 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/node.py @@ -0,0 +1,788 @@ +# Nodes represent a definition of a value in our graph of operators. +from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set +from ._compatibility import compatibility +from .immutable_collections import immutable_dict, immutable_list +import torch +import builtins +import types +import inspect +import warnings +from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair +from .._ops import ops as _ops +from torch._C import _NodeBase + +if TYPE_CHECKING: + from .graph import Graph + +__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"] + +BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, + torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, + torch.SymInt, torch.SymBool, torch.SymFloat] +base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] + +Target = Union[Callable[..., Any], str] + +Argument = Optional[Union[ + Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types + List[Any], # actually Argument + Dict[str, Any], # actually Argument + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + range, + 'Node', + BaseArgumentTypes +]] + +_legal_ops = dict.fromkeys(['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']) + +_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = { + torch._C._set_grad_enabled, + torch.amp._enter_autocast, + torch.amp._exit_autocast, +} + +# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs, +# or add logic to correctly mark all inplace ops as side effectful. +_side_effectful_functions: Set[Callable] = { + torch._assert, + torch._assert_async, + _ops.aten._assert_async.msg, + _ops.aten._assert_scalar.default, + _ops.aten.sym_constrain_range.default, + _ops.aten.sym_constrain_range_for_size.default, + _ops.profiler._record_function_enter, + _ops.profiler._record_function_enter_new, + _ops.profiler._record_function_exit, + _ops.inductor.accumulate_grad_.default, +} | _side_effectful_need_to_be_preserved_pre_dispatch +if hasattr(_ops.inductor, "resize_storage_bytes_"): + _side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default) + + +@compatibility(is_backward_compatible=False) +def has_side_effect(fn: Callable) -> Callable: + _side_effectful_functions.add(fn) + return fn + + +# this is fixed on master, WAR for 1.5 +def _find_module_of_method(orig_method: Callable[..., Any]) -> str: + name = orig_method.__name__ + module = orig_method.__module__ + if module is not None: + return module + for guess in [torch, torch.nn.functional]: + if getattr(guess, name, None) is orig_method: + return guess.__name__ + raise RuntimeError(f'cannot find module for {orig_method}') + +# Borrowed from CPython typing module +# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 +def _type_repr(obj: object) -> str: + """Return the repr() of an object, special-casing types (internal helper). + If obj is a type, we return a shorter version than the default + type.__repr__, based on the module and qualified name, which is + typically enough to uniquely identify a type. For everything + else, we fall back on repr(obj). + """ + if isinstance(obj, type): + if obj.__module__ == 'builtins': + return obj.__qualname__ + return f'{obj.__module__}.{obj.__qualname__}' + if obj is ...: + return '...' + if isinstance(obj, types.FunctionType): + return obj.__name__ + return repr(obj) + +def _get_qualified_name(func: Callable[..., Any]) -> str: + # things like getattr just appear in builtins + if getattr(builtins, func.__name__, None) is func: + return func.__name__ + # torch.Tensor.{fn} + if (isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType)) + and func is getattr(torch.Tensor, func.__name__, None)): + return f"torch.Tensor.{func.__name__}" + name = func.__name__ + if name == "": + # For lambdas, try to get their defining name in the module + try: + name = inspect.getsource(func).split("=")[0].strip() + except Exception as e: + raise RuntimeError("Unable to represent lambda") from e + module = _find_module_of_method(func) + module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module + # Fixup segment_reduce mismatch + if module == "torch" and name == "segment_reduce": + name = "_" + name + return f'{module}.{name}' + +def _format_arg(arg: object, max_list_len: float = float('inf')) -> str: + if hasattr(arg, '_custom_fx_repr_fn'): + return arg._custom_fx_repr_fn() + elif isinstance(arg, list): + items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) + maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' + return f'[{items}{maybe_len}]' + elif isinstance(arg, tuple): + items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) + maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' + maybe_comma = ',' if len(arg) == 1 else '' + return f'({items}{maybe_comma}{maybe_len})' + elif isinstance(arg, dict): + items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items()) + return f'{{{items_str}}}' + + if isinstance(arg, Node): + return '%' + str(arg) + else: + return str(arg) + +@compatibility(is_backward_compatible=True) +class Node(_NodeBase): + """ + ``Node`` is the data structure that represents individual operations within + a ``Graph``. For the most part, Nodes represent callsites to various entities, + such as operators, methods, and Modules (some exceptions include nodes that + specify function inputs and outputs). Each ``Node`` has a function specified + by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: + + - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. + ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument + denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to + the function parameters (e.g. ``x``) in the graph printout. + - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the + fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. + ``args`` and ``kwargs`` are don't-care + - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign + to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, + following the Python calling convention + - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is + as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. + ``args`` and ``kwargs`` represent the arguments to invoke the module on, *excluding the self argument*. + - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method + to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, + *including the self argument* + - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement + in the Graph printout. + """ + _args: Tuple['Argument', ...] + _kwargs: Dict[str, 'Argument'] + + @compatibility(is_backward_compatible=True) + def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', + args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], + return_type : Optional[Any] = None) -> None: + """ + Instantiate an instance of ``Node``. Note: most often, you want to use the + Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather + than instantiating a ``Node`` directly. + + Args: + graph (Graph): The ``Graph`` to which this ``Node`` should belong. + + name (str): The name to which the output of this ``Node`` should be assigned + + op (str): The opcode for this ``Node``. Can be one of 'placeholder', + 'call_method', 'call_module', 'call_function', 'get_attr', + 'output' + + target ('Target'): The target this op should call. See the broader + ``Node`` docstring for more details. + + args (Tuple['Argument']): The args to be passed to ``target`` + + kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target`` + + return_type (Optional[Any]): The python type expression representing the + type of the output of this node. This field can be used for + annotation of values in the generated code or for other types + of analyses. + """ + super().__init__() + self.graph = graph + self.name = name # unique name of value being created + assert op in _legal_ops + self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr + if op == 'call_function': + if not callable(target): + raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' + 'but a Callable is expected') + else: + if not isinstance(target, str): + raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' + 'but a str is expected') + self.target = target # for method/module/function, the name of the method/module/function/attr + # being invoked, e.g add, layer1, or torch.add + + # All `Node`-valued inputs. Key is the Node, value is don't-care. + # The public API for this is `all_input_nodes`, this private attribute + # should not be accessed directly. + self._input_nodes : Dict[Node, None] = {} + self.__update_args_kwargs(args, kwargs) + + # All of the nodes that use the value produced by this Node + # Note one user may correspond to several uses, e.g. the node fo ``x + x`` + # would appear once here, but represents two uses. + # + # Is a dict to act as an "ordered set". Keys are significant, value dont-care + self.users : Dict[Node, None] = {} + # Type expression representing the output value of this node. + # This should contain the same class of Type objects that would appear + # as type annotations for function inputs/outputs. + # + # For placeholder nodes, this value will be used to type-annotate the + # generated function parameters. + # For the return node, this value will be used to type-annotate the + # generated function return type. (Note this is a special case. ``return`` + # does not produce a value, it's more of a notation. Thus, this value + # describes the type of args[0] in the ``return`` node. + self.type : Optional[Any] = return_type + self._sort_key: Any = () + + # If set, use this fn to print this node + self._repr_fn : Optional[Callable[[Node], str]] = None + + # Dictionary to store metadata passes need to do their + # transformations. This metadata is preserved across node copies + self.meta : Dict[str, Any] = {} + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + state["_erased"] = self._erased + state["_prev"] = self._prev + state["_next"] = self._next + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + _erased = state.pop("_erased") + _prev = state.pop("_prev") + _next = state.pop("_next") + self.__dict__.update(state) + self._erased = _erased + self._prev = _prev + self._next = _next + + @property + def next(self) -> 'Node': + """ + Returns the next ``Node`` in the linked list of Nodes. + + Returns: + + The next ``Node`` in the linked list of Nodes. + """ + return self._next + + @property + def prev(self) -> 'Node': + """ + Returns the previous ``Node`` in the linked list of Nodes. + + Returns: + + The previous ``Node`` in the linked list of Nodes. + """ + return self._prev + + @compatibility(is_backward_compatible=True) + def prepend(self, x: 'Node') -> None: + """ + Insert x before this node in the list of nodes in the graph. Example:: + + Before: p -> self + bx -> x -> ax + After: p -> x -> self + bx -> ax + + Args: + x (Node): The node to put before this node. Must be a member of the same graph. + """ + assert self.graph == x.graph, "Attempting to move a Node into a different Graph" + if self == x: + warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.") + return + x._remove_from_list() + p = self._prev + p._next, x._prev = x, p + x._next, self._prev = self, x + + # compute x._sort_key + psk = x._prev._sort_key + nsk = x._next._sort_key + if len(psk) > len(nsk): + idx: int + *prefix, idx = psk[:len(nsk) + 1] + x._sort_key = (*prefix, idx + 1) + elif len(psk) < len(nsk): + *prefix, idx = nsk[:len(psk) + 1] + x._sort_key = (*prefix, idx - 1) + else: # same length, increase length by 1 + x._sort_key = (*psk, 0) + + def __gt__(self, other: 'Node') -> bool: + return self._sort_key > other._sort_key + + def __lt__(self, other: 'Node') -> bool: + return self._sort_key < other._sort_key + + def __ge__(self, other: 'Node') -> bool: + return self > other or self == other + + def __le__(self, other: 'Node') -> bool: + return self < other or self == other + + @compatibility(is_backward_compatible=True) + def append(self, x: 'Node') -> None: + """ + Insert ``x`` after this node in the list of nodes in the graph. + Equivalent to ``self.next.prepend(x)`` + + Args: + x (Node): The node to put after this node. Must be a member of the same graph. + """ + self._next.prepend(x) + + def _remove_from_list(self) -> None: + p, n = self._prev, self._next + p._next, n._prev = n, p + + @property + def args(self) -> Tuple[Argument, ...]: + """ + The tuple of arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more + information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. + """ + return self._args + + @args.setter + def args(self, a : Tuple[Argument, ...]) -> None: + """ + Set the tuple of arguments to this Node. The interpretation of arguments + depends on the node's opcode. See the ``fx.Graph`` docstring for more + information. + """ + # DO NOT CALL `__update_args_kwargs` directly. The correct way to + # set `args` is via direct assignment, i.e. `node.args = new_args` + self.__update_args_kwargs(a, self._kwargs) + + @property + def kwargs(self) -> Dict[str, Argument]: + """ + The dict of keyword arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more + information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. + """ + return self._kwargs + + @kwargs.setter + def kwargs(self, k : Dict[str, Argument]) -> None: + """ + Set the dict of kwargs to this Node. The interpretation of arguments + depends on the node's opcode. See the ``fx.Graph`` docstring for more + information. + """ + # DO NOT CALL `__update_args_kwargs` directly. The correct way to + # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` + self.__update_args_kwargs(self._args, k) + + @property + def all_input_nodes(self) -> List['Node']: + """ + Return all Nodes that are inputs to this Node. This is equivalent to + iterating over ``args`` and ``kwargs`` and only collecting the values that + are Nodes. + + Returns: + + List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this + ``Node``, in that order. + """ + return list(self._input_nodes.keys()) + + @compatibility(is_backward_compatible=True) + def update_arg(self, idx : int, arg : Argument) -> None: + """ + Update an existing positional argument to contain the new value + ``arg``. After calling, ``self.args[idx] == arg``. + + Args: + + idx (int): The index into ``self.args`` of the element to update + arg (Argument): The new argument value to write into ``args`` + """ + args = list(self.args) + args[idx] = arg + self.args = tuple(args) + + @compatibility(is_backward_compatible=True) + def insert_arg(self, idx : int, arg : Argument) -> None: + """ + Insert an positional argument to the argument list with given index. + + Args: + + idx (int): The index of the element in ``self.args`` to be inserted before. + arg (Argument): The new argument value to insert into ``args`` + """ + assert 0 <= idx <= len(self.args), "insert_args index must be between 0 and len(self.args)" + args_left = self.args[:idx] + args_right = self.args[idx:] + + self._args = args_left + (arg,) + args_right + + _new_input_nodes: Dict[Node, None] = {} + map_arg(arg, _new_input_nodes.setdefault) + + for new_use in _new_input_nodes.keys(): + if new_use not in self._input_nodes: + self._input_nodes.setdefault(new_use) + new_use.users.setdefault(self) + + @compatibility(is_backward_compatible=True) + def update_kwarg(self, key : str, arg : Argument) -> None: + """ + Update an existing keyword argument to contain the new value + ``arg``. After calling, ``self.kwargs[key] == arg``. + + Args: + + key (str): The key in ``self.kwargs`` of the element to update + arg (Argument): The new argument value to write into ``kwargs`` + """ + self.kwargs = {**self.kwargs, key: arg} + + @property + def stack_trace(self) -> Optional[str]: + """ + Return the Python stack trace that was recorded during tracing, if any. + When traced with fx.Tracer, this property is usually populated by + `Tracer.create_proxy`. To record stack traces during tracing for debug purposes, + set `record_stack_traces = True` on the `Tracer` instance. + When traced with dynamo, this property will be populated by default by + `OutputGraph.create_proxy`. + + stack_trace would have the innermost frame at the end of the string. + """ + return self.meta.get("stack_trace", None) + + @stack_trace.setter + def stack_trace(self, trace : Optional[str]) -> None: + self.meta["stack_trace"] = trace + + def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']) -> None: + """ + This API is internal. Do *not* call it directly. + """ + def update_users_and_input_nodes(n: Any) -> Any: + if isinstance(n, Node): + self._input_nodes.setdefault(n) + n.users.setdefault(self) + return n + + # Clear prior users and input_nodes + for old_use in self._input_nodes.keys(): + old_use.users.pop(self) + self._input_nodes = {} + + # We do three things in a single pass of the args + # - Normalize list->immutable_list, dict->immutable_dict, etc + # - Populate self._input_nodes + # - Populate arg.users[self] for each arg + self._args = map_aggregate(new_args, update_users_and_input_nodes) # type: ignore[assignment] + self._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes) # type: ignore[assignment] + + def __repr__(self) -> str: + if self._repr_fn: + return self._repr_fn(self) + return self.name + + def _pretty_print_target(self, target: object) -> str: + """ + Make target printouts more user-friendly. + 1) builtins will be printed as `builtins.xyz` + 2) operators will be printed as `operator.xyz` + 3) other callables will be printed with qualified name, e.g. torch.add + """ + if isinstance(target, str): + return target + if hasattr(target, '__module__'): + name = getattr(target, '__name__', None) + if name is None: + # Just to be defensive, if we don't have `__name__`, get the + # qualname. Not sure if this happens for any members of `operator` + # or `builtins`. This fallback path is not as good, since e.g. + # things in `operator` have `_operator` as their __module__. + # TODO: THIS IS BROKEN: _get_qualified_name calls `__name__` + return _get_qualified_name(target) # type: ignore[arg-type] + if target.__module__ == 'builtins': + return f'builtins.{name}' + elif target.__module__ == '_operator': + return f'operator.{name}' + return _get_qualified_name(target) # type: ignore[arg-type] + + @compatibility(is_backward_compatible=True) + def format_node(self, + placeholder_names: Optional[List[str]] = None, + maybe_return_typename: Optional[List[str]] = None) -> Optional[str]: + """ + Return a descriptive string representation of ``self``. + + This method can be used with no arguments as a debugging + utility. + + This function is also used internally in the ``__str__`` method + of ``Graph``. Together, the strings in ``placeholder_names`` + and ``maybe_return_typename`` make up the signature of the + autogenerated ``forward`` function in this Graph's surrounding + GraphModule. ``placeholder_names`` and ``maybe_return_typename`` + should not be used otherwise. + + Args: + placeholder_names: A list that will store formatted strings + representing the placeholders in the generated + ``forward`` function. Internal use only. + maybe_return_typename: A single-element list that will store + a formatted string representing the output of the + generated ``forward`` function. Internal use only. + + Returns: + str: If 1) we're using ``format_node`` as an internal helper + in the ``__str__`` method of ``Graph``, and 2) ``self`` + is a placeholder Node, return ``None``. Otherwise, + return a descriptive string representation of the + current Node. + """ + if self.op == 'placeholder': + assert isinstance(self.target, str) + arg_str = self.target + arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else '' + if placeholder_names: + placeholder_names.append(arg_str) + return None + maybe_typename = f'{_type_repr(self.type)} ' if self.type else '' + default_val = '(default=' + str(self.args[0]) + ')' if self.args else '' + return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}' + elif self.op == 'get_attr': + maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' + return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ + f'{self.op}[target={self._pretty_print_target(self.target)}]' + elif self.op == 'output': + if self.type and maybe_return_typename: + maybe_return_typename[0] = f' -> {_type_repr(self.type)}' + return f'return {self.args[0]}' + else: + maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' + return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ + f'{self.op}[target={self._pretty_print_target(self.target)}](' \ + f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})' + + @compatibility(is_backward_compatible=True) + def replace_all_uses_with(self, + replace_with: 'Node', + delete_user_cb: Callable[['Node'], bool] = lambda user: True, + *, + propagate_meta: bool = False + ) -> List['Node']: + """ + Replace all uses of ``self`` in the Graph with the Node ``replace_with``. + + Args: + + replace_with (Node): The node to replace all uses of ``self`` with. + delete_user_cb (Callable): Callback that is called to determine + whether a given user of the self node should be removed. + propagate_meta (bool): Whether or not to copy all properties + on the .meta field of the original node onto the replacement node. + For safety, this is only valid to do if the replacement node + doesn't already have an existing .meta field. + + Returns: + + The list of Nodes on which this change was made. + """ + if propagate_meta: + assert len(replace_with.meta) == 0, \ + 'Called node.replace_all_uses_with(replace_with, propagate_meta=True), ' \ + 'but replace_with already has .meta keys' + for k, v in self.meta.items(): + replace_with.meta[k] = v + to_process = list(self.users) + skipped = [] + m = self.graph.owning_module + for use_node in to_process: + if not delete_user_cb(use_node): + skipped.append(use_node) + continue + + def maybe_replace_node(n : Node) -> Node: + if n == self: + return replace_with + else: + return n + + if getattr(m, "_replace_hook", None): + m._replace_hook(old=self, new=replace_with.name, user=use_node) + + new_args = map_arg(use_node.args, maybe_replace_node) + new_kwargs = map_arg(use_node.kwargs, maybe_replace_node) + assert isinstance(new_args, tuple) + assert isinstance(new_kwargs, dict) + use_node.__update_args_kwargs(new_args, new_kwargs) + + assert len(self.users) - len(skipped) == 0 + return [n for n in to_process if n not in skipped] + + @compatibility(is_backward_compatible=False) + def is_impure(self) -> bool: + """ + Returns whether this op is impure, i.e. if its op is a placeholder or + output, or if a call_function or call_module which is impure. + + Returns: + + bool: If the op is impure or not. + """ + if self.op in {"placeholder", "output"}: + return True + + # Check if an impure function based on schema. + if self.op == "call_function": + schema = getattr(self.target, "_schema", None) + schema_mutable = schema is not None and schema.is_mutable + return schema_mutable or self.target in _side_effectful_functions + + # Check if an impure module. + if self.op == "call_module": + assert ( + self.graph.owning_module is not None + ), "self.graph.owning_module not set for purity check" + target_mod = self.graph.owning_module.get_submodule(self.target) + assert ( + target_mod is not None + ), f"Did not find expected submodule target {self.target}" + return getattr(target_mod, "_is_impure", False) + + return False + + @compatibility(is_backward_compatible=False) + def normalized_arguments( + self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None, + kwarg_types : Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to Python targets. This means that + `args/kwargs` will be matched up to the module/functional's + signature and return exclusively kwargs in positional order + if `normalize_to_only_use_kwargs` is true. + Also populates default values. Does not support positional-only + parameters or varargs parameters. + + Supports module calls. + + May require `arg_types` and `kwarg_types` in order to disambiguate overloads. + + Args: + root (torch.nn.Module): Module upon which to resolve module targets. + arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args + kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns NamedTuple ArgsKwargsPair, or `None` if not successful. + """ + if self.op == 'call_function': + assert callable(self.target) + return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type] + elif self.op == 'call_module': + assert isinstance(self.target, str) + return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type] + + return None + + @compatibility(is_backward_compatible=True) + def replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None: + """ + Loop through input nodes of ``self``, and replace all instances of + ``old_input`` with ``new_input``. + + Args: + + old_input (Node): The old input node to be replaced. + new_input (Node): The new input node to replace ``old_input``. + """ + def maybe_replace_node(n : Node) -> Node: + return new_input if n == old_input else n + + m = self.graph.owning_module + if getattr(m, "_replace_hook", None): + m._replace_hook(old=old_input, new=new_input.name, user=self) + + new_args = map_arg(self.args, maybe_replace_node) + new_kwargs = map_arg(self.kwargs, maybe_replace_node) + assert isinstance(new_args, tuple) + assert isinstance(new_kwargs, dict) + self.__update_args_kwargs(new_args, new_kwargs) + + def _rename(self, candidate: str) -> None: + if candidate == self.name: + return + name = self.graph._graph_namespace.create_name(candidate, None) + self.name = name + self.graph._graph_namespace._rename_object(self, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name == 'name' and hasattr(self, "name"): + m = self.graph.owning_module + if getattr(m, "_replace_hook", None): + assert isinstance(value, str) + for user in self.users: + m._replace_hook(old=self, new=value, user=user) + update = False + if ( + hasattr(self, name) and + hasattr(self.graph, "_find_nodes_lookup_table") and + self in self.graph._find_nodes_lookup_table + ): + update = True + self.graph._find_nodes_lookup_table.remove(self) + object.__setattr__(self, name, value) + if update: + self.graph._find_nodes_lookup_table.insert(self) + +@compatibility(is_backward_compatible=True) +def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: + """ + Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + """ + assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" + return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) + +@compatibility(is_backward_compatible=True) +def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: + """ + Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + """ + if isinstance(a, tuple): + t = tuple([map_aggregate(elem, fn) for elem in a]) + # Support NamedTuple (if it has `_fields`) by repacking into original type. + return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type] + elif isinstance(a, list): + return immutable_list([map_aggregate(elem, fn) for elem in a]) + elif isinstance(a, dict): + rv = immutable_dict() + for k, v in a.items(): + dict.__setitem__(rv, k, map_aggregate(v, fn)) + return rv + elif isinstance(a, slice): + return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) + else: + return fn(a) diff --git a/lib/python3.10/site-packages/torch/fx/operator_schemas.py b/lib/python3.10/site-packages/torch/fx/operator_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5beed5285d9a208d7a28ebdeeb79e1bdf2c19e --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/operator_schemas.py @@ -0,0 +1,451 @@ +# mypy: allow-untyped-defs +import torch +import inspect +import numbers +import types +import typing +import enum +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING +from torch._jit_internal import boolean_dispatched +from ._compatibility import compatibility +from torch._ops import OpOverloadPacket, OpOverload + +if TYPE_CHECKING: + from .node import Argument + +__all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint", + "type_matches", "normalize_function", "normalize_module"] + +@compatibility(is_backward_compatible=False) +class ArgsKwargsPair(NamedTuple): + """ + Simple named tuple for wrapping args/kwargs pairs. + """ + args: Tuple[Any, ...] + kwargs: Dict[str, Any] + +_manual_overrides : Dict[Callable, List[inspect.Signature]] = {} + +def _nonzero_schemas(): + signatures = [] + + def nonzero(self): + pass + signatures.append(inspect.signature(nonzero)) + + def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef] + pass + signatures.append(inspect.signature(nonzero)) + + return signatures + +_manual_overrides[torch.nonzero] = _nonzero_schemas() + +class _FakeGlobalNamespace: + def __getattr__(self, name): + if name == 'torch': + return torch + raise RuntimeError('Expected a torch namespace lookup') + +_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout, + 'number' : numbers.Number, 'Future' : torch.jit.Future, + 'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme, + '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None), + 'Storage': torch.UntypedStorage, + 't': typing.TypeVar('t')} +for k in dir(typing): + _type_eval_globals[k] = getattr(typing, k) + +def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any: + """ + Convert a TorchScript type to a Python type (including subtypes) via + eval'ing the annotation_str. _type_eval_globals sets up expressions + like "List" and "Future" to map to actual types (typing.List and jit.Future) + """ + return eval(ts_type.annotation_str, _type_eval_globals) + +def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: + from inspect import Parameter + parameters : List[Parameter] = [] + for arg in ts_schema.arguments: + arg_type = _torchscript_type_to_python_type(arg.type) + default = arg.default_value if arg.has_default_value() else Parameter.empty + # TODO: Figure out if this is safe. It seems like when generating the type signatures for + # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor + # argument name. Downstream, if someone converts that positional argument to a keyword + # argument, the name mismatch will break things, so here we're going to normalize the + # name to "input" + name = arg.name if arg.name != 'self' else 'input' + kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD + # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument + if name == "from": + assert kind == Parameter.POSITIONAL_OR_KEYWORD + # ParameterKind type is internal implementation detail to inspec package + # which makes it hard to do type annotation + kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment] + # This renders all previous arguments to positional only + for idx, p in enumerate(parameters): + assert p.kind == Parameter.POSITIONAL_OR_KEYWORD + parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation) + parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type)) + return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns] + if len(return_types) == 0: + return_type = None + elif len(return_types) == 1: + return_type = return_types[0] + else: + return_type = tuple(return_types) + + return inspect.Signature(parameters, return_annotation=return_type) + +_SCHEMA_TO_SIGNATURE_CACHE : Dict[Tuple[str, str], inspect.Signature] = {} + +def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: + # Cached as it's called in the hot path of FakeTensor dispatch + cache_key = ts_schema.name, ts_schema.overload_name + cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key) + if cache_val is not None: + return cache_val + + res = _torchscript_schema_to_signature_impl(ts_schema) + _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res + return res + +@compatibility(is_backward_compatible=False) +def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']): + signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) + + if signatures and schemas: + matched_schemas = [] + + # Iterate through all of the schema until we find one that matches + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs + # values. If none matches, `new_args_and_kwargs` will be None + for candidate_signature, schema in zip(signatures, schemas): + try: + candidate_signature.bind(*args, **kwargs) + matched_schemas.append((candidate_signature, schema)) + except TypeError as e: + continue + + def throw_if_mutable(schema): + if schema.is_mutable: + raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional ' + f'code, so operations that mutate operands in-place (e.g. via `out` arguments) ' + f'are not supported') + + if len(matched_schemas) == 0: + # Did not match any schema. Cannot check for mutation + pass + elif len(matched_schemas) == 1: + # Matched exactly one schema, unambiguous + _, schema_to_check = matched_schemas[0] + throw_if_mutable(schema_to_check) + else: + # Ambiguous schema match. Since mutability checking is best effort, + # do nothing. + pass + +@compatibility(is_backward_compatible=False) +def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): + """ + Given an operator on the `torch` namespace, return a list of `inspect.Signature` + objects corresponding to the overloads of that op.. May return `None` if a signature + could not be retrieved. + + Args: + op (Callable): An operator on the `torch` namespace to look up a signature for + + Returns: + Optional[List[inspect.Signature]]: A list of signatures for the overloads of this + operator, or None if the operator signatures could not be retrieved. If + return_schemas=True, returns a tuple containing the optional Python signatures + and the optional TorchScript Function signature + """ + if isinstance(op, OpOverload): + schemas = [op._schema] + elif isinstance(op, OpOverloadPacket): + schemas = [getattr(op, overload)._schema for overload in op.overloads()] + else: + override = _manual_overrides.get(op) + if override: + return (override, None) if return_schemas else None + + aten_fn = torch.jit._builtins._find_builtin(op) + + if aten_fn is None: + return (None, None) if return_schemas else None + schemas = torch._C._jit_get_schemas_for_operator(aten_fn) + + signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] + return (signatures, schemas) if return_schemas else signatures + +@compatibility(is_backward_compatible=False) +def create_type_hint(x): + """ + Produces a type hint for the given argument. + + The :func:`create_type_hint` looks for a type hint compatible with the input argument `x`. + + If `x` is a `list` or `tuple`, it looks for an object in the list whose type is a superclass + of the rest, and uses that as `base_type` for the `List` or `Tuple` to be returned. + If no such object is found, it defaults to `List[Any]`. + + If `x` is neither a `list` nor a `tuple`, it returns `x`. + """ + try: + if isinstance(x, (list, tuple)): + # todo(chilli): Figure out the right way for mypy to handle this + if isinstance(x, list): + def ret_type(x): + return List[x] # type: ignore[valid-type] + else: + def ret_type(x): + return Tuple[x, ...] + if len(x) == 0: + return ret_type(Any) + base_type = x[0] + for t in x: + if issubclass(t, base_type): + continue + elif issubclass(base_type, t): + base_type = t + else: + return ret_type(Any) + return ret_type(base_type) + except Exception as e: + # We tried to create a type hint for list but failed. + warnings.warn(f"We were not able to successfully create type hint from the type {x}") + return x + +@compatibility(is_backward_compatible=False) +def type_matches(signature_type : Any, argument_type : Any): + sig_origin_type = getattr(signature_type, '__origin__', signature_type) + + if signature_type is argument_type: + return True + + # Union types in signature. Given type needs to match one of the + # contained types in the Union + if sig_origin_type is typing.Union and signature_type != argument_type: + sig_contained = signature_type.__args__ + return any(type_matches(c, argument_type) for c in sig_contained) + + if signature_type is List[int] and argument_type is int: + # int can be promoted to List[int] + return True + + if getattr(signature_type, '__origin__', None) in {list, List}: + sig_el_type = signature_type.__args__[0] + if not inspect.isclass(sig_el_type): + warnings.warn( + f"Does not support nested parametric types, got {signature_type}. Please file a bug.") + return False + if getattr(argument_type, '__origin__', None) in {list, List}: + return issubclass(argument_type.__args__[0], sig_el_type) + + def is_homogeneous_tuple(t): + if getattr(t, "__origin__", None) not in {tuple, Tuple}: + return False + contained = t.__args__ + if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason + return True + return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained) + + # Tuple[T] is accepted for List[T] parameters + return is_homogeneous_tuple(argument_type) + + # Dtype is an int in schemas + if signature_type is int and argument_type is torch.dtype: + return True + + if signature_type is numbers.Number and argument_type in {int, float}: + return True + if inspect.isclass(argument_type) and inspect.isclass(signature_type): + return issubclass(argument_type, signature_type) + + return False + +@compatibility(is_backward_compatible=False) +def normalize_function( + target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None, + kwarg_types : Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to PyTorch functions. This means that + `args/kwargs` will be matched up to the functional's + signature and return exclusively kwargs in positional order if + `normalize_to_only_use_kwargs` is True. + Also populates default values. Does not support positional-only + parameters or varargs parameters (*args, **kwargs). Does not support modules. + + May require `arg_types` and `kwarg_types` in order to disambiguate overloads. + + Args: + target (Callable): Function that we are normalizing + args (Tuple[Any]): Tuple of args to the function + kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function + arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args + kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns normalized_args_and_kwargs, or `None` if not successful. + """ + if kwargs is None: + kwargs = {} + new_args_and_kwargs = None + if not isinstance(target, types.BuiltinFunctionType) and not ( + isinstance(target, (OpOverloadPacket, OpOverload)) + ): + target_for_analysis = target + if target in boolean_dispatched: + # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have + # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` + # branches of the dispatch have exactly the same signature. If they do, use the `true` + # branch signature for analysis. Otherwise, leave this un-normalized + assert not isinstance(target, str) + dispatched = boolean_dispatched[target] + if_true, if_false = dispatched['if_true'], dispatched['if_false'] + if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters: + return None + target_for_analysis = if_true + + assert callable(target_for_analysis) + sig = inspect.signature(inspect.unwrap(target_for_analysis)) + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs) + else: + assert callable(target) + torch_op_schemas = get_signature_for_torch_op(target) + matched_schemas = [] + if torch_op_schemas: + # Iterate through all of the schema until we find one that matches + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs + # values. If none matches, `new_args_and_kwargs` will be None + for candidate_signature in torch_op_schemas: + try: + candidate_signature.bind(*args, **kwargs) + matched_schemas.append(candidate_signature) + except TypeError as e: + continue + + if len(matched_schemas) == 0: + # Did not match any schema. Cannot normalize + pass + elif len(matched_schemas) == 1: + # Matched exactly one schema, unambiguous + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs, + normalize_to_only_use_kwargs) + else: + if arg_types is not None or kwarg_types is not None: + arg_types = arg_types if arg_types else cast(Tuple[Any], ()) + kwarg_types = kwarg_types if kwarg_types else {} + for candidate_signature in torch_op_schemas: + sig_matches = True + try: + bound_types = candidate_signature.bind(*arg_types, **kwarg_types) + for arg_name, arg_type in bound_types.arguments.items(): + param = candidate_signature.parameters[arg_name] + sig_matches = sig_matches and type_matches(param.annotation, arg_type) + except TypeError as e: + sig_matches = False + if sig_matches: + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs, + normalize_to_only_use_kwargs) + break + else: + # Matched more than one schema. In this situation, the caller must provide the types of + # the arguments of the overload they expect. + schema_printouts = '\n'.join(str(schema) for schema in matched_schemas) + raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but ' + f'the schema match was ambiguous! Please provide argument types to ' + f'the normalize_arguments() call. Available schemas:\n{schema_printouts}') + + return new_args_and_kwargs + +@compatibility(is_backward_compatible=False) +def normalize_module( + root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to PyTorch modules. This means that + `args/kwargs` will be matched up to the functional's + signature and return exclusively kwargs in positional order if + `normalize_to_only_use_kwargs` is True. + Also populates default values. Does not support positional-only + parameters or varargs parameters (*args, **kwargs). + + Args: + root (nn.Module): root module upon which we query modules + target (Callable): Function that we are normalizing + args (Tuple[Any]): Tuple of args to the function + kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns normalized_args_and_kwargs, or `None` if not successful. + """ + try: + submod = root.get_submodule(target) + except AttributeError as e: + raise RuntimeError(f"Tried to normalize node with target {target} but root did not " + f"have that target!") from e + if hasattr(submod.__class__, '__name__'): + classname = submod.__class__.__name__ + if getattr(torch.nn, classname, None) == submod.__class__: + sig = inspect.signature(inspect.unwrap(submod.forward)) + if kwargs is None: + kwargs = {} + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, + normalize_to_only_use_kwargs) + return new_args_and_kwargs + return None + +def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...], + kwargs : Dict[str, Any], + normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]: + """ + Given a call target, args, and kwargs, return the arguments normalized into + an ArgsKwargsPair, or None if the type signature is not supported by + this normalization. + + Args: + + sig (inspect.Signature): Signature object for the target + args (Tuple): Arguments that appear at the callsite for `target` + kwargs (Dict): Keyword arguments that appear at the callsite for `target` + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if + this target is not supported. + """ + + # Don't currently support positional-only + # or varargs (*args, **kwargs) signatures + supported_parameter_types = { + inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} + if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): + # Add an exception for one signature, which is common for random/uniform, i.e.: + # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None + # `from` is Python keyword and as such functions with that signature should have + # positional-only args, but at the same time they could be dispatched as kwargs + if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']: + return None + + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + new_kwargs : Dict[str, Any] = {} + new_args : List[Any] = [] + for i, param in enumerate(sig.parameters): + if not normalize_to_only_use_kwargs and i < len(args): + new_args.append(bound_args.arguments[param]) + else: + new_kwargs[param] = bound_args.arguments[param] + + return ArgsKwargsPair(tuple(new_args), new_kwargs) diff --git a/lib/python3.10/site-packages/torch/fx/proxy.py b/lib/python3.10/site-packages/torch/fx/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..2b86a1c609f918d63dbcddf791cddf3d72b9f924 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/proxy.py @@ -0,0 +1,609 @@ +# mypy: ignore-errors + +import enum +import dis +import copy +import sys +import torch +import inspect +import operator +import collections +import logging + +from dataclasses import is_dataclass, fields + + +from .graph import magic_methods, reflectable_magic_methods, Graph +from torch.utils._traceback import CapturedTraceback +from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable +from .node import Target, Node, Argument, base_types, map_aggregate +from ._compatibility import compatibility +from .operator_schemas import check_for_mutable_operation +import torch.fx.traceback as fx_traceback + +__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', + 'Proxy', 'Attribute', 'ParameterProxy', 'Scope', + 'ScopeContextManager'] + + +log = logging.getLogger(__name__) + + +@compatibility(is_backward_compatible=False) +class Scope: + """ Scope object that records the module path and the module type + of a module. Scope is used to track the information of the module + that contains a Node in a Graph of GraphModule. For example:: + + class Sub(torch.nn.Module): + def forward(self, x): + # This will be a call_method Node in GraphModule, + # scope for this would be (module_path="sub", module_type=Sub) + return x.transpose(1, 2) + + class M(torch.nn.Module): + def __init__(self) -> None: + self.sub = Sub() + + def forward(self, x): + # This will be a call_method Node as well, + # scope for this would be (module_path="", None) + x = x.transpose(1, 2) + x = self.sub(x) + return x + + """ + + def __init__(self, module_path: str, module_type: Any): + super().__init__() + self.module_path = module_path + self.module_type = module_type + + +@compatibility(is_backward_compatible=False) +class ScopeContextManager: + """ A context manager to track the Scope of Node during symbolic tracing. + When entering a forward function of a Module, we'll update the scope information of + the current module, and when we exit, we'll restore the previous scope information. + """ + + def __init__( + self, + scope: Scope, + current_scope: Scope, + ): + super().__init__() + # Keep a copy of prev scope to restore on exit + self._prev_scope = copy.copy(scope) + # Update scope to current scope + scope.module_path = current_scope.module_path + scope.module_type = current_scope.module_type + # Save a reference so we can restore it + self._scope = scope + + def __enter__(self): + return self._scope + + def __exit__(self, *args): + self._scope.module_path = self._prev_scope.module_path + self._scope.module_type = self._prev_scope.module_type + return + + +_COPY_META_FIELDS = [ + "nn_module_stack", + "torch_fn", + "source_fn_stack", + "original_aten", + "recompute", + "ac_graph_id", + "from_node", + "quantization_tag", # TODO deprecated + "_numeric_debug_handle", # TODO deprecated + "custom", + "partitioner_tag" +] + + +@compatibility(is_backward_compatible=True) +class TracerBase: + graph: Graph + record_stack_traces : bool = False + # Feature flag for mutable schema checking + # Enableby default in 1.12 + check_mutable_operations : bool = False + # Feature flag for assert tracing + trace_asserts : bool = False + # Feature flag for proxying accesses to buffer values + proxy_buffer_attributes : bool = False + + # Name of the function to be traced. It will only be used when + # ``root`` is an instance of ``nn.Module`` + traced_func_name: str = "forward" + + # Maps the containing module's name to the operator name + scope : Scope + + # Records the module call stack + module_stack: OrderedDict[str, Tuple[str, Any]] + + # Mapping of node name to module scope + node_name_to_scope: Dict[str, Tuple[str, type]] + + @compatibility(is_backward_compatible=True) + def create_node(self, kind : str, target : Target, + args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: + """ + Inserts a graph node given target, args, kwargs, and name. + + This method can be overridden to do extra checking, validation, or + modification of values used in node creation. For example, one might + want to disallow in-place operations from being recorded. + """ + + if kind == 'call_function' and self.check_mutable_operations: + check_for_mutable_operation(target, args, kwargs) + + node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) + # TODO node_name_to_scope will be depreciated in favor of + # node.meta['nn_module_stack'] + self.node_name_to_scope[node.name] = ( + self.scope.module_path, + self.scope.module_type, + ) + # Optionally set stack trace on the created Node for debugging purposes + if fx_traceback.has_preserved_node_meta(): + current_meta: Dict[str, Any] = fx_traceback.get_current_meta() + + stack_trace = current_meta.get("stack_trace") + if stack_trace: + node.stack_trace = stack_trace + # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta + # If other meta fields are needed, they can be added here + for field in _COPY_META_FIELDS: + if field in current_meta: + node.meta[field] = copy.copy(current_meta[field]) + + # Here we decrement to account for the sequence_nr having + # just been incremented while tracing this lowered aten op. + new_seq_nr = torch.autograd._get_sequence_nr() - 1 + # The sequence_nr increments every time a new autograd Node + # is created. During the FWD pass we store the sequence_nr + # corresponding to the last autograd Node created on this fx + # node's meta. A single aten op can create multiple autograd + # nodes as is the case with in-place foreach ops. During the + # BWD pass we retrieve the sequence_nr stored on the current + # executing autograd Node. See NOTE [ Sequence Number ]. + if current_meta.get("in_grad_fn", 0) > 0: + new_seq_nr = current_meta["grad_fn_seq_nr"][-1] + node.meta["seq_nr"] = new_seq_nr + + elif self.module_stack: + node.meta['nn_module_stack'] = copy.copy(self.module_stack) + + log.debug("create_node %s", node) + return node + + @compatibility(is_backward_compatible=True) + def proxy(self, node: Node) -> 'Proxy': + return Proxy(node, self) + + @compatibility(is_backward_compatible=True) + def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], + name: Optional[str] = None, type_expr : Optional[Any] = None, + proxy_factory_fn: Callable[[Node], 'Proxy'] = None): + ''' + Create a Node from the given arguments, then return the Node + wrapped in a Proxy object. + + If kind = 'placeholder', then we're creating a Node that + represents the parameter of a function. If we need to encode + a default parameter, we use the ``args`` tuple. ``args`` is + otherwise empty for ``placeholder`` Nodes. + ''' + + args_ = self.create_arg(args) + kwargs_ = self.create_arg(kwargs) + assert isinstance(args_, tuple) + assert isinstance(kwargs_, dict) + + node = self.create_node(kind, target, args_, kwargs_, name, type_expr) + + if not proxy_factory_fn: + proxy = self.proxy(node) + else: + proxy = proxy_factory_fn(node) + + if self.record_stack_traces and not proxy.node.stack_trace: + proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format()) + + + return proxy + + def _find_user_frame(self): + """ + Find the Python stack frame executing the user code during + symbolic tracing. + """ + # We have to do a little dance here. Basically, walk up the callstack and + # record the first frame not in the pytorch source. This is the frame executing + # the user code during tracing. + frame = inspect.currentframe() + + pt_files = ['torch/fx/proxy.py', + 'torch/fx/_symbolic_trace.py', + 'torch/fx/experimental/proxy_tensor.py', + 'torch/_ops.py', + 'torch/_tensor.py', + 'torch/utils/_python_dispatch.py', + 'torch/_prims_common/wrappers.py', + 'torch/_refs/__init__.py', + 'torch/_refs/nn/functional/__init__.py', + 'torch/utils/_stats.py', + ] + while frame: + frame = frame.f_back + if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files): + break + + if not frame: + return None + + return frame + + @compatibility(is_backward_compatible=True) + def create_arg(self, a: Any) -> Argument: + """ + A method that lowers the objects seen as arguments during symbolic evaluation + into Argument types that can be stored in IR. + + Can be override to support more trace-specific types. + """ + if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'): + return a.__fx_create_arg__(self) + # aggregates + elif isinstance(a, tuple) and hasattr(a, '_fields'): + # NamedTuple constructors don't seem to like getting a generator + # expression as an argument to their constructor, so build this + # intermediate tuple and unpack it into the NamedTuple constructor + args = tuple(self.create_arg(elem) for elem in a) + return type(a)(*args) # type: ignore[arg-type] + elif isinstance(a, (tuple, list)): + return type(a)(self.create_arg(elem) for elem in a) + elif isinstance(a, dict): + r = {} + for k, v in a.items(): + # Check for invalid dict keys. We do not want a Proxy to appear + # anywhere within the key. Since keys can be collection types, + # we iterate through the key with map_aggregate + k = self.create_arg(k) + + def no_node(arg): + if isinstance(arg, Node): + raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " + f"Node. Got key: {k}") + map_aggregate(k, no_node) + + r[k] = self.create_arg(v) + return r + elif isinstance(a, slice): + return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) + + elif isinstance(a, range): + return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) + + elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + return a + + if isinstance(a, Proxy): + # base case: we unwrap the Proxy object + return a.node + + if is_dataclass(a): + kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)} + return self.create_node("call_function", a.__class__, (), kwargs) + + elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: + return a + raise NotImplementedError(f"argument of type: {type(a)}") + + @compatibility(is_backward_compatible=True) + def to_bool(self, obj: 'Proxy') -> bool: + """Called when a proxy object is being converted to a boolean, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return a value. + """ + raise TraceError('symbolically traced variables cannot be used as inputs to control flow') + + @compatibility(is_backward_compatible=True) + def iter(self, obj: 'Proxy') -> Iterator: + """Called when a proxy object is being iterated over, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return an iterator. + """ + raise TraceError('Proxy object cannot be iterated. This can be ' + 'attempted when the Proxy is used in a loop or' + ' as a *args or **kwargs function argument. ' + 'See the torch.fx docs on pytorch.org for a ' + 'more detailed explanation of what types of ' + 'control flow can be traced, and check out the' + ' Proxy docstring for help troubleshooting ' + 'Proxy iteration errors') + + @compatibility(is_backward_compatible=True) + def keys(self, obj: 'Proxy') -> Any: + """Called when a proxy object is has the keys() method called. + This is what happens when ** is called on a proxy. This should return an + iterator it ** is suppose to work in your custom tracer. + """ + return Attribute(obj, 'keys')() + + +# used in Proxy object when just appending to the graph while not tracing. +@compatibility(is_backward_compatible=True) +class GraphAppendingTracer(TracerBase): + def __init__(self, graph: Graph): + super().__init__() + self.graph = graph + self.scope = Scope("", None) + self.module_stack = collections.OrderedDict() + self.node_name_to_scope = {} + +@compatibility(is_backward_compatible=False) +def assert_fn(x): + assert x + +@compatibility(is_backward_compatible=True) +class TraceError(ValueError): + pass + +@compatibility(is_backward_compatible=True) +class Proxy: + """ + ``Proxy`` objects are ``Node`` wrappers that flow through the + program during symbolic tracing and record all the operations + (``torch`` function calls, method calls, operators) that they touch + into the growing FX Graph. + + If you're doing graph transforms, you can wrap your own ``Proxy`` + method around a raw ``Node`` so that you can use the overloaded + operators to add additional things to a ``Graph``. + + ``Proxy`` objects cannot be iterated. In other words, the symbolic + tracer will throw an error if a ``Proxy`` is used in a loop or as + an ``*args``/``**kwargs`` function argument. + + There are two main ways around this: + 1. Factor out the untraceable logic into a top-level function and + use ``fx.wrap`` on it. + 2. If the control flow is static (i.e. the loop trip count is + based on some hyperparameter), the code can be kept in its original + position and refactored into something like:: + + for i in range(self.some_hyperparameter): + indexed_item = proxied_value[i] + + For a more detailed description into the Proxy internals, check out + the "Proxy" section in `torch/fx/README.md` + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): + if tracer is None: + # This allows you to create a Proxy object around a raw Node + tracer = GraphAppendingTracer(node.graph) + self.tracer = tracer + self.node = node + + def __repr__(self) -> str: + return f'Proxy({self.node.name})' + + def __getattr__(self, k) -> 'Attribute': + # note: not added to the graph yet, if this is a method call + # we peephole optimize to the method invocation + return Attribute(self, k) + + def __getstate__(self) -> Dict: + return self.__dict__ + + def __deepcopy__(self, memo) -> Dict: + # We have to explicitly override this method, because otherwise deepcopy + # will go to __getattr__(self, "__deepcopy__") and return a + # Attribute(__deepcopy__), and may go into an infinite loop in some cases. + import copy + new_dict = {} + for k, v in self.__dict__.items(): + try: + new_obj = copy.deepcopy(v, memo) + except Exception: + log.warning( + "Shallow copy %s of Proxy because it cannot be deepcopied. " + "Proxy is created for node %s", k, self.node.name) + new_obj = copy.copy(v) + new_dict[k] = new_obj + assert "node" in new_dict + assert "tracer" in new_dict + new_proxy = Proxy(new_dict["node"], new_dict["tracer"]) + for k, v in new_dict.items(): + new_proxy.__dict__[k] = v + return new_proxy + + def __setstate__(self, d): + # This is called when being unpickled/loaded. + self.__dict__ = d + + def __call__(self, *args, **kwargs) -> 'Proxy': + return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) + + def __iter__(self) -> Iterator['Proxy']: + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + inst_list = list(dis.get_instructions(calling_frame.f_code)) + if sys.version_info >= (3, 11): + from bisect import bisect_left + inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset) + else: + inst_idx = calling_frame.f_lasti // 2 + inst = inst_list[inst_idx] + if inst.opname == 'UNPACK_SEQUENCE': + return (self[i] for i in range(inst.argval)) # type: ignore[index] + + return self.tracer.iter(self) + + def __abs__(self): + return self.tracer.create_proxy('call_function', operator.abs, (self,), {}) + + def __bool__(self) -> bool: + if self.tracer.trace_asserts: + # check if this boolean is used in an assertion, bytecode pattern for assertions + # is pretty stable for Python 3.7--3.9 + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + insts = list(dis.get_instructions(calling_frame.f_code)) + if sys.version_info >= (3, 11): + from bisect import bisect_left + cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset) + else: + cur = calling_frame.f_lasti // 2 + inst = insts[cur] + + if inst.opname == 'POP_JUMP_IF_TRUE': + first = insts[cur + 1] + assert inst.arg is not None + last = insts[inst.arg // 2 - 1] + starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError' + or first.opname == 'LOAD_ASSERTION_ERROR') + if starts_with_assert and last.opname == 'RAISE_VARARGS': + self.tracer.create_proxy('call_function', assert_fn, (self,), {}) + return True + + return self.tracer.to_bool(self) + + @compatibility(is_backward_compatible=True) + def keys(self): + return self.tracer.keys(self) + + def __len__(self): + raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " + "this call to be recorded, please call torch.fx.wrap('len') at " + "module scope") + + @classmethod + def __torch_function__(cls, orig_method, types, args=None, kwargs=None): + args = args if args else () + kwargs = kwargs if kwargs else {} + + tracers : Dict[Any, None] = {} + + def find_tracer(a): + if isinstance(a, cls): + tracers[a.tracer] = None + torch.fx.node.map_aggregate(args, find_tracer) + torch.fx.node.map_aggregate(kwargs, find_tracer) + + if len(tracers) > 1: + raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while ' + f'trying to trace operations {orig_method}') + tracer = next(iter(tracers.keys())) + + if isinstance(orig_method, torch._C.ScriptMethod): + args = (orig_method.owner,) + args + return tracer.create_proxy('call_method', orig_method.name, args, kwargs) + if torch.overrides.is_tensor_method_or_property(orig_method): + return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) + else: + if isinstance(orig_method, torch._ops.HigherOrderOperator): + # TODO: Define how to symbolically trace HigherOrderOperators + raise RuntimeError("Unable to symbolically trace HigherOrderOperators") + return tracer.create_proxy('call_function', orig_method, args, kwargs, + name=tracer.graph._target_to_str(orig_method.__name__)) + + +@compatibility(is_backward_compatible=True) +class Attribute(Proxy): + @compatibility(is_backward_compatible=True) + def __init__(self, root: Proxy, attr: str): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node: Optional[Node] = None + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + + +@compatibility(is_backward_compatible=False) +class ParameterProxy(Proxy): + """ + A special proxy which lets "shape", "size", "dim", and a few other + attribute accesses pass through to the underlying module parameter object, + so that conditional tests on these attributes will not throw exception during tracing + """ + def __init__(self, tracer: TracerBase, node: Node, name, param): + super().__init__(node, tracer) + assert isinstance(param, torch.nn.Parameter) + self.param = param + self.name = name + + def __repr__(self) -> str: + return f'ParameterProxy({self.name})' + + @property + def shape(self): + return self.param.shape + + def size(self): + return self.param.size() + + def dim(self): + return self.param.dim() + + @property + def ndim(self): + return self.param.ndim + + def numel(self): + return self.param.numel() + + def nelement(self): + return self.param.nelement() + + +for method in magic_methods: + def _scope(method): + def impl(*args, **kwargs): + tracer = args[0].tracer + target = getattr(operator, method) + return tracer.create_proxy('call_function', target, args, kwargs) + impl.__name__ = method + as_magic = f'__{method.strip("_")}__' + setattr(Proxy, as_magic, impl) + _scope(method) + +def _define_reflectable(orig_method_name): + method_name = f'__r{orig_method_name.strip("_")}__' + + def impl(self, rhs): + target = getattr(operator, orig_method_name) + return self.tracer.create_proxy('call_function', target, (rhs, self), {}) + impl.__name__ = method_name + impl.__qualname__ = method_name + setattr(Proxy, method_name, impl) + +for orig_method_name in reflectable_magic_methods: + _define_reflectable(orig_method_name) diff --git a/lib/python3.10/site-packages/torch/fx/subgraph_rewriter.py b/lib/python3.10/site-packages/torch/fx/subgraph_rewriter.py new file mode 100644 index 0000000000000000000000000000000000000000..7f2cb743d2cdd9ba09c605432a144ccba97327da --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/subgraph_rewriter.py @@ -0,0 +1,348 @@ +from .graph_module import GraphModule +from .graph import Graph +from .node import Node +from ._symbolic_trace import symbolic_trace +from ._compatibility import compatibility + +import copy +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING +import torch + +if TYPE_CHECKING: + from .passes.utils.matcher_with_name_node_map_utils import InternalMatch + +__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"] + +@compatibility(is_backward_compatible=True) +class Match(NamedTuple): + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + +@compatibility(is_backward_compatible=False) +@dataclass +class ReplacedPatterns: + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + # List of nodes that were added into the graph + replacements: List[Node] + +def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: + gm.delete_all_unused_submodules() + + if isinstance(replacement, GraphModule): + replacement.graph.lint() + + def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: + module_path, _, attr_name = target.rpartition(".") + try: + mod: torch.nn.Module = gm.get_submodule(module_path) + except AttributeError: + return None + attr = getattr(mod, attr_name, None) + return attr + + for node in gm.graph.nodes: + if node.op == "call_module" or node.op == "get_attr": + + gm_attr = try_get_attr(gm, node.target) + replacement_attr = try_get_attr(replacement, node.target) + + # CASE 1: This target already exists as an attribute in our + # result GraphModule. Whether or not it exists in + # `replacement`, the existing submodule takes precedence. + if gm_attr is not None: + continue + + # CASE 2: The target exists as an attribute in `replacement` + # only, so we need to copy it over. + elif replacement_attr is not None: + new_attr = copy.deepcopy(replacement_attr) + if isinstance(replacement_attr, torch.nn.Module): + gm.add_submodule(node.target, new_attr) + else: + setattr(gm, node.target, new_attr) + + # CASE 3: The target doesn't exist as an attribute in `gm` + # or `replacement` + else: + raise RuntimeError('Attempted to create a "', node.op, + '" node during subgraph rewriting ' + f"with target {node.target}, but " + "the referenced attribute does not " + "exist in the replacement GraphModule") + + gm.graph.lint() + + +@compatibility(is_backward_compatible=True) +def replace_pattern( + gm: GraphModule, + pattern: Union[Callable, GraphModule], + replacement: Union[Callable, GraphModule] +) -> List[Match]: + """ + Matches all possible non-overlapping sets of operators and their + data dependencies (``pattern``) in the Graph of a GraphModule + (``gm``), then replaces each of these matched subgraphs with another + subgraph (``replacement``). + + Args: + ``gm``: The GraphModule that wraps the Graph to operate on + ``pattern``: The subgraph to match in ``gm`` for replacement + ``replacement``: The subgraph to replace ``pattern`` with + + Returns: + List[Match]: A list of ``Match`` objects representing the places + in the original graph that ``pattern`` was matched to. The list + is empty if there are no matches. ``Match`` is defined as: + + .. code-block:: python + + class Match(NamedTuple): + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + + Examples: + + .. code-block:: python + + import torch + from torch.fx import symbolic_trace, subgraph_rewriter + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, w1, w2): + m1 = torch.cat([w1, w2]).sum() + m2 = torch.cat([w1, w2]).sum() + return x + torch.max(m1) + torch.max(m2) + + def pattern(w1, w2): + return torch.cat([w1, w2]).sum() + + def replacement(w1, w2): + return torch.stack([w1, w2]) + + traced_module = symbolic_trace(M()) + + subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) + + The above code will first match ``pattern`` in the ``forward`` + method of ``traced_module``. Pattern-matching is done based on + use-def relationships, not node names. For example, if you had + ``p = torch.cat([a, b])`` in ``pattern``, you could match + ``m = torch.cat([a, b])`` in the original ``forward`` function, + despite the variable names being different (``p`` vs ``m``). + + The ``return`` statement in ``pattern`` is matched based on its + value only; it may or may not match to the ``return`` statement in + the larger graph. In other words, the pattern doesn't have to extend + to the end of the larger graph. + + When the pattern is matched, it will be removed from the larger + function and replaced by ``replacement``. If there are multiple + matches for ``pattern`` in the larger function, each non-overlapping + match will be replaced. In the case of a match overlap, the first + found match in the set of overlapping matches will be replaced. + ("First" here being defined as the first in a topological ordering + of the Nodes' use-def relationships. In most cases, the first Node + is the parameter that appears directly after ``self``, while the + last Node is whatever the function returns.) + + One important thing to note is that the parameters of the + ``pattern`` Callable must be used in the Callable itself, + and the parameters of the ``replacement`` Callable must match + the pattern. The first rule is why, in the above code block, the + ``forward`` function has parameters ``x, w1, w2``, but the + ``pattern`` function only has parameters ``w1, w2``. ``pattern`` + doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. + As an example of the second rule, consider replacing + + .. code-block:: python + + def pattern(x, y): + return torch.neg(x) + torch.relu(y) + + with + + .. code-block:: python + + def replacement(x, y): + return torch.relu(x) + + In this case, ``replacement`` needs the same number of parameters + as ``pattern`` (both ``x`` and ``y``), even though the parameter + ``y`` isn't used in ``replacement``. + + After calling ``subgraph_rewriter.replace_pattern``, the generated + Python code looks like this: + + .. code-block:: python + + def forward(self, x, w1, w2): + stack_1 = torch.stack([w1, w2]) + sum_1 = stack_1.sum() + stack_2 = torch.stack([w1, w2]) + sum_2 = stack_2.sum() + max_1 = torch.max(sum_1) + add_1 = x + max_1 + max_2 = torch.max(sum_2) + add_2 = add_1 + max_2 + return add_2 + """ + match_and_replacements = _replace_pattern(gm, pattern, replacement) + return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements] + + +# Experimental API, not backward compatible +@compatibility(is_backward_compatible=False) +def replace_pattern_with_filters( + gm: GraphModule, + pattern: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule], + match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, + ignore_literals: bool = False, +) -> List[ReplacedPatterns]: + """ + See replace_pattern for documentation. This function is an overload with an additional match_filter argument. + + Args: + ``match_filters``: A list of functions that take in + (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating + whether the match satisfies the condition. + See matcher_utils.py for definition of InternalMatch. + """ + + return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals) + + +def _replace_pattern( + gm: GraphModule, + pattern: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule], + match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, + ignore_literals: bool = False, +) -> List[ReplacedPatterns]: + + from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch + + if match_filters is None: + match_filters = [] + + # Get the graphs for `gm`, `pattern`, `replacement` + original_graph: Graph = gm.graph + + if isinstance(pattern, GraphModule): + pattern_graph = pattern.graph + elif isinstance(pattern, Graph): + pattern_graph = pattern + else: + pattern_graph = symbolic_trace(pattern).graph + + if isinstance(replacement, GraphModule): + replacement_graph = replacement.graph + elif isinstance(replacement, Graph): + replacement_graph = replacement + else: + replacement_graph = symbolic_trace(replacement).graph + + matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False, + remove_overlapping_matches=True, ignore_literals=ignore_literals) + _matches: List[InternalMatch] = matcher.match(original_graph) + + # Filter out matches that don't match the filter + _matches = [ + m for m in _matches + if all(match_filter(m, original_graph, pattern_graph) + for match_filter in match_filters) + ] + + replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] + + # As we progressively replace nodes, we'll need to keep track of how the match results should change + match_changed_node: Dict[Node, Node] = {} + + match_and_replacements = [] + for match in _matches: + + # Build connecting between replacement graph's input and original graph input producer node + + # Initialize `val_map` with mappings from placeholder nodes in + # `replacement` to their corresponding node in `original_graph` + assert len(match.placeholder_nodes) == len(replacement_placeholders) + val_map: Dict[Node, Node] = {} + for rn, gn in zip(replacement_placeholders, match.placeholder_nodes): + if isinstance(gn, Node): + val_map[rn] = match_changed_node.get(gn, gn) + if gn != val_map[rn]: + # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn + gn_ind = match.placeholder_nodes.index(gn) + match.placeholder_nodes[gn_ind] = match_changed_node[gn] + map_key = list(match.nodes_map.keys())[list(match.nodes_map.values()).index(gn)] + match.nodes_map[map_key] = match_changed_node[gn] + else: + val_map[rn] = gn + + # Copy the replacement graph over + user_nodes: Set[Node] = set() + for n in match.returning_nodes: + user_nodes.update(n.users) + assert user_nodes, "The returning_nodes should have at least one user node" + + if len(user_nodes) == 1: + first_user_node = next(iter(user_nodes)) + else: + # If there are multiple user nodes, we need to find the first user node + # in the current execution order of the `original_graph` + for n in original_graph.nodes: + if n in user_nodes: + first_user_node = n + break + + with original_graph.inserting_before(first_user_node): # type: ignore[possibly-undefined] + copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map) + + if isinstance(copied_returning_nodes, Node): + copied_returning_nodes = (copied_returning_nodes, ) + + # Get a list of nodes that have been replaced into the graph + replacement_nodes: List[Node] = [v for v in val_map.values() if v not in match.placeholder_nodes] + + # Hook the output Node of the replacement subgraph in to the + # original Graph at the correct location + assert len(match.returning_nodes) == len(copied_returning_nodes) # type: ignore[arg-type] + for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type] + gn.replace_all_uses_with(copied_node) + match_changed_node[gn] = copied_node + # Remove the original nodes + for node in reversed(pattern_graph.nodes): + if node.op != "placeholder" and node.op != "output": + gn = match.nodes_map[node] + gm.graph.erase_node(gn) + + match_and_replacements.append( + ReplacedPatterns( + anchor=match.anchors[0], + nodes_map=match.nodes_map, + replacements=replacement_nodes + ) + ) + + # Update the passed-in GraphModule to reflect the new state of + # `original_graph` + gm.recompile() + + # If `replacement` was an nn.Module, we'll need to make sure that + # all the submodules have been copied over correctly + if isinstance(replacement, torch.nn.Module): + _replace_attributes(gm, replacement) + + return match_and_replacements diff --git a/lib/python3.10/site-packages/torch/fx/tensor_type.py b/lib/python3.10/site-packages/torch/fx/tensor_type.py new file mode 100644 index 0000000000000000000000000000000000000000..83b5a9f8faf65e86813d55926ef275fa19ef2013 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/tensor_type.py @@ -0,0 +1,105 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.unification import Var # type: ignore[attr-defined] + +from ._compatibility import compatibility + + +@compatibility(is_backward_compatible=False) +class TensorType: + """ + TensorType defines a type for tensors, which consists of a list of dimensions. + Example: + class M(torch.nn.Module): + def forward(self, x:TensorType((1,2,3, Dyn)), y:TensorType((1,2,3, Dyn))): + return torch.add(x, y) + """ + + def __init__(self, dim): + self.__origin__ = TensorType + self.__args__ = dim + + def __repr__(self): + return f'TensorType[{self.__args__}]' + + def __eq__(self, other): + if isinstance(other, self.__class__): + return list(self.__args__) == list(other.__args__) + else: + return False + + @staticmethod + def __class_getitem__(*args): + if len(args) == 1 and isinstance(args[0], tuple): + args = args[0] + return TensorType(tuple(args)) + + +class _DynType: + """ + _DynType defines a type which stands for the absence of type information. + """ + def __init__(self) -> None: + self.__name__ = '_DynType' + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __str__(self): + return "Dyn" + + def __repr__(self): + return "Dyn" + + +Dyn = _DynType() + +@compatibility(is_backward_compatible=False) +def is_consistent(t1, t2): + """ + A binary relation denoted by ~ that determines if t1 is consistent with t2. + The relation is reflexive, symmetric but not transitive. + returns True if t1 and t2 are consistent and False otherwise. + Example: + Dyn ~ TensorType((1,2,3)) + int ~ Dyn + int ~ int + TensorType((1,Dyn,3)) ~ TensorType((1,2,3)) + """ + + if t1 == t2: + return True + + if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): + return True + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + return len(t1.__args__) == len(t2.__args__) and \ + all(is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)) + else: + return False + + +@compatibility(is_backward_compatible=False) +def is_more_precise(t1, t2): + """ + A binary relation denoted by <= that determines if t1 is more precise than t2. + The relation is reflexive and transitive. + returns True if t1 is more precise than t2 and False otherwise. + Example: + Dyn >= TensorType((1,2,3)) + int >= Dyn + int >= int + TensorType((1,Dyn,3)) <= TensorType((1,2,3)) + """ + if t1 == t2: + return True + + if isinstance(t2, _DynType): + return True + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + return len(t1.__args__) == len(t2.__args__) and \ + all(is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)) + + else: + return False diff --git a/lib/python3.10/site-packages/torch/fx/traceback.py b/lib/python3.10/site-packages/torch/fx/traceback.py new file mode 100644 index 0000000000000000000000000000000000000000..4e72a8011f63ab44a64b8ece6ccbe4b2875fdda6 --- /dev/null +++ b/lib/python3.10/site-packages/torch/fx/traceback.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +import traceback +from contextlib import contextmanager +from typing import List, Any, Dict +from ._compatibility import compatibility + +__all__ = ['preserve_node_meta', 'has_preserved_node_meta', + 'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr', + 'format_stack', 'set_current_meta', 'get_current_meta'] + +current_meta: Dict[str, Any] = {} +should_preserve_node_meta = False + + +@compatibility(is_backward_compatible=False) +@contextmanager +def preserve_node_meta(): + global should_preserve_node_meta + global current_meta + + saved_should_preserve_node_meta = should_preserve_node_meta + # Shallow copy is OK since fields of current_meta are not mutated + saved_current_meta = current_meta.copy() + try: + should_preserve_node_meta = True + yield + finally: + should_preserve_node_meta = saved_should_preserve_node_meta + current_meta = saved_current_meta + + +@compatibility(is_backward_compatible=False) +def set_stack_trace(stack : List[str]): + global current_meta + + if should_preserve_node_meta and stack: + current_meta["stack_trace"] = "".join(stack) + + +@compatibility(is_backward_compatible=False) +def set_grad_fn_seq_nr(seq_nr): + global current_meta + + if should_preserve_node_meta: + # The seq_nr is captured by eager mode in the grad_fn during forward + current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr] + current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1 + + +@compatibility(is_backward_compatible=False) +def reset_grad_fn_seq_nr(): + # NB: reset state properly, this would be helpful towards supporting + # reentrant autograd if we actually wanted to do that. + global current_meta + if should_preserve_node_meta: + current_level = current_meta.get("in_grad_fn", 0) + assert current_level > 0 + if current_level == 1: + del current_meta["in_grad_fn"] + del current_meta["grad_fn_seq_nr"] + else: + current_meta["in_grad_fn"] = current_level - 1 + current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1] + + +@compatibility(is_backward_compatible=False) +def format_stack() -> List[str]: + if should_preserve_node_meta: + return [current_meta.get("stack_trace", "")] + else: + # fallback to traceback.format_stack() + return traceback.format_list(traceback.extract_stack()[:-1]) + + +@compatibility(is_backward_compatible=False) +def has_preserved_node_meta() -> bool: + return should_preserve_node_meta + + +@compatibility(is_backward_compatible=False) +@contextmanager +def set_current_meta(node): + global current_meta + if should_preserve_node_meta and node.meta: + saved_meta = current_meta + try: + current_meta = node.meta.copy() + + # Append (node.name, node.target) onto "from_node" for provenance tracking + if "from_node" not in current_meta: + current_meta["from_node"] = [(node.name, node.target)] + elif current_meta["from_node"][-1][0] != node.name: + current_meta["from_node"] = current_meta["from_node"] + [(node.name, node.target)] + + yield + finally: + current_meta = saved_meta + else: + yield + + +@compatibility(is_backward_compatible=False) +def get_current_meta() -> Dict[str, Any]: + return current_meta diff --git a/lib/python3.10/site-packages/torch/include/clog.h b/lib/python3.10/site-packages/torch/include/clog.h new file mode 100644 index 0000000000000000000000000000000000000000..bf09cd0cb6de4ff632807ad2e58df9e402906878 --- /dev/null +++ b/lib/python3.10/site-packages/torch/include/clog.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#define CLOG_NONE 0 +#define CLOG_FATAL 1 +#define CLOG_ERROR 2 +#define CLOG_WARNING 3 +#define CLOG_INFO 4 +#define CLOG_DEBUG 5 + +#ifndef CLOG_VISIBILITY +#if defined(__ELF__) +#define CLOG_VISIBILITY __attribute__((__visibility__("internal"))) +#elif defined(__MACH__) +#define CLOG_VISIBILITY __attribute__((__visibility__("hidden"))) +#else +#define CLOG_VISIBILITY +#endif +#endif + +#ifndef CLOG_ARGUMENTS_FORMAT +#if defined(__GNUC__) +#define CLOG_ARGUMENTS_FORMAT __attribute__((__format__(__printf__, 1, 2))) +#else +#define CLOG_ARGUMENTS_FORMAT +#endif +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +CLOG_VISIBILITY void clog_vlog_debug( + const char* module, + const char* format, + va_list args); +CLOG_VISIBILITY void clog_vlog_info( + const char* module, + const char* format, + va_list args); +CLOG_VISIBILITY void clog_vlog_warning( + const char* module, + const char* format, + va_list args); +CLOG_VISIBILITY void clog_vlog_error( + const char* module, + const char* format, + va_list args); +CLOG_VISIBILITY void clog_vlog_fatal( + const char* module, + const char* format, + va_list args); + +#define CLOG_DEFINE_LOG_DEBUG(log_debug_function_name, module, level) \ + CLOG_ARGUMENTS_FORMAT \ + inline static void log_debug_function_name(const char* format, ...) { \ + if (level >= CLOG_DEBUG) { \ + va_list args; \ + va_start(args, format); \ + clog_vlog_debug(module, format, args); \ + va_end(args); \ + } \ + } + +#define CLOG_DEFINE_LOG_INFO(log_info_function_name, module, level) \ + CLOG_ARGUMENTS_FORMAT \ + inline static void log_info_function_name(const char* format, ...) { \ + if (level >= CLOG_INFO) { \ + va_list args; \ + va_start(args, format); \ + clog_vlog_info(module, format, args); \ + va_end(args); \ + } \ + } + +#define CLOG_DEFINE_LOG_WARNING(log_warning_function_name, module, level) \ + CLOG_ARGUMENTS_FORMAT \ + inline static void log_warning_function_name(const char* format, ...) { \ + if (level >= CLOG_WARNING) { \ + va_list args; \ + va_start(args, format); \ + clog_vlog_warning(module, format, args); \ + va_end(args); \ + } \ + } + +#define CLOG_DEFINE_LOG_ERROR(log_error_function_name, module, level) \ + CLOG_ARGUMENTS_FORMAT \ + inline static void log_error_function_name(const char* format, ...) { \ + if (level >= CLOG_ERROR) { \ + va_list args; \ + va_start(args, format); \ + clog_vlog_error(module, format, args); \ + va_end(args); \ + } \ + } + +#define CLOG_DEFINE_LOG_FATAL(log_fatal_function_name, module, level) \ + CLOG_ARGUMENTS_FORMAT \ + inline static void log_fatal_function_name(const char* format, ...) { \ + if (level >= CLOG_FATAL) { \ + va_list args; \ + va_start(args, format); \ + clog_vlog_fatal(module, format, args); \ + va_end(args); \ + } \ + abort(); \ + } + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/lib/python3.10/site-packages/torch/include/cpuinfo.h b/lib/python3.10/site-packages/torch/include/cpuinfo.h new file mode 100644 index 0000000000000000000000000000000000000000..8bb1db4e96470883bcc04a0d803770a3fc093015 --- /dev/null +++ b/lib/python3.10/site-packages/torch/include/cpuinfo.h @@ -0,0 +1,2245 @@ +#pragma once +#ifndef CPUINFO_H +#define CPUINFO_H + +#ifndef __cplusplus +#include +#endif + +#ifdef __APPLE__ +#include +#endif + +#include + +/* Identify architecture and define corresponding macro */ + +#if defined(__i386__) || defined(__i486__) || defined(__i586__) || defined(__i686__) || defined(_M_IX86) +#define CPUINFO_ARCH_X86 1 +#endif + +#if defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) || defined(_M_AMD64) +#define CPUINFO_ARCH_X86_64 1 +#endif + +#if defined(__arm__) || defined(_M_ARM) +#define CPUINFO_ARCH_ARM 1 +#endif + +#if defined(__aarch64__) || defined(_M_ARM64) +#define CPUINFO_ARCH_ARM64 1 +#endif + +#if defined(__PPC64__) || defined(__powerpc64__) || defined(_ARCH_PPC64) +#define CPUINFO_ARCH_PPC64 1 +#endif + +#if defined(__asmjs__) +#define CPUINFO_ARCH_ASMJS 1 +#endif + +#if defined(__wasm__) +#if defined(__wasm_simd128__) +#define CPUINFO_ARCH_WASMSIMD 1 +#else +#define CPUINFO_ARCH_WASM 1 +#endif +#endif + +#if defined(__riscv) +#if (__riscv_xlen == 32) +#define CPUINFO_ARCH_RISCV32 1 +#elif (__riscv_xlen == 64) +#define CPUINFO_ARCH_RISCV64 1 +#endif +#endif + +/* Define other architecture-specific macros as 0 */ + +#ifndef CPUINFO_ARCH_X86 +#define CPUINFO_ARCH_X86 0 +#endif + +#ifndef CPUINFO_ARCH_X86_64 +#define CPUINFO_ARCH_X86_64 0 +#endif + +#ifndef CPUINFO_ARCH_ARM +#define CPUINFO_ARCH_ARM 0 +#endif + +#ifndef CPUINFO_ARCH_ARM64 +#define CPUINFO_ARCH_ARM64 0 +#endif + +#ifndef CPUINFO_ARCH_PPC64 +#define CPUINFO_ARCH_PPC64 0 +#endif + +#ifndef CPUINFO_ARCH_ASMJS +#define CPUINFO_ARCH_ASMJS 0 +#endif + +#ifndef CPUINFO_ARCH_WASM +#define CPUINFO_ARCH_WASM 0 +#endif + +#ifndef CPUINFO_ARCH_WASMSIMD +#define CPUINFO_ARCH_WASMSIMD 0 +#endif + +#ifndef CPUINFO_ARCH_RISCV32 +#define CPUINFO_ARCH_RISCV32 0 +#endif + +#ifndef CPUINFO_ARCH_RISCV64 +#define CPUINFO_ARCH_RISCV64 0 +#endif + +#if CPUINFO_ARCH_X86 && defined(_MSC_VER) +#define CPUINFO_ABI __cdecl +#elif CPUINFO_ARCH_X86 && defined(__GNUC__) +#define CPUINFO_ABI __attribute__((__cdecl__)) +#else +#define CPUINFO_ABI +#endif + +#define CPUINFO_CACHE_UNIFIED 0x00000001 +#define CPUINFO_CACHE_INCLUSIVE 0x00000002 +#define CPUINFO_CACHE_COMPLEX_INDEXING 0x00000004 + +struct cpuinfo_cache { + /** Cache size in bytes */ + uint32_t size; + /** Number of ways of associativity */ + uint32_t associativity; + /** Number of sets */ + uint32_t sets; + /** Number of partitions */ + uint32_t partitions; + /** Line size in bytes */ + uint32_t line_size; + /** + * Binary characteristics of the cache (unified cache, inclusive cache, + * cache with complex indexing). + * + * @see CPUINFO_CACHE_UNIFIED, CPUINFO_CACHE_INCLUSIVE, + * CPUINFO_CACHE_COMPLEX_INDEXING + */ + uint32_t flags; + /** Index of the first logical processor that shares this cache */ + uint32_t processor_start; + /** Number of logical processors that share this cache */ + uint32_t processor_count; +}; + +struct cpuinfo_trace_cache { + uint32_t uops; + uint32_t associativity; +}; + +#define CPUINFO_PAGE_SIZE_4KB 0x1000 +#define CPUINFO_PAGE_SIZE_1MB 0x100000 +#define CPUINFO_PAGE_SIZE_2MB 0x200000 +#define CPUINFO_PAGE_SIZE_4MB 0x400000 +#define CPUINFO_PAGE_SIZE_16MB 0x1000000 +#define CPUINFO_PAGE_SIZE_1GB 0x40000000 + +struct cpuinfo_tlb { + uint32_t entries; + uint32_t associativity; + uint64_t pages; +}; + +/** Vendor of processor core design */ +enum cpuinfo_vendor { + /** Processor vendor is not known to the library, or the library failed + to get vendor information from the OS. */ + cpuinfo_vendor_unknown = 0, + + /* Active vendors of modern CPUs */ + + /** + * Intel Corporation. Vendor of x86, x86-64, IA64, and ARM processor + * microarchitectures. + * + * Sold its ARM design subsidiary in 2006. The last ARM processor design + * was released in 2004. + */ + cpuinfo_vendor_intel = 1, + /** Advanced Micro Devices, Inc. Vendor of x86 and x86-64 processor + microarchitectures. */ + cpuinfo_vendor_amd = 2, + /** ARM Holdings plc. Vendor of ARM and ARM64 processor + microarchitectures. */ + cpuinfo_vendor_arm = 3, + /** Qualcomm Incorporated. Vendor of ARM and ARM64 processor + microarchitectures. */ + cpuinfo_vendor_qualcomm = 4, + /** Apple Inc. Vendor of ARM and ARM64 processor microarchitectures. */ + cpuinfo_vendor_apple = 5, + /** Samsung Electronics Co., Ltd. Vendir if ARM64 processor + microarchitectures. */ + cpuinfo_vendor_samsung = 6, + /** Nvidia Corporation. Vendor of ARM64-compatible processor + microarchitectures. */ + cpuinfo_vendor_nvidia = 7, + /** MIPS Technologies, Inc. Vendor of MIPS processor microarchitectures. + */ + cpuinfo_vendor_mips = 8, + /** International Business Machines Corporation. Vendor of PowerPC + processor microarchitectures. */ + cpuinfo_vendor_ibm = 9, + /** Ingenic Semiconductor. Vendor of MIPS processor microarchitectures. + */ + cpuinfo_vendor_ingenic = 10, + /** + * VIA Technologies, Inc. Vendor of x86 and x86-64 processor + * microarchitectures. + * + * Processors are designed by Centaur Technology, a subsidiary of VIA + * Technologies. + */ + cpuinfo_vendor_via = 11, + /** Cavium, Inc. Vendor of ARM64 processor microarchitectures. */ + cpuinfo_vendor_cavium = 12, + /** Broadcom, Inc. Vendor of ARM processor microarchitectures. */ + cpuinfo_vendor_broadcom = 13, + /** Applied Micro Circuits Corporation (APM). Vendor of ARM64 processor + microarchitectures. */ + cpuinfo_vendor_apm = 14, + /** + * Huawei Technologies Co., Ltd. Vendor of ARM64 processor + * microarchitectures. + * + * Processors are designed by HiSilicon, a subsidiary of Huawei. + */ + cpuinfo_vendor_huawei = 15, + /** + * Hygon (Chengdu Haiguang Integrated Circuit Design Co., Ltd), Vendor + * of x86-64 processor microarchitectures. + * + * Processors are variants of AMD cores. + */ + cpuinfo_vendor_hygon = 16, + /** SiFive, Inc. Vendor of RISC-V processor microarchitectures. */ + cpuinfo_vendor_sifive = 17, + + /* Active vendors of embedded CPUs */ + + /** Texas Instruments Inc. Vendor of ARM processor microarchitectures. + */ + cpuinfo_vendor_texas_instruments = 30, + /** Marvell Technology Group Ltd. Vendor of ARM processor + * microarchitectures. + */ + cpuinfo_vendor_marvell = 31, + /** RDC Semiconductor Co., Ltd. Vendor of x86 processor + microarchitectures. */ + cpuinfo_vendor_rdc = 32, + /** DM&P Electronics Inc. Vendor of x86 processor microarchitectures. */ + cpuinfo_vendor_dmp = 33, + /** Motorola, Inc. Vendor of PowerPC and ARM processor + microarchitectures. */ + cpuinfo_vendor_motorola = 34, + + /* Defunct CPU vendors */ + + /** + * Transmeta Corporation. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 2004. + * Transmeta processors implemented VLIW ISA and used binary translation + * to execute x86 code. + */ + cpuinfo_vendor_transmeta = 50, + /** + * Cyrix Corporation. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1996. + */ + cpuinfo_vendor_cyrix = 51, + /** + * Rise Technology. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1999. + */ + cpuinfo_vendor_rise = 52, + /** + * National Semiconductor. Vendor of x86 processor microarchitectures. + * + * Sold its x86 design subsidiary in 1999. The last processor design was + * released in 1998. + */ + cpuinfo_vendor_nsc = 53, + /** + * Silicon Integrated Systems. Vendor of x86 processor + * microarchitectures. + * + * Sold its x86 design subsidiary in 2001. The last processor design was + * released in 2001. + */ + cpuinfo_vendor_sis = 54, + /** + * NexGen. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1994. + * NexGen designed the first x86 microarchitecture which decomposed x86 + * instructions into simple microoperations. + */ + cpuinfo_vendor_nexgen = 55, + /** + * United Microelectronics Corporation. Vendor of x86 processor + * microarchitectures. + * + * Ceased x86 in the early 1990s. The last processor design was released + * in 1991. Designed U5C and U5D processors. Both are 486 level. + */ + cpuinfo_vendor_umc = 56, + /** + * Digital Equipment Corporation. Vendor of ARM processor + * microarchitecture. + * + * Sold its ARM designs in 1997. The last processor design was released + * in 1997. + */ + cpuinfo_vendor_dec = 57, +}; + +/** + * Processor microarchitecture + * + * Processors with different microarchitectures often have different instruction + * performance characteristics, and may have dramatically different pipeline + * organization. + */ +enum cpuinfo_uarch { + /** Microarchitecture is unknown, or the library failed to get + information about the microarchitecture from OS */ + cpuinfo_uarch_unknown = 0, + + /** Pentium and Pentium MMX microarchitecture. */ + cpuinfo_uarch_p5 = 0x00100100, + /** Intel Quark microarchitecture. */ + cpuinfo_uarch_quark = 0x00100101, + + /** Pentium Pro, Pentium II, and Pentium III. */ + cpuinfo_uarch_p6 = 0x00100200, + /** Pentium M. */ + cpuinfo_uarch_dothan = 0x00100201, + /** Intel Core microarchitecture. */ + cpuinfo_uarch_yonah = 0x00100202, + /** Intel Core 2 microarchitecture on 65 nm process. */ + cpuinfo_uarch_conroe = 0x00100203, + /** Intel Core 2 microarchitecture on 45 nm process. */ + cpuinfo_uarch_penryn = 0x00100204, + /** Intel Nehalem and Westmere microarchitectures (Core i3/i5/i7 1st + gen). */ + cpuinfo_uarch_nehalem = 0x00100205, + /** Intel Sandy Bridge microarchitecture (Core i3/i5/i7 2nd gen). */ + cpuinfo_uarch_sandy_bridge = 0x00100206, + /** Intel Ivy Bridge microarchitecture (Core i3/i5/i7 3rd gen). */ + cpuinfo_uarch_ivy_bridge = 0x00100207, + /** Intel Haswell microarchitecture (Core i3/i5/i7 4th gen). */ + cpuinfo_uarch_haswell = 0x00100208, + /** Intel Broadwell microarchitecture. */ + cpuinfo_uarch_broadwell = 0x00100209, + /** Intel Sky Lake microarchitecture (14 nm, including + Kaby/Coffee/Whiskey/Amber/Comet/Cascade/Cooper Lake). */ + cpuinfo_uarch_sky_lake = 0x0010020A, + /** DEPRECATED (Intel Kaby Lake microarchitecture). */ + cpuinfo_uarch_kaby_lake = 0x0010020A, + /** Intel Palm Cove microarchitecture (10 nm, Cannon Lake). */ + cpuinfo_uarch_palm_cove = 0x0010020B, + /** Intel Sunny Cove microarchitecture (10 nm, Ice Lake). */ + cpuinfo_uarch_sunny_cove = 0x0010020C, + + /** Pentium 4 with Willamette, Northwood, or Foster cores. */ + cpuinfo_uarch_willamette = 0x00100300, + /** Pentium 4 with Prescott and later cores. */ + cpuinfo_uarch_prescott = 0x00100301, + + /** Intel Atom on 45 nm process. */ + cpuinfo_uarch_bonnell = 0x00100400, + /** Intel Atom on 32 nm process. */ + cpuinfo_uarch_saltwell = 0x00100401, + /** Intel Silvermont microarchitecture (22 nm out-of-order Atom). */ + cpuinfo_uarch_silvermont = 0x00100402, + /** Intel Airmont microarchitecture (14 nm out-of-order Atom). */ + cpuinfo_uarch_airmont = 0x00100403, + /** Intel Goldmont microarchitecture (Denverton, Apollo Lake). */ + cpuinfo_uarch_goldmont = 0x00100404, + /** Intel Goldmont Plus microarchitecture (Gemini Lake). */ + cpuinfo_uarch_goldmont_plus = 0x00100405, + + /** Intel Knights Ferry HPC boards. */ + cpuinfo_uarch_knights_ferry = 0x00100500, + /** Intel Knights Corner HPC boards (aka Xeon Phi). */ + cpuinfo_uarch_knights_corner = 0x00100501, + /** Intel Knights Landing microarchitecture (second-gen MIC). */ + cpuinfo_uarch_knights_landing = 0x00100502, + /** Intel Knights Hill microarchitecture (third-gen MIC). */ + cpuinfo_uarch_knights_hill = 0x00100503, + /** Intel Knights Mill Xeon Phi. */ + cpuinfo_uarch_knights_mill = 0x00100504, + + /** Intel/Marvell XScale series. */ + cpuinfo_uarch_xscale = 0x00100600, + + /** AMD K5. */ + cpuinfo_uarch_k5 = 0x00200100, + /** AMD K6 and alike. */ + cpuinfo_uarch_k6 = 0x00200101, + /** AMD Athlon and Duron. */ + cpuinfo_uarch_k7 = 0x00200102, + /** AMD Athlon 64, Opteron 64. */ + cpuinfo_uarch_k8 = 0x00200103, + /** AMD Family 10h (Barcelona, Istambul, Magny-Cours). */ + cpuinfo_uarch_k10 = 0x00200104, + /** + * AMD Bulldozer microarchitecture + * Zambezi FX-series CPUs, Zurich, Valencia and Interlagos Opteron CPUs. + */ + cpuinfo_uarch_bulldozer = 0x00200105, + /** + * AMD Piledriver microarchitecture + * Vishera FX-series CPUs, Trinity and Richland APUs, Delhi, Seoul, Abu + * Dhabi Opteron CPUs. + */ + cpuinfo_uarch_piledriver = 0x00200106, + /** AMD Steamroller microarchitecture (Kaveri APUs). */ + cpuinfo_uarch_steamroller = 0x00200107, + /** AMD Excavator microarchitecture (Carizzo APUs). */ + cpuinfo_uarch_excavator = 0x00200108, + /** AMD Zen microarchitecture (12/14 nm Ryzen and EPYC CPUs). */ + cpuinfo_uarch_zen = 0x00200109, + /** AMD Zen 2 microarchitecture (7 nm Ryzen and EPYC CPUs). */ + cpuinfo_uarch_zen2 = 0x0020010A, + /** AMD Zen 3 microarchitecture. */ + cpuinfo_uarch_zen3 = 0x0020010B, + /** AMD Zen 4 microarchitecture. */ + cpuinfo_uarch_zen4 = 0x0020010C, + + /** NSC Geode and AMD Geode GX and LX. */ + cpuinfo_uarch_geode = 0x00200200, + /** AMD Bobcat mobile microarchitecture. */ + cpuinfo_uarch_bobcat = 0x00200201, + /** AMD Jaguar mobile microarchitecture. */ + cpuinfo_uarch_jaguar = 0x00200202, + /** AMD Puma mobile microarchitecture. */ + cpuinfo_uarch_puma = 0x00200203, + + /** ARM7 series. */ + cpuinfo_uarch_arm7 = 0x00300100, + /** ARM9 series. */ + cpuinfo_uarch_arm9 = 0x00300101, + /** ARM 1136, ARM 1156, ARM 1176, or ARM 11MPCore. */ + cpuinfo_uarch_arm11 = 0x00300102, + + /** ARM Cortex-A5. */ + cpuinfo_uarch_cortex_a5 = 0x00300205, + /** ARM Cortex-A7. */ + cpuinfo_uarch_cortex_a7 = 0x00300207, + /** ARM Cortex-A8. */ + cpuinfo_uarch_cortex_a8 = 0x00300208, + /** ARM Cortex-A9. */ + cpuinfo_uarch_cortex_a9 = 0x00300209, + /** ARM Cortex-A12. */ + cpuinfo_uarch_cortex_a12 = 0x00300212, + /** ARM Cortex-A15. */ + cpuinfo_uarch_cortex_a15 = 0x00300215, + /** ARM Cortex-A17. */ + cpuinfo_uarch_cortex_a17 = 0x00300217, + + /** ARM Cortex-A32. */ + cpuinfo_uarch_cortex_a32 = 0x00300332, + /** ARM Cortex-A35. */ + cpuinfo_uarch_cortex_a35 = 0x00300335, + /** ARM Cortex-A53. */ + cpuinfo_uarch_cortex_a53 = 0x00300353, + /** ARM Cortex-A55 revision 0 (restricted dual-issue capabilities + compared to revision 1+). */ + cpuinfo_uarch_cortex_a55r0 = 0x00300354, + /** ARM Cortex-A55. */ + cpuinfo_uarch_cortex_a55 = 0x00300355, + /** ARM Cortex-A57. */ + cpuinfo_uarch_cortex_a57 = 0x00300357, + /** ARM Cortex-A65. */ + cpuinfo_uarch_cortex_a65 = 0x00300365, + /** ARM Cortex-A72. */ + cpuinfo_uarch_cortex_a72 = 0x00300372, + /** ARM Cortex-A73. */ + cpuinfo_uarch_cortex_a73 = 0x00300373, + /** ARM Cortex-A75. */ + cpuinfo_uarch_cortex_a75 = 0x00300375, + /** ARM Cortex-A76. */ + cpuinfo_uarch_cortex_a76 = 0x00300376, + /** ARM Cortex-A77. */ + cpuinfo_uarch_cortex_a77 = 0x00300377, + /** ARM Cortex-A78. */ + cpuinfo_uarch_cortex_a78 = 0x00300378, + + /** ARM Neoverse N1. */ + cpuinfo_uarch_neoverse_n1 = 0x00300400, + /** ARM Neoverse E1. */ + cpuinfo_uarch_neoverse_e1 = 0x00300401, + /** ARM Neoverse V1. */ + cpuinfo_uarch_neoverse_v1 = 0x00300402, + /** ARM Neoverse N2. */ + cpuinfo_uarch_neoverse_n2 = 0x00300403, + /** ARM Neoverse V2. */ + cpuinfo_uarch_neoverse_v2 = 0x00300404, + + /** ARM Cortex-X1. */ + cpuinfo_uarch_cortex_x1 = 0x00300501, + /** ARM Cortex-X2. */ + cpuinfo_uarch_cortex_x2 = 0x00300502, + /** ARM Cortex-X3. */ + cpuinfo_uarch_cortex_x3 = 0x00300503, + /** ARM Cortex-X4. */ + cpuinfo_uarch_cortex_x4 = 0x00300504, + + /** ARM Cortex-A510. */ + cpuinfo_uarch_cortex_a510 = 0x00300551, + /** ARM Cortex-A520. */ + cpuinfo_uarch_cortex_a520 = 0x00300552, + /** ARM Cortex-A710. */ + cpuinfo_uarch_cortex_a710 = 0x00300571, + /** ARM Cortex-A715. */ + cpuinfo_uarch_cortex_a715 = 0x00300572, + /** ARM Cortex-A720. */ + cpuinfo_uarch_cortex_a720 = 0x00300573, + + /** Qualcomm Scorpion. */ + cpuinfo_uarch_scorpion = 0x00400100, + /** Qualcomm Krait. */ + cpuinfo_uarch_krait = 0x00400101, + /** Qualcomm Kryo. */ + cpuinfo_uarch_kryo = 0x00400102, + /** Qualcomm Falkor. */ + cpuinfo_uarch_falkor = 0x00400103, + /** Qualcomm Saphira. */ + cpuinfo_uarch_saphira = 0x00400104, + + /** Nvidia Denver. */ + cpuinfo_uarch_denver = 0x00500100, + /** Nvidia Denver 2. */ + cpuinfo_uarch_denver2 = 0x00500101, + /** Nvidia Carmel. */ + cpuinfo_uarch_carmel = 0x00500102, + + /** Samsung Exynos M1 (Exynos 8890 big cores). */ + cpuinfo_uarch_exynos_m1 = 0x00600100, + /** Samsung Exynos M2 (Exynos 8895 big cores). */ + cpuinfo_uarch_exynos_m2 = 0x00600101, + /** Samsung Exynos M3 (Exynos 9810 big cores). */ + cpuinfo_uarch_exynos_m3 = 0x00600102, + /** Samsung Exynos M4 (Exynos 9820 big cores). */ + cpuinfo_uarch_exynos_m4 = 0x00600103, + /** Samsung Exynos M5 (Exynos 9830 big cores). */ + cpuinfo_uarch_exynos_m5 = 0x00600104, + + /* Deprecated synonym for Cortex-A76 */ + cpuinfo_uarch_cortex_a76ae = 0x00300376, + /* Deprecated names for Exynos. */ + cpuinfo_uarch_mongoose_m1 = 0x00600100, + cpuinfo_uarch_mongoose_m2 = 0x00600101, + cpuinfo_uarch_meerkat_m3 = 0x00600102, + cpuinfo_uarch_meerkat_m4 = 0x00600103, + + /** Apple A6 and A6X processors. */ + cpuinfo_uarch_swift = 0x00700100, + /** Apple A7 processor. */ + cpuinfo_uarch_cyclone = 0x00700101, + /** Apple A8 and A8X processor. */ + cpuinfo_uarch_typhoon = 0x00700102, + /** Apple A9 and A9X processor. */ + cpuinfo_uarch_twister = 0x00700103, + /** Apple A10 and A10X processor. */ + cpuinfo_uarch_hurricane = 0x00700104, + /** Apple A11 processor (big cores). */ + cpuinfo_uarch_monsoon = 0x00700105, + /** Apple A11 processor (little cores). */ + cpuinfo_uarch_mistral = 0x00700106, + /** Apple A12 processor (big cores). */ + cpuinfo_uarch_vortex = 0x00700107, + /** Apple A12 processor (little cores). */ + cpuinfo_uarch_tempest = 0x00700108, + /** Apple A13 processor (big cores). */ + cpuinfo_uarch_lightning = 0x00700109, + /** Apple A13 processor (little cores). */ + cpuinfo_uarch_thunder = 0x0070010A, + /** Apple A14 / M1 processor (big cores). */ + cpuinfo_uarch_firestorm = 0x0070010B, + /** Apple A14 / M1 processor (little cores). */ + cpuinfo_uarch_icestorm = 0x0070010C, + /** Apple A15 / M2 processor (big cores). */ + cpuinfo_uarch_avalanche = 0x0070010D, + /** Apple A15 / M2 processor (little cores). */ + cpuinfo_uarch_blizzard = 0x0070010E, + + /** Cavium ThunderX. */ + cpuinfo_uarch_thunderx = 0x00800100, + /** Cavium ThunderX2 (originally Broadcom Vulkan). */ + cpuinfo_uarch_thunderx2 = 0x00800200, + + /** Marvell PJ4. */ + cpuinfo_uarch_pj4 = 0x00900100, + + /** Broadcom Brahma B15. */ + cpuinfo_uarch_brahma_b15 = 0x00A00100, + /** Broadcom Brahma B53. */ + cpuinfo_uarch_brahma_b53 = 0x00A00101, + + /** Applied Micro X-Gene. */ + cpuinfo_uarch_xgene = 0x00B00100, + + /* Hygon Dhyana (a modification of AMD Zen for Chinese market). */ + cpuinfo_uarch_dhyana = 0x01000100, + + /** HiSilicon TaiShan v110 (Huawei Kunpeng 920 series processors). */ + cpuinfo_uarch_taishan_v110 = 0x00C00100, +}; + +struct cpuinfo_processor { + /** SMT (hyperthread) ID within a core */ + uint32_t smt_id; + /** Core containing this logical processor */ + const struct cpuinfo_core* core; + /** Cluster of cores containing this logical processor */ + const struct cpuinfo_cluster* cluster; + /** Physical package containing this logical processor */ + const struct cpuinfo_package* package; +#if defined(__linux__) + /** + * Linux-specific ID for the logical processor: + * - Linux kernel exposes information about this logical processor in + * /sys/devices/system/cpu/cpu/ + * - Bit in the cpu_set_t identifies this logical processor + */ + int linux_id; +#endif +#if defined(_WIN32) || defined(__CYGWIN__) + /** Windows-specific ID for the group containing the logical processor. + */ + uint16_t windows_group_id; + /** + * Windows-specific ID of the logical processor within its group: + * - Bit in the KAFFINITY mask identifies this + * logical processor within its group. + */ + uint16_t windows_processor_id; +#endif +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + /** APIC ID (unique x86-specific ID of the logical processor) */ + uint32_t apic_id; +#endif + struct { + /** Level 1 instruction cache */ + const struct cpuinfo_cache* l1i; + /** Level 1 data cache */ + const struct cpuinfo_cache* l1d; + /** Level 2 unified or data cache */ + const struct cpuinfo_cache* l2; + /** Level 3 unified or data cache */ + const struct cpuinfo_cache* l3; + /** Level 4 unified or data cache */ + const struct cpuinfo_cache* l4; + } cache; +}; + +struct cpuinfo_core { + /** Index of the first logical processor on this core. */ + uint32_t processor_start; + /** Number of logical processors on this core */ + uint32_t processor_count; + /** Core ID within a package */ + uint32_t core_id; + /** Cluster containing this core */ + const struct cpuinfo_cluster* cluster; + /** Physical package containing this core. */ + const struct cpuinfo_package* package; + /** Vendor of the CPU microarchitecture for this core */ + enum cpuinfo_vendor vendor; + /** CPU microarchitecture for this core */ + enum cpuinfo_uarch uarch; +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + /** Value of CPUID leaf 1 EAX register for this core */ + uint32_t cpuid; +#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + /** Value of Main ID Register (MIDR) for this core */ + uint32_t midr; +#endif + /** Clock rate (non-Turbo) of the core, in Hz */ + uint64_t frequency; +}; + +struct cpuinfo_cluster { + /** Index of the first logical processor in the cluster */ + uint32_t processor_start; + /** Number of logical processors in the cluster */ + uint32_t processor_count; + /** Index of the first core in the cluster */ + uint32_t core_start; + /** Number of cores on the cluster */ + uint32_t core_count; + /** Cluster ID within a package */ + uint32_t cluster_id; + /** Physical package containing the cluster */ + const struct cpuinfo_package* package; + /** CPU microarchitecture vendor of the cores in the cluster */ + enum cpuinfo_vendor vendor; + /** CPU microarchitecture of the cores in the cluster */ + enum cpuinfo_uarch uarch; +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + /** Value of CPUID leaf 1 EAX register of the cores in the cluster */ + uint32_t cpuid; +#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + /** Value of Main ID Register (MIDR) of the cores in the cluster */ + uint32_t midr; +#endif + /** Clock rate (non-Turbo) of the cores in the cluster, in Hz */ + uint64_t frequency; +}; + +#define CPUINFO_PACKAGE_NAME_MAX 48 + +struct cpuinfo_package { + /** SoC or processor chip model name */ + char name[CPUINFO_PACKAGE_NAME_MAX]; + /** Index of the first logical processor on this physical package */ + uint32_t processor_start; + /** Number of logical processors on this physical package */ + uint32_t processor_count; + /** Index of the first core on this physical package */ + uint32_t core_start; + /** Number of cores on this physical package */ + uint32_t core_count; + /** Index of the first cluster of cores on this physical package */ + uint32_t cluster_start; + /** Number of clusters of cores on this physical package */ + uint32_t cluster_count; +}; + +struct cpuinfo_uarch_info { + /** Type of CPU microarchitecture */ + enum cpuinfo_uarch uarch; +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + /** Value of CPUID leaf 1 EAX register for the microarchitecture */ + uint32_t cpuid; +#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + /** Value of Main ID Register (MIDR) for the microarchitecture */ + uint32_t midr; +#endif + /** Number of logical processors with the microarchitecture */ + uint32_t processor_count; + /** Number of cores with the microarchitecture */ + uint32_t core_count; +}; + +#ifdef __cplusplus +extern "C" { +#endif + +bool CPUINFO_ABI cpuinfo_initialize(void); + +void CPUINFO_ABI cpuinfo_deinitialize(void); + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +/* This structure is not a part of stable API. Use cpuinfo_has_x86_* functions + * instead. */ +struct cpuinfo_x86_isa { +#if CPUINFO_ARCH_X86 + bool rdtsc; +#endif + bool rdtscp; + bool rdpid; + bool sysenter; +#if CPUINFO_ARCH_X86 + bool syscall; +#endif + bool msr; + bool clzero; + bool clflush; + bool clflushopt; + bool mwait; + bool mwaitx; +#if CPUINFO_ARCH_X86 + bool emmx; +#endif + bool fxsave; + bool xsave; +#if CPUINFO_ARCH_X86 + bool fpu; + bool mmx; + bool mmx_plus; +#endif + bool three_d_now; + bool three_d_now_plus; +#if CPUINFO_ARCH_X86 + bool three_d_now_geode; +#endif + bool prefetch; + bool prefetchw; + bool prefetchwt1; +#if CPUINFO_ARCH_X86 + bool daz; + bool sse; + bool sse2; +#endif + bool sse3; + bool ssse3; + bool sse4_1; + bool sse4_2; + bool sse4a; + bool misaligned_sse; + bool avx; + bool avxvnni; + bool fma3; + bool fma4; + bool xop; + bool f16c; + bool avx2; + bool avx512f; + bool avx512pf; + bool avx512er; + bool avx512cd; + bool avx512dq; + bool avx512bw; + bool avx512vl; + bool avx512ifma; + bool avx512vbmi; + bool avx512vbmi2; + bool avx512bitalg; + bool avx512vpopcntdq; + bool avx512vnni; + bool avx512bf16; + bool avx512fp16; + bool avx512vp2intersect; + bool avx512_4vnniw; + bool avx512_4fmaps; + bool amx_bf16; + bool amx_tile; + bool amx_int8; + bool amx_fp16; + bool avx_vnni_int8; + bool avx_vnni_int16; + bool avx_ne_convert; + bool hle; + bool rtm; + bool xtest; + bool mpx; +#if CPUINFO_ARCH_X86 + bool cmov; + bool cmpxchg8b; +#endif + bool cmpxchg16b; + bool clwb; + bool movbe; +#if CPUINFO_ARCH_X86_64 + bool lahf_sahf; +#endif + bool fs_gs_base; + bool lzcnt; + bool popcnt; + bool tbm; + bool bmi; + bool bmi2; + bool adx; + bool aes; + bool vaes; + bool pclmulqdq; + bool vpclmulqdq; + bool gfni; + bool rdrand; + bool rdseed; + bool sha; + bool rng; + bool ace; + bool ace2; + bool phe; + bool pmm; + bool lwp; +}; + +extern struct cpuinfo_x86_isa cpuinfo_isa; +#endif + +static inline bool cpuinfo_has_x86_rdtsc(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.rdtsc; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_rdtscp(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.rdtscp; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_rdpid(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.rdpid; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_clzero(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.clzero; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_mwait(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.mwait; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_mwaitx(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.mwaitx; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_fxsave(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.fxsave; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_xsave(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.xsave; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_fpu(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.fpu; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_mmx(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.mmx; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_mmx_plus(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.mmx_plus; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_3dnow(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.three_d_now; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_3dnow_plus(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.three_d_now_plus; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_3dnow_geode(void) { +#if CPUINFO_ARCH_X86_64 + return false; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return false; +#else + return cpuinfo_isa.three_d_now_geode; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_prefetch(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.prefetch; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_prefetchw(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.prefetchw; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_prefetchwt1(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.prefetchwt1; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_daz(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.daz; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sse(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.sse; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sse2(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.sse2; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sse3(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.sse3; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_ssse3(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.ssse3; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sse4_1(void) { +#if CPUINFO_ARCH_X86_64 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.sse4_1; +#endif +#elif CPUINFO_ARCH_X86 + return cpuinfo_isa.sse4_1; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sse4_2(void) { +#if CPUINFO_ARCH_X86_64 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.sse4_2; +#endif +#elif CPUINFO_ARCH_X86 + return cpuinfo_isa.sse4_2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sse4a(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.sse4a; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_misaligned_sse(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.misaligned_sse; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avxvnni(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avxvnni; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_fma3(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.fma3; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_fma4(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.fma4; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_xop(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.xop; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_f16c(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.f16c; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx2(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512f(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512f; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512pf(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512pf; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512er(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512er; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512cd(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512cd; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512dq(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512dq; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512bw(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512bw; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512vl(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512vl; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512ifma(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512ifma; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512vbmi(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512vbmi; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512vbmi2(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512vbmi2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512bitalg(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512bitalg; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512vpopcntdq(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512vpopcntdq; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512vnni(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512vnni; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512bf16(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512bf16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512fp16(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512fp16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512vp2intersect(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512vp2intersect; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512_4vnniw(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512_4vnniw; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512_4fmaps(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512_4fmaps; +#else + return false; +#endif +} + +/* [NOTE] Intel Advanced Matrix Extensions (AMX) detection + * + * I. AMX is a new extensions to the x86 ISA to work on matrices, consists of + * 1) 2-dimentional registers (tiles), hold sub-matrices from larger matrices in memory + * 2) Accelerator called Tile Matrix Multiply (TMUL), contains instructions operating on tiles + * + * II. Platforms that supports AMX: + * +-----------------+-----+----------+----------+----------+----------+ + * | Platforms | Gen | amx-bf16 | amx-tile | amx-int8 | amx-fp16 | + * +-----------------+-----+----------+----------+----------+----------+ + * | Sapphire Rapids | 4th | YES | YES | YES | NO | + * +-----------------+-----+----------+----------+----------+----------+ + * | Emerald Rapids | 5th | YES | YES | YES | NO | + * +-----------------+-----+----------+----------+----------+----------+ + * | Granite Rapids | 6th | YES | YES | YES | YES | + * +-----------------+-----+----------+----------+----------+----------+ + * + * Reference: https://www.intel.com/content/www/us/en/products/docs + * /accelerator-engines/advanced-matrix-extensions/overview.html + */ +static inline bool cpuinfo_has_x86_amx_bf16(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.amx_bf16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_amx_tile(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.amx_tile; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_amx_int8(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.amx_int8; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_amx_fp16(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.amx_fp16; +#else + return false; +#endif +} + +/* + * Intel AVX Vector Neural Network Instructions (VNNI) INT8 + * Supported Platfroms: Sierra Forest, Arrow Lake, Lunar Lake + */ +static inline bool cpuinfo_has_x86_avx_vnni_int8(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx_vnni_int8; +#else + return false; +#endif +} + +/* + * Intel AVX Vector Neural Network Instructions (VNNI) INT16 + * Supported Platfroms: Arrow Lake, Lunar Lake + */ +static inline bool cpuinfo_has_x86_avx_vnni_int16(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx_vnni_int16; +#else + return false; +#endif +} + +/* + * A new set of instructions, which can convert low precision floating point + * like BF16/FP16 to high precision floating point FP32, as well as convert FP32 + * elements to BF16. This instruction allows the platform to have improved AI + * capabilities and better compatibility. + * + * Supported Platforms: Sierra Forest, Arrow Lake, Lunar Lake + */ +static inline bool cpuinfo_has_x86_avx_ne_convert(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx_ne_convert; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_hle(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.hle; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_rtm(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.rtm; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_xtest(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.xtest; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_mpx(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.mpx; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_cmov(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 + return cpuinfo_isa.cmov; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_cmpxchg8b(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 + return cpuinfo_isa.cmpxchg8b; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_cmpxchg16b(void) { +#if CPUINFO_ARCH_X86_64 + return cpuinfo_isa.cmpxchg16b; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_clwb(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.clwb; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_movbe(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.movbe; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_lahf_sahf(void) { +#if CPUINFO_ARCH_X86 + return true; +#elif CPUINFO_ARCH_X86_64 + return cpuinfo_isa.lahf_sahf; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_lzcnt(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.lzcnt; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_popcnt(void) { +#if CPUINFO_ARCH_X86_64 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.popcnt; +#endif +#elif CPUINFO_ARCH_X86 + return cpuinfo_isa.popcnt; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_tbm(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.tbm; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_bmi(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.bmi; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_bmi2(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.bmi2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_adx(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.adx; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_aes(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.aes; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_vaes(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.vaes; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_pclmulqdq(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.pclmulqdq; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_vpclmulqdq(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.vpclmulqdq; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_gfni(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.gfni; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_rdrand(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.rdrand; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_rdseed(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.rdseed; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sha(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.sha; +#else + return false; +#endif +} + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +/* This structure is not a part of stable API. Use cpuinfo_has_arm_* functions + * instead. */ +struct cpuinfo_arm_isa { +#if CPUINFO_ARCH_ARM + bool thumb; + bool thumb2; + bool thumbee; + bool jazelle; + bool armv5e; + bool armv6; + bool armv6k; + bool armv7; + bool armv7mp; + bool armv8; + bool idiv; + + bool vfpv2; + bool vfpv3; + bool d32; + bool fp16; + bool fma; + + bool wmmx; + bool wmmx2; + bool neon; +#endif +#if CPUINFO_ARCH_ARM64 + bool atomics; + bool bf16; + bool sve; + bool sve2; + bool i8mm; + bool sme; + uint32_t svelen; +#endif + bool rdm; + bool fp16arith; + bool dot; + bool jscvt; + bool fcma; + bool fhm; + + bool aes; + bool sha1; + bool sha2; + bool pmull; + bool crc32; +}; + +extern struct cpuinfo_arm_isa cpuinfo_isa; +#endif + +static inline bool cpuinfo_has_arm_thumb(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.thumb; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_thumb2(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.thumb2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_v5e(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.armv5e; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_v6(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.armv6; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_v6k(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.armv6k; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_v7(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.armv7; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_v7mp(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.armv7mp; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_v8(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.armv8; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_idiv(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.idiv; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv2(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv3(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv3; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv3_d32(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv3 && cpuinfo_isa.d32; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv3_fp16(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv3 && cpuinfo_isa.fp16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv3_fp16_d32(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv3 && cpuinfo_isa.fp16 && cpuinfo_isa.d32; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv4(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv3 && cpuinfo_isa.fma; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv4_d32(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv3 && cpuinfo_isa.fma && cpuinfo_isa.d32; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_fp16_arith(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.fp16arith; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_bf16(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.bf16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_wmmx(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.wmmx; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_wmmx2(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.wmmx2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.neon; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_fp16(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.neon && cpuinfo_isa.fp16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_fma(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.neon && cpuinfo_isa.fma; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_v8(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.neon && cpuinfo_isa.armv8; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_atomics(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.atomics; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_rdm(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.rdm; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_fp16_arith(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.neon && cpuinfo_isa.fp16arith; +#elif CPUINFO_ARCH_ARM64 + return cpuinfo_isa.fp16arith; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_fhm(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.fhm; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_dot(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.dot; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_bf16(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.bf16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_jscvt(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.jscvt; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_fcma(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.fcma; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_i8mm(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.i8mm; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_aes(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.aes; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sha1(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sha1; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sha2(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sha2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_pmull(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.pmull; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_crc32(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.crc32; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sve(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sve; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sve_bf16(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sve && cpuinfo_isa.bf16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sve2(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sve2; +#else + return false; +#endif +} + +// Function to get the max SVE vector length on ARM CPU's which support SVE. +static inline uint32_t cpuinfo_get_max_arm_sve_length(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.svelen * 8; // bytes * 8 = bit length(vector length) +#else + return 0; +#endif +} + +static inline bool cpuinfo_has_arm_sme(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sme; +#else + return false; +#endif +} + +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 +/* This structure is not a part of stable API. Use cpuinfo_has_riscv_* functions + * instead. */ +struct cpuinfo_riscv_isa { + /** + * Keep fields in line with the canonical order as defined by + * Section 27.11 Subset Naming Convention. + */ + /* RV32I/64I/128I Base ISA. */ + bool i; +#if CPUINFO_ARCH_RISCV32 + /* RV32E Base ISA. */ + bool e; +#endif + /* Integer Multiply/Divide Extension. */ + bool m; + /* Atomic Extension. */ + bool a; + /* Single-Precision Floating-Point Extension. */ + bool f; + /* Double-Precision Floating-Point Extension. */ + bool d; + /* Compressed Extension. */ + bool c; + /* Vector Extension. */ + bool v; +}; + +extern struct cpuinfo_riscv_isa cpuinfo_isa; +#endif + +static inline bool cpuinfo_has_riscv_i(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.i; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_e(void) { +#if CPUINFO_ARCH_RISCV32 + return cpuinfo_isa.e; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_m(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.m; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_a(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.a; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_f(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.f; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_d(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.d; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_g(void) { + // The 'G' extension is simply shorthand for 'IMAFD'. + return cpuinfo_has_riscv_i() && cpuinfo_has_riscv_m() && cpuinfo_has_riscv_a() && cpuinfo_has_riscv_f() && + cpuinfo_has_riscv_d(); +} + +static inline bool cpuinfo_has_riscv_c(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.c; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_v(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.v; +#else + return false; +#endif +} + +const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_processors(void); +const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_cores(void); +const struct cpuinfo_cluster* CPUINFO_ABI cpuinfo_get_clusters(void); +const struct cpuinfo_package* CPUINFO_ABI cpuinfo_get_packages(void); +const struct cpuinfo_uarch_info* CPUINFO_ABI cpuinfo_get_uarchs(void); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_caches(void); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1d_caches(void); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l2_caches(void); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l3_caches(void); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l4_caches(void); + +const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_processor(uint32_t index); +const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_core(uint32_t index); +const struct cpuinfo_cluster* CPUINFO_ABI cpuinfo_get_cluster(uint32_t index); +const struct cpuinfo_package* CPUINFO_ABI cpuinfo_get_package(uint32_t index); +const struct cpuinfo_uarch_info* CPUINFO_ABI cpuinfo_get_uarch(uint32_t index); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_cache(uint32_t index); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1d_cache(uint32_t index); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l2_cache(uint32_t index); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l3_cache(uint32_t index); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l4_cache(uint32_t index); + +uint32_t CPUINFO_ABI cpuinfo_get_processors_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_cores_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_clusters_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_packages_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_uarchs_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_l1i_caches_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_l1d_caches_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_l2_caches_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_l3_caches_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_l4_caches_count(void); + +/** + * Returns upper bound on cache size. + */ +uint32_t CPUINFO_ABI cpuinfo_get_max_cache_size(void); + +/** + * Identify the logical processor that executes the current thread. + * + * There is no guarantee that the thread will stay on the same logical processor + * for any time. Callers should treat the result as only a hint, and be prepared + * to handle NULL return value. + */ +const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_current_processor(void); + +/** + * Identify the core that executes the current thread. + * + * There is no guarantee that the thread will stay on the same core for any + * time. Callers should treat the result as only a hint, and be prepared to + * handle NULL return value. + */ +const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_current_core(void); + +/** + * Identify the microarchitecture index of the core that executes the current + * thread. If the system does not support such identification, the function + * returns 0. + * + * There is no guarantee that the thread will stay on the same type of core for + * any time. Callers should treat the result as only a hint. + */ +uint32_t CPUINFO_ABI cpuinfo_get_current_uarch_index(void); + +/** + * Identify the microarchitecture index of the core that executes the current + * thread. If the system does not support such identification, the function + * returns the user-specified default value. + * + * There is no guarantee that the thread will stay on the same type of core for + * any time. Callers should treat the result as only a hint. + */ +uint32_t CPUINFO_ABI cpuinfo_get_current_uarch_index_with_default(uint32_t default_uarch_index); + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif /* CPUINFO_H */ diff --git a/lib/python3.10/site-packages/torch/include/dnnl.h b/lib/python3.10/site-packages/torch/include/dnnl.h new file mode 100644 index 0000000000000000000000000000000000000000..bc74bf644f4b628018d7a9103ba63320abc466d5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/include/dnnl.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_H +#define DNNL_H + +#include "oneapi/dnnl/dnnl.h" + +#endif /* DNNL_H */ diff --git a/lib/python3.10/site-packages/torch/include/dnnl_config.h b/lib/python3.10/site-packages/torch/include/dnnl_config.h new file mode 100644 index 0000000000000000000000000000000000000000..48925e1e3ab49ae135c6e9c4c501aa2f5e030913 --- /dev/null +++ b/lib/python3.10/site-packages/torch/include/dnnl_config.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_CONFIG_H +#define DNNL_CONFIG_H + +#include "oneapi/dnnl/dnnl_config.h" + +#endif /* DNNL_CONFIG_H */ diff --git a/lib/python3.10/site-packages/torch/include/dnnl_debug.h b/lib/python3.10/site-packages/torch/include/dnnl_debug.h new file mode 100644 index 0000000000000000000000000000000000000000..5044971832bbbe56127920a527508b207a803eea --- /dev/null +++ b/lib/python3.10/site-packages/torch/include/dnnl_debug.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_DEBUG_H +#define DNNL_DEBUG_H + +#include "oneapi/dnnl/dnnl_debug.h" + +#endif /* DNNL_DEBUG_H */ diff --git a/lib/python3.10/site-packages/torch/lib/libcaffe2_nvrtc.so b/lib/python3.10/site-packages/torch/lib/libcaffe2_nvrtc.so new file mode 100644 index 0000000000000000000000000000000000000000..2b86281620c846c2f42a4ad2e9a261a7a00c6fd7 Binary files /dev/null and b/lib/python3.10/site-packages/torch/lib/libcaffe2_nvrtc.so differ diff --git a/lib/python3.10/site-packages/torch/lib/libshm.so b/lib/python3.10/site-packages/torch/lib/libshm.so new file mode 100644 index 0000000000000000000000000000000000000000..4df70d9356597a8f57fce865e55135bb960a1446 Binary files /dev/null and b/lib/python3.10/site-packages/torch/lib/libshm.so differ diff --git a/lib/python3.10/site-packages/torch/lib/libtorch_global_deps.so b/lib/python3.10/site-packages/torch/lib/libtorch_global_deps.so new file mode 100644 index 0000000000000000000000000000000000000000..c53678bd7bb321b5613a80929e6e627437d649e6 Binary files /dev/null and b/lib/python3.10/site-packages/torch/lib/libtorch_global_deps.so differ diff --git a/lib/python3.10/site-packages/torch/linalg/__init__.py b/lib/python3.10/site-packages/torch/linalg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cef76fec1107d5a47d8fd84857de9f846aa3070d --- /dev/null +++ b/lib/python3.10/site-packages/torch/linalg/__init__.py @@ -0,0 +1,2853 @@ +import torch +from torch._C import _add_docstr, _linalg # type: ignore[attr-defined] + +LinAlgError = torch._C._LinAlgError # type: ignore[attr-defined] + +Tensor = torch.Tensor + +common_notes = { + "experimental_warning": """This function is "experimental" and it may change in a future PyTorch release.""", + "sync_note": "When inputs are on a CUDA device, this function synchronizes that device with the CPU.", + "sync_note_ex": r"When the inputs are on a CUDA device, this function synchronizes only when :attr:`check_errors`\ `= True`.", + "sync_note_has_ex": ("When inputs are on a CUDA device, this function synchronizes that device with the CPU. " + "For a version of this function that does not synchronize, see :func:`{}`.") +} + + +# Note: This not only adds doc strings for functions in the linalg namespace, but +# also connects the torch.linalg Python namespace to the torch._C._linalg builtins. + +cross = _add_docstr(_linalg.linalg_cross, r""" +linalg.cross(input, other, *, dim=-1, out=None) -> Tensor + + +Computes the cross product of two 3-dimensional vectors. + +Supports input of float, double, cfloat and cdouble dtypes. Also supports batches +of vectors, for which it computes the product along the dimension :attr:`dim`. +It broadcasts over the batch dimensions. + +Args: + input (Tensor): the first input tensor. + other (Tensor): the second input tensor. + dim (int, optional): the dimension along which to take the cross-product. Default: `-1`. + +Keyword args: + out (Tensor, optional): the output tensor. Ignored if `None`. Default: `None`. + +Example: + >>> a = torch.randn(4, 3) + >>> a + tensor([[-0.3956, 1.1455, 1.6895], + [-0.5849, 1.3672, 0.3599], + [-1.1626, 0.7180, -0.0521], + [-0.1339, 0.9902, -2.0225]]) + >>> b = torch.randn(4, 3) + >>> b + tensor([[-0.0257, -1.4725, -1.2251], + [-1.1479, -0.7005, -1.9757], + [-1.3904, 0.3726, -1.1836], + [-0.9688, -0.7153, 0.2159]]) + >>> torch.linalg.cross(a, b) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + >>> a = torch.randn(1, 3) # a is broadcast to match shape of b + >>> a + tensor([[-0.9941, -0.5132, 0.5681]]) + >>> torch.linalg.cross(a, b) + tensor([[ 1.4653, -1.2325, 1.4507], + [ 1.4119, -2.6163, 0.1073], + [ 0.3957, -1.9666, -1.0840], + [ 0.2956, -0.3357, 0.2139]]) +""") + +cholesky = _add_docstr(_linalg.linalg_cholesky, r""" +linalg.cholesky(A, *, upper=False, out=None) -> Tensor + +Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **Cholesky decomposition** of a complex Hermitian or real symmetric positive-definite matrix +:math:`A \in \mathbb{K}^{n \times n}` is defined as + +.. math:: + + A = LL^{\text{H}}\mathrlap{\qquad L \in \mathbb{K}^{n \times n}} + +where :math:`L` is a lower triangular matrix with real positive diagonal (even in the complex case) and +:math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, and the transpose when :math:`L` is real-valued. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +""" + fr""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.cholesky_ex")} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.cholesky_ex` for a version of this operation that + skips the (slow) error checking by default and instead returns the debug + information. This makes it a faster way to check if a matrix is + positive-definite. + + :func:`torch.linalg.eigh` for a different decomposition of a Hermitian matrix. + The eigenvalue decomposition gives more information about the matrix but it + slower to compute than the Cholesky decomposition. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian positive-definite matrices. + +Keyword args: + upper (bool, optional): whether to return an upper triangular matrix. + The tensor returned with upper=True is the conjugate transpose of the tensor + returned with upper=False. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the :attr:`A` matrix or any matrix in a batched :attr:`A` is not Hermitian + (resp. symmetric) positive-definite. If :attr:`A` is a batch of matrices, + the error message will include the batch index of the first matrix that fails + to meet this condition. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A @ A.T.conj() + torch.eye(2) # creates a Hermitian positive-definite matrix + >>> A + tensor([[2.5266+0.0000j, 1.9586-2.0626j], + [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) + >>> L = torch.linalg.cholesky(A) + >>> L + tensor([[1.5895+0.0000j, 0.0000+0.0000j], + [1.2322+1.2976j, 2.4928+0.0000j]], dtype=torch.complex128) + >>> torch.dist(L @ L.T.conj(), A) + tensor(4.4692e-16, dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> A = A @ A.mT + torch.eye(2) # batch of symmetric positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> torch.dist(L @ L.mT, A) + tensor(5.8747e-16, dtype=torch.float64) +""") + +cholesky_ex = _add_docstr(_linalg.linalg_cholesky_ex, r""" +linalg.cholesky_ex(A, *, upper=False, check_errors=False, out=None) -> (Tensor, Tensor) + +Computes the Cholesky decomposition of a complex Hermitian or real +symmetric positive-definite matrix. + +This function skips the (slow) error checking and error message construction +of :func:`torch.linalg.cholesky`, instead directly returning the LAPACK +error codes as part of a named tuple ``(L, info)``. This makes this function +a faster way to check if a matrix is positive-definite, and it provides an +opportunity to handle decomposition errors more gracefully or performantly +than :func:`torch.linalg.cholesky` does. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +If :attr:`A` is not a Hermitian positive-definite matrix, or if it's a batch of matrices +and one or more of them is not a Hermitian positive-definite matrix, +then ``info`` stores a positive integer for the corresponding matrix. +The positive integer indicates the order of the leading minor that is not positive-definite, +and the decomposition could not be completed. +``info`` filled with zeros indicates that the decomposition was successful. +If ``check_errors=True`` and ``info`` contains positive integers, then a RuntimeError is thrown. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +.. seealso:: + :func:`torch.linalg.cholesky` is a NumPy compatible variant that always checks for errors. + +Args: + A (Tensor): the Hermitian `n \times n` matrix or the batch of such matrices of size + `(*, n, n)` where `*` is one or more batch dimensions. + +Keyword args: + upper (bool, optional): whether to return an upper triangular matrix. + The tensor returned with upper=True is the conjugate transpose of the tensor + returned with upper=False. + check_errors (bool, optional): controls whether to check the content of ``infos``. Default: `False`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A @ A.t().conj() # creates a Hermitian positive-definite matrix + >>> L, info = torch.linalg.cholesky_ex(A) + >>> A + tensor([[ 2.3792+0.0000j, -0.9023+0.9831j], + [-0.9023-0.9831j, 0.8757+0.0000j]], dtype=torch.complex128) + >>> L + tensor([[ 1.5425+0.0000j, 0.0000+0.0000j], + [-0.5850-0.6374j, 0.3567+0.0000j]], dtype=torch.complex128) + >>> info + tensor(0, dtype=torch.int32) + +""") + +inv = _add_docstr(_linalg.linalg_inv, r""" +linalg.inv(A, *, out=None) -> Tensor + +Computes the inverse of a square matrix if it exists. +Throws a `RuntimeError` if the matrix is not invertible. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +for a matrix :math:`A \in \mathbb{K}^{n \times n}`, +its **inverse matrix** :math:`A^{-1} \in \mathbb{K}^{n \times n}` (if it exists) is defined as + +.. math:: + + A^{-1}A = AA^{-1} = \mathrm{I}_n + +where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + +The inverse matrix exists if and only if :math:`A` is `invertible`_. In this case, +the inverse is unique. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices +then the output has the same batch dimensions. + +""" + fr""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.inv_ex")} +""" + r""" + +.. note:: + Consider using :func:`torch.linalg.solve` if possible for multiplying a matrix on the left by + the inverse, as:: + + linalg.solve(A, B) == linalg.inv(A) @ B # When B is a matrix + + It is always preferred to use :func:`~solve` when possible, as it is faster and more + numerically stable than computing the inverse explicitly. + +.. seealso:: + + :func:`torch.linalg.pinv` computes the pseudoinverse (Moore-Penrose inverse) of matrices + of any shape. + + :func:`torch.linalg.solve` computes :attr:`A`\ `.inv() @ \ `:attr:`B` with a + numerically stable algorithm. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of invertible matrices. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the matrix :attr:`A` or any matrix in the batch of matrices :attr:`A` is not invertible. + +Examples:: + + >>> A = torch.randn(4, 4) + >>> Ainv = torch.linalg.inv(A) + >>> torch.dist(A @ Ainv, torch.eye(4)) + tensor(1.1921e-07) + + >>> A = torch.randn(2, 3, 4, 4) # Batch of matrices + >>> Ainv = torch.linalg.inv(A) + >>> torch.dist(A @ Ainv, torch.eye(4)) + tensor(1.9073e-06) + + >>> A = torch.randn(4, 4, dtype=torch.complex128) # Complex matrix + >>> Ainv = torch.linalg.inv(A) + >>> torch.dist(A @ Ainv, torch.eye(4)) + tensor(7.5107e-16, dtype=torch.float64) + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""") + +solve_ex = _add_docstr(_linalg.linalg_solve_ex, r""" +linalg.solve_ex(A, B, *, left=True, check_errors=False, out=None) -> (Tensor, Tensor) + +A version of :func:`~solve` that does not perform error checks unless :attr:`check_errors`\ `= True`. +It also returns the :attr:`info` tensor returned by `LAPACK's getrf`_. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + check_errors (bool, optional): controls whether to check the content of ``infos`` and raise + an error if it is non-zero. Default: `False`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(result, info)`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> Ainv, info = torch.linalg.solve_ex(A) + >>> torch.dist(torch.linalg.inv(A), Ainv) + tensor(0.) + >>> info + tensor(0, dtype=torch.int32) + +.. _LAPACK's getrf: + https://www.netlib.org/lapack/explore-html/dd/d9a/group__double_g_ecomputational_ga0019443faea08275ca60a734d0593e60.html +""") + +inv_ex = _add_docstr(_linalg.linalg_inv_ex, r""" +linalg.inv_ex(A, *, check_errors=False, out=None) -> (Tensor, Tensor) + +Computes the inverse of a square matrix if it is invertible. + +Returns a namedtuple ``(inverse, info)``. ``inverse`` contains the result of +inverting :attr:`A` and ``info`` stores the LAPACK error codes. + +If :attr:`A` is not an invertible matrix, or if it's a batch of matrices +and one or more of them is not an invertible matrix, +then ``info`` stores a positive integer for the corresponding matrix. +The positive integer indicates the diagonal element of the LU decomposition of +the input matrix that is exactly zero. +``info`` filled with zeros indicates that the inversion was successful. +If ``check_errors=True`` and ``info`` contains positive integers, then a RuntimeError is thrown. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.inv` is a NumPy compatible variant that always checks for errors. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of square matrices. + check_errors (bool, optional): controls whether to check the content of ``info``. Default: `False`. + +Keyword args: + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> Ainv, info = torch.linalg.inv_ex(A) + >>> torch.dist(torch.linalg.inv(A), Ainv) + tensor(0.) + >>> info + tensor(0, dtype=torch.int32) + +""") + +det = _add_docstr(_linalg.linalg_det, r""" +linalg.det(A, *, out=None) -> Tensor + +Computes the determinant of a square matrix. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.linalg.slogdet` computes the sign and natural logarithm of the absolute + value of the determinant of square matrices. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> torch.linalg.det(A) + tensor(0.0934) + + >>> A = torch.randn(3, 2, 2) + >>> torch.linalg.det(A) + tensor([1.1990, 0.4099, 0.7386]) +""") + +slogdet = _add_docstr(_linalg.linalg_slogdet, r""" +linalg.slogdet(A, *, out=None) -> (Tensor, Tensor) + +Computes the sign and natural logarithm of the absolute value of the determinant of a square matrix. + +For complex :attr:`A`, it returns the sign and the natural logarithm of the modulus of the +determinant, that is, a logarithmic polar decomposition of the determinant. + +The determinant can be recovered as `sign * exp(logabsdet)`. +When a matrix has a determinant of zero, it returns `(0, -inf)`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.linalg.det` computes the determinant of square matrices. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(sign, logabsdet)`. + + `sign` will have the same dtype as :attr:`A`. + + `logabsdet` will always be real-valued, even when :attr:`A` is complex. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> A + tensor([[ 0.0032, -0.2239, -1.1219], + [-0.6690, 0.1161, 0.4053], + [-1.6218, -0.9273, -0.0082]]) + >>> torch.linalg.det(A) + tensor(-0.7576) + >>> torch.logdet(A) + tensor(nan) + >>> torch.linalg.slogdet(A) + torch.return_types.linalg_slogdet(sign=tensor(-1.), logabsdet=tensor(-0.2776)) +""") + +eig = _add_docstr(_linalg.linalg_eig, r""" +linalg.eig(A, *, out=None) -> (Tensor, Tensor) + +Computes the eigenvalue decomposition of a square matrix if it exists. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **eigenvalue decomposition** of a square matrix +:math:`A \in \mathbb{K}^{n \times n}` (if it exists) is defined as + +.. math:: + + A = V \operatorname{diag}(\Lambda) V^{-1}\mathrlap{\qquad V \in \mathbb{C}^{n \times n}, \Lambda \in \mathbb{C}^n} + +This decomposition exists if and only if :math:`A` is `diagonalizable`_. +This is the case when all its eigenvalues are different. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The returned eigenvalues are not guaranteed to be in any specific order. + +.. note:: The eigenvalues and eigenvectors of a real matrix may be complex. + +""" + fr""" +.. note:: {common_notes["sync_note"]} +""" + r""" + +.. warning:: This function assumes that :attr:`A` is `diagonalizable`_ (for example, when all the + eigenvalues are different). If it is not diagonalizable, the returned + eigenvalues will be correct but :math:`A \neq V \operatorname{diag}(\Lambda)V^{-1}`. + +.. warning:: The returned eigenvectors are normalized to have norm `1`. + Even then, the eigenvectors of a matrix are not unique, nor are they continuous with respect to + :attr:`A`. Due to this lack of uniqueness, different hardware and software may compute + different eigenvectors. + + This non-uniqueness is caused by the fact that multiplying an eigenvector by + by :math:`e^{i \phi}, \phi \in \mathbb{R}` produces another set of valid eigenvectors + of the matrix. For this reason, the loss function shall not depend on the phase of the + eigenvectors, as this quantity is not well-defined. + This is checked when computing the gradients of this function. As such, + when inputs are on a CUDA device, the computation of the gradients + of this function synchronizes that device with the CPU. + + +.. warning:: Gradients computed using the `eigenvectors` tensor will only be finite when + :attr:`A` has distinct eigenvalues. + Furthermore, if the distance between any two eigenvalues is close to zero, + the gradient will be numerically unstable, as it depends on the eigenvalues + :math:`\lambda_i` through the computation of + :math:`\frac{1}{\min_{i \neq j} \lambda_i - \lambda_j}`. + +.. seealso:: + + :func:`torch.linalg.eigvals` computes only the eigenvalues. + Unlike :func:`torch.linalg.eig`, the gradients of :func:`~eigvals` are always + numerically stable. + + :func:`torch.linalg.eigh` for a (faster) function that computes the eigenvalue decomposition + for Hermitian and symmetric matrices. + + :func:`torch.linalg.svd` for a function that computes another type of spectral + decomposition that works on matrices of any shape. + + :func:`torch.linalg.qr` for another (much faster) decomposition that works on matrices of + any shape. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of diagonalizable matrices. + +Keyword args: + out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(eigenvalues, eigenvectors)` which corresponds to :math:`\Lambda` and :math:`V` above. + + `eigenvalues` and `eigenvectors` will always be complex-valued, even when :attr:`A` is real. The eigenvectors + will be given by the columns of `eigenvectors`. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A + tensor([[ 0.9828+0.3889j, -0.4617+0.3010j], + [ 0.1662-0.7435j, -0.6139+0.0562j]], dtype=torch.complex128) + >>> L, V = torch.linalg.eig(A) + >>> L + tensor([ 1.1226+0.5738j, -0.7537-0.1286j], dtype=torch.complex128) + >>> V + tensor([[ 0.9218+0.0000j, 0.1882-0.2220j], + [-0.0270-0.3867j, 0.9567+0.0000j]], dtype=torch.complex128) + >>> torch.dist(V @ torch.diag(L) @ torch.linalg.inv(V), A) + tensor(7.7119e-16, dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> L, V = torch.linalg.eig(A) + >>> torch.dist(V @ torch.diag_embed(L) @ torch.linalg.inv(V), A) + tensor(3.2841e-16, dtype=torch.float64) + +.. _diagonalizable: + https://en.wikipedia.org/wiki/Diagonalizable_matrix#Definition +""") + +eigvals = _add_docstr(_linalg.linalg_eigvals, r""" +linalg.eigvals(A, *, out=None) -> Tensor + +Computes the eigenvalues of a square matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **eigenvalues** of a square matrix :math:`A \in \mathbb{K}^{n \times n}` are defined +as the roots (counted with multiplicity) of the polynomial `p` of degree `n` given by + +.. math:: + + p(\lambda) = \operatorname{det}(A - \lambda \mathrm{I}_n)\mathrlap{\qquad \lambda \in \mathbb{C}} + +where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The returned eigenvalues are not guaranteed to be in any specific order. + +.. note:: The eigenvalues of a real matrix may be complex, as the roots of a real polynomial may be complex. + + The eigenvalues of a matrix are always well-defined, even when the matrix is not diagonalizable. + +""" + fr""" +.. note:: {common_notes["sync_note"]} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.eig` computes the full eigenvalue decomposition. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Returns: + A complex-valued tensor containing the eigenvalues even when :attr:`A` is real. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> L = torch.linalg.eigvals(A) + >>> L + tensor([ 1.1226+0.5738j, -0.7537-0.1286j], dtype=torch.complex128) + + >>> torch.dist(L, torch.linalg.eig(A).eigenvalues) + tensor(2.4576e-07) +""") + +eigh = _add_docstr(_linalg.linalg_eigh, r""" +linalg.eigh(A, UPLO='L', *, out=None) -> (Tensor, Tensor) + +Computes the eigenvalue decomposition of a complex Hermitian or real symmetric matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **eigenvalue decomposition** of a complex Hermitian or real symmetric matrix +:math:`A \in \mathbb{K}^{n \times n}` is defined as + +.. math:: + + A = Q \operatorname{diag}(\Lambda) Q^{\text{H}}\mathrlap{\qquad Q \in \mathbb{K}^{n \times n}, \Lambda \in \mathbb{R}^n} + +where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex, and the transpose when :math:`Q` is real-valued. +:math:`Q` is orthogonal in the real case and unitary in the complex case. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +:attr:`A` is assumed to be Hermitian (resp. symmetric), but this is not checked internally, instead: + +- If :attr:`UPLO`\ `= 'L'` (default), only the lower triangular part of the matrix is used in the computation. +- If :attr:`UPLO`\ `= 'U'`, only the upper triangular part of the matrix is used. + +The eigenvalues are returned in ascending order. + +""" + fr""" +.. note:: {common_notes["sync_note"]} +""" + r""" + +.. note:: The eigenvalues of real symmetric or complex Hermitian matrices are always real. + +.. warning:: The eigenvectors of a symmetric matrix are not unique, nor are they continuous with + respect to :attr:`A`. Due to this lack of uniqueness, different hardware and + software may compute different eigenvectors. + + This non-uniqueness is caused by the fact that multiplying an eigenvector by + `-1` in the real case or by :math:`e^{i \phi}, \phi \in \mathbb{R}` in the complex + case produces another set of valid eigenvectors of the matrix. + For this reason, the loss function shall not depend on the phase of the eigenvectors, as + this quantity is not well-defined. + This is checked for complex inputs when computing the gradients of this function. As such, + when inputs are complex and are on a CUDA device, the computation of the gradients + of this function synchronizes that device with the CPU. + +.. warning:: Gradients computed using the `eigenvectors` tensor will only be finite when + :attr:`A` has distinct eigenvalues. + Furthermore, if the distance between any two eigenvalues is close to zero, + the gradient will be numerically unstable, as it depends on the eigenvalues + :math:`\lambda_i` through the computation of + :math:`\frac{1}{\min_{i \neq j} \lambda_i - \lambda_j}`. + +.. warning:: User may see pytorch crashes if running `eigh` on CUDA devices with CUDA versions before 12.1 update 1 + with large ill-conditioned matrices as inputs. + Refer to :ref:`Linear Algebra Numerical Stability` for more details. + If this is the case, user may (1) tune their matrix inputs to be less ill-conditioned, + or (2) use :func:`torch.backends.cuda.preferred_linalg_library` to + try other supported backends. + +.. seealso:: + + :func:`torch.linalg.eigvalsh` computes only the eigenvalues of a Hermitian matrix. + Unlike :func:`torch.linalg.eigh`, the gradients of :func:`~eigvalsh` are always + numerically stable. + + :func:`torch.linalg.cholesky` for a different decomposition of a Hermitian matrix. + The Cholesky decomposition gives less information about the matrix but is much faster + to compute than the eigenvalue decomposition. + + :func:`torch.linalg.eig` for a (slower) function that computes the eigenvalue decomposition + of a not necessarily Hermitian square matrix. + + :func:`torch.linalg.svd` for a (slower) function that computes the more general SVD + decomposition of matrices of any shape. + + :func:`torch.linalg.qr` for another (much faster) decomposition that works on general + matrices. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian matrices. + UPLO ('L', 'U', optional): controls whether to use the upper or lower triangular part + of :attr:`A` in the computations. Default: `'L'`. + +Keyword args: + out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(eigenvalues, eigenvectors)` which corresponds to :math:`\Lambda` and :math:`Q` above. + + `eigenvalues` will always be real-valued, even when :attr:`A` is complex. + It will also be ordered in ascending order. + + `eigenvectors` will have the same dtype as :attr:`A` and will contain the eigenvectors as its columns. + +Examples:: + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A + A.T.conj() # creates a Hermitian matrix + >>> A + tensor([[2.9228+0.0000j, 0.2029-0.0862j], + [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) + >>> L, Q = torch.linalg.eigh(A) + >>> L + tensor([0.3277, 2.9415], dtype=torch.float64) + >>> Q + tensor([[-0.0846+-0.0000j, -0.9964+0.0000j], + [ 0.9170+0.3898j, -0.0779-0.0331j]], dtype=torch.complex128) + >>> torch.dist(Q @ torch.diag(L.cdouble()) @ Q.T.conj(), A) + tensor(6.1062e-16, dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> A = A + A.mT # creates a batch of symmetric matrices + >>> L, Q = torch.linalg.eigh(A) + >>> torch.dist(Q @ torch.diag_embed(L) @ Q.mH, A) + tensor(1.5423e-15, dtype=torch.float64) +""") + +eigvalsh = _add_docstr(_linalg.linalg_eigvalsh, r""" +linalg.eigvalsh(A, UPLO='L', *, out=None) -> Tensor + +Computes the eigenvalues of a complex Hermitian or real symmetric matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **eigenvalues** of a complex Hermitian or real symmetric matrix :math:`A \in \mathbb{K}^{n \times n}` +are defined as the roots (counted with multiplicity) of the polynomial `p` of degree `n` given by + +.. math:: + + p(\lambda) = \operatorname{det}(A - \lambda \mathrm{I}_n)\mathrlap{\qquad \lambda \in \mathbb{R}} + +where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. +The eigenvalues of a real symmetric or complex Hermitian matrix are always real. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The eigenvalues are returned in ascending order. + +:attr:`A` is assumed to be Hermitian (resp. symmetric), but this is not checked internally, instead: + +- If :attr:`UPLO`\ `= 'L'` (default), only the lower triangular part of the matrix is used in the computation. +- If :attr:`UPLO`\ `= 'U'`, only the upper triangular part of the matrix is used. + +""" + fr""" +.. note:: {common_notes["sync_note"]} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.eigh` computes the full eigenvalue decomposition. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian matrices. + UPLO ('L', 'U', optional): controls whether to use the upper or lower triangular part + of :attr:`A` in the computations. Default: `'L'`. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Returns: + A real-valued tensor containing the eigenvalues even when :attr:`A` is complex. + The eigenvalues are returned in ascending order. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A + A.T.conj() # creates a Hermitian matrix + >>> A + tensor([[2.9228+0.0000j, 0.2029-0.0862j], + [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) + >>> torch.linalg.eigvalsh(A) + tensor([0.3277, 2.9415], dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> A = A + A.mT # creates a batch of symmetric matrices + >>> torch.linalg.eigvalsh(A) + tensor([[ 2.5797, 3.4629], + [-4.1605, 1.3780], + [-3.1113, 2.7381]], dtype=torch.float64) +""") + +householder_product = _add_docstr(_linalg.linalg_householder_product, r""" +householder_product(A, tau, *, out=None) -> Tensor + +Computes the first `n` columns of a product of Householder matrices. + +Let :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, and +let :math:`A \in \mathbb{K}^{m \times n}` be a matrix with columns :math:`a_i \in \mathbb{K}^m` +for :math:`i=1,\ldots,m` with :math:`m \geq n`. Denote by :math:`b_i` the vector resulting from +zeroing out the first :math:`i-1` components of :math:`a_i` and setting to `1` the :math:`i`-th. +For a vector :math:`\tau \in \mathbb{K}^k` with :math:`k \leq n`, this function computes the +first :math:`n` columns of the matrix + +.. math:: + + H_1H_2 ... H_k \qquad\text{with}\qquad H_i = \mathrm{I}_m - \tau_i b_i b_i^{\text{H}} + +where :math:`\mathrm{I}_m` is the `m`-dimensional identity matrix and :math:`b^{\text{H}}` is the +conjugate transpose when :math:`b` is complex, and the transpose when :math:`b` is real-valued. +The output matrix is the same size as the input matrix :attr:`A`. + +See `Representation of Orthogonal or Unitary Matrices`_ for further details. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.geqrf` can be used together with this function to form the `Q` from the + :func:`~qr` decomposition. + + :func:`torch.ormqr` is a related function that computes the matrix multiplication + of a product of Householder matrices with another matrix. + However, that function is not supported by autograd. + +.. warning:: + Gradient computations are only well-defined if :math:`\tau_i \neq \frac{1}{||a_i||^2}`. + If this condition is not met, no error will be thrown, but the gradient produced may contain `NaN`. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + tau (Tensor): tensor of shape `(*, k)` where `*` is zero or more batch dimensions. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if :attr:`A` doesn't satisfy the requirement `m >= n`, + or :attr:`tau` doesn't satisfy the requirement `n >= k`. + +Examples:: + + >>> A = torch.randn(2, 2) + >>> h, tau = torch.geqrf(A) + >>> Q = torch.linalg.householder_product(h, tau) + >>> torch.dist(Q, torch.linalg.qr(A).Q) + tensor(0.) + + >>> h = torch.randn(3, 2, 2, dtype=torch.complex128) + >>> tau = torch.randn(3, 1, dtype=torch.complex128) + >>> Q = torch.linalg.householder_product(h, tau) + >>> Q + tensor([[[ 1.8034+0.4184j, 0.2588-1.0174j], + [-0.6853+0.7953j, 2.0790+0.5620j]], + + [[ 1.4581+1.6989j, -1.5360+0.1193j], + [ 1.3877-0.6691j, 1.3512+1.3024j]], + + [[ 1.4766+0.5783j, 0.0361+0.6587j], + [ 0.6396+0.1612j, 1.3693+0.4481j]]], dtype=torch.complex128) + +.. _Representation of Orthogonal or Unitary Matrices: + https://www.netlib.org/lapack/lug/node128.html +""") + +ldl_factor = _add_docstr(_linalg.linalg_ldl_factor, r""" +linalg.ldl_factor(A, *, hermitian=False, out=None) -> (Tensor, Tensor) + +Computes a compact representation of the LDL factorization of a Hermitian or symmetric (possibly indefinite) matrix. + +When :attr:`A` is complex valued it can be Hermitian (:attr:`hermitian`\ `= True`) +or symmetric (:attr:`hermitian`\ `= False`). + +The factorization is of the form the form :math:`A = L D L^T`. +If :attr:`hermitian` is `True` then transpose operation is the conjugate transpose. + +:math:`L` (or :math:`U`) and :math:`D` are stored in compact form in ``LD``. +They follow the format specified by `LAPACK's sytrf`_ function. +These tensors may be used in :func:`torch.linalg.ldl_solve` to solve linear systems. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +""" + fr""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.ldl_factor_ex")} +""" + r""" + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian matrices. + +Keyword args: + hermitian (bool, optional): whether to consider the input to be Hermitian or symmetric. + For real-valued matrices, this switch has no effect. Default: `False`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LD, pivots)`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.mT # make symmetric + >>> A + tensor([[7.2079, 4.2414, 1.9428], + [4.2414, 3.4554, 0.3264], + [1.9428, 0.3264, 1.3823]]) + >>> LD, pivots = torch.linalg.ldl_factor(A) + >>> LD + tensor([[ 7.2079, 0.0000, 0.0000], + [ 0.5884, 0.9595, 0.0000], + [ 0.2695, -0.8513, 0.1633]]) + >>> pivots + tensor([1, 2, 3], dtype=torch.int32) + +.. _LAPACK's sytrf: + https://www.netlib.org/lapack/explore-html/d3/db6/group__double_s_ycomputational_gad91bde1212277b3e909eb6af7f64858a.html +""") + +ldl_factor_ex = _add_docstr(_linalg.linalg_ldl_factor_ex, r""" +linalg.ldl_factor_ex(A, *, hermitian=False, check_errors=False, out=None) -> (Tensor, Tensor, Tensor) + +This is a version of :func:`~ldl_factor` that does not perform error checks unless :attr:`check_errors`\ `= True`. +It also returns the :attr:`info` tensor returned by `LAPACK's sytrf`_. +``info`` stores integer error codes from the backend library. +A positive integer indicates the diagonal element of :math:`D` that is zero. +Division by 0 will occur if the result is used for solving a system of linear equations. +``info`` filled with zeros indicates that the factorization was successful. +If ``check_errors=True`` and ``info`` contains positive integers, then a `RuntimeError` is thrown. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian matrices. + +Keyword args: + hermitian (bool, optional): whether to consider the input to be Hermitian or symmetric. + For real-valued matrices, this switch has no effect. Default: `False`. + check_errors (bool, optional): controls whether to check the content of ``info`` and raise + an error if it is non-zero. Default: `False`. + out (tuple, optional): tuple of three tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LD, pivots, info)`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.mT # make symmetric + >>> A + tensor([[7.2079, 4.2414, 1.9428], + [4.2414, 3.4554, 0.3264], + [1.9428, 0.3264, 1.3823]]) + >>> LD, pivots, info = torch.linalg.ldl_factor_ex(A) + >>> LD + tensor([[ 7.2079, 0.0000, 0.0000], + [ 0.5884, 0.9595, 0.0000], + [ 0.2695, -0.8513, 0.1633]]) + >>> pivots + tensor([1, 2, 3], dtype=torch.int32) + >>> info + tensor(0, dtype=torch.int32) + +.. _LAPACK's sytrf: + https://www.netlib.org/lapack/explore-html/d3/db6/group__double_s_ycomputational_gad91bde1212277b3e909eb6af7f64858a.html +""") + +ldl_solve = _add_docstr(_linalg.linalg_ldl_solve, r""" +linalg.ldl_solve(LD, pivots, B, *, hermitian=False, out=None) -> Tensor + +Computes the solution of a system of linear equations using the LDL factorization. + +:attr:`LD` and :attr:`pivots` are the compact representation of the LDL factorization and +are expected to be computed by :func:`torch.linalg.ldl_factor_ex`. +:attr:`hermitian` argument to this function should be the same +as the corresponding arguments in :func:`torch.linalg.ldl_factor_ex`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +""" + fr""" +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +Args: + LD (Tensor): the `n \times n` matrix or the batch of such matrices of size + `(*, n, n)` where `*` is one or more batch dimensions. + pivots (Tensor): the pivots corresponding to the LDL factorization of :attr:`LD`. + B (Tensor): right-hand side tensor of shape `(*, n, k)`. + +Keyword args: + hermitian (bool, optional): whether to consider the decomposed matrix to be Hermitian or symmetric. + For real-valued matrices, this switch has no effect. Default: `False`. + out (tuple, optional): output tensor. `B` may be passed as `out` and the result is computed in-place on `B`. + Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> A = A @ A.mT # make symmetric + >>> LD, pivots, info = torch.linalg.ldl_factor_ex(A) + >>> B = torch.randn(2, 3, 4) + >>> X = torch.linalg.ldl_solve(LD, pivots, B) + >>> torch.linalg.norm(A @ X - B) + >>> tensor(0.0001) +""") + +lstsq = _add_docstr(_linalg.linalg_lstsq, r""" +torch.linalg.lstsq(A, B, rcond=None, *, driver=None) -> (Tensor, Tensor, Tensor, Tensor) + +Computes a solution to the least squares problem of a system of linear equations. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **least squares problem** for a linear system :math:`AX = B` with +:math:`A \in \mathbb{K}^{m \times n}, B \in \mathbb{K}^{m \times k}` is defined as + +.. math:: + + \min_{X \in \mathbb{K}^{n \times k}} \|AX - B\|_F + +where :math:`\|-\|_F` denotes the Frobenius norm. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +:attr:`driver` chooses the backend function that will be used. +For CPU inputs the valid values are `'gels'`, `'gelsy'`, `'gelsd`, `'gelss'`. +To choose the best driver on CPU consider: + +- If :attr:`A` is well-conditioned (its `condition number`_ is not too large), or you do not mind some precision loss. + + - For a general matrix: `'gelsy'` (QR with pivoting) (default) + - If :attr:`A` is full-rank: `'gels'` (QR) + +- If :attr:`A` is not well-conditioned. + + - `'gelsd'` (tridiagonal reduction and SVD) + - But if you run into memory issues: `'gelss'` (full SVD). + +For CUDA input, the only valid driver is `'gels'`, which assumes that :attr:`A` is full-rank. + +See also the `full description of these drivers`_ + +:attr:`rcond` is used to determine the effective rank of the matrices in :attr:`A` +when :attr:`driver` is one of (`'gelsy'`, `'gelsd'`, `'gelss'`). +In this case, if :math:`\sigma_i` are the singular values of `A` in decreasing order, +:math:`\sigma_i` will be rounded down to zero if :math:`\sigma_i \leq \text{rcond} \cdot \sigma_1`. +If :attr:`rcond`\ `= None` (default), :attr:`rcond` is set to the machine precision of the dtype of :attr:`A` times `max(m, n)`. + +This function returns the solution to the problem and some extra information in a named tuple of +four tensors `(solution, residuals, rank, singular_values)`. For inputs :attr:`A`, :attr:`B` +of shape `(*, m, n)`, `(*, m, k)` respectively, it contains + +- `solution`: the least squares solution. It has shape `(*, n, k)`. +- `residuals`: the squared residuals of the solutions, that is, :math:`\|AX - B\|_F^2`. + It has shape equal to the batch dimensions of :attr:`A`. + It is computed when `m > n` and every matrix in :attr:`A` is full-rank, + otherwise, it is an empty tensor. + If :attr:`A` is a batch of matrices and any matrix in the batch is not full rank, + then an empty tensor is returned. This behavior may change in a future PyTorch release. +- `rank`: tensor of ranks of the matrices in :attr:`A`. + It has shape equal to the batch dimensions of :attr:`A`. + It is computed when :attr:`driver` is one of (`'gelsy'`, `'gelsd'`, `'gelss'`), + otherwise it is an empty tensor. +- `singular_values`: tensor of singular values of the matrices in :attr:`A`. + It has shape `(*, min(m, n))`. + It is computed when :attr:`driver` is one of (`'gelsd'`, `'gelss'`), + otherwise it is an empty tensor. + +.. note:: + This function computes `X = \ `:attr:`A`\ `.pinverse() @ \ `:attr:`B` in a faster and + more numerically stable way than performing the computations separately. + +.. warning:: + The default value of :attr:`rcond` may change in a future PyTorch release. + It is therefore recommended to use a fixed value to avoid potential + breaking changes. + +Args: + A (Tensor): lhs tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + B (Tensor): rhs tensor of shape `(*, m, k)` where `*` is zero or more batch dimensions. + rcond (float, optional): used to determine the effective rank of :attr:`A`. + If :attr:`rcond`\ `= None`, :attr:`rcond` is set to the machine + precision of the dtype of :attr:`A` times `max(m, n)`. Default: `None`. + +Keyword args: + driver (str, optional): name of the LAPACK/MAGMA method to be used. + If `None`, `'gelsy'` is used for CPU inputs and `'gels'` for CUDA inputs. + Default: `None`. + +Returns: + A named tuple `(solution, residuals, rank, singular_values)`. + +Examples:: + + >>> A = torch.randn(1,3,3) + >>> A + tensor([[[-1.0838, 0.0225, 0.2275], + [ 0.2438, 0.3844, 0.5499], + [ 0.1175, -0.9102, 2.0870]]]) + >>> B = torch.randn(2,3,3) + >>> B + tensor([[[-0.6772, 0.7758, 0.5109], + [-1.4382, 1.3769, 1.1818], + [-0.3450, 0.0806, 0.3967]], + [[-1.3994, -0.1521, -0.1473], + [ 1.9194, 1.0458, 0.6705], + [-1.1802, -0.9796, 1.4086]]]) + >>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3) + >>> torch.dist(X, torch.linalg.pinv(A) @ B) + tensor(1.5152e-06) + + >>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values + >>> torch.dist(S, torch.linalg.svdvals(A)) + tensor(2.3842e-07) + + >>> A[:, 0].zero_() # Decrease the rank of A + >>> rank = torch.linalg.lstsq(A, B).rank + >>> rank + tensor([2]) + +.. _condition number: + https://pytorch.org/docs/main/linalg.html#torch.linalg.cond +.. _full description of these drivers: + https://www.netlib.org/lapack/lug/node27.html +""") + +matrix_power = _add_docstr(_linalg.linalg_matrix_power, r""" +matrix_power(A, n, *, out=None) -> Tensor + +Computes the `n`-th power of a square matrix for an integer `n`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +If :attr:`n`\ `= 0`, it returns the identity matrix (or batch) of the same shape +as :attr:`A`. If :attr:`n` is negative, it returns the inverse of each matrix +(if invertible) raised to the power of `abs(n)`. + +.. note:: + Consider using :func:`torch.linalg.solve` if possible for multiplying a matrix on the left by + a negative power as, if :attr:`n`\ `> 0`:: + + torch.linalg.solve(matrix_power(A, n), B) == matrix_power(A, -n) @ B + + It is always preferred to use :func:`~solve` when possible, as it is faster and more + numerically stable than computing :math:`A^{-n}` explicitly. + +.. seealso:: + + :func:`torch.linalg.solve` computes :attr:`A`\ `.inverse() @ \ `:attr:`B` with a + numerically stable algorithm. + +Args: + A (Tensor): tensor of shape `(*, m, m)` where `*` is zero or more batch dimensions. + n (int): the exponent. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if :attr:`n`\ `< 0` and the matrix :attr:`A` or any matrix in the + batch of matrices :attr:`A` is not invertible. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> torch.linalg.matrix_power(A, 0) + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + >>> torch.linalg.matrix_power(A, 3) + tensor([[ 1.0756, 0.4980, 0.0100], + [-1.6617, 1.4994, -1.9980], + [-0.4509, 0.2731, 0.8001]]) + >>> torch.linalg.matrix_power(A.expand(2, -1, -1), -2) + tensor([[[ 0.2640, 0.4571, -0.5511], + [-1.0163, 0.3491, -1.5292], + [-0.4899, 0.0822, 0.2773]], + [[ 0.2640, 0.4571, -0.5511], + [-1.0163, 0.3491, -1.5292], + [-0.4899, 0.0822, 0.2773]]]) +""") + +matrix_rank = _add_docstr(_linalg.linalg_matrix_rank, r""" +linalg.matrix_rank(A, *, atol=None, rtol=None, hermitian=False, out=None) -> Tensor + +Computes the numerical rank of a matrix. + +The matrix rank is computed as the number of singular values +(or eigenvalues in absolute value when :attr:`hermitian`\ `= True`) +that are greater than :math:`\max(\text{atol}, \sigma_1 * \text{rtol})` threshold, +where :math:`\sigma_1` is the largest singular value (or eigenvalue). + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +If :attr:`hermitian`\ `= True`, :attr:`A` is assumed to be Hermitian if complex or +symmetric if real, but this is not checked internally. Instead, just the lower +triangular part of the matrix is used in the computations. + +If :attr:`rtol` is not specified and :attr:`A` is a matrix of dimensions `(m, n)`, +the relative tolerance is set to be :math:`\text{rtol} = \max(m, n) \varepsilon` +and :math:`\varepsilon` is the epsilon value for the dtype of :attr:`A` (see :class:`.finfo`). +If :attr:`rtol` is not specified and :attr:`atol` is specified to be larger than zero then +:attr:`rtol` is set to zero. + +If :attr:`atol` or :attr:`rtol` is a :class:`torch.Tensor`, its shape must be broadcastable to that +of the singular values of :attr:`A` as returned by :func:`torch.linalg.svdvals`. + +.. note:: + This function has NumPy compatible variant `linalg.matrix_rank(A, tol, hermitian=False)`. + However, use of the positional argument :attr:`tol` is deprecated in favor of :attr:`atol` and :attr:`rtol`. + +""" + fr""" +.. note:: The matrix rank is computed using a singular value decomposition + :func:`torch.linalg.svdvals` if :attr:`hermitian`\ `= False` (default) and the eigenvalue + decomposition :func:`torch.linalg.eigvalsh` when :attr:`hermitian`\ `= True`. + {common_notes["sync_note"]} +""" + r""" + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + tol (float, Tensor, optional): [NumPy Compat] Alias for :attr:`atol`. Default: `None`. + +Keyword args: + atol (float, Tensor, optional): the absolute tolerance value. When `None` it's considered to be zero. + Default: `None`. + rtol (float, Tensor, optional): the relative tolerance value. See above for the value it takes when `None`. + Default: `None`. + hermitian(bool): indicates whether :attr:`A` is Hermitian if complex + or symmetric if real. Default: `False`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.eye(10) + >>> torch.linalg.matrix_rank(A) + tensor(10) + >>> B = torch.eye(10) + >>> B[0, 0] = 0 + >>> torch.linalg.matrix_rank(B) + tensor(9) + + >>> A = torch.randn(4, 3, 2) + >>> torch.linalg.matrix_rank(A) + tensor([2, 2, 2, 2]) + + >>> A = torch.randn(2, 4, 2, 3) + >>> torch.linalg.matrix_rank(A) + tensor([[2, 2, 2, 2], + [2, 2, 2, 2]]) + + >>> A = torch.randn(2, 4, 3, 3, dtype=torch.complex64) + >>> torch.linalg.matrix_rank(A) + tensor([[3, 3, 3, 3], + [3, 3, 3, 3]]) + >>> torch.linalg.matrix_rank(A, hermitian=True) + tensor([[3, 3, 3, 3], + [3, 3, 3, 3]]) + >>> torch.linalg.matrix_rank(A, atol=1.0, rtol=0.0) + tensor([[3, 2, 2, 2], + [1, 2, 1, 2]]) + >>> torch.linalg.matrix_rank(A, atol=1.0, rtol=0.0, hermitian=True) + tensor([[2, 2, 2, 1], + [1, 2, 2, 2]]) +""") + +norm = _add_docstr(_linalg.linalg_norm, r""" +linalg.norm(A, ord=None, dim=None, keepdim=False, *, out=None, dtype=None) -> Tensor + +Computes a vector or matrix norm. + +Supports input of float, double, cfloat and cdouble dtypes. + +Whether this function computes a vector or matrix norm is determined as follows: + +- If :attr:`dim` is an `int`, the vector norm will be computed. +- If :attr:`dim` is a `2`-`tuple`, the matrix norm will be computed. +- If :attr:`dim`\ `= None` and :attr:`ord`\ `= None`, + :attr:`A` will be flattened to 1D and the `2`-norm of the resulting vector will be computed. +- If :attr:`dim`\ `= None` and :attr:`ord` `!= None`, :attr:`A` must be 1D or 2D. + +:attr:`ord` defines the norm that is computed. The following norms are supported: + +====================== ========================= ======================================================== +:attr:`ord` norm for matrices norm for vectors +====================== ========================= ======================================================== +`None` (default) Frobenius norm `2`-norm (see below) +`'fro'` Frobenius norm -- not supported -- +`'nuc'` nuclear norm -- not supported -- +`inf` `max(sum(abs(x), dim=1))` `max(abs(x))` +`-inf` `min(sum(abs(x), dim=1))` `min(abs(x))` +`0` -- not supported -- `sum(x != 0)` +`1` `max(sum(abs(x), dim=0))` as below +`-1` `min(sum(abs(x), dim=0))` as below +`2` largest singular value as below +`-2` smallest singular value as below +other `int` or `float` -- not supported -- `sum(abs(x)^{ord})^{(1 / ord)}` +====================== ========================= ======================================================== + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +.. seealso:: + + :func:`torch.linalg.vector_norm` computes a vector norm. + + :func:`torch.linalg.matrix_norm` computes a matrix norm. + + The above functions are often clearer and more flexible than using :func:`torch.linalg.norm`. + For example, `torch.linalg.norm(A, ord=1, dim=(0, 1))` always + computes a matrix norm, but with `torch.linalg.vector_norm(A, ord=1, dim=(0, 1))` it is possible + to compute a vector norm over the two dimensions. + +Args: + A (Tensor): tensor of shape `(*, n)` or `(*, m, n)` where `*` is zero or more batch dimensions + ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `None` + dim (int, Tuple[int], optional): dimensions over which to compute + the vector or matrix norm. See above for the behavior when :attr:`dim`\ `= None`. + Default: `None` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to + :attr:`dtype` before performing the operation, and the returned tensor's type + will be :attr:`dtype`. Default: `None` + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Examples:: + + >>> from torch import linalg as LA + >>> a = torch.arange(9, dtype=torch.float) - 4 + >>> a + tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) + >>> B = a.reshape((3, 3)) + >>> B + tensor([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]]) + + >>> LA.norm(a) + tensor(7.7460) + >>> LA.norm(B) + tensor(7.7460) + >>> LA.norm(B, 'fro') + tensor(7.7460) + >>> LA.norm(a, float('inf')) + tensor(4.) + >>> LA.norm(B, float('inf')) + tensor(9.) + >>> LA.norm(a, -float('inf')) + tensor(0.) + >>> LA.norm(B, -float('inf')) + tensor(2.) + + >>> LA.norm(a, 1) + tensor(20.) + >>> LA.norm(B, 1) + tensor(7.) + >>> LA.norm(a, -1) + tensor(0.) + >>> LA.norm(B, -1) + tensor(6.) + >>> LA.norm(a, 2) + tensor(7.7460) + >>> LA.norm(B, 2) + tensor(7.3485) + + >>> LA.norm(a, -2) + tensor(0.) + >>> LA.norm(B.double(), -2) + tensor(1.8570e-16, dtype=torch.float64) + >>> LA.norm(a, 3) + tensor(5.8480) + >>> LA.norm(a, -3) + tensor(0.) + +Using the :attr:`dim` argument to compute vector norms:: + + >>> c = torch.tensor([[1., 2., 3.], + ... [-1, 1, 4]]) + >>> LA.norm(c, dim=0) + tensor([1.4142, 2.2361, 5.0000]) + >>> LA.norm(c, dim=1) + tensor([3.7417, 4.2426]) + >>> LA.norm(c, ord=1, dim=1) + tensor([6., 6.]) + +Using the :attr:`dim` argument to compute matrix norms:: + + >>> A = torch.arange(8, dtype=torch.float).reshape(2, 2, 2) + >>> LA.norm(A, dim=(1,2)) + tensor([ 3.7417, 11.2250]) + >>> LA.norm(A[0, :, :]), LA.norm(A[1, :, :]) + (tensor(3.7417), tensor(11.2250)) +""") + +vector_norm = _add_docstr(_linalg.linalg_vector_norm, r""" +linalg.vector_norm(x, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor + +Computes a vector norm. + +If :attr:`x` is complex valued, it computes the norm of :attr:`x`\ `.abs()` + +Supports input of float, double, cfloat and cdouble dtypes. + +This function does not necessarily treat multidimensional :attr:`x` as a batch of +vectors, instead: + +- If :attr:`dim`\ `= None`, :attr:`x` will be flattened before the norm is computed. +- If :attr:`dim` is an `int` or a `tuple`, the norm will be computed over these dimensions + and the other dimensions will be treated as batch dimensions. + +This behavior is for consistency with :func:`torch.linalg.norm`. + +:attr:`ord` defines the vector norm that is computed. The following norms are supported: + +====================== =============================== +:attr:`ord` vector norm +====================== =============================== +`2` (default) `2`-norm (see below) +`inf` `max(abs(x))` +`-inf` `min(abs(x))` +`0` `sum(x != 0)` +other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}` +====================== =============================== + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +:attr:`dtype` may be used to perform the computation in a more precise dtype. +It is semantically equivalent to calling ``linalg.vector_norm(x.to(dtype))`` +but it is faster in some cases. + +.. seealso:: + + :func:`torch.linalg.matrix_norm` computes a matrix norm. + +Args: + x (Tensor): tensor, flattened by default, but this behavior can be + controlled using :attr:`dim`. + ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` + dim (int, Tuple[int], optional): dimensions over which to compute + the norm. See above for the behavior when :attr:`dim`\ `= None`. + Default: `None` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + dtype (:class:`torch.dtype`, optional): type used to perform the accumulation and the return. + If specified, :attr:`x` is cast to :attr:`dtype` before performing the operation, + and the returned tensor's type will be :attr:`dtype` if real and of its real counterpart if complex. + :attr:`dtype` may be complex if :attr:`x` is complex, otherwise it must be real. + :attr:`x` should be convertible without narrowing to :attr:`dtype`. Default: None + +Returns: + A real-valued tensor, even when :attr:`x` is complex. + +Examples:: + + >>> from torch import linalg as LA + >>> a = torch.arange(9, dtype=torch.float) - 4 + >>> a + tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) + >>> B = a.reshape((3, 3)) + >>> B + tensor([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]]) + >>> LA.vector_norm(a, ord=3.5) + tensor(5.4345) + >>> LA.vector_norm(B, ord=3.5) + tensor(5.4345) +""") + +matrix_norm = _add_docstr(_linalg.linalg_matrix_norm, r""" +linalg.matrix_norm(A, ord='fro', dim=(-2, -1), keepdim=False, *, dtype=None, out=None) -> Tensor + +Computes a matrix norm. + +If :attr:`A` is complex valued, it computes the norm of :attr:`A`\ `.abs()` + +Support input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices: the norm will be computed over the +dimensions specified by the 2-tuple :attr:`dim` and the other dimensions will +be treated as batch dimensions. The output will have the same batch dimensions. + +:attr:`ord` defines the matrix norm that is computed. The following norms are supported: + +====================== ======================================================== +:attr:`ord` matrix norm +====================== ======================================================== +`'fro'` (default) Frobenius norm +`'nuc'` nuclear norm +`inf` `max(sum(abs(x), dim=1))` +`-inf` `min(sum(abs(x), dim=1))` +`1` `max(sum(abs(x), dim=0))` +`-1` `min(sum(abs(x), dim=0))` +`2` largest singular value +`-2` smallest singular value +====================== ======================================================== + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +Args: + A (Tensor): tensor with two or more dimensions. By default its + shape is interpreted as `(*, m, n)` where `*` is zero or more + batch dimensions, but this behavior can be controlled using :attr:`dim`. + ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'fro'` + dim (Tuple[int, int], optional): dimensions over which to compute the norm. Default: `(-2, -1)` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to + :attr:`dtype` before performing the operation, and the returned tensor's type + will be :attr:`dtype`. Default: `None` + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Examples:: + + >>> from torch import linalg as LA + >>> A = torch.arange(9, dtype=torch.float).reshape(3, 3) + >>> A + tensor([[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]) + >>> LA.matrix_norm(A) + tensor(14.2829) + >>> LA.matrix_norm(A, ord=-1) + tensor(9.) + >>> B = A.expand(2, -1, -1) + >>> B + tensor([[[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]], + + [[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]]) + >>> LA.matrix_norm(B) + tensor([14.2829, 14.2829]) + >>> LA.matrix_norm(B, dim=(0, 2)) + tensor([ 3.1623, 10.0000, 17.2627]) +""") + +matmul = _add_docstr(_linalg.linalg_matmul, r""" +linalg.matmul(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.matmul` +""") + +diagonal = _add_docstr(_linalg.linalg_diagonal, r""" +linalg.diagonal(A, *, offset=0, dim1=-2, dim2=-1) -> Tensor + +Alias for :func:`torch.diagonal` with defaults :attr:`dim1`\ `= -2`, :attr:`dim2`\ `= -1`. +""") + +multi_dot = _add_docstr(_linalg.linalg_multi_dot, r""" +linalg.multi_dot(tensors, *, out=None) + +Efficiently multiplies two or more matrices by reordering the multiplications so that +the fewest arithmetic operations are performed. + +Supports inputs of float, double, cfloat and cdouble dtypes. +This function does not support batched inputs. + +Every tensor in :attr:`tensors` must be 2D, except for the first and last which +may be 1D. If the first tensor is a 1D vector of shape `(n,)` it is treated as a row vector +of shape `(1, n)`, similarly if the last tensor is a 1D vector of shape `(n,)` it is treated +as a column vector of shape `(n, 1)`. + +If the first and last tensors are matrices, the output will be a matrix. +However, if either is a 1D vector, then the output will be a 1D vector. + +Differences with `numpy.linalg.multi_dot`: + +- Unlike `numpy.linalg.multi_dot`, the first and last tensors must either be 1D or 2D + whereas NumPy allows them to be nD + +.. warning:: This function does not broadcast. + +.. note:: This function is implemented by chaining :func:`torch.mm` calls after + computing the optimal matrix multiplication order. + +.. note:: The cost of multiplying two matrices with shapes `(a, b)` and `(b, c)` is + `a * b * c`. Given matrices `A`, `B`, `C` with shapes `(10, 100)`, + `(100, 5)`, `(5, 50)` respectively, we can calculate the cost of different + multiplication orders as follows: + + .. math:: + + \begin{align*} + \operatorname{cost}((AB)C) &= 10 \times 100 \times 5 + 10 \times 5 \times 50 = 7500 \\ + \operatorname{cost}(A(BC)) &= 10 \times 100 \times 50 + 100 \times 5 \times 50 = 75000 + \end{align*} + + In this case, multiplying `A` and `B` first followed by `C` is 10 times faster. + +Args: + tensors (Sequence[Tensor]): two or more tensors to multiply. The first and last + tensors may be 1D or 2D. Every other tensor must be 2D. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> from torch.linalg import multi_dot + + >>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])]) + tensor(8) + >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])]) + tensor([8]) + >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])]) + tensor([[8]]) + + >>> A = torch.arange(2 * 3).view(2, 3) + >>> B = torch.arange(3 * 2).view(3, 2) + >>> C = torch.arange(2 * 2).view(2, 2) + >>> multi_dot((A, B, C)) + tensor([[ 26, 49], + [ 80, 148]]) +""") + +svd = _add_docstr(_linalg.linalg_svd, r""" +linalg.svd(A, full_matrices=True, *, driver=None, out=None) -> (Tensor, Tensor, Tensor) + +Computes the singular value decomposition (SVD) of a matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **full SVD** of a matrix +:math:`A \in \mathbb{K}^{m \times n}`, if `k = min(m,n)`, is defined as + +.. math:: + + A = U \operatorname{diag}(S) V^{\text{H}} + \mathrlap{\qquad U \in \mathbb{K}^{m \times m}, S \in \mathbb{R}^k, V \in \mathbb{K}^{n \times n}} + +where :math:`\operatorname{diag}(S) \in \mathbb{K}^{m \times n}`, +:math:`V^{\text{H}}` is the conjugate transpose when :math:`V` is complex, and the transpose when :math:`V` is real-valued. +The matrices :math:`U`, :math:`V` (and thus :math:`V^{\text{H}}`) are orthogonal in the real case, and unitary in the complex case. + +When `m > n` (resp. `m < n`) we can drop the last `m - n` (resp. `n - m`) columns of `U` (resp. `V`) to form the **reduced SVD**: + +.. math:: + + A = U \operatorname{diag}(S) V^{\text{H}} + \mathrlap{\qquad U \in \mathbb{K}^{m \times k}, S \in \mathbb{R}^k, V \in \mathbb{K}^{k \times n}} + +where :math:`\operatorname{diag}(S) \in \mathbb{K}^{k \times k}`. +In this case, :math:`U` and :math:`V` also have orthonormal columns. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The returned decomposition is a named tuple `(U, S, Vh)` +which corresponds to :math:`U`, :math:`S`, :math:`V^{\text{H}}` above. + +The singular values are returned in descending order. + +The parameter :attr:`full_matrices` chooses between the full (default) and reduced SVD. + +The :attr:`driver` kwarg may be used in CUDA with a cuSOLVER backend to choose the algorithm used to compute the SVD. +The choice of a driver is a trade-off between accuracy and speed. + +- If :attr:`A` is well-conditioned (its `condition number`_ is not too large), or you do not mind some precision loss. + + - For a general matrix: `'gesvdj'` (Jacobi method) + - If :attr:`A` is tall or wide (`m >> n` or `m << n`): `'gesvda'` (Approximate method) + +- If :attr:`A` is not well-conditioned or precision is relevant: `'gesvd'` (QR based) + +By default (:attr:`driver`\ `= None`), we call `'gesvdj'` and, if it fails, we fallback to `'gesvd'`. + +Differences with `numpy.linalg.svd`: + +- Unlike `numpy.linalg.svd`, this function always returns a tuple of three tensors + and it doesn't support `compute_uv` argument. + Please use :func:`torch.linalg.svdvals`, which computes only the singular values, + instead of `compute_uv=False`. + +.. note:: When :attr:`full_matrices`\ `= True`, the gradients with respect to `U[..., :, min(m, n):]` + and `Vh[..., min(m, n):, :]` will be ignored, as those vectors can be arbitrary bases + of the corresponding subspaces. + +.. warning:: The returned tensors `U` and `V` are not unique, nor are they continuous with + respect to :attr:`A`. + Due to this lack of uniqueness, different hardware and software may compute + different singular vectors. + + This non-uniqueness is caused by the fact that multiplying any pair of singular + vectors :math:`u_k, v_k` by `-1` in the real case or by + :math:`e^{i \phi}, \phi \in \mathbb{R}` in the complex case produces another two + valid singular vectors of the matrix. + For this reason, the loss function shall not depend on this :math:`e^{i \phi}` quantity, + as it is not well-defined. + This is checked for complex inputs when computing the gradients of this function. As such, + when inputs are complex and are on a CUDA device, the computation of the gradients + of this function synchronizes that device with the CPU. + +.. warning:: Gradients computed using `U` or `Vh` will only be finite when + :attr:`A` does not have repeated singular values. If :attr:`A` is rectangular, + additionally, zero must also not be one of its singular values. + Furthermore, if the distance between any two singular values is close to zero, + the gradient will be numerically unstable, as it depends on the singular values + :math:`\sigma_i` through the computation of + :math:`\frac{1}{\min_{i \neq j} \sigma_i^2 - \sigma_j^2}`. + In the rectangular case, the gradient will also be numerically unstable when + :attr:`A` has small singular values, as it also depends on the computation of + :math:`\frac{1}{\sigma_i}`. + +.. seealso:: + + :func:`torch.linalg.svdvals` computes only the singular values. + Unlike :func:`torch.linalg.svd`, the gradients of :func:`~svdvals` are always + numerically stable. + + :func:`torch.linalg.eig` for a function that computes another type of spectral + decomposition of a matrix. The eigendecomposition works just on square matrices. + + :func:`torch.linalg.eigh` for a (faster) function that computes the eigenvalue decomposition + for Hermitian and symmetric matrices. + + :func:`torch.linalg.qr` for another (much faster) decomposition that works on general + matrices. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + full_matrices (bool, optional): controls whether to compute the full or reduced + SVD, and consequently, + the shape of the returned tensors + `U` and `Vh`. Default: `True`. + +Keyword args: + driver (str, optional): name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs. + Available options are: `None`, `gesvd`, `gesvdj`, and `gesvda`. + Default: `None`. + out (tuple, optional): output tuple of three tensors. Ignored if `None`. + +Returns: + A named tuple `(U, S, Vh)` which corresponds to :math:`U`, :math:`S`, :math:`V^{\text{H}}` above. + + `S` will always be real-valued, even when :attr:`A` is complex. + It will also be ordered in descending order. + + `U` and `Vh` will have the same dtype as :attr:`A`. The left / right singular vectors will be given by + the columns of `U` and the rows of `Vh` respectively. + +Examples:: + + >>> A = torch.randn(5, 3) + >>> U, S, Vh = torch.linalg.svd(A, full_matrices=False) + >>> U.shape, S.shape, Vh.shape + (torch.Size([5, 3]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(A, U @ torch.diag(S) @ Vh) + tensor(1.0486e-06) + + >>> U, S, Vh = torch.linalg.svd(A) + >>> U.shape, S.shape, Vh.shape + (torch.Size([5, 5]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(A, U[:, :3] @ torch.diag(S) @ Vh) + tensor(1.0486e-06) + + >>> A = torch.randn(7, 5, 3) + >>> U, S, Vh = torch.linalg.svd(A, full_matrices=False) + >>> torch.dist(A, U @ torch.diag_embed(S) @ Vh) + tensor(3.0957e-06) + +.. _condition number: + https://pytorch.org/docs/main/linalg.html#torch.linalg.cond +.. _the resulting vectors will span the same subspace: + https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD +""") + +svdvals = _add_docstr(_linalg.linalg_svdvals, r""" +linalg.svdvals(A, *, driver=None, out=None) -> Tensor + +Computes the singular values of a matrix. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The singular values are returned in descending order. + +.. note:: This function is equivalent to NumPy's `linalg.svd(A, compute_uv=False)`. + +""" + fr""" +.. note:: {common_notes["sync_note"]} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.svd` computes the full singular value decomposition. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + driver (str, optional): name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs. + Available options are: `None`, `gesvd`, `gesvdj`, and `gesvda`. + Check :func:`torch.linalg.svd` for details. + Default: `None`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Examples:: + + >>> A = torch.randn(5, 3) + >>> S = torch.linalg.svdvals(A) + >>> S + tensor([2.5139, 2.1087, 1.1066]) + + >>> torch.dist(S, torch.linalg.svd(A, full_matrices=False).S) + tensor(2.4576e-07) +""") + +cond = _add_docstr(_linalg.linalg_cond, r""" +linalg.cond(A, p=None, *, out=None) -> Tensor + +Computes the condition number of a matrix with respect to a matrix norm. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **condition number** :math:`\kappa` of a matrix +:math:`A \in \mathbb{K}^{n \times n}` is defined as + +.. math:: + + \kappa(A) = \|A\|_p\|A^{-1}\|_p + +The condition number of :attr:`A` measures the numerical stability of the linear system `AX = B` +with respect to a matrix norm. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +:attr:`p` defines the matrix norm that is computed. The following norms are supported: + +========= ================================= +:attr:`p` matrix norm +========= ================================= +`None` `2`-norm (largest singular value) +`'fro'` Frobenius norm +`'nuc'` nuclear norm +`inf` `max(sum(abs(x), dim=1))` +`-inf` `min(sum(abs(x), dim=1))` +`1` `max(sum(abs(x), dim=0))` +`-1` `min(sum(abs(x), dim=0))` +`2` largest singular value +`-2` smallest singular value +========= ================================= + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +For :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)`, this function uses +:func:`torch.linalg.norm` and :func:`torch.linalg.inv`. +As such, in this case, the matrix (or every matrix in the batch) :attr:`A` has to be square +and invertible. + +For :attr:`p` in `(2, -2)`, this function can be computed in terms of the singular values +:math:`\sigma_1 \geq \ldots \geq \sigma_n` + +.. math:: + + \kappa_2(A) = \frac{\sigma_1}{\sigma_n}\qquad \kappa_{-2}(A) = \frac{\sigma_n}{\sigma_1} + +In these cases, it is computed using :func:`torch.linalg.svdvals`. For these norms, the matrix +(or every matrix in the batch) :attr:`A` may have any shape. + +.. note :: When inputs are on a CUDA device, this function synchronizes that device with the CPU + if :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)`. + +.. seealso:: + + :func:`torch.linalg.solve` for a function that solves linear systems of square matrices. + + :func:`torch.linalg.lstsq` for a function that solves linear systems of general matrices. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions + for :attr:`p` in `(2, -2)`, and of shape `(*, n, n)` where every matrix + is invertible for :attr:`p` in `('fro', 'nuc', inf, -inf, 1, -1)`. + p (int, inf, -inf, 'fro', 'nuc', optional): + the type of the matrix norm to use in the computations (see above). Default: `None` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Raises: + RuntimeError: + if :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)` + and the :attr:`A` matrix or any matrix in the batch :attr:`A` is not square + or invertible. + +Examples:: + + >>> A = torch.randn(3, 4, 4, dtype=torch.complex64) + >>> torch.linalg.cond(A) + >>> A = torch.tensor([[1., 0, -1], [0, 1, 0], [1, 0, 1]]) + >>> torch.linalg.cond(A) + tensor([1.4142]) + >>> torch.linalg.cond(A, 'fro') + tensor(3.1623) + >>> torch.linalg.cond(A, 'nuc') + tensor(9.2426) + >>> torch.linalg.cond(A, float('inf')) + tensor(2.) + >>> torch.linalg.cond(A, float('-inf')) + tensor(1.) + >>> torch.linalg.cond(A, 1) + tensor(2.) + >>> torch.linalg.cond(A, -1) + tensor(1.) + >>> torch.linalg.cond(A, 2) + tensor([1.4142]) + >>> torch.linalg.cond(A, -2) + tensor([0.7071]) + + >>> A = torch.randn(2, 3, 3) + >>> torch.linalg.cond(A) + tensor([[9.5917], + [3.2538]]) + >>> A = torch.randn(2, 3, 3, dtype=torch.complex64) + >>> torch.linalg.cond(A) + tensor([[4.6245], + [4.5671]]) +""") + +pinv = _add_docstr(_linalg.linalg_pinv, r""" +linalg.pinv(A, *, atol=None, rtol=None, hermitian=False, out=None) -> Tensor + +Computes the pseudoinverse (Moore-Penrose inverse) of a matrix. + +The pseudoinverse may be `defined algebraically`_ +but it is more computationally convenient to understand it `through the SVD`_ + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +If :attr:`hermitian`\ `= True`, :attr:`A` is assumed to be Hermitian if complex or +symmetric if real, but this is not checked internally. Instead, just the lower +triangular part of the matrix is used in the computations. + +The singular values (or the norm of the eigenvalues when :attr:`hermitian`\ `= True`) +that are below :math:`\max(\text{atol}, \sigma_1 \cdot \text{rtol})` threshold are +treated as zero and discarded in the computation, +where :math:`\sigma_1` is the largest singular value (or eigenvalue). + +If :attr:`rtol` is not specified and :attr:`A` is a matrix of dimensions `(m, n)`, +the relative tolerance is set to be :math:`\text{rtol} = \max(m, n) \varepsilon` +and :math:`\varepsilon` is the epsilon value for the dtype of :attr:`A` (see :class:`.finfo`). +If :attr:`rtol` is not specified and :attr:`atol` is specified to be larger than zero then +:attr:`rtol` is set to zero. + +If :attr:`atol` or :attr:`rtol` is a :class:`torch.Tensor`, its shape must be broadcastable to that +of the singular values of :attr:`A` as returned by :func:`torch.linalg.svd`. + +.. note:: This function uses :func:`torch.linalg.svd` if :attr:`hermitian`\ `= False` and + :func:`torch.linalg.eigh` if :attr:`hermitian`\ `= True`. + For CUDA inputs, this function synchronizes that device with the CPU. + +.. note:: + Consider using :func:`torch.linalg.lstsq` if possible for multiplying a matrix on the left by + the pseudoinverse, as:: + + torch.linalg.lstsq(A, B).solution == A.pinv() @ B + + It is always preferred to use :func:`~lstsq` when possible, as it is faster and more + numerically stable than computing the pseudoinverse explicitly. + +.. note:: + This function has NumPy compatible variant `linalg.pinv(A, rcond, hermitian=False)`. + However, use of the positional argument :attr:`rcond` is deprecated in favor of :attr:`rtol`. + +.. warning:: + This function uses internally :func:`torch.linalg.svd` (or :func:`torch.linalg.eigh` + when :attr:`hermitian`\ `= True`), so its derivative has the same problems as those of these + functions. See the warnings in :func:`torch.linalg.svd` and :func:`torch.linalg.eigh` for + more details. + +.. seealso:: + + :func:`torch.linalg.inv` computes the inverse of a square matrix. + + :func:`torch.linalg.lstsq` computes :attr:`A`\ `.pinv() @ \ `:attr:`B` with a + numerically stable algorithm. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + rcond (float, Tensor, optional): [NumPy Compat]. Alias for :attr:`rtol`. Default: `None`. + +Keyword args: + atol (float, Tensor, optional): the absolute tolerance value. When `None` it's considered to be zero. + Default: `None`. + rtol (float, Tensor, optional): the relative tolerance value. See above for the value it takes when `None`. + Default: `None`. + hermitian(bool, optional): indicates whether :attr:`A` is Hermitian if complex + or symmetric if real. Default: `False`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 5) + >>> A + tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132], + [-1.1143, -0.3662, 0.3042, 1.6374, -0.9294], + [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]]) + >>> torch.linalg.pinv(A) + tensor([[ 0.0600, -0.1933, -0.2090], + [-0.0903, -0.0817, -0.4752], + [-0.7124, -0.1631, -0.2272], + [ 0.1356, 0.3933, -0.5023], + [-0.0308, -0.1725, -0.5216]]) + + >>> A = torch.randn(2, 6, 3) + >>> Apinv = torch.linalg.pinv(A) + >>> torch.dist(Apinv @ A, torch.eye(3)) + tensor(8.5633e-07) + + >>> A = torch.randn(3, 3, dtype=torch.complex64) + >>> A = A + A.T.conj() # creates a Hermitian matrix + >>> Apinv = torch.linalg.pinv(A, hermitian=True) + >>> torch.dist(Apinv @ A, torch.eye(3)) + tensor(1.0830e-06) + +.. _defined algebraically: + https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Existence_and_uniqueness +.. _through the SVD: + https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Singular_value_decomposition_(SVD) +""") + +matrix_exp = _add_docstr(_linalg.linalg_matrix_exp, r""" +linalg.matrix_exp(A) -> Tensor + +Computes the matrix exponential of a square matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +this function computes the **matrix exponential** of :math:`A \in \mathbb{K}^{n \times n}`, which is defined as + +.. math:: + \mathrm{matrix\_exp}(A) = \sum_{k=0}^\infty \frac{1}{k!}A^k \in \mathbb{K}^{n \times n}. + +If the matrix :math:`A` has eigenvalues :math:`\lambda_i \in \mathbb{C}`, +the matrix :math:`\mathrm{matrix\_exp}(A)` has eigenvalues :math:`e^{\lambda_i} \in \mathbb{C}`. + +Supports input of bfloat16, float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Example:: + + >>> A = torch.empty(2, 2, 2) + >>> A[0, :, :] = torch.eye(2, 2) + >>> A[1, :, :] = 2 * torch.eye(2, 2) + >>> A + tensor([[[1., 0.], + [0., 1.]], + + [[2., 0.], + [0., 2.]]]) + >>> torch.linalg.matrix_exp(A) + tensor([[[2.7183, 0.0000], + [0.0000, 2.7183]], + + [[7.3891, 0.0000], + [0.0000, 7.3891]]]) + + >>> import math + >>> A = torch.tensor([[0, math.pi/3], [-math.pi/3, 0]]) # A is skew-symmetric + >>> torch.linalg.matrix_exp(A) # matrix_exp(A) = [[cos(pi/3), sin(pi/3)], [-sin(pi/3), cos(pi/3)]] + tensor([[ 0.5000, 0.8660], + [-0.8660, 0.5000]]) +""") + + +solve = _add_docstr(_linalg.linalg_solve, r""" +linalg.solve(A, B, *, left=True, out=None) -> Tensor + +Computes the solution of a square system of linear equations with a unique solution. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** associated to +:math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{n \times k}`, which is defined as + +.. math:: AX = B + +If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that solves the system + +.. math:: + + XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} + +This system of linear equations has one solution if and only if :math:`A` is `invertible`_. +This function assumes that :math:`A` is invertible. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +Letting `*` be zero or more batch dimensions, + +- If :attr:`A` has shape `(*, n, n)` and :attr:`B` has shape `(*, n)` (a batch of vectors) or shape + `(*, n, k)` (a batch of matrices or "multiple right-hand sides"), this function returns `X` of shape + `(*, n)` or `(*, n, k)` respectively. +- Otherwise, if :attr:`A` has shape `(*, n, n)` and :attr:`B` has shape `(n,)` or `(n, k)`, :attr:`B` + is broadcasted to have shape `(*, n)` or `(*, n, k)` respectively. + This function then returns the solution of the resulting batch of systems of linear equations. + +.. note:: + This function computes `X = \ `:attr:`A`\ `.inverse() @ \ `:attr:`B` in a faster and + more numerically stable way than performing the computations separately. + +.. note:: + It is possible to compute the solution of the system :math:`XA = B` by passing the inputs + :attr:`A` and :attr:`B` transposed and transposing the output returned by this function. + +.. note:: + :attr:`A` is allowed to be a non-batched `torch.sparse_csr_tensor`, but only with `left=True`. + +""" + fr""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.solve_ex")} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.solve_triangular` computes the solution of a triangular system of linear + equations with a unique solution. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + B (Tensor): right-hand side tensor of shape `(*, n)` or `(*, n, k)` or `(n,)` or `(n, k)` + according to the rules described above + +Keyword args: + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the :attr:`A` matrix is not invertible or any matrix in a batched :attr:`A` + is not invertible. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> b = torch.randn(3) + >>> x = torch.linalg.solve(A, b) + >>> torch.allclose(A @ x, b) + True + >>> A = torch.randn(2, 3, 3) + >>> B = torch.randn(2, 3, 4) + >>> X = torch.linalg.solve(A, B) + >>> X.shape + torch.Size([2, 3, 4]) + >>> torch.allclose(A @ X, B) + True + + >>> A = torch.randn(2, 3, 3) + >>> b = torch.randn(3, 1) + >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3, 1) + >>> x.shape + torch.Size([2, 3, 1]) + >>> torch.allclose(A @ x, b) + True + >>> b = torch.randn(3) + >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3) + >>> x.shape + torch.Size([2, 3]) + >>> Ax = A @ x.unsqueeze(-1) + >>> torch.allclose(Ax, b.unsqueeze(-1).expand_as(Ax)) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""") + +solve_triangular = _add_docstr(_linalg.linalg_solve_triangular, r""" +linalg.solve_triangular(A, B, *, upper, left=True, unitriangular=False, out=None) -> Tensor + +Computes the solution of a triangular system of linear equations with a unique solution. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** +associated to the triangular matrix :math:`A \in \mathbb{K}^{n \times n}` without zeros on the diagonal +(that is, it is `invertible`_) and the rectangular matrix , :math:`B \in \mathbb{K}^{n \times k}`, +which is defined as + +.. math:: AX = B + +The argument :attr:`upper` signals whether :math:`A` is upper or lower triangular. + +If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that +solves the system + +.. math:: + + XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} + +If :attr:`upper`\ `= True` (resp. `False`) just the upper (resp. lower) triangular half of :attr:`A` +will be accessed. The elements below the main diagonal will be considered to be zero and will not be accessed. + +If :attr:`unitriangular`\ `= True`, the diagonal of :attr:`A` is assumed to be ones and will not be accessed. + +The result may contain `NaN` s if the diagonal of :attr:`A` contains zeros or elements that +are very close to zero and :attr:`unitriangular`\ `= False` (default) or if the input matrix +has very small eigenvalues. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.linalg.solve` computes the solution of a general square system of linear + equations with a unique solution. + +Args: + A (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= False`) + where `*` is zero or more batch dimensions. + B (Tensor): right-hand side tensor of shape `(*, n, k)`. + +Keyword args: + upper (bool): whether :attr:`A` is an upper or lower triangular matrix. + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + unitriangular (bool, optional): if `True`, the diagonal elements of :attr:`A` are assumed to be + all equal to `1`. Default: `False`. + out (Tensor, optional): output tensor. `B` may be passed as `out` and the result is computed in-place on `B`. + Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 3).triu_() + >>> B = torch.randn(3, 4) + >>> X = torch.linalg.solve_triangular(A, B, upper=True) + >>> torch.allclose(A @ X, B) + True + + >>> A = torch.randn(2, 3, 3).tril_() + >>> B = torch.randn(2, 3, 4) + >>> X = torch.linalg.solve_triangular(A, B, upper=False) + >>> torch.allclose(A @ X, B) + True + + >>> A = torch.randn(2, 4, 4).tril_() + >>> B = torch.randn(2, 3, 4) + >>> X = torch.linalg.solve_triangular(A, B, upper=False, left=False) + >>> torch.allclose(X @ A, B) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""") + +lu_factor = _add_docstr(_linalg.linalg_lu_factor, r""" +linalg.lu_factor(A, *, bool pivot=True, out=None) -> (Tensor, Tensor) + +Computes a compact representation of the LU factorization with partial pivoting of a matrix. + +This function computes a compact representation of the decomposition given by :func:`torch.linalg.lu`. +If the matrix is square, this representation may be used in :func:`torch.linalg.lu_solve` +to solve system of linear equations that share the matrix :attr:`A`. + +The returned decomposition is represented as a named tuple `(LU, pivots)`. +The ``LU`` matrix has the same shape as the input matrix ``A``. Its upper and lower triangular +parts encode the non-constant elements of ``L`` and ``U`` of the LU decomposition of ``A``. + +The returned permutation matrix is represented by a 1-indexed vector. `pivots[i] == j` represents +that in the `i`-th step of the algorithm, the `i`-th row was permuted with the `j-1`-th row. + +On CUDA, one may use :attr:`pivot`\ `= False`. In this case, this function returns the LU +decomposition without pivoting if it exists. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +""" + fr""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.lu_factor_ex")} +""" + r""" +.. warning:: The LU decomposition is almost never unique, as often there are different permutation + matrices that can yield different LU decompositions. + As such, different platforms, like SciPy, or inputs on different devices, + may produce different valid decompositions. + + Gradient computations are only supported if the input matrix is full-rank. + If this condition is not met, no error will be thrown, but the gradient may not be finite. + This is because the LU decomposition with pivoting is not differentiable at these points. + +.. seealso:: + + :func:`torch.linalg.lu_solve` solves a system of linear equations given the output of this + function provided the input matrix was square and invertible. + + :func:`torch.lu_unpack` unpacks the tensors returned by :func:`~lu_factor` into the three + matrices `P, L, U` that form the decomposition. + + :func:`torch.linalg.lu` computes the LU decomposition with partial pivoting of a possibly + non-square matrix. It is a composition of :func:`~lu_factor` and :func:`torch.lu_unpack`. + + :func:`torch.linalg.solve` solves a system of linear equations. It is a composition + of :func:`~lu_factor` and :func:`~lu_solve`. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU + decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LU, pivots)`. + +Raises: + RuntimeError: if the :attr:`A` matrix is not invertible or any matrix in a batched :attr:`A` + is not invertible. + +Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> B1 = torch.randn(2, 3, 4) + >>> B2 = torch.randn(2, 3, 7) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> X1 = torch.linalg.lu_solve(LU, pivots, B1) + >>> X2 = torch.linalg.lu_solve(LU, pivots, B2) + >>> torch.allclose(A @ X1, B1) + True + >>> torch.allclose(A @ X2, B2) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""") + +lu_factor_ex = _add_docstr(_linalg.linalg_lu_factor_ex, r""" +linalg.lu_factor_ex(A, *, pivot=True, check_errors=False, out=None) -> (Tensor, Tensor, Tensor) + +This is a version of :func:`~lu_factor` that does not perform error checks unless :attr:`check_errors`\ `= True`. +It also returns the :attr:`info` tensor returned by `LAPACK's getrf`_. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU + decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`. + check_errors (bool, optional): controls whether to check the content of ``infos`` and raise + an error if it is non-zero. Default: `False`. + out (tuple, optional): tuple of three tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LU, pivots, info)`. + +.. _LAPACK's getrf: + https://www.netlib.org/lapack/explore-html/dd/d9a/group__double_g_ecomputational_ga0019443faea08275ca60a734d0593e60.html +""") + +lu_solve = _add_docstr(_linalg.linalg_lu_solve, r""" +linalg.lu_solve(LU, pivots, B, *, left=True, adjoint=False, out=None) -> Tensor + +Computes the solution of a square system of linear equations with a unique solution given an LU decomposition. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** associated to +:math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{n \times k}`, which is defined as + +.. math:: AX = B + +where :math:`A` is given factorized as returned by :func:`~lu_factor`. + +If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that solves the system + +.. math:: + + XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} + +If :attr:`adjoint`\ `= True` (and :attr:`left`\ `= True`), given an LU factorization of :math:`A` +this function function returns the :math:`X \in \mathbb{K}^{n \times k}` that solves the system + +.. math:: + + A^{\text{H}}X = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} + +where :math:`A^{\text{H}}` is the conjugate transpose when :math:`A` is complex, and the +transpose when :math:`A` is real-valued. The :attr:`left`\ `= False` case is analogous. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +Args: + LU (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= True`) + where `*` is zero or more batch dimensions as returned by :func:`~lu_factor`. + pivots (Tensor): tensor of shape `(*, n)` (or `(*, k)` if :attr:`left`\ `= True`) + where `*` is zero or more batch dimensions as returned by :func:`~lu_factor`. + B (Tensor): right-hand side tensor of shape `(*, n, k)`. + +Keyword args: + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + adjoint (bool, optional): whether to solve the system :math:`AX=B` or :math:`A^{\text{H}}X = B`. Default: `False`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> B = torch.randn(3, 2) + >>> X = torch.linalg.lu_solve(LU, pivots, B) + >>> torch.allclose(A @ X, B) + True + + >>> B = torch.randn(3, 3, 2) # Broadcasting rules apply: A is broadcasted + >>> X = torch.linalg.lu_solve(LU, pivots, B) + >>> torch.allclose(A @ X, B) + True + + >>> B = torch.randn(3, 5, 3) + >>> X = torch.linalg.lu_solve(LU, pivots, B, left=False) + >>> torch.allclose(X @ A, B) + True + + >>> B = torch.randn(3, 3, 4) # Now solve for A^T + >>> X = torch.linalg.lu_solve(LU, pivots, B, adjoint=True) + >>> torch.allclose(A.mT @ X, B) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""") + +lu = _add_docstr(_linalg.linalg_lu, r""" +lu(A, *, pivot=True, out=None) -> (Tensor, Tensor, Tensor) + +Computes the LU decomposition with partial pivoting of a matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **LU decomposition with partial pivoting** of a matrix +:math:`A \in \mathbb{K}^{m \times n}` is defined as + +.. math:: + + A = PLU\mathrlap{\qquad P \in \mathbb{K}^{m \times m}, L \in \mathbb{K}^{m \times k}, U \in \mathbb{K}^{k \times n}} + +where `k = min(m,n)`, :math:`P` is a `permutation matrix`_, :math:`L` is lower triangular with ones on the diagonal +and :math:`U` is upper triangular. + +If :attr:`pivot`\ `= False` and :attr:`A` is on GPU, then the **LU decomposition without pivoting** is computed + +.. math:: + + A = LU\mathrlap{\qquad L \in \mathbb{K}^{m \times k}, U \in \mathbb{K}^{k \times n}} + +When :attr:`pivot`\ `= False`, the returned matrix :attr:`P` will be empty. +The LU decomposition without pivoting `may not exist`_ if any of the principal minors of :attr:`A` is singular. +In this case, the output matrix may contain `inf` or `NaN`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.linalg.solve` solves a system of linear equations using the LU decomposition + with partial pivoting. + +.. warning:: The LU decomposition is almost never unique, as often there are different permutation + matrices that can yield different LU decompositions. + As such, different platforms, like SciPy, or inputs on different devices, + may produce different valid decompositions. + +.. warning:: Gradient computations are only supported if the input matrix is full-rank. + If this condition is not met, no error will be thrown, but the gradient + may not be finite. + This is because the LU decomposition with pivoting is not differentiable at these points. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + pivot (bool, optional): Controls whether to compute the LU decomposition with partial pivoting or + no pivoting. Default: `True`. + +Keyword args: + out (tuple, optional): output tuple of three tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(P, L, U)`. + +Examples:: + + >>> A = torch.randn(3, 2) + >>> P, L, U = torch.linalg.lu(A) + >>> P + tensor([[0., 1., 0.], + [0., 0., 1.], + [1., 0., 0.]]) + >>> L + tensor([[1.0000, 0.0000], + [0.5007, 1.0000], + [0.0633, 0.9755]]) + >>> U + tensor([[0.3771, 0.0489], + [0.0000, 0.9644]]) + >>> torch.dist(A, P @ L @ U) + tensor(5.9605e-08) + + >>> A = torch.randn(2, 5, 7, device="cuda") + >>> P, L, U = torch.linalg.lu(A, pivot=False) + >>> P + tensor([], device='cuda:0') + >>> torch.dist(A, L @ U) + tensor(1.0376e-06, device='cuda:0') + +.. _permutation matrix: + https://en.wikipedia.org/wiki/Permutation_matrix +.. _may not exist: + https://en.wikipedia.org/wiki/LU_decomposition#Definitions +""") + +tensorinv = _add_docstr(_linalg.linalg_tensorinv, r""" +linalg.tensorinv(A, ind=2, *, out=None) -> Tensor + +Computes the multiplicative inverse of :func:`torch.tensordot`. + +If `m` is the product of the first :attr:`ind` dimensions of :attr:`A` and `n` is the product of +the rest of the dimensions, this function expects `m` and `n` to be equal. +If this is the case, it computes a tensor `X` such that +`tensordot(\ `:attr:`A`\ `, X, \ `:attr:`ind`\ `)` is the identity matrix in dimension `m`. +`X` will have the shape of :attr:`A` but with the first :attr:`ind` dimensions pushed back to the end + +.. code:: text + + X.shape == A.shape[ind:] + A.shape[:ind] + +Supports input of float, double, cfloat and cdouble dtypes. + +.. note:: When :attr:`A` is a `2`-dimensional tensor and :attr:`ind`\ `= 1`, + this function computes the (multiplicative) inverse of :attr:`A` + (see :func:`torch.linalg.inv`). + +.. note:: + Consider using :func:`torch.linalg.tensorsolve` if possible for multiplying a tensor on the left + by the tensor inverse, as:: + + linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B) # When B is a tensor with shape A.shape[:B.ndim] + + It is always preferred to use :func:`~tensorsolve` when possible, as it is faster and more + numerically stable than computing the pseudoinverse explicitly. + +.. seealso:: + + :func:`torch.linalg.tensorsolve` computes + `torch.tensordot(tensorinv(\ `:attr:`A`\ `), \ `:attr:`B`\ `)`. + +Args: + A (Tensor): tensor to invert. Its shape must satisfy + `prod(\ `:attr:`A`\ `.shape[:\ `:attr:`ind`\ `]) == + prod(\ `:attr:`A`\ `.shape[\ `:attr:`ind`\ `:])`. + ind (int): index at which to compute the inverse of :func:`torch.tensordot`. Default: `2`. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the reshaped :attr:`A` is not invertible or the product of the first + :attr:`ind` dimensions is not equal to the product of the rest. + +Examples:: + + >>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3)) + >>> Ainv = torch.linalg.tensorinv(A, ind=2) + >>> Ainv.shape + torch.Size([8, 3, 4, 6]) + >>> B = torch.randn(4, 6) + >>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B)) + True + + >>> A = torch.randn(4, 4) + >>> Atensorinv = torch.linalg.tensorinv(A, ind=1) + >>> Ainv = torch.linalg.inv(A) + >>> torch.allclose(Atensorinv, Ainv) + True +""") + +tensorsolve = _add_docstr(_linalg.linalg_tensorsolve, r""" +linalg.tensorsolve(A, B, dims=None, *, out=None) -> Tensor + +Computes the solution `X` to the system `torch.tensordot(A, X) = B`. + +If `m` is the product of the first :attr:`B`\ `.ndim` dimensions of :attr:`A` and +`n` is the product of the rest of the dimensions, this function expects `m` and `n` to be equal. + +The returned tensor `x` satisfies +`tensordot(\ `:attr:`A`\ `, x, dims=x.ndim) == \ `:attr:`B`. +`x` has shape :attr:`A`\ `[B.ndim:]`. + +If :attr:`dims` is specified, :attr:`A` will be reshaped as + +.. code:: text + + A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0)) + +Supports inputs of float, double, cfloat and cdouble dtypes. + +.. seealso:: + + :func:`torch.linalg.tensorinv` computes the multiplicative inverse of + :func:`torch.tensordot`. + +Args: + A (Tensor): tensor to solve for. Its shape must satisfy + `prod(\ `:attr:`A`\ `.shape[:\ `:attr:`B`\ `.ndim]) == + prod(\ `:attr:`A`\ `.shape[\ `:attr:`B`\ `.ndim:])`. + B (Tensor): tensor of shape :attr:`A`\ `.shape[:\ `:attr:`B`\ `.ndim]`. + dims (Tuple[int], optional): dimensions of :attr:`A` to be moved. + If `None`, no dimensions are moved. Default: `None`. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the reshaped :attr:`A`\ `.view(m, m)` with `m` as above is not + invertible or the product of the first :attr:`ind` dimensions is not equal + to the product of the rest of the dimensions. + +Examples:: + + >>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) + >>> B = torch.randn(2 * 3, 4) + >>> X = torch.linalg.tensorsolve(A, B) + >>> X.shape + torch.Size([2, 3, 4]) + >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B) + True + + >>> A = torch.randn(6, 4, 4, 3, 2) + >>> B = torch.randn(4, 3, 2) + >>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2)) + >>> X.shape + torch.Size([6, 4]) + >>> A = A.permute(1, 3, 4, 0, 2) + >>> A.shape[B.ndim:] + torch.Size([6, 4]) + >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6) + True +""") + +qr = _add_docstr(_linalg.linalg_qr, r""" +qr(A, mode='reduced', *, out=None) -> (Tensor, Tensor) + +Computes the QR decomposition of a matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **full QR decomposition** of a matrix +:math:`A \in \mathbb{K}^{m \times n}` is defined as + +.. math:: + + A = QR\mathrlap{\qquad Q \in \mathbb{K}^{m \times m}, R \in \mathbb{K}^{m \times n}} + +where :math:`Q` is orthogonal in the real case and unitary in the complex case, +and :math:`R` is upper triangular with real diagonal (even in the complex case). + +When `m > n` (tall matrix), as `R` is upper triangular, its last `m - n` rows are zero. +In this case, we can drop the last `m - n` columns of `Q` to form the +**reduced QR decomposition**: + +.. math:: + + A = QR\mathrlap{\qquad Q \in \mathbb{K}^{m \times n}, R \in \mathbb{K}^{n \times n}} + +The reduced QR decomposition agrees with the full QR decomposition when `n >= m` (wide matrix). + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The parameter :attr:`mode` chooses between the full and reduced QR decomposition. +If :attr:`A` has shape `(*, m, n)`, denoting `k = min(m, n)` + +- :attr:`mode`\ `= 'reduced'` (default): Returns `(Q, R)` of shapes `(*, m, k)`, `(*, k, n)` respectively. + It is always differentiable. +- :attr:`mode`\ `= 'complete'`: Returns `(Q, R)` of shapes `(*, m, m)`, `(*, m, n)` respectively. + It is differentiable for `m <= n`. +- :attr:`mode`\ `= 'r'`: Computes only the reduced `R`. Returns `(Q, R)` with `Q` empty and `R` of shape `(*, k, n)`. + It is never differentiable. + +Differences with `numpy.linalg.qr`: + +- :attr:`mode`\ `= 'raw'` is not implemented. +- Unlike `numpy.linalg.qr`, this function always returns a tuple of two tensors. + When :attr:`mode`\ `= 'r'`, the `Q` tensor is an empty tensor. + +.. warning:: The elements in the diagonal of `R` are not necessarily positive. + As such, the returned QR decomposition is only unique up to the sign of the diagonal of `R`. + Therefore, different platforms, like NumPy, or inputs on different devices, + may produce different valid decompositions. + +.. warning:: The QR decomposition is only well-defined if the first `k = min(m, n)` columns + of every matrix in :attr:`A` are linearly independent. + If this condition is not met, no error will be thrown, but the QR produced + may be incorrect and its autodiff may fail or produce incorrect results. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + mode (str, optional): one of `'reduced'`, `'complete'`, `'r'`. + Controls the shape of the returned tensors. Default: `'reduced'`. + +Keyword args: + out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(Q, R)`. + +Examples:: + + >>> A = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> Q, R = torch.linalg.qr(A) + >>> Q + tensor([[-0.8571, 0.3943, 0.3314], + [-0.4286, -0.9029, -0.0343], + [ 0.2857, -0.1714, 0.9429]]) + >>> R + tensor([[ -14.0000, -21.0000, 14.0000], + [ 0.0000, -175.0000, 70.0000], + [ 0.0000, 0.0000, -35.0000]]) + >>> (Q @ R).round() + tensor([[ 12., -51., 4.], + [ 6., 167., -68.], + [ -4., 24., -41.]]) + >>> (Q.T @ Q).round() + tensor([[ 1., 0., 0.], + [ 0., 1., -0.], + [ 0., -0., 1.]]) + >>> Q2, R2 = torch.linalg.qr(A, mode='r') + >>> Q2 + tensor([]) + >>> torch.equal(R, R2) + True + >>> A = torch.randn(3, 4, 5) + >>> Q, R = torch.linalg.qr(A, mode='complete') + >>> torch.dist(Q @ R, A) + tensor(1.6099e-06) + >>> torch.dist(Q.mT @ Q, torch.eye(4)) + tensor(6.2158e-07) +""") + +vander = _add_docstr(_linalg.linalg_vander, r""" +vander(x, N=None) -> Tensor + +Generates a Vandermonde matrix. + +Returns the Vandermonde matrix :math:`V` + +.. math:: + + V = \begin{pmatrix} + 1 & x_1 & x_1^2 & \dots & x_1^{N-1}\\ + 1 & x_2 & x_2^2 & \dots & x_2^{N-1}\\ + 1 & x_3 & x_3^2 & \dots & x_3^{N-1}\\ + \vdots & \vdots & \vdots & \ddots &\vdots \\ + 1 & x_n & x_n^2 & \dots & x_n^{N-1} + \end{pmatrix}. + +for `N > 1`. +If :attr:`N`\ `= None`, then `N = x.size(-1)` so that the output is a square matrix. + +Supports inputs of float, double, cfloat, cdouble, and integral dtypes. +Also supports batches of vectors, and if :attr:`x` is a batch of vectors then +the output has the same batch dimensions. + +Differences with `numpy.vander`: + +- Unlike `numpy.vander`, this function returns the powers of :attr:`x` in ascending order. + To get them in the reverse order call ``linalg.vander(x, N).flip(-1)``. + +Args: + x (Tensor): tensor of shape `(*, n)` where `*` is zero or more batch dimensions + consisting of vectors. + +Keyword args: + N (int, optional): Number of columns in the output. Default: `x.size(-1)` + +Example:: + + >>> x = torch.tensor([1, 2, 3, 5]) + >>> linalg.vander(x) + tensor([[ 1, 1, 1, 1], + [ 1, 2, 4, 8], + [ 1, 3, 9, 27], + [ 1, 5, 25, 125]]) + >>> linalg.vander(x, N=3) + tensor([[ 1, 1, 1], + [ 1, 2, 4], + [ 1, 3, 9], + [ 1, 5, 25]]) +""") + +vecdot = _add_docstr(_linalg.linalg_vecdot, r""" +linalg.vecdot(x, y, *, dim=-1, out=None) -> Tensor + +Computes the dot product of two batches of vectors along a dimension. + +In symbols, this function computes + +.. math:: + + \sum_{i=1}^n \overline{x_i}y_i. + +over the dimension :attr:`dim` where :math:`\overline{x_i}` denotes the conjugate for complex +vectors, and it is the identity for real vectors. + +Supports input of half, bfloat16, float, double, cfloat, cdouble and integral dtypes. +It also supports broadcasting. + +Args: + x (Tensor): first batch of vectors of shape `(*, n)`. + y (Tensor): second batch of vectors of shape `(*, n)`. + +Keyword args: + dim (int): Dimension along which to compute the dot product. Default: `-1`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> v1 = torch.randn(3, 2) + >>> v2 = torch.randn(3, 2) + >>> linalg.vecdot(v1, v2) + tensor([ 0.3223, 0.2815, -0.1944]) + >>> torch.vdot(v1[0], v2[0]) + tensor(0.3223) +""") diff --git a/lib/python3.10/site-packages/torch/masked/__init__.py b/lib/python3.10/site-packages/torch/masked/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d00ba1e8d5aff6a18490ac7b16a629ac36e3dcb5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/masked/__init__.py @@ -0,0 +1,57 @@ +from torch.masked._ops import ( + _canonical_dim, + _combine_input_and_mask, + _generate_docstring, + _input_mask, + _output_mask, + _reduction_identity, + _where, + amax, + amin, + argmax, + argmin, + cumprod, + cumsum, + log_softmax, + logaddexp, + logsumexp, + mean, + median, + norm, + normalize, + prod, + softmax, + softmin, + std, + sum, + var, +) +from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor +from torch.masked.maskedtensor.creation import as_masked_tensor, masked_tensor + + +__all__ = [ + "amax", + "amin", + "argmax", + "argmin", + "as_masked_tensor", + "cumprod", + "cumsum", + "is_masked_tensor", + "log_softmax", + "logaddexp", + "logsumexp", + "masked_tensor", + "MaskedTensor", + "mean", + "median", + "norm", + "normalize", + "prod", + "softmax", + "softmin", + "std", + "sum", + "var", +] diff --git a/lib/python3.10/site-packages/torch/masked/_docs.py b/lib/python3.10/site-packages/torch/masked/_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..fa130bbefbc5caa7373459ef2fc3dc5292239948 --- /dev/null +++ b/lib/python3.10/site-packages/torch/masked/_docs.py @@ -0,0 +1,1177 @@ +# This file is generated, do not modify it! +# +# To update this file, run the update masked docs script as follows: +# +# python tools/update_masked_docs.py +# +# The script must be called from an environment where the development +# version of torch package can be imported and is functional. +# + +amax_docstring = """amax(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns maximum of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +The identity value of maximum operation, which is used to start the +reduction, depends on input dtype. For instance, for float32, uint8, +and int32 dtypes, the identity values are ``-inf``, ``0``, and ``-2147483648``, respectively. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in maximum computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of maximum operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.amax(input, 1, mask=mask) + tensor([ -1, -9223372036854775808]) +""" + +amin_docstring = """amin(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns minimum of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +The identity value of minimum operation, which is used to start the +reduction, depends on input dtype. For instance, for float32, uint8, +and int32 dtypes, the identity values are ``inf``, ``255``, and ``2147483647``, respectively. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in minimum computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of minimum operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.amin(input, 1, mask=mask) + tensor([ -3, 9223372036854775807]) +""" + +argmax_docstring = """argmax(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor +Returns argmax of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. +The identity value of argmax operation, which is used to start the +reduction, depends on input dtype. For instance, for float32, uint8, +and int32 dtypes, the identity values are ``-inf``, ``0``, and ``-2147483648``, respectively. +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in argmax computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of argmax operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which argmax is computed. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.argmax(input, 1, mask=mask) + tensor([2, 0]) +""" + +argmin_docstring = """argmin(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor +Returns argmin of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. +The identity value of argmin operation, which is used to start the +reduction, depends on input dtype. For instance, for float32, uint8, +and int32 dtypes, the identity values are ``inf``, ``255``, and ``2147483647``, respectively. +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in argmin computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of argmin operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which argmin is computed. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.argmin(input, 1, mask=mask) + tensor([0, 0]) +""" + +cumprod_docstring = """cumprod(input, dim, *, dtype=None, mask=None) -> Tensor + +Returns cumulative_prod of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is +defined as ``prod(x[:i])``. + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +cumulative_prod computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the cumulative_prod output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which cumulative_prod is computed. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.cumprod(input, 1, mask=mask) + tensor([[-3., -3., 3.], + [ 1., 1., 1.]]) +""" + +cumsum_docstring = """cumsum(input, dim, *, dtype=None, mask=None) -> Tensor + +Returns cumulative_sum of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is +defined as ``sum(x[:i])``. + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +cumulative_sum computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the cumulative_sum output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which cumulative_sum is computed. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.cumsum(input, 1, mask=mask) + tensor([[-3., -3., -4.], + [ 0., 0., 0.]]) +""" + +log_softmax_docstring = """log_softmax(input, dim, *, dtype=None, mask=None) -> Tensor + +Returns log_softmax of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is +defined as ``log(exp(x[i])/sum(exp(x)))``. + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +log_softmax computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the log_softmax output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which log_softmax is computed. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.log_softmax(input, 1, mask=mask) + tensor([[-2.1269, -inf, -0.1269], + [ nan, nan, nan]]) +""" + +logsumexp_docstring = """logsumexp(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns logsumexp of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +The identity value of logsumexp operation, which is used to start the reduction, is ``-2147483648``. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in logsumexp computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of logsumexp operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.logsumexp(input, 1, mask=mask) + tensor([ 0, -9223372036854775808]) +""" + +mean_docstring = """mean(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns mean of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +By definition, the identity value of a mean operation is the mean +value of the tensor. If all elements of the input tensor along given +dimension(s) :attr:`dim` are masked-out, the identity value of the +mean is undefined. Due to this ambiguity, the elements of output +tensor with strided layout, that correspond to fully masked-out +elements, have ``nan`` values. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in mean computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of mean operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.mean(input, 1, mask=mask) + tensor([-2., nan]) +""" + +median_docstring = """median(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor +Returns median of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. +By definition, the identity value of a median operation is the median +value of the tensor. If all elements of the input tensor along given +dimension(s) :attr:`dim` are masked-out, the identity value of the +median is undefined. Due to this ambiguity, the elements of output +tensor with strided layout, that correspond to fully masked-out +elements, have ``nan`` values. +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in median computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of median operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which median is computed. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.median(input, 1, mask=mask) + tensor([-3., nan]) +""" + +norm_docstring = """norm(input, ord, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns norm of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +The identity value of norm operation, which is used to start the +reduction, is ``0.0``, except for ``ord=-inf`` it is +``inf``. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in norm computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of norm operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + ord (int, float, optional): the order of vector norm. Default: 2. + See :func:`torch.linalg.vector_norm` for a list of supported norms. + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.norm(input, 2.0, 1, mask=mask) + tensor([3.1623, 0.0000]) +""" + +normalize_docstring = """normalize(input, ord, dim, *, eps=1e-12, dtype=None, mask=None) -> Tensor + +Returns normalize of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Normalize of i-th element in ``x`` is +defined as ``x[i]/max(norm(x, p), eps)``. + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +normalize computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the normalize output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + ord (int, float): the order of vector norm. Default: 2. + See :func:`torch.linalg.vector_norm` for a list of supported norms. + dim (int): the dimension along which normalize is computed. + +Keyword args: + eps (float, optional): small value to avoid division by zero. Default: 1e-12. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.normalize(input, 2.0, 1, mask=mask) + tensor([[-0.9487, 0.0000, -0.3162], + [ 0.0000, 0.0000, 0.0000]]) +""" + +prod_docstring = """prod(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns product of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +The identity value of product operation, which is used to start the reduction, is ``1``. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in product computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of product operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.prod(input, 1, mask=mask) + tensor([3, 1]) +""" + +softmax_docstring = """softmax(input, dim, *, dtype=None, mask=None) -> Tensor + +Returns softmax of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Softmax of i-th element in ``x`` is +defined as ``exp(x[i])/sum(exp(x))``. + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +softmax computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the softmax output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which softmax is computed. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.softmax(input, 1, mask=mask) + tensor([[0.1192, 0.0000, 0.8808], + [ nan, nan, nan]]) +""" + +softmin_docstring = """softmin(input, dim, *, dtype=None, mask=None) -> Tensor + +Returns softmin of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Softmin of i-th element in ``x`` is +defined as ``exp(-x[i])/sum(exp(-x))``. + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +softmin computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the softmin output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int): the dimension along which softmin is computed. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3., -2., -1.], [ 0., 1., 2.]]) + >>> input + tensor([[-3., -2., -1.], + [ 0., 1., 2.]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.softmin(input, 1, mask=mask) + tensor([[0.8808, 0.0000, 0.1192], + [ nan, nan, nan]]) +""" + +std_docstring = """std(input, dim, unbiased, *, keepdim=False, dtype=None, mask=None) -> Tensor +Returns standard_deviation of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. +The identity value of sample standard deviation operation is undefined. The +elements of output tensor with strided layout, that correspond to +fully masked-out elements, have ``nan`` values. +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in standard_deviation computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of standard_deviation operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + unbiased (bool): when True, use Bessel's correction, otherwise, compute + the uncorrected sample variance. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.std(input, 1, False, mask=mask) + tensor([1., nan]) +""" + +sum_docstring = """sum(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor + +Returns sum of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. + +The identity value of sum operation, which is used to start the reduction, is ``0``. + +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in sum computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of sum operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.sum(input, 1, mask=mask) + tensor([-4, 0]) +""" + +var_docstring = """var(input, dim, unbiased, *, keepdim=False, dtype=None, mask=None) -> Tensor +Returns variance of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`. +The identity value of sample variance operation is undefined. The +elements of output tensor with strided layout, that correspond to +fully masked-out elements, have ``nan`` values. +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in variance computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of variance operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``. + unbiased (bool): when True, use Bessel's correction, otherwise, compute + the uncorrected sample variance. + +Keyword args: + keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: False. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: None. + mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. +Example:: + + >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]]) + >>> input + tensor([[-3, -2, -1], + [ 0, 1, 2]]) + >>> mask = tensor([[ True, False, True], [False, False, False]]) + >>> mask + tensor([[ True, False, True], + [False, False, False]]) + >>> torch.masked._ops.var(input, 1, False, mask=mask) + tensor([1., nan]) +""" diff --git a/lib/python3.10/site-packages/torch/masked/_ops.py b/lib/python3.10/site-packages/torch/masked/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..3d8151197a207e550647a4bc19a1ed615eff5cdd --- /dev/null +++ b/lib/python3.10/site-packages/torch/masked/_ops.py @@ -0,0 +1,1796 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import warnings +from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union + +import torch +from torch import sym_float, Tensor +from torch._prims_common import corresponding_real_dtype +from torch.masked import _docs +from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor +from torch.masked.maskedtensor.creation import as_masked_tensor + + +if TYPE_CHECKING: + from torch.types import _dtype as DType + + DimOrDims = Optional[Union[int, Tuple[int], List[int]]] +else: + # The JIT doesn't understand Union, nor torch.dtype here + DType = int + DimOrDims = Optional[Tuple[int]] + + +__all__: List[str] = [] + +# All masked reduction/normalization operations have the same +# signatures. Here we introduce docstring templates that are applied +# to docstrings of reduction/normalization functions via +# _apply_docstring_templates decorator. + + +def _apply_docstring_templates(func): + """Decorator that applies docstring templates to function docstring + and returns the function instance. + """ + + doc_string = getattr(_docs, f"{func.__name__}_docstring", None) + if doc_string is None: + warnings.warn( + f"No documentation string available for {func.__name__}." + " PyTorch team should run `python tools/update_masked_docs.py`" + " to generate the missing docstrings." + ) + else: + func.__doc__ = doc_string + + # Expose function as public symbol + __all__.append(func.__name__) + + return func + + +def _generate_docstring(func): + """A utility function called from tools/update_masked_docs.py + script to update the module torch.masked._docs.py + """ + docstring_templates = dict( + reduction_signature="""\ +{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""", + reduction_descr="""\ +Returns {operation name} of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`.""", + reduction_args="""\ +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in {operation name} computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of {operation name} operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + {args_declarations} + +Keyword args: + {kwargs_declarations}""", + reduction_example="""\ +Example:: + + >>> input = {example_input} + >>> input + {indent_example_input} + >>> mask = {example_mask} + >>> mask + {indent_example_mask} + >>> {full_function_name}(input, {example_args}, mask=mask) + {indent_example_output} +""", + reduction_identity="""\ +The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""", + reduction_identity_dtype="""\ +The identity value of {operation name} operation, which is used to start the +reduction, depends on input dtype. For instance, for float32, uint8, +and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""", + normalization_signature="""\ +{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""", + normalization_descr="""\ +Returns {operation name} of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +{definition}""", + normalization_args="""\ +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +{operation name} computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the {operation name} output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + {args_declarations} + +Keyword args: + {kwargs_declarations}""", + normalization_example="""\ +Example:: + + >>> input = {example_input} + >>> input + {indent_example_input} + >>> mask = {example_mask} + >>> mask + {indent_example_mask} + >>> {full_function_name}(input, {example_args}, mask=mask) + {indent_example_output} +""", + ) + + args_and_kwargs = dict( + # argument name sufficies separated by double underscore will + # be removed in the final documentation string. + sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + cumsum=(("dim__as_int",), ("dtype=None", "mask=None")), + cumprod=(("dim__as_int",), ("dtype=None", "mask=None")), + amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + norm=( + ( + "ord", + "dim", + ), + ("keepdim=False", "dtype=None", "mask=None"), + ), + var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), + std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), + logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + softmax=(("dim__as_int",), ("dtype=None", "mask=None")), + log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")), + softmin=(("dim__as_int",), ("dtype=None", "mask=None")), + normalize=( + ( + "ord__required", + "dim__as_int", + ), + ("eps=1e-12", "dtype=None", "mask=None"), + ), + ) + + argument_declarations = dict( + dim="""\ +dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``.""", + dim__as_int="""\ +dim (int): the dimension along which {operation name} is computed.""", + ord="""\ +ord (int, float, optional): the order of vector norm. Default: 2. + See :func:`torch.linalg.vector_norm` for a list of supported norms.""", + ord__required="""\ +ord (int, float): the order of vector norm. Default: 2. + See :func:`torch.linalg.vector_norm` for a list of supported norms.""", + unbiased="""\ +unbiased (bool): when True, use Bessel's correction, otherwise, compute + the uncorrected sample variance.""", + eps="""\ +eps (float, optional): small value to avoid division by zero. Default: {default}.""", + keepdim="""\ +keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: {default}.""", + dtype="""\ +dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: {default}.""", + mask="""\ +mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""", + ) + + definitions = dict( + softmax="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Softmax of i-th element in ``x`` is +defined as ``exp(x[i])/sum(exp(x))``.""", + log_softmax="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is +defined as ``log(exp(x[i])/sum(exp(x)))``.""", + softmin="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Softmin of i-th element in ``x`` is +defined as ``exp(-x[i])/sum(exp(-x))``.""", + normalize="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Normalize of i-th element in ``x`` is +defined as ``x[i]/max(norm(x, p), eps)``.""", + cumsum="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is +defined as ``sum(x[:i])``.""", + cumprod="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is +defined as ``prod(x[:i])``.""", + ) + + reduction_names = dict( + sum="sum", + prod="product", + amax="maximum", + amin="minimum", + argmax="argmax", + argmin="argmin", + mean="mean", + median="median", + norm="norm", + var="variance", + std="standard_deviation", + logsumexp="logsumexp", + ) + + normalization_names = dict( + softmax="softmax", + log_softmax="log_softmax", + softmin="softmin", + normalize="normalize", + cumsum="cumulative_sum", + cumprod="cumulative_prod", + ) + + operation_names = {} + operation_names.update(reduction_names) + operation_names.update(normalization_names) + + # Default example data: + example_dim = 1 + example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]]) + example_mask = torch.tensor([[True, False, True], [False, False, False]]) + example_args: Tuple[Any, ...] + if func.__name__ in {"norm", "normalize"}: + example_args = (2.0, example_dim) + example_input = example_input.to(dtype=torch.float32) + elif func.__name__ in {"var", "std"}: + example_args = (example_dim, False) + elif func.__name__ == "median": + example_args = (example_dim,) + example_input = example_input.to(dtype=torch.float32) + else: + example_args = (example_dim,) + + operation_args: Tuple[str, ...] + operation_kwargs: Tuple[str, ...] + operation_args, operation_kwargs = args_and_kwargs[func.__name__] + arg_declarations = [ + "\n ".join( + argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines() + ) + for a in operation_args + ] + kwarg_declarations = [ + "\n ".join( + argument_declarations.get( + a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.' + ) + .format(default=a.split("=", 1)[1]) + .splitlines() + ) + for a in operation_kwargs + ] + + if func.__name__ in reduction_names: + op_kind = "reduction" + doc_sections = ["signature", "descr", "identity", "args", "example"] + elif func.__name__ in normalization_names: + op_kind = "normalization" + doc_sections = ["signature", "descr", "args", "example"] + example_input = example_input.to(dtype=torch.float32) + else: + assert 0 # add function name to operation names dictionaries + example_output = func(example_input, *example_args, mask=example_mask) + + template_data = { + "function_name": func.__name__, + "full_function_name": func.__module__ + "." + func.__name__, + "operation name": operation_names[func.__name__], + "operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args), + "operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs), + # one-line representation of a tensor: + "example_input": " ".join(str(example_input).split()), + "example_args": ", ".join(map(str, example_args)), + "example_mask": " ".join(str(example_mask).split()), + # multi-line representation of a tensor with indent + "indent_example_input": ("\n ").join(str(example_input).splitlines()), + "indent_example_mask": ("\n ").join(str(example_mask).splitlines()), + "indent_example_output": ("\n ").join(str(example_output).splitlines()), + } + + if func.__name__ in reduction_names: + template_data.update( + identity_uint8=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.uint8) + ), + identity_int32=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.int32) + ), + identity_float32=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.float32) + ), + ) + if func.__name__ == "norm": + template_data.update( + identity_ord_ninf=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf") + ) + ) + elif func.__name__ in normalization_names: + template_data.update(definition=definitions[func.__name__]) + else: + assert 0 # add function name to operation names dictionaries + template_data.update( + args_declarations=("\n ".join(arg_declarations)).format_map(template_data) + ) + template_data.update( + kwargs_declarations=("\n ".join(kwarg_declarations)).format_map( + template_data + ) + ) + + # Apply function name info to docstring templates: + templates = { + k: v.format_map(template_data) + for k, v in docstring_templates.items() + if k.startswith(op_kind) + } + templates.update( + (k, v.format_map(template_data) if isinstance(v, str) else v) + for k, v in template_data.items() + ) + + # Apply docstring templates to function doctring: + if func.__doc__ is None: + doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections]) + else: + doc_template = func.__doc__ + return doc_template.format_map(templates) + + +def _reduction_identity(op_name: str, input: Tensor, *args): + """Return identity value as scalar tensor of a reduction operation on + given input, or None, if the identity value cannot be uniquely + defined for the given input. + + The identity value of the operation is defined as the initial + value to reduction operation that has a property ``op(op_identity, + value) == value`` for any value in the domain of the operation. + Or put it another way, including or excluding the identity value in + a list of operands will not change the reduction result. + + See https://github.com/pytorch/rfcs/pull/27 for more information. + + """ + dtype: DType = input.dtype + device = input.device + op_name = op_name.rsplit(".", 1)[-1] # lstrip module name when present + if op_name in {"sum", "cumsum"}: + return torch.tensor(0, dtype=dtype, device=device) + elif op_name in {"prod", "cumprod"}: + return torch.tensor(1, dtype=dtype, device=device) + elif op_name in {"amax", "argmax", "logaddexp"}: + if torch.is_floating_point(input): + return torch.tensor(-torch.inf, dtype=dtype, device=device) + elif torch.is_signed(input) or dtype == torch.uint8: + return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) + elif op_name in {"logsumexp"}: + if torch.is_floating_point(input): + return torch.tensor(-torch.inf, dtype=dtype, device=device) + elif torch.is_complex(input): + return torch.tensor(-torch.inf + 0j, dtype=dtype, device=device) + elif torch.is_signed(input) or dtype == torch.uint8: + return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) + elif op_name in {"amin", "argmin"}: + if torch.is_floating_point(input): + return torch.tensor(torch.inf, dtype=dtype, device=device) + elif torch.is_signed(input) or dtype == torch.uint8: + return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device) + elif op_name == "mean": + # Strictly speaking, the identity value of the mean operation + # is the mean of the input. Since the mean value depends on + # the dim argument and it may be a non-scalar tensor, we + # consider the identity value of the mean operation ambiguous. + # Moreover, the mean value of empty input is undefined. + return None + elif op_name == "norm": + ord = args[0] if args else 2 + if ord == float("-inf"): + assert torch.is_floating_point(input), input.dtype + return torch.tensor(torch.inf, dtype=dtype, device=device) + return torch.tensor(0, dtype=dtype, device=device) + elif op_name == "median": + # We use NaN for now because the implementation is currently using torch.nanmedian + # and NaN is the identity for that function since it gets ignored + dtype = input.dtype if torch.is_floating_point(input) else torch.float + return torch.tensor(torch.nan, dtype=dtype, device=device) + elif op_name in {"var", "std"}: + return None + raise NotImplementedError(f"identity of {op_name} on {dtype} input") + + +def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]: + """Return dim argument as a tuple of sorted dim values.""" + dims: List[int] = [] + if dim == (): + # Currently, `dim=()` in reductions operations means "reduce + # over all dimensions" while in future, it will read "no + # reduce". See https://github.com/pytorch/pytorch/issues/29137 + # When gh-29137 is resolved, this if-block must be deleted. + dim = None + if dim is None: + return tuple(range(ndim)) + ndim = max(ndim, 1) + dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim + for d in dim_: + if d in dims: + raise RuntimeError(f"dim={d} appears multiple times in the list of dims") + if d >= ndim or d < -ndim: + raise IndexError( + f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})" + ) + dims.append(d % ndim) + return tuple(sorted(dims)) + + +def _sparse_coo_flatten_indices(indices: Tensor, shape: tuple): + # Flatted N-D indices to 1-D indices + flat_indices = indices.new_zeros(indices.size(1)) + for d, sz in enumerate(shape): + flat_indices.mul_(sz) + flat_indices.add_(indices[d]) + return flat_indices + + +def _any(input: Tensor, dim: tuple, keepdim: bool): + # Support torch.any with tuple dim argument. + # Workaround of https://github.com/pytorch/pytorch/issues/56586 + r = input + for d in reversed(dim): + r = r.any(dim=d, keepdim=keepdim) + return r + + +def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: + """Sparse variant of torch.where. Supports sparse COO and hybrid sparse COO tensors. + + _sparse_coo_where implements the following invariant: + + _sparse_coo_where(mask, input, fill_value).to_dense(fill_value) == + torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value)) + + where `a == b` means `assertEqual(a, b)`, mask is boolean sparse + tensor, and `to_dense(fill_value)` is like `to_dense()` except + that the unspecified elements are mapped to `fill_value` rather + than to `0`. + + Returns a sparse COO tensor with the following features: + + - all specified elements correspond to masked-in elements that + have the values of the input tensor. If there exists a masked-in + element (as specified by mask) that is not specified in the + input, in the result tensor, the corresponding element has value + 0. In the dense part of the sparse tensor, the masked-out + elements are replaced with fill_value. + + - all unspecified elements correspond to masked-out elements. + """ + + assert input.layout == torch.sparse_coo + assert mask.layout == input.layout + assert mask.shape == input.shape + assert mask.dense_dim() == input.dense_dim() # TODO: eliminate this restriction + + input = input.coalesce() + + # For set operations on sparse tensor indices, we'll convert + # multi-dimensional indices to 1-D indices for efficiency. + input_flat_indices = _sparse_coo_flatten_indices( + input.indices(), input.shape[: input.sparse_dim()] + ) + mask_flat_indices = _sparse_coo_flatten_indices( + mask.indices(), mask.shape[: mask.sparse_dim()] + ) + + # the set of mask flat indices that define masked-in elements: + if mask.dense_dim() > 0: + mask_values = _any( + mask.values(), tuple(range(1, input.sparse_dim() + 1)), False + ) + else: + mask_values = mask.values() + maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]] + + def intersection(i1, i2): + union, counts = torch.cat([i1, i2]).unique(return_counts=True) + return union, torch.where(counts.gt(1)) + + def minus(i1, i2): + union, counts = torch.cat([i1, i2]).unique(return_counts=True) + return intersection(union[torch.where(counts.eq(1))], i1) + + def _apply(a): + obj, w = a + return obj[w] + + # the set of input flat indices of specified and masked-in elements: + maskin_input_flat_indices = _apply( + intersection(maskin_flat_indices, input_flat_indices) + ) + _, w = intersection(input_flat_indices, maskin_input_flat_indices) + + # the indices and values of masked-in elements + where_input_indices = input.indices()[(slice(None),) + w] + where_input_values = input.values()[w] + + if mask.dense_dim() > 0: + # apply mask to the dense part of the input values: + _, w1 = intersection(mask_flat_indices, maskin_input_flat_indices) + where_mask_values = mask.values()[w1] + where_input_values = torch.where( + where_mask_values, where_input_values, fill_value + ) + + # the set of flat indices of unspecified input and masked-in elements: + maskin_zero_flat_indices = _apply( + minus(maskin_flat_indices, maskin_input_flat_indices) + ) + + # the indices of masked-in zero elements + _, w = intersection(mask_flat_indices, maskin_zero_flat_indices) + where_zero_indices = mask.indices()[(slice(None),) + w] + + # construct result + n = where_zero_indices.size(1) + if n == 0: + # the input is coalesced, hence input_flat_indices are ordered + # and the result is guaranteed to be coalesced: + result = torch.sparse_coo_tensor( + where_input_indices, where_input_values, input.shape + ) + return result._coalesced_(True) + + where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1) + where_values = torch.cat( + [ + where_input_values, + where_input_values.new_zeros((n,) + where_input_values.shape[1:]), + ] + ) + result = torch.sparse_coo_tensor(where_indices, where_values, input.shape) + + # appending zero elements leads to uncoalesced sparse tensor + return result.coalesce() + + +def _sparse_coo_scatter_reduction_helper( + op, + mask_input: Tensor, + dims: Tuple[int, ...], + keepdim: bool, + dtype: Optional[DType] = None, +) -> Tensor: + reduce = op.__name__ + valid_reductions = ["sum", "prod", "amax", "amin"] + if reduce not in valid_reductions: + raise ValueError( + f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead" + ) + + output_dtype = dtype + values, indices = mask_input._values(), mask_input._indices() + input_dims = mask_input.dim() + num_sparse_dims = mask_input.sparse_dim() + reduced_sparse_dims = [] + retained_sparse_dims = [] + reduced_dense_dims = [] + + # promote dtype if specified + if values.dtype != output_dtype: + values = values.to(output_dtype) + + if keepdim: + output_shape = tuple( + 1 if i in dims else si for (i, si) in enumerate(mask_input.shape) + ) + else: + output_shape = tuple( + si for (i, si) in enumerate(mask_input.shape) if i not in dims + ) + + for d in dims: + if d >= input_dims: + continue + + if d < num_sparse_dims: + reduced_sparse_dims.append(d) + else: + reduced_dense_dims.append(d + 1 - num_sparse_dims) + + # Reduce dense dimensions + if len(reduced_dense_dims) > 0: + if reduce == "sum": + new_values = values + new_values = op(new_values, dim=reduced_dense_dims, keepdim=bool(keepdim)) + else: + # FIXME: Implement reductions for dense dimensions for ops with non-zero reduction identities + return NotImplemented + else: + new_values = values.clone() + + # Reduce sparse dimensions + if len(reduced_sparse_dims) == num_sparse_dims: + if reduce in {"amax", "amin"} and new_values.size(0) == 0: + # IndexError: amax(): Expected reduction dim 0 to have non-zero size. + # sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not + # See https://github.com/pytorch/pytorch/issues/61901 + new_values = _reduction_identity(reduce, new_values) + else: + new_values = op(new_values, dim=0) + if keepdim: + for _ in range(num_sparse_dims): + new_values = new_values.unsqueeze(0) + return new_values.to(dtype=output_dtype).to_sparse() + else: + new_indices = indices.clone() + if keepdim: + # zero out reduced sparse dimensions if keepdim = True + # ensures that the call to torch.unique folds duplicated indices together while preserving the dimension + new_indices[reduced_sparse_dims, :] = 0 + else: + # remove reduced sparse dimensions if keepdim = False + if len(reduced_sparse_dims) > 0: + retained_sparse_dims = [ + i + for i in range(num_sparse_dims) + if i not in set(reduced_sparse_dims) + ] + new_indices = new_indices.index_select( + 0, torch.tensor(retained_sparse_dims).to(mask_input.device) + ) + + # Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices + if new_indices.numel() > 0: + # lexsort indices and get index tensor for scatter reduction + new_indices, inverse_indices = torch.unique( + new_indices, return_inverse=True, dim=1 + ) + out_shape = list(new_values.shape) + out_shape[0] = new_indices.shape[1] + for _ in range(new_values.ndim - 1): + inverse_indices = inverse_indices.unsqueeze(-1) + scatter_indices = inverse_indices.expand(new_values.shape) + # FIXME: temporary workaround for issue with bfloat16/float16 remove when acctype is implemented for scatter_reduce + if output_dtype in {torch.bfloat16, torch.float16}: + new_values = new_values.to(torch.float) + out = new_values.new_empty(out_shape) + new_values = out.scatter_reduce_( + 0, scatter_indices, new_values, reduce=reduce, include_self=False + ) + new_values = new_values.to(dtype=output_dtype) + else: + out = new_values.new_empty(out_shape) + new_values = out.scatter_reduce_( + 0, scatter_indices, new_values, reduce=reduce, include_self=False + ) + + return torch.sparse_coo_tensor( + new_indices, + new_values, + output_shape, + dtype=output_dtype, + device=mask_input.device, + ) + + +def _sparse_csr_segment_reduction_helper( + op, + mask_input: Tensor, + dims: Tuple[int, ...], + keepdim: bool, + dtype: Optional[DType] = None, +) -> Tensor: + # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True + # FIXME: when dense dimensions are implemented for CSR tensors + assert ( + keepdim + ), "reduction operations on CSR tensors with keepdim=False is unsupported" + reduce = op.__name__ + valid_reductions = ["sum", "prod", "mean", "amax", "amin"] + if reduce not in valid_reductions: + raise ValueError( + f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead" + ) + device = mask_input.device + output_dtype = dtype + values, crow_indices, col_indices = ( + mask_input.values(), + mask_input.crow_indices(), + mask_input.col_indices(), + ) + + # promote dtype if specified + if values.dtype != output_dtype: + values = values.to(output_dtype) + + if len(dims) == 0: + return mask_input + if len(dims) == 1: + if dims[0] == 0: + new_col_indices, scatter_indices = torch.unique( + col_indices, return_inverse=True + ) + new_nnz = new_col_indices.shape[0] + new_crow_indices = torch.tensor([0, new_nnz]) + new_values = values.new_empty(new_col_indices.shape) + new_values.scatter_reduce_( + 0, scatter_indices, values, reduce, include_self=False + ) + new_shape = [1, mask_input.size(1)] + else: + assert ( + dims[0] == 1 + ), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1." + # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1 + # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0 + new_crow_indices = torch.cat( + ( + crow_indices.new_zeros(1), + torch.cumsum(torch.diff(crow_indices) != 0, 0), + ), + 0, + ) + new_nnz = new_crow_indices[-1] + new_col_indices = col_indices.new_zeros(new_nnz) + new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined] + new_shape = [mask_input.size(0), 1] + else: + assert len(dims) == 2 + nnz = min(1, values.numel()) + if nnz == 1: + op_kwargs = {"keepdim": True, "dtype": output_dtype} + # amax and amin do not support dtype kwarg + if reduce in ["amax", "amin"]: + del op_kwargs["dtype"] + new_values = op(values, 0, **op_kwargs) + else: + new_values = torch.empty(0, dtype=output_dtype) + new_col_indices = col_indices.new_zeros(nnz) + new_crow_indices = torch.tensor([0, nnz]) + new_shape = [1, nnz] + + return torch.sparse_csr_tensor( + new_crow_indices, + new_col_indices, + new_values, + new_shape, + dtype=output_dtype, + device=device, + ) + + +def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: + """Sparse variant of torch.where. Supports sparse CSR tensors.""" + # TODO: implement sparse CSR specific where operator for efficiency + return _sparse_coo_where( + mask.to_sparse_coo(), input.to_sparse_coo(), fill_value + ).to_sparse_csr() + + +def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: + """torch.where with sparse inputs support. + + _where implements the following invariant: + + _where(mask, input, fill_value).to_dense(fill_value) == + torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value)) + + where `a == b` means `assertEqual(a, b)`, mask is boolean sparse + tensor, and `to_dense(fill_value)` is like `to_dense()` except + that the unspecified elements are mapped to `fill_value` rather + than to `0`. + + Returns a sparse tensor with the following features: + + - all specified elements correspond to masked-in elements that + have the values of the input tensor. If there exists a masked-in + element (as specified by mask) that is not specified in the + input, in the result tensor, the corresponding element has value + 0. In the dense part of the sparse tensor, the masked-out + elements are replaced with fill_value. + + - all unspecified elements correspond to masked-out elements. + """ + if mask.layout == torch.strided: + return torch.where(mask, input, fill_value) + elif mask.layout == torch.sparse_coo: + return _sparse_coo_where(mask, input, fill_value) + elif mask.layout == torch.sparse_csr: + return _sparse_csr_where(mask, input, fill_value) + else: + raise ValueError( + f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}" + ) + + +def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor: + """Return canonical input mask. + + A canonical input mask is defined as a boolean mask tensor that + shape and layout matches with the shape and the layout of the + input. + + The canonical input mask is computed from the :attr:`mask` tensor + content to meet the following criteria: + + 1. The shape of the canonical input mask is the same as the shape + of :attr:`input` tensor. If the mask tensor has a smaller shape + than the shape of the :attr:`input`, broadcasting rules will be + applied. Downcasting of mask is not supported. + + 2. The layout of the canonical input mask is the same as the + layout of the :attr:`input` tensor. If the mask has different + layout, it will be converted to the expected layout. In the + case of sparse COO layout, the canonical input mask will be + coalesced. + + 3. The dtype of the canonical input mask is torch.bool. If the + mask dtype is not bool then it will be converted to bool dtype + using `.to(dtype=bool)` method call. + + 4. The elements of the canonical input mask have boolean values + copied from the content of the :attr:`mask` tensor (after + possible broadcasting and dtype conversion transforms). In + general, the sparsity pattern of the sparse canonical input + mask need not to be the same as the sparsity pattern of the + sparse :attr:`input` tensor. + + """ + if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: + raise ValueError( + f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}" + ) + + mask = kwargs.get("mask") + + # default mask + if mask is None: + raise ValueError("_input_mask requires explicit mask") + + # mask shape must match with input shape + if mask.shape != input.shape: + if mask.ndim > input.ndim: + raise IndexError( + "_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)" + ) + if mask.layout == torch.strided: + mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool) + elif mask.layout == torch.sparse_coo: + mask = torch._sparse_broadcast_to(mask, input.shape) + else: + assert mask.layout == torch.sparse_csr + # Broadcasting of CSR tensors is not implemented. Working + # around by using COO layout. + mask = torch._sparse_broadcast_to( + mask.to_sparse(), input.shape + ).to_sparse_csr() + + # mask layout must match with input layout + if mask.layout != input.layout: + if input.layout == torch.strided: + mask = mask.to_dense() + elif input.layout == torch.sparse_coo: + if mask.layout == torch.strided: + mask = mask.to_sparse(input.sparse_dim()) + else: + mask = mask.to_sparse() + else: + assert input.layout == torch.sparse_csr + mask = mask.to_sparse_csr() + + # sparse mask must be coalesced + if mask.layout == torch.sparse_coo: + mask = mask.coalesce() + + # mask is a boolean tensor + mask = mask.to(dtype=torch.bool) + + return mask + + +def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor: + """Return output mask of masked operation applied to given arguments.""" + if callable(op): + is_reduction = op.__name__ in { + "sum", + "prod", + "amax", + "amin", + "argmax", + "argmin", + "mean", + "median", + "norm", + "var", + "std", + "logsumexp", + } + is_normalization = op.__name__ in { + "softmax", + "log_softmax", + "softmin", + "normalize", + "cumsum", + "cumprod", + } + if is_reduction: + if op.__name__ == "norm": + if args: + args = args[1:] # lstrip ord argument + dim = args[0] if args else kwargs.get("dim") + outmask = _input_mask(input, *args, **kwargs) + keepdim = kwargs.get("keepdim", False) + dim_ = _canonical_dim(dim, input.ndim) + return _any(outmask, dim_, bool(keepdim)) + elif is_normalization: + return _input_mask(input, *args, **kwargs) + else: + raise ValueError( + f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})" + ) + else: + raise ValueError( + f"_output_mask expected masked operation (got {type(op).__name__} object)" + ) + + +def _combine_input_and_mask( + op, input: Union[MaskedTensor, Tensor], mask, *args +) -> Tensor: + def helper(input, mask): + if mask is None: + return input + canonical_mask = _input_mask(input, mask=mask) + if callable(op): + fill_value = _reduction_identity(op.__name__, input, *args) + return _where(canonical_mask, input, fill_value) + else: + raise ValueError( + f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)" + ) + + class Combine(torch.autograd.Function): + @staticmethod + def forward(ctx, input, mask): + """Return input with masked-out elements eliminated for the given operations.""" + ctx.save_for_backward(mask) + + if mask is not None: + ctx.mark_non_differentiable(mask) + + return helper(input, mask) + + @staticmethod + def backward(ctx, grad_output): + (mask,) = ctx.saved_tensors + grad_data = ( + grad_output.get_data() if is_masked_tensor(grad_output) else grad_output + ) + result = as_masked_tensor(grad_data, mask) + return result, None + + return ( + Combine.apply(input.get_data(), input.get_mask()) # type: ignore[union-attr] + if is_masked_tensor(input) + else helper(input, mask) + ) + + +@_apply_docstring_templates +def sum( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + # __doc__ is generated by _apply_docstring_templates decorator + if dtype is None: + # promote integer types to int64 when output dtype is not specified + if input.layout == torch.sparse_csr: + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: + # csr.to(dtype=torch.int64) is not implemented, so + # using coo.to on input to ensure the promoted dtype + input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr() + else: + dtype = input.dtype + else: + dtype = input.dtype + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: + dtype = torch.int64 + dim_ = _canonical_dim(dim, input.ndim) + mask_input = _combine_input_and_mask(sum, input, mask) + if mask_input.layout == torch.strided: + return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype) + elif mask_input.layout == torch.sparse_coo: + return _sparse_coo_scatter_reduction_helper( + torch.sum, mask_input, dim_, bool(keepdim), dtype + ) + elif mask_input.layout == torch.sparse_csr: + return torch._sparse_csr_sum( + mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype + ) + else: + raise ValueError( + f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def prod( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + # __doc__ is generated by _apply_docstring_templates decorator + if dtype is None: + # promote integer types to int64 when output dtype is not specified + if input.layout == torch.sparse_csr: + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: + # csr.to(dtype=torch.int64) is not implemented, so + # using coo.to on input to ensure the promoted dtype + input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr() + else: + dtype = input.dtype + else: + dtype = input.dtype + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: + dtype = torch.int64 + dim_ = _canonical_dim(dim, input.ndim) + mask_input = _combine_input_and_mask(prod, input, mask) + if mask_input.layout == torch.strided: + # Workaround https://github.com/pytorch/pytorch/issues/56586 + result = mask_input + result = result.to(dtype=dtype) + for d in reversed(dim_): + result = result.prod(dim=d, keepdim=bool(keepdim)) + return result + elif mask_input.layout == torch.sparse_coo: + if mask is None: + # See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors + raise ValueError( + "masked prod expects explicit mask for sparse_coo tensor input" + ) + return _sparse_coo_scatter_reduction_helper( + torch.prod, mask_input, dim_, bool(keepdim), dtype + ) + elif mask_input.layout == torch.sparse_csr: + if mask is None: + # mask is None corresponds to all-True mask. The + # unspecified elements in the CSR tensor correspond to + # zero values. Hence, the prod reduction result is + # automatically zero unless all elements are specified. + # A semi-optimal way to take this into account is to use: + # + # masked_prod(csr, ..., mask=None) == torch._sparse_csr_prod(csr, ...) * all(csr.nonzero(), ...) + # + # but that requires implementing `all` and `nonzero` + # support for sparse csr tensors. + raise ValueError( + "masked prod expects explicit mask for sparse_csr tensor input" + ) + return torch._sparse_csr_prod( + mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype + ) + else: + raise ValueError( + f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def cumsum( + input: Tensor, + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(sum, input, mask) + if mask_input.layout == torch.strided: + return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype) + else: + raise ValueError( + f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def cumprod( + input: Tensor, + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(prod, input, mask) + if mask_input.layout == torch.strided: + return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype) + else: + raise ValueError( + f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def amax( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +{reduction_identity_dtype} + +{reduction_args} + +{reduction_example}""" + if dtype is None: + dtype = input.dtype + + mask_input = _combine_input_and_mask(amax, input, mask) + dim_ = _canonical_dim(dim, mask_input.ndim) + if mask_input.layout == torch.strided: + return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype) + elif mask_input.layout == torch.sparse_coo: + if mask is None: + # See comment in the sparse_csr branch of prod, a similar issue arises here + # where unspecified elements along a dimension may need to be reduced with the result + raise ValueError( + "masked amax expects explicit mask for sparse_coo tensor input" + ) + return _sparse_coo_scatter_reduction_helper( + torch.amax, mask_input, dim_, bool(keepdim), dtype + ) + elif mask_input.layout == torch.sparse_csr: + if mask is None: + raise ValueError( + "masked amax expects explicit mask for sparse_csr tensor input" + ) + return _sparse_csr_segment_reduction_helper( + torch.amax, mask_input, dim_, bool(keepdim), dtype + ) + else: + raise ValueError( + f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def amin( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +{reduction_identity_dtype} + +{reduction_args} + +{reduction_example}""" + if dtype is None: + dtype = input.dtype + + mask_input = _combine_input_and_mask(amin, input, mask) + dim_ = _canonical_dim(dim, mask_input.ndim) + if mask_input.layout == torch.strided: + return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype) + elif mask_input.layout == torch.sparse_coo: + if mask is None: + # See comment in the sparse_csr branch of prod, a similar issue arises here + # where unspecified elements along a dimension may need to be reduced with the result + raise ValueError( + "masked amax expects explicit mask for sparse_coo tensor input" + ) + return _sparse_coo_scatter_reduction_helper( + torch.amin, mask_input, dim_, bool(keepdim), dtype + ) + elif mask_input.layout == torch.sparse_csr: + if mask is None: + raise ValueError( + "masked amin expects explicit mask for sparse_csr tensor input" + ) + return _sparse_csr_segment_reduction_helper( + torch.amin, mask_input, dim_, bool(keepdim), dtype + ) + else: + raise ValueError( + f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def argmax( + input: Union[Tensor, MaskedTensor], + dim: Optional[int] = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +{reduction_identity_dtype} +{reduction_args} +{reduction_example}""" + if dtype is None: + dtype = input.dtype + mask_input = _combine_input_and_mask(argmax, input, mask) + if mask_input.layout == torch.strided: + return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype) + else: + raise ValueError( + f"masked argmax expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def argmin( + input: Union[Tensor, MaskedTensor], + dim: Optional[int] = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +{reduction_identity_dtype} +{reduction_args} +{reduction_example}""" + if dtype is None: + dtype = input.dtype + mask_input = _combine_input_and_mask(argmin, input, mask) + if mask_input.layout == torch.strided: + return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype) + else: + raise ValueError( + f"masked argmin expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def mean( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +By definition, the identity value of a mean operation is the mean +value of the tensor. If all elements of the input tensor along given +dimension(s) :attr:`dim` are masked-out, the identity value of the +mean is undefined. Due to this ambiguity, the elements of output +tensor with strided layout, that correspond to fully masked-out +elements, have ``nan`` values. + +{reduction_args} + +{reduction_example}""" + if dtype is None: + dtype = input.dtype + if input.layout == torch.strided: + if mask is None: + # TODO: compute count analytically + count = sum( + torch.ones(input.shape, dtype=torch.int64, device=input.device), + dim, + keepdim=keepdim, + ) + total = sum(input, dim, keepdim=keepdim, dtype=dtype) + else: + inmask = _input_mask(input, mask=mask) + count = inmask.sum(dim=dim, keepdim=bool(keepdim)) + total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask) + return total / count + elif input.layout == torch.sparse_csr: + mask_input = _combine_input_and_mask(mean, input, mask) + dim_ = _canonical_dim(dim, mask_input.ndim) + if mask is None: + raise ValueError( + "masked mean expects explicit mask for sparse_csr tensor input" + ) + return _sparse_csr_segment_reduction_helper( + torch.mean, mask_input, dim_, bool(keepdim), dtype + ) + else: + raise ValueError( + f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)" + ) + + +@_apply_docstring_templates +def median( + input: Union[Tensor, MaskedTensor], + dim: int = -1, + *, + keepdim: bool = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +By definition, the identity value of a median operation is the median +value of the tensor. If all elements of the input tensor along given +dimension(s) :attr:`dim` are masked-out, the identity value of the +median is undefined. Due to this ambiguity, the elements of output +tensor with strided layout, that correspond to fully masked-out +elements, have ``nan`` values. +{reduction_args} +{reduction_example}""" + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + is_float = torch.is_floating_point(input) + if not is_float: + input = input.to(dtype=torch.float) + mask_input = _combine_input_and_mask(median, input, mask) + if mask_input.layout == torch.strided: + output = torch.nanmedian(mask_input, dim_, keepdim).values + if is_float: + return output + elif not is_float and not torch.isnan(output).any(): + return output.to(dtype=dtype) + else: + raise ValueError( + "masked median expects no fully masked out rows if dtype is not floating point" + ) + else: + raise ValueError( + f"masked median expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def logsumexp( + input: Tensor, + dim: DimOrDims = None, + *, + keepdim: bool = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim) + mask_input = _combine_input_and_mask(logsumexp, input, mask) + if mask_input.layout == torch.strided: + return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype) + else: + raise ValueError( + f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)" + ) + + +# Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations +def logaddexp( + input: Union[Tensor, MaskedTensor], + other: Union[Tensor, MaskedTensor], + *, + dtype: Optional[DType] = None, + input_mask: Optional[Tensor] = None, + other_mask: Optional[Tensor] = None, +) -> Tensor: + """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor + + Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other` + tensor. The :attr:`input` elements are masked out according to the boolean tensor + :attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor + :attr:`other_mask`. + + The shapes of a mask tensor and the tensor to be masked + don't need to match, but they must be :ref:`broadcastable + ` and the dimensionality of the mask + tensor must not be greater than of the tensor to be masked. + + Args: + input (Tensor): the input tensor + other (Tensor): the second input tensor + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the output tensor is + casted to :attr:`dtype` after the operation is + performed. Default: None. + input_mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of :attr:`input` tensor elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + other_mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of :attr:`other` tensor elements. + Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``. + + Example:: + + >>> input = torch.tensor([-100.0, -200, -300]) + >>> input + tensor([-100., -200., -300.]) + >>> other = torch.tensor([-1.0, -2, -3]) + >>> other + tensor([-1., -2., -3.]) + >>> mask = torch.tensor([True, False, True]) + >>> mask + tensor([ True, False, True]) + >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask) + tensor([-1., -inf, -3.])""" + if dtype is None: + dtype = input.dtype + if input.layout == torch.strided and other.layout == torch.strided: + mask_input = _combine_input_and_mask(logaddexp, input, input_mask) + mask_other = _combine_input_and_mask(logaddexp, other, other_mask) + return torch.logaddexp(mask_input, mask_other).to(dtype=dtype) + else: + raise ValueError( + f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)" + ) + + +@_apply_docstring_templates +def norm( + input: Union[Tensor, MaskedTensor], + ord: Optional[float] = 2.0, + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +The identity value of norm operation, which is used to start the +reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is +``{identity_ord_ninf}``. + +{reduction_args} + +{reduction_example}""" + if dtype is None: + dtype = input.dtype + mask_input = _combine_input_and_mask(norm, input, mask, ord) + if mask_input.layout == torch.strided: + dim_ = _canonical_dim(dim, input.ndim) + return torch.linalg.vector_norm( + mask_input, ord, dim_, bool(keepdim), dtype=dtype + ) + else: + raise ValueError( + f"masked norm expects strided tensor (got {mask_input.layout} tensor)" + ) + + +def _std_var( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims, + unbiased: Optional[bool], + *, + correction_opt: Optional[Union[int, float]], + keepdim: Optional[bool], + dtype: Optional[DType], + mask: Optional[Tensor], + take_sqrt: Optional[bool], +) -> Tensor: + assert ( + unbiased is None or correction_opt is None + ), "Only one of unbiased and correction may be given" + correction = 1.0 + if unbiased is not None: + correction = 1.0 if unbiased else 0.0 + if correction_opt is not None: + correction = sym_float(correction_opt) + + if dtype is None: + dtype = input.dtype + if not (dtype.is_floating_point or dtype.is_complex): + dtype = torch.float32 + compute_dtype = dtype + if not (compute_dtype.is_floating_point or compute_dtype.is_complex): + compute_dtype = torch.float32 + if input.layout == torch.strided: + if mask is None: + # TODO: compute count analytically + count = sum( + torch.ones(input.shape, dtype=torch.int64, device=input.device), + dim, + keepdim=True, + ) + sample_total = sum(input, dim, keepdim=True, dtype=dtype) + else: + inmask = _input_mask(input, mask=mask) + count = inmask.sum(dim=dim, keepdim=True) + sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask) + # TODO: replace torch.subtract/divide/square/maximum with + # masked subtract/divide/square/maximum when these will be + # available. + sample_mean = torch.divide(sample_total, count) + x = torch.subtract(input, sample_mean) + if mask is None: + total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype) + else: + total = sum( + x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined] + ) + if not keepdim: + count = count.reshape(total.shape) + if correction != 0: + real_dtype = ( + corresponding_real_dtype(compute_dtype) + if compute_dtype.is_complex + else compute_dtype + ) + count = count.to(real_dtype) + count = torch.subtract(count, correction) + count = torch.maximum(count, count.new_zeros([])) + output = torch.divide(total, count).to(dtype=dtype) + if take_sqrt: + output = torch.sqrt(output) + return output + else: + raise ValueError( + f"masked std/var expects strided tensor (got {input.layout} tensor)" + ) + + +@_apply_docstring_templates +def var( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + unbiased: Optional[bool] = None, + *, + correction: Optional[Union[int, float]] = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +The identity value of sample variance operation is undefined. The +elements of output tensor with strided layout, that correspond to +fully masked-out elements, have ``nan`` values. +{reduction_args} +{reduction_example}""" + return _std_var( + input=input, + dim=dim, + unbiased=unbiased, + correction_opt=correction, + keepdim=keepdim, + dtype=dtype, + mask=mask, + take_sqrt=False, + ) + + +@_apply_docstring_templates +def std( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + unbiased: Optional[bool] = None, + *, + correction: Optional[int] = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +The identity value of sample standard deviation operation is undefined. The +elements of output tensor with strided layout, that correspond to +fully masked-out elements, have ``nan`` values. +{reduction_args} +{reduction_example}""" + return _std_var( + input=input, + dim=dim, + unbiased=unbiased, + correction_opt=correction, + keepdim=keepdim, + dtype=dtype, + mask=mask, + take_sqrt=True, + ) + + +@_apply_docstring_templates +def softmax( + input: Union[Tensor, MaskedTensor], + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(amax, input, mask) + if mask_input.layout == torch.strided: + return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype) + else: + raise ValueError( + f"masked softmax expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def log_softmax( + input: Union[Tensor, MaskedTensor], + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(amax, input, mask) + if mask_input.layout == torch.strided: + return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype) + else: + raise ValueError( + f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def softmin( + input: Union[Tensor, MaskedTensor], + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(amin, input, mask) + if mask_input.layout == torch.strided: + return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype) + else: + raise ValueError( + f"masked softmin expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def normalize( + input: Union[Tensor, MaskedTensor], + ord: float, + dim: int, + *, + eps: float = 1e-12, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + # TODO: eliminate mask_input as unnecessary when using masked divide. + mask_input = _combine_input_and_mask(sum, input, mask) + if mask_input.layout == torch.strided: + nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask) + # TODO: replace torch.maximum with masked maximum when available. + denom = torch.maximum(nrm_, nrm_.new_full([], eps)) + # TODO: replace torch.divide with masked divide when available. + return torch.divide(mask_input, denom) + else: + raise ValueError( + f"masked normalize expects strided tensor (got {mask_input.layout} tensor)" + ) diff --git a/lib/python3.10/site-packages/torch/monitor/__init__.py b/lib/python3.10/site-packages/torch/monitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36493cd7539c90736a06d08da972eb546645dbda --- /dev/null +++ b/lib/python3.10/site-packages/torch/monitor/__init__.py @@ -0,0 +1,38 @@ +from torch._C._monitor import * # noqa: F403 +from typing import TYPE_CHECKING + +from torch._C._monitor import _WaitCounter # type: ignore[attr-defined] + +if TYPE_CHECKING: + from torch.utils.tensorboard import SummaryWriter + + +STAT_EVENT = "torch.monitor.Stat" + + +class TensorboardEventHandler: + """ + TensorboardEventHandler is an event handler that will write known events to + the provided SummaryWriter. + + This currently only supports ``torch.monitor.Stat`` events which are logged + as scalars. + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_MONITOR) + >>> # xdoctest: +REQUIRES(module:tensorboard) + >>> from torch.utils.tensorboard import SummaryWriter + >>> from torch.monitor import TensorboardEventHandler, register_event_handler + >>> writer = SummaryWriter("log_dir") + >>> register_event_handler(TensorboardEventHandler(writer)) + """ + def __init__(self, writer: "SummaryWriter") -> None: + """ + Constructs the ``TensorboardEventHandler``. + """ + self._writer = writer + + def __call__(self, event: Event) -> None: + if event.name == STAT_EVENT: + for k, v in event.data.items(): + self._writer.add_scalar(k, v, walltime=event.timestamp.timestamp()) diff --git a/lib/python3.10/site-packages/torch/mps/__init__.py b/lib/python3.10/site-packages/torch/mps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc6b00d3c1a719023dd4e10c5a3f62d02b6a2ed4 --- /dev/null +++ b/lib/python3.10/site-packages/torch/mps/__init__.py @@ -0,0 +1,166 @@ +# mypy: allow-untyped-defs +r""" +This package enables an interface for accessing MPS (Metal Performance Shaders) backend in Python. +Metal is Apple's API for programming metal GPU (graphics processor unit). Using MPS means that increased +performance can be achieved, by running work on the metal GPU(s). +See https://developer.apple.com/documentation/metalperformanceshaders for more details. +""" +from typing import Union + +import torch +from torch import Tensor + + +_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False) +_default_mps_generator: torch._C.Generator = None # type: ignore[assignment] + + +# local helper function (not public or exported) +def _get_default_mps_generator() -> torch._C.Generator: + global _default_mps_generator + if _default_mps_generator is None: + _default_mps_generator = torch._C._mps_get_default_generator() + return _default_mps_generator + + +def device_count() -> int: + r"""Returns the number of available MPS devices.""" + return int(torch._C._has_mps and torch._C._mps_is_available()) + + +def synchronize() -> None: + r"""Waits for all kernels in all streams on a MPS device to complete.""" + return torch._C._mps_deviceSynchronize() + + +def get_rng_state(device: Union[int, str, torch.device] = "mps") -> Tensor: + r"""Returns the random number generator state as a ByteTensor. + + Args: + device (torch.device or int, optional): The device to return the RNG state of. + Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device). + """ + return _get_default_mps_generator().get_state() + + +def set_rng_state( + new_state: Tensor, device: Union[int, str, torch.device] = "mps" +) -> None: + r"""Sets the random number generator state. + + Args: + new_state (torch.ByteTensor): The desired state + device (torch.device or int, optional): The device to set the RNG state. + Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device). + """ + new_state_copy = new_state.clone(memory_format=torch.contiguous_format) + _get_default_mps_generator().set_state(new_state_copy) + + +def manual_seed(seed: int) -> None: + r"""Sets the seed for generating random numbers. + + Args: + seed (int): The desired seed. + """ + # the torch.mps.manual_seed() can be called from the global + # torch.manual_seed() in torch/random.py. So we need to make + # sure mps is available (otherwise we just return without + # erroring out) + if not torch._C._has_mps: + return + seed = int(seed) + _get_default_mps_generator().manual_seed(seed) + + +def seed() -> None: + r"""Sets the seed for generating random numbers to a random number.""" + _get_default_mps_generator().seed() + + +def empty_cache() -> None: + r"""Releases all unoccupied cached memory currently held by the caching + allocator so that those can be used in other GPU applications. + """ + torch._C._mps_emptyCache() + + +def set_per_process_memory_fraction(fraction) -> None: + r"""Set memory fraction for limiting process's memory allocation on MPS device. + The allowed value equals the fraction multiplied by recommended maximum device memory + (obtained from Metal API device.recommendedMaxWorkingSetSize). + If trying to allocate more than the allowed value in a process, it will raise an out of + memory error in allocator. + + Args: + fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction. + + .. note:: + Passing 0 to fraction means unlimited allocations + (may cause system failure if out of memory). + Passing fraction greater than 1.0 allows limits beyond the value + returned from device.recommendedMaxWorkingSetSize. + """ + + if not isinstance(fraction, float): + raise TypeError("Invalid type for fraction argument, must be `float`") + if fraction < 0 or fraction > 2: + raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~2") + + torch._C._mps_setMemoryFraction(fraction) + + +def current_allocated_memory() -> int: + r"""Returns the current GPU memory occupied by tensors in bytes. + + .. note:: + The returned size does not include cached allocations in + memory pools of MPSAllocator. + """ + return torch._C._mps_currentAllocatedMemory() + + +def driver_allocated_memory() -> int: + r"""Returns total GPU memory allocated by Metal driver for the process in bytes. + + .. note:: + The returned size includes cached allocations in MPSAllocator pools + as well as allocations from MPS/MPSGraph frameworks. + """ + return torch._C._mps_driverAllocatedMemory() + + +def recommended_max_memory() -> int: + r"""Returns recommended max Working set size for GPU memory in bytes. + + .. note:: + Recommended max working set size for Metal. + returned from device.recommendedMaxWorkingSetSize. + """ + return torch._C._mps_recommendedMaxMemory() + + +def is_available() -> bool: + return device_count() > 0 + + +from . import profiler +from .event import Event + + +__all__ = [ + "device_count", + "get_rng_state", + "manual_seed", + "seed", + "set_rng_state", + "synchronize", + "empty_cache", + "set_per_process_memory_fraction", + "current_allocated_memory", + "driver_allocated_memory", + "Event", + "profiler", + "recommended_max_memory", + "is_available", +] diff --git a/lib/python3.10/site-packages/torch/mps/event.py b/lib/python3.10/site-packages/torch/mps/event.py new file mode 100644 index 0000000000000000000000000000000000000000..d619c027480c3ad6c52744afa76f35ff4cba64c0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/mps/event.py @@ -0,0 +1,46 @@ +# mypy: allow-untyped-defs +import torch + + +class Event: + r"""Wrapper around an MPS event. + + MPS events are synchronization markers that can be used to monitor the + device's progress, to accurately measure timing, and to synchronize MPS streams. + + Args: + enable_timing (bool, optional): indicates if the event should measure time + (default: ``False``) + """ + + def __init__(self, enable_timing=False): + self.__eventId = torch._C._mps_acquireEvent(enable_timing) + + def __del__(self): + # checks if torch._C is already destroyed + if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0: + torch._C._mps_releaseEvent(self.__eventId) + + def record(self): + r"""Records the event in the default stream.""" + torch._C._mps_recordEvent(self.__eventId) + + def wait(self): + r"""Makes all future work submitted to the default stream wait for this event.""" + torch._C._mps_waitForEvent(self.__eventId) + + def query(self): + r"""Returns True if all work currently captured by event has completed.""" + return torch._C._mps_queryEvent(self.__eventId) + + def synchronize(self): + r"""Waits until the completion of all work currently captured in this event. + This prevents the CPU thread from proceeding until the event completes. + """ + torch._C._mps_synchronizeEvent(self.__eventId) + + def elapsed_time(self, end_event): + r"""Returns the time elapsed in milliseconds after the event was + recorded and before the end_event was recorded. + """ + return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId) diff --git a/lib/python3.10/site-packages/torch/mps/profiler.py b/lib/python3.10/site-packages/torch/mps/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..4dcd86b30ecbf765cb13c674e602bdf21058f668 --- /dev/null +++ b/lib/python3.10/site-packages/torch/mps/profiler.py @@ -0,0 +1,61 @@ +# mypy: allow-untyped-defs +import contextlib + +import torch + + +__all__ = ["start", "stop", "profile"] + + +def start(mode: str = "interval", wait_until_completed: bool = False) -> None: + r"""Start OS Signpost tracing from MPS backend. + + The generated OS Signposts could be recorded and viewed in + XCode Instruments Logging tool. + + Args: + mode(str): OS Signpost tracing mode could be "interval", "event", + or both "interval,event". + The interval mode traces the duration of execution of the operations, + whereas event mode marks the completion of executions. + See document `Recording Performance Data`_ for more info. + wait_until_completed(bool): Waits until the MPS Stream complete + executing each encoded GPU operation. This helps generating single + dispatches on the trace's timeline. + Note that enabling this option would affect the performance negatively. + + .. _Recording Performance Data: + https://developer.apple.com/documentation/os/logging/recording_performance_data + """ + mode_normalized = mode.lower().replace(" ", "") + torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed) + + +def stop(): + r"""Stops generating OS Signpost tracing from MPS backend.""" + torch._C._mps_profilerStopTrace() + + +@contextlib.contextmanager +def profile(mode: str = "interval", wait_until_completed: bool = False): + r"""Context Manager to enabling generating OS Signpost tracing from MPS backend. + + Args: + mode(str): OS Signpost tracing mode could be "interval", "event", + or both "interval,event". + The interval mode traces the duration of execution of the operations, + whereas event mode marks the completion of executions. + See document `Recording Performance Data`_ for more info. + wait_until_completed(bool): Waits until the MPS Stream complete + executing each encoded GPU operation. This helps generating single + dispatches on the trace's timeline. + Note that enabling this option would affect the performance negatively. + + .. _Recording Performance Data: + https://developer.apple.com/documentation/os/logging/recording_performance_data + """ + try: + start(mode, wait_until_completed) + yield + finally: + stop() diff --git a/lib/python3.10/site-packages/torch/mtia/__init__.py b/lib/python3.10/site-packages/torch/mtia/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa107e973f6d9ce9cbf153d4d544b1dca02e07e --- /dev/null +++ b/lib/python3.10/site-packages/torch/mtia/__init__.py @@ -0,0 +1,332 @@ +# mypy: allow-untyped-defs +r""" +This package enables an interface for accessing MTIA backend in python +""" + +import threading +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import device as _device, Tensor +from torch._utils import _dummy_type, _LazySeedTracker, classproperty +from torch.types import Device + +from ._utils import _get_device_index + + +_device_t = Union[_device, str, int, None] + +# torch.mtia.Event/Stream is alias of torch.Event/Stream +Event = torch.Event +Stream = torch.Stream + +_initialized = False +_queued_calls: List[ + Tuple[Callable[[], None], List[str]] +] = [] # don't invoke these until initialization occurs +_tls = threading.local() +_initialization_lock = threading.Lock() +_lazy_seed_tracker = _LazySeedTracker() + + +def init(): + _lazy_init() + + +def is_initialized(): + r"""Return whether PyTorch's MTIA state has been initialized.""" + return _initialized and not _is_in_bad_fork() + + +def _is_in_bad_fork() -> bool: + return torch._C._mtia_isInBadFork() + + +def _lazy_init() -> None: + global _initialized, _queued_calls + if is_initialized() or hasattr(_tls, "is_initializing"): + return + with _initialization_lock: + # We be double-checking locking, boys! This is OK because + # the above test was GIL protected anyway. The inner test + # is for when a thread blocked on some other thread which was + # doing the initialization; when they get the lock, they will + # find there is nothing left to do. + if is_initialized(): + return + # It is important to prevent other threads from entering _lazy_init + # immediately, while we are still guaranteed to have the GIL, because some + # of the C calls we make below will release the GIL + if _is_in_bad_fork(): + raise RuntimeError( + "Cannot re-initialize MTIA in forked subprocess. To use MTIA with " + "multiprocessing, you must use the 'spawn' start method" + ) + if not _is_compiled(): + raise AssertionError( + "Torch not compiled with MTIA enabled. " + "Ensure you have `import mtia.host_runtime.torch_mtia` in your python " + "src file and include `//mtia/host_runtime/torch_mtia:torch_mtia` as " + "your target dependency!" + ) + + torch._C._mtia_init() + # Some of the queued calls may reentrantly call _lazy_init(); + # we need to just return without initializing in that case. + # However, we must not let any *other* threads in! + _tls.is_initializing = True + + for calls in _lazy_seed_tracker.get_calls(): + if calls: + _queued_calls.append(calls) + + try: + for queued_call, orig_traceback in _queued_calls: + try: + queued_call() + except Exception as e: + msg = ( + f"MTIA call failed lazily at initialization with error: {str(e)}\n\n" + f"MTIA call was originally invoked at:\n\n{''.join(orig_traceback)}" + ) + raise DeferredMtiaCallError(msg) from e + finally: + delattr(_tls, "is_initializing") + _initialized = True + + +class DeferredMtiaCallError(Exception): + pass + + +def _is_compiled() -> bool: + r"""Return true if compiled with MTIA support.""" + return torch._C._mtia_isBuilt() + + +def is_available() -> bool: + r"""Return true if MTIA device is available""" + if not _is_compiled(): + return False + # MTIA has to init devices first to know if there is any devices available. + return device_count() > 0 + + +def synchronize(device: Optional[_device_t] = None) -> None: + r"""Waits for all jobs in all streams on a MTIA device to complete.""" + with torch.mtia.device(device): + return torch._C._mtia_deviceSynchronize() + + +def device_count() -> int: + r"""Return the number of MTIA devices available.""" + return torch._C._accelerator_hooks_device_count() + + +def current_device() -> int: + r"""Return the index of a currently selected device.""" + return torch._C._accelerator_hooks_get_current_device() + + +def current_stream(device: Optional[_device_t] = None) -> Stream: + r"""Return the currently selected :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the currently selected :class:`Stream` for the current device, given + by :func:`~torch.mtia.current_device`, if :attr:`device` is ``None`` + (default). + """ + return torch._C._mtia_getCurrentStream(_get_device_index(device, optional=True)) + + +def default_stream(device: Optional[_device_t] = None) -> Stream: + r"""Return the default :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the default :class:`Stream` for the current device, given by + :func:`~torch.mtia.current_device`, if :attr:`device` is ``None`` + (default). + """ + return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True)) + + +def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]: + r"""Return a dictionary of MTIA memory allocator statistics for a given device. + + Args: + device (torch.device or int, optional) selected device. Returns + statistics for the current device, given by current_device(), + if device is None (default). + """ + if not is_initialized(): + return {} + return torch._C._mtia_memoryStats(_get_device_index(device, optional=True)) + + +def set_stream(stream: Stream): + r"""Set the current stream.This is a wrapper API to set the stream. + Usage of this function is discouraged in favor of the ``stream`` + context manager. + + Args: + stream (Stream): selected stream. This function is a no-op + if this argument is ``None``. + """ + if stream is None: + return + torch._C._mtia_setCurrentStream(stream) + + +def set_device(device: _device_t) -> None: + r"""Set the current device. + + Args: + device (torch.device or int): selected device. This function is a no-op + if this argument is negative. + """ + device = _get_device_index(device) + if device >= 0: + torch._C._accelerator_hooks_set_current_device(device) + + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device: Any): + self.idx = _get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch._C._accelerator_hooks_maybe_exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch._C._accelerator_hooks_maybe_exchange_device(self.prev_idx) + return False + + +class StreamContext: + r"""Context-manager that selects a given stream. + + All MTIA kernels queued within its context will be enqueued on a selected + stream. + + Args: + Stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: Streams are per-device. + """ + + cur_stream: Optional["torch.mtia.Stream"] + + def __init__(self, stream: Optional["torch.mtia.Stream"]): + self.cur_stream = None + self.stream = stream + self.idx = _get_device_index(None, True) + if not torch.jit.is_scripting(): + if self.idx is None: + self.idx = -1 + + self.src_prev_stream = ( + None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) + ) + self.dst_prev_stream = ( + None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) + ) + + def __enter__(self): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # Return if stream is None or MTIA device not available + if cur_stream is None or self.idx == -1: + return + self.src_prev_stream = torch.mtia.current_stream(None) + + # If the stream is not on the current device, then + # set the current stream on the device + if self.src_prev_stream.device != cur_stream.device: + with device(cur_stream.device): + self.dst_prev_stream = torch.mtia.current_stream(cur_stream.device) + torch.mtia.set_stream(cur_stream) + + def __exit__(self, type: Any, value: Any, traceback: Any): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # If stream is None or no MTIA device available, return + if cur_stream is None or self.idx == -1: + return + + # Reset the stream on the original device + # and destination device + if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr] + torch.mtia.set_stream(self.dst_prev_stream) # type: ignore[arg-type] + torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type] + + +def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext: + r"""Wrap around the Context-manager StreamContext that selects a given stream. + + Arguments: + stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + ..Note:: In eager mode stream is of type Stream class while in JIT it doesn't support torch.mtia.stream + """ + return StreamContext(stream) + + +def get_rng_state(device: Union[int, str, torch.device] = "mtia") -> Tensor: + r"""Returns the random number generator state as a ByteTensor. + + Args: + device (torch.device or int, optional): The device to return the RNG state of. + Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device). + """ + warnings.warn( + "get_rng_state is not implemented in torch.mtia", + UserWarning, + stacklevel=2, + ) + return torch.zeros([1], dtype=torch.uint8, device=device) + + +def set_rng_state( + new_state: Tensor, device: Union[int, str, torch.device] = "mtia" +) -> None: + r"""Sets the random number generator state. + + Args: + new_state (torch.ByteTensor): The desired state + device (torch.device or int, optional): The device to set the RNG state. + Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device). + """ + warnings.warn( + "set_rng_state is not implemented in torch.mtia", + UserWarning, + stacklevel=2, + ) + + +__all__ = [ + "init", + "is_available", + "is_initialized", + "synchronize", + "device_count", + "current_device", + "current_stream", + "default_stream", + "memory_stats", + "set_device", + "set_stream", + "stream", + "device", + "set_rng_state", + "get_rng_state", +] diff --git a/lib/python3.10/site-packages/torch/mtia/_utils.py b/lib/python3.10/site-packages/torch/mtia/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..090e26f321232f9687c2b348ac602dbb6699b03f --- /dev/null +++ b/lib/python3.10/site-packages/torch/mtia/_utils.py @@ -0,0 +1,38 @@ +from typing import Any + +import torch + +# The _get_device_index has been moved to torch.utils._get_device_index +from torch._utils import _get_device_index as _torch_get_device_index + + +def _get_device_index( + device: Any, optional: bool = False, allow_cpu: bool = False +) -> int: + r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + is a MTIA device. Note that for a MTIA device without a specified index, + i.e., ``torch.device('mtia')``, this will return the current default MTIA + device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, + CPU devices will be accepted and ``-1`` will be returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default MTIA + device if :attr:`optional` is ``True``. + """ + if isinstance(device, int): + return device + if isinstance(device, str): + device = torch.device(device) + if isinstance(device, torch.device): + if allow_cpu: + if device.type not in ["mtia", "cpu"]: + raise ValueError(f"Expected a mtia or cpu device, but got: {device}") + elif device.type != "mtia": + raise ValueError(f"Expected a mtia device, but got: {device}") + if not torch.jit.is_scripting(): + if isinstance(device, torch.mtia.device): + return device.idx + return _torch_get_device_index(device, optional, allow_cpu) diff --git a/lib/python3.10/site-packages/torch/multiprocessing/__init__.py b/lib/python3.10/site-packages/torch/multiprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..745c180d8c415c4e52472864d26381e0872f3354 --- /dev/null +++ b/lib/python3.10/site-packages/torch/multiprocessing/__init__.py @@ -0,0 +1,100 @@ +# mypy: allow-untyped-defs +"""torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module. + +It registers custom reducers, that use shared memory to provide shared +views on the same data in different processes. Once the tensor/storage is moved +to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible +to send it to other processes without making any copies. + +The API is 100% compatible with the original module - it's enough to change +``import multiprocessing`` to ``import torch.multiprocessing`` to have all the +tensors sent through the queues or shared via other mechanisms, moved to shared +memory. + +Because of the similarity of APIs we do not document most of this package +contents, and we recommend referring to very good docs of the original module. +""" +import multiprocessing +import sys + +import torch + +from .reductions import init_reductions + + +__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"] + + +from multiprocessing import * # noqa: F403 + + +__all__ += multiprocessing.__all__ # noqa: PLE0605 type: ignore[attr-defined] + + +# This call adds a Linux specific prctl(2) wrapper function to this module. +# See https://github.com/pytorch/pytorch/pull/14391 for more information. +torch._C._multiprocessing_init() + + +"""Add helper function to spawn N processes and wait for completion of any of +them. This depends `mp.get_context` which was added in Python 3.4.""" +from .spawn import ( + ENV_VAR_PARALLEL_START, + ProcessContext, + ProcessExitedException, + ProcessRaisedException, + spawn, + SpawnContext, + start_processes, +) + + +if sys.platform == "darwin" or sys.platform == "win32": + _sharing_strategy = "file_system" + _all_sharing_strategies = {"file_system"} +else: + _sharing_strategy = "file_descriptor" + _all_sharing_strategies = {"file_descriptor", "file_system"} + + +def set_sharing_strategy(new_strategy): + """Set the strategy for sharing CPU tensors. + + Args: + new_strategy (str): Name of the selected strategy. Should be one of + the values returned by :func:`get_all_sharing_strategies()`. + """ + global _sharing_strategy + assert new_strategy in _all_sharing_strategies + _sharing_strategy = new_strategy + + +def get_sharing_strategy(): + """Return the current strategy for sharing CPU tensors.""" + return _sharing_strategy + + +def get_all_sharing_strategies(): + """Return a set of sharing strategies supported on a current system.""" + return _all_sharing_strategies + + +def _set_thread_name(name: str) -> None: + """Set the name of the current thread. + + Args: + name (str): Name of the current thread. + """ + torch._C._set_thread_name(name) + + +def _get_thread_name() -> str: + """Get the name of the current thread. + + Returns: + str: Name of the current thread. + """ + return torch._C._get_thread_name() + + +init_reductions() diff --git a/lib/python3.10/site-packages/torch/multiprocessing/_atfork.py b/lib/python3.10/site-packages/torch/multiprocessing/_atfork.py new file mode 100644 index 0000000000000000000000000000000000000000..ac4a97c10c07ae680765b0f362ef33c4bfb2308b --- /dev/null +++ b/lib/python3.10/site-packages/torch/multiprocessing/_atfork.py @@ -0,0 +1,35 @@ +# mypy: allow-untyped-defs +import sys + + +__all__ = ["register_after_fork"] + +if sys.platform == "win32": + import multiprocessing.util as _util + + def _register(func): + def wrapper(arg): + func() + + _util.register_after_fork(_register, wrapper) + +else: + import os + + def _register(func): + os.register_at_fork(after_in_child=func) + + +def register_after_fork(func): + """Register a callable to be executed in the child process after a fork. + + Note: + In python < 3.7 this will only work with processes created using the + ``multiprocessing`` module. In python >= 3.7 it also works with + ``os.fork()``. + + Args: + func (function): Function taking no arguments to be called in the child after fork + + """ + _register(func) diff --git a/lib/python3.10/site-packages/torch/multiprocessing/pool.py b/lib/python3.10/site-packages/torch/multiprocessing/pool.py new file mode 100644 index 0000000000000000000000000000000000000000..6915203566469cfaf7170d87894ce03cc8348dd5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/multiprocessing/pool.py @@ -0,0 +1,52 @@ +import multiprocessing.pool +import multiprocessing.util as util + +from .queue import SimpleQueue + + +def clean_worker(*args, **kwargs): + import gc + + multiprocessing.pool.worker(*args, **kwargs) + # Regular multiprocessing workers don't fully clean up after themselves, + # so we have to explicitly trigger garbage collection to make sure that all + # destructors are called... + gc.collect() + + +class Pool(multiprocessing.pool.Pool): + """Pool implementation which uses our version of SimpleQueue. + + This lets us pass tensors in shared memory across processes instead of + serializing the underlying data. + """ + + def _setup_queues(self): + self._inqueue = SimpleQueue() + self._outqueue = SimpleQueue() + self._quick_put = self._inqueue._writer.send + self._quick_get = self._outqueue._reader.recv + + def _repopulate_pool(self): + """Increase the number of pool processes to the specified number. + + Bring the number of pool processes up to the specified number, for use after + reaping workers which have exited. + """ + for i in range(self._processes - len(self._pool)): + # changed worker -> clean_worker + args = ( + self._inqueue, + self._outqueue, + self._initializer, + self._initargs, + self._maxtasksperchild, + ) + if hasattr(self, "_wrap_exception"): + args += (self._wrap_exception,) + w = self.Process(target=clean_worker, args=args) + self._pool.append(w) + w.name = w.name.replace("Process", "PoolWorker") + w.daemon = True + w.start() + util.debug("added worker") diff --git a/lib/python3.10/site-packages/torch/multiprocessing/queue.py b/lib/python3.10/site-packages/torch/multiprocessing/queue.py new file mode 100644 index 0000000000000000000000000000000000000000..876bf8d0e7459b60a41b59b0a093608e515ba455 --- /dev/null +++ b/lib/python3.10/site-packages/torch/multiprocessing/queue.py @@ -0,0 +1,43 @@ +# mypy: allow-untyped-defs +import io +import multiprocessing.queues +import pickle +from multiprocessing.reduction import ForkingPickler + + +class ConnectionWrapper: + """Proxy class for _multiprocessing.Connection which uses ForkingPickler for object serialization.""" + + def __init__(self, conn): + self.conn = conn + + def send(self, obj): + buf = io.BytesIO() + ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj) + self.send_bytes(buf.getvalue()) + + def recv(self): + buf = self.recv_bytes() + return pickle.loads(buf) + + def __getattr__(self, name): + if "conn" in self.__dict__: + return getattr(self.conn, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'") + + +class Queue(multiprocessing.queues.Queue): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) + self._writer: ConnectionWrapper = ConnectionWrapper(self._writer) + self._send = self._writer.send + self._recv = self._reader.recv + + +class SimpleQueue(multiprocessing.queues.SimpleQueue): + def _make_methods(self): + if not isinstance(self._reader, ConnectionWrapper): + self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) + self._writer: ConnectionWrapper = ConnectionWrapper(self._writer) + super()._make_methods() # type: ignore[misc] diff --git a/lib/python3.10/site-packages/torch/multiprocessing/reductions.py b/lib/python3.10/site-packages/torch/multiprocessing/reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0818571a93c0e9809c4638446e7ebdb15bd87e --- /dev/null +++ b/lib/python3.10/site-packages/torch/multiprocessing/reductions.py @@ -0,0 +1,647 @@ +# mypy: allow-untyped-defs +import multiprocessing +import os +import threading +from multiprocessing.reduction import ForkingPickler +from multiprocessing.util import register_after_fork +from typing import Union + +import torch +from torch._namedtensor_internals import check_serializing_named_tensor + + +try: + # Early load resource_sharer to prevent a partially initialized instance + # from being inherited in a forked child process. The reduce_storage method + # requires this module indirectly through DupFd(). The built-in mp.Queue + # class pickles arguments in a background thread which may overlap with the + # fork. + import multiprocessing.resource_sharer +except ImportError: + pass + + +class StorageWeakRef: + r"""A weak reference to a Storage. + + The cdata member is a Python number containing the integer representation of + the Storage pointer. + """ + + __slots__ = ["cdata", "_free_weak_ref"] + + def __init__(self, storage): + self.cdata = storage._weak_ref() + # Save a direct reference to _free_weak_ref because the `torch` module + # might be cleared during Python shutdown before this module is cleared. + self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined] + + @classmethod + def from_weakref(cls, cdata): + instance = cls.__new__(cls) + instance.cdata = cdata + instance._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined] + return instance + + def expired(self): + return torch.Storage._expired(self.cdata) # type: ignore[attr-defined] + + def __del__(self): + self._free_weak_ref(self.cdata) + + def __hash__(self): + return self.cdata + + def __eq__(self, other): + if id(self) == id(other): + return True + return self.cdata == other.cdata + + +class SharedCache(dict): + """Dictionary from multiprocessing handles to StorageWeakRef.""" + + def __init__(self) -> None: + # free_dead_references() is called if the len exceeds the current + # limit. The limit scales with the number of remaining live objects. + self.limit = 128 + # `fork` inherits lock state, so in case we fork when the lock is held, + # we register a function to reset the lock to a new object to avoid + # possible deadlocks, following python multiprocessing library design. + self._after_fork() + register_after_fork(self, SharedCache._after_fork) + + def _after_fork(self): + self.lock = threading.Lock() + + def get(self, key): + with self.lock: + return dict.get(self, key) + + def __setitem__(self, key, storage_ref): + with self.lock: + dict.__setitem__(self, key, storage_ref) + if len(self) > self.limit: + self.free_dead_references() + + def free_dead_references(self): + live = 0 + for key, storage_ref in list(self.items()): + if storage_ref.expired(): + del self[key] + else: + live += 1 + self.limit = max(128, live * 2) + + +# mapping from handles to StorageWeakRef objects +shared_cache = SharedCache() + + +def rebuild_event(device, handle): + return torch.cuda.Event.from_ipc_handle(device, handle) + + +def reduce_event(event): + handle = event.ipc_handle() + return (rebuild_event, (event.device, handle)) + + +def rebuild_tensor(cls, storage, metadata): + storage_offset, size, stride, requires_grad = metadata + t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) + if cls == torch.nn.parameter.Parameter: + # we have to pass requires_grad into constructor, rather than set it as an + # attribute later, because it's an important check for Integer Tensors to + # have requires_grad=False (or else they raise an error) + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + return t + + +def rebuild_meta_tensor( + tensor_cls, + tensor_size, + tensor_stride, + tensor_offset, + dtype, + storage_size_bytes, + requires_grad, +): + untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta") + + typed_storage = torch.TypedStorage( + wrap_storage=untyped_storage, dtype=dtype, _internal=True + ) + + t = torch._utils._rebuild_tensor( + typed_storage, + tensor_offset, + tensor_size, + tensor_stride, + ) + + if tensor_cls == torch.nn.parameter.Parameter: + # It is crucial for integer tensors to receive + # the requires_grad=False as an argument in the constructor + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + + return t + + +def rebuild_cuda_tensor( + tensor_cls, + tensor_size, + tensor_stride, + tensor_offset, + storage_cls, + dtype, + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, +): + # If storage_handle is None, storage points to nullptr. + if storage_handle is None or storage_size_bytes == 0: + storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) + else: + storage = storage_from_cache( + storage_cls, (storage_handle, storage_offset_bytes) + ) + if storage is None: + torch.cuda._lazy_init() + storage = storage_cls._new_shared_cuda( + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) + shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef( + storage + ) + else: + # We already ref counting this Storage, but producer needs new ref-counters to be released. + storage_cls._release_ipc_counter( + ref_counter_handle, ref_counter_offset, device=storage_device + ) + + _storage = ( + storage + if isinstance(storage, torch.UntypedStorage) + else storage._untyped_storage + ) + + t = torch._utils._rebuild_tensor( + torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), + tensor_offset, + tensor_size, + tensor_stride, + ) + + if tensor_cls == torch.nn.parameter.Parameter: + # It is crucial for integer tensors to receive + # the requires_grad=False as an argument in the constructor + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + + return t + + +def reduce_tensor(tensor): + if tensor.requires_grad and not tensor.is_leaf: + raise RuntimeError( + "Cowardly refusing to serialize non-leaf tensor which requires_grad, " + "since autograd does not support crossing process boundaries. " + "If you just want to transfer the data, call detach() on the tensor " + "before serializing (e.g., putting it on the queue)." + ) + + check_serializing_named_tensor(tensor) + torch.utils.hooks.warn_if_has_hooks(tensor) + + # Note [CUDA IPC and the caching allocator] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # When you send a CUDA tensor over IPC, you might expect that you will + # get out the same storage from the other end. However, the CUDA caching + # allocator makes it difficult to preserve this invariant. Consider + # the following situation: a tensor of size 0x100 points to offset 0x20 of + # a storage at 0xA100 of size 0x100. (For simplicity, all of these + # sizes are given in bytes). HOWEVER, with the caching allocator, this storage + # might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000. + # + # When we want to send this CUDA tensor over IPC, we must send the + # *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just + # the storage 0xA100 (because that is what CUDA supports). So, on the + # other end, there simply isn't any way to say, "Wait, you gave me + # a bigger region (0xA000) than the one I wanted (0xA100)". + # + # OK, so if you sent the cudaMalloc allocation, can you just wrap that up as + # one storage itself? No, because this cudaMalloc allocation might contain + # storages of mixed types: float, bytes, double... If you make the entire + # allocation a single storage of a type A, we'll hit an error when constructing + # a tensor of type B on the storage. + # + # cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the + # receiver side. However, cudaIpcMemHandles from each device in a given process may + # only be opened by one context per device per other process. + # If we open and close a memory handle multiples times in a process, CUDA is allowed + # to give it a different address; similarly, once we close the memory, we're not + # allowed to access it(and the storage/tensor built on top of it), even if it is + # still live in the original process. As we cannot make a cudaMalloc allocation + # to a single storage in one go, this requires us to cache the device pointer for + # each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep + # the old ones alives. + # See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html] + # + # This is fine, because all we need to do is to save our position in the allocation, + # and reconstruct storage and tensor from it. + # 0xA000 -> -------CUDA Allocation------ + # | | + # | | + # | | + # | | + # 0xA100 -> --------storage1 begin------ + # | | + # 0xA120 -> --------tensor1 begin ------ + # | | + # | | + # | | + # | | + # | | + # 0xA160 -> --------tensor1 end--------- + # | | + # | | + # | | + # 0xA200 -> --------storage1 end-------- + # | | + # 0xE000 -> --------CUDA allocation----- + # + # To send tensor1, the following info are required from sender to receiver for + # storage recontruction. + # 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process). + # basePtr may not be exactly 0xA000 since it's a different process. + # 2. offset(0xA100) of storage1 in the CUDA allocation. + # 3. size of storage1(0x100). + # + # On receiver side: + # 1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage + # of the same type using (basePtr, offset, size). + # 2. we can reconstruct the tensor on top of the reconstructed storage + # Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100)) + # + # This strategy has a few implications: + # + # 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one + # go (non-compositionally), and this requires to have a global map + # memHandle -> devPtr for each process. + # + # 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize + # of the storage beyond 0x100 would merely have caused us to do a + # reallocation. You don't really want to do this, but if you did, + # all that would happen is that you would lose IPC sharing. But if + # you do this in the new world, we will happily let you write out of + # bounds of your "allocation", clobbering unrelated data in the cached + # allocator block. BAD! + # + # By the way, in old versions of PyTorch, we supported this situation + # natively using a "storage view", which permitted multiple storages to be + # views on each other. But this was the *only* use of storage views, so we + # eliminated it so that we could just use tensor views to implement the same + # thing. + # + + # TODO: Handle distinguishing between subclass and non-subclass versions of NT better + # https://github.com/pytorch/pytorch/issues/110543 + from torch.nested._internal.nested_tensor import NestedTensor + + if tensor.is_nested and not isinstance(tensor, NestedTensor): + return reduce_nested_tensor(tensor) + + if tensor.layout in { + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_bsr, + torch.sparse_csc, + torch.sparse_bsc, + }: + return reduce_sparse_tensor(tensor) + + storage = tensor._typed_storage() + + if storage._untyped_storage.device.type == "cuda": + ( + device, + handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) = storage._share_cuda_() + tensor_offset = tensor.storage_offset() + shared_cache[handle] = StorageWeakRef(storage) + # _backward_hooks purposely omitted here, see + # Note [Don't serialize hooks] + return ( + rebuild_cuda_tensor, + ( + type(tensor), + tensor.size(), + tensor.stride(), + tensor_offset, # tensor offset in its storage + type(storage), + tensor.dtype, + device, + handle, # identifier which CUDA allocation is the storage in. + storage_size_bytes, # size(in bytes) of the storage + storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation + tensor.requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ), + ) + elif storage._untyped_storage.device.type == "meta": + return ( + rebuild_meta_tensor, + ( + type(tensor), + tensor.size(), + tensor.stride(), + tensor.storage_offset(), + tensor.dtype, + tensor.untyped_storage().size(), + tensor.requires_grad, + ), + ) + + # _backward_hooks purposely omitted here, see Note [Don't serialize hooks] + metadata = ( + tensor.storage_offset(), + tensor.size(), + tensor.stride(), + tensor.requires_grad, + ) + return (rebuild_tensor, (type(tensor), storage, metadata)) + + +def rebuild_nested_tensor( + rebuild_buffer_func, + rebuild_buffer_args, + rebuild_sizes_func, + rebuild_sizes_args, + rebuild_strides_func, + rebuild_strides_args, + rebuild_offsets_func, + rebuild_offsets_args, +): + buffer = rebuild_buffer_func(*rebuild_buffer_args) + sizes = rebuild_sizes_func(*rebuild_sizes_args) + strides = rebuild_strides_func(*rebuild_strides_args) + offsets = rebuild_offsets_func(*rebuild_offsets_args) + return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets) + + +def reduce_nested_tensor(nt): + rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values()) + rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size()) + rebuild_strides_func, rebuild_strides_args = reduce_tensor( + nt._nested_tensor_strides() + ) + rebuild_offsets_func, rebuild_offsets_args = reduce_tensor( + nt._nested_tensor_storage_offsets() + ) + + return ( + rebuild_nested_tensor, + ( + rebuild_buffer_func, + rebuild_buffer_args, + rebuild_sizes_func, + rebuild_sizes_args, + rebuild_strides_func, + rebuild_strides_args, + rebuild_offsets_func, + rebuild_offsets_args, + ), + ) + + +def rebuild_sparse_coo_tensor( + rebuild_indices_func, + rebuild_indices_args, + rebuild_values_func, + rebuild_values_args, + shape, + is_coalesced, +): + indices = rebuild_indices_func(*rebuild_indices_args) + values = rebuild_values_func(*rebuild_values_args) + return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced) + + +def rebuild_sparse_compressed_tensor( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + rebuild_plain_indices_func, + rebuild_plain_indices_args, + rebuild_values_func, + rebuild_values_args, + shape, + layout, +): + compressed_indices = rebuild_compressed_indices_func( + *rebuild_compressed_indices_args + ) + plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args) + values = rebuild_values_func(*rebuild_values_args) + return torch.sparse_compressed_tensor( + compressed_indices, plain_indices, values, shape, layout=layout + ) + + +def reduce_sparse_tensor(sparse): + if sparse.layout is torch.sparse_coo: + rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices()) + rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values()) + return ( + rebuild_sparse_coo_tensor, + ( + rebuild_indices_func, + rebuild_indices_args, + rebuild_values_func, + rebuild_values_args, + sparse.shape, + sparse.is_coalesced(), + ), + ) + else: + if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices = sparse.crow_indices() + plain_indices = sparse.col_indices() + elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}: + compressed_indices = sparse.ccol_indices() + plain_indices = sparse.row_indices() + else: + raise NotImplementedError(sparse.layout) + ( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + ) = reduce_tensor(compressed_indices) + rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor( + plain_indices + ) + rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values()) + return ( + rebuild_sparse_compressed_tensor, + ( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + rebuild_plain_indices_func, + rebuild_plain_indices_args, + rebuild_values_func, + rebuild_values_args, + sparse.shape, + sparse.layout, + ), + ) + + +def fd_id(fd): + # Returns a tuple which uniquely identifies a file descriptor. In Mac OS, + # this doesn't work with shared memory handles, which is why we don't + # support the "file_descriptor" sharing method on that platform. + stat = os.fstat(fd) + return (stat.st_ino, stat.st_dev) + + +def storage_from_cache(cls, key): + storage_ref = shared_cache.get(key) + if storage_ref is None: + return None + return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata) + + +def rebuild_storage_fd(cls, df, size): + fd = df.detach() + try: + storage = storage_from_cache(cls, fd_id(fd)) + if storage is not None: + return storage + storage = cls._new_shared_fd_cpu(fd, size) + shared_cache[fd_id(fd)] = StorageWeakRef(storage) + return storage + finally: + os.close(fd) + + +def rebuild_storage_filename(cls, manager, handle, size, dtype=None): + storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache( + cls, handle + ) + if storage is not None: + return storage._shared_decref() + if dtype is None: + storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size) + else: + byte_size = size * torch._utils._element_size(dtype) + untyped_storage: torch.UntypedStorage = ( + torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size) + ) + storage = torch.TypedStorage( + wrap_storage=untyped_storage, dtype=dtype, _internal=True + ) + shared_cache[handle] = StorageWeakRef(storage) + return storage._shared_decref() + + +def rebuild_storage_empty(cls): + return cls() + + +def rebuild_typed_storage(storage, dtype): + return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True) + + +# Use for torch.storage.TypedStorage +def reduce_typed_storage(storage): + return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype)) + + +def rebuild_typed_storage_child(storage, storage_type): + return storage_type(wrap_storage=storage, _internal=True) + + +# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage +def reduce_typed_storage_child(storage): + return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage))) + + +def reduce_storage(storage): + from . import get_sharing_strategy + + if storage.is_cuda: + raise RuntimeError( + "Cannot pickle CUDA storage; try pickling a CUDA tensor instead" + ) + elif storage.device.type == "meta": + raise RuntimeError( + "Cannot pickle meta storage; try pickling a meta tensor instead" + ) + elif get_sharing_strategy() == "file_system": + metadata = storage._share_filename_cpu_() + cache_key = metadata[1] + rebuild = rebuild_storage_filename + if isinstance(storage, torch.TypedStorage): + metadata += (storage.dtype,) + storage._shared_incref() + elif storage.size() == 0: + # This is special cased because Empty tensors + # (with size 0) cannot be mmapped. + return (rebuild_storage_empty, (type(storage),)) + else: + fd, size = storage._share_fd_cpu_() + df = multiprocessing.reduction.DupFd(fd) + cache_key = fd_id(fd) + metadata = (df, size) + rebuild = rebuild_storage_fd # type: ignore[assignment] + + shared_cache[cache_key] = StorageWeakRef(storage) + return (rebuild, (type(storage),) + metadata) + + +def init_reductions(): + ForkingPickler.register(torch.cuda.Event, reduce_event) + + for t in torch._storage_classes: + if t.__name__ == "UntypedStorage": + ForkingPickler.register(t, reduce_storage) + else: + ForkingPickler.register(t, reduce_typed_storage_child) + + ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage) + + for t in torch._tensor_classes: + ForkingPickler.register(t, reduce_tensor) + + # TODO: Maybe this should be in tensor_classes? :) + ForkingPickler.register(torch.Tensor, reduce_tensor) + + from torch.nn.parameter import Parameter + + ForkingPickler.register(Parameter, reduce_tensor) diff --git a/lib/python3.10/site-packages/torch/multiprocessing/spawn.py b/lib/python3.10/site-packages/torch/multiprocessing/spawn.py new file mode 100644 index 0000000000000000000000000000000000000000..74bdde0fd97b20355686fc49fdb50a8fe02c5006 --- /dev/null +++ b/lib/python3.10/site-packages/torch/multiprocessing/spawn.py @@ -0,0 +1,328 @@ +# mypy: allow-untyped-defs +import logging +import multiprocessing +import multiprocessing.connection +import os +import pickle +import signal +import sys +import tempfile +import time +import warnings +from concurrent.futures import as_completed, ThreadPoolExecutor +from typing import Optional + +from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] + + +ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START" + +log = logging.getLogger(__name__) + +__all__ = [ + "ProcessContext", + "ProcessException", + "ProcessExitedException", + "ProcessRaisedException", + "spawn", + "SpawnContext", + "start_processes", +] + + +class ProcessException(Exception): + __slots__ = ["error_index", "error_pid"] + + def __init__(self, msg: str, error_index: int, pid: int): + super().__init__(msg) + self.msg = msg + self.error_index = error_index + self.pid = pid + + def __reduce__(self): + return type(self), (self.msg, self.error_index, self.pid) + + +class ProcessRaisedException(ProcessException): + """Exception raised when a process failed due to an exception raised by the code.""" + + def __init__( + self, + msg: str, + error_index: int, + error_pid: int, + ): + super().__init__(msg, error_index, error_pid) + + +class ProcessExitedException(ProcessException): + """Exception raised when a process failed due to signal or exited with a specific code.""" + + __slots__ = ["exit_code"] + + def __init__( + self, + msg: str, + error_index: int, + error_pid: int, + exit_code: int, + signal_name: Optional[str] = None, + ): + super().__init__(msg, error_index, error_pid) + self.exit_code = exit_code + self.signal_name = signal_name + + def __reduce__(self): + return ( + type(self), + (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name), + ) + + +def _wrap(fn, i, args, error_file): + # prctl(2) is a Linux specific system call. + # On other systems the following function call has no effect. + # This is set to ensure that non-daemonic child processes can + # terminate if their parent terminates before they do. + _prctl_pr_set_pdeathsig(signal.SIGINT) + + try: + fn(i, *args) + except KeyboardInterrupt: + pass # SIGINT; Killed by parent, do nothing + except Exception: + # Propagate exception to parent process, keeping original traceback + import traceback + + with open(error_file, "wb") as fh: + pickle.dump(traceback.format_exc(), fh) + sys.exit(1) + + +class ProcessContext: + def __init__(self, processes, error_files): + self.error_files = error_files + self.processes = processes + self.sentinels = { + process.sentinel: index for index, process in enumerate(processes) + } + + def pids(self): + return [int(process.pid) for process in self.processes] + + def join(self, timeout=None): + r"""Join one or more processes within spawn context. + + Attempt to join one or more processes in this spawn context. + If one of them exited with a non-zero exit status, this function + kills the remaining processes and raises an exception with the cause + of the first process exiting. + + Returns ``True`` if all processes have been joined successfully, + ``False`` if there are more processes that need to be joined. + + Args: + timeout (float): Wait this long before giving up on waiting. + """ + # Ensure this function can be called even when we're done. + if len(self.sentinels) == 0: + return True + + # Wait for any process to fail or all of them to succeed. + ready = multiprocessing.connection.wait( + self.sentinels.keys(), + timeout=timeout, + ) + + error_index = None + for sentinel in ready: + index = self.sentinels.pop(sentinel) + process = self.processes[index] + process.join() + if process.exitcode != 0: + error_index = index + break + + # Return if there was no error. + if error_index is None: + # Return whether or not all processes have been joined. + return len(self.sentinels) == 0 + + # Assume failure. Terminate processes that are still alive. + # Try SIGTERM then SIGKILL if the process isn't going down. + # The reason is related to python signal handling is limited + # to main thread and if that is in c/c++ land and stuck it won't + # to handle it. We have seen processes getting stuck not handling + # SIGTERM for the above reason. + timeout: int = 30 + for process in self.processes: + if process.is_alive(): + log.warning("Terminating process %s via signal SIGTERM", process.pid) + process.terminate() + end = time.monotonic() + timeout + for process in self.processes: + time_to_wait = max(0, end - time.monotonic()) + process.join(time_to_wait) + for process in self.processes: + if process.is_alive(): + log.warning( + "Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL", + process.pid, + ) + process.kill() + process.join() + + # The file will only be created if the process crashed. + failed_process = self.processes[error_index] + if not os.access(self.error_files[error_index], os.R_OK): + exitcode = self.processes[error_index].exitcode + if exitcode < 0: + try: + name = signal.Signals(-exitcode).name + except ValueError: + name = f"" + raise ProcessExitedException( + "process %d terminated with signal %s" % (error_index, name), + error_index=error_index, + error_pid=failed_process.pid, + exit_code=exitcode, + signal_name=name, + ) + else: + raise ProcessExitedException( + "process %d terminated with exit code %d" % (error_index, exitcode), + error_index=error_index, + error_pid=failed_process.pid, + exit_code=exitcode, + ) + + with open(self.error_files[error_index], "rb") as fh: + original_trace = pickle.load(fh) + msg = "\n\n-- Process %d terminated with the following error:\n" % error_index + msg += original_trace + raise ProcessRaisedException(msg, error_index, failed_process.pid) + + +class SpawnContext(ProcessContext): + def __init__(self, processes, error_files): + warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.") + super().__init__(processes, error_files) + + +# Note: [start_processes] +# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a +# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the +# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork' +# works better than 'spawn'. Every helper function we created for mp.spawn is indeed +# general enough, and backends like XLA can reuse them in Colab notebooks as well. +# Currently we only add this API first, we can consider adding it to documentation as +# needed in the future. +def start_processes( + fn, + args=(), + nprocs=1, + join=True, + daemon=False, + start_method="spawn", +): + # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010), + # this func will start processes in parallel if start_method is 'forkserver'. + # Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1. + # todo: investigate why spawn does not work with threadpool and raises SIGINT + if ( + start_method == "forkserver" + and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1" + ): + log.info("Starting processes in parallel.") + start_parallel = True + else: + # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start + start_parallel = False + + mp = multiprocessing.get_context(start_method) + error_files = [None] * nprocs + processes = [None] * nprocs + + def start_process(i): + # Each process is assigned a file to write tracebacks to. We + # use the file being non-empty to indicate an exception + # occurred (vs an expected shutdown). Note: this previously + # used a multiprocessing.Queue but that can be prone to + # deadlocks, so we went with a simpler solution for a one-shot + # message between processes. + tf = tempfile.NamedTemporaryFile( + prefix="pytorch-errorfile-", suffix=".pickle", delete=False + ) + tf.close() + os.unlink(tf.name) + process = mp.Process( + target=_wrap, + args=(fn, i, args, tf.name), + daemon=daemon, + ) + process.start() + return i, process, tf.name + + if not start_parallel: + for i in range(nprocs): + idx, process, tf_name = start_process(i) + error_files[idx] = tf_name + processes[idx] = process + else: + with ThreadPoolExecutor(max_workers=nprocs) as executor: + futures = [executor.submit(start_process, i) for i in range(nprocs)] + for fut in as_completed(futures): + idx, process, tf_name = fut.result() + # idx and process rank needs to be the same. + error_files[idx] = tf_name + processes[idx] = process + context = ProcessContext(processes, error_files) + if not join: + return context + + # Loop on join until it returns True or raises an exception. + while not context.join(): + pass + + +def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"): + r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``. + + If one of the processes exits with a non-zero exit status, the + remaining processes are killed and an exception is raised with the + cause of termination. In the case an exception was caught in the + child process, it is forwarded and its traceback is included in + the exception raised in the parent process. + + Args: + fn (function): Function is called as the entrypoint of the + spawned process. This function must be defined at the top + level of a module so it can be pickled and spawned. This + is a requirement imposed by multiprocessing. + + The function is called as ``fn(i, *args)``, where ``i`` is + the process index and ``args`` is the passed through tuple + of arguments. + + args (tuple): Arguments passed to ``fn``. + nprocs (int): Number of processes to spawn. + join (bool): Perform a blocking join on all processes. + daemon (bool): The spawned processes' daemon flag. If set to True, + daemonic processes will be created. + start_method (str): (deprecated) this method will always use ``spawn`` + as the start method. To use a different start method + use ``start_processes()``. + + Returns: + None if ``join`` is ``True``, + :class:`~ProcessContext` if ``join`` is ``False`` + + """ + if start_method != "spawn": + msg = ( + f"This method only supports start_method=spawn (got: {start_method}).\n" + "To use a different start_method use:\n\t\t" + " torch.multiprocessing.start_processes(...)" + ) + warnings.warn(msg, FutureWarning, stacklevel=2) + return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") diff --git a/lib/python3.10/site-packages/torch/nested/__init__.py b/lib/python3.10/site-packages/torch/nested/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38baafa0cf951d7d994c2629a0332b2e1f2496a1 --- /dev/null +++ b/lib/python3.10/site-packages/torch/nested/__init__.py @@ -0,0 +1,465 @@ +# mypy: allow-untyped-defs +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import SymInt, Tensor +from torch._C import _add_docstr, _nested # type: ignore[attr-defined] + +from torch.types import _device as Device, _dtype as DType + +__all__ = [ + "to_padded_tensor", + "as_nested_tensor", + "nested_tensor", + "nested_tensor_from_jagged", + "narrow", + "masked_select", +] + +# Nested Tensor constructor functions + + +def as_nested_tensor( + ts: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], + dtype: Optional[DType] = None, + device: Optional[Device] = None, + layout=None +) -> Tensor: + r""" + Constructs a nested tensor preserving autograd history from a tensor or a list / tuple of + tensors. + + If a nested tensor is passed, it will be returned directly unless the device / dtype / layout + differ. Note that converting device / dtype will result in a copy, while converting layout + is not currently supported by this function. + + If a non-nested tensor is passed, it is treated as a batch of constituents of consistent size. + A copy will be incurred if the passed device / dtype differ from those of the input OR if + the input is non-contiguous. Otherwise, the input's storage will be used directly. + + If a tensor list is provided, tensors in the list are always copied during construction of + the nested tensor. + + Args: + ts (Tensor or List[Tensor] or Tuple[Tensor]): a tensor to treat as a nested tensor OR a + list / tuple of tensors with the same ndim + + Keyword arguments: + dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor. + Default: if None, same :class:`torch.dtype` as leftmost tensor in the list. + device (:class:`torch.device`, optional): the desired device of returned nested tensor. + Default: if None, same :class:`torch.device` as leftmost tensor in the list + layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. + Only strided and jagged layouts are supported. Default: if None, the strided layout. + + Example:: + + >>> a = torch.arange(3, dtype=torch.float, requires_grad=True) + >>> b = torch.arange(5, dtype=torch.float, requires_grad=True) + >>> nt = torch.nested.as_nested_tensor([a, b]) + >>> nt.is_leaf + False + >>> fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)]) + >>> nt.backward(fake_grad) + >>> a.grad + tensor([1., 1., 1.]) + >>> b.grad + tensor([0., 0., 0., 0., 0.]) + >>> c = torch.randn(3, 5, requires_grad=True) + >>> nt2 = torch.nested.as_nested_tensor(c) + """ + is_tensor_list = isinstance(ts, (list, tuple)) and all(isinstance(t, Tensor) for t in ts) + if not isinstance(ts, Tensor) and not is_tensor_list: + raise TypeError( + "as_nested_tensor(): Expected first argument to be a tensor or a list / tuple of tensors " + ) + # convert tuple -> list if needed + if is_tensor_list and not isinstance(ts, list): + ts = list(ts) + + if isinstance(ts, Tensor) and ts.dim() < 2: + raise RuntimeError("as_nested_tensor(): Expected tensor argument to have dim() > 1") + + if isinstance(ts, Tensor) and ts.is_nested: + if layout == ts.layout: + # return input directly or input copied to device / dtype + return ts.to(device=device, dtype=dtype) + else: + # TODO: Just use nt.to(layout=layout) when it exists. + raise RuntimeError( + "as_nested_tensor(): Converting between nested tensor layouts is not supported") + + if layout is None: + layout = torch.strided + if layout == torch.strided: + if isinstance(ts, Tensor): + # contiguous() might be necessary to get flattened view. + # we could probably be more precise about when to do this as an optimization + buffer = ts.contiguous().view(-1).to(device=device, dtype=dtype) + nested_sizes = torch.tensor([t.shape for t in ts]) + return torch._nested_view_from_buffer( + buffer, + nested_sizes, + *torch._nested_compute_contiguous_strides_offsets(nested_sizes)) + else: + assert isinstance(ts, list) + return torch._nested_tensor_from_tensor_list(ts, dtype, None, device, None) + elif layout == torch.jagged: + if isinstance(ts, Tensor): + if device is None: + device = ts.device + + # contiguous() might be necessary to get flattened view. + # we could probably be more precise about when to do this as an optimization + values = ts.contiguous().flatten(0, 1).to(device=device, dtype=dtype) + batch_size = ts.shape[0] + seq_len = ts.shape[1] + offsets = torch.arange(0, batch_size * seq_len + 1, seq_len, + device=device, dtype=torch.int64) + + from torch.nested._internal.nested_tensor import nested_view_from_values_offsets + + return nested_view_from_values_offsets( + values, offsets, min_seqlen=seq_len, max_seqlen=seq_len + ) + else: + from torch.nested._internal.nested_tensor import jagged_from_list + + assert isinstance(ts, list) + nt, _ = jagged_from_list(ts, offsets=None, device=device, dtype=dtype) + return nt + else: + raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}") + + +# Note: This not only adds doc strings for the nested ops, but +# also connects the torch.nested Python namespace to the torch._C._nested builtins. + +to_padded_tensor = _add_docstr( + _nested.nested_to_padded_tensor, + r""" +to_padded_tensor(input, padding, output_size=None, out=None) -> Tensor + +Returns a new (non-nested) Tensor by padding the :attr:`input` nested tensor. +The leading entries will be filled with the nested data, +while the trailing entries will be padded. + +.. warning:: + + :func:`to_padded_tensor` always copies the underlying data, + since the nested and the non-nested tensors differ in memory layout. + +Args: + padding (float): The padding value for the trailing entries. + +Keyword args: + output_size (Tuple[int]): The size of the output tensor. + If given, it must be large enough to contain all nested data; + else, will infer by taking the max size of each nested sub-tensor along each dimension. + out (Tensor, optional): the output tensor. + +Example:: + + >>> nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))]) + nested_tensor([ + tensor([[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276], + [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995]]), + tensor([[-1.8546, -0.7194, -0.2918, -0.1846], + [ 0.2773, 0.8793, -0.5183, -0.6447], + [ 1.8009, 1.8468, -0.9832, -1.5272]]) + ]) + >>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0) + tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276], + [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], + [[-1.8546, -0.7194, -0.2918, -0.1846, 0.0000], + [ 0.2773, 0.8793, -0.5183, -0.6447, 0.0000], + [ 1.8009, 1.8468, -0.9832, -1.5272, 0.0000]]]) + >>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6)) + tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276, 1.0000], + [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995, 1.0000], + [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], + [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]], + [[-1.8546, -0.7194, -0.2918, -0.1846, 1.0000, 1.0000], + [ 0.2773, 0.8793, -0.5183, -0.6447, 1.0000, 1.0000], + [ 1.8009, 1.8468, -0.9832, -1.5272, 1.0000, 1.0000], + [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]]) + >>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2)) + RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported. + +""", +) + +def nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor: + r""" +Constructs a nested tensor with no autograd history (also known as a "leaf tensor", see +:ref:`Autograd mechanics `) from :attr:`tensor_list` a list of tensors. + +Args: + tensor_list (List[array_like]): a list of tensors, or anything that can be passed to torch.tensor, + where each element of the list has the same dimensionality. + +Keyword arguments: + dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor. + Default: if None, same :class:`torch.dtype` as leftmost tensor in the list. + layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. + Only strided and jagged layouts are supported. Default: if None, the strided layout. + device (:class:`torch.device`, optional): the desired device of returned nested tensor. + Default: if None, same :class:`torch.device` as leftmost tensor in the list + requires_grad (bool, optional): If autograd should record operations on the + returned nested tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned nested tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + +Example:: + + >>> a = torch.arange(3, dtype=torch.float, requires_grad=True) + >>> b = torch.arange(5, dtype=torch.float, requires_grad=True) + >>> nt = torch.nested.nested_tensor([a, b], requires_grad=True) + >>> nt.is_leaf + True + """ + if layout is None: + layout = torch.strided + if layout == torch.strided: + return _nested.nested_tensor( + tensor_list, + dtype=dtype, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory) + elif layout == torch.jagged: + # Need to wrap lists of scalars as tensors + list_of_tensors = [t if isinstance(t, Tensor) else torch.as_tensor(t) for t in tensor_list] + + from torch.nested._internal.nested_tensor import jagged_from_list + + with torch.no_grad(): + nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype) + + nt.requires_grad_(requires_grad) + if pin_memory: + nt = nt.pin_memory() # type: ignore[assignment] + + return nt + else: + raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}") + + +def narrow(tensor: Tensor, dim: int, start: Union[int, Tensor], length: Union[int, Tensor], layout=torch.strided) -> Tensor: + r""" +Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows +similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor +shows only the elements in the interval `[start, start+length)`. As nested representations +allow for a different `start` and `length` at each 'row' of that dimension, :attr:`start` and :attr:`length` +can also be tensors of shape `tensor.shape[0]`. + +There's some differences depending on the layout you use for the nested tensor. If using strided layout, +torch.narrow will do a copy of the narrowed data into a contiguous NT with strided layout, while +jagged layout narrow() will create a non-contiguous view of your original strided tensor. This particular +representation is really useful for representing kv-caches in Transformer models, as specialized +SDPA kernels can deal with format easily, resulting in performance improvements. + + +Args: + tensor (:class:`torch.Tensor`): a strided tensor, which will be used as the underlying data + for the nested tensor if using the jagged layout or will be copied for the strided layout. + dim (int): the dimension where narrow will be applied. Only `dim=1` is supported for the + jagged layout, while strided supports all dim + start (Union[int, :class:`torch.Tensor`]): starting element for the narrow operation + length (Union[int, :class:`torch.Tensor`]): number of elements taken during the narrow op + +Keyword arguments: + layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. + Only strided and jagged layouts are supported. Default: if None, the strided layout. + +Example:: + + >>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64) + >>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64) + >>> narrow_base = torch.randn(5, 10, 20) + >>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged) + >>> nt_narrowed.is_contiguous() + False + """ + if not isinstance(start, (int, SymInt, Tensor)): + raise RuntimeError("start must be an integer or a tensor") + + if not isinstance(length, (int, SymInt, Tensor)): + raise RuntimeError("length must be an integer or a tensor") + + if layout == torch.strided: + if isinstance(start, Tensor) or isinstance(length, Tensor): + raise RuntimeError("start and length must be integers for the strided layout NT impl") + # TODO: switch to as_nested_tensor(tensor) when it is available + nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow(dim, start, length) + elif layout == torch.jagged: + if dim != 1: + raise RuntimeError("jagged layout only supports dim=1") + + from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths + + if isinstance(start, (int, SymInt)): + start = torch.tensor([start], device=tensor.device, dtype=torch.int64) + + if isinstance(length, (int, SymInt)): + length = torch.tensor([length], device=tensor.device, dtype=torch.int64) + + nt, _, _ = jagged_from_tensor_and_lengths(tensor, start, length) + else: + raise RuntimeError(f"Specified layout is unsupported for nested narrow: {layout}") + + return nt + + +def nested_tensor_from_jagged( + values: Tensor, + offsets: Optional[Tensor] = None, + lengths: Optional[Tensor] = None, + jagged_dim: Optional[int] = None, + min_seqlen: Optional[int] = None, + max_seqlen: Optional[int] = None, +) -> Tensor: + r""" +Constructs a jagged layout nested tensor from the given jagged components. The jagged layout +consists of a required values buffer with the jagged dimension packed into a single dimension. +The offsets / lengths metadata determines how this dimension is split into batch elements +and are expected to be allocated on the same device as the values buffer. + +Expected metadata formats: + * offsets: Indices within the packed dimension splitting it into heterogeneously-sized + batch elements. Example: [0, 2, 3, 6] indicates that a packed jagged dim of size 6 + should be conceptually split into batch elements of length [2, 1, 3]. Note that both the + beginning and ending offsets are required for kernel convenience (i.e. shape batch_size + 1). + * lengths: Lengths of the individual batch elements; shape == batch_size. Example: [2, 1, 3] + indicates that a packed jagged dim of size 6 should be conceptually split into batch + elements of length [2, 1, 3]. + +Note that it can be useful to provide both offsets and lengths. This describes a nested tensor +with "holes", where the offsets indicate the start position of each batch item and the length +specifies the total number of elements (see example below). + +The returned jagged layout nested tensor will be a view of the input values tensor. + +Args: + values (:class:`torch.Tensor`): The underlying buffer in the shape of + (sum_B(*), D_1, ..., D_N). The jagged dimension is packed into a single dimension, + with the offsets / lengths metadata used to distinguish batch elements. + offsets (optional :class:`torch.Tensor`): Offsets into the jagged dimension of shape B + 1. + lengths (optional :class:`torch.Tensor`): Lengths of the batch elements of shape B. + jagged_dim (optional int): Indicates which dimension in values is the packed jagged + dimension. If None, this is set to dim=1 (i.e. the dimension immediately following + the batch dimension). Default: None + min_seqlen (optional int): If set, uses the specified value as the cached minimum sequence + length for the returned nested tensor. This can be a useful alternative to computing + this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None + max_seqlen (optional int): If set, uses the specified value as the cached maximum sequence + length for the returned nested tensor. This can be a useful alternative to computing + this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None + +Example:: + + >>> values = torch.randn(12, 5) + >>> offsets = torch.tensor([0, 3, 5, 6, 10, 12]) + >>> nt = nested_tensor_from_jagged(values, offsets) + >>> # 3D shape with the middle dimension jagged + >>> nt.shape + torch.Size([5, j2, 5]) + >>> # Length of each item in the batch: + >>> offsets.diff() + tensor([3, 2, 1, 4, 2]) + + >>> values = torch.randn(6, 5) + >>> offsets = torch.tensor([0, 2, 3, 6]) + >>> lengths = torch.tensor([1, 1, 2]) + >>> # NT with holes + >>> nt = nested_tensor_from_jagged(values, offsets, lengths) + >>> a, b, c = nt.unbind() + >>> # Batch item 1 consists of indices [0, 1) + >>> torch.equal(a, values[0:1, :]) + True + >>> # Batch item 2 consists of indices [2, 3) + >>> torch.equal(b, values[2:3, :]) + True + >>> # Batch item 3 consists of indices [3, 5) + >>> torch.equal(c, values[3:5, :]) + True + """ + from torch.fx._symbolic_trace import is_fx_tracing + if is_fx_tracing(): + raise RuntimeError( + "torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace. " + "Use fx.wrap to wrap the function that calls nested_tensor_from_jagged." + ) + + if offsets is None: + if lengths is None: + raise RuntimeError( + "nested_tensor_from_jagged(): At least one of offsets or lengths is required." + ) + else: + # TODO: Truly support offsets=None at some point? + # For now, just convert lengths -> offsets for kernel convenience + offsets = F.pad(lengths.cumsum(0), (1, 0)) + lengths = None + + if jagged_dim is None: + jagged_dim = 1 + + from torch.nested._internal.nested_tensor import nested_view_from_values_offsets_lengths + + return nested_view_from_values_offsets_lengths( + values, offsets, lengths, ragged_idx=jagged_dim, min_seqlen=min_seqlen, max_seqlen=max_seqlen) + +def masked_select(tensor: Tensor, mask: Tensor) -> Tensor: + r""" + Constructs a nested tensor given a strided tensor input and a strided mask, the resulting jagged layout nested tensor + will have values retain values where the mask is equal to True. The dimensionality of the mask is preserved and is + represented with the offsets, this is unlike :func:`masked_select` where the output is collapsed to a 1D tensor. + + Args: + tensor (:class:`torch.Tensor`): a strided tensor from which the jagged layout nested tensor is constructed from. + mask (:class:`torch.Tensor`): a strided mask tensor which is applied to the tensor input + + Example:: + + >>> tensor = torch.randn(3, 3) + >>> mask = torch.tensor([[False, False, True], [True, False, True], [False, False, True]]) + >>> nt = torch.nested.masked_select(tensor, mask) + >>> nt.shape + torch.Size([3, j4]) + >>> # Length of each item in the batch: + >>> nt.offsets().diff() + tensor([1, 2, 1]) + + >>> tensor = torch.randn(6, 5) + >>> mask = torch.tensor([False]) + >>> nt = torch.nested.masked_select(tensor, mask) + >>> nt.shape + torch.Size([6, j5]) + >>> # Length of each item in the batch: + >>> nt.offsets().diff() + tensor([0, 0, 0, 0, 0, 0]) + """ + if tensor.layout != torch.strided: + raise RuntimeError( + f"torch.nested.masked_select requires a strided tensor, given {tensor.layout}" + ) + + if mask.layout != torch.strided: + raise RuntimeError( + f"torch.nested.masked_select requires a strided mask, given: {mask.layout}" + ) + res_values = tensor.masked_select(mask) + expanded_mask = mask.expand(tensor.shape) + res_lengths = expanded_mask.sum(dim=tensor.ndim - 1).view(-1) + + from torch.nested._internal.nested_tensor import ( + nested_view_from_values_offsets, + ) + + return nested_view_from_values_offsets( + values=res_values, + offsets=F.pad(res_lengths.cumsum(dim=0), (1, 0)), + ) diff --git a/lib/python3.10/site-packages/torch/nn/__init__.py b/lib/python3.10/site-packages/torch/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36c3ae609bb0871214138888d29501c803705598 --- /dev/null +++ b/lib/python3.10/site-packages/torch/nn/__init__.py @@ -0,0 +1,62 @@ +# mypy: allow-untyped-defs +from torch.nn.parameter import ( # usort: skip + Buffer as Buffer, + Parameter as Parameter, + UninitializedBuffer as UninitializedBuffer, + UninitializedParameter as UninitializedParameter, +) +from torch.nn.modules import * # usort: skip # noqa: F403 +from torch.nn import ( + attention as attention, + functional as functional, + init as init, + modules as modules, + parallel as parallel, + parameter as parameter, + utils as utils, +) +from torch.nn.parallel import DataParallel as DataParallel + + +def factory_kwargs(kwargs): + r"""Return a canonicalized dict of factory kwargs. + + Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed + to factory functions like torch.empty, or errors if unrecognized kwargs are present. + + This function makes it simple to write code like this:: + + class MyModule(nn.Module): + def __init__(self, **kwargs): + factory_kwargs = torch.nn.factory_kwargs(kwargs) + self.weight = Parameter(torch.empty(10, **factory_kwargs)) + + Why should you use this function instead of just passing `kwargs` along directly? + + 1. This function does error validation, so if there are unexpected kwargs we will + immediately report an error, instead of deferring it to the factory call + 2. This function supports a special `factory_kwargs` argument, which can be used to + explicitly specify a kwarg to be used for factory functions, in the event one of the + factory kwargs conflicts with an already existing argument in the signature (e.g. + in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory + functions, as distinct from the dtype argument, by saying + ``f(dtype1, factory_kwargs={"dtype": dtype2})``) + """ + if kwargs is None: + return {} + simple_keys = {"device", "dtype", "memory_format"} + expected_keys = simple_keys | {"factory_kwargs"} + if not kwargs.keys() <= expected_keys: + raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}") + + # guarantee no input kwargs is untouched + r = dict(kwargs.get("factory_kwargs", {})) + for k in simple_keys: + if k in kwargs: + if k in r: + raise TypeError( + f"{k} specified twice, in **kwargs and in factory_kwargs" + ) + r[k] = kwargs[k] + + return r diff --git a/lib/python3.10/site-packages/torch/nn/_reduction.py b/lib/python3.10/site-packages/torch/nn/_reduction.py new file mode 100644 index 0000000000000000000000000000000000000000..93b00dc6feb43df50a95528ab2cb01d1fcac1609 --- /dev/null +++ b/lib/python3.10/site-packages/torch/nn/_reduction.py @@ -0,0 +1,60 @@ +import warnings +from typing import Optional + + +# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h + + +def get_enum(reduction: str) -> int: + if reduction == "none": + ret = 0 + elif reduction == "mean": + ret = 1 + elif reduction == "elementwise_mean": + warnings.warn( + "reduction='elementwise_mean' is deprecated. " + "Please use reduction='mean' instead." + ) + ret = 1 + elif reduction == "sum": + ret = 2 + else: + ret = -1 # TODO: remove once JIT exceptions support control flow + raise ValueError(f"{reduction} is not a valid value for reduction") + return ret + + +# In order to support previous versions, accept boolean size_average and reduce +# and convert them into the new constants for now + + +# We use these functions in torch/legacy as well, in which case we'll silence the warning +def legacy_get_string( + size_average: Optional[bool], + reduce: Optional[bool], + emit_warning: bool = True, +) -> str: + warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." + + if size_average is None: + size_average = True + if reduce is None: + reduce = True + + if size_average and reduce: + ret = "mean" + elif reduce: + ret = "sum" + else: + ret = "none" + if emit_warning: + warnings.warn(warning.format(ret)) + return ret + + +def legacy_get_enum( + size_average: Optional[bool], + reduce: Optional[bool], + emit_warning: bool = True, +) -> int: + return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/lib/python3.10/site-packages/torch/nn/common_types.py b/lib/python3.10/site-packages/torch/nn/common_types.py new file mode 100644 index 0000000000000000000000000000000000000000..74661d604c3e60427048b68d0b7cfa733366bd2b --- /dev/null +++ b/lib/python3.10/site-packages/torch/nn/common_types.py @@ -0,0 +1,44 @@ +from typing import Optional, Tuple, TypeVar, Union + +from torch import Tensor + + +# Create some useful type aliases + +# Template for arguments which can be supplied as a tuple, or which can be a scalar which PyTorch will internally +# broadcast to a tuple. +# Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations. +T = TypeVar("T") +_scalar_or_tuple_any_t = Union[T, Tuple[T, ...]] +_scalar_or_tuple_1_t = Union[T, Tuple[T]] +_scalar_or_tuple_2_t = Union[T, Tuple[T, T]] +_scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]] +_scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]] +_scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]] +_scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]] + +# For arguments which represent size parameters (eg, kernel size, padding) +_size_any_t = _scalar_or_tuple_any_t[int] +_size_1_t = _scalar_or_tuple_1_t[int] +_size_2_t = _scalar_or_tuple_2_t[int] +_size_3_t = _scalar_or_tuple_3_t[int] +_size_4_t = _scalar_or_tuple_4_t[int] +_size_5_t = _scalar_or_tuple_5_t[int] +_size_6_t = _scalar_or_tuple_6_t[int] + +# For arguments which represent optional size parameters (eg, adaptive pool parameters) +_size_any_opt_t = _scalar_or_tuple_any_t[Optional[int]] +_size_2_opt_t = _scalar_or_tuple_2_t[Optional[int]] +_size_3_opt_t = _scalar_or_tuple_3_t[Optional[int]] + +# For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters) +_ratio_2_t = _scalar_or_tuple_2_t[float] +_ratio_3_t = _scalar_or_tuple_3_t[float] +_ratio_any_t = _scalar_or_tuple_any_t[float] + +_tensor_list_t = _scalar_or_tuple_any_t[Tensor] + +# For the return value of max pooling operations that may or may not return indices. +# With the proposed 'Literal' feature to Python typing, it might be possible to +# eventually eliminate this. +_maybe_indices_t = _scalar_or_tuple_2_t[Tensor] diff --git a/lib/python3.10/site-packages/torch/nn/cpp.py b/lib/python3.10/site-packages/torch/nn/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..98a61bfb7c428d024753ea3f6c4dc34fdb3bea08 --- /dev/null +++ b/lib/python3.10/site-packages/torch/nn/cpp.py @@ -0,0 +1,89 @@ +# mypy: allow-untyped-defs +"""Functionality for Python <-> C++ frontend inter-op.""" + +from torch import nn + + +class OrderedDictWrapper: + """A wrapper around a C++ OrderedDict. + + It dynamically evaluates the OrderedDict getter on a bound C++ module, such + that new changes on the C++ side are picked up. Otherwise accessing e.g. + ``cpp_module._parameters`` just once would get a frozen copy of the parameters + at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__`` + so using properties does not work. + """ + + def __init__(self, cpp_module, attr): + self.cpp_module = cpp_module + self.attr = attr + + @property + def cpp_dict(self): + return getattr(self.cpp_module, self.attr) + + # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we + # must manually override them. + + def items(self): + return self.cpp_dict.items() + + def keys(self): + return self.cpp_dict.keys() + + def values(self): + return self.cpp_dict.values() + + def __iter__(self): + return self.cpp_dict.__iter__() + + def __len__(self): + return self.cpp_dict.__len__() + + def __contains__(self, key): + return self.cpp_dict.__contains__(key) + + def __getitem__(self, key): + return self.cpp_dict.__getitem__(key) + + +class ModuleWrapper(nn.Module): + """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access.""" + + def __init__(self, cpp_module): + # Assign before the super class constructor so ``self.training`` can be + # assigned to in the super class constructor. + self.cpp_module = cpp_module + super().__init__() + self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment] + self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment] + self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment] + for attr in dir(cpp_module): + # Skip magic methods and the three attributes above. + if not attr.startswith("_"): + setattr(self, attr, getattr(self.cpp_module, attr)) + + def _apply(self, fn, recurse=True): + for param in self.parameters(): + # Tensors stored in modules are graph leaves, and we don't + # want to create copy nodes, so we have to unpack the data. + param.data = fn(param.data) + if param._grad is not None: + param._grad.data = fn(param._grad.data) + + for buf in self.buffers(): + buf.data = fn(buf.data) + + return self + + # nn.Module defines training as a boolean + @property # type: ignore[override] + def training(self): + return self.cpp_module.training + + @training.setter + def training(self, mode): + self.cpp_module.train(mode) + + def __repr__(self): + return self.cpp_module.__repr__() diff --git a/lib/python3.10/site-packages/torch/nn/functional.py b/lib/python3.10/site-packages/torch/nn/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..3072ac0fef0f299fbb44a030bcc19a253b53e0e0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/nn/functional.py @@ -0,0 +1,6290 @@ +"""Functional interface.""" + +import importlib +import math +import warnings +from typing import Callable, List, Optional, Tuple, TYPE_CHECKING, Union + +import torch +from torch import _VF, sym_int as _sym_int, Tensor +from torch._C import _add_docstr, _infer_size +from torch._jit_internal import ( + _overload, + boolean_dispatch, + BroadcastingList1, + BroadcastingList2, + BroadcastingList3, +) +from torch._torch_docs import reproducibility_notes, sparse_support_notes, tf32_notes +from torch.nn import _reduction as _Reduction, grad # noqa: F401 +from torch.nn.modules.utils import _list_with_default, _pair, _single, _triple +from torch.overrides import ( + handle_torch_function, + has_torch_function, + has_torch_function_unary, + has_torch_function_variadic, +) + + +if TYPE_CHECKING: + from torch.types import _dtype as DType +else: + # The JIT doesn't understand Union, nor torch.dtype here + DType = int + +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +conv1d = _add_docstr( + torch.conv1d, + r""" +conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor + +Applies a 1D convolution over an input signal composed of several input +planes. + +{tf32_note} + +See :class:`~torch.nn.Conv1d` for details and output shape. + +Note: + {cudnn_reproducibility_note} + +Note: + This operator supports complex data types i.e. ``complex32, complex64, complex128``. +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` + weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kW)` + bias: optional bias of shape :math:`(\text{out\_channels})`. Default: ``None`` + stride: the stride of the convolving kernel. Can be a single number or + a one-element tuple `(sW,)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, + single number or a one-element tuple `(padW,)`. Default: 0 + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the same shape as the input. However, this mode + doesn't support any stride values other than 1. + + .. warning:: + For ``padding='same'``, if the ``weight`` is even-length and + ``dilation`` is odd in any dimension, a full :func:`pad` operation + may be needed internally. Lowering performance. + dilation: the spacing between kernel elements. Can be a single number or + a one-element tuple `(dW,)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by + the number of groups. Default: 1 + +Examples:: + + >>> inputs = torch.randn(33, 16, 30) + >>> filters = torch.randn(20, 16, 5) + >>> F.conv1d(inputs, filters) +""", +) + +conv2d = _add_docstr( + torch.conv2d, + r""" +conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor + +Applies a 2D convolution over an input image composed of several input +planes. + +{tf32_note} + +See :class:`~torch.nn.Conv2d` for details and output shape. + +Note: + {cudnn_reproducibility_note} + +Note: + This operator supports complex data types i.e. ``complex32, complex64, complex128``. +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` + bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None`` + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sH, sW)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, + single number or a tuple `(padH, padW)`. Default: 0 + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the same shape as the input. However, this mode + doesn't support any stride values other than 1. + + .. warning:: + For ``padding='same'``, if the ``weight`` is even-length and + ``dilation`` is odd in any dimension, a full :func:`pad` operation + may be needed internally. Lowering performance. + + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dH, dW)`. Default: 1 + groups: split input into groups, both :math:`\text{in\_channels}` and :math:`\text{out\_channels}` + should be divisible by the number of groups. Default: 1 + +Examples:: + + >>> # With square kernels and equal stride + >>> filters = torch.randn(8, 4, 3, 3) + >>> inputs = torch.randn(1, 4, 5, 5) + >>> F.conv2d(inputs, filters, padding=1) +""", +) # noqa: E501 + +conv3d = _add_docstr( + torch.conv3d, + r""" +conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor + +Applies a 3D convolution over an input image composed of several input +planes. + +{tf32_note} + +See :class:`~torch.nn.Conv3d` for details and output shape. + +Note: + {cudnn_reproducibility_note} + +Note: + This operator supports complex data types i.e. ``complex32, complex64, complex128``. +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` + weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kT , kH , kW)` + bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: None + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sT, sH, sW)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, + single number or a tuple `(padT, padH, padW)`. Default: 0 + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the same shape as the input. However, this mode + doesn't support any stride values other than 1. + + .. warning:: + For ``padding='same'``, if the ``weight`` is even-length and + ``dilation`` is odd in any dimension, a full :func:`pad` operation + may be needed internally. Lowering performance. + + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dT, dH, dW)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by + the number of groups. Default: 1 + +Examples:: + + >>> filters = torch.randn(33, 16, 3, 3, 3) + >>> inputs = torch.randn(20, 16, 50, 10, 20) + >>> F.conv3d(inputs, filters) +""", +) # noqa: E501 + +conv_transpose1d = _add_docstr( + torch.conv_transpose1d, + r""" +conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor + +Applies a 1D transposed convolution operator over an input signal +composed of several input planes, sometimes also called "deconvolution". + +{tf32_note} + +See :class:`~torch.nn.ConvTranspose1d` for details and output shape. + +Note: + {cudnn_reproducibility_note} +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` + weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kW)` + bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None + stride: the stride of the convolving kernel. Can be a single number or a + tuple ``(sW,)``. Default: 1 + padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both + sides of each dimension in the input. Can be a single number or a tuple + ``(padW,)``. Default: 0 + output_padding: additional size added to one side of each dimension in the + output shape. Can be a single number or a tuple ``(out_padW)``. Default: 0 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + dilation: the spacing between kernel elements. Can be a single number or + a tuple ``(dW,)``. Default: 1 + +Examples:: + + >>> inputs = torch.randn(20, 16, 50) + >>> weights = torch.randn(16, 33, 5) + >>> F.conv_transpose1d(inputs, weights) +""", +) + +conv_transpose2d = _add_docstr( + torch.conv_transpose2d, + r""" +conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor + +Applies a 2D transposed convolution operator over an input image +composed of several input planes, sometimes also called "deconvolution". + +{tf32_note} + +See :class:`~torch.nn.ConvTranspose2d` for details and output shape. + +Note: + {cudnn_reproducibility_note} +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kH , kW)` + bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None + stride: the stride of the convolving kernel. Can be a single number or a + tuple ``(sH, sW)``. Default: 1 + padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both + sides of each dimension in the input. Can be a single number or a tuple + ``(padH, padW)``. Default: 0 + output_padding: additional size added to one side of each dimension in the + output shape. Can be a single number or a tuple ``(out_padH, out_padW)``. + Default: 0 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + dilation: the spacing between kernel elements. Can be a single number or + a tuple ``(dH, dW)``. Default: 1 + +Examples:: + + >>> # With square kernels and equal stride + >>> inputs = torch.randn(1, 4, 5, 5) + >>> weights = torch.randn(4, 8, 3, 3) + >>> F.conv_transpose2d(inputs, weights, padding=1) +""", +) # noqa: E501 + +conv_transpose3d = _add_docstr( + torch.conv_transpose3d, + r""" +conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor + +Applies a 3D transposed convolution operator over an input image +composed of several input planes, sometimes also called "deconvolution" + +{tf32_note} + +See :class:`~torch.nn.ConvTranspose3d` for details and output shape. + +Note: + {cudnn_reproducibility_note} +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` + weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kT , kH , kW)` + bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None + stride: the stride of the convolving kernel. Can be a single number or a + tuple ``(sT, sH, sW)``. Default: 1 + padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both + sides of each dimension in the input. Can be a single number or a tuple + ``(padT, padH, padW)``. Default: 0 + output_padding: additional size added to one side of each dimension in the + output shape. Can be a single number or a tuple + ``(out_padT, out_padH, out_padW)``. Default: 0 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dT, dH, dW)`. Default: 1 + +Examples:: + + >>> inputs = torch.randn(20, 16, 50, 10, 20) + >>> weights = torch.randn(16, 33, 3, 3, 3) + >>> F.conv_transpose3d(inputs, weights) +""", +) # noqa: E501 + +conv_tbc = _add_docstr( + torch.conv_tbc, + r""" +Applies a 1-dimensional sequence convolution over an input sequence. +Input and output dimensions are (Time, Batch, Channels) - hence TBC. + +Args: + input: input tensor of shape :math:`(\text{sequence length} \times batch \times \text{in\_channels})` + weight: filter of shape (:math:`\text{kernel width} \times \text{in\_channels} \times \text{out\_channels}`) + bias: bias of shape (:math:`\text{out\_channels}`) + pad: number of timesteps to pad. Default: 0 +""", +) + + +# Pooling +avg_pool1d = _add_docstr( + torch.avg_pool1d, + r""" +avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor + +Applies a 1D average pooling over an input signal composed of several +input planes. + +See :class:`~torch.nn.AvgPool1d` for details and output shape. + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` + kernel_size: the size of the window. Can be a single number or a + tuple `(kW,)` + stride: the stride of the window. Can be a single number or a tuple + `(sW,)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padW,)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` to compute the + output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + +Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> input = torch.tensor([[[1, 2, 3, 4, 5, 6, 7]]], dtype=torch.float32) + >>> F.avg_pool1d(input, kernel_size=3, stride=2) + tensor([[[ 2., 4., 6.]]]) + +""", +) + + +avg_pool2d = _add_docstr( + torch._C._nn.avg_pool2d, + r""" +avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor + +Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size +:math:`sH \times sW` steps. The number of output features is equal to the number of +input planes. + +See :class:`~torch.nn.AvgPool2d` for details and output shape. + +Args: + input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padH, padW)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None +""", +) + +avg_pool3d = _add_docstr( + torch._C._nn.avg_pool3d, + r""" +avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor + +Applies 3D average-pooling operation in :math:`kT \times kH \times kW` regions by step +size :math:`sT \times sH \times sW` steps. The number of output features is equal to +:math:`\lfloor\frac{\text{input planes}}{sT}\rfloor`. + +See :class:`~torch.nn.AvgPool3d` for details and output shape. + +Args: + input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iT \times iH , iW)` + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kT, kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padT, padH, padW)`, Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape + count_include_pad: when True, will include the zero-padding in the + averaging calculation + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None +""", +) + + +def fractional_max_pool2d_with_indices( + input: Tensor, + kernel_size: BroadcastingList2[int], + output_size: Optional[BroadcastingList2[int]] = None, + output_ratio: Optional[BroadcastingList2[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) + + Applies 2D fractional max pooling over an input signal composed of several input planes. + + Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham + + The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic + step size determined by the target output size. + The number of output features is equal to the number of input planes. + + Args: + kernel_size: the size of the window to take a max over. + Can be a single number :math:`k` (for a square kernel of :math:`k \times k`) + or a tuple `(kH, kW)` + output_size: the target output size of the image of the form :math:`oH \times oW`. + Can be a tuple `(oH, oW)` or a single number :math:`oH` for a square image :math:`oH \times oH` + output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. + This has to be a number or tuple in the range (0, 1) + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to :func:`~torch.nn.functional.max_unpool2d`. + + Examples:: + >>> input = torch.randn(20, 16, 50, 32) + >>> # pool of square window of size=3, and target output size 13x12 + >>> F.fractional_max_pool2d(input, 3, output_size=(13, 12)) + >>> # pool of square window and target output size being half of input image size + >>> F.fractional_max_pool2d(input, 3, output_ratio=(0.5, 0.5)) + + .. _Fractional MaxPooling: + http://arxiv.org/abs/1412.6071 + """ + if has_torch_function_variadic(input, _random_samples): + return handle_torch_function( + fractional_max_pool2d_with_indices, + (input, _random_samples), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + if output_size is None and output_ratio is None: + raise ValueError( + "fractional_max_pool2d requires specifying either an output_size or an output_ratio" + ) + if output_size is None: + assert output_ratio is not None + if len(output_ratio) > 2: + raise ValueError( + "fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints." + ) + _output_ratio = _pair(output_ratio) + output_size = [ + int(input.size(-2) * _output_ratio[0]), + int(input.size(-1) * _output_ratio[1]), + ] + + if _random_samples is None: + n_batch = 1 if input.dim() == 3 else input.size(0) + _random_samples = torch.rand( + n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device + ) + return torch._C._nn.fractional_max_pool2d( + input, kernel_size, output_size, _random_samples + ) + + +def _fractional_max_pool2d( + input: Tensor, + kernel_size: BroadcastingList2[int], + output_size: Optional[BroadcastingList2[int]] = None, + output_ratio: Optional[BroadcastingList2[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None, +) -> Tensor: + if has_torch_function_variadic(input, _random_samples): + return handle_torch_function( + fractional_max_pool2d, + (input, _random_samples), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + return fractional_max_pool2d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + + +fractional_max_pool2d = boolean_dispatch( + arg_name="return_indices", + arg_index=4, + default=False, + if_true=fractional_max_pool2d_with_indices, + if_false=_fractional_max_pool2d, + module_name=__name__, + func_name="fractional_max_pool2d", +) + + +def fractional_max_pool3d_with_indices( + input: Tensor, + kernel_size: BroadcastingList3[int], + output_size: Optional[BroadcastingList3[int]] = None, + output_ratio: Optional[BroadcastingList3[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) + + Applies 3D fractional max pooling over an input signal composed of several input planes. + + Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham + + The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic + step size determined by the target output size. + The number of output features is equal to the number of input planes. + + Args: + kernel_size: the size of the window to take a max over. + Can be a single number :math:`k` (for a square kernel of :math:`k \times k \times k`) + or a tuple `(kT, kH, kW)` + output_size: the target output size of the form :math:`oT \times oH \times oW`. + Can be a tuple `(oT, oH, oW)` or a single number :math:`oH` for a cubic output + :math:`oH \times oH \times oH` + output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. + This has to be a number or tuple in the range (0, 1) + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to :func:`~torch.nn.functional.max_unpool3d`. + + Shape: + - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where + :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or + :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})` + + Examples:: + >>> input = torch.randn(20, 16, 50, 32, 16) + >>> # pool of cubic window of size=3, and target output size 13x12x11 + >>> F.fractional_max_pool3d(input, 3, output_size=(13, 12, 11)) + >>> # pool of cubic window and target output size being half of input size + >>> F.fractional_max_pool3d(input, 3, output_ratio=(0.5, 0.5, 0.5)) + + .. _Fractional MaxPooling: + http://arxiv.org/abs/1412.6071 + """ + if has_torch_function_variadic(input, _random_samples): + return handle_torch_function( + fractional_max_pool3d_with_indices, + (input, _random_samples), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + if output_size is None and output_ratio is None: + raise ValueError( + "fractional_max_pool3d requires specifying either an output_size or an output_ratio" + ) + if output_size is None: + assert output_ratio is not None + _output_ratio = _triple(output_ratio) + output_size = [ + int(input.size(-3) * _output_ratio[0]), + int(input.size(-2) * _output_ratio[1]), + int(input.size(-1) * _output_ratio[2]), + ] + + if _random_samples is None: + n_batch = 1 if input.dim() == 4 else input.size(0) + _random_samples = torch.rand( + n_batch, input.size(-4), 3, dtype=input.dtype, device=input.device + ) + return torch._C._nn.fractional_max_pool3d( + input, kernel_size, output_size, _random_samples + ) + + +def _fractional_max_pool3d( + input: Tensor, + kernel_size: BroadcastingList3[int], + output_size: Optional[BroadcastingList3[int]] = None, + output_ratio: Optional[BroadcastingList3[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None, +) -> Tensor: + if has_torch_function_variadic(input, _random_samples): + return handle_torch_function( + fractional_max_pool3d, + (input, _random_samples), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + return fractional_max_pool3d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + + +fractional_max_pool3d = boolean_dispatch( + arg_name="return_indices", + arg_index=4, + default=False, + if_true=fractional_max_pool3d_with_indices, + if_false=_fractional_max_pool3d, + module_name=__name__, + func_name="fractional_max_pool3d", +) + + +def max_pool1d_with_indices( + input: Tensor, + kernel_size: BroadcastingList1[int], + stride: Optional[BroadcastingList1[int]] = None, + padding: BroadcastingList1[int] = 0, + dilation: BroadcastingList1[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) + + Applies a 1D max pooling over an input signal composed of several input + planes. + + .. note:: + The order of :attr:`ceil_mode` and :attr:`return_indices` is different from + what seen in :class:`~torch.nn.MaxPool1d`, and will change in a future release. + + See :class:`~torch.nn.MaxPool1d` for details. + + Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`, minibatch dim optional. + kernel_size: the size of the window. Can be a single number or a + tuple `(kW,)` + stride: the stride of the window. Can be a single number or a tuple + `(sW,)`. Default: :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. + return_indices: If ``True``, will return the argmax along with the max values. + Useful for :class:`torch.nn.functional.max_unpool1d` later + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_pool1d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.max_pool1d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + +def _max_pool1d( + input: Tensor, + kernel_size: BroadcastingList1[int], + stride: Optional[BroadcastingList1[int]] = None, + padding: BroadcastingList1[int] = 0, + dilation: BroadcastingList1[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + max_pool1d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool1d = boolean_dispatch( + arg_name="return_indices", + arg_index=6, + default=False, + if_true=max_pool1d_with_indices, + if_false=_max_pool1d, + module_name=__name__, + func_name="max_pool1d", +) + + +def max_pool2d_with_indices( + input: Tensor, + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + padding: BroadcastingList2[int] = 0, + dilation: BroadcastingList2[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) + + Applies a 2D max pooling over an input signal composed of several input + planes. + + .. note:: + The order of :attr:`ceil_mode` and :attr:`return_indices` is different from + what seen in :class:`~torch.nn.MaxPool2d`, and will change in a future release. + + See :class:`~torch.nn.MaxPool2d` for details. + + Args: + input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`, minibatch dim optional. + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sH, sW)`. Default: :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. + return_indices: If ``True``, will return the argmax along with the max values. + Useful for :class:`torch.nn.functional.max_unpool2d` later + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_pool2d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch._C._nn.max_pool2d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + +def _max_pool2d( + input: Tensor, + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + padding: BroadcastingList2[int] = 0, + dilation: BroadcastingList2[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + max_pool2d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool2d = boolean_dispatch( + arg_name="return_indices", + arg_index=6, + default=False, + if_true=max_pool2d_with_indices, + if_false=_max_pool2d, + module_name=__name__, + func_name="max_pool2d", +) + + +def max_pool3d_with_indices( + input: Tensor, + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + padding: BroadcastingList3[int] = 0, + dilation: BroadcastingList3[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) + + Applies a 3D max pooling over an input signal composed of several input + planes. + + .. note:: + The order of :attr:`ceil_mode` and :attr:`return_indices` is different from + what seen in :class:`~torch.nn.MaxPool3d`, and will change in a future release. + + See :class:`~torch.nn.MaxPool3d` for details. + + Args: + input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iD, iH , iW)`, minibatch dim optional. + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kT, kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. + return_indices: If ``True``, will return the argmax along with the max values. + Useful for :class:`torch.nn.functional.max_unpool3d` later + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_pool3d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch._C._nn.max_pool3d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + +def _max_pool3d( + input: Tensor, + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + padding: BroadcastingList3[int] = 0, + dilation: BroadcastingList3[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + max_pool3d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool3d = boolean_dispatch( + arg_name="return_indices", + arg_index=6, + default=False, + if_true=max_pool3d_with_indices, + if_false=_max_pool3d, + module_name=__name__, + func_name="max_pool3d", +) + + +def _unpool_output_size( + input: Tensor, + kernel_size: List[int], + stride: List[int], + padding: List[int], + output_size: Optional[List[int]], +) -> List[int]: + input_size = input.size() + default_size = torch.jit.annotate(List[int], []) + for d in range(len(kernel_size)): + default_size.append( + (input_size[-len(kernel_size) + d] - 1) * stride[d] + + kernel_size[d] + - 2 * padding[d] + ) + if output_size is None: + ret = default_size + else: + if len(output_size) == len(kernel_size) + 2: + output_size = output_size[2:] + if len(output_size) != len(kernel_size): + raise ValueError( + "output_size should be a sequence containing " + f"{len(kernel_size)} or {len(kernel_size) + 2} elements, but it has a length of '{len(output_size)}'" + ) + for d in range(len(kernel_size)): + min_size = default_size[d] - stride[d] + max_size = default_size[d] + stride[d] + if not (min_size < output_size[d] < max_size): + raise ValueError( + f'invalid output_size "{output_size}" (dim {d} must be between {min_size} and {max_size})' + ) + + ret = output_size + return ret + + +def max_unpool1d( + input: Tensor, + indices: Tensor, + kernel_size: BroadcastingList1[int], + stride: Optional[BroadcastingList1[int]] = None, + padding: BroadcastingList1[int] = 0, + output_size: Optional[BroadcastingList1[int]] = None, +) -> Tensor: + r"""Compute a partial inverse of :class:`MaxPool1d`. + + See :class:`~torch.nn.MaxUnpool1d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool1d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) + kernel_size = _single(kernel_size) + if stride is not None: + _stride = _single(stride) + else: + _stride = kernel_size + padding = _single(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + if isinstance(output_size, list): + output_size = output_size + [1] + else: + output_size = output_size + (1,) + return torch._C._nn.max_unpool2d( + input.unsqueeze(-1), indices.unsqueeze(-1), output_size + ).squeeze(-1) + + +def max_unpool2d( + input: Tensor, + indices: Tensor, + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + padding: BroadcastingList2[int] = 0, + output_size: Optional[BroadcastingList2[int]] = None, +) -> Tensor: + r"""Compute a partial inverse of :class:`MaxPool2d`. + + See :class:`~torch.nn.MaxUnpool2d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool2d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) + kernel_size = _pair(kernel_size) + if stride is not None: + _stride = _pair(stride) + else: + _stride = kernel_size + padding = _pair(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + return torch._C._nn.max_unpool2d(input, indices, output_size) + + +def max_unpool3d( + input: Tensor, + indices: Tensor, + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + padding: BroadcastingList3[int] = 0, + output_size: Optional[BroadcastingList3[int]] = None, +) -> Tensor: + r"""Compute a partial inverse of :class:`MaxPool3d`. + + See :class:`~torch.nn.MaxUnpool3d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool3d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) + kernel_size = _triple(kernel_size) + if stride is not None: + _stride = _triple(stride) + else: + _stride = kernel_size + padding = _triple(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding) + + +def lp_pool3d( + input: Tensor, + norm_type: Union[int, float], + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + ceil_mode: bool = False, +) -> Tensor: + r""" + Apply a 3D power-average pooling over an input signal composed of several input planes. + + If the sum of all inputs to the power of `p` is + zero, the gradient is set to zero as well. + + See :class:`~torch.nn.LPPool3d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool3d, + (input,), + input, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, + ) + kd, kw, kh = _triple(kernel_size) + if stride is not None: + out = avg_pool3d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool3d( + input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode + ) + + return ( + (torch.sign(out) * relu(torch.abs(out))).mul(kd * kw * kh).pow(1.0 / norm_type) + ) + + +def lp_pool2d( + input: Tensor, + norm_type: Union[int, float], + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + ceil_mode: bool = False, +) -> Tensor: + r""" + Apply a 2D power-average pooling over an input signal composed of several input planes. + + If the sum of all inputs to the power of `p` is + zero, the gradient is set to zero as well. + + See :class:`~torch.nn.LPPool2d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool2d, + (input,), + input, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, + ) + kw, kh = _pair(kernel_size) + if stride is not None: + out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool2d( + input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode + ) + + return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1.0 / norm_type) + + +def lp_pool1d( + input: Tensor, + norm_type: Union[int, float], + kernel_size: int, + stride: Optional[BroadcastingList1[int]] = None, + ceil_mode: bool = False, +) -> Tensor: + r"""Apply a 1D power-average pooling over an input signal composed of several input planes. + + If the sum of all inputs to the power of `p` is + zero, the gradient is set to zero as well. + + See :class:`~torch.nn.LPPool1d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool1d, + (input,), + input, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, + ) + if stride is not None: + out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool1d( + input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode + ) + + return ( + (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type) + ) + + +def adaptive_max_pool1d_with_indices( + input: Tensor, + output_size: BroadcastingList1[int], + return_indices: bool = False, +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + adaptive_max_pool1d(input, output_size, return_indices=False) + + Applies a 1D adaptive max pooling over an input signal composed of + several input planes. + + See :class:`~torch.nn.AdaptiveMaxPool1d` for details and output shape. + + Args: + output_size: the target output size (single integer) + return_indices: whether to return pooling indices. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool1d_with_indices, + (input,), + input, + output_size, + return_indices=return_indices, + ) + return torch.adaptive_max_pool1d(input, output_size) + + +def _adaptive_max_pool1d( + input: Tensor, + output_size: BroadcastingList1[int], + return_indices: bool = False, +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool1d, + (input,), + input, + output_size, + return_indices=return_indices, + ) + return adaptive_max_pool1d_with_indices(input, output_size)[0] + + +adaptive_max_pool1d = boolean_dispatch( + arg_name="return_indices", + arg_index=2, + default=False, + if_true=adaptive_max_pool1d_with_indices, + if_false=_adaptive_max_pool1d, + module_name=__name__, + func_name="adaptive_max_pool1d", +) + + +def adaptive_max_pool2d_with_indices( + input: Tensor, + output_size: BroadcastingList2[int], + return_indices: bool = False, +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r"""adaptive_max_pool2d(input, output_size, return_indices=False) + + Applies a 2D adaptive max pooling over an input signal composed of + several input planes. + + See :class:`~torch.nn.AdaptiveMaxPool2d` for details and output shape. + + Args: + output_size: the target output size (single integer or + double-integer tuple) + return_indices: whether to return pooling indices. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool2d_with_indices, + (input,), + input, + output_size, + return_indices=return_indices, + ) + output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_max_pool2d(input, output_size) + + +def _adaptive_max_pool2d( + input: Tensor, + output_size: BroadcastingList2[int], + return_indices: bool = False, +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool2d, + (input,), + input, + output_size, + return_indices=return_indices, + ) + return adaptive_max_pool2d_with_indices(input, output_size)[0] + + +adaptive_max_pool2d = boolean_dispatch( + arg_name="return_indices", + arg_index=2, + default=False, + if_true=adaptive_max_pool2d_with_indices, + if_false=_adaptive_max_pool2d, + module_name=__name__, + func_name="adaptive_max_pool2d", +) + + +def adaptive_max_pool3d_with_indices( + input: Tensor, + output_size: BroadcastingList3[int], + return_indices: bool = False, +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + adaptive_max_pool3d(input, output_size, return_indices=False) + + Applies a 3D adaptive max pooling over an input signal composed of + several input planes. + + See :class:`~torch.nn.AdaptiveMaxPool3d` for details and output shape. + + Args: + output_size: the target output size (single integer or + triple-integer tuple) + return_indices: whether to return pooling indices. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool3d_with_indices, + (input,), + input, + output_size, + return_indices=return_indices, + ) + output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_max_pool3d(input, output_size) + + +def _adaptive_max_pool3d( + input: Tensor, + output_size: BroadcastingList3[int], + return_indices: bool = False, +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool3d, + (input,), + input, + output_size, + return_indices=return_indices, + ) + return adaptive_max_pool3d_with_indices(input, output_size)[0] + + +adaptive_max_pool3d = boolean_dispatch( + arg_name="return_indices", + arg_index=2, + default=False, + if_true=adaptive_max_pool3d_with_indices, + if_false=_adaptive_max_pool3d, + module_name=__name__, + func_name="adaptive_max_pool3d", +) + + +adaptive_avg_pool1d = _add_docstr( + torch.adaptive_avg_pool1d, + r""" +adaptive_avg_pool1d(input, output_size) -> Tensor + +Applies a 1D adaptive average pooling over an input signal composed of +several input planes. + +See :class:`~torch.nn.AdaptiveAvgPool1d` for details and output shape. + +Args: + output_size: the target output size (single integer) +""", +) + + +def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: + r"""Apply a 2D adaptive average pooling over an input signal composed of several input planes. + + See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. + + Args: + output_size: the target output size (single integer or + double-integer tuple) + """ + if has_torch_function_unary(input): + return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) + _output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_avg_pool2d(input, _output_size) + + +def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> Tensor: + r"""Apply a 3D adaptive average pooling over an input signal composed of several input planes. + + See :class:`~torch.nn.AdaptiveAvgPool3d` for details and output shape. + + Args: + output_size: the target output size (single integer or + triple-integer tuple) + """ + if has_torch_function_unary(input): + return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) + _output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_avg_pool3d(input, _output_size) + + +# Activation functions +def dropout( + input: Tensor, + p: float = 0.5, + training: bool = True, + inplace: bool = False, +) -> Tensor: + r"""During training, randomly zeroes some elements of the input tensor with probability :attr:`p`. + + Uses samples from a Bernoulli distribution. + + See :class:`~torch.nn.Dropout` for details. + + Args: + p: probability of an element to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + dropout, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + return ( + _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) + ) + + +def alpha_dropout( + input: Tensor, + p: float = 0.5, + training: bool = False, + inplace: bool = False, +) -> Tensor: + r"""Apply alpha dropout to the input. + + See :class:`~torch.nn.AlphaDropout` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + alpha_dropout, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + return ( + _VF.alpha_dropout_(input, p, training) + if inplace + else _VF.alpha_dropout(input, p, training) + ) + + +def dropout1d( + input: Tensor, + p: float = 0.5, + training: bool = True, + inplace: bool = False, +) -> Tensor: + r"""Randomly zero out entire channels (a channel is a 1D feature map). + + For example, the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 1D tensor :math:`\text{input}[i, j]` of the input tensor. + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + See :class:`~torch.nn.Dropout1d` for details. + + Args: + p: probability of a channel to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + dropout1d, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + inp_dim = input.dim() + if inp_dim not in (2, 3): + raise RuntimeError( + f"dropout1d: Expected 2D or 3D input, but received a {inp_dim}D input. " + "Note that dropout1d exists to provide channel-wise dropout on inputs with 1 " + "spatial dimension, a channel dimension, and an optional batch dimension " + "(i.e. 2D or 3D inputs)." + ) + + is_batched = inp_dim == 3 + if not is_batched: + input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) + + result = ( + _VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training) + ) + + if not is_batched: + result = result.squeeze_(0) if inplace else result.squeeze(0) + + return result + + +def dropout2d( + input: Tensor, + p: float = 0.5, + training: bool = True, + inplace: bool = False, +) -> Tensor: + r"""Randomly zero out entire channels (a channel is a 2D feature map). + + For example, the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 2D tensor :math:`\text{input}[i, j]` of the input tensor. + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + See :class:`~torch.nn.Dropout2d` for details. + + Args: + p: probability of a channel to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + dropout2d, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + inp_dim = input.dim() + if inp_dim not in (3, 4): + warn_msg = ( + f"dropout2d: Received a {inp_dim}-D input to dropout2d, which is deprecated " + "and will result in an error in a future release. To retain the behavior " + "and silence this warning, please use dropout instead. Note that dropout2d " + "exists to provide channel-wise dropout on inputs with 2 spatial dimensions, " + "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs)." + ) + warnings.warn(warn_msg) + + # TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing + # a 3D input will perform dropout1d behavior instead. This was done historically and the + # behavior is maintained here for now. + # See https://github.com/pytorch/pytorch/issues/77081 + if inp_dim == 3: + warnings.warn( + "dropout2d: Received a 3D input to dropout2d and assuming that channel-wise " + "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C " + "is the channel dim. This behavior will change in a future release to interpret the " + "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D " + "channel-wise dropout behavior, please switch to using dropout1d instead." + ) + + result = ( + _VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training) + ) + + return result + + +def dropout3d( + input: Tensor, + p: float = 0.5, + training: bool = True, + inplace: bool = False, +) -> Tensor: + r"""Randomly zero out entire channels (a channel is a 3D feature map). + + For example, the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 3D tensor :math:`\text{input}[i, j]` of the input tensor. + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + See :class:`~torch.nn.Dropout3d` for details. + + Args: + p: probability of a channel to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + dropout3d, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + inp_dim = input.dim() + if inp_dim not in (4, 5): + warn_msg = ( + f"dropout3d: Received a {inp_dim}-D input to dropout3d, which is deprecated " + "and will result in an error in a future release. To retain the behavior " + "and silence this warning, please use dropout instead. Note that dropout3d " + "exists to provide channel-wise dropout on inputs with 3 spatial dimensions, " + "a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs)." + ) + warnings.warn(warn_msg) + + is_batched = inp_dim == 5 + if not is_batched: + input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) + + result = ( + _VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training) + ) + + if not is_batched: + result = result.squeeze_(0) if inplace else result.squeeze(0) + return result + + +def feature_alpha_dropout( + input: Tensor, + p: float = 0.5, + training: bool = False, + inplace: bool = False, +) -> Tensor: + r"""Randomly masks out entire channels (a channel is a feature map). + + For example, the :math:`j`-th channel of the :math:`i`-th sample in the batch input + is a tensor :math:`\text{input}[i, j]` of the input tensor. Instead of + setting activations to zero, as in regular Dropout, the activations are set + to the negative saturation value of the SELU activation function. + + Each element will be masked independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + The elements to be masked are randomized on every forward call, and scaled + and shifted to maintain zero mean and unit variance. + + See :class:`~torch.nn.FeatureAlphaDropout` for details. + + Args: + p: dropout probability of a channel to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + feature_alpha_dropout, + (input,), + input, + p=p, + training=training, + inplace=inplace, + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + return ( + _VF.feature_alpha_dropout_(input, p, training) + if inplace + else _VF.feature_alpha_dropout(input, p, training) + ) + + +def _threshold( + input: Tensor, + threshold: float, + value: float, + inplace: bool = False, +) -> Tensor: + r"""Apply a threshold to each element of the input Tensor. + + See :class:`~torch.nn.Threshold` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + _threshold, (input,), input, threshold, value, inplace=inplace + ) + if inplace: + result = _VF.threshold_(input, threshold, value) + else: + result = _VF.threshold(input, threshold, value) + return result + + +# We define this function as _threshold because it takes an argument +# named threshold, which clobbers the recursive reference to the +# function needed for __torch_function__ support +threshold = _threshold + +threshold_ = _add_docstr( + _VF.threshold_, + r""" +threshold_(input, threshold, value) -> Tensor + +In-place version of :func:`~threshold`. +""", +) + + +def relu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 + r"""relu(input, inplace=False) -> Tensor + + Applies the rectified linear unit function element-wise. See + :class:`~torch.nn.ReLU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(relu, (input,), input, inplace=inplace) + if inplace: + result = torch.relu_(input) + else: + result = torch.relu(input) + return result + + +relu_ = _add_docstr( + torch.relu_, + r""" +relu_(input) -> Tensor + +In-place version of :func:`~relu`. +""", +) + + +def glu(input: Tensor, dim: int = -1) -> Tensor: # noqa: D400,D402 + r""" + glu(input, dim=-1) -> Tensor + + The gated linear unit. Computes: + + .. math :: + \text{GLU}(a, b) = a \otimes \sigma(b) + + where `input` is split in half along `dim` to form `a` and `b`, :math:`\sigma` + is the sigmoid function and :math:`\otimes` is the element-wise product between matrices. + + See `Language Modeling with Gated Convolutional Networks `_. + + Args: + input (Tensor): input tensor + dim (int): dimension on which to split the input. Default: -1 + """ + if has_torch_function_unary(input): + return handle_torch_function(glu, (input,), input, dim=dim) + if input.dim() == 0: + raise RuntimeError( + "glu does not support scalars because halving size must be even" + ) + return torch._C._nn.glu(input, dim) + + +def hardtanh( + input: Tensor, + min_val: float = -1.0, + max_val: float = 1.0, + inplace: bool = False, +) -> Tensor: # noqa: D400,D402 + r""" + hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor + + Applies the HardTanh function element-wise. See :class:`~torch.nn.Hardtanh` for more + details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace + ) + if min_val > max_val: + raise ValueError("min_val cannot be greater than max_val") + if inplace: + result = torch._C._nn.hardtanh_(input, min_val, max_val) + else: + result = torch._C._nn.hardtanh(input, min_val, max_val) + return result + + +hardtanh_ = _add_docstr( + torch._C._nn.hardtanh_, + r""" +hardtanh_(input, min_val=-1., max_val=1.) -> Tensor + +In-place version of :func:`~hardtanh`. +""", +) + + +def relu6(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 + r"""relu6(input, inplace=False) -> Tensor + + Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`. + + See :class:`~torch.nn.ReLU6` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(relu6, (input,), input, inplace=inplace) + if inplace: + result = torch._C._nn.relu6_(input) + else: + result = torch._C._nn.relu6(input) + return result + + +def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: + r"""Apply the Exponential Linear Unit (ELU) function element-wise. + + See :class:`~torch.nn.ELU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace) + if inplace: + result = torch._C._nn.elu_(input, alpha) + else: + result = torch._C._nn.elu(input, alpha) + return result + + +elu_ = _add_docstr( + torch._C._nn.elu_, + r""" +elu_(input, alpha=1.) -> Tensor + +In-place version of :func:`~elu`. +""", +) + + +def selu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 + r"""selu(input, inplace=False) -> Tensor + + Applies element-wise, + :math:`\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))`, + with :math:`\alpha=1.6732632423543772848170429916717` and + :math:`scale=1.0507009873554804934193349852946`. + + See :class:`~torch.nn.SELU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(selu, (input,), input, inplace=inplace) + if inplace: + result = torch.selu_(input) + else: + result = torch.selu(input) + return result + + +selu_ = _add_docstr( + torch.selu_, + r""" +selu_(input) -> Tensor + +In-place version of :func:`~selu`. +""", +) + + +def celu( + input: Tensor, + alpha: float = 1.0, + inplace: bool = False, +) -> Tensor: # noqa: D400,D402 + r"""celu(input, alpha=1., inplace=False) -> Tensor + + Applies element-wise, + :math:`\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))`. + + See :class:`~torch.nn.CELU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + celu, (input,), input, alpha=alpha, inplace=inplace + ) + if inplace: + result = torch.celu_(input, alpha) + else: + result = torch.celu(input, alpha) + return result + + +celu_ = _add_docstr( + torch.celu_, + r""" +celu_(input, alpha=1.) -> Tensor + +In-place version of :func:`~celu`. +""", +) + + +def leaky_relu( + input: Tensor, + negative_slope: float = 0.01, + inplace: bool = False, +) -> Tensor: # noqa: D400,D402 + r""" + leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor + + Applies element-wise, + :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)` + + See :class:`~torch.nn.LeakyReLU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace + ) + if inplace: + result = torch._C._nn.leaky_relu_(input, negative_slope) + else: + result = torch._C._nn.leaky_relu(input, negative_slope) + return result + + +leaky_relu_ = _add_docstr( + torch._C._nn.leaky_relu_, + r""" +leaky_relu_(input, negative_slope=0.01) -> Tensor + +In-place version of :func:`~leaky_relu`. +""", +) + + +prelu = _add_docstr( + torch.prelu, + r"""prelu(input, weight) -> Tensor + +Applies element-wise the function +:math:`\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x)` where weight is a +learnable parameter. + +.. note:: + `weight` is expected to be a scalar or 1-D tensor. If `weight` is 1-D, + its size must match the number of input channels, determined by + `input.size(1)` when `input.dim() >= 2`, otherwise 1. + In the 1-D case, note that when `input` has dim > 2, `weight` can be expanded + to the shape of `input` in a way that is not possible using normal + :ref:`broadcasting semantics`. + +See :class:`~torch.nn.PReLU` for more details. +""", +) + + +def rrelu( + input: Tensor, + lower: float = 1.0 / 8, + upper: float = 1.0 / 3, + training: bool = False, + inplace: bool = False, +) -> Tensor: # noqa: D400,D402 + r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor + + Randomized leaky ReLU. + + See :class:`~torch.nn.RReLU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + rrelu, + (input,), + input, + lower=lower, + upper=upper, + training=training, + inplace=inplace, + ) + if inplace: + result = torch.rrelu_(input, lower, upper, training) + else: + result = torch.rrelu(input, lower, upper, training) + return result + + +rrelu_ = _add_docstr( + torch.rrelu_, + r""" +rrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor + +In-place version of :func:`~rrelu`. +""", +) + +logsigmoid = _add_docstr( + torch._C._nn.log_sigmoid, + r""" +logsigmoid(input) -> Tensor + +Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)` + +See :class:`~torch.nn.LogSigmoid` for more details. +""", +) + +gelu = _add_docstr( + torch._C._nn.gelu, + r""" +gelu(input, approximate = 'none') -> Tensor + +When the approximate argument is 'none', it applies element-wise the function +:math:`\text{GELU}(x) = x * \Phi(x)` + +where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. + +When the approximate argument is 'tanh', Gelu is estimated with + +.. math:: + \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) + +See `Gaussian Error Linear Units (GELUs) `_. +""", +) + +hardshrink = _add_docstr( + torch.hardshrink, + r""" +hardshrink(input, lambd=0.5) -> Tensor + +Applies the hard shrinkage function element-wise + +See :class:`~torch.nn.Hardshrink` for more details. +""", +) + + +def tanhshrink(input): # noqa: D400,D402 + r"""tanhshrink(input) -> Tensor + + Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)` + + See :class:`~torch.nn.Tanhshrink` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(tanhshrink, (input,), input) + return input - input.tanh() + + +def softsign(input): # noqa: D400,D402 + r"""softsign(input) -> Tensor + + Applies element-wise, the function :math:`\text{SoftSign}(x) = \frac{x}{1 + |x|}` + + See :class:`~torch.nn.Softsign` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(softsign, (input,), input) + return input / (input.abs() + 1) + + +softplus = _add_docstr( + torch._C._nn.softplus, + r""" +softplus(input, beta=1, threshold=20) -> Tensor + +Applies element-wise, the function :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`. + +For numerical stability the implementation reverts to the linear function +when :math:`input \times \beta > threshold`. + +See :class:`~torch.nn.Softplus` for more details. +""", +) + + +def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: + warnings.warn( + f"Implicit dimension choice for {name} has been deprecated. " + "Change the call to include dim=X as an argument.", + stacklevel=stacklevel, + ) + if ndim == 0 or ndim == 1 or ndim == 3: + ret = 0 + else: + ret = 1 + return ret + + +def softmin( + input: Tensor, + dim: Optional[int] = None, + _stacklevel: int = 3, + dtype: Optional[DType] = None, +) -> Tensor: + r"""Apply a softmin function. + + Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula. + + See :class:`~torch.nn.Softmin` for more details. + + Args: + input (Tensor): input + dim (int): A dimension along which softmin will be computed (so every slice + along dim will sum to 1). + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + """ + if has_torch_function_unary(input): + return handle_torch_function( + softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype + ) + if dim is None: + dim = _get_softmax_dim("softmin", input.dim(), _stacklevel) + if dtype is None: + ret = (-input).softmax(dim) + else: + ret = (-input).softmax(dim, dtype=dtype) + return ret + + +def softmax( + input: Tensor, + dim: Optional[int] = None, + _stacklevel: int = 3, + dtype: Optional[DType] = None, +) -> Tensor: + r"""Apply a softmax function. + + Softmax is defined as: + + :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}` + + It is applied to all slices along dim, and will re-scale them so that the elements + lie in the range `[0, 1]` and sum to 1. + + See :class:`~torch.nn.Softmax` for more details. + + Args: + input (Tensor): input + dim (int): A dimension along which softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + .. note:: + This function doesn't work directly with NLLLoss, + which expects the Log to be computed between the Softmax and itself. + Use log_softmax instead (it's faster and has better numerical properties). + + """ + if has_torch_function_unary(input): + return handle_torch_function( + softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype + ) + if dim is None: + dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) + if dtype is None: + ret = input.softmax(dim) + else: + ret = input.softmax(dim, dtype=dtype) + return ret + + +def gumbel_softmax( + logits: Tensor, + tau: float = 1, + hard: bool = False, + eps: float = 1e-10, + dim: int = -1, +) -> Tensor: + r""" + Sample from the Gumbel-Softmax distribution (`Link 1`_ `Link 2`_) and optionally discretize. + + Args: + logits: `[..., num_features]` unnormalized log probabilities + tau: non-negative scalar temperature + hard: if ``True``, the returned samples will be discretized as one-hot vectors, + but will be differentiated as if it is the soft sample in autograd + dim (int): A dimension along which softmax will be computed. Default: -1. + + Returns: + Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. + If ``hard=True``, the returned samples will be one-hot, otherwise they will + be probability distributions that sum to 1 across `dim`. + + .. note:: + This function is here for legacy reasons, may be removed from nn.Functional in the future. + + .. note:: + The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft` + + It achieves two things: + - makes the output value exactly one-hot + (since we add then subtract y_soft value) + - makes the gradient equal to y_soft gradient + (since we strip all other gradients) + + Examples:: + >>> logits = torch.randn(20, 32) + >>> # Sample soft categorical using reparametrization trick: + >>> F.gumbel_softmax(logits, tau=1, hard=False) + >>> # Sample hard categorical using "Straight-through" trick: + >>> F.gumbel_softmax(logits, tau=1, hard=True) + + .. _Link 1: + https://arxiv.org/abs/1611.00712 + .. _Link 2: + https://arxiv.org/abs/1611.01144 + """ + if has_torch_function_unary(logits): + return handle_torch_function( + gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim + ) + if eps != 1e-10: + warnings.warn("`eps` parameter is deprecated and has no effect.") + + gumbels = ( + -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) + .exponential_() + .log() + ) # ~Gumbel(0,1) + gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) + y_soft = gumbels.softmax(dim) + + if hard: + # Straight through. + index = y_soft.max(dim, keepdim=True)[1] + y_hard = torch.zeros_like( + logits, memory_format=torch.legacy_contiguous_format + ).scatter_(dim, index, 1.0) + ret = y_hard - y_soft.detach() + y_soft + else: + # Reparametrization trick. + ret = y_soft + return ret + + +def log_softmax( + input: Tensor, + dim: Optional[int] = None, + _stacklevel: int = 3, + dtype: Optional[DType] = None, +) -> Tensor: + r"""Apply a softmax followed by a logarithm. + + While mathematically equivalent to log(softmax(x)), doing these two + operations separately is slower and numerically unstable. This function + uses an alternative formulation to compute the output and gradient correctly. + + See :class:`~torch.nn.LogSoftmax` for more details. + + Args: + input (Tensor): input + dim (int): A dimension along which log_softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is cast to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + """ + if has_torch_function_unary(input): + return handle_torch_function( + log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype + ) + if dim is None: + dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) + if dtype is None: + ret = input.log_softmax(dim) + else: + ret = input.log_softmax(dim, dtype=dtype) + return ret + + +softshrink = _add_docstr( + torch._C._nn.softshrink, + r""" +softshrink(input, lambd=0.5) -> Tensor + +Applies the soft shrinkage function elementwise + +See :class:`~torch.nn.Softshrink` for more details. +""", +) + + +def tanh(input): # noqa: D400,D402 + r"""tanh(input) -> Tensor + + Applies element-wise, + :math:`\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}` + + See :class:`~torch.nn.Tanh` for more details. + """ + return input.tanh() + + +def sigmoid(input): # noqa: D400,D402 + r"""sigmoid(input) -> Tensor + + Applies the element-wise function :math:`\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}` + + See :class:`~torch.nn.Sigmoid` for more details. + """ + return input.sigmoid() + + +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: + r"""Apply the Hardsigmoid function element-wise. + + .. math:: + \text{Hardsigmoid}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + 1 & \text{if~} x \ge +3, \\ + x / 6 + 1 / 2 & \text{otherwise} + \end{cases} + + Args: + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + + See :class:`~torch.nn.Hardsigmoid` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.hardsigmoid_(input) + return torch._C._nn.hardsigmoid(input) + + +linear = _add_docstr( + torch._C._nn.linear, + r""" +linear(input, weight, bias=None) -> Tensor + +Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. + +This operation supports 2-D :attr:`weight` with :ref:`sparse layout` + +{sparse_beta_warning} + +This operator supports :ref:`TensorFloat32`. + +Shape: + + - Input: :math:`(*, in\_features)` where `*` means any number of + additional dimensions, including none + - Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)` + - Bias: :math:`(out\_features)` or :math:`()` + - Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight +""".format( + **sparse_support_notes + ), +) + + +bilinear = _add_docstr( + torch.bilinear, + r""" +bilinear(input1, input2, weight, bias=None) -> Tensor + +Applies a bilinear transformation to the incoming data: +:math:`y = x_1^T A x_2 + b` + +Shape: + + - input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` + and :math:`*` means any number of additional dimensions. + All but the last dimension of the inputs should be the same. + - input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}` + - weight: :math:`(\text{out\_features}, \text{in1\_features}, + \text{in2\_features})` + - bias: :math:`(\text{out\_features})` + - output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}` + and all but the last dimension are the same shape as the input. +""", +) + + +def silu(input: Tensor, inplace: bool = False) -> Tensor: + r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise. + + The SiLU function is also known as the swish function. + + .. math:: + \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} + + .. note:: + See `Gaussian Error Linear Units (GELUs) `_ + where the SiLU (Sigmoid Linear Unit) was originally coined, and see + `Sigmoid-Weighted Linear Units for Neural Network Function Approximation + in Reinforcement Learning `_ and `Swish: + a Self-Gated Activation Function `_ + where the SiLU was experimented with later. + + See :class:`~torch.nn.SiLU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(silu, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.silu_(input) + return torch._C._nn.silu(input) + + +def mish(input: Tensor, inplace: bool = False) -> Tensor: + r"""Apply the Mish function, element-wise. + + Mish: A Self Regularized Non-Monotonic Neural Activation Function. + + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + + .. note:: + See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ + + See :class:`~torch.nn.Mish` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(mish, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.mish_(input) + return torch._C._nn.mish(input) + + +def hardswish(input: Tensor, inplace: bool = False) -> Tensor: + r"""Apply hardswish function, element-wise. + + Follows implementation as described in the paper: + `Searching for MobileNetV3`_. + + .. math:: + \text{Hardswish}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + x & \text{if~} x \ge +3, \\ + x \cdot (x + 3) /6 & \text{otherwise} + \end{cases} + + See :class:`~torch.nn.Hardswish` for more details. + + .. _`Searching for MobileNetV3`: + https://arxiv.org/abs/1905.02244 + """ + if has_torch_function_unary(input): + return handle_torch_function(hardswish, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.hardswish_(input) + return torch._C._nn.hardswish(input) + + +def _no_grad_embedding_renorm_( + weight: Tensor, + input: Tensor, + max_norm: float, + norm_type: float, +) -> Tuple[Tensor, Tensor]: + torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type) + + +def embedding( + input: Tensor, + weight: Tensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: + r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size. + + This module is often used to retrieve word embeddings using indices. + The input to the module is a list of indices, and the embedding matrix, + and the output is the corresponding word embeddings. + + See :class:`torch.nn.Embedding` for more details. + + .. note:: + Note that the analytical gradients of this function with respect to + entries in :attr:`weight` at the row specified by :attr:`padding_idx` + are expected to differ from the numerical ones. + + .. note:: + Note that `:class:`torch.nn.Embedding` differs from this function in + that it initializes the row of :attr:`weight` specified by + :attr:`padding_idx` to all zeros on construction. + + Args: + input (LongTensor): Tensor containing indices into the embedding matrix + weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1, + and number of columns equal to the embedding size + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + Note: this will modify :attr:`weight` in-place. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under + :class:`torch.nn.Embedding` for more details regarding sparse gradients. + + Shape: + - Input: LongTensor of arbitrary shape containing the indices to extract + - Weight: Embedding matrix of floating point type with shape `(V, embedding_dim)`, + where V = maximum index + 1 and embedding_dim = the embedding size + - Output: `(*, embedding_dim)`, where `*` is the input shape + + Examples:: + + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) + >>> # an embedding matrix containing 10 tensors of size 3 + >>> embedding_matrix = torch.rand(10, 3) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> F.embedding(input, embedding_matrix) + tensor([[[ 0.8490, 0.9625, 0.6753], + [ 0.9666, 0.7761, 0.6108], + [ 0.6246, 0.9751, 0.3618], + [ 0.4161, 0.2419, 0.7383]], + + [[ 0.6246, 0.9751, 0.3618], + [ 0.0237, 0.7794, 0.0528], + [ 0.9666, 0.7761, 0.6108], + [ 0.3385, 0.8612, 0.1867]]]) + + >>> # example with padding_idx + >>> weights = torch.rand(10, 3) + >>> weights[0, :].zero_() + >>> embedding_matrix = weights + >>> input = torch.tensor([[0, 2, 0, 5]]) + >>> F.embedding(input, embedding_matrix, padding_idx=0) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.5609, 0.5384, 0.8720], + [ 0.0000, 0.0000, 0.0000], + [ 0.6262, 0.2438, 0.7471]]]) + """ + if has_torch_function_variadic(input, weight): + return handle_torch_function( + embedding, + (input, weight), + input, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < weight.size( + 0 + ), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert padding_idx >= -weight.size( + 0 + ), "Padding_idx must be within num_embeddings" + padding_idx = weight.size(0) + padding_idx + else: + padding_idx = -1 + if max_norm is not None: + # Note [embedding_renorm contiguous] + # `embedding_renorm_` will call .contiguous() on input anyways, so we + # call it here and take advantage of the improved locality in the + # `embedding` call below too. + input = input.contiguous() + # Note [embedding_renorm set_grad_enabled] + # XXX: equivalent to + # with torch.no_grad(): + # torch.embedding_renorm_ + # remove once script supports set_grad_enabled + _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) + return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) + + +def embedding_bag( + input: Tensor, + weight: Tensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, +) -> Tensor: + r"""Compute sums, means or maxes of `bags` of embeddings. + + Calculation is done without instantiating the intermediate embeddings. + See :class:`torch.nn.EmbeddingBag` for more details. + + Note: + {backward_reproducibility_note} + + Args: + input (LongTensor): Tensor containing bags of indices into the embedding matrix + weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1, + and number of columns equal to the embedding size + offsets (LongTensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines + the starting index position of each bag (sequence) in :attr:`input`. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + Note: this will modify :attr:`weight` in-place. + norm_type (float, optional): The ``p`` in the ``p``-norm to compute for the :attr:`max_norm` option. + Default ``2``. + scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + Note: this option is not supported when ``mode="max"``. + mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. + Default: ``"mean"`` + sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under + :class:`torch.nn.Embedding` for more details regarding sparse gradients. + Note: this option is not supported when ``mode="max"``. + per_sample_weights (Tensor, optional): a tensor of float / double weights, or None + to indicate all weights should be taken to be 1. If specified, :attr:`per_sample_weights` + must have exactly the same shape as input and is treated as having the same + :attr:`offsets`, if those are not None. + + include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1. + The last element is the size of the input, or the ending index position of the last bag (sequence). + + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the + gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated + during training, i.e. it remains as a fixed "pad". Note that the embedding + vector at :attr:`padding_idx` is excluded from the reduction. + + Shape: + - :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional) + + - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences) + each of fixed length ``N``, and this will return ``B`` values aggregated in a way + depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case. + + - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of + multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing + the starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` + of shape `(B)`, :attr:`input` will be viewed as having ``B`` bags. + Empty bags (i.e., having 0-length) will have returned vectors filled by zeros. + + - :attr:`weight` (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` + + - :attr:`per_sample_weights` (Tensor, optional). Has the same shape as :attr:`input`. + + - :attr:`output`: aggregated embedding values of shape `(B, embedding_dim)` + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding_matrix = torch.rand(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) + >>> offsets = torch.tensor([0, 4]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> F.embedding_bag(input, embedding_matrix, offsets) + tensor([[ 0.3397, 0.3552, 0.5545], + [ 0.5893, 0.4386, 0.5882]]) + + >>> # example with padding_idx + >>> embedding_matrix = torch.rand(10, 3) + >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9]) + >>> offsets = torch.tensor([0, 4]) + >>> F.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum') + tensor([[ 0.0000, 0.0000, 0.0000], + [-0.7082, 3.2145, -2.6251]]) + """ + if has_torch_function_variadic(input, weight, offsets, per_sample_weights): + return handle_torch_function( + embedding_bag, + (input, weight, offsets, per_sample_weights), + input, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx, + ) + # Check for backward compatibility. + # Used to be embedding_bag(weight, input, ...) + # Now is embedding_bag(input, weight, ...) + if weight.dtype == torch.long and input.is_floating_point(): + warnings.warn( + "Argument order of nn.functional.embedding_bag was changed. " + "Usage `embedding_bag(weight, input, ...)` is deprecated, " + "and should now be `embedding_bag(input, weight, ...)`." + ) + weight, input = input, weight + + if per_sample_weights is not None and input.size() != per_sample_weights.size(): + raise ValueError( + f"embedding_bag: If per_sample_weights ({per_sample_weights.shape}) is not None, " + f"then it must have the same shape as the input ({input.shape})" + ) + + if not weight.dim() == 2: + raise ValueError( + f"weight has to be a 2D Tensor, but got Tensor of dimension {weight.dim()}" + ) + + if input.dim() == 2: + if offsets is not None: + type_str = "" + # TODO: Remove this once script supports type() calls + if not torch.jit.is_scripting(): + type_str = str(type(offsets)) + raise ValueError( + "if input is 2D, then offsets has to be None" + ", as input is treated is a mini-batch of" + " fixed length sequences. However, found " + f"offsets of type {type_str}" + ) + offsets = torch.arange( + 0, input.numel(), input.size(1), dtype=input.dtype, device=input.device + ) + + input = input.reshape(-1) + if per_sample_weights is not None: + per_sample_weights = per_sample_weights.reshape(-1) + elif input.dim() == 1: + if offsets is None: + raise ValueError("offsets has to be a 1D Tensor but got None") + if offsets.dim() != 1: + raise ValueError("offsets has to be a 1D Tensor") + else: + raise ValueError( + f"input has to be 1D or 2D Tensor, but got Tensor of dimension {input.dim()}" + ) + if mode == "sum": + mode_enum = 0 + elif mode == "mean": + mode_enum = 1 + elif mode == "max": + mode_enum = 2 + + if scale_grad_by_freq: + raise ValueError( + "max mode does not support scaling the gradient by the frequency" + ) + + if sparse: + raise ValueError("max mode does not support sparse weights") + + else: + raise ValueError("mode has to be one of sum, mean or max") + + if max_norm is not None: + # XXX: equivalent to + # with torch.no_grad(): + # torch.nembedding_renorm_ + # remove once script supports set_grad_enabled + _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) + + if per_sample_weights is not None and mode != "sum": + raise NotImplementedError( + "embedding_bag: per_sample_weights was not None. " + "per_sample_weights is only supported for mode='sum' " + f"(got mode='{mode}'). Please open a feature request on GitHub." + ) + + ret, _, _, _ = torch.embedding_bag( + weight, + input, + offsets, + scale_grad_by_freq, + mode_enum, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) + return ret + + +if embedding_bag.__doc__: + embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes) + + +def _verify_batch_size(size: List[int]) -> None: + # XXX: JIT script does not support the reduce from functools, and mul op is a + # builtin, which cannot be used as a value to a func yet, so rewrite this size + # check to a simple equivalent for loop + # + # TODO: make use of reduce like below when JIT is ready with the missing features: + # from operator import mul + # from functools import reduce + # + # if reduce(mul, size[2:], size[0]) == 1 + size_prods = size[0] + for i in range(len(size) - 2): + size_prods *= size[i + 2] + if size_prods == 1: + raise ValueError( + f"Expected more than 1 value per channel when training, got input size {size}" + ) + + +def batch_norm( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + r"""Apply Batch Normalization for each channel across a batch of data. + + See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`, + :class:`~torch.nn.BatchNorm3d` for details. + """ + if has_torch_function_variadic(input, running_mean, running_var, weight, bias): + return handle_torch_function( + batch_norm, + (input, running_mean, running_var, weight, bias), + input, + running_mean, + running_var, + weight=weight, + bias=bias, + training=training, + momentum=momentum, + eps=eps, + ) + if training: + _verify_batch_size(input.size()) + + return torch.batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + torch.backends.cudnn.enabled, + ) + + +def _verify_spatial_size(size: List[int]) -> None: + # Verify that there is > 1 spatial element for instance norm calculation. + size_prods = 1 + for i in range(2, len(size)): + size_prods *= size[i] + if size_prods == 1: + raise ValueError( + f"Expected more than 1 spatial element when training, got input size {size}" + ) + + +def instance_norm( + input: Tensor, + running_mean: Optional[Tensor] = None, + running_var: Optional[Tensor] = None, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + r"""Apply Instance Normalization independently for each channel in every data sample within a batch. + + See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`, + :class:`~torch.nn.InstanceNorm3d` for details. + """ + if has_torch_function_variadic(input, running_mean, running_var, weight, bias): + return handle_torch_function( + instance_norm, + (input, running_mean, running_var, weight, bias), + input, + running_mean=running_mean, + running_var=running_var, + weight=weight, + bias=bias, + use_input_stats=use_input_stats, + momentum=momentum, + eps=eps, + ) + if use_input_stats: + _verify_spatial_size(input.size()) + return torch.instance_norm( + input, + weight, + bias, + running_mean, + running_var, + use_input_stats, + momentum, + eps, + torch.backends.cudnn.enabled, + ) + + +def layer_norm( + input: Tensor, + normalized_shape: List[int], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + r"""Apply Layer Normalization for last certain number of dimensions. + + See :class:`~torch.nn.LayerNorm` for details. + """ + if has_torch_function_variadic(input, weight, bias): + return handle_torch_function( + layer_norm, + (input, weight, bias), + input, + normalized_shape, + weight=weight, + bias=bias, + eps=eps, + ) + return torch.layer_norm( + input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled + ) + + +def rms_norm( + input: Tensor, + normalized_shape: List[int], + weight: Optional[Tensor] = None, + eps: Optional[float] = None, +) -> Tensor: + r"""Apply Root Mean Square Layer Normalization. + + See :class:`~torch.nn.RMSNorm` for details. + """ + if has_torch_function_variadic(input, weight): + return handle_torch_function( + rms_norm, (input, weight), input, normalized_shape, weight=weight, eps=eps + ) + return torch.rms_norm(input, normalized_shape, weight, eps) + + +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + r"""Apply Group Normalization for last certain number of dimensions. + + See :class:`~torch.nn.GroupNorm` for details. + """ + if has_torch_function_variadic(input, weight, bias): + return handle_torch_function( + group_norm, + ( + input, + weight, + bias, + ), + input, + num_groups, + weight=weight, + bias=bias, + eps=eps, + ) + if input.dim() < 2: + raise RuntimeError( + f"Expected at least 2 dimensions for input tensor but received {input.dim()}" + ) + _verify_batch_size( + [input.size(0) * input.size(1) // num_groups, num_groups] + + list(input.size()[2:]) + ) + return torch.group_norm( + input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled + ) + + +def local_response_norm( + input: Tensor, + size: int, + alpha: float = 1e-4, + beta: float = 0.75, + k: float = 1.0, +) -> Tensor: + r"""Apply local response normalization over an input signal. + + The input signal is composed of several input planes, where channels occupy the second dimension. + Normalization is applied across channels. + + See :class:`~torch.nn.LocalResponseNorm` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k + ) + dim = input.dim() + if dim < 3: + raise ValueError( + f"Expected 3D or higher dimensionality input (got {dim} dimensions)" + ) + + if input.numel() == 0: + return input + + div = input.mul(input) + if dim == 3: + div = div.unsqueeze(1) + div = pad(div, (0, 0, size // 2, (size - 1) // 2)) + div = avg_pool2d(div, (size, 1), stride=1).squeeze(1) + else: + sizes = input.size() + div = div.view(sizes[0], 1, sizes[1], sizes[2], -1) + div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2)) + div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1) + div = div.view(sizes) + div = div.mul(alpha).add(k).pow(beta) + return input / div + + +# loss + + +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: int = 0, + reduction: str = "mean", + zero_infinity: bool = False, +) -> Tensor: + r"""Apply the Connectionist Temporal Classification loss. + + See :class:`~torch.nn.CTCLoss` for details. + + Note: + {cudnn_reproducibility_note} + + Note: + {backward_reproducibility_note} + + Args: + log_probs: :math:`(T, N, C)` or :math:`(T, C)` where `C = number of characters in alphabet including blank`, + `T = input length`, and `N = batch size`. + The logarithmized probabilities of the outputs + (e.g. obtained with :func:`torch.nn.functional.log_softmax`). + targets: :math:`(N, S)` or `(sum(target_lengths))`. + Targets cannot be blank. In the second form, the targets are assumed to be concatenated. + input_lengths: :math:`(N)` or :math:`()`. + Lengths of the inputs (must each be :math:`\leq T`) + target_lengths: :math:`(N)` or :math:`()`. + Lengths of the targets + blank (int, optional): + Blank label. Default :math:`0`. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output losses will be divided by the target lengths and + then the mean over the batch is taken, ``'sum'``: the output will be + summed. Default: ``'mean'`` + zero_infinity (bool, optional): + Whether to zero infinite losses and the associated gradients. + Default: ``False`` + Infinite losses mainly occur when the inputs are too short + to be aligned to the targets. + + Example:: + + >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() + >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) + >>> input_lengths = torch.full((16,), 50, dtype=torch.long) + >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) + >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) + >>> loss.backward() + """ + if has_torch_function_variadic(log_probs, targets, input_lengths, target_lengths): + return handle_torch_function( + ctc_loss, + (log_probs, targets, input_lengths, target_lengths), + log_probs, + targets, + input_lengths, + target_lengths, + blank=blank, + reduction=reduction, + zero_infinity=zero_infinity, + ) + return torch.ctc_loss( + log_probs, + targets, + input_lengths, + target_lengths, + blank, + _Reduction.get_enum(reduction), + zero_infinity, + ) + + +if ctc_loss.__doc__: + ctc_loss.__doc__ = ctc_loss.__doc__.format(**reproducibility_notes) + + +def nll_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + r"""Compute the negative log likelihood loss. + + See :class:`~torch.nn.NLLLoss` for details. + + Args: + input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)` + in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K \geq 1` + in the case of K-dimensional loss. `input` is expected to be log-probabilities. + target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, + or :math:`(N, d_1, d_2, ..., d_K)` where :math:`K \geq 1` for + K-dimensional loss. + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, has to be a Tensor of size `C` + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when reduce is ``False``. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Default: -100 + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Example:: + + >>> # input is of size N x C = 3 x 5 + >>> input = torch.randn(3, 5, requires_grad=True) + >>> # each element in target has to have 0 <= value < C + >>> target = torch.tensor([1, 0, 4]) + >>> output = F.nll_loss(F.log_softmax(input, dim=1), target) + >>> output.backward() + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + nll_loss, + (input, target, weight), + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + return torch._C._nn.nll_loss_nd( + input, target, weight, _Reduction.get_enum(reduction), ignore_index + ) + + +def poisson_nll_loss( + input: Tensor, + target: Tensor, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + r"""Poisson negative log likelihood loss. + + See :class:`~torch.nn.PoissonNLLLoss` for details. + + Args: + input: expectation of underlying Poisson distribution. + target: random sample :math:`target \sim \text{Poisson}(input)`. + log_input: if ``True`` the loss is computed as + :math:`\exp(\text{input}) - \text{target} * \text{input}`, if ``False`` then loss is + :math:`\text{input} - \text{target} * \log(\text{input}+\text{eps})`. Default: ``True`` + full: whether to compute full loss, i. e. to add the Stirling + approximation term. Default: ``False`` + :math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when reduce is ``False``. Default: ``True`` + eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when + :attr:`log_input`\ =\ ``False``. Default: 1e-8 + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + poisson_nll_loss, + (input, target), + input, + target, + log_input=log_input, + full=full, + size_average=size_average, + eps=eps, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + if reduction != "none" and reduction != "mean" and reduction != "sum": + ret = input + raise ValueError(reduction + " is not a valid value for reduction") + + ret = torch.poisson_nll_loss( + input, target, log_input, full, eps, _Reduction.get_enum(reduction) + ) + return ret + + +def gaussian_nll_loss( + input: Tensor, + target: Tensor, + var: Tensor, + full: bool = False, + eps: float = 1e-6, + reduction: str = "mean", +) -> Tensor: + r"""Gaussian negative log likelihood loss. + + See :class:`~torch.nn.GaussianNLLLoss` for details. + + Args: + input: expectation of the Gaussian distribution. + target: sample from the Gaussian distribution. + var: tensor of positive variance(s), one for each of the expectations + in the input (heteroscedastic), or a single one (homoscedastic). + full (bool, optional): include the constant term in the loss calculation. Default: ``False``. + eps (float, optional): value added to var, for stability. Default: 1e-6. + reduction (str, optional): specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output is the average of all batch member losses, + ``'sum'``: the output is the sum of all batch member losses. + Default: ``'mean'``. + """ + if has_torch_function_variadic(input, target, var): + return handle_torch_function( + gaussian_nll_loss, + (input, target, var), + input, + target, + var, + full=full, + eps=eps, + reduction=reduction, + ) + + # Check var size + # If var.size == input.size, the case is heteroscedastic and no further checks are needed. + # Otherwise: + if var.size() != input.size(): + # If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case. + # e.g. input.size = (10, 2, 3), var.size = (10, 2) + # -> unsqueeze var so that var.shape = (10, 2, 1) + # this is done so that broadcasting can happen in the loss calculation + if input.size()[:-1] == var.size(): + var = torch.unsqueeze(var, -1) + + # This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1. + # This is also a homoscedastic case. + # e.g. input.size = (10, 2, 3), var.size = (10, 2, 1) + elif ( + input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1 + ): # Heteroscedastic case + pass + + # If none of the above pass, then the size of var is incorrect. + else: + raise ValueError("var is of incorrect size") + + # Check validity of reduction mode + if reduction != "none" and reduction != "mean" and reduction != "sum": + raise ValueError(reduction + " is not valid") + + # Entries of var must be non-negative + if torch.any(var < 0): + raise ValueError("var has negative entry/entries") + + # Clamp for stability + var = var.clone() + with torch.no_grad(): + var.clamp_(min=eps) + + # Calculate the loss + loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var) + if full: + loss += 0.5 * math.log(2 * math.pi) + + if reduction == "mean": + return loss.mean() + elif reduction == "sum": + return loss.sum() + else: + return loss + + +def kl_div( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + log_target: bool = False, +) -> Tensor: + r"""Compute the KL Divergence loss. + + Refer - The `Kullback-Leibler divergence Loss + `__ + + See :class:`~torch.nn.KLDivLoss` for details. + + Args: + input: Tensor of arbitrary shape in log-probabilities. + target: Tensor of the same shape as input. See :attr:`log_target` for + the target's interpretation. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when reduce is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. + ``'none'``: no reduction will be applied + ``'batchmean'``: the sum of the output will be divided by the batchsize + ``'sum'``: the output will be summed + ``'mean'``: the output will be divided by the number of elements in the output + Default: ``'mean'`` + log_target (bool): A flag indicating whether ``target`` is passed in the log space. + It is recommended to pass certain distributions (like ``softmax``) + in the log space to avoid numerical issues caused by explicit ``log``. + Default: ``False`` + + .. note:: + :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, + and in the meantime, specifying either of those two args will override :attr:`reduction`. + + .. warning:: + :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use + :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + kl_div, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + log_target=log_target, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + if reduction == "mean": + warnings.warn( + "reduction: 'mean' divides the total loss by both the batch size and the support size." + "'batchmean' divides only by the batch size, and aligns with the KL div math definition." + "'mean' will be changed to behave the same as 'batchmean' in the next major release." + ) + + # special case for batchmean + if reduction == "batchmean": + reduction_enum = _Reduction.get_enum("sum") + else: + reduction_enum = _Reduction.get_enum(reduction) + + reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target) + + if reduction == "batchmean" and input.dim() != 0: + reduced = reduced / input.size()[0] + + return reduced + + +def cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", + label_smoothing: float = 0.0, +) -> Tensor: + r"""Compute the cross entropy loss between input logits and target. + + See :class:`~torch.nn.CrossEntropyLoss` for details. + + Args: + input (Tensor) : Predicted unnormalized logits; + see Shape section below for supported shapes. + target (Tensor) : Ground truth class indices or class probabilities; + see Shape section below for supported shapes. + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, has to be a Tensor of size `C` + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when reduce is ``False``. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Note that + :attr:`ignore_index` is only applicable when the target contains class indices. + Default: -100 + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. + + Shape: + - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with + :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. + If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`. + + where: + + .. math:: + \begin{aligned} + C ={} & \text{number of classes} \\ + N ={} & \text{batch size} \\ + \end{aligned} + + Examples:: + + >>> # Example of target with class indices + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randint(5, (3,), dtype=torch.int64) + >>> loss = F.cross_entropy(input, target) + >>> loss.backward() + >>> + >>> # Example of target with class probabilities + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5).softmax(dim=1) + >>> loss = F.cross_entropy(input, target) + >>> loss.backward() + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + cross_entropy, + (input, target, weight), + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + label_smoothing=label_smoothing, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + return torch._C._nn.cross_entropy_loss( + input, + target, + weight, + _Reduction.get_enum(reduction), + ignore_index, + label_smoothing, + ) + + +def binary_cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + r"""Measure Binary Cross Entropy between the target and input probabilities. + + See :class:`~torch.nn.BCELoss` for details. + + Args: + input: Tensor of arbitrary shape as probabilities. + target: Tensor of the same shape as input with values between 0 and 1. + weight (Tensor, optional): a manual rescaling weight + if provided it's repeated to match input tensor shape + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when reduce is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Examples:: + + >>> input = torch.randn(3, 2, requires_grad=True) + >>> target = torch.rand(3, 2, requires_grad=False) + >>> loss = F.binary_cross_entropy(torch.sigmoid(input), target) + >>> loss.backward() + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + binary_cross_entropy, + (input, target, weight), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if target.size() != input.size(): + raise ValueError( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}) is deprecated. " + "Please ensure they have the same size." + ) + + if weight is not None: + new_size = _infer_size(target.size(), weight.size()) + weight = weight.expand(new_size) + + return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) + + +def binary_cross_entropy_with_logits( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + pos_weight: Optional[Tensor] = None, +) -> Tensor: + r"""Calculate Binary Cross Entropy between target and input logits. + + See :class:`~torch.nn.BCEWithLogitsLoss` for details. + + Args: + input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits). + target: Tensor of the same shape as input with values between 0 and 1 + weight (Tensor, optional): a manual rescaling weight + if provided it's repeated to match input tensor shape + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when reduce is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target. + Must be a tensor with equal size along the class dimension to the number of classes. + Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired + operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of + size [B, C, H, W] will apply different pos_weights to each element of the batch or + [C, H, W] the same pos_weights across the batch. To apply the same positive weight + along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. + Default: ``None`` + + Examples:: + + >>> input = torch.randn(3, requires_grad=True) + >>> target = torch.empty(3).random_(2) + >>> loss = F.binary_cross_entropy_with_logits(input, target) + >>> loss.backward() + """ + if has_torch_function_variadic(input, target, weight, pos_weight): + return handle_torch_function( + binary_cross_entropy_with_logits, + (input, target, weight, pos_weight), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + pos_weight=pos_weight, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + + if not (target.size() == input.size()): + raise ValueError( + f"Target size ({target.size()}) must be the same as input size ({input.size()})" + ) + + return torch.binary_cross_entropy_with_logits( + input, target, weight, pos_weight, reduction_enum + ) + + +def smooth_l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> Tensor: + r"""Compute the Smooth L1 loss. + + Function uses a squared term if the absolute + element-wise error falls below beta and an L1 term otherwise. + + See :class:`~torch.nn.SmoothL1Loss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + smooth_l1_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + beta=beta, + ) + if not (target.size() == input.size()): + warnings.warn( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + + if beta == 0.0: + return torch._C._nn.l1_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction) + ) + else: + return torch._C._nn.smooth_l1_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction), beta + ) + + +def huber_loss( + input: Tensor, + target: Tensor, + reduction: str = "mean", + delta: float = 1.0, +) -> Tensor: + r"""Compute the Huber loss. + + Function uses a squared term if the absolute + element-wise error falls below delta and a delta-scaled L1 term otherwise. + + When delta equals 1, this loss is equivalent to SmoothL1Loss. + In general, Huber loss differs from SmoothL1Loss by a factor of delta (AKA beta in Smooth L1). + + See :class:`~torch.nn.HuberLoss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + huber_loss, + (input, target), + input, + target, + reduction=reduction, + delta=delta, + ) + if not (target.size() == input.size()): + warnings.warn( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + return torch._C._nn.huber_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction), delta + ) + + +def l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor + + Function that takes the mean element-wise absolute value difference. + + See :class:`~torch.nn.L1Loss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + l1_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if not (target.size() == input.size()): + warnings.warn( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + return torch._C._nn.l1_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction) + ) + + +def mse_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor + + Measures the element-wise mean squared error. + See :class:`~torch.nn.MSELoss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + mse_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if not (target.size() == input.size()): + warnings.warn( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + return torch._C._nn.mse_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction) + ) + + +def margin_ranking_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.MarginRankingLoss` for details. + """ + if has_torch_function_variadic(input1, input2, target): + return handle_torch_function( + margin_ranking_loss, + (input1, input2, target), + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if input1.dim() != input2.dim() or input1.dim() != target.dim(): + raise RuntimeError( + f"margin_ranking_loss : All input tensors should have same dimension but got sizes: " + f"input1: {input1.size()}, input2: {input2.size()}, target: {target.size()} " + ) + return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum) + + +def hinge_embedding_loss( + input: Tensor, + target: Tensor, + margin: float = 1.0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.HingeEmbeddingLoss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + hinge_embedding_loss, + (input, target), + input, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch.hinge_embedding_loss(input, target, margin, reduction_enum) + + +def multilabel_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.MultiLabelMarginLoss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + multilabel_margin_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) + + +def soft_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r""" + soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.SoftMarginLoss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + soft_margin_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch._C._nn.soft_margin_loss(input, target, reduction_enum) + + +def multilabel_soft_margin_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""multilabel_soft_margin_loss(input, target, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details. + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + multilabel_soft_margin_loss, + (input, target, weight), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input)) + + if weight is not None: + loss = loss * weight + + class_dim = input.dim() - 1 + C = input.size(class_dim) + loss = loss.sum(dim=class_dim) / C # only return N loss values + + if reduction == "none": + ret = loss + elif reduction == "mean": + ret = loss.mean() + elif reduction == "sum": + ret = loss.sum() + else: + ret = input + raise ValueError(reduction + " is not valid") + return ret + + +def cosine_embedding_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.CosineEmbeddingLoss` for details. + """ + if has_torch_function_variadic(input1, input2, target): + return handle_torch_function( + cosine_embedding_loss, + (input1, input2, target), + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum) + + +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: int = 1, + margin: float = 1.0, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.MultiMarginLoss` for details. + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + multi_margin_loss, + (input, target, weight), + input, + target, + p=p, + margin=margin, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if p != 1 and p != 2: + raise ValueError("only p == 1 and p == 2 supported") + if weight is not None: + if weight.dim() != 1: + raise ValueError("weight must be one-dimensional") + + return torch._C._nn.multi_margin_loss( + input, target, p, margin, weight, reduction_enum + ) + + +pixel_shuffle = _add_docstr( + torch.pixel_shuffle, + r""" +pixel_shuffle(input, upscale_factor) -> Tensor + +Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a +tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is the :attr:`upscale_factor`. + +See :class:`~torch.nn.PixelShuffle` for details. + +Args: + input (Tensor): the input tensor + upscale_factor (int): factor to increase spatial resolution by + +Examples:: + + >>> input = torch.randn(1, 9, 4, 4) + >>> output = torch.nn.functional.pixel_shuffle(input, 3) + >>> print(output.size()) + torch.Size([1, 1, 12, 12]) +""", +) + +pixel_unshuffle = _add_docstr( + torch.pixel_unshuffle, + r""" +pixel_unshuffle(input, downscale_factor) -> Tensor + +Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a +tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape +:math:`(*, C \times r^2, H, W)`, where r is the :attr:`downscale_factor`. + +See :class:`~torch.nn.PixelUnshuffle` for details. + +Args: + input (Tensor): the input tensor + downscale_factor (int): factor to increase spatial resolution by + +Examples:: + + >>> input = torch.randn(1, 1, 12, 12) + >>> output = torch.nn.functional.pixel_unshuffle(input, 3) + >>> print(output.size()) + torch.Size([1, 9, 4, 4]) +""", +) + +channel_shuffle = _add_docstr( + torch.channel_shuffle, + r""" +channel_shuffle(input, groups) -> Tensor + +Divide the channels in a tensor of shape :math:`(*, C , H, W)` +into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`, +while keeping the original tensor shape. + +See :class:`~torch.nn.ChannelShuffle` for details. + +Args: + input (Tensor): the input tensor + groups (int): number of groups to divide channels in and rearrange. + +Examples:: + + >>> input = torch.randn(1, 4, 2, 2) + >>> print(input) + [[[[1, 2], + [3, 4]], + [[5, 6], + [7, 8]], + [[9, 10], + [11, 12]], + [[13, 14], + [15, 16]], + ]] + >>> output = torch.nn.functional.channel_shuffle(input, 2) + >>> print(output) + [[[[1, 2], + [3, 4]], + [[9, 10], + [11, 12]], + [[5, 6], + [7, 8]], + [[13, 14], + [15, 16]], + ]] +""", +) + +native_channel_shuffle = _add_docstr( + torch.native_channel_shuffle, + r""" +native_channel_shuffle(input, groups) -> Tensor + +Native kernel level implementation of the `channel_shuffle`. +This function might become private in future releases, use with caution. + +Divide the channels in a tensor of shape :math:`(*, C , H, W)` +into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`, +while keeping the original tensor shape. + +See :class:`~torch.nn.ChannelShuffle` for details. + +Args: + input (Tensor): the input tensor + groups (int): number of groups to divide channels in and rearrange. + +Examples:: + + >>> input = torch.randn(1, 4, 2, 2) + >>> print(input) + [[[[1, 2], + [3, 4]], + [[5, 6], + [7, 8]], + [[9, 10], + [11, 12]], + [[13, 14], + [15, 16]], + ]] + >>> output = torch.nn.functional.native_channel_shuffle(input, 2) + >>> print(output) + [[[[1, 2], + [3, 4]], + [[9, 10], + [11, 12]], + [[5, 6], + [7, 8]], + [[13, 14], + [15, 16]], + ]] +""", +) + + +@_overload +def upsample( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, +) -> Tensor: # noqa: B950 + pass + + +@_overload +def upsample( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, +) -> Tensor: # noqa: B950 + pass + + +def upsample( # noqa: F811 + input, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, +): + r"""Upsample input. + + Provided tensor is upsampled to either the given :attr:`size` or the given + :attr:`scale_factor` + + .. warning:: + This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. + This is equivalent with ``nn.functional.interpolate(...)``. + + Note: + {backward_reproducibility_note} + + The algorithm used for upsampling is determined by :attr:`mode`. + + Currently temporal, spatial and volumetric upsampling are supported, i.e. + expected inputs are 3-D, 4-D or 5-D in shape. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + The modes available for upsampling are: `nearest`, `linear` (3D-only), + `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only) + + Args: + input (Tensor): the input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. + Default: ``False`` + + .. note:: + With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce + negative values or values greater than 255 for images. + Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot + when displaying the image. + + .. warning:: + With ``align_corners = True``, the linearly interpolating modes + (`linear`, `bilinear`, and `trilinear`) don't proportionally align the + output and input pixels, and thus the output values can depend on the + input size. This was the default behavior for these modes up to version + 0.3.1. Since then, the default behavior is ``align_corners = False``. + See :class:`~torch.nn.Upsample` for concrete examples on how this + affects the outputs. + + """ + warnings.warn( + "`nn.functional.upsample` is deprecated. " + "Use `nn.functional.interpolate` instead.", + stacklevel=2, + ) + return interpolate(input, size, scale_factor, mode, align_corners) + + +if upsample.__doc__: + upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes) + + +def _is_integer(x) -> bool: + r"""Type check the input number is an integer. + + Will return True for int, SymInt, Numpy integers and Tensors with integer elements. + """ + if isinstance(x, (int, torch.SymInt)): + return True + if np is not None and isinstance(x, np.integer): + return True + return isinstance(x, Tensor) and not x.is_floating_point() + + +@_overload +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[List[float]] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: B950 + pass + + +@_overload +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[List[float]] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: B950 + pass + + +@_overload +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: B950 + pass + + +@_overload +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: + pass + + +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[List[float]] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: B950 + r"""Down/up samples the input. + + Tensor interpolated to either the given :attr:`size` or the given + :attr:`scale_factor` + + The algorithm used for interpolation is determined by :attr:`mode`. + + Currently temporal, spatial and volumetric sampling are supported, i.e. + expected inputs are 3-D, 4-D or 5-D in shape. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + The modes available for resizing are: `nearest`, `linear` (3D-only), + `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`, `nearest-exact` + + Args: + input (Tensor): the input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. If `scale_factor` is a tuple, + its length has to match the number of spatial dimensions; `input.dim() - 2`. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. + Default: ``False`` + recompute_scale_factor (bool, optional): recompute the scale_factor for use in the + interpolation calculation. If `recompute_scale_factor` is ``True``, then + `scale_factor` must be passed in and `scale_factor` is used to compute the + output `size`. The computed output `size` will be used to infer new scales for + the interpolation. Note that when `scale_factor` is floating-point, it may differ + from the recomputed `scale_factor` due to rounding and precision issues. + If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will + be used directly for interpolation. Default: ``None``. + antialias (bool, optional): flag to apply anti-aliasing. Default: ``False``. Using anti-alias + option together with ``align_corners=False``, interpolation result would match Pillow + result for downsampling operation. Supported modes: ``'bilinear'``, ``'bicubic'``. + + .. note:: + With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce + negative values or values greater than 255 for images. + Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot + when displaying the image. + + .. note:: + Mode ``mode='nearest-exact'`` matches Scikit-Image and PIL nearest neighbours interpolation + algorithms and fixes known issues with ``mode='nearest'``. This mode is introduced to keep + backward compatibility. + Mode ``mode='nearest'`` matches buggy OpenCV's ``INTER_NEAREST`` interpolation algorithm. + + .. note:: + The gradients for the dtype ``float16`` on CUDA may be inaccurate in the upsample operation + when using modes ``['linear', 'bilinear', 'bicubic', 'trilinear', 'area']``. + For more details, please refer to the discussion in + `issue#104157 `_. + + Note: + {backward_reproducibility_note} + """ + if has_torch_function_unary(input): + return handle_torch_function( + interpolate, + (input,), + input, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + antialias=antialias, + ) + + if mode in ("nearest", "area", "nearest-exact"): + if align_corners is not None: + raise ValueError( + "align_corners option can only be set with the " + "interpolating modes: linear | bilinear | bicubic | trilinear" + ) + else: + if align_corners is None: + align_corners = False + + dim = input.dim() - 2 # Number of spatial dimensions. + + # Process size and scale_factor. Validate that exactly one is set. + # Validate its length if it is a list, or expand it if it is a scalar. + # After this block, exactly one of output_size and scale_factors will + # be non-None, and it will be a list (or tuple). + if size is not None and scale_factor is not None: + raise ValueError("only one of size or scale_factor should be defined") + elif size is not None: + assert scale_factor is None + scale_factors = None + if isinstance(size, (list, tuple)): + if len(size) != dim: + raise ValueError( + "Input and output must have the same number of spatial dimensions, but got " + f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. " + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " + "output size in (o1, o2, ...,oK) format." + ) + if not torch.jit.is_scripting(): + if not all(_is_integer(x) for x in size): + raise TypeError( + "expected size to be one of int or Tuple[int] or Tuple[int, int] or " + f"Tuple[int, int, int], but got size with types {[type(x) for x in size]}" + ) + output_size = size + else: + output_size = [size for _ in range(dim)] + elif scale_factor is not None: + assert size is None + output_size = None + if isinstance(scale_factor, (list, tuple)): + if len(scale_factor) != dim: + raise ValueError( + "Input and scale_factor must have the same number of spatial dimensions, but " + f"got input with spatial dimensions of {list(input.shape[2:])} and " + f"scale_factor of shape {scale_factor}. " + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " + "scale_factor in (s1, s2, ...,sK) format." + ) + scale_factors = scale_factor + else: + scale_factors = [scale_factor for _ in range(dim)] + else: + raise ValueError("either size or scale_factor should be defined") + + if ( + recompute_scale_factor is not None + and recompute_scale_factor + and size is not None + ): + raise ValueError( + "recompute_scale_factor is not meaningful with an explicit size." + ) + + # "area" mode always requires an explicit size rather than scale factor. + # Re-use the recompute_scale_factor code path. + if mode == "area" and output_size is None: + recompute_scale_factor = True + + if recompute_scale_factor is not None and recompute_scale_factor: + # We compute output_size here, then un-set scale_factors. + # The C++ code will recompute it based on the (integer) output size. + assert scale_factors is not None + if not torch.jit.is_scripting() and torch._C._get_tracing_state(): + # make scale_factor a tensor in tracing so constant doesn't get baked in + output_size = [ + ( + torch.floor( + ( + input.size(i + 2).float() + * torch.tensor(scale_factors[i], dtype=torch.float32) + ).float() + ) + ) + for i in range(dim) + ] + elif torch.jit.is_scripting(): + output_size = [ + int(math.floor(float(input.size(i + 2)) * scale_factors[i])) + for i in range(dim) + ] + else: + output_size = [ + _sym_int(input.size(i + 2) * scale_factors[i]) for i in range(dim) + ] + scale_factors = None + + if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4): + raise ValueError( + "Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input" + ) + + if input.dim() == 3 and mode == "nearest": + return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) + if input.dim() == 4 and mode == "nearest": + return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) + if input.dim() == 5 and mode == "nearest": + return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) + + if input.dim() == 3 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) + if input.dim() == 4 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) + if input.dim() == 5 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) + + if input.dim() == 3 and mode == "area": + assert output_size is not None + return adaptive_avg_pool1d(input, output_size) + if input.dim() == 4 and mode == "area": + assert output_size is not None + return adaptive_avg_pool2d(input, output_size) + if input.dim() == 5 and mode == "area": + assert output_size is not None + return adaptive_avg_pool3d(input, output_size) + + if input.dim() == 3 and mode == "linear": + assert align_corners is not None + return torch._C._nn.upsample_linear1d( + input, output_size, align_corners, scale_factors + ) + if input.dim() == 4 and mode == "bilinear": + assert align_corners is not None + if antialias: + return torch._C._nn._upsample_bilinear2d_aa( + input, output_size, align_corners, scale_factors + ) + # Two levels are necessary to prevent TorchScript from touching + # are_deterministic_algorithms_enabled. + if not torch.jit.is_scripting(): + if torch.are_deterministic_algorithms_enabled() and ( + input.is_cuda or input.is_xpu + ): + # Use slow decomp whose backward will be in terms of index_put + # importlib is required because the import cannot be top level + # (cycle) and cannot be nested (TS doesn't support) + return importlib.import_module( + "torch._decomp.decompositions" + )._upsample_linear_vec(input, output_size, align_corners, scale_factors) + return torch._C._nn.upsample_bilinear2d( + input, output_size, align_corners, scale_factors + ) + if input.dim() == 5 and mode == "trilinear": + assert align_corners is not None + return torch._C._nn.upsample_trilinear3d( + input, output_size, align_corners, scale_factors + ) + if input.dim() == 4 and mode == "bicubic": + assert align_corners is not None + if antialias: + return torch._C._nn._upsample_bicubic2d_aa( + input, output_size, align_corners, scale_factors + ) + return torch._C._nn.upsample_bicubic2d( + input, output_size, align_corners, scale_factors + ) + + if input.dim() == 3 and mode == "bilinear": + raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") + if input.dim() == 3 and mode == "trilinear": + raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") + if input.dim() == 4 and mode == "linear": + raise NotImplementedError("Got 4D input, but linear mode needs 3D input") + if input.dim() == 4 and mode == "trilinear": + raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") + if input.dim() == 5 and mode == "linear": + raise NotImplementedError("Got 5D input, but linear mode needs 3D input") + if input.dim() == 5 and mode == "bilinear": + raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") + + raise NotImplementedError( + "Input Error: Only 3D, 4D and 5D input Tensors supported" + f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact" + f" (got {mode})" + ) + + +if interpolate.__doc__: + interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes) + + +@_overload +def upsample_nearest( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[float] = None, +) -> Tensor: + pass + + +@_overload +def upsample_nearest( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[float] = None, +) -> Tensor: + pass + + +def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 + r"""Upsamples the input, using nearest neighbours' pixel values. + + .. warning:: + This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. + This is equivalent with ``nn.functional.interpolate(..., mode='nearest')``. + + Currently spatial and volumetric upsampling are supported (i.e. expected + inputs are 4 or 5 dimensional). + + Args: + input (Tensor): input + size (int or Tuple[int, int] or Tuple[int, int, int]): output spatia + size. + scale_factor (int): multiplier for spatial size. Has to be an integer. + + Note: + {backward_reproducibility_note} + """ + # DeprecationWarning is ignored by default + warnings.warn( + "`nn.functional.upsample_nearest` is deprecated. " + "Use `nn.functional.interpolate` instead.", + stacklevel=2, + ) + return interpolate(input, size, scale_factor, mode="nearest") + + +if upsample_nearest.__doc__: + upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes) + + +@_overload +def upsample_bilinear( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[float] = None, +) -> Tensor: + pass + + +@_overload +def upsample_bilinear( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[float] = None, +) -> Tensor: + pass + + +@_overload +def upsample_bilinear( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[List[float]] = None, +) -> Tensor: + pass + + +@_overload +def upsample_bilinear( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[List[float]] = None, +) -> Tensor: + pass + + +def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 + r"""Upsamples the input, using bilinear upsampling. + + .. warning:: + This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. + This is equivalent with + ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``. + + Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` fo + volumetric (5 dimensional) inputs. + + Args: + input (Tensor): input + size (int or Tuple[int, int]): output spatial size. + scale_factor (int or Tuple[int, int]): multiplier for spatial size + + Note: + {backward_reproducibility_note} + """ + # DeprecationWarning is ignored by default + warnings.warn( + "`nn.functional.upsample_bilinear` is deprecated. " + "Use `nn.functional.interpolate` instead.", + stacklevel=2, + ) + return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) + + +if upsample_bilinear.__doc__: + upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format( + **reproducibility_notes + ) + +GRID_SAMPLE_INTERPOLATION_MODES = { + "bilinear": 0, + "nearest": 1, + "bicubic": 2, +} + +GRID_SAMPLE_PADDING_MODES = { + "zeros": 0, + "border": 1, + "reflection": 2, +} + + +def grid_sample( + input: Tensor, + grid: Tensor, + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: Optional[bool] = None, +) -> Tensor: + r"""Compute grid sample. + + Given an :attr:`input` and a flow-field :attr:`grid`, computes the + ``output`` using :attr:`input` values and pixel locations from :attr:`grid`. + + Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are + supported. + + In the spatial (4-D) case, for :attr:`input` with shape + :math:`(N, C, H_\text{in}, W_\text{in})` and :attr:`grid` with shape + :math:`(N, H_\text{out}, W_\text{out}, 2)`, the output will have shape + :math:`(N, C, H_\text{out}, W_\text{out})`. + + For each output location ``output[n, :, h, w]``, the size-2 vector + ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``, + which are used to interpolate the output value ``output[n, :, h, w]``. + In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the + ``x``, ``y``, ``z`` pixel locations for interpolating + ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or + ``bilinear`` interpolation method to sample the input pixels. + + :attr:`grid` specifies the sampling pixel locations normalized by the + :attr:`input` spatial dimensions. Therefore, it should have most values in + the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the + left-top pixel of :attr:`input`, and values ``x = 1, y = 1`` is the + right-bottom pixel of :attr:`input`. + + If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding + outputs are handled as defined by :attr:`padding_mode`. Options are + + * ``padding_mode="zeros"``: use ``0`` for out-of-bound grid locations, + * ``padding_mode="border"``: use border values for out-of-bound grid locations, + * ``padding_mode="reflection"``: use values at locations reflected by + the border for out-of-bound grid locations. For location far away + from the border, it will keep being reflected until becoming in bound, + e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1`` + and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes + ``x'' = -0.5``. + + Note: + This function is often used in conjunction with :func:`affine_grid` + to build `Spatial Transformer Networks`_ . + + Note: + When using the CUDA backend, this operation may induce nondeterministic + behaviour in its backward pass that is not easily switched off. + Please see the notes on :doc:`/notes/randomness` for background. + + Note: + NaN values in :attr:`grid` would be interpreted as ``-1``. + + Args: + input (Tensor): input of shape :math:`(N, C, H_\text{in}, W_\text{in})` (4-D case) + or :math:`(N, C, D_\text{in}, H_\text{in}, W_\text{in})` (5-D case) + grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case) + or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case) + mode (str): interpolation mode to calculate output values + ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'`` + Note: ``mode='bicubic'`` supports only 4-D input. + When ``mode='bilinear'`` and the input is 5-D, the interpolation mode + used internally will actually be trilinear. However, when the input is 4-D, + the interpolation mode will legitimately be bilinear. + padding_mode (str): padding mode for outside grid values + ``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input as squares rather than points. + If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring + to the center points of the input's corner pixels. If set to ``False``, they + are instead considered as referring to the corner points of the input's corner + pixels, making the sampling more resolution agnostic. + This option parallels the ``align_corners`` option in + :func:`interpolate`, and so whichever option is used here + should also be used there to resize the input image before grid sampling. + Default: ``False`` + + Returns: + output (Tensor): output Tensor + + .. _`Spatial Transformer Networks`: + https://arxiv.org/abs/1506.02025 + + .. warning:: + When ``align_corners = True``, the grid positions depend on the pixel + size relative to the input image size, and so the locations sampled by + :func:`grid_sample` will differ for the same input given at different + resolutions (that is, after being upsampled or downsampled). + The default behavior up to version 1.2.0 was ``align_corners = True``. + Since then, the default behavior has been changed to ``align_corners = False``, + in order to bring it in line with the default for :func:`interpolate`. + + .. note:: + ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`. + The constant :math:`\alpha` might be different from packages to packages. + For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. + This algorithm may "overshoot" the range of values it's interpolating. + For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. + Clamp the results with :func:`torch.clamp` to ensure they are within the valid range. + .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation + .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 + .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908 + """ + if has_torch_function_variadic(input, grid): + return handle_torch_function( + grid_sample, + (input, grid), + input, + grid, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + if mode != "bilinear" and mode != "nearest" and mode != "bicubic": + raise ValueError( + f"nn.functional.grid_sample(): expected mode to be 'bilinear', 'nearest' or 'bicubic', but got: '{mode}'" + ) + if ( + padding_mode != "zeros" + and padding_mode != "border" + and padding_mode != "reflection" + ): + raise ValueError( + "nn.functional.grid_sample(): expected padding_mode " + "to be 'zeros', 'border', or 'reflection', " + f"but got: '{padding_mode}'" + ) + + if mode == "bilinear": + mode_enum = 0 + elif mode == "nearest": + mode_enum = 1 + else: # mode == 'bicubic' + mode_enum = 2 + + if padding_mode == "zeros": + padding_mode_enum = 0 + elif padding_mode == "border": + padding_mode_enum = 1 + else: # padding_mode == 'reflection' + padding_mode_enum = 2 + + if align_corners is None: + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) + align_corners = False + + return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) + + +def affine_grid( + theta: Tensor, + size: List[int], + align_corners: Optional[bool] = None, +) -> Tensor: + r"""Generate 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`. + + .. note:: + This function is often used in conjunction with :func:`grid_sample` + to build `Spatial Transformer Networks`_ . + + Args: + theta (Tensor): input batch of affine matrices with shape + (:math:`N \times 2 \times 3`) for 2D or + (:math:`N \times 3 \times 4`) for 3D + size (torch.Size): the target output image size. + (:math:`N \times C \times H \times W` for 2D or + :math:`N \times C \times D \times H \times W` for 3D) + Example: torch.Size((32, 3, 24, 24)) + align_corners (bool, optional): if ``True``, consider ``-1`` and ``1`` + to refer to the centers of the corner pixels rather than the image corners. + Refer to :func:`grid_sample` for a more complete description. + A grid generated by :func:`affine_grid` should be passed to :func:`grid_sample` + with the same setting for this option. + Default: ``False`` + + Returns: + output (Tensor): output Tensor of size (:math:`N \times H \times W \times 2`) + + .. _`Spatial Transformer Networks`: + https://arxiv.org/abs/1506.02025 + + .. warning:: + When ``align_corners = True``, the grid positions depend on the pixel + size relative to the input image size, and so the locations sampled by + :func:`grid_sample` will differ for the same input given at different + resolutions (that is, after being upsampled or downsampled). + The default behavior up to version 1.2.0 was ``align_corners = True``. + Since then, the default behavior has been changed to ``align_corners = False``, + in order to bring it in line with the default for :func:`interpolate`. + .. warning:: + When ``align_corners = True``, 2D affine transforms on 1D data and + 3D affine transforms on 2D data (that is, when one of the spatial + dimensions has unit size) are ill-defined, and not an intended use case. + This is not a problem when ``align_corners = False``. + Up to version 1.2.0, all grid points along a unit dimension were + considered arbitrarily to be at ``-1``. + From version 1.3.0, under ``align_corners = True`` all grid points + along a unit dimension are considered to be at ``0`` + (the center of the input image). + """ + if has_torch_function_unary(theta): + return handle_torch_function( + affine_grid, (theta,), theta, size, align_corners=align_corners + ) + if align_corners is None: + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) + align_corners = False + + # enforce floating point dtype on theta + if not theta.is_floating_point(): + raise ValueError( + f"Expected theta to have floating point type, but got {theta.dtype}" + ) + # check that shapes and sizes match + if len(size) == 4: + if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3: + raise ValueError( + f"Expected a batch of 2D affine matrices of shape Nx2x3 for size {size}. Got {theta.shape}." + ) + spatial_size = size[-2:] # spatial dimension sizes + elif len(size) == 5: + if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4: + raise ValueError( + f"Expected a batch of 3D affine matrices of shape Nx3x4 for size {size}. Got {theta.shape}." + ) + spatial_size = size[-3:] # spatial dimension sizes + else: + raise NotImplementedError( + "affine_grid only supports 4D and 5D sizes, " + "for 2D and 3D affine transforms, respectively. " + f"Got size {size}." + ) + # check for empty span + if align_corners and min(spatial_size) == 1: + warnings.warn( + "Since version 1.3.0, affine_grid behavior has changed " + "for unit-size grids when align_corners=True. " + "This is not an intended use case of affine_grid. " + "See the documentation of affine_grid for details." + ) + elif min(size) <= 0: + raise ValueError(f"Expected non-zero, positive output size. Got {size}") + + return torch.affine_grid_generator(theta, size, align_corners) + + +def pad( + input: Tensor, + pad: List[int], + mode: str = "constant", + value: Optional[float] = None, +) -> Tensor: + r""" + pad(input, pad, mode="constant", value=None) -> Tensor + + Pads tensor. + + Padding size: + The padding size by which to pad some dimensions of :attr:`input` + are described starting from the last dimension and moving forward. + :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions + of ``input`` will be padded. + For example, to pad only the last dimension of the input tensor, then + :attr:`pad` has the form + :math:`(\text{padding\_left}, \text{padding\_right})`; + to pad the last 2 dimensions of the input tensor, then use + :math:`(\text{padding\_left}, \text{padding\_right},` + :math:`\text{padding\_top}, \text{padding\_bottom})`; + to pad the last 3 dimensions, use + :math:`(\text{padding\_left}, \text{padding\_right},` + :math:`\text{padding\_top}, \text{padding\_bottom}` + :math:`\text{padding\_front}, \text{padding\_back})`. + + Padding mode: + See :class:`torch.nn.CircularPad2d`, :class:`torch.nn.ConstantPad2d`, + :class:`torch.nn.ReflectionPad2d`, and :class:`torch.nn.ReplicationPad2d` + for concrete examples on how each of the padding modes works. Constant + padding is implemented for arbitrary dimensions. Circular, replicate and + reflection padding are implemented for padding the last 3 dimensions of a + 4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor, + or the last dimension of a 2D or 3D input tensor. + + Note: + When using the CUDA backend, this operation may induce nondeterministic + behaviour in its backward pass that is not easily switched off. + Please see the notes on :doc:`/notes/randomness` for background. + + Args: + input (Tensor): N-dimensional tensor + pad (tuple): m-elements tuple, where + :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even. + mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + Default: ``'constant'`` + value: fill value for ``'constant'`` padding. Default: ``0`` + + Examples:: + + >>> t4d = torch.empty(3, 3, 4, 2) + >>> p1d = (1, 1) # pad last dim by 1 on each side + >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding + >>> print(out.size()) + torch.Size([3, 3, 4, 4]) + >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2) + >>> out = F.pad(t4d, p2d, "constant", 0) + >>> print(out.size()) + torch.Size([3, 3, 8, 4]) + >>> t4d = torch.empty(3, 3, 4, 2) + >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3) + >>> out = F.pad(t4d, p3d, "constant", 0) + >>> print(out.size()) + torch.Size([3, 9, 7, 3]) + """ + if has_torch_function_unary(input): + return handle_torch_function( + torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value + ) + if not torch.jit.is_scripting(): + if torch.are_deterministic_algorithms_enabled() and ( + input.is_cuda or input.is_xpu + ): + if mode == "replicate": + # Use slow decomp whose backward will be in terms of index_put. + # importlib is required because the import cannot be top level + # (cycle) and cannot be nested (TS doesn't support) + return importlib.import_module( + "torch._decomp.decompositions" + )._replication_pad(input, pad) + return torch._C._nn.pad(input, pad, mode, value) + + +# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 +pad.__module__ = "torch.nn.functional" + +# distance + + +pairwise_distance = _add_docstr( + torch.pairwise_distance, + r""" +pairwise_distance(x1, x2, p=2.0, eps=1e-6, keepdim=False) -> Tensor + +See :class:`torch.nn.PairwiseDistance` for details +""", +) + + +pdist = _add_docstr( + torch.pdist, + r""" +pdist(input, p=2) -> Tensor + +Computes the p-norm distance between every pair of row vectors in the input. +This is identical to the upper triangular portion, excluding the diagonal, of +`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster +if the rows are contiguous. + +If input has shape :math:`N \times M` then the output will have shape +:math:`\frac{1}{2} N (N - 1)`. + +This function is equivalent to ``scipy.spatial.distance.pdist(input, +'minkowski', p=p)`` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is +equivalent to ``scipy.spatial.distance.pdist(input, 'hamming') * M``. +When :math:`p = \infty`, the closest scipy function is +``scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())``. + +Args: + input: input tensor of shape :math:`N \times M`. + p: p value for the p-norm distance to calculate between each vector pair + :math:`\in [0, \infty]`. +""", +) + + +cosine_similarity = _add_docstr( + torch.cosine_similarity, + r""" +cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor + +Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable +to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is +squeezed (see :func:`torch.squeeze`), resulting in the +output tensor having 1 fewer dimension. + +.. math :: + \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2, \epsilon) \cdot \max(\Vert x_2 \Vert _2, \epsilon)} + +Supports :ref:`type promotion `. + +Args: + x1 (Tensor): First input. + x2 (Tensor): Second input. + dim (int, optional): Dimension along which cosine similarity is computed. Default: 1 + eps (float, optional): Small value to avoid division by zero. + Default: 1e-8 + +Example:: + + >>> input1 = torch.randn(100, 128) + >>> input2 = torch.randn(100, 128) + >>> output = F.cosine_similarity(input1, input2) + >>> print(output) +""", +) + + +one_hot = _add_docstr( + torch._C._nn.one_hot, + r""" +one_hot(tensor, num_classes=-1) -> LongTensor + +Takes LongTensor with index values of shape ``(*)`` and returns a tensor +of shape ``(*, num_classes)`` that have zeros everywhere except where the +index of last dimension matches the corresponding value of the input tensor, +in which case it will be 1. + +See also `One-hot on Wikipedia`_ . + +.. _One-hot on Wikipedia: + https://en.wikipedia.org/wiki/One-hot + +Arguments: + tensor (LongTensor): class values of any shape. + num_classes (int): Total number of classes. If set to -1, the number + of classes will be inferred as one greater than the largest class + value in the input tensor. + +Returns: + LongTensor that has one more dimension with 1 values at the + index of last dimension indicated by the input, and 0 everywhere + else. + +Examples: + >>> F.one_hot(torch.arange(0, 5) % 3) + tensor([[1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5) + tensor([[1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0]]) + >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) + tensor([[[1, 0, 0], + [0, 1, 0]], + [[0, 0, 1], + [1, 0, 0]], + [[0, 1, 0], + [0, 0, 1]]]) +""", +) + + +def triplet_margin_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + r"""Compute the triplet loss between given input tensors and a margin greater than 0. + + See :class:`~torch.nn.TripletMarginLoss` for details. + """ + if has_torch_function_variadic(anchor, positive, negative): + return handle_torch_function( + triplet_margin_loss, + (anchor, positive, negative), + anchor, + positive, + negative, + margin=margin, + p=p, + eps=eps, + swap=swap, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if margin <= 0: + raise ValueError(f"margin must be greater than 0, got {margin}") + return torch.triplet_margin_loss( + anchor, positive, negative, margin, p, eps, swap, reduction_enum + ) + + +def triplet_margin_with_distance_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + *, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", +) -> Tensor: + r"""Compute the triplet margin loss for input tensors using a custom distance function. + + See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details. + """ + if torch.jit.is_scripting(): + raise NotImplementedError( + "F.triplet_margin_with_distance_loss does not support JIT scripting: " + "functions requiring Callables cannot be scripted." + ) + + if has_torch_function_variadic(anchor, positive, negative): + return handle_torch_function( + triplet_margin_with_distance_loss, + (anchor, positive, negative), + anchor, + positive, + negative, + distance_function=distance_function, + margin=margin, + swap=swap, + reduction=reduction, + ) + + # Check validity of reduction mode + if reduction not in ("mean", "sum", "none"): + raise ValueError(f"{reduction} is not a valid value for reduction") + + # Check validity of margin + if margin <= 0: + raise ValueError(f"margin must be greater than 0, got {margin}") + + # Check dimensions + a_dim = anchor.ndim + p_dim = positive.ndim + n_dim = negative.ndim + if not (a_dim == p_dim and p_dim == n_dim): + raise RuntimeError( + f"The anchor, positive, and negative tensors are expected to have " + f"the same number of dimensions, but got: anchor {a_dim}D, " + f"positive {p_dim}D, and negative {n_dim}D inputs" + ) + + # Calculate loss + if distance_function is None: + distance_function = torch.pairwise_distance + + dist_pos = distance_function(anchor, positive) + dist_neg = distance_function(anchor, negative) + # The distance swap is described in the paper "Learning shallow + # convolutional feature descriptors with triplet losses" by V. Balntas, E. + # Riba et al. If True, and if the positive example is closer to the + # negative example than the anchor is, swaps the positive example and the + # anchor in the loss computation. + if swap: + dist_swap = distance_function(positive, negative) + dist_neg = torch.minimum(dist_neg, dist_swap) + loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) + + # Apply reduction + if reduction == "sum": + return torch.sum(loss) + elif reduction == "mean": + return torch.mean(loss) + else: # reduction == "none" + return loss + + +def normalize( + input: Tensor, + p: float = 2.0, + dim: int = 1, + eps: float = 1e-12, + out: Optional[Tensor] = None, +) -> Tensor: + r"""Perform :math:`L_p` normalization of inputs over specified dimension. + + For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each + :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as + + .. math:: + v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. + + With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization. + + Args: + input: input tensor of any shape + p (float): the exponent value in the norm formulation. Default: 2 + dim (int or tuple of ints): the dimension to reduce. Default: 1 + eps (float): small value to avoid division by zero. Default: 1e-12 + out (Tensor, optional): the output tensor. If :attr:`out` is used, this + operation won't be differentiable. + """ + if has_torch_function_variadic(input, out): + return handle_torch_function( + normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out + ) + if out is None: + denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) + return input / denom + else: + denom = input.norm(p, dim, keepdim=True).clamp_min_(eps).expand_as(input) + return torch.div(input, denom, out=out) + + +def assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None: + assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name) + + +def unfold( + input: Tensor, + kernel_size: BroadcastingList2[int], + dilation: BroadcastingList2[int] = 1, + padding: BroadcastingList2[int] = 0, + stride: BroadcastingList2[int] = 1, +) -> Tensor: + r"""Extract sliding local blocks from a batched input tensor. + + .. warning:: + Currently, only 4-D input tensors (batched image-like tensors) are + supported. + + .. warning:: + + More than one element of the unfolded tensor may refer to a single + memory location. As a result, in-place operations (especially ones that + are vectorized) may result in incorrect behavior. If you need to write + to the tensor, please clone it first. + + + See :class:`torch.nn.Unfold` for details + """ + if has_torch_function_unary(input): + return handle_torch_function( + unfold, + (input,), + input, + kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + ) + return torch._C._nn.im2col( + input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride) + ) + + +def fold( + input: Tensor, + output_size: BroadcastingList2[int], + kernel_size: BroadcastingList2[int], + dilation: BroadcastingList2[int] = 1, + padding: BroadcastingList2[int] = 0, + stride: BroadcastingList2[int] = 1, +) -> Tensor: + r"""Combine an array of sliding local blocks into a large containing tensor. + + .. warning:: + Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported. + + See :class:`torch.nn.Fold` for details + """ + if has_torch_function_unary(input): + return handle_torch_function( + fold, + (input,), + input, + output_size, + kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + ) + return torch._C._nn.col2im( + input, + _pair(output_size), + _pair(kernel_size), + _pair(dilation), + _pair(padding), + _pair(stride), + ) + + +# +# multihead attention +# + + +def _in_projection_packed( + q: Tensor, + k: Tensor, + v: Tensor, + w: Tensor, + b: Optional[Tensor] = None, +) -> List[Tensor]: + r"""Perform the in-projection step of the attention operation, using packed weights. + + Output is a triple containing projection tensors for query, key and value. + + Args: + q, k, v: query, key and value tensors to be projected. For self-attention, + these are typically the same tensor; for encoder-decoder attention, + k and v are typically the same tensor. (We take advantage of these + identities for performance if they are present.) Regardless, q, k and v + must share a common embedding dimension; otherwise their shapes may vary. + w: projection weights for q, k and v, packed into a single tensor. Weights + are packed along dimension 0, in q, k, v order. + b: optional projection biases for q, k and v, packed into a single tensor + in q, k, v order. + + Shape: + Inputs: + - q: :math:`(..., E)` where E is the embedding dimension + - k: :math:`(..., E)` where E is the embedding dimension + - v: :math:`(..., E)` where E is the embedding dimension + - w: :math:`(E * 3, E)` where E is the embedding dimension + - b: :math:`E * 3` where E is the embedding dimension + + Output: + - in output list :math:`[q', k', v']`, each output tensor will have the + same shape as the corresponding input tensor. + """ + E = q.size(-1) + if k is v: + if q is k: + # self-attention + proj = linear(q, w, b) + # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() + proj = ( + proj.unflatten(-1, (3, E)) + .unsqueeze(0) + .transpose(0, -2) + .squeeze(-2) + .contiguous() + ) + return proj[0], proj[1], proj[2] + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + q_proj = linear(q, w_q, b_q) + kv_proj = linear(k, w_kv, b_kv) + # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() + kv_proj = ( + kv_proj.unflatten(-1, (2, E)) + .unsqueeze(0) + .transpose(0, -2) + .squeeze(-2) + .contiguous() + ) + return (q_proj, kv_proj[0], kv_proj[1]) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +def _in_projection( + q: Tensor, + k: Tensor, + v: Tensor, + w_q: Tensor, + w_k: Tensor, + w_v: Tensor, + b_q: Optional[Tensor] = None, + b_k: Optional[Tensor] = None, + b_v: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor]: + r"""Perform the in-projection step of the attention operation. + + This is simply a triple of linear projections, + with shape constraints on the weights which + ensure embedding dimension uniformity in the projected outputs. + Output is a triple containing projection tensors for query, key and value. + + Args: + q, k, v: query, key and value tensors to be projected. + w_q, w_k, w_v: weights for q, k and v, respectively. + b_q, b_k, b_v: optional biases for q, k and v, respectively. + + Shape: + Inputs: + - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any + number of leading dimensions. + - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any + number of leading dimensions. + - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any + number of leading dimensions. + - w_q: :math:`(Eq, Eq)` + - w_k: :math:`(Eq, Ek)` + - w_v: :math:`(Eq, Ev)` + - b_q: :math:`(Eq)` + - b_k: :math:`(Eq)` + - b_v: :math:`(Eq)` + + Output: in output triple :math:`(q', k', v')`, + - q': :math:`[Qdims..., Eq]` + - k': :math:`[Kdims..., Eq]` + - v': :math:`[Vdims..., Eq]` + + """ + Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) + assert w_q.shape == ( + Eq, + Eq, + ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" + assert w_k.shape == ( + Eq, + Ek, + ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" + assert w_v.shape == ( + Eq, + Ev, + ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" + assert b_q is None or b_q.shape == ( + Eq, + ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" + assert b_k is None or b_k.shape == ( + Eq, + ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" + assert b_v is None or b_v.shape == ( + Eq, + ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +scaled_dot_product_attention = _add_docstr( + torch._C._nn.scaled_dot_product_attention, + r"""scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=False) -> Tensor: + + Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed, + and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be + specified as a keyword argument. + + .. code-block:: python + + # Efficient implementation equivalent to the following: + def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + if enable_gqa: + key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) + value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + + .. warning:: + This function is beta and subject to change. + + .. warning:: + This function always applies dropout according to the specified ``dropout_p`` argument. + To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module + that makes the function call is not in training mode. + + For example: + + .. code-block:: python + + class MyModel(nn.Module): + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, ...): + return F.scaled_dot_product_attention(..., + dropout_p=(self.p if self.training else 0.0)) + + Note: + + There are currently three supported implementations of scaled dot product attention: + + - `FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning`_ + - `Memory-Efficient Attention`_ + - A PyTorch implementation defined in C++ matching the above formulation + + The function may call optimized kernels for improved performance when using the CUDA backend. + For all other backends, the PyTorch implementation will be used. + + All implementations are enabled by default. Scaled dot product attention attempts to automatically select the + most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation + is used, the following functions are provided for enabling and disabling implementations. + The context manager is the preferred mechanism: + + - :func:`torch.nn.attention.sdpa_kernel`: A context manager used to enable or disable any of the implementations. + - :func:`torch.backends.cuda.enable_flash_sdp`: Globally enables or disables FlashAttention. + - :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Globally enables or disables Memory-Efficient Attention. + - :func:`torch.backends.cuda.enable_math_sdp`: Globally enables or disables the PyTorch C++ implementation. + + Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation, + disable the PyTorch C++ implementation using :func:`torch.nn.attention.sdpa_kernel`. + In the event that a fused implementation is not available, a warning will be raised with the + reasons why the fused implementation cannot run. + + Due to the nature of fusing floating point operations, the output of this function may be different + depending on what backend kernel is chosen. + The c++ implementation supports torch.float64 and can be used when higher precision is required. + For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16. + For more information please see :doc:`/notes/numerical_accuracy` + + Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention + and math kernel on CUDA tensor, and does not support Nested tensor. + Constraints for GQA: + + - number_of_heads_query % number_of_heads_key_value == 0 and, + - number_of_heads_key == number_of_heads_value + + Note: + + {cudnn_reproducibility_note} + """.format( + **reproducibility_notes + ) + + r""" + Args: + query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. + key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`. + value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`. + attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, + which is :math:`(N,..., L, S)`. Two types of masks are supported. + A boolean mask where a value of True indicates that the element *should* take part in attention. + A float mask of the same type as query, key, value that is added to the attention score. + dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied + is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a + square matrix. The attention masking has the form of the upper left causal bias due to the alignment + (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix. + An error is thrown if both attn_mask and is_causal are set. + scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set + to :math:`\frac{1}{\sqrt{E}}`. + enable_gqa (bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False. + + Returns: + output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. + + Shape legend: + - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` + - :math:`S: \text{Source sequence length}` + - :math:`L: \text{Target sequence length}` + - :math:`E: \text{Embedding dimension of the query and key}` + - :math:`Ev: \text{Embedding dimension of the value}` + - :math:`Hq: \text{Number of heads of query}` + - :math:`H: \text{Number of heads of key and value}` + + Examples: + + >>> # Optionally use the context manager to ensure one of the fused kernels is run + >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + >>> F.scaled_dot_product_attention(query,key,value) + + + >>> # Sample for GQA for llama3 + >>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda") + >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> with sdpa_kernel(backends=[SDPBackend.MATH]): + >>> F.scaled_dot_product_attention(query,key,value,enable_gqa=True) + + + .. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning: + https://arxiv.org/abs/2307.08691 + .. _Memory-Efficient Attention: + https://github.com/facebookresearch/xformers + .. _Grouped-Query Attention: + https://arxiv.org/pdf/2305.13245 + """, +) + + +def _mha_shape_check( + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + num_heads: int, +): + # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask` + # and returns if the input is batched or not. + # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. + + # Shape check. + if query.dim() == 3: + # Batched Inputs + is_batched = True + assert key.dim() == 3 and value.dim() == 3, ( + "For batched (3-D) `query`, expected `key` and `value` to be 3-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively" + ) + if key_padding_mask is not None: + assert key_padding_mask.dim() == 2, ( + "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" + f" but found {key_padding_mask.dim()}-D tensor instead" + ) + if attn_mask is not None: + assert attn_mask.dim() in (2, 3), ( + "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead" + ) + elif query.dim() == 2: + # Unbatched Inputs + is_batched = False + assert key.dim() == 2 and value.dim() == 2, ( + "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively" + ) + + if key_padding_mask is not None: + assert key_padding_mask.dim() == 1, ( + "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" + f" but found {key_padding_mask.dim()}-D tensor instead" + ) + + if attn_mask is not None: + assert attn_mask.dim() in (2, 3), ( + "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead" + ) + if attn_mask.dim() == 3: + expected_shape = (num_heads, query.shape[0], key.shape[0]) + assert ( + attn_mask.shape == expected_shape + ), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}" + else: + raise AssertionError( + f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor" + ) + + return is_batched + + +def _canonical_mask( + mask: Optional[Tensor], + mask_name: str, + other_type: Optional[DType], + other_name: str, + target_type: DType, + check_other: bool = True, +) -> Optional[Tensor]: + if mask is not None: + _mask_dtype = mask.dtype + _mask_is_float = torch.is_floating_point(mask) + if _mask_dtype != torch.bool and not _mask_is_float: + raise AssertionError( + f"only bool and floating types of {mask_name} are supported" + ) + if check_other and other_type is not None: + if _mask_dtype != other_type: + warnings.warn( + f"Support for mismatched {mask_name} and {other_name} " + "is deprecated. Use same type for both instead." + ) + if not _mask_is_float: + mask = torch.zeros_like(mask, dtype=target_type).masked_fill_( + mask, float("-inf") + ) + return mask + + +def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]: + if input is None: + return None + elif isinstance(input, torch.Tensor): + return input.dtype + raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor") + + +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, +) -> Tuple[Tensor, Optional[Tensor]]: + r"""Forward method for MultiHeadAttention. + + See :class:`torch.nn.MultiheadAttention` for details. + + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + Default: `True` + Note: `needs_weight` defaults to `True`, but should be set to `False` + For best performance when attention weights are not needed. + *Setting needs_weights to `True` + leads to a significant performance degradation.* + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + is_causal: If specified, applies a causal mask as attention mask, and ignores + attn_mask for computing scaled dot product attention. + Default: ``False``. + .. warning:: + is_causal is provides a hint that the attn_mask is the + causal mask.Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. + Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect + when ``need_weights=True.``. Default: True + + + Shape: + Inputs: + - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a FloatTensor is provided, it will be directly added to the value. + If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + + Outputs: + - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns + attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. + """ + tens_ops = ( + query, + key, + value, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + out_proj_weight, + out_proj_bias, + ) + if has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + is_causal=is_causal, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + average_attn_weights=average_attn_weights, + ) + + is_batched = _mha_shape_check( + query, key, value, key_padding_mask, attn_mask, num_heads + ) + + # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input + # is batched, run the computation and before returning squeeze the + # batch dimension so that the output doesn't carry this temporary batch dimension. + if not is_batched: + # unsqueeze if the input is unbatched + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + + key_padding_mask = _canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=_none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype, + ) + + if is_causal and attn_mask is None: + raise RuntimeError( + "Need attn_mask if specifying the is_causal hint. " + "You may use the Transformer module method " + "`generate_square_subsequent_mask` to create this mask." + ) + + if is_causal and key_padding_mask is None and not need_weights: + # when we have a kpm or need weights, we need attn_mask + # Otherwise, we use the is_causal hint go as is_causal + # indicator to SDPA. + attn_mask = None + else: + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + if key_padding_mask is not None: + # We have the attn_mask, and use that to merge kpm into it. + # Turn off use of is_causal hint, as the merged mask is no + # longer causal. + is_causal = False + + assert ( + embed_dim == embed_dim_to_check + ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + if isinstance(embed_dim, torch.Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(num_heads, rounding_mode="trunc") + else: + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + if use_separate_proj_weight: + # allow MHA to have different embedding dimensions when separate projection weights are used + assert ( + key.shape[:2] == value.shape[:2] + ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + else: + assert ( + key.shape == value.shape + ), f"key shape {key.shape} does not match value shape {value.shape}" + + # + # compute in-projection + # + if not use_separate_proj_weight: + assert ( + in_proj_weight is not None + ), "use_separate_proj_weight is False but in_proj_weight is None" + q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) + else: + assert ( + q_proj_weight is not None + ), "use_separate_proj_weight is True but q_proj_weight is None" + assert ( + k_proj_weight is not None + ), "use_separate_proj_weight is True but k_proj_weight is None" + assert ( + v_proj_weight is not None + ), "use_separate_proj_weight is True but v_proj_weight is None" + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.chunk(3) + q, k, v = _in_projection( + query, + key, + value, + q_proj_weight, + k_proj_weight, + v_proj_weight, + b_q, + b_k, + b_v, + ) + + # prep attention mask + + if attn_mask is not None: + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) + else: + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) + + # add bias along batch dimension (currently second) + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None + assert bias_v is None + + # + # reshape q, k, v for multihead attention and make them batch first + # + q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if static_k is None: + k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert ( + static_k.size(0) == bsz * num_heads + ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + assert ( + static_k.size(2) == head_dim + ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + k = static_k + if static_v is None: + v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert ( + static_v.size(0) == bsz * num_heads + ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + assert ( + static_v.size(2) == head_dim + ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + v = static_v + + # add zero attention along batch dimension (now first) + if add_zero_attn: + zero_attn_shape = (bsz * num_heads, 1, head_dim) + k = torch.cat( + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 + ) + v = torch.cat( + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1 + ) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.size(1) + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + bsz, + src_len, + ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = ( + key_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, num_heads, -1, -1) + .reshape(bsz * num_heads, 1, src_len) + ) + if attn_mask is None: + attn_mask = key_padding_mask + else: + attn_mask = attn_mask + key_padding_mask + + # adjust dropout probability + if not training: + dropout_p = 0.0 + + # + # (deep breath) calculate attention and out projection + # + + if need_weights: + B, Nt, E = q.shape + q_scaled = q * math.sqrt(1.0 / float(E)) + + assert not ( + is_causal and attn_mask is None + ), "FIXME: is_causal not implemented for need_weights" + + if attn_mask is not None: + attn_output_weights = torch.baddbmm( + attn_mask, q_scaled, k.transpose(-2, -1) + ) + else: + attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) + attn_output_weights = softmax(attn_output_weights, dim=-1) + if dropout_p > 0.0: + attn_output_weights = dropout(attn_output_weights, p=dropout_p) + + attn_output = torch.bmm(attn_output_weights, v) + + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + ) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + + # optionally average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.mean(dim=1) + + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) + return attn_output, attn_output_weights + else: + # attn_mask can be either (L,S) or (N*num_heads, L, S) + # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) + # in order to match the input for SDPA of (N, num_heads, L, S) + if attn_mask is not None: + if attn_mask.size(0) == 1 and attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) + + q = q.view(bsz, num_heads, tgt_len, head_dim) + k = k.view(bsz, num_heads, src_len, head_dim) + v = v.view(bsz, num_heads, src_len, head_dim) + + attn_output = scaled_dot_product_attention( + q, k, v, attn_mask, dropout_p, is_causal + ) + attn_output = ( + attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + ) + + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + return attn_output, None diff --git a/lib/python3.10/site-packages/torch/nn/functional.pyi b/lib/python3.10/site-packages/torch/nn/functional.pyi new file mode 100644 index 0000000000000000000000000000000000000000..97186614cff810f7df1c4128ac586323fad87861 --- /dev/null +++ b/lib/python3.10/site-packages/torch/nn/functional.pyi @@ -0,0 +1,691 @@ +# @generated by tools/pyi/gen_pyi.py from torch/nn/functional.pyi.in +# mypy: allow-untyped-defs + +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + overload, + Sequence, + Tuple, + Union, +) + +from torch import Tensor +from torch.types import _dtype, _int, _size + +from .common_types import ( + _ratio_any_t, + _size_1_t, + _size_2_opt_t, + _size_2_t, + _size_3_opt_t, + _size_3_t, + _size_any_t, +) + +# 'TypedDict' is a new accepted type that represents a dictionary with a fixed set of allowed keys. +# It is standards-track but not in `typing` yet. We leave this hear to be uncommented once the feature +# is wide-spread. + +# from mypy_extensions import TypedDict + +# GRID_SAMPLE_INTERPOLATION_MODES = TypedDict('GRID_SAMPLE_INTERPOLATION_MODES', {'bilinear': int, 'nearest': int}) +# GRID_SAMPLE_PADDING_MODES = TypedDict('GRID_SAMPLE_PADDING_MODES', {'zeros': int, 'border': int, 'reflection': int}) + +GRID_SAMPLE_INTERPOLATION_MODES = Dict[str, int] +GRID_SAMPLE_PADDING_MODES = Dict[str, int] + +# These stubs were generated by running stubgen (`stubgen --parse-only functional.py`), followed by manual cleaning. +# +# The 'BroadcastingList{1,2,3}' types were replaced by `_size` or _output_ratio, as appropriate. +# This was necessary since the JIT uses BroadcastingList* types but static checking with mypy etc requires a `Sequence` +# type. There is no way to express the expected lengths of these lists in the current Python typing system. +# +# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were +# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code +# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system +# to encode the type semantics of `_add_docstr`, should that system ever become widespread. +def fractional_max_pool2d_with_indices( + input: Tensor, + kernel_size: _size, + output_size: Optional[_size] = ..., + output_ratio: Optional[_ratio_any_t] = ..., + return_indices: bool = ..., + _random_samples: Optional[Tensor] = ..., +) -> Tuple[Tensor, Tensor]: ... +def fractional_max_pool3d_with_indices( + input: Tensor, + kernel_size: _size, + output_size: Optional[_size] = ..., + output_ratio: Optional[_ratio_any_t] = ..., + return_indices: bool = ..., + _random_samples: Optional[Tensor] = ..., +) -> Tuple[Tensor, Tensor]: ... +def max_pool1d_with_indices( + input: Tensor, + kernel_size: _size, + stride: Optional[_size] = ..., + padding: _size = ..., + dilation: _size = ..., + ceil_mode: bool = ..., + return_indices: bool = ..., +) -> Tuple[Tensor, Tensor]: ... +def max_pool2d_with_indices( + input: Tensor, + kernel_size: _size, + stride: Optional[_size] = ..., + padding: _size = ..., + dilation: _size = ..., + ceil_mode: bool = ..., + return_indices: bool = ..., +) -> Tuple[Tensor, Tensor]: ... +def max_pool3d_with_indices( + input: Tensor, + kernel_size: _size, + stride: Optional[_size] = ..., + padding: _size = ..., + dilation: _size = ..., + ceil_mode: bool = ..., + return_indices: bool = ..., +) -> Tuple[Tensor, Tensor]: ... +def max_unpool1d( + input: Tensor, + indices: Tensor, + kernel_size: _size, + stride: Optional[_size] = ..., + padding: _size = ..., + output_size: Optional[_size] = ..., +) -> Tensor: ... +def max_unpool2d( + input: Tensor, + indices: Tensor, + kernel_size: _size, + stride: Optional[_size] = ..., + padding: _size = ..., + output_size: Optional[_size] = ..., +) -> Tensor: ... +def max_unpool3d( + input: Tensor, + indices: Tensor, + kernel_size: _size, + stride: Optional[_size] = ..., + padding: _size = ..., + output_size: Optional[_size] = ..., +) -> Tensor: ... +def lp_pool1d( + input: Tensor, + norm_type: float, + kernel_size: _size_1_t, + stride: Union[Optional[_size], Optional[int]] = ..., + ceil_mode: bool = ..., +) -> Tensor: ... +def lp_pool2d( + input: Tensor, + norm_type: float, + kernel_size: _size_2_t, + stride: Union[Optional[_size], Optional[int]] = ..., + ceil_mode: bool = ..., +) -> Tensor: ... +def lp_pool3d( + input: Tensor, + norm_type: float, + kernel_size: _size_3_t, + stride: Union[Optional[_size], Optional[int]] = ..., + ceil_mode: bool = ..., +) -> Tensor: ... +def adaptive_max_pool1d_with_indices( + input: Tensor, + output_size: _size, + return_indices: bool = ..., +) -> Tuple[Tensor, Tensor]: ... +def adaptive_max_pool2d_with_indices( + input: Tensor, + output_size: _size_2_opt_t, + return_indices: bool = ..., +) -> Tuple[Tensor, Tensor]: ... +def adaptive_max_pool3d_with_indices( + input: Tensor, + output_size: _size_3_opt_t, + return_indices: bool = ..., +) -> Tuple[Tensor, Tensor]: ... +def adaptive_avg_pool2d(input: Tensor, output_size: _size_2_opt_t) -> Tensor: ... +def adaptive_avg_pool3d(input: Tensor, output_size: _size_3_opt_t) -> Tensor: ... +def dropout( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def alpha_dropout( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def dropout1d( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def dropout2d( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def dropout3d( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def feature_alpha_dropout( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def threshold( + input: Tensor, + threshold: float, + value: float, + inplace: bool = ..., +) -> Tensor: ... +def relu(input: Tensor, inplace: bool = ...) -> Tensor: ... +def glu(input: Tensor, dim: int = ...) -> Tensor: ... +def hardtanh( + input: Tensor, + min_val: float = ..., + max_val: float = ..., + inplace: bool = ..., +) -> Tensor: ... +def relu6(input: Tensor, inplace: bool = ...) -> Tensor: ... +def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ... +def selu(input: Tensor, inplace: bool = ...) -> Tensor: ... +def celu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ... +def leaky_relu( + input: Tensor, + negative_slope: float = ..., + inplace: bool = ..., +) -> Tensor: ... +def rrelu( + input: Tensor, + lower: float = ..., + upper: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def tanhshrink(input: Any): ... +def softsign(input: Any): ... +def softmin( + input: Tensor, + dim: Optional[int] = ..., + _stacklevel: int = ..., + dtype: Optional[_dtype] = ..., +) -> Tensor: ... +def softmax( + input: Tensor, + dim: Optional[int] = ..., + _stacklevel: int = ..., + dtype: Optional[_dtype] = ..., +) -> Tensor: ... +def gumbel_softmax( + logits: Tensor, + tau: float = ..., + hard: bool = ..., + eps: float = ..., + dim: int = ..., +) -> Tensor: ... +def log_softmax( + input: Tensor, + dim: Optional[int] = ..., + _stacklevel: int = ..., + dtype: Optional[_dtype] = ..., +) -> Tensor: ... +def tanh(input: Any): ... +def sigmoid(input: Any) -> Tensor: ... +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: ... +def silu(input: Tensor, inplace: bool = False) -> Tensor: ... +def mish(input: Tensor, inplace: bool = False) -> Tensor: ... +def hardswish(input: Tensor, inplace: bool = False) -> Tensor: ... +def embedding( + input: Tensor, + weight: Tensor, + padding_idx: Optional[int] = ..., + max_norm: Optional[float] = ..., + norm_type: float = ..., + scale_grad_by_freq: bool = ..., + sparse: bool = ..., +) -> Tensor: ... +def embedding_bag( + input: Tensor, + weight: Tensor, + offsets: Optional[Tensor] = ..., + max_norm: Optional[float] = ..., + norm_type: float = ..., + scale_grad_by_freq: bool = ..., + mode: str = ..., + sparse: bool = ..., + per_sample_weights: Optional[Tensor] = ..., + include_last_offset: bool = ..., + padding_idx: Optional[int] = ..., +) -> Tensor: ... +def batch_norm( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = ..., + bias: Optional[Tensor] = ..., + training: bool = ..., + momentum: float = ..., + eps: float = ..., +) -> Tensor: ... +def instance_norm( + input: Tensor, + running_mean: Optional[Tensor] = ..., + running_var: Optional[Tensor] = ..., + weight: Optional[Tensor] = ..., + bias: Optional[Tensor] = ..., + use_input_stats: bool = ..., + momentum: float = ..., + eps: float = ..., +) -> Tensor: ... +def layer_norm( + input: Tensor, + normalized_shape: Sequence[int], + weight: Optional[Tensor] = ..., + bias: Optional[Tensor] = ..., + eps: float = ..., +) -> Tensor: ... +def rms_norm( + input: Tensor, + normalized_shape: Sequence[int], + weight: Optional[Tensor] = ..., + eps: Optional[float] = ..., +) -> Tensor: ... +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = ..., + bias: Optional[Tensor] = ..., + eps: float = ..., +) -> Tensor: ... +def local_response_norm( + input: Tensor, + size: int, + alpha: float = ..., + beta: float = ..., + k: float = ..., +) -> Tensor: ... +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: int = ..., + reduction: str = ..., + zero_infinity: bool = ..., +) -> Tensor: ... +def nll_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = ..., + size_average: Optional[bool] = ..., + ignore_index: int = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def poisson_nll_loss( + input: Tensor, + target: Tensor, + log_input: bool = ..., + full: bool = ..., + size_average: Optional[bool] = ..., + eps: float = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def gaussian_nll_loss( + input: Tensor, + target: Tensor, + var: Tensor, + full: Optional[bool] = ..., + eps: Optional[float] = ..., + reduction: Optional[str] = ..., +) -> Tensor: ... +def kl_div( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., + log_target: bool = ..., +) -> Tensor: ... +def cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = ..., + size_average: Optional[bool] = ..., + ignore_index: int = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., + label_smoothing: float = ..., +) -> Tensor: ... +def binary_cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def binary_cross_entropy_with_logits( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., + pos_weight: Optional[Tensor] = ..., +) -> Tensor: ... +def smooth_l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., + beta: float = ..., +) -> Tensor: ... +def huber_loss( + input: Tensor, + target: Tensor, + reduction: str = ..., + delta: float = ..., +) -> Tensor: ... +def l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def mse_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def margin_ranking_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def hinge_embedding_loss( + input: Tensor, + target: Tensor, + margin: float = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def multilabel_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def soft_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def multilabel_soft_margin_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def cosine_embedding_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: int = ..., + margin: float = ..., + weight: Optional[Tensor] = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def upsample( + input: Any, + size: Optional[Any] = ..., + scale_factor: Optional[Any] = ..., + mode: str = ..., + align_corners: Optional[Any] = ..., +): ... +def interpolate( + input: Any, + size: Optional[Any] = ..., + scale_factor: Optional[Any] = ..., + mode: str = ..., + align_corners: Optional[Any] = ..., + recompute_scale_factor: Optional[Any] = ..., + antialias: bool = ..., +): ... +def upsample_nearest( + input: Any, + size: Optional[Any] = ..., + scale_factor: Optional[Any] = ..., +): ... +def upsample_bilinear( + input: Any, + size: Optional[Any] = ..., + scale_factor: Optional[Any] = ..., +): ... +def grid_sample( + input: Tensor, + grid: Tensor, + mode: str = ..., + padding_mode: str = ..., + align_corners: Optional[Any] = ..., +) -> Tensor: ... +def affine_grid( + theta: Tensor, + size: List[int], + align_corners: Optional[Any] = ..., +) -> Tensor: ... +def triplet_margin_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + margin: float = ..., + p: float = ..., + eps: float = ..., + swap: bool = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def triplet_margin_with_distance_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + *, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = ..., + margin: float = ..., + swap: bool = ..., + reduction: str = ..., +) -> Tensor: ... +def normalize( + input: Tensor, + p: float = ..., + dim: int = ..., + eps: float = ..., + out: Optional[Tensor] = ..., +) -> Tensor: ... +def assert_int_or_pair( + arg: Any, + arg_name: Any, + message: Any, +) -> None: ... +def unfold( + input: Tensor, + kernel_size: _size_any_t, + dilation: _size_any_t = ..., + padding: _size_any_t = ..., + stride: _size_any_t = ..., +) -> Tensor: ... +def fold( + input: Tensor, + output_size: _size_any_t, + kernel_size: _size_any_t, + dilation: _size_any_t = ..., + padding: _size_any_t = ..., + stride: _size_any_t = ..., +) -> Tensor: ... +def _canonical_mask( + mask: Optional[Tensor], + mask_name: str, + other_type: Optional[_dtype], + other_name: str, + target_type: _dtype, + check_other: bool = True, +) -> Optional[Tensor]: ... +def _none_or_dtype(input: Optional[Tensor]) -> Optional[_dtype]: ... +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, +) -> Tuple[Tensor, Optional[Tensor]]: ... + +from torch import conv1d as conv1d +from torch import conv2d as conv2d +from torch import conv3d as conv3d +from torch import conv_transpose1d as conv_transpose1d +from torch import conv_transpose2d as conv_transpose2d +from torch import conv_transpose3d as conv_transpose3d +from torch import conv_tbc as conv_tbc +from torch import avg_pool1d as avg_pool1d +from torch import adaptive_avg_pool1d as adaptive_avg_pool1d +from torch import relu_ as relu_ +from torch import selu_ as selu_ +from torch import celu_ as celu_ +from torch import prelu as prelu +from torch import rrelu_ as rrelu_ +from torch import hardshrink as hardshrink +from torch import bilinear as bilinear +from torch import pixel_shuffle as pixel_shuffle +from torch import pixel_unshuffle as pixel_unshuffle +from torch import channel_shuffle as channel_shuffle +from torch import native_channel_shuffle as native_channel_shuffle +from torch import pairwise_distance as pairwise_distance +from torch import pdist as pdist +from torch import cosine_similarity as cosine_similarity +from torch._C._nn import avg_pool2d as avg_pool2d +from torch._C._nn import avg_pool3d as avg_pool3d +from torch._C._nn import hardtanh_ as hardtanh_ +from torch._C._nn import elu_ as elu_ +from torch._C._nn import leaky_relu_ as leaky_relu_ +from torch._C._nn import gelu as gelu +from torch._C._nn import softplus as softplus +from torch._C._nn import softshrink as softshrink +from torch._C._nn import linear as linear +from torch._C._nn import pad as pad +from torch._C._nn import one_hot as one_hot +from torch._C._nn import scaled_dot_product_attention as scaled_dot_product_attention +from torch._C._nn import log_sigmoid +logsigmoid = log_sigmoid + +@overload +def adaptive_max_pool1d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[False] = False) -> Tensor: ... +@overload +def adaptive_max_pool1d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool1d(input: Tensor, output_size: Union[_int, _size], *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[False] = False) -> Tensor: ... +@overload +def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size], *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[False] = False) -> Tensor: ... +@overload +def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size], *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ... +@overload +def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]] = None, output_ratio: Optional[_ratio_any_t] = None, return_indices: Literal[False] = False, _random_samples: Optional[Tensor] = None) -> Tensor: ... +@overload +def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]], output_ratio: Optional[_ratio_any_t], return_indices: Literal[True], /, _random_samples: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ... +@overload +def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]] = None, output_ratio: Optional[_ratio_any_t] = None, *, return_indices: Literal[True], _random_samples: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ... +@overload +def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]] = None, output_ratio: Optional[_ratio_any_t] = None, return_indices: Literal[False] = False, _random_samples: Optional[Tensor] = None) -> Tensor: ... +@overload +def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]], output_ratio: Optional[_ratio_any_t], return_indices: Literal[True], /, _random_samples: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ... +@overload +def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]] = None, output_ratio: Optional[_ratio_any_t] = None, *, return_indices: Literal[True], _random_samples: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ... +@overload +def max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, return_indices: Literal[False] = False) -> Tensor: ... +@overload +def max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]], padding: Union[_int, _size], dilation: Union[_int, _size], ceil_mode: bool, return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ... +@overload +def max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ... +@overload +def max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, return_indices: Literal[False] = False) -> Tensor: ... +@overload +def max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]], padding: Union[_int, _size], dilation: Union[_int, _size], ceil_mode: bool, return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ... +@overload +def max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ... +@overload +def max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, return_indices: Literal[False] = False) -> Tensor: ... +@overload +def max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]], padding: Union[_int, _size], dilation: Union[_int, _size], ceil_mode: bool, return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ... +@overload +def max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ... diff --git a/lib/python3.10/site-packages/torch/nn/grad.py b/lib/python3.10/site-packages/torch/nn/grad.py new file mode 100644 index 0000000000000000000000000000000000000000..61e817dbed612e51f1ff426373fbef1d192e331e --- /dev/null +++ b/lib/python3.10/site-packages/torch/nn/grad.py @@ -0,0 +1,298 @@ +# mypy: allow-untyped-defs +"""Gradient interface.""" + +import torch +from torch.nn.modules.utils import _pair, _single, _triple + + +def conv1d_input( + input_size, + weight, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): + r"""Compute the gradient of conv1d with respect to the input of the convolution. + + This is same as the 1D transposed convolution operator under the hood but requires + the shape of the gradient w.r.t. input to be specified explicitly. + + Args: + input_size : Shape of the input gradient tensor + weight: weight tensor (out_channels x in_channels/groups x kW) + grad_output : output gradient tensor (minibatch x out_channels x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(1, 1, 3, requires_grad=True) + >>> weight = torch.randn(1, 1, 1, requires_grad=True) + >>> output = F.conv1d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> grad_input = torch.autograd.grad(output, input, grad_output) + >>> F.grad.conv1d_input(input.shape, weight, grad_output) + + """ + input = grad_output.new_empty(1).expand(input_size) + + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _single(stride), + _single(padding), + _single(dilation), + False, + [0], + groups, + (True, False, False), + )[0] + + +def conv1d_weight( + input, + weight_size, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): + r"""Compute the gradient of conv1d with respect to the weight of the convolution. + + Args: + input: input tensor of shape (minibatch x in_channels x iW) + weight_size : Shape of the weight gradient tensor + grad_output : output gradient tensor (minibatch x out_channels x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(1, 1, 3, requires_grad=True) + >>> weight = torch.randn(1, 1, 1, requires_grad=True) + >>> output = F.conv1d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> # xdoctest: +SKIP + >>> grad_weight = torch.autograd.grad(output, filter, grad_output) + >>> F.grad.conv1d_weight(input, weight.shape, grad_output) + + """ + weight = grad_output.new_empty(1).expand(weight_size) + + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _single(stride), + _single(padding), + _single(dilation), + False, + [0], + groups, + (False, True, False), + )[1] + + +def conv2d_input( + input_size, + weight, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): + r"""Compute the gradient of conv2d with respect to the input of the convolution. + + This is same as the 2D transposed convolution operator under the hood but requires + the shape of the gradient w.r.t. input to be specified explicitly. + + Args: + input_size : Shape of the input gradient tensor + weight: weight tensor (out_channels x in_channels/groups x kH x kW) + grad_output : output gradient tensor (minibatch x out_channels x oH x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(1, 1, 3, 3, requires_grad=True) + >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True) + >>> output = F.conv2d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> grad_input = torch.autograd.grad(output, input, grad_output) + >>> F.grad.conv2d_input(input.shape, weight, grad_output) + + """ + input = grad_output.new_empty(1).expand(input_size) + + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _pair(stride), + _pair(padding), + _pair(dilation), + False, + [0], + groups, + (True, False, False), + )[0] + + +def conv2d_weight( + input, + weight_size, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): + r"""Compute the gradient of conv2d with respect to the weight of the convolution. + + Args: + input: input tensor of shape (minibatch x in_channels x iH x iW) + weight_size : Shape of the weight gradient tensor + grad_output : output gradient tensor (minibatch x out_channels x oH x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(1, 1, 3, 3, requires_grad=True) + >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True) + >>> output = F.conv2d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> # xdoctest: +SKIP + >>> grad_weight = torch.autograd.grad(output, filter, grad_output) + >>> F.grad.conv2d_weight(input, weight.shape, grad_output) + + """ + weight = grad_output.new_empty(1).expand(weight_size) + + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _pair(stride), + _pair(padding), + _pair(dilation), + False, + [0], + groups, + (False, True, False), + )[1] + + +def conv3d_input( + input_size, + weight, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): + r"""Compute the gradient of conv3d with respect to the input of the convolution. + + This is same as the 3D transposed convolution operator under the hood but requires + the shape of the gradient w.r.t. input to be specified explicitly. + + Args: + input_size : Shape of the input gradient tensor + weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW) + grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True) + >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True) + >>> output = F.conv3d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> grad_input = torch.autograd.grad(output, input, grad_output) + >>> F.grad.conv3d_input(input.shape, weight, grad_output) + + """ + input = grad_output.new_empty(1).expand(input_size) + + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _triple(stride), + _triple(padding), + _triple(dilation), + False, + [0], + groups, + (True, False, False), + )[0] + + +def conv3d_weight( + input, + weight_size, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): + r"""Compute the gradient of conv3d with respect to the weight of the convolution. + + Args: + input: input tensor of shape (minibatch x in_channels x iT x iH x iW) + weight_size : Shape of the weight gradient tensor + grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True) + >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True) + >>> output = F.conv3d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> grad_weight = torch.autograd.grad(output, weight, grad_output) + >>> F.grad.conv3d_weight(input, weight.shape, grad_output) + + """ + weight = grad_output.new_empty(1).expand(weight_size) + + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _triple(stride), + _triple(padding), + _triple(dilation), + False, + [0], + groups, + (False, True, False), + )[1] diff --git a/lib/python3.10/site-packages/torch/nn/init.py b/lib/python3.10/site-packages/torch/nn/init.py new file mode 100644 index 0000000000000000000000000000000000000000..3d0600b43b68f96e0db09dd330d93582c7fb0ff0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/nn/init.py @@ -0,0 +1,697 @@ +# mypy: allow-untyped-defs +"""This file contains utilities for initializing neural network parameters.""" +import math +import warnings +from typing import Optional as _Optional + +import torch +from torch import Tensor + + +# These no_grad_* functions are necessary as wrappers around the parts of these +# functions that use `with torch.no_grad()`. The JIT doesn't support context +# managers, so these need to be implemented as builtins. Using these wrappers +# lets us keep those builtins small and re-usable. +def _no_grad_uniform_(tensor, a, b, generator=None): + with torch.no_grad(): + return tensor.uniform_(a, b, generator=generator) + + +def _no_grad_normal_(tensor, mean, std, generator=None): + with torch.no_grad(): + return tensor.normal_(mean, std, generator=generator) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None): + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def _no_grad_fill_(tensor, val): + with torch.no_grad(): + return tensor.fill_(val) + + +def _no_grad_zero_(tensor): + with torch.no_grad(): + return tensor.zero_() + + +def calculate_gain(nonlinearity, param=None): + r"""Return the recommended gain value for the given nonlinearity function. + + The values are as follows: + + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + SELU :math:`\frac{3}{4}` + ================= ==================================================== + + .. warning:: + In order to implement `Self-Normalizing Neural Networks`_ , + you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. + This gives the initial weights a variance of ``1 / N``, + which is necessary to induce a stable fixed point in the forward pass. + In contrast, the default gain for ``SELU`` sacrifices the normalization + effect for more stable gradient flow in rectangular layers. + + Args: + nonlinearity: the non-linear function (`nn.functional` name) + param: optional parameter for the non-linear function + + Examples: + >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + + .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html + """ + linear_fns = [ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + ] + if nonlinearity in linear_fns or nonlinearity == "sigmoid": + return 1 + elif nonlinearity == "tanh": + return 5.0 / 3 + elif nonlinearity == "relu": + return math.sqrt(2.0) + elif nonlinearity == "leaky_relu": + if param is None: + negative_slope = 0.01 + elif ( + not isinstance(param, bool) + and isinstance(param, int) + or isinstance(param, float) + ): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError(f"negative_slope {param} not a valid number") + return math.sqrt(2.0 / (1 + negative_slope**2)) + elif nonlinearity == "selu": + return ( + 3.0 / 4 + ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) + else: + raise ValueError(f"Unsupported nonlinearity {nonlinearity}") + + +def uniform_( + tensor: Tensor, + a: float = 0.0, + b: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from the uniform distribution. + + :math:`\mathcal{U}(a, b)`. + + Args: + tensor: an n-dimensional `torch.Tensor` + a: the lower bound of the uniform distribution + b: the upper bound of the uniform distribution + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.uniform_(w) + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator + ) + return _no_grad_uniform_(tensor, a, b, generator) + + +def normal_( + tensor: Tensor, + mean: float = 0.0, + std: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from the normal distribution. + + :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.normal_(w) + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator + ) + return _no_grad_normal_(tensor, mean, std, generator) + + +def trunc_normal_( + tensor: Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from a truncated normal distribution. + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator) + + +def constant_(tensor: Tensor, val: float) -> Tensor: + r"""Fill the input Tensor with the value :math:`\text{val}`. + + Args: + tensor: an n-dimensional `torch.Tensor` + val: the value to fill the tensor with + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.constant_(w, 0.3) + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + constant_, (tensor,), tensor=tensor, val=val + ) + return _no_grad_fill_(tensor, val) + + +def ones_(tensor: Tensor) -> Tensor: + r"""Fill the input Tensor with the scalar value `1`. + + Args: + tensor: an n-dimensional `torch.Tensor` + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.ones_(w) + """ + return _no_grad_fill_(tensor, 1.0) + + +def zeros_(tensor: Tensor) -> Tensor: + r"""Fill the input Tensor with the scalar value `0`. + + Args: + tensor: an n-dimensional `torch.Tensor` + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.zeros_(w) + """ + return _no_grad_zero_(tensor) + + +def eye_(tensor): + r"""Fill the 2-dimensional input `Tensor` with the identity matrix. + + Preserves the identity of the inputs in `Linear` layers, where as + many inputs are preserved as possible. + + Args: + tensor: a 2-dimensional `torch.Tensor` + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.eye_(w) + """ + if tensor.ndimension() != 2: + raise ValueError("Only tensors with 2 dimensions are supported") + + with torch.no_grad(): + torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad) + return tensor + + +def dirac_(tensor, groups=1): + r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function. + + Preserves the identity of the inputs in `Convolutional` + layers, where as many input channels are preserved as possible. In case + of groups>1, each group of channels preserves identity + + Args: + tensor: a {3, 4, 5}-dimensional `torch.Tensor` + groups (int, optional): number of groups in the conv layer (default: 1) + Examples: + >>> w = torch.empty(3, 16, 5, 5) + >>> nn.init.dirac_(w) + >>> w = torch.empty(3, 24, 5, 5) + >>> nn.init.dirac_(w, 3) + """ + dimensions = tensor.ndimension() + if dimensions not in [3, 4, 5]: + raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported") + + sizes = tensor.size() + + if sizes[0] % groups != 0: + raise ValueError("dim 0 must be divisible by groups") + + out_chans_per_grp = sizes[0] // groups + min_dim = min(out_chans_per_grp, sizes[1]) + + with torch.no_grad(): + tensor.zero_() + + for g in range(groups): + for d in range(min_dim): + if dimensions == 3: # Temporal convolution + tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1 + elif dimensions == 4: # Spatial convolution + tensor[ + g * out_chans_per_grp + d, + d, + tensor.size(2) // 2, + tensor.size(3) // 2, + ] = 1 + else: # Volumetric convolution + tensor[ + g * out_chans_per_grp + d, + d, + tensor.size(2) // 2, + tensor.size(3) // 2, + tensor.size(4) // 2, + ] = 1 + return tensor + + +def _calculate_fan_in_and_fan_out(tensor): + dimensions = tensor.dim() + if dimensions < 2: + raise ValueError( + "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" + ) + + num_input_fmaps = tensor.size(1) + num_output_fmaps = tensor.size(0) + receptive_field_size = 1 + if tensor.dim() > 2: + # math.prod is not always available, accumulate the product manually + # we could use functools.reduce but that is not supported by TorchScript + for s in tensor.shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def xavier_uniform_( + tensor: Tensor, + gain: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input `Tensor` with values using a Xavier uniform distribution. + + The method is described in `Understanding the difficulty of training + deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). + The resulting tensor will have values sampled from + :math:`\mathcal{U}(-a, a)` where + + .. math:: + a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + gain: an optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.xavier_uniform_(w.T, ...)``. + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + return _no_grad_uniform_(tensor, -a, a, generator) + + +def xavier_normal_( + tensor: Tensor, + gain: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input `Tensor` with values using a Xavier normal distribution. + + The method is described in `Understanding the difficulty of training deep feedforward + neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor + will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + gain: an optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.xavier_normal_(w) + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.xavier_normal_(w.T, ...)``. + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + + return _no_grad_normal_(tensor, 0.0, std, generator) + + +def _calculate_correct_fan(tensor, mode): + mode = mode.lower() + valid_modes = ["fan_in", "fan_out"] + if mode not in valid_modes: + raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + return fan_in if mode == "fan_in" else fan_out + + +def kaiming_uniform_( + tensor: Tensor, + a: float = 0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", + generator: _Optional[torch.Generator] = None, +): + r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. + + The method is described in `Delving deep into rectifiers: Surpassing + human-level performance on ImageNet classification` - He, K. et al. (2015). + The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.kaiming_uniform_(w.T, ...)``. + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + kaiming_uniform_, + (tensor,), + tensor=tensor, + a=a, + mode=mode, + nonlinearity=nonlinearity, + generator=generator, + ) + + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + with torch.no_grad(): + return tensor.uniform_(-bound, bound, generator=generator) + + +def kaiming_normal_( + tensor: Tensor, + a: float = 0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", + generator: _Optional[torch.Generator] = None, +): + r"""Fill the input `Tensor` with values using a Kaiming normal distribution. + + The method is described in `Delving deep into rectifiers: Surpassing + human-level performance on ImageNet classification` - He, K. et al. (2015). + The resulting tensor will have values sampled from + :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.kaiming_normal_(w.T, ...)``. + """ + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + with torch.no_grad(): + return tensor.normal_(0, std, generator=generator) + + +def orthogonal_( + tensor, + gain=1, + generator: _Optional[torch.Generator] = None, +): + r"""Fill the input `Tensor` with a (semi) orthogonal matrix. + + Described in `Exact solutions to the nonlinear dynamics of learning in deep + linear neural networks` - Saxe, A. et al. (2013). The input tensor must have + at least 2 dimensions, and for tensors with more than 2 dimensions the + trailing dimensions are flattened. + + Args: + tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` + gain: optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> w = torch.empty(3, 5) + >>> nn.init.orthogonal_(w) + """ + if tensor.ndimension() < 2: + raise ValueError("Only tensors with 2 or more dimensions are supported") + + if tensor.numel() == 0: + # no-op + return tensor + rows = tensor.size(0) + cols = tensor.numel() // rows + flattened = tensor.new_empty((rows, cols)).normal_(0, 1, generator=generator) + + if rows < cols: + flattened.t_() + + # Compute the qr factorization + q, r = torch.linalg.qr(flattened) + # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf + d = torch.diag(r, 0) + ph = d.sign() + q *= ph + + if rows < cols: + q.t_() + + with torch.no_grad(): + tensor.view_as(q).copy_(q) + tensor.mul_(gain) + return tensor + + +def sparse_( + tensor, + sparsity, + std=0.01, + generator: _Optional[torch.Generator] = None, +): + r"""Fill the 2D input `Tensor` as a sparse matrix. + + The non-zero elements will be drawn from the normal distribution + :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via + Hessian-free optimization` - Martens, J. (2010). + + Args: + tensor: an n-dimensional `torch.Tensor` + sparsity: The fraction of elements in each column to be set to zero + std: the standard deviation of the normal distribution used to generate + the non-zero values + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.sparse_(w, sparsity=0.1) + """ + if tensor.ndimension() != 2: + raise ValueError("Only tensors with 2 dimensions are supported") + + rows, cols = tensor.shape + num_zeros = int(math.ceil(sparsity * rows)) + + with torch.no_grad(): + tensor.normal_(0, std, generator=generator) + for col_idx in range(cols): + row_indices = torch.randperm(rows) + zero_indices = row_indices[:num_zeros] + tensor[zero_indices, col_idx] = 0 + return tensor + + +# for backward compatibility +def _make_deprecate(meth): + new_name = meth.__name__ + old_name = new_name[:-1] + + def deprecated_init(*args, **kwargs): + warnings.warn( + f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.", + FutureWarning, + stacklevel=2, + ) + return meth(*args, **kwargs) + + deprecated_init.__doc__ = rf""" + {old_name}(...) + + .. warning:: + This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`. + + See :func:`~torch.nn.init.{new_name}` for details.""" + deprecated_init.__name__ = old_name + return deprecated_init + + +uniform = _make_deprecate(uniform_) +normal = _make_deprecate(normal_) +constant = _make_deprecate(constant_) +eye = _make_deprecate(eye_) +dirac = _make_deprecate(dirac_) +xavier_uniform = _make_deprecate(xavier_uniform_) +xavier_normal = _make_deprecate(xavier_normal_) +kaiming_uniform = _make_deprecate(kaiming_uniform_) +kaiming_normal = _make_deprecate(kaiming_normal_) +orthogonal = _make_deprecate(orthogonal_) +sparse = _make_deprecate(sparse_) diff --git a/lib/python3.10/site-packages/torch/nn/parameter.py b/lib/python3.10/site-packages/torch/nn/parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..7d2ad36aeb589cad23e6c78dfa1482d8ad0d02ea --- /dev/null +++ b/lib/python3.10/site-packages/torch/nn/parameter.py @@ -0,0 +1,280 @@ +from collections import OrderedDict + +import torch +from torch._C import _disabled_torch_function_impl + + +# Metaclass to combine _TensorMeta and the instance check override for Parameter. +class _ParameterMeta(torch._C._TensorMeta): + # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. + def __instancecheck__(self, instance): + if self is Parameter: + if isinstance(instance, torch.Tensor) and getattr( + instance, "_is_param", False + ): + return True + return super().__instancecheck__(instance) + + +class Parameter(torch.Tensor, metaclass=_ParameterMeta): + r"""A kind of Tensor that is to be considered a module parameter. + + Parameters are :class:`~torch.Tensor` subclasses, that have a + very special property when used with :class:`Module` s - when they're + assigned as Module attributes they are automatically added to the list of + its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator. + Assigning a Tensor doesn't have such effect. This is because one might + want to cache some temporary state, like last hidden state of the RNN, in + the model. If there was no such class as :class:`Parameter`, these + temporaries would get registered too. + + Args: + data (Tensor): parameter tensor. + requires_grad (bool, optional): if the parameter requires gradient. Note that + the torch.no_grad() context does NOT affect the default behavior of + Parameter creation--the Parameter will still have `requires_grad=True` in + :class:`~no_grad` mode. See :ref:`locally-disable-grad-doc` for more + details. Default: `True` + """ + + def __new__(cls, data=None, requires_grad=True): + if data is None: + data = torch.empty(0) + if type(data) is torch.Tensor or type(data) is Parameter: + # For ease of BC maintenance, keep this path for standard Tensor. + # Eventually (tm), we should change the behavior for standard Tensor to match. + return torch.Tensor._make_subclass(cls, data, requires_grad) + + # Path for custom tensors: set a flag on the instance to indicate parameter-ness. + t = data.detach().requires_grad_(requires_grad) + if type(t) is not type(data): + raise RuntimeError( + f"Creating a Parameter from an instance of type {type(data).__name__} " + "requires that detach() returns an instance of the same type, but return " + f"type {type(t).__name__} was found instead. To use the type as a " + "Parameter, please correct the detach() semantics defined by " + "its __torch_dispatch__() implementation." + ) + t._is_param = True + return t + + # Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types + # are still considered that custom tensor type and these methods will not be called for them. + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + result = type(self)( + self.data.clone(memory_format=torch.preserve_format), self.requires_grad + ) + memo[id(self)] = result + return result + + def __repr__(self): + return "Parameter containing:\n" + super().__repr__() + + def __reduce_ex__(self, proto): + state = torch._utils._get_obj_state(self) + + # See Note [Don't serialize hooks] + hooks = OrderedDict() + if not state: + return ( + torch._utils._rebuild_parameter, + (self.data, self.requires_grad, hooks), + ) + + return ( + torch._utils._rebuild_parameter_with_state, + (self.data, self.requires_grad, hooks, state), + ) + + __torch_function__ = _disabled_torch_function_impl + + +class UninitializedTensorMixin: + _allowed_methods = [ + torch.Tensor.__hash__, + torch.Tensor.size, + torch.Tensor.copy_, + torch.Tensor.is_complex, + torch.Tensor.is_floating_point, + torch.Tensor.half, + torch.Tensor.float, + torch.Tensor.double, + torch.Tensor.char, + torch.Tensor.short, + torch.Tensor.int, + torch.Tensor.long, + torch.Tensor.cuda, + torch.Tensor.cpu, + torch.Tensor.to, + torch.Tensor.get_device, + torch._has_compatible_shallow_copy_type, + ] + + def materialize(self, shape, device=None, dtype=None): + r"""Create a Parameter or Tensor with the same properties of the uninitialized one. + + Given a shape, it materializes a parameter in the same device + and with the same `dtype` as the current one or the specified ones in the + arguments. + + Args: + shape : (tuple): the shape for the materialized tensor. + device (:class:`torch.device`): the desired device of the parameters + and buffers in this module. Optional. + dtype (:class:`torch.dtype`): the desired floating point type of + the floating point parameters and buffers in this module. Optional. + """ + if device is None: + device = self.data.device + if dtype is None: + dtype = self.data.dtype + self.data = torch.empty(shape, device=device, dtype=dtype) + self.__class__ = self.cls_to_become + + @property + def shape(self): + raise RuntimeError( + "Can't access the shape of an uninitialized parameter or buffer. " + "This error usually happens in `load_state_dict` when trying to load " + "an uninitialized parameter into an initialized one. " + "Call `forward` to initialize the parameters before accessing their attributes." + ) + + def share_memory_(self): + raise RuntimeError( + "Can't share memory on an uninitialized parameter or buffer. " + "Call `forward` to initialize the parameters before calling " + "`module.share_memory()`." + ) + + def __repr__(self): + return f"<{self.__class__.__name__}>" + + def __reduce_ex__(self, proto): + # See Note [Don't serialize hooks] + return (self.__class__, (self.requires_grad,)) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + # method-wrapper is to detect access to Tensor properties that are + # wrapped in descriptors + if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper": + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + raise ValueError( + f"Attempted to use an uninitialized parameter in {func}. " + "This error happens when you are using a `LazyModule` or " + f"explicitly manipulating `torch.nn.parameter.{cls.__name__}` " + "objects. When using LazyModules Call `forward` with a dummy batch " + "to initialize the parameters before calling torch functions" + ) + + +def is_lazy(param): + return isinstance(param, UninitializedTensorMixin) + + +class UninitializedParameter(UninitializedTensorMixin, Parameter): + r"""A parameter that is not initialized. + + Uninitialized Parameters are a a special case of :class:`torch.nn.Parameter` + where the shape of the data is still unknown. + + Unlike a :class:`torch.nn.Parameter`, uninitialized parameters + hold no data and attempting to access some properties, like their shape, + will throw a runtime error. The only operations that can be performed on a uninitialized + parameter are changing its datatype, moving it to a different device and + converting it to a regular :class:`torch.nn.Parameter`. + + The default device or dtype to use when the parameter is materialized can be set + during construction using e.g. ``device='cuda'``. + """ + + cls_to_become = Parameter + + def __new__(cls, requires_grad=True, device=None, dtype=None) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + data = torch.empty(0, **factory_kwargs) + return torch.Tensor._make_subclass(cls, data, requires_grad) + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + result = type(self)(self.requires_grad, self.data.device, self.data.dtype) + memo[id(self)] = result + return result + + +# Metaclass to combine _TensorMeta and the instance check override for Buffer. +class _BufferMeta(torch._C._TensorMeta): + # Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag. + def __instancecheck__(self, instance): + if self is Buffer: + if isinstance(instance, torch.Tensor) and getattr( + instance, "_is_buffer", False + ): + return True + return super().__instancecheck__(instance) + + +class Buffer(torch.Tensor, metaclass=_BufferMeta): + r"""A kind of Tensor that should not be considered a model + parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. + + Buffers are :class:`~torch.Tensor` subclasses, that have a + very special property when used with :class:`Module` s -- when they're + assigned as Module attributes they are automatically added to the list of + its buffers, and will appear e.g. in :meth:`~torch.nn.Module.buffers` iterator. + Assigning a Tensor doesn't have such effect. One can still assign a Tensor as explicitly by using + the :meth:`~torch.nn.Module.register_buffer` function. + + Args: + data (Tensor): buffer tensor. + persistent (bool, optional): whether the buffer is part of the module's + :attr:`state_dict`. Default: ``True`` + """ + + def __new__(cls, data=None, *, persistent=True): + if data is None: + data = torch.empty(0) + + t = data.detach().requires_grad_(data.requires_grad) + t.persistent = persistent + t._is_buffer = True + return t + + __torch_function__ = _disabled_torch_function_impl + + +class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): + r"""A buffer that is not initialized. + + Uninitialized Buffer is a a special case of :class:`torch.Tensor` + where the shape of the data is still unknown. + + Unlike a :class:`torch.Tensor`, uninitialized parameters + hold no data and attempting to access some properties, like their shape, + will throw a runtime error. The only operations that can be performed on a uninitialized + parameter are changing its datatype, moving it to a different device and + converting it to a regular :class:`torch.Tensor`. + + The default device or dtype to use when the buffer is materialized can be set + during construction using e.g. ``device='cuda'``. + """ + + cls_to_become = torch.Tensor + + def __new__( + cls, requires_grad=False, device=None, dtype=None, persistent=True + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + data = torch.empty(0, **factory_kwargs) + ret = torch.Tensor._make_subclass(cls, data, requires_grad) + ret.persistent = persistent + ret._is_buffer = True + return ret diff --git a/lib/python3.10/site-packages/torch/nn/parameter.pyi b/lib/python3.10/site-packages/torch/nn/parameter.pyi new file mode 100644 index 0000000000000000000000000000000000000000..9c998fb07f2c185c1ad2f59556486a95d7c6a9e5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/nn/parameter.pyi @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +from typing_extensions import TypeGuard + +from torch import device, dtype, Tensor + +class Parameter(Tensor): + def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ... + +def is_lazy( + param: Tensor, +) -> TypeGuard[UninitializedParameter | UninitializedBuffer]: ... + +class UninitializedParameter(Tensor): + def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ... + def materialize( + self, + shape: tuple[int, ...], + device: device | None = None, + dtype: dtype | None = None, + ) -> None: ... + +class Buffer(Tensor): + persistent: bool + def __init__( + self, + data: Tensor = ..., + requires_grad: bool = ..., + persistent: bool = ..., + ): ... + +class UninitializedBuffer(Tensor): + persistent: bool + def __init__( + self, + data: Tensor = ..., + requires_grad: bool = ..., + persistent: bool = ..., + ): ... + def materialize( + self, + shape: tuple[int, ...], + device: device | None = None, + dtype: dtype | None = None, + ) -> None: ... diff --git a/lib/python3.10/site-packages/torch/onnx/__init__.py b/lib/python3.10/site-packages/torch/onnx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..27c8e2c6240c18fe550b9a383c41a3ab2a33182c --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/__init__.py @@ -0,0 +1,553 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + + +__all__ = [ + # Modules + "symbolic_helper", + "utils", + "errors", + # All opsets + "symbolic_caffe2", + "symbolic_opset7", + "symbolic_opset8", + "symbolic_opset9", + "symbolic_opset10", + "symbolic_opset11", + "symbolic_opset12", + "symbolic_opset13", + "symbolic_opset14", + "symbolic_opset15", + "symbolic_opset16", + "symbolic_opset17", + "symbolic_opset18", + "symbolic_opset19", + "symbolic_opset20", + # Enums + "ExportTypes", + "OperatorExportTypes", + "TrainingMode", + "TensorProtoDataType", + "JitScalarType", + # Public functions + "export", + "export_to_pretty_string", + "is_in_onnx_export", + "select_model_mode_for_export", + "register_custom_op_symbolic", + "unregister_custom_op_symbolic", + "disable_log", + "enable_log", + # Base error + "OnnxExporterError", + # Dynamo Exporter + "DiagnosticOptions", + "ExportOptions", + "ONNXProgram", + "ONNXRuntimeOptions", + "OnnxRegistry", + "dynamo_export", + "enable_fake_mode", + # DORT / torch.compile + "is_onnxrt_backend_supported", +] + +from typing import Any, Callable, Collection, Mapping, Sequence, TYPE_CHECKING + +import torch +from torch import _C +from torch._C import _onnx as _C_onnx +from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode + +from ._exporter_states import ExportTypes +from ._internal.onnxruntime import ( + is_onnxrt_backend_supported, + OrtBackend as _OrtBackend, + OrtBackendOptions as _OrtBackendOptions, + OrtExecutionProvider as _OrtExecutionProvider, +) +from ._type_utils import JitScalarType +from .errors import OnnxExporterError +from .utils import ( + _optimize_graph, + _run_symbolic_function, + _run_symbolic_method, + export_to_pretty_string, + is_in_onnx_export, + register_custom_op_symbolic, + select_model_mode_for_export, + unregister_custom_op_symbolic, +) + + +from . import ( # usort: skip. Keep the order instead of sorting lexicographically + errors, + symbolic_caffe2, + symbolic_helper, + symbolic_opset7, + symbolic_opset8, + symbolic_opset9, + symbolic_opset10, + symbolic_opset11, + symbolic_opset12, + symbolic_opset13, + symbolic_opset14, + symbolic_opset15, + symbolic_opset16, + symbolic_opset17, + symbolic_opset18, + symbolic_opset19, + symbolic_opset20, + utils, +) + + +from ._internal._exporter_legacy import ( # usort: skip. needs to be last to avoid circular import + DiagnosticOptions, + ExportOptions, + ONNXProgram, + ONNXRuntimeOptions, + OnnxRegistry, + enable_fake_mode, +) + + +if TYPE_CHECKING: + import os + +# Set namespace for exposed private names +DiagnosticOptions.__module__ = "torch.onnx" +ExportOptions.__module__ = "torch.onnx" +ExportTypes.__module__ = "torch.onnx" +JitScalarType.__module__ = "torch.onnx" +ONNXProgram.__module__ = "torch.onnx" +ONNXRuntimeOptions.__module__ = "torch.onnx" +OnnxExporterError.__module__ = "torch.onnx" +OnnxRegistry.__module__ = "torch.onnx" +_OrtBackend.__module__ = "torch.onnx" +_OrtBackendOptions.__module__ = "torch.onnx" +_OrtExecutionProvider.__module__ = "torch.onnx" +enable_fake_mode.__module__ = "torch.onnx" +is_onnxrt_backend_supported.__module__ = "torch.onnx" + +producer_name = "pytorch" +producer_version = _C_onnx.PRODUCER_VERSION + + +def export( + model: torch.nn.Module + | torch.export.ExportedProgram + | torch.jit.ScriptModule + | torch.jit.ScriptFunction, + args: tuple[Any, ...] = (), + f: str | os.PathLike | None = None, + *, + kwargs: dict[str, Any] | None = None, + export_params: bool = True, + verbose: bool | None = None, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + opset_version: int | None = None, + dynamic_axes: Mapping[str, Mapping[int, str]] + | Mapping[str, Sequence[int]] + | None = None, + keep_initializers_as_inputs: bool = False, + dynamo: bool = False, + # Dynamo only options + external_data: bool = True, + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, + report: bool = False, + verify: bool = False, + profile: bool = False, + dump_exported_program: bool = False, + artifacts_dir: str | os.PathLike = ".", + fallback: bool = False, + # Deprecated options + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, + do_constant_folding: bool = True, + custom_opsets: Mapping[str, int] | None = None, + export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, + autograd_inlining: bool = True, + **_: Any, # ignored options +) -> Any | None: + r"""Exports a model into ONNX format. + + Args: + model: The model to be exported. + args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the + exported model; any Tensor arguments will become inputs of the exported model, + in the order they occur in the tuple. + f: Path to the output ONNX model file. E.g. "model.onnx". + kwargs: Optional example keyword inputs. + export_params: If false, parameters (weights) will not be exported. + verbose: Whether to enable verbose logging. + input_names: names to assign to the input nodes of the graph, in order. + output_names: names to assign to the output nodes of the graph, in order. + opset_version: The version of the + `default (ai.onnx) opset `_ + to target. Must be >= 7. + dynamic_axes: + + By default the exported model will have the shapes of all input and output tensors + set to exactly match those given in ``args``. To specify axes of tensors as + dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: + + * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or + ``output_names``. + * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a + list, each element is an axis index. + + For example:: + + class SumModule(torch.nn.Module): + def forward(self, x): + return torch.sum(x, dim=1) + + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_value: 2 # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_value: 2 # axis 0 + ... + + While:: + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + dynamic_axes={ + # dict value: manually named axes + "x": {0: "my_custom_axis_name"}, + # list value: automatic names + "sum": [0], + }, + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_param: "my_custom_axis_name" # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_param: "sum_dynamic_axes_1" # axis 0 + ... + + keep_initializers_as_inputs: If True, all the + initializers (typically corresponding to model weights) in the + exported graph will also be added as inputs to the graph. If False, + then initializers are not added as inputs to the graph, and only + the user inputs are added as inputs. + + Set this to True if you intend to supply model weights at runtime. + Set it to False if the weights are static to allow for better optimizations + (e.g. constant folding) by backends/runtimes. + + dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript. + external_data: Whether to save the model weights as an external data file. + This is required for models with large weights that exceed the ONNX file size limit (2GB). + When False, the weights are saved in the ONNX file with the model architecture. + dynamic_shapes: A dictionary of dynamic shapes for the model inputs. Refer to + :func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True. + Only one parameter `dynamic_axes` or `dynamic_shapes` should be set + at the same time. + report: Whether to generate a markdown report for the export process. + verify: Whether to verify the exported model using ONNX Runtime. + profile: Whether to profile the export process. + dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file. + This is useful for debugging the exporter. + artifacts_dir: The directory to save the debugging artifacts like the report and the serialized + exported program. + fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails. + + training: Deprecated option. Instead, set the training mode of the model before exporting. + operator_export_type: Deprecated option. Only ONNX is supported. + do_constant_folding: Deprecated option. The exported graph is always optimized. + custom_opsets: Deprecated. + A dictionary: + + * KEY (str): opset domain name + * VALUE (int): opset version + + If a custom opset is referenced by ``model`` but not mentioned in this dictionary, + the opset version is set to 1. Only custom opset domain name and version should be + indicated through this argument. + export_modules_as_functions: Deprecated option. + + Flag to enable + exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the + particular types of modules to export as local functions in ONNX. + This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because + ``opset_version`` < 15 implies IR version < 8, which means no local function support. + Module variables will be exported as function attributes. There are two categories of function + attributes. + + 1. Annotated attributes: class variables that have type annotations via + `PEP 526-style `_ + will be exported as attributes. + Annotated attributes are not used inside the subgraph of ONNX local function because + they are not created by PyTorch JIT tracing, but they may be used by consumers + to determine whether or not to replace the function with a particular fused kernel. + + 2. Inferred attributes: variables that are used by operators inside the module. Attribute names + will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from + python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. + + * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. + * ``True``: export all ``nn.Module`` forward calls as local function nodes. + * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, + only if the type of the ``nn.Module`` is found in the set. + autograd_inlining: Deprecated. + Flag used to control whether to inline autograd functions. + Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. + """ + if dynamo is True or isinstance(model, torch.export.ExportedProgram): + from torch.onnx._internal import exporter + + if isinstance(args, torch.Tensor): + args = (args,) + return exporter.export_compat( + model, + args, + f, + kwargs=kwargs, + export_params=export_params, + verbose=verbose, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + external_data=external_data, + dynamic_shapes=dynamic_shapes, + report=report, + verify=verify, + profile=profile, + dump_exported_program=dump_exported_program, + artifacts_dir=artifacts_dir, + fallback=fallback, + ) + else: + from torch.onnx.utils import export + + if dynamic_shapes: + raise ValueError( + "The exporter only supports dynamic shapes " + "through parameter dynamic_axes when dynamo=False." + ) + + export( + model, + args, + f, # type: ignore[arg-type] + kwargs=kwargs, + export_params=export_params, + verbose=verbose is True, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + training=training, + operator_export_type=operator_export_type, + do_constant_folding=do_constant_folding, + custom_opsets=custom_opsets, + export_modules_as_functions=export_modules_as_functions, + autograd_inlining=autograd_inlining, + ) + return None + + +def dynamo_export( + model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined] + /, + *model_args, + export_options: ExportOptions | None = None, + **model_kwargs, +) -> ONNXProgram | Any: + """Export a torch.nn.Module to an ONNX graph. + + Args: + model: The PyTorch model to be exported to ONNX. + model_args: Positional inputs to ``model``. + model_kwargs: Keyword inputs to ``model``. + export_options: Options to influence the export to ONNX. + + Returns: + An in-memory representation of the exported ONNX model. + + **Example 1 - Simplest export** + :: + + class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x, bias=None): + out = self.linear(x) + out = out + bias + return out + + + model = MyModel() + kwargs = {"bias": 3.0} + args = (torch.randn(2, 2, 2),) + onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save( + "my_simple_model.onnx" + ) + + **Example 2 - Exporting with dynamic shapes** + :: + + # The previous model can be exported with dynamic shapes + export_options = torch.onnx.ExportOptions(dynamic_shapes=True) + onnx_program = torch.onnx.dynamo_export( + model, *args, **kwargs, export_options=export_options + ) + onnx_program.save("my_dynamic_model.onnx") + """ + + # NOTE: The new exporter is experimental and is not enabled by default. + import warnings + + from torch.onnx import _flags + from torch.onnx._internal import exporter + from torch.utils import _pytree + + if isinstance(model, torch.export.ExportedProgram): + return exporter.export_compat( + model, # type: ignore[arg-type] + model_args, + f=None, + kwargs=model_kwargs, + opset_version=18, + external_data=True, + export_params=True, + fallback=True, + ) + elif _flags.USE_EXPERIMENTAL_LOGIC: + if export_options is not None: + warnings.warn( + "You are using an experimental ONNX export logic, which currently only supports dynamic shapes. " + "For a more comprehensive set of export options, including advanced features, please consider using " + "`torch.onnx.export(..., dynamo=True)`. ", + category=FutureWarning, + ) + + if export_options is not None and export_options.dynamic_shapes: + # Make all shapes dynamic + def _to_dynamic_shapes_mapper(): + arg_order = 0 + + def _to_dynamic_shape(x): + nonlocal arg_order + if isinstance(x, torch.Tensor): + rank = len(x.shape) + dynamic_shape = {} + for i in range(rank): + dynamic_shape[i] = torch.export.Dim( + f"arg_{arg_order}_dim_{i}" + ) + arg_order += 1 + return dynamic_shape + else: + return None + + return _to_dynamic_shape + + # model_args could be nested + dynamic_shapes = _pytree.tree_map( + _to_dynamic_shapes_mapper(), + model_args, + ) + else: + dynamic_shapes = None + + return exporter.export_compat( + model, # type: ignore[arg-type] + model_args, + f=None, + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + opset_version=18, + external_data=True, + export_params=True, + fallback=True, + ) + else: + from torch.onnx._internal._exporter_legacy import dynamo_export + + return dynamo_export( + model, *model_args, export_options=export_options, **model_kwargs + ) + + +# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module. + +# Returns True iff ONNX logging is turned on. +is_onnx_log_enabled = _C._jit_is_onnx_log_enabled + + +def enable_log() -> None: + r"""Enables ONNX logging.""" + _C._jit_set_onnx_log_enabled(True) + + +def disable_log() -> None: + r"""Disables ONNX logging.""" + _C._jit_set_onnx_log_enabled(False) + + +"""Sets output stream for ONNX logging. + +Args: + stream_name (str, default "stdout"): Only 'stdout' and 'stderr' are supported + as ``stream_name``. +""" +set_log_stream = _C._jit_set_onnx_log_output_stream + + +"""A simple logging facility for ONNX exporter. + +Args: + args: Arguments are converted to string, concatenated together with a newline + character appended to the end, and flushed to output stream. +""" +log = _C._jit_onnx_log diff --git a/lib/python3.10/site-packages/torch/onnx/_constants.py b/lib/python3.10/site-packages/torch/onnx/_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..6c91b245ed703f3b539b3baff36c25e278f40134 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/_constants.py @@ -0,0 +1,25 @@ +"""Constant values used in ONNX.""" + +ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO" + +ONNX_BASE_OPSET = 9 +ONNX_MIN_OPSET = 7 +ONNX_MAX_OPSET = 20 +ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20 +# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py +ONNX_DEFAULT_OPSET = 17 +ONNX_CONSTANT_FOLDING_MIN_OPSET = 9 + +PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues" + +INT64_MAX = 9223372036854775807 +INT32_MAX = 2147483647 +INT16_MAX = 32767 +INT8_MAX = 127 +UINT8_MAX = 255 + +INT64_MIN = -9223372036854775808 +INT32_MIN = -2147483648 +INT16_MIN = -32768 +INT8_MIN = -128 +UINT8_MIN = 0 diff --git a/lib/python3.10/site-packages/torch/onnx/_deprecation.py b/lib/python3.10/site-packages/torch/onnx/_deprecation.py new file mode 100644 index 0000000000000000000000000000000000000000..24fe4ccc54fcdbe2dc8cd99ab802b8fdf25e9821 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/_deprecation.py @@ -0,0 +1,72 @@ +"""Utility for deprecating functions.""" + +import functools +import textwrap +import warnings +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def deprecated( + since: str, removed_in: str, instructions: str +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """Marks functions as deprecated. + + It will result in a warning when the function is called and a note in the + docstring. + + Args: + since: The version when the function was first deprecated. + removed_in: The version when the function will be removed. + instructions: The action users should take. + """ + + def decorator(function: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(function) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + warnings.warn( + f"'{function.__module__}.{function.__name__}' " + f"is deprecated in version {since} and will be " + f"removed in {removed_in}. Please {instructions}.", + category=FutureWarning, + stacklevel=2, + ) + return function(*args, **kwargs) + + # Add a deprecation note to the docstring. + docstring = function.__doc__ or "" + + # Add a note to the docstring. + deprecation_note = textwrap.dedent( + f"""\ + .. deprecated:: {since} + Deprecated and will be removed in version {removed_in}. + Please {instructions}. + """ + ) + + # Split docstring at first occurrence of newline + summary_and_body = docstring.split("\n\n", 1) + + if len(summary_and_body) > 1: + summary, body = summary_and_body + + # Dedent the body. We cannot do this with the presence of the summary because + # the body contains leading whitespaces when the summary does not. + body = textwrap.dedent(body) + + new_docstring_parts = [deprecation_note, "\n\n", summary, body] + else: + summary = summary_and_body[0] + + new_docstring_parts = [deprecation_note, "\n\n", summary] + + wrapper.__doc__ = "".join(new_docstring_parts) + + return wrapper + + return decorator diff --git a/lib/python3.10/site-packages/torch/onnx/_experimental.py b/lib/python3.10/site-packages/torch/onnx/_experimental.py new file mode 100644 index 0000000000000000000000000000000000000000..86c035d412fd7da2cd7b03a6e187dfca8a3a6061 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/_experimental.py @@ -0,0 +1,27 @@ +"""Experimental classes and functions used by ONNX export.""" + +import dataclasses +from typing import Mapping, Optional, Sequence, Set, Type, Union + +import torch +import torch._C._onnx as _C_onnx + + +@dataclasses.dataclass +class ExportOptions: + """Arguments used by :func:`torch.onnx.export`.""" + + # TODO(justinchuby): Deprecate and remove this class. + + export_params: bool = True + verbose: bool = False + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL + input_names: Optional[Sequence[str]] = None + output_names: Optional[Sequence[str]] = None + operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX + opset_version: Optional[int] = None + do_constant_folding: bool = True + dynamic_axes: Optional[Mapping[str, Union[Mapping[int, str], Sequence[int]]]] = None + keep_initializers_as_inputs: Optional[bool] = None + custom_opsets: Optional[Mapping[str, int]] = None + export_modules_as_functions: Union[bool, Set[Type[torch.nn.Module]]] = False diff --git a/lib/python3.10/site-packages/torch/onnx/_exporter_states.py b/lib/python3.10/site-packages/torch/onnx/_exporter_states.py new file mode 100644 index 0000000000000000000000000000000000000000..2fdf7a7ac95c7b58a5e21145bb11cedc3ad29a88 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/_exporter_states.py @@ -0,0 +1,12 @@ +from __future__ import annotations + + +class ExportTypes: + """Specifies how the ONNX model is stored.""" + + # TODO(justinchuby): Deprecate and remove this class. + + PROTOBUF_FILE = "Saves model in the specified protobuf file." + ZIP_ARCHIVE = "Saves model in the specified ZIP file (uncompressed)." + COMPRESSED_ZIP_ARCHIVE = "Saves model in the specified ZIP file (compressed)." + DIRECTORY = "Saves model in the specified folder." diff --git a/lib/python3.10/site-packages/torch/onnx/_flags.py b/lib/python3.10/site-packages/torch/onnx/_flags.py new file mode 100644 index 0000000000000000000000000000000000000000..6bbabef61870255fffd4be5ca26782e109a77597 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/_flags.py @@ -0,0 +1,49 @@ +"""Internal feature flags for torch.onnx. + +NOTE: These flags are experimental only. Any flag here can be removed at any +time without notice. +""" + +import logging +import os + + +logger = logging.getLogger(__name__) + + +def _load_boolean_flag( + name: str, + *, + this_will: str, + deprecated: bool = False, + default: bool = False, +) -> bool: + """Load a boolean flag from environment variable. + + Args: + name: The name of the environment variable. + this_will: A string that describes what this flag will do. + deprecated: Whether this flag is deprecated. + default: The default value if envvar not defined. + """ + undefined = os.getenv(name) is None + state = os.getenv(name) == "1" + if state: + if deprecated: + logger.error( + "Experimental flag %s is deprecated. Please remove it from your environment.", + name, + ) + else: + logger.warning( + "Experimental flag %s is enabled. This will %s.", name, this_will + ) + if undefined: + state = default + return state + + +USE_EXPERIMENTAL_LOGIC: bool = _load_boolean_flag( + "TORCH_ONNX_USE_EXPERIMENTAL_LOGIC", + this_will="use ExportedProgram and the new torch.onnx export logic", +) diff --git a/lib/python3.10/site-packages/torch/onnx/_globals.py b/lib/python3.10/site-packages/torch/onnx/_globals.py new file mode 100644 index 0000000000000000000000000000000000000000..f3dd273386f8f06edd41a1a1b53771c101f2902a --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/_globals.py @@ -0,0 +1,87 @@ +# mypy: allow-untyped-defs +"""Globals used internally by the ONNX exporter. + +Do not use this module outside of `torch.onnx` and its tests. + +Be very judicious when adding any new global variables. Do not create new global +variables unless they are absolutely necessary. +""" + +import torch._C._onnx as _C_onnx + +# This module should only depend on _constants and nothing else in torch.onnx to keep +# dependency direction clean. +from torch.onnx import _constants + + +class _InternalGlobals: + """Globals used internally by ONNX exporter. + + NOTE: Be very judicious when adding any new variables. Do not create new + global variables unless they are absolutely necessary. + """ + + def __init__(self) -> None: + self._export_onnx_opset_version = _constants.ONNX_DEFAULT_OPSET + self._training_mode: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL + self._in_onnx_export: bool = False + # Whether the user's model is training during export + self.export_training: bool = False + self.operator_export_type: _C_onnx.OperatorExportTypes = ( + _C_onnx.OperatorExportTypes.ONNX + ) + self.onnx_shape_inference: bool = True + self._autograd_inlining: bool = True + + @property + def training_mode(self): + """The training mode for the exporter.""" + return self._training_mode + + @training_mode.setter + def training_mode(self, training_mode: _C_onnx.TrainingMode): + if not isinstance(training_mode, _C_onnx.TrainingMode): + raise TypeError( + "training_mode must be of type 'torch.onnx.TrainingMode'. This is " + "likely a bug in torch.onnx." + ) + self._training_mode = training_mode + + @property + def export_onnx_opset_version(self) -> int: + """Opset version used during export.""" + return self._export_onnx_opset_version + + @export_onnx_opset_version.setter + def export_onnx_opset_version(self, value: int): + supported_versions = range( + _constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1 + ) + if value not in supported_versions: + raise ValueError(f"Unsupported ONNX opset version: {value}") + self._export_onnx_opset_version = value + + @property + def in_onnx_export(self) -> bool: + """Whether it is in the middle of ONNX export.""" + return self._in_onnx_export + + @in_onnx_export.setter + def in_onnx_export(self, value: bool): + if type(value) is not bool: + raise TypeError("in_onnx_export must be a boolean") + self._in_onnx_export = value + + @property + def autograd_inlining(self) -> bool: + """Whether Autograd must be inlined.""" + return self._autograd_inlining + + @autograd_inlining.setter + def autograd_inlining(self, value: bool): + if type(value) is not bool: + raise TypeError("autograd_inlining must be a boolean") + self._autograd_inlining = value + + +GLOBALS = _InternalGlobals() diff --git a/lib/python3.10/site-packages/torch/onnx/_onnx_supported_ops.py b/lib/python3.10/site-packages/torch/onnx/_onnx_supported_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e2707298d6d908639104ffcc36846f2e2430ee7b --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/_onnx_supported_ops.py @@ -0,0 +1,98 @@ +# mypy: allow-untyped-defs +import inspect +from typing import Dict, List, Union + +from torch import _C +from torch.onnx import _constants +from torch.onnx._internal import registration + + +class _TorchSchema: + def __init__(self, schema: Union[_C.FunctionSchema, str]) -> None: + if isinstance(schema, _C.FunctionSchema): + self.name: str = schema.name + self.overload_name: str = schema.overload_name + self.arguments: List[str] = [arg.name for arg in schema.arguments] + self.optional_arguments: List[str] = [] + self.returns: List[str] = [ret.name for ret in schema.returns] + self.opsets: List[int] = [] + else: + self.name = schema + self.overload_name = "" + self.arguments = [] + self.optional_arguments = [] + self.returns = [] + self.opsets = [] + + def __str__(self) -> str: + s = ( + f"{self.name}.{self.overload_name}(" + + ", ".join(self.arguments) + + ") -> (" + + ", ".join(self.returns) + + ")" + + " in opsets " + + ", ".join(str(opset) for opset in self.opsets) + ) + return s + + def __hash__(self): + # TODO(thiagocrepaldi): handle overload_name? + return hash(self.name) + + def __eq__(self, other) -> bool: + if not isinstance(other, _TorchSchema): + return False + # TODO(thiagocrepaldi): handle overload_name? + return self.name == other.name + + def is_aten(self) -> bool: + return self.name.startswith("aten::") + + def is_backward(self) -> bool: + return "backward" in self.name + + +def _symbolic_argument_count(func): + params = [] + signature = inspect.signature(func) + optional_params = [] + for name, parameter in signature.parameters.items(): + if name in {"_outputs", "g"}: + continue + if parameter.default is parameter.empty: + optional_params.append(parameter) + else: + params.append(str(parameter)) + return params + + +def all_forward_schemas() -> Dict[str, _TorchSchema]: + """Returns schemas for all TorchScript forward ops.""" + torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()] + return {schema.name: schema for schema in torch_schemas if not schema.is_backward()} + + +def all_symbolics_schemas() -> Dict[str, _TorchSchema]: + """Returns schemas for all onnx supported ops.""" + symbolics_schemas = {} + + for name in registration.registry.all_functions(): + func_group = registration.registry.get_function_group(name) + assert func_group is not None + symbolics_schema = _TorchSchema(name) + func = func_group.get(_constants.ONNX_MAX_OPSET) + if func is not None: + symbolics_schema.arguments = _symbolic_argument_count(func) + symbolics_schema.opsets = list( + range(func_group.get_min_supported(), _constants.ONNX_MAX_OPSET + 1) + ) + else: + # Only support opset < 9 + func = func_group.get(7) + symbolics_schema.arguments = _symbolic_argument_count(func) + symbolics_schema.opsets = list(range(7, _constants.ONNX_BASE_OPSET)) + + symbolics_schemas[name] = symbolics_schema + + return symbolics_schemas diff --git a/lib/python3.10/site-packages/torch/onnx/_type_utils.py b/lib/python3.10/site-packages/torch/onnx/_type_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..81bcaeef1107a1ca3291a4468e78c28ed0a906c8 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/_type_utils.py @@ -0,0 +1,391 @@ +# mypy: allow-untyped-defs +"""Utilities for converting and operating on ONNX, JIT and torch types.""" + +from __future__ import annotations + +import enum +import typing +from typing import Literal + +import torch +from torch._C import _onnx as _C_onnx +from torch.onnx import errors + + +if typing.TYPE_CHECKING: + # Hack to help mypy to recognize torch._C.Value + from torch import _C # noqa: F401 + +ScalarName = Literal[ + "Byte", + "Char", + "Double", + "Float", + "Half", + "Int", + "Long", + "Short", + "Bool", + "ComplexHalf", + "ComplexFloat", + "ComplexDouble", + "QInt8", + "QUInt8", + "QInt32", + "BFloat16", + "Float8E5M2", + "Float8E4M3FN", + "Float8E5M2FNUZ", + "Float8E4M3FNUZ", + "Undefined", +] + +TorchName = Literal[ + "bool", + "uint8_t", + "int8_t", + "double", + "float", + "half", + "int", + "int64_t", + "int16_t", + "complex32", + "complex64", + "complex128", + "qint8", + "quint8", + "qint32", + "bfloat16", + "float8_e5m2", + "float8_e4m3fn", + "float8_e5m2fnuz", + "float8_e4m3fnuz", +] + + +class JitScalarType(enum.IntEnum): + """Scalar types defined in torch. + + Use ``JitScalarType`` to convert from torch and JIT scalar types to ONNX scalar types. + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> # xdoctest: +IGNORE_WANT("win32 has different output") + >>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type() + TensorProtoDataType.FLOAT + + >>> JitScalarType.from_value(torch_c_value_with_type_float).onnx_type() + TensorProtoDataType.FLOAT + + >>> JitScalarType.from_dtype(torch.get_default_dtype).onnx_type() + TensorProtoDataType.FLOAT + + """ + + # Order defined in https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h + UINT8 = 0 + INT8 = enum.auto() # 1 + INT16 = enum.auto() # 2 + INT = enum.auto() # 3 + INT64 = enum.auto() # 4 + HALF = enum.auto() # 5 + FLOAT = enum.auto() # 6 + DOUBLE = enum.auto() # 7 + COMPLEX32 = enum.auto() # 8 + COMPLEX64 = enum.auto() # 9 + COMPLEX128 = enum.auto() # 10 + BOOL = enum.auto() # 11 + QINT8 = enum.auto() # 12 + QUINT8 = enum.auto() # 13 + QINT32 = enum.auto() # 14 + BFLOAT16 = enum.auto() # 15 + FLOAT8E5M2 = enum.auto() # 16 + FLOAT8E4M3FN = enum.auto() # 17 + FLOAT8E5M2FNUZ = enum.auto() # 18 + FLOAT8E4M3FNUZ = enum.auto() # 19 + UNDEFINED = enum.auto() # 20 + + @classmethod + def _from_name(cls, name: ScalarName | TorchName | str | None) -> JitScalarType: + """Convert a JIT scalar type or torch type name to ScalarType. + + Note: DO NOT USE this API when `name` comes from a `torch._C.Value.type()` calls. + A "RuntimeError: INTERNAL ASSERT FAILED at "../aten/src/ATen/core/jit_type_base.h" can + be raised in several scenarios where shape info is not present. + Instead use `from_value` API which is safer. + + Args: + name: JIT scalar type name (Byte) or torch type name (uint8_t). + + Returns: + JitScalarType + + Raises: + OnnxExporterError: if name is not a valid scalar type name or if it is None. + """ + if name is None: + raise errors.OnnxExporterError("Scalar type name cannot be None") + if valid_scalar_name(name): + return _SCALAR_NAME_TO_TYPE[name] # type: ignore[index] + if valid_torch_name(name): + return _TORCH_NAME_TO_SCALAR_TYPE[name] # type: ignore[index] + + raise errors.OnnxExporterError(f"Unknown torch or scalar type: '{name}'") + + @classmethod + def from_dtype(cls, dtype: torch.dtype | None) -> JitScalarType: + """Convert a torch dtype to JitScalarType. + + Note: DO NOT USE this API when `dtype` comes from a `torch._C.Value.type()` calls. + A "RuntimeError: INTERNAL ASSERT FAILED at "../aten/src/ATen/core/jit_type_base.h" can + be raised in several scenarios where shape info is not present. + Instead use `from_value` API which is safer. + + Args: + dtype: A torch.dtype to create a JitScalarType from + + Returns: + JitScalarType + + Raises: + OnnxExporterError: if dtype is not a valid torch.dtype or if it is None. + """ + if dtype not in _DTYPE_TO_SCALAR_TYPE: + raise errors.OnnxExporterError(f"Unknown dtype: {dtype}") + return _DTYPE_TO_SCALAR_TYPE[dtype] + + @classmethod + def from_onnx_type( + cls, onnx_type: int | _C_onnx.TensorProtoDataType | None + ) -> JitScalarType: + """Convert a ONNX data type to JitScalarType. + + Args: + onnx_type: A torch._C._onnx.TensorProtoDataType to create a JitScalarType from + + Returns: + JitScalarType + + Raises: + OnnxExporterError: if dtype is not a valid torch.dtype or if it is None. + """ + if onnx_type not in _ONNX_TO_SCALAR_TYPE: + raise errors.OnnxExporterError(f"Unknown onnx_type: {onnx_type}") + return _ONNX_TO_SCALAR_TYPE[typing.cast(_C_onnx.TensorProtoDataType, onnx_type)] + + @classmethod + def from_value( + cls, value: None | torch._C.Value | torch.Tensor, default=None + ) -> JitScalarType: + """Create a JitScalarType from an value's scalar type. + + Args: + value: An object to fetch scalar type from. + default: The JitScalarType to return if a valid scalar cannot be fetched from value + + Returns: + JitScalarType. + + Raises: + OnnxExporterError: if value does not have a valid scalar type and default is None. + SymbolicValueError: when value.type()'s info are empty and default is None + """ + + if not isinstance(value, (torch._C.Value, torch.Tensor)) or ( + isinstance(value, torch._C.Value) and value.node().mustBeNone() + ): + # default value of type JitScalarType is returned when value is not valid + if default is None: + raise errors.OnnxExporterError( + "value must be either torch._C.Value or torch.Tensor objects." + ) + elif not isinstance(default, JitScalarType): + raise errors.OnnxExporterError( + "default value must be a JitScalarType object." + ) + return default + + # Each value type has their own way of storing scalar type + if isinstance(value, torch.Tensor): + return cls.from_dtype(value.dtype) + if isinstance(value.type(), torch.ListType): + try: + return cls.from_dtype(value.type().getElementType().dtype()) + except RuntimeError: + return cls._from_name(str(value.type().getElementType())) + if isinstance(value.type(), torch._C.OptionalType): + if value.type().getElementType().dtype() is None: + if isinstance(default, JitScalarType): + return default + raise errors.OnnxExporterError( + "default value must be a JitScalarType object." + ) + return cls.from_dtype(value.type().getElementType().dtype()) + + scalar_type = None + if value.node().kind() != "prim::Constant" or not isinstance( + value.type(), torch._C.NoneType + ): + # value must be a non-list torch._C.Value scalar + scalar_type = value.type().scalarType() + + if scalar_type is not None: + return cls._from_name(scalar_type) + + # When everything fails... try to default + if default is not None: + return default + raise errors.SymbolicValueError( + f"Cannot determine scalar type for this '{type(value.type())}' instance and " + "a default value was not provided.", + value, + ) + + def scalar_name(self) -> ScalarName: + """Convert a JitScalarType to a JIT scalar type name.""" + return _SCALAR_TYPE_TO_NAME[self] + + def torch_name(self) -> TorchName: + """Convert a JitScalarType to a torch type name.""" + return _SCALAR_TYPE_TO_TORCH_NAME[self] + + def dtype(self) -> torch.dtype: + """Convert a JitScalarType to a torch dtype.""" + return _SCALAR_TYPE_TO_DTYPE[self] + + def onnx_type(self) -> _C_onnx.TensorProtoDataType: + """Convert a JitScalarType to an ONNX data type.""" + if self not in _SCALAR_TYPE_TO_ONNX: + raise errors.OnnxExporterError( + f"Scalar type {self} cannot be converted to ONNX" + ) + return _SCALAR_TYPE_TO_ONNX[self] + + def onnx_compatible(self) -> bool: + """Return whether this JitScalarType is compatible with ONNX.""" + return ( + self in _SCALAR_TYPE_TO_ONNX + and self != JitScalarType.UNDEFINED + and self != JitScalarType.COMPLEX32 + ) + + +def valid_scalar_name(scalar_name: ScalarName | str) -> bool: + """Return whether the given scalar name is a valid JIT scalar type name.""" + return scalar_name in _SCALAR_NAME_TO_TYPE + + +def valid_torch_name(torch_name: TorchName | str) -> bool: + """Return whether the given torch name is a valid torch type name.""" + return torch_name in _TORCH_NAME_TO_SCALAR_TYPE + + +# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h +_SCALAR_TYPE_TO_NAME: dict[JitScalarType, ScalarName] = { + JitScalarType.BOOL: "Bool", + JitScalarType.UINT8: "Byte", + JitScalarType.INT8: "Char", + JitScalarType.INT16: "Short", + JitScalarType.INT: "Int", + JitScalarType.INT64: "Long", + JitScalarType.HALF: "Half", + JitScalarType.FLOAT: "Float", + JitScalarType.DOUBLE: "Double", + JitScalarType.COMPLEX32: "ComplexHalf", + JitScalarType.COMPLEX64: "ComplexFloat", + JitScalarType.COMPLEX128: "ComplexDouble", + JitScalarType.QINT8: "QInt8", + JitScalarType.QUINT8: "QUInt8", + JitScalarType.QINT32: "QInt32", + JitScalarType.BFLOAT16: "BFloat16", + JitScalarType.FLOAT8E5M2: "Float8E5M2", + JitScalarType.FLOAT8E4M3FN: "Float8E4M3FN", + JitScalarType.FLOAT8E5M2FNUZ: "Float8E5M2FNUZ", + JitScalarType.FLOAT8E4M3FNUZ: "Float8E4M3FNUZ", + JitScalarType.UNDEFINED: "Undefined", +} + +_SCALAR_NAME_TO_TYPE: dict[ScalarName, JitScalarType] = { + v: k for k, v in _SCALAR_TYPE_TO_NAME.items() +} + +_SCALAR_TYPE_TO_TORCH_NAME: dict[JitScalarType, TorchName] = { + JitScalarType.BOOL: "bool", + JitScalarType.UINT8: "uint8_t", + JitScalarType.INT8: "int8_t", + JitScalarType.INT16: "int16_t", + JitScalarType.INT: "int", + JitScalarType.INT64: "int64_t", + JitScalarType.HALF: "half", + JitScalarType.FLOAT: "float", + JitScalarType.DOUBLE: "double", + JitScalarType.COMPLEX32: "complex32", + JitScalarType.COMPLEX64: "complex64", + JitScalarType.COMPLEX128: "complex128", + JitScalarType.QINT8: "qint8", + JitScalarType.QUINT8: "quint8", + JitScalarType.QINT32: "qint32", + JitScalarType.BFLOAT16: "bfloat16", + JitScalarType.FLOAT8E5M2: "float8_e5m2", + JitScalarType.FLOAT8E4M3FN: "float8_e4m3fn", + JitScalarType.FLOAT8E5M2FNUZ: "float8_e5m2fnuz", + JitScalarType.FLOAT8E4M3FNUZ: "float8_e4m3fnuz", +} + +_TORCH_NAME_TO_SCALAR_TYPE: dict[TorchName, JitScalarType] = { + v: k for k, v in _SCALAR_TYPE_TO_TORCH_NAME.items() +} + +_SCALAR_TYPE_TO_ONNX = { + JitScalarType.BOOL: _C_onnx.TensorProtoDataType.BOOL, + JitScalarType.UINT8: _C_onnx.TensorProtoDataType.UINT8, + JitScalarType.INT8: _C_onnx.TensorProtoDataType.INT8, + JitScalarType.INT16: _C_onnx.TensorProtoDataType.INT16, + JitScalarType.INT: _C_onnx.TensorProtoDataType.INT32, + JitScalarType.INT64: _C_onnx.TensorProtoDataType.INT64, + JitScalarType.HALF: _C_onnx.TensorProtoDataType.FLOAT16, + JitScalarType.FLOAT: _C_onnx.TensorProtoDataType.FLOAT, + JitScalarType.DOUBLE: _C_onnx.TensorProtoDataType.DOUBLE, + JitScalarType.COMPLEX64: _C_onnx.TensorProtoDataType.COMPLEX64, + JitScalarType.COMPLEX128: _C_onnx.TensorProtoDataType.COMPLEX128, + JitScalarType.BFLOAT16: _C_onnx.TensorProtoDataType.BFLOAT16, + JitScalarType.UNDEFINED: _C_onnx.TensorProtoDataType.UNDEFINED, + JitScalarType.COMPLEX32: _C_onnx.TensorProtoDataType.UNDEFINED, + JitScalarType.QINT8: _C_onnx.TensorProtoDataType.INT8, + JitScalarType.QUINT8: _C_onnx.TensorProtoDataType.UINT8, + JitScalarType.QINT32: _C_onnx.TensorProtoDataType.INT32, + JitScalarType.FLOAT8E5M2: _C_onnx.TensorProtoDataType.FLOAT8E5M2, + JitScalarType.FLOAT8E4M3FN: _C_onnx.TensorProtoDataType.FLOAT8E4M3FN, + JitScalarType.FLOAT8E5M2FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E5M2FNUZ, + JitScalarType.FLOAT8E4M3FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E4M3FNUZ, +} + +_ONNX_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_ONNX.items()} + +# source of truth is +# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp +_SCALAR_TYPE_TO_DTYPE = { + JitScalarType.BOOL: torch.bool, + JitScalarType.UINT8: torch.uint8, + JitScalarType.INT8: torch.int8, + JitScalarType.INT16: torch.short, + JitScalarType.INT: torch.int, + JitScalarType.INT64: torch.int64, + JitScalarType.HALF: torch.half, + JitScalarType.FLOAT: torch.float, + JitScalarType.DOUBLE: torch.double, + JitScalarType.COMPLEX32: torch.complex32, + JitScalarType.COMPLEX64: torch.complex64, + JitScalarType.COMPLEX128: torch.complex128, + JitScalarType.QINT8: torch.qint8, + JitScalarType.QUINT8: torch.quint8, + JitScalarType.QINT32: torch.qint32, + JitScalarType.BFLOAT16: torch.bfloat16, + JitScalarType.FLOAT8E5M2: torch.float8_e5m2, + JitScalarType.FLOAT8E4M3FN: torch.float8_e4m3fn, + JitScalarType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, + JitScalarType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, +} + +_DTYPE_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_DTYPE.items()} diff --git a/lib/python3.10/site-packages/torch/onnx/errors.py b/lib/python3.10/site-packages/torch/onnx/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e035a8a85f1d4786e2b08f1a73231c8aec9696 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/errors.py @@ -0,0 +1,103 @@ +"""ONNX exporter exceptions.""" + +from __future__ import annotations + + +__all__ = [ + "OnnxExporterWarning", + "SymbolicValueError", + "UnsupportedOperatorError", +] + +import textwrap +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from torch import _C + + +class OnnxExporterWarning(UserWarning): + """Warnings in the ONNX exporter.""" + + +class OnnxExporterError(RuntimeError): + """Errors raised by the ONNX exporter. This is the base class for all exporter errors.""" + + +class UnsupportedOperatorError(OnnxExporterError): + """Raised when an operator is unsupported by the exporter.""" + + # NOTE: This is legacy and is only used by the torchscript exporter + # Clean up when the torchscript exporter is removed + def __init__(self, name: str, version: int, supported_version: int | None): + from torch.onnx import _constants + from torch.onnx._internal import diagnostics + + if supported_version is not None: + diagnostic_rule: diagnostics.infra.Rule = ( + diagnostics.rules.operator_supported_in_newer_opset_version + ) + msg = diagnostic_rule.format_message(name, version, supported_version) + diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) + else: + if name.startswith(("aten::", "prim::", "quantized::")): + diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function + msg = diagnostic_rule.format_message( + name, version, _constants.PYTORCH_GITHUB_ISSUES_URL + ) + diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) + else: + diagnostic_rule = diagnostics.rules.missing_custom_symbolic_function + msg = diagnostic_rule.format_message(name) + diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) + super().__init__(msg) + + +class SymbolicValueError(OnnxExporterError): + """Errors around TorchScript values and nodes.""" + + # NOTE: This is legacy and is only used by the torchscript exporter + # Clean up when the torchscript exporter is removed + def __init__(self, msg: str, value: _C.Value): + message = ( + f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the " + f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] " + ) + + code_location = value.node().sourceRange() + if code_location: + message += f"\n (node defined in {code_location})" + + try: + # Add its input and output to the message. + message += "\n\n" + message += textwrap.indent( + ( + "Inputs:\n" + + ( + "\n".join( + f" #{i}: {input_} (type '{input_.type()}')" + for i, input_ in enumerate(value.node().inputs()) + ) + or " Empty" + ) + + "\n" + + "Outputs:\n" + + ( + "\n".join( + f" #{i}: {output} (type '{output.type()}')" + for i, output in enumerate(value.node().outputs()) + ) + or " Empty" + ) + ), + " ", + ) + except AttributeError: + message += ( + " Failed to obtain its input and output for debugging. " + "Please refer to the TorchScript graph for debugging information." + ) + + super().__init__(message) diff --git a/lib/python3.10/site-packages/torch/onnx/operators.py b/lib/python3.10/site-packages/torch/onnx/operators.py new file mode 100644 index 0000000000000000000000000000000000000000..88ac6779f91ccec30518b5fa302f4d3f16e6b77a --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/operators.py @@ -0,0 +1,47 @@ +# mypy: allow-untyped-defs +r"""This file provides a location for operators that help exporting models via onnx. + +E.g. `shape_as_tensor` and `reshape_from_tensor_shape` +are to make all dynamic sizes operations traceable. + +NOTE: at one point these functions were implemented differently. +Since then we have implemented these directly in ATen, so this +file is kept purely for backward-compatibility. +""" + +import torch +import torch.onnx + + +def shape_as_tensor(x): + """Get the shape of a tensor as a tensor. + + Args: + x (Tensor): The input tensor. + + Returns: + Tensor: A tensor of shape [len(x.shape)] containing the size of each dimension of x. + + Example: + >>> x = torch.randn(2, 3) + >>> shape_as_tensor(x) + tensor([2, 3]) + + """ + return torch._shape_as_tensor(x) + + +def reshape_from_tensor_shape(x, shape): + """Reshape a tensor to the given shape. + + This function is used to make dynamic size operations traceable when exporting models via ONNX. + This function is kept for backward-compatibility. It is implemented directly in ATen. + + Parameters: + x (Tensor): the tensor to be reshaped. + shape (Tensor): the target shape. + + Returns: + Tensor: the reshaped tensor. + """ + return torch._reshape_from_tensor(x, shape) diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_caffe2.py b/lib/python3.10/site-packages/torch/onnx/symbolic_caffe2.py new file mode 100644 index 0000000000000000000000000000000000000000..83a2ff6c32ec9511a86bf42b9ad76c7a24a5b2cf --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_caffe2.py @@ -0,0 +1,361 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +import importlib +import inspect + +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +def register_quantized_ops(domain: str, version: int): + # Register all quantized ops + module = importlib.import_module("torch.onnx.symbolic_caffe2") + quant_version_ops = inspect.getmembers(module) + aten_q_ops = { + "relu", + "_empty_affine_quantized", + "dequantize", + "quantize_per_tensor", + "upsample_nearest2d", + "avg_pool2d", + "reshape", + "slice", + "cat", + "max_pool2d", + "sigmoid", + } + for op, func in quant_version_ops: + name = f"{domain}::{op}" + if inspect.isfunction(func) and not registration.registry.is_registered_op( + name, version + ): + if op in aten_q_ops: + # Override the builtin aten ops + registration.registry.register( + f"aten::{op}", version, func, custom=True + ) + registration.registry.register(name, version, func) + + +def _permute_helper(g: jit_utils.GraphContext, input, axes): + quant_args = { + "axes_i": axes, + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + output = g.op("_caffe2::Int8Transpose", input, **quant_args) + symbolic_helper._quantized_ops.add(output) + return output + + +def nchw2nhwc(g: jit_utils.GraphContext, input): + axes = [0, 2, 3, 1] + return _permute_helper(g, input, axes) + + +def nhwc2nchw(g: jit_utils.GraphContext, input): + axes = [0, 3, 1, 2] + return _permute_helper(g, input, axes) + + +def linear_prepack(g: jit_utils.GraphContext, weight, bias): + # Mapping to a dummy caffe2 prepack node. + # During the onnx -> c2 conversion we can look up original weight and bias + # from this node + output = g.op("_caffe2::WeightPrepack", weight, bias) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "f", "i") +def linear(g: jit_utils.GraphContext, input, weight, bias, scale, zero_point): + kwargs = { + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +def conv_prepack( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + # Mapping to a dummy caffe2 prepack node. + # During the onnx -> c2 conversion we can look up original weight and bias + # from this node + output = g.op("_caffe2::WeightPrepack", input, weight, bias) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") +def conv2d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + scale, + zero_point, +): + kernel_size = weight.node()["shape"][1:3] + kwargs = { + "strides_i": stride, + "pads_i": padding + padding, + "dilations_i": dilation, + "group_i": groups, + "kernels_i": kernel_size, + "order_s": "NHWC", + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") +def conv2d_relu( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + scale, + zero_point, +): + kernel_size = weight.node()["shape"][1:3] + kwargs = { + "strides_i": stride, + "pads_i": padding + padding, + "dilations_i": dilation, + "group_i": groups, + "kernels_i": kernel_size, + "order_s": "NHWC", + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "f", "i") +def add(g: jit_utils.GraphContext, input_a, input_b, scale, zero_point): + kwargs = { + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v") +def relu(g: jit_utils.GraphContext, input): + if input not in symbolic_helper._quantized_ops: + return opset9.relu(g, input) + kwargs = { + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + output = g.op("_caffe2::Int8Relu", input, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "f", "i", "t") +def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): + kwargs = { + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8Quantize", input, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v") +def dequantize(g: jit_utils.GraphContext, input): + return g.op("_caffe2::Int8Dequantize", input) + + +@symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t") +def _empty_affine_quantized( + g: jit_utils.GraphContext, + input, + shape, + scale, + zero_point, + dtype, + pin_memory, + memory_format, + layout, +): + return input + + +def upsample_nearest2d( + g: jit_utils.GraphContext, + input, + output_size, + align_corners=None, + scales_h=None, + scales_w=None, +): + if input not in symbolic_helper._quantized_ops: + return opset9.upsample_nearest2d(g, input, output_size, align_corners) # type: ignore[attr-defined] + + output_size = symbolic_helper._parse_arg(output_size, "is") + kwargs = { + "output_size_i": output_size, + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + input = nchw2nhwc(g, input) + output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs) + output = nhwc2nchw(g, output) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") +def max_pool2d( + g: jit_utils.GraphContext, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, +): + if input not in symbolic_helper._quantized_ops: + return opset9.max_pool2d( # type: ignore[attr-defined] + g, input, kernel_size, stride, padding, dilation, ceil_mode + ) + kwargs = { + "strides_i": stride, + "pads_i": padding + padding, + "kernel_i": kernel_size[0], + "order_s": "NHWC", + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + input = nchw2nhwc(g, input) + output = g.op("_caffe2::Int8MaxPool", input, **kwargs) + output = nhwc2nchw(g, output) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") +def avg_pool2d( + g: jit_utils.GraphContext, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + if input not in symbolic_helper._quantized_ops: + return opset9.avg_pool2d( # type: ignore[attr-defined] + g, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + kwargs = { + "strides_i": stride, + "pads_i": padding + padding, + "kernel_i": kernel_size[0], + "order_s": "NHWC", + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + input = nchw2nhwc(g, input) + output = g.op("_caffe2::Int8AveragePool", input, **kwargs) + output = nhwc2nchw(g, output) + symbolic_helper._quantized_ops.add(output) + return output + + +def reshape(g: jit_utils.GraphContext, input, shape): + if input not in symbolic_helper._quantized_ops: + return opset9.reshape(g, input, shape) + + kwargs = { + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "v", "i") +def slice(g: jit_utils.GraphContext, input, dim, start, end, step): + if input not in symbolic_helper._quantized_ops: + return opset9.slice(g, input, dim, start, end, step) + + if step != 1: + raise RuntimeError("ONNX quantized slice export only works for step 1.") + start = symbolic_helper._parse_arg(start, "i") + end = symbolic_helper._parse_arg(end, "i") + dim = symbolic_helper._parse_arg(dim, "i") + + kwargs = { + "start_idx_i": start, + "end_idx_i": end, + "dim_i": dim, + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + output = g.op("_caffe2::Int8Slice", input, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +def cat(g: jit_utils.GraphContext, tensor_list, dim, scale=None, zero_point=None): + tensors = symbolic_helper._unpack_list(tensor_list) + input = tensors[0] + if input not in symbolic_helper._quantized_ops: + return opset9.cat(g, tensor_list, dim) + + dim = symbolic_helper._parse_arg(dim, "i") + kwargs = { + "Y_scale_f": tensors[0].node()["Y_scale"], + "Y_zero_point_i": tensors[0].node()["Y_zero_point"], + } + output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v") +def sigmoid(g: jit_utils.GraphContext, input): + if input not in symbolic_helper._quantized_ops: + return opset9.sigmoid(g, input) + # Caffe2 expects the output scale to be 1/2^8 + # and output zero_point to be 0 (quint8 type) + out_scale = 1.0 / 256 + zero_point = 0 + kwargs = { + "Y_scale_f": out_scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8Sigmoid", input, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py b/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..799f2d6f81a568345fa876dde6cb15faad50053b --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py @@ -0,0 +1,2261 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import inspect +import math +import sys +import typing +import warnings +from typing import Any, Callable, Literal, NoReturn, Sequence, TypeVar as _TypeVar +from typing_extensions import Concatenate as _Concatenate, ParamSpec as _ParamSpec + +import torch +import torch._C._onnx as _C_onnx +from torch import _C + +# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics +from torch.onnx import _constants, _type_utils, errors, utils +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils + + +if typing.TYPE_CHECKING: + from torch.types import Number + +_T = _TypeVar("_T") +_U = _TypeVar("_U") +_P = _ParamSpec("_P") + +# --------------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------------- + +_ValueDescriptor = Literal[ + "v", + "i", + "is", + "f", + "fs", + "b", + "s", + "t", + "none", +] + + +def _parse_arg( + value, + desc: _ValueDescriptor, + arg_name: str | None = None, + node_name: str | None = None, +): + if desc == "none": + return value + if desc == "v" or not _is_value(value): + return value + + node = value.node() + if node.mustBeNone(): + return None + if node.kind() == "onnx::Constant": + node_val = _node_get(node, "value") + if desc == "i": + return int(node_val) + elif desc == "f": + return float(node_val) + elif desc == "b": + return bool(node_val) + elif desc == "s": + return str(node_val) + elif desc == "t": + return node_val + elif desc == "is": + return [int(v) for v in node_val] + elif desc == "fs": + return [float(v) for v in node_val] + else: + raise errors.SymbolicValueError( + f"ONNX symbolic does not understand the Constant node '{node}' " + f"specified with descriptor '{desc}'.", + value, + ) + elif node.kind() == "prim::ListConstruct": + if desc == "is": + for v in node.inputs(): + element_node = v.node() + if element_node.kind() != "onnx::Constant": + raise errors.SymbolicValueError( + f"Failed to export a node '{element_node}' " + f"(in list node {node}) " + f"because it is not constant. " + f"Please try to make things (e.g. kernel sizes) static if possible.", + value, + ) + return [int(_node_get(v.node(), "value")) for v in value.node().inputs()] + else: + raise errors.SymbolicValueError( + f"ONNX symbolic does not know how to unpack the ListConstruct node that " + f"is not a list of integers: '{node}'", + value, + ) + + if arg_name is None or node_name is None: + raise errors.SymbolicValueError( + f"Expected node type 'onnx::Constant', got '{node.kind()}'.", + value, + ) + + raise errors.SymbolicValueError( + "Expected node type 'onnx::Constant' " + f"for argument '{arg_name}' of node '{node_name}', got '{node.kind()}'.", + value, + ) + + +def _node_get(node: _C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type.""" + assert isinstance(node, _C.Node) + sel = node.kindOf(key) + return getattr(node, sel)(key) + + +def _is_onnx_constant(value: _C.Value): + """Whether a Value is an ONNX constant.""" + return value.node().kind() == "onnx::Constant" + + +def _maybe_get_const( + value: _C.Value | torch.Tensor | Number | Sequence | None, + descriptor: _ValueDescriptor, +): + # NOTE: prim::Constant at this stage usually means something not compatible in ONNX, + # otherwise it'd be converted to onnx::Constant + # TODO(justinchuby): Replace insinstance with _is_value once we figure out mypy + if isinstance(value, _C.Value) and _is_onnx_constant(value): + return _parse_arg(value, descriptor) + return value + + +def _maybe_get_scalar(value): + value_t = _maybe_get_const(value, "t") + if isinstance(value_t, torch.Tensor) and value_t.shape == (): + return value_t + return value + + +def _get_const(value, desc, arg_name): + if not _is_constant(value): + raise errors.SymbolicValueError( + f"ONNX symbolic expected a constant value of the '{arg_name}' argument, " + f"got '{value}'", + value, + ) + return _parse_arg(value, desc) + + +def _unpack_list(list_value: _C.Value) -> list[_C.Value]: + list_node = list_value.node() + if list_node.kind() != "prim::ListConstruct": + raise errors.SymbolicValueError( + f"ONNX symbolic expected node type prim::ListConstruct, " + f"got '{list_node}'.", + list_value, + ) + return list(list_node.inputs()) + + +def _unpack_tuple(tuple_value: _C.Value) -> tuple[_C.Value, ...]: + tuple_node = tuple_value.node() + if not _is_tuple_construct(tuple_value): + raise errors.SymbolicValueError( + f"ONNX symbolic expected node type 'prim::TupleConstruct', " + f"got '{tuple_node.kind()}'.", + tuple_value, + ) + return tuple(tuple_node.inputs()) + + +def _unpack_quantized_tensor(tuple_value: _C.Value) -> tuple[_C.Value, ...]: + """Unpacks a quantized tensor into a tuple of tensor and scale/zero_point. + Args: + tuple_value: A tuple of tensor, scale, zero_point, and optionally axis. + Returns: + A tuple of tensor, scale, zero_point, and optionally axis. + """ + tuple_node = tuple_value.node() + # A quantized tensor is represented as tuple of the form (tensor, scale, zero_point, ) + if not _is_tuple_construct(tuple_value): + raise errors.SymbolicValueError( + f"ONNX symbolic expected the output of `{tuple_node}` to be a quantized " + f"tensor. Is this likely due to missing support for quantized " + f"`{tuple_node.kind()}`. Please create an issue on {_constants.PYTORCH_GITHUB_ISSUES_URL}", + tuple_value, + ) + unpacked = tuple(tuple_node.inputs()) + assert len(unpacked) == 3 or len(unpacked) == 4 + return unpacked + + +# Check if list_value is output from prim::ListConstruct +# This is usually called before _unpack_list to ensure the list can be unpacked. +def _is_packed_list(list_value: Any) -> bool: + return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct" + + +def parse_args( + *arg_descriptors: _ValueDescriptor, +) -> Callable[[Callable[_Concatenate[_U, _P], _T]], Callable[_Concatenate[_U, _P], _T]]: + """A decorator which converts args from torch._C.Value to built-in types. + + For example: + + ``` + @parse_args('v', 'i', 'fs') + foo(g, a, b, c): + assert isinstance(a, torch._C.Value) + assert isinstance(b, int) + assert isinstance(c, list) + assert isinstance(c[0], float) + ``` + + Args: + arg_descriptors: list of str, where each element is + a string that specifies the type to convert to. Valid descriptors: + "v": no conversion, keep torch._C.Value. + "i": int + "is": list of int + "f": float + "fs": list of float + "b": bool + "s": str + "t": torch.Tensor + "none": the variable is unused + """ + + def decorator( + fn: Callable[_Concatenate[_U, _P], _T], + ) -> Callable[_Concatenate[_U, _P], _T]: + fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined] + + @functools.wraps(fn) + def wrapper(g: _U, *args: _P.args, **kwargs: _P.kwargs) -> _T: + # some args may be optional, so the length may be smaller + FILE_BUG_MSG = ( + "If you believe this is not due to custom symbolic implementation within your code or " + "an external library, please file an issue at " + "https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to report this bug." + ) + assert len(arg_descriptors) >= len(args), ( + f"A mismatch between the number of arguments ({len(args)}) and " + f"their descriptors ({len(arg_descriptors)}) was found at symbolic function '{fn.__name__}'. " + f"{FILE_BUG_MSG}" + ) + + try: + sig = inspect.signature(fn) + arg_names = list(sig.parameters.keys())[1:] + fn_name = fn.__name__ + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + arg_names = [None] * len(args) # type: ignore[list-item] + fn_name = None + args = [ + _parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[method-assign] + for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names) + ] + # only support _outputs in kwargs + assert len(kwargs) <= 1, ( + f"Symbolic function {fn.__name__}'s '**kwargs' can contain a single " + f"key/value entry. " + f"{FILE_BUG_MSG}" + ) + + if len(kwargs) == 1: + assert "_outputs" in kwargs, ( + f"Symbolic function {fn.__name__}'s '**kwargs' can only contain " + f"'_outputs' key at '**kwargs'. " + f"{FILE_BUG_MSG}" + ) + return fn(g, *args, **kwargs) + + return wrapper + + return decorator + + +def quantized_args( + *arg_q_descriptors: bool, + scale: float | None = None, + zero_point: int | None = None, + quantize_output: bool = True, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """A decorator which extends support for quantized version of the base operator. + + Quantization is detected by examining the arguments that are annotated by + `arg_q_descriptors`. + + If quantization is detected, the base operator symbolic function will be wrapped with + argument de-quantization and output quantization. + + Otherwise, only the base symbolic function will be invoked. + + For example: + + ``` + @quantized_args(True, False) + def foo(g, x, y): + return x + y + ``` + + is equivalent to + + ``` + def q_foo(g, x, y): + if is_quantized_tensor(x): + x = dequantize(x) + out = foo(g, x, y) + return quantize(out) + else: + return foo(g, x, y) + ``` + + Args: + arg_q_descriptors: A sequence of bool, where each element represents if the + argument is QTensor for quantized version of this operator. It defaults + to False for unspecified (variable length) arguments. + scale: Quantized output scale. If None, derive from + the first quantized input scale. + zero_point: Quantized output zero point. If None, + derive from the first quantized input zero point. + quantize_output: If True, quantize the output of the base operator. Default is True + """ + + def decorator(fn): + @functools.wraps(fn) + def wrapper(g, *args, **kwargs): + nonlocal scale + nonlocal zero_point + if scale is not None: + _scale = g.op("Constant", value_t=torch.tensor(scale)) + else: + _scale = None + if zero_point is not None: + _zero_point = g.op("Constant", value_t=torch.tensor(zero_point)) + else: + _zero_point = None + + # Support variable length arguments by marking unspecified ones as non-quantized + arg_q_descriptors_extended = arg_q_descriptors + (False,) * ( + len(args) - len(arg_q_descriptors) + ) + descriptor_args = tuple(zip(arg_q_descriptors_extended, args)) + + def _is_arg_quantized(descriptor, arg): + return descriptor and _is_value(arg) and _is_tuple_construct(arg) + + # Run regular symbolic function if none of the argument is QTensor. + is_quantized = [] + for descriptor, arg in descriptor_args: + # ListConstruct + if _is_packed_list(arg): + for arg_input in arg.node().inputs(): + is_quantized.append(_is_arg_quantized(descriptor, arg_input)) + else: + is_quantized.append(_is_arg_quantized(descriptor, arg)) + + if not any(is_quantized): + return fn(g, *args, **kwargs) + + # Dequantize arguments that are quantized + non_quantized_args = [] + for descriptor, arg in descriptor_args: + if _is_arg_quantized(descriptor, arg): + # Quantized arg is a tuple of (value, scale, zero_point) + dequantized_arg, arg_scale, arg_zero_point, _ = dequantize_helper( + g, arg + ) + non_quantized_args.append(dequantized_arg) + # Set scale and zero_point to the first quantized input if not already set + if _scale is None: + _scale = arg_scale + if _zero_point is None: + _zero_point = arg_zero_point + # ListConstruct + elif _is_packed_list(arg): + for arg_input in arg.node().inputs(): + if _is_arg_quantized(descriptor, arg_input): + # Quantized arg is a tuple of (value, scale, zero_point) + ( + dequantized_arg, + arg_scale, + arg_zero_point, + _, + ) = dequantize_helper(g, arg_input) + # Set scale and zero_point to the first quantized input if not already set + if _scale is None: + _scale = arg_scale + if _zero_point is None: + _zero_point = arg_zero_point + arg_input.replaceAllUsesWith(dequantized_arg) + non_quantized_args.append(arg) + else: + # Non-quantized arg + non_quantized_args.append(arg) + # TODO(justinchuby): Only single output is supported for now. We may want to + # support multiple outputs in the future. + output = fn(g, *non_quantized_args, **kwargs) + + assert _scale is not None, "Bug: Scale must be set for quantized operator" + assert ( + _zero_point is not None + ), "Bug: Zero point must be set for quantized operator" + + if quantize_output: + return quantize_helper(g, output, _scale, _zero_point) + return output + + return wrapper + + return decorator + + +def _scalar(x: Any) -> Number | None: + """Convert a scalar tensor into a Python value.""" + if isinstance(x, torch.Tensor) and x.shape == (): + return x.item() + return None + + +def _if_scalar_type_as(self, tensor): + """ + Convert self into the same type of tensor, as necessary. + We only support implicit casting for scalars, so we never + actually need to insert an ONNX cast operator here; just + fix up the scalar. + """ + if isinstance(self, _C.Value): + return self + + scalar_type = _type_utils.JitScalarType.from_value( + tensor, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + ty = scalar_type.scalar_name().lower() + return getattr(self, ty)() + return self + + +def _is_none(x: Any) -> bool: + return x is None or (x.node().mustBeNone() if isinstance(x, _C.Value) else False) + + +def _is_value(x: Any) -> bool: + return isinstance(x, _C.Value) + + +def _is_constant(value: Any) -> bool: + return not _is_value(value) or value.node().kind() in { + "onnx::Constant", + "prim::Constant", + } + + +def _is_tensor(x: _C.Value) -> bool: + return x.type().isSubtypeOf(_C.TensorType.get()) + + +# Note: _C.JitType is not exposed to Python and cannot be checked in runtime. +def _as_list_type(jit_type: _C.JitType) -> _C.ListType | None: + if isinstance(jit_type, _C.ListType): + return jit_type + return None + + +def _is_list(x: _C.Value) -> bool: + return _as_list_type(x.type()) is not None + + +def _is_tensor_list(x: _C.Value) -> bool: + x_type = _as_list_type(x.type()) + if x_type is None: + return False + return isinstance(x_type.getElementType(), _C.TensorType) + + +def _is_scalar_list(x: _C.Value) -> bool: + """Checks if x is a scalar list, for example: List[float], List[int]. + + Besides checking the type is ListType, we also check if the data type is + a valid ONNX data type. + """ + x_type = _as_list_type(x.type()) + if x_type is None: + return False + scalar_type = _type_utils.JitScalarType.from_value(x) + return scalar_type.onnx_compatible() + + +def _is_tuple_construct(x: _C.Value) -> bool: + return x.node().kind() == "prim::TupleConstruct" + + +def is_complex_value(x: _C.Value) -> bool: + assert _is_value(x) + return _type_utils.JitScalarType.from_value( + x, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.COMPLEX32, + _type_utils.JitScalarType.COMPLEX64, + _type_utils.JitScalarType.COMPLEX128, + } + + +def _get_tensor_rank(x: _C.Value) -> int | None: + if not _is_tensor(x) or x.type() is None: + return None + x_type = x.type() + x_type = typing.cast(_C.TensorType, x_type) + return x_type.dim() + + +def _get_tensor_sizes(x: _C.Value, allow_nonstatic: bool = True): + if not _is_tensor(x) or x.type() is None: + return None + x_type = x.type() + x_type = typing.cast(_C.TensorType, x_type) + if allow_nonstatic: + # Each individual symbol is returned as None. + # e.g. [1, "a", "b"] -> [1, None, None] + return x_type.varyingSizes() + # returns None, if exists any symbol in sizes. + # e.g. [1, "a", "b"] -> None + return x_type.sizes() + + +def _get_tensor_dim_size(x: _C.Value, dim: int) -> int | None: + sizes = _get_tensor_sizes(x) + return sizes[dim] if sizes else None + + +def _get_dim_for_cross(x: _C.Value, dim: int | None): + if dim == -1: + tensor_rank = _get_tensor_rank(x) + assert tensor_rank is not None + return dim + tensor_rank + # If dim is not given, it defaults to the first dimension found with the size 3 + if dim is None: + sizes = _get_tensor_sizes(x) + assert sizes is not None + for index, size in enumerate(sizes): + if size is not None and size == 3: + return index + return dim + + +def _unimplemented(op: str, msg: str, value: _C.Value | None = None) -> None: + # For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators + if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: + _onnx_unsupported(f"{op}, {msg}", value) + + +def _onnx_unsupported(op_name: str, value: _C.Value | None = None) -> NoReturn: + message = ( + f"Unsupported: ONNX export of operator {op_name}. " + f"Please feel free to request support or submit a pull request " + f"on PyTorch GitHub: {_constants.PYTORCH_GITHUB_ISSUES_URL}" + ) + if isinstance(value, _C.Value): + raise errors.SymbolicValueError( + message, + value, + ) + raise errors.OnnxExporterError(message) + + +def _onnx_opset_unsupported( + op_name: str, + current_opset: int, + supported_opset: int, + value: _C.Value | None = None, +) -> NoReturn: + message = ( + f"Unsupported: ONNX export of {op_name} in opset {current_opset}. " + f"Please try opset version {supported_opset}." + ) + if isinstance(value, _C.Value): + raise errors.SymbolicValueError( + message, + value, + ) + raise errors.OnnxExporterError(message) + + +def _onnx_opset_unsupported_detailed( + op_name: str, + current_opset: int, + supported_opset: int, + reason: str, + value: _C.Value | None = None, +) -> NoReturn: + message = ( + f"Unsupported: ONNX export of {op_name} in " + f"opset {current_opset}. {reason}. Please try opset version {supported_opset}." + ) + if isinstance(value, _C.Value): + raise errors.SymbolicValueError( + message, + value, + ) + raise errors.OnnxExporterError(message) + + +def _block_list_in_opset(name: str): + def symbolic_fn(*args, **kwargs): + raise errors.OnnxExporterError( + f"ONNX export failed on {name}, which is not implemented for opset " + f"{GLOBALS.export_onnx_opset_version}. " + "Try exporting with other opset versions." + ) + + return symbolic_fn + + +def _try_get_scalar_type(*args) -> _type_utils.JitScalarType | None: + for arg in args: + scalar_type = _type_utils.JitScalarType.from_value( + arg, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + return scalar_type + return None + + +def _type_promote_from_values(*args) -> _type_utils.JitScalarType: + undef = _type_utils.JitScalarType.UNDEFINED + jit_types = [_try_get_scalar_type(arg) for arg in args] + if len(jit_types) == 0: + return undef + if len(jit_types) == 1: + return jit_types[0] # type: ignore[return-value] + new_dtype = jit_types[0].dtype() # type: ignore[union-attr] + for t in jit_types: + new_dtype = torch.promote_types(new_dtype, t.dtype()) # type: ignore[union-attr] + return _type_utils.JitScalarType.from_dtype(new_dtype) + + +def _maybe_cast_to_type( + g: jit_utils.GraphContext, value, jit_type: _type_utils.JitScalarType +): + if ( + _type_utils.JitScalarType.from_value(value, _type_utils.JitScalarType.UNDEFINED) + != jit_type + ): + return g.op( + "Cast", + value, + to_i=jit_type.onnx_type(), + ) + return value + + +def _select_helper(g: jit_utils.GraphContext, self, dim, index, apply_reshape=True): + index_const = _maybe_get_scalar(index) + index_dim = _get_tensor_rank(index) + if not _is_value(index_const): + # Index is a constant scalar. Make it a size 1 constant tensor. + index = g.op("Constant", value_t=torch.LongTensor([index_const])) + elif index_dim is not None and apply_reshape: + if index_dim == 0: + # Index is a scalar. Reshape it to a size 1 tensor. + index = _reshape_helper( + g, index, g.op("Constant", value_t=torch.LongTensor([1])) + ) + + index_scalar_type = _type_utils.JitScalarType.from_value( + index, _type_utils.JitScalarType.UNDEFINED + ) + if index_scalar_type not in { + _type_utils.JitScalarType.INT64, + _type_utils.JitScalarType.INT, + }: + index = g.op("Cast", index, to_i=_C_onnx.TensorProtoDataType.INT64) + return g.op("Gather", self, index, axis_i=dim) + + +def _slice_helper( + g: jit_utils.GraphContext, + input, + axes, + starts, + ends, + steps=None, +): + if g.opset <= 9: + from torch.onnx.symbolic_opset9 import _slice as _slice9 + + return _slice9(g, input, axes, starts, ends) + else: + from torch.onnx.symbolic_opset10 import _slice as _slice10 + + return _slice10(g, input, axes, starts, ends, steps) + + +def _is_fp(value) -> bool: + return _type_utils.JitScalarType.from_value( + value, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.FLOAT, + _type_utils.JitScalarType.DOUBLE, + _type_utils.JitScalarType.HALF, + _type_utils.JitScalarType.BFLOAT16, + } + + +def _is_bool(value) -> bool: + return _type_utils.JitScalarType.from_value( + value, _type_utils.JitScalarType.UNDEFINED + ) in {_type_utils.JitScalarType.BOOL} + + +def _generate_wrapped_number(g: jit_utils.GraphContext, scalar): + """Creates a wrapped number based on https://github.com/pytorch/pytorch/issues/9515. + + A Tensor is a considered a "wrapped number" if it is + auto-wrapped from a C++ or Python number type. Integer types are + wrapped as 0-dim int64 tensors and floating-point types are + wrapped as 0-dim double tensors. + + The input to this function is constant value. If the data type + is a floating point type, it is converted to a 0-dim double + tensor, else it is converted to a 0-dim tensor of its original type + """ + assert not isinstance(scalar, torch.Tensor) + if isinstance(scalar, float): + return g.op("Constant", value_t=torch.tensor(scalar, dtype=torch.double)) + return g.op("Constant", value_t=torch.tensor(scalar)) + + +def _sort_helper(g: jit_utils.GraphContext, input, dim, decending=True, out=None): + if out is not None: + _unimplemented("Sort", "Out parameter is not supported") + shape_ = g.op("Shape", input) + dim_size_ = g.op( + "Gather", + shape_, + g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)), + ) + if g.opset <= 10: + if not decending: + _unimplemented("Sort", "Ascending is not supported") + return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2) + else: + return g.op( + "TopK", input, dim_size_, axis_i=dim, largest_i=decending, outputs=2 + ) + + +def _topk_helper( + g: jit_utils.GraphContext, input, k, dim, largest=True, sorted=False, out=None +): + if out is not None: + _unimplemented("TopK", "Out parameter is not supported") + if not _is_value(k): + k = g.op("Constant", value_t=torch.tensor([k], dtype=torch.int64)) + else: + k = _reshape_helper(g, k, g.op("Constant", value_t=torch.tensor([1]))) + if _try_get_scalar_type(k) != _type_utils.JitScalarType.INT64: + k = g.op("Cast", k, to_i=_C_onnx.TensorProtoDataType.INT64) + if g.opset <= 10: + if not largest: + _unimplemented("TopK", "Ascending is not supported") + return g.op("TopK", input, k, axis_i=dim, outputs=2) + else: + return g.op( + "TopK", input, k, axis_i=dim, largest_i=largest, sorted_i=sorted, outputs=2 + ) + + +def _lt_helper(g: jit_utils.GraphContext, input, other): + if g.opset <= 8: + from torch.onnx.symbolic_opset8 import lt as _lt8 + + return _lt8(g, input, other) + else: + from torch.onnx.symbolic_opset9 import lt as _lt9 + + return _lt9(g, input, other) + + +def _interpolate_warning(interpolate_mode): + onnx_op = ( + "onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample" + ) + warnings.warn( + "You are trying to export the model with " + + onnx_op + + " for ONNX opset version " + "" + str(GLOBALS.export_onnx_opset_version) + ". " + "This operator might cause results to not match the expected results by PyTorch.\n" + "ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. " + "Attributes to determine how to transform the input were added in onnx:Resize in opset 11 " + "to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n" + "We recommend using opset 11 and above for models using this operator." + ) + + +def _unsqueeze_helper(g: jit_utils.GraphContext, input, axes_i): + if _is_constant(axes_i[0]): + if g.opset >= 13: + axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) + return g.op("Unsqueeze", input, axes) + return g.op("Unsqueeze", input, axes_i=axes_i) + # Tensor type + if g.opset < 13: + raise errors.SymbolicValueError( + "Opset version must be >= 13 for Unsqueeze with dynamic axes.", input + ) + return g.op("Unsqueeze", input, axes_i[0]) + + +def _squeeze_helper(g: jit_utils.GraphContext, input, axes_i): + if _is_constant(axes_i[0]): + if g.opset >= 13: + axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) + return g.op("Squeeze", input, axes) + return g.op("Squeeze", input, axes_i=axes_i) + # Tensor type + if g.opset < 13: + raise errors.SymbolicValueError( + "Opset version must be >= 13 for Squeeze with dynamic axes.", input + ) + axes_t = axes_i[0] + axes_rank = _get_tensor_rank(axes_t) + assert axes_rank is not None + if axes_rank > 1: + raise errors.SymbolicValueError( + "For Squeeze axses as input, the axes rank must be one in ONNX spec.", input + ) + elif axes_rank == 0: + # The axes is a scalar. Unsqueeze it to a rank 1 tensor. + axes_t = _unsqueeze_helper(g, axes_t, [0]) + return g.op("Squeeze", input, axes_t) + return g.op("Squeeze", input, axes_t) + + +def _reducesum_helper( + g: jit_utils.GraphContext, + input, + axes_i=None, + keepdims_i=1, + noop_with_empty_axes_i=0, +): + keepdims_i = _maybe_get_const(keepdims_i, "i") + if g.opset >= 13: + if axes_i: + if not _is_value(axes_i): + axes_i = g.op( + "Constant", value_t=torch.tensor(axes_i, dtype=torch.long) + ) + return g.op( + "ReduceSum", + input, + axes_i, + keepdims_i=keepdims_i, + noop_with_empty_axes_i=noop_with_empty_axes_i, + ) + return g.op( + "ReduceSum", + input, + keepdims_i=keepdims_i, + noop_with_empty_axes_i=noop_with_empty_axes_i, + ) + else: + return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i) + + +def _interpolate_size_to_scales(g: jit_utils.GraphContext, input, output_size, dim): + output_size = _maybe_get_const(output_size, "is") + if _is_value(output_size): + offset = 2 + offsets = g.op("Constant", value_t=torch.ones(offset, dtype=torch.float32)) + dividend = g.op("Cast", output_size, to_i=_C_onnx.TensorProtoDataType.FLOAT) + divisor = _slice_helper( + g, g.op("Shape", input), axes=[0], ends=[sys.maxsize], starts=[offset] + ) + divisor = g.op("Cast", divisor, to_i=_C_onnx.TensorProtoDataType.FLOAT) + scale_dims = g.op("Div", dividend, divisor) + scales = g.op("Concat", offsets, scale_dims, axis_i=0) + else: + scales_constant = [ + 1.0 + if i < 2 + else float(output_size[-(dim - i)]) + / float(input.type().sizes()[-(dim - i)]) + for i in range(0, dim) + ] + scales = g.op( + "Constant", value_t=torch.tensor(scales_constant, dtype=torch.float32) + ) + return scales + + +def _interpolate_get_scales_if_available(g: jit_utils.GraphContext, scales): + available_scales = _maybe_get_const(scales[0], "fs") != -1 and not _is_none( + scales[0] + ) + + if not available_scales: + return None + + offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) + scales_list = g.op( + "Constant", value_t=torch.tensor(_maybe_get_const(scales[0], "fs")) + ) + scales = g.op("Concat", offsets, scales_list, axis_i=0) + return scales + + +def _get_interpolate_attributes(g: jit_utils.GraphContext, mode, args): + if mode == "nearest": + align_corners = None + scales = args[0:] + else: + align_corners = args[0] + scales = args[1:] + scales = _interpolate_get_scales_if_available(g, scales) + return scales, align_corners + + +def _interpolate_get_scales(g: jit_utils.GraphContext, scale_factor, dim): + offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) + scale_factor_rank = _get_tensor_rank(scale_factor) + if isinstance(scale_factor.type(), _C.ListType) or ( + scale_factor_rank is not None and scale_factor_rank > 0 + ): + return g.op("Concat", offsets, scale_factor, axis_i=0) + else: + scale_factor = _unsqueeze_helper(g, scale_factor, [0]) + scale_factor = g.op( + "Cast", scale_factor, to_i=_C_onnx.TensorProtoDataType.FLOAT + ) + scales = [scale_factor for i in range(dim - 2)] + scale_factor = g.op("Concat", offsets, *scales, axis_i=0) + return scale_factor + + +def _interpolate_get_scales_and_mode( + g: jit_utils.GraphContext, input, size, scale_factor, mode, align_corners +): + mode = _maybe_get_const(mode, "s") + if "linear" in mode: + mode = "linear" + if "cubic" in mode: + mode = "cubic" + _interpolate_warning(mode) + + align_corners = _maybe_get_const(align_corners, "b") + if isinstance(align_corners, bool) and align_corners: + return _unimplemented("interpolate", "align_corners == True") + + if not input.type().dim(): + return _unimplemented("interpolate", "missing input shape") + dim = input.type().dim() + + if not _is_none(scale_factor): + scale_factor = _interpolate_get_scales(g, scale_factor, dim) + elif not _is_none(size): + if not _is_packed_list(size): + is_scalar = _maybe_get_const(size, "t").dim() == 0 + if is_scalar: + size = _unsqueeze_helper(g, size, [0]) + size = [size for i in range(dim - 2)] + size = g.op("Concat", *size, axis_i=0) + scale_factor = _interpolate_size_to_scales(g, input, size, dim) + else: + return _unimplemented( + "interpolate", "Both size and scales are None in __interpolate" + ) + return scale_factor, mode + + +def _argmin_argmax_helper( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, + op_name: str, +): + def op_wrapper(input, axis_i, keepdims_i): + if g.opset >= 12: + return g.op( + op_name, + input, + axis_i=axis_i, + keepdims_i=keepdims_i, + select_last_index_i=False, + ) + return g.op(op_name, input, axis_i=axis_i, keepdims_i=keepdims_i) + + if _is_none(dim): + flattened = _reshape_helper( + g, input, g.op("Constant", value_t=torch.tensor([-1])) + ) + output = op_wrapper(flattened, axis_i=0, keepdims_i=False) + if keepdim: + input_shape = g.op("Shape", input) + input_shape_shape = g.op("Shape", input_shape) + new_shape = g.op( + "ConstantOfShape", + input_shape_shape, + value_t=torch.tensor([1], dtype=torch.int64), + ) + output = g.op("Reshape", output, new_shape) + return output + + dim = _parse_arg(dim, "i") + return op_wrapper(input, axis_i=dim, keepdims_i=keepdim) + + +def _interpolate_helper(name, dim, interpolate_mode): + @quantized_args(True, False, False) + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = _get_interpolate_attributes(g, interpolate_mode, args) + align_corners = _maybe_get_scalar(align_corners) + coordinate_transformation_mode = ( + "asymmetric" + if interpolate_mode == "nearest" + else "align_corners" + if align_corners + else "half_pixel" + ) + + if scales is None: + input_size = g.op("Shape", input) + input_size_beg = _slice_helper( + g, input_size, axes=[0], ends=[2], starts=[0] + ) + output_size = g.op( + "Cast", output_size, to_i=_C_onnx.TensorProtoDataType.INT64 + ) + output_size = g.op("Concat", input_size_beg, output_size, axis_i=0) + + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + empty_scales = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + empty_scales = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + + return g.op( + "Resize", + input, + empty_roi, + empty_scales, + output_size, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=interpolate_mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) # only valid when mode="nearest" + else: + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + + return g.op( + "Resize", + input, + empty_roi, + scales, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=interpolate_mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) # only valid when mode="nearest" + + return symbolic_fn + + +def __interpolate_helper( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, +): + mode = _maybe_get_const(mode, "s") + if "linear" in mode: + mode = "linear" + if "cubic" in mode: + mode = "cubic" + align_corners = _maybe_get_const(align_corners, "b") + align_corners = False if not isinstance(align_corners, bool) else align_corners + coordinate_transformation_mode = ( + "asymmetric" + if mode == "nearest" + else "align_corners" + if align_corners + else "half_pixel" + ) + + if not _is_none(size): + input_size = g.op("Shape", input) + input_size = _slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) + # in some cases size is not a packed list but size is a scalar + # We need to also verify that (_maybe_get_const(size, "t").dim() == 0) + # but this information is not always available. Try to get the dim, + # and if not assume that it is not a scalar. + try: + is_scalar = not _is_packed_list(size) and ( + _maybe_get_const(size, "t").dim() == 0 + ) + except AttributeError: + is_scalar = not _is_packed_list(size) + if not is_scalar: + warnings.warn( + "Cannot verify if the output_size is a scalar " + "while exporting interpolate. Assuming that it is not a scalar." + ) + + if is_scalar: + rank = _get_tensor_rank(input) + if rank is None: + return _unimplemented( + "interpolate (with a scalar output_size)", + "missing input shape (try giving an array of output_size values)", + ) + size = _unsqueeze_helper(g, size, [0]) + size = [size for i in range(rank - 2)] + size = g.op("Concat", *size, axis_i=0) + size = g.op("Cast", size, to_i=_C_onnx.TensorProtoDataType.INT64) + size = g.op("Concat", input_size, size, axis_i=0) + + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + empty_scales = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + empty_scales = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + + return g.op( + "Resize", + input, + empty_roi, + empty_scales, + size, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) + else: # if not _is_none(scales) + rank = _get_tensor_rank(input) + if rank is None: + return _unimplemented("interpolate (with scales)", "missing input shape") + + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + + scales = _interpolate_get_scales(g, scale_factor, rank) + return g.op( + "Resize", + input, + empty_roi, + scales, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) # only valid when mode="nearest" + + +def _unbind_helper(g: jit_utils.GraphContext, self, dim, _outputs): + if g.opset < 11: + from torch.onnx.symbolic_opset9 import unbind + elif g.opset <= 12: + from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef] + else: + from torch.onnx.symbolic_opset13 import unbind # type: ignore[no-redef] + return unbind(g, self, dim, _outputs) + + +def _scatter_helper(g: jit_utils.GraphContext, self, dim, index, src): + if g.opset <= 10: + from torch.onnx.symbolic_opset9 import scatter + else: + # for mypy, scatter was imported two lines above + from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef] + return scatter(g, self, dim, index, src) + + +def _repeat_interleave_split_helper(g: jit_utils.GraphContext, self, reps, dim): + if g.opset <= 12: + split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps) + else: + from torch.onnx.symbolic_opset13 import split + + repeats = g.op("Constant", value_t=torch.tensor([1] * reps)) + split_out = split(g, self, repeats, dim, _outputs=reps) + return split_out if reps > 1 else [split_out] + + +def _repeat_interleave_single_value_repeat_helper( + g: jit_utils.GraphContext, self, repeats, dim +): + from torch.onnx.symbolic_opset9 import flatten, unsqueeze + + if not _is_tensor(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + + const_repeats: bool = _is_constant(repeats) + reps = _maybe_get_const(repeats, "t") + + # Convert 'repeats' to 1-d if it is 0-d. + if _get_tensor_rank(repeats) == 0: + repeats = g.op("Reshape", repeats, g.op("Constant", value_t=torch.tensor([1]))) + + # Create a new dim of size 1, then expand it to be 'repeats' long, and finally collapse it. + unsqueezed = unsqueeze(g, self, dim + 1) + + # repeats_per_dim is 1 for all dims except for the new unsqueezed dim, where it has value 'repeats'. + if const_repeats: + # 'Repeats' is a constant, 'repeats_per_dim' can be a constant. + onehot = torch.ones(_get_tensor_rank(unsqueezed), dtype=torch.int64) # type: ignore[arg-type] + onehot[dim + 1] = reps + repeats_per_dim = g.op("Constant", value_t=onehot) + else: + # 'Repeats' is a variable, 'repeats_per_dim' cannot be a constant. + onehot = g.op( + "OneHot", + unsqueeze(g, dim + 1, 0), # indices, must be >= 1-dimensional + g.op( + "Constant", value_t=torch.tensor(_get_tensor_rank(unsqueezed)) + ), # depth + g.op( + "Concat", g.op("Constant", value_t=torch.tensor([1])), repeats, axis_i=0 + ), # on/off values + ) + repeats_per_dim = flatten(g, onehot, 0, 1) + + tiled = g.op("Tile", unsqueezed, repeats_per_dim) + return flatten(g, tiled, dim, dim + 1) + + +def _arange_cast_helper( + g: jit_utils.GraphContext, end, start=None, step=None, dtype=None +) -> tuple[ + _type_utils.JitScalarType, + _C.Value | None, + _C.Value | None, + _C.Value | None, +]: + def _is_all_integral(scalars): + for scalar in scalars: + scalar_type = _type_utils.JitScalarType.from_value( + scalar, _type_utils.JitScalarType.UNDEFINED + ) + if ( + scalar_type != _type_utils.JitScalarType.INT64 + and scalar_type != _type_utils.JitScalarType.UNDEFINED + ): + return False + return True + + # This logic is based on torch.arange docs. If "dtype" is provided, + # infer input types from dtype. If not, then check if any of start, stop, + # or step are floating point, and infer the type from get_default. + # Otherwise, the dtype is inferred to be torch.int64. + if dtype is None or (_is_value(dtype) and _is_none(dtype)): + if _is_all_integral([start, end, step]): + scalar_type = _type_utils.JitScalarType.INT64 + else: + scalar_type = _type_utils.JitScalarType.from_dtype( + torch.get_default_dtype() + ) + else: + assert isinstance(dtype, int) + # TODO(justinchuby): Check if dtype is indeed a int. + scalar_type = _type_utils.JitScalarType(dtype) + + start = g.op("Cast", start, to_i=scalar_type.onnx_type()) if start else None + end = g.op("Cast", end, to_i=scalar_type.onnx_type()) if end else None + step = g.op("Cast", step, to_i=scalar_type.onnx_type()) if step else None + return scalar_type, end, start, step + + +def _arange_helper(g: jit_utils.GraphContext, *args): + if g.opset <= 10: + from torch.onnx.symbolic_opset9 import arange + else: + from torch.onnx.symbolic_opset11 import arange # type: ignore[no-redef] + return arange(g, *args) + + +def _size_helper(g: jit_utils.GraphContext, self, dim): + full_shape = g.op("Shape", self) + from torch.onnx.symbolic_opset9 import select + + return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim) + + +def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index): + # 1. reshape index => [1, ..., 1, dim, 1, ..., 1] + # 2. expand index => [..., dim, ...], same shape as self except for dim. + # 3. expand value as well. + # 4. apply onnx::scatter. + + from torch.onnx.symbolic_opset9 import expand + + if g.opset <= 10: + from torch.onnx.symbolic_opset9 import scatter + else: + # for mypy, scatter was imported two lines above + from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef] + + if self.type().dim() is None: + return _unimplemented("index_fill", "input rank not accessible") + self_dim = self.type().dim() + dim_value = _parse_arg(dim, "i") + if dim_value < 0: + dim_value += self_dim + unsqueezed_index = _unsqueeze_helper( + g, index, [i for i in range(self_dim) if i != dim_value] + ) + expanded_index_shape = scatter( + g, g.op("Shape", self), 0, _unsqueeze_helper(g, dim, [0]), g.op("Shape", index) + ) + expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None) + return expanded_index_shape, expanded_index + + +# By default, when any value in the 'shape' input is equal to zero +# the corresponding dimension value is copied from the input tensor dynamically. +# allowzero=1 indicates that if any value in the 'shape' input is set to zero, +# the zero value is honored, similar to NumPy. +# allowzero=1 is only supported for opset version >= 14. +def _reshape_helper(g: jit_utils.GraphContext, input, shape, allowzero=0): + shape = _maybe_get_const(shape, "is") + if not _is_value(shape): + shape = g.op("Constant", value_t=torch.LongTensor(shape)) + if g.opset <= 13: + if allowzero == 1: + _onnx_opset_unsupported( + "Reshape with allowzero=1", GLOBALS.export_onnx_opset_version, 14, input + ) + return g.op("Reshape", input, shape) + else: + return g.op("Reshape", input, shape, allowzero_i=allowzero) + + +def _batchnorm_helper( + g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var +): + from torch.onnx.symbolic_opset9 import _var_mean + + batch_size = _get_tensor_dim_size(input, 0) + channel_size = _get_tensor_dim_size(input, 1) + + if weight is None or _is_none(weight): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of batch_norm for unknown channel size.", + input, + ) + weight_value = torch.tensor( + [1.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + weight = g.op("Constant", value_t=weight_value) + if bias is None or _is_none(bias): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of batch_norm for unknown channel size.", + input, + ) + bias_value = torch.tensor( + [0.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + bias = g.op("Constant", value_t=bias_value) + # If track_running_stats is set to False batch statistics are instead used during evaluation time + if ( + running_mean is None + or _is_none(running_mean) + or running_var is None + or _is_none(running_var) + ): + assert batch_size is not None and channel_size is not None + reshape_in = _reshape_helper( + g, + input, + g.op( + "Constant", + value_t=torch.tensor([batch_size, channel_size, -1], dtype=torch.int64), + ), + ) + trans_in = g.op("Transpose", reshape_in, perm_i=[0, 2, 1]) + running_var, running_mean = _var_mean( + g, + trans_in, + g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)), + False, + False, + ) + return weight, bias, running_mean, running_var + + +def _avgpool_helper( + tuple_fn: Callable[[Any], Sequence[int]], + padding: int | Sequence[int], + kernel_size, + stride, + divisor_override, + name, +) -> tuple[int, ...]: + if divisor_override and divisor_override.node().kind() != "prim::Constant": + _unimplemented(name, "divisor_override") + return tuple(tuple_fn(padding)) + + +def check_training_mode(op_train_mode: int, op_name: str) -> None: + """Warns the user if the model's training mode and the export mode do not agree.""" + if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE: + return + + if op_train_mode: + op_mode_enum = _C_onnx.TrainingMode.TRAINING + else: + op_mode_enum = _C_onnx.TrainingMode.EVAL + if op_mode_enum == GLOBALS.training_mode: + # The modes agree. Do nothing + return + + op_mode_text = f"train={bool(op_train_mode)}" + # Setting the model mode could result in op_mode != GLOBALS.training_mode + # if the model is a FuncModule. In this case we warn the user of + # the state and export depending on op_mode + # This is to support use-cases of fixing certain layer weights + # in training. + warnings.warn( + f"ONNX export mode is set to {GLOBALS.training_mode}, but operator '{op_name}' " + f"is set to {op_mode_text}. Exporting with {op_mode_text}." + ) + + +def _flatten_helper(g: jit_utils.GraphContext, input, start_dim, end_dim, dim): + input_size = g.op("Shape", input) + slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim]) + slices = [slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))] + if end_dim < dim - 1: + slice3 = _slice_helper( + g, input_size, axes=[0], starts=[end_dim + 1], ends=[dim] + ) + slices = [ + slice1, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + slice3, + ] + + final_shape = g.op("Concat", *slices, axis_i=0) + from torch.onnx.symbolic_opset9 import _reshape_from_tensor + + return _reshape_from_tensor(g, input, final_shape) + + +def _is_split_static(split_size_or_sizes, _outputs): + if _outputs is None: + return False + if ( + _is_value(split_size_or_sizes) + and split_size_or_sizes.node().kind() != "onnx::Constant" + ): + return False + return True + + +def _optional_input_placeholder_tensor(g): + n = g.op("prim::Constant") + n.setType(_C.OptionalType.ofTensor()) + return n + + +def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name): + rank = _get_tensor_rank(self) + if rank is not None and any( + _get_tensor_dim_size(self, i) == 0 for i in range(rank) + ): + # If input tensor is empty, according to ONNX ReduceSum definition, + # set keepdims=1 so that the resulted tensor has the same rank as the input. + return g.op(op_name, self, keepdims_i=1) + return g.op(op_name, self, keepdims_i=0) + + +def dequantize_helper( + g: jit_utils.GraphContext, + qtensor: _C.Value, + qdtype: _C_onnx.TensorProtoDataType | None = None, +) -> tuple[_C.Value, _C.Value, _C.Value, _C.Value | None]: + """Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`. + + Args: + g: Graph, the ONNX IR graph that is under construction. + qtensor: torch._C.Value, either a tuple of (quantized_tensor, scale, zero_point) + for per tensor quantization, or + (quantized_tensor, scale, zero_point, axis) for per channel quantization, + representing the quantized tensor. + qdtype: torch.onnx.TensorProtoDataType default None, if not None, represents the + data type of quantized tensor. It must be either + torch.onnx.TensorProtoDataType.UINT8 or torch.onnx.TensorProtoDataType.INT8. + """ + unpacked_qtensors = _unpack_quantized_tensor(qtensor) + tensor, scale, zero_point = unpacked_qtensors[:3] + axis = unpacked_qtensors[3] if len(unpacked_qtensors) >= 4 else None + axis_i = _get_const(axis, "i", "axis") + input_qdtype = _type_utils.JitScalarType.from_value(tensor) + if qdtype is None: + if input_qdtype is not None: + qdtype = input_qdtype.onnx_type() + else: + qdtype = _C_onnx.TensorProtoDataType.UINT8 + value = g.op("Cast", tensor, to_i=qdtype) + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + zero_point = g.op("Cast", zero_point, to_i=qdtype) + + if axis_i is not None and GLOBALS.export_onnx_opset_version < 13: + _onnx_opset_unsupported_detailed( + "DequantizeLinear", + GLOBALS.export_onnx_opset_version, + 13, + "Attribute axis is not supported.", + qtensor, + ) + + return ( + g.op("DequantizeLinear", value, scale, zero_point, axis_i=axis_i), + scale, + zero_point, + axis, + ) + + +def quantize_helper( + g: jit_utils.GraphContext, + tensor: _C.Value, + scale: _C.Value, + zero_point: _C.Value, + axis: _C.Value | None = None, +) -> _C.Value: + """Appends to graph `g` ONNX nodes that quantizes `tensor` based on `scale`, `zero_point` and `axis`. + + Args: + g: Graph, the ONNX IR graph that is under construction. + tensor: torch._C.Value, representing the tensor to be quantized. + scale: torch._C.Value, quantized scale. + zero_point: torch._C.Value, quantized zero point. + axis: Optional[torch._C.Value] default None, if None, represents per tensor quantization. + Otherwise, represents per channel quantization, along given axis. + + Returns: + A TupleConstruct storing information of the quantized tensor. + """ + if ( + axis is not None + and not _is_none(axis) + and GLOBALS.export_onnx_opset_version < 13 + ): + _onnx_opset_unsupported_detailed( + "QuantizeLinear", + GLOBALS.export_onnx_opset_version, + 13, + "Attribute axis is not supported.", + tensor, + ) + + assert scale is not None + if ( + _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) + != _type_utils.JitScalarType.FLOAT + ): + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + assert zero_point is not None + if _type_utils.JitScalarType.from_value( + zero_point, _type_utils.JitScalarType.UNDEFINED + ) not in { + _type_utils.JitScalarType.UINT8, + _type_utils.JitScalarType.INT8, + }: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + output = g.op( + "QuantizeLinear", + tensor, + scale, + zero_point, + axis_i=_get_const(axis, "i", "axis"), + ) + args = [output, scale, zero_point] + if axis is not None and not _is_none(axis): + args.append(axis) + return g.op("prim::TupleConstruct", *args) + + +def requantize_bias_helper( + g: jit_utils.GraphContext, bias, input_scale, weight_scale, axis=None +): + """In PyTorch, bias is float and is quantized to int32 implicitly inside the quantized ATen op kernel. + In ONNX we need to make the quantization explicit because operators expect all of their inputs to be quantized. + Since int32 is not a supported output type by ONNX operator `QuantizeLinear`, quantization is exported using + regular operators. + """ + bias_scale = g.op("Mul", weight_scale, input_scale) + bias_scale_shape = g.op("Shape", bias_scale) + bias_zero_point = g.op( + "ConstantOfShape", bias_scale_shape, value_t=torch.tensor([0], dtype=torch.int) + ) + q_bias = g.op( + "Cast", g.op("Div", bias, bias_scale), to_i=_C_onnx.TensorProtoDataType.INT32 + ) + axis_args = [] + if axis is not None and not _is_none(axis): + axis_args.append(axis) + return g.op("prim::TupleConstruct", q_bias, bias_scale, bias_zero_point, *axis_args) + + +def args_have_same_dtype(args): + assert args + base_dtype = _type_utils.JitScalarType.from_value(args[0]) + has_same_dtype = all( + _type_utils.JitScalarType.from_value(elem) == base_dtype for elem in args + ) + return has_same_dtype + + +def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs): + """Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types. + This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch + operator data type. For example, `Cast(Clip(Cast(INPUT)))` can be used to mimic + `Clip(INPUT)` (opset version < 12). + + Args: + g (torch._C.Graph): graph to write the ONNX representation into. + op_name (str): operator name in ONNX. + *args (tuple): operands to the operator. + **kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default) + indicating the smallest opset version to trigger such casting behavior and "target_float_t" + (optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator. + + Returns: + Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator. + """ + opset_before = kwargs.pop("opset_before", None) + target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT) + + inputs = list(args) + dtype_0 = _type_utils.JitScalarType.from_value(inputs[0]) + + require_cast = not _is_fp(inputs[0]) and ( + opset_before is None or GLOBALS.export_onnx_opset_version < opset_before + ) + + if require_cast: + for input in inputs: + if input.isCompleteTensor(): + input_scalar_type = _type_utils.JitScalarType.from_value(input) + if input_scalar_type != dtype_0: + raise errors.SymbolicValueError( + f"Inputs of {op_name} must have same dtype." + f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}", + input, + ) + for i, input in enumerate(inputs): + if input.isCompleteTensor() and not _is_fp(input): + inputs[i] = g.op( + "Cast", + input, + to_i=target_float_t.onnx_type(), + ) + + self = g.op(op_name, *inputs, **kwargs) + + if require_cast: + self = g.op("Cast", self, to_i=dtype_0.onnx_type()) + + return self + + +def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + # This check only covers traced modules where dtype is present + # pytorch reduce-ops cast all other integral types to int64 + if not _is_fp(self) and scalar_type != _type_utils.JitScalarType.INT64: + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64) + return self + + +def _apply_params(*args, **kwargs): + """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" + + def _apply(fn): + return fn(*args, **kwargs) + + return _apply + + +def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True): + def symbolic(g, self, dim=None, keepdim=None): + self = _maybe_cast_reduce_op_input(g, self) + if dim is None or dim == (): + # Dim can be 0, which will cause (not dim) == True. So we don't want to do + # (not dim) + # all-reduce path + return _handle_reduce_dim_none(g, self, onnx_op_name) + else: + # dim-reduce path + keepdim = _get_const(keepdim, "i", "keepdim") + if g.opset < 18: + desc = "is" if allow_multi_dim_support else "i" + dim = _get_const(dim, desc, "dim") + dim_list = dim if allow_multi_dim_support else [dim] + return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim) + else: + if _is_value(dim): + axes = dim + else: + if allow_multi_dim_support: + axes = g.op( + "Constant", value_t=torch.tensor(dim, dtype=torch.long) + ) + else: + axes = g.op( + "Constant", value_t=torch.tensor([dim], dtype=torch.long) + ) + return g.op(onnx_op_name, self, axes, keepdims_i=keepdim) + + return symbolic + + +def _overload_by_arg_count(fn): + @functools.wraps(fn) + def wrapper(g, *args): + overloads = fn(g, *args) + for overload in overloads: + arg_descriptors = overload._arg_descriptors + if len(arg_descriptors) == len(args): + return overload(g, *args) + return _unimplemented(f"aten::{fn.__name__}", f"with {len(args)} arguments") + + return wrapper + + +def _reduce_with_dtype_helper( + onnx_op: str, name: str, allow_multi_dim_support: bool = True +): + symbolic = _reduce_op_symbolic_helper( + onnx_op, allow_multi_dim_support=allow_multi_dim_support + ) + + @_overload_by_arg_count + def reduce(g, *args, **kwargs): + @quantized_args(True) + @parse_args("v", "none") + def reduce_nodim(g, self, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = _get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return _unimplemented(name, "dtype", dtype) + result = symbolic(g, self) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + dim_desc = "is" if allow_multi_dim_support else "i" + + @quantized_args(True) + @parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type] + def reduce_dim(g, self, dim, keepdim, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = _get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return _unimplemented(name, "dtype", dtype) + result = symbolic(g, self, dim, keepdim) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + return reduce_nodim, reduce_dim + + return reduce + + +def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.max(input) + if dim_or_y is None and keepdim is None: + return g.op("ReduceMax", self, keepdims_i=0) + # torch.max(input, other) + if keepdim is None: + return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12) + # torch.max(input, dim, keepdim) + else: + keepdim = _get_const(keepdim, "i", "keepdim") + dim = _get_const(dim_or_y, "i", "dim") + if g.opset < 18: + max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim) + else: + axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + max = g.op("ReduceMax", self, axes, keepdims_i=keepdim) + indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim) + return max, indices + + +def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.min(input) + if dim_or_y is None and keepdim is None: + return g.op("ReduceMin", self, keepdims_i=0) + # torch.min(input, other) + if keepdim is None: + return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12) + # torch.min(input, dim, keepdim) + else: + keepdim = _get_const(keepdim, "i", "keepdim") + dim = _get_const(dim_or_y, "i", "dim") + if g.opset < 18: + min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim) + else: + axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + min = g.op("ReduceMin", self, axes, keepdims_i=keepdim) + indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim) + return min, indices + + +def _numel_helper(g: jit_utils.GraphContext, self): + shape = g.op("Shape", self) + return g.op("ReduceProd", shape, keepdims_i=0) + + +@parse_args("v", "is", "i", "i") +def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim): + if g.opset < 18: + if dim is None: + mean = g.op("ReduceMean", input, keepdims_i=0) + t_mean = mean + num_elements = _numel_helper(g, input) + else: + mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim) + t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1) + redudced_dims = g.op("Shape", input) + # dim could contain one or multiple dimensions + redudced_dims = g.op( + "Gather", + redudced_dims, + g.op("Constant", value_t=torch.tensor(dim)), + axis_i=0, + ) + num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) + sub_v = g.op("Sub", input, t_mean) + sqr_sub = g.op("Mul", sub_v, sub_v) + keepdim_mean = 0 if dim is None else keepdim + var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean) + # Correct bias in calculating variance, by dividing it over (N - correction) instead on N + if correction is None: + correction = 1 + if correction != 0: + num_elements = g.op( + "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT + ) + one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) + mul = g.op("Mul", var, num_elements) + var = g.op("Div", mul, g.op("Sub", num_elements, one)) + return var, mean + else: + axes = None + if dim is None: + mean = g.op("ReduceMean", input, keepdims_i=0) + t_mean = mean + num_elements = _numel_helper(g, input) + else: + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + mean = g.op("ReduceMean", input, axes, keepdims_i=keepdim) + t_mean = g.op("ReduceMean", input, axes, keepdims_i=1) + redudced_dims = g.op("Shape", input) + # dim could contain one or multiple dimensions + redudced_dims = g.op( + "Gather", + redudced_dims, + g.op("Constant", value_t=torch.tensor(dim)), + axis_i=0, + ) + num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) + sub_v = g.op("Sub", input, t_mean) + sqr_sub = g.op("Mul", sub_v, sub_v) + keepdim_mean = 0 if dim is None else keepdim + if axes is None: + var = g.op("ReduceMean", sqr_sub, keepdims_i=keepdim_mean) + else: + var = g.op("ReduceMean", sqr_sub, axes, keepdims_i=keepdim_mean) + # Correct bias in calculating variance, by dividing it over (N - correction) instead on N + if correction is None: + correction = 1 + if correction != 0: + num_elements = g.op( + "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT + ) + one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) + mul = g.op("Mul", var, num_elements) + var = g.op("Div", mul, g.op("Sub", num_elements, one)) + return var, mean + + +def _embedding_bag_helper( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + if scale_grad_by_freq and GLOBALS.export_training: + return _onnx_unsupported( + "embedding_bag with scale_grad_by_freq for training mode" + ) + if padding_idx is not None and padding_idx >= 0: + raise RuntimeError("embedding_bag with padding_idx") + + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + zero = g.op("Constant", value_t=torch.tensor([0])) + + indices_len = _unsqueeze_helper( + g, + _size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))), + [0], + ) + if not include_last_offset: + offsets = [offsets, indices_len] + offsets = g.op("Concat", *offsets, axis_i=0) + + # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by + # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings. + # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in. + offsets_starts = _slice_helper( + g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1] + ) + offsets_ends = _slice_helper( + g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1] + ) + + loop_len = _size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))) + + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, n_blocks=1 + ) + loop_block = loop_context.block + + # FIXME(justinchuby): We need to handle what happens when we call b.op on a node return + block_input_iter = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) + + indices_start = loop_context.op( + "Gather", offsets_starts, block_input_iter, axis_i=0 + ) + indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0) + indices_start = _unsqueeze_helper(loop_context, indices_start, [0]) + indices_end = _unsqueeze_helper(loop_context, indices_end, [0]) + + indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero) + embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0) + if not _is_none(per_sample_weights): + per_sample_weights_row = loop_context.op( + "Slice", per_sample_weights, indices_start, indices_end, zero + ) + per_sample_weights_row = _unsqueeze_helper( + loop_context, per_sample_weights_row, [1] + ) + embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row) + if mode == 0: + embeddings = _reducesum_helper( + loop_context, embeddings, axes_i=[0], keepdims_i=0 + ) + elif mode == 1: + if loop_context.opset < 18: + embeddings = loop_context.op( + "ReduceMean", embeddings, axes_i=[0], keepdims_i=0 + ) + else: + axes = loop_context.op( + "Constant", value_t=torch.tensor([0], dtype=torch.long) + ) + embeddings = loop_context.op("ReduceMean", embeddings, axes, keepdims_i=0) + else: + if loop_context.opset < 18: + embeddings = loop_context.op( + "ReduceMax", embeddings, axes_i=[0], keepdims_i=0 + ) + else: + axes = loop_context.op( + "Constant", value_t=torch.tensor([0], dtype=torch.long) + ) + embeddings = loop_context.op("ReduceMax", embeddings, axes, keepdims_i=0) + + cond_out = loop_context.op( + "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL + ) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, embeddings) + + # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. + # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. + return loop.node().output(), None, None, None + + +def _linalg_vector_norm_helper( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: float, + dim: Sequence[int] | None, + keepdim: bool, + dtype: torch._C.Value, +): + axes = None + # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html + if _is_none(dim): + self = _reshape_helper(g, self, [-1]) + keepdim = False + elif g.opset >= 18: + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + + if ord == math.inf: + if g.opset < 18: + result = g.op( + "ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim + ) + else: + if axes is None: + result = g.op("ReduceMax", g.op("Abs", self), keepdims_i=keepdim) + else: + result = g.op("ReduceMax", g.op("Abs", self), axes, keepdims_i=keepdim) + elif ord == -math.inf: + if g.opset < 18: + result = g.op( + "ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim + ) + else: + if axes is None: + result = g.op("ReduceMin", g.op("Abs", self), keepdims_i=keepdim) + else: + result = g.op("ReduceMin", g.op("Abs", self), axes, keepdims_i=keepdim) + elif ord == 0: + if g.opset < 11: + return _onnx_opset_unsupported_detailed( + "linalg_vector_norm", 9, 11, "ord=0 not supported", self + ) + else: + if dim is None: + self = _reshape_helper( + g, + self, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), + ) + keepdim = False + + cond_op = g.op( + "Not", + g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))), + ) + cond_op = g.op( + "Cast", + cond_op, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + return _reducesum_helper(g, cond_op, axes_i=dim, keepdims_i=keepdim) + elif ord == 1: + if g.opset < 18: + result = _reduce_op_symbolic_helper("ReduceL1")( + g, self, dim=dim, keepdim=keepdim + ) + else: + if axes is None: + result = _reduce_op_symbolic_helper("ReduceL1")( + g, self, keepdim=keepdim + ) + else: + result = _reduce_op_symbolic_helper("ReduceL1")( + g, self, axes, keepdim=keepdim + ) + elif ord == 2: + if g.opset < 18: + result = _reduce_op_symbolic_helper("ReduceL2")( + g, self, dim=dim, keepdim=keepdim + ) + else: + if axes is None: + result = _reduce_op_symbolic_helper("ReduceL2")( + g, self, keepdim=keepdim + ) + else: + result = _reduce_op_symbolic_helper("ReduceL2")( + g, self, axes, keepdim=keepdim + ) + else: + ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32)) + result = _reducesum_helper( + g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim + ) + result = g.op( + "Pow", + result, + g.op( + "Div", + g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)), + ord_op, + ), + ) + + if not _is_none(dtype): + dtype = _get_const(dtype, "i", "dtype") + result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) # type: ignore[arg-type] + return result + + +# Deprecated. Internally use _type_utils.ScalarType +# TODO: remove these once we support Type's in the JIT IR and we can once again +# use the unified toType operator +cast_pytorch_to_onnx = { + "Byte": _C_onnx.TensorProtoDataType.UINT8, + "Char": _C_onnx.TensorProtoDataType.INT8, + "Double": _C_onnx.TensorProtoDataType.DOUBLE, + "Float": _C_onnx.TensorProtoDataType.FLOAT, + "Half": _C_onnx.TensorProtoDataType.FLOAT16, + "Int": _C_onnx.TensorProtoDataType.INT32, + "Long": _C_onnx.TensorProtoDataType.INT64, + "Short": _C_onnx.TensorProtoDataType.INT16, + "Bool": _C_onnx.TensorProtoDataType.BOOL, + "ComplexFloat": _C_onnx.TensorProtoDataType.COMPLEX64, + "ComplexDouble": _C_onnx.TensorProtoDataType.COMPLEX128, + "BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16, + "Undefined": _C_onnx.TensorProtoDataType.UNDEFINED, +} + +# Deprecated. Internally use _type_utils.ScalarType +scalar_name_to_pytorch = { + "uint8_t": "Byte", + "int8_t": "Char", + "double": "Double", + "float": "Float", + "half": "Half", + "int": "Int", + "int64_t": "Long", + "int16_t": "Short", + "bool": "Bool", + "complex64": "ComplexFloat", + "complex128": "ComplexDouble", + "qint8": "QInt8", + "quint8": "QUInt8", + "qint32": "QInt32", + "bfloat16": "BFloat16", +} + + +# Deprecated. Internally use _type_utils.ScalarType +# This indicates each scalar type's corresponding +# torch type. Related source: +# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h +scalar_type_to_pytorch_type = [ + torch.uint8, # 0 + torch.int8, # 1 + torch.short, # 2 + torch.int, # 3 + torch.int64, # 4 + torch.half, # 5 + torch.float, # 6 + torch.double, # 7 + torch.complex32, # 8 + torch.complex64, # 9 + torch.complex128, # 10 + torch.bool, # 11 + torch.qint8, # 12 + torch.quint8, # 13 + torch.qint32, # 14 + torch.bfloat16, # 15 +] + +# Deprecated. Internally use _type_utils.ScalarType +# source of truth is +# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp +pytorch_name_to_type = { + "Byte": torch.uint8, + "Char": torch.int8, + "Double": torch.double, + "Float": torch.float, + "Half": torch.half, + "Int": torch.int, + "Long": torch.int64, + "Short": torch.short, + "Bool": torch.bool, + "ComplexFloat": torch.complex64, + "ComplexDouble": torch.complex128, + "QInt8": torch.qint8, + "QUInt8": torch.quint8, + "QInt32": torch.qint32, + "BFloat16": torch.bfloat16, +} + + +# Deprecated. Internally use _type_utils.ScalarType +scalar_type_to_onnx = [ + cast_pytorch_to_onnx["Byte"], # 0 + cast_pytorch_to_onnx["Char"], # 1 + cast_pytorch_to_onnx["Short"], # 2 + cast_pytorch_to_onnx["Int"], # 3 + cast_pytorch_to_onnx["Long"], # 4 + cast_pytorch_to_onnx["Half"], # 5 + cast_pytorch_to_onnx["Float"], # 6 + cast_pytorch_to_onnx["Double"], # 7 + cast_pytorch_to_onnx["Undefined"], # 8 + cast_pytorch_to_onnx["ComplexFloat"], # 9 + cast_pytorch_to_onnx["ComplexDouble"], # 10 + cast_pytorch_to_onnx["Bool"], # 11 + cast_pytorch_to_onnx["Char"], # 12 + cast_pytorch_to_onnx["Byte"], # 13 + cast_pytorch_to_onnx["Int"], # 14 + cast_pytorch_to_onnx["BFloat16"], # 15 +] + +# Global set to store the list of quantized operators in the network. +# This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX. +_quantized_ops: set[int] = set() diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset10.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset10.py new file mode 100644 index 0000000000000000000000000000000000000000..975b6bdbe7d85f78843dfa5d55de643ca01b44fd --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset10.py @@ -0,0 +1,1184 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +from __future__ import annotations + +import functools +import sys +import warnings +from typing import Sequence + +import torch +import torch._C._onnx as _C_onnx +import torch.onnx +from torch import _C + +# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics +from torch.onnx import ( + _constants, + _type_utils, + errors, + symbolic_helper, + symbolic_opset9 as opset9, +) +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +# This file exports ONNX ops for opset 10 +# Opset 10 is supported by ONNX release 1.5.0 +# release on 04/24/19 + + +__all__ = [ + "dequantize", + "div", + "embedding_bag", + "fake_quantize_per_tensor_affine", + "flip", + "fmod", + "isfinite", + "isinf", + "nan_to_num", + "quantize_per_tensor", + "quantized_add_relu", + "quantized_add", + "quantized_cat", + "quantized_conv1d_relu", + "quantized_conv2d_relu", + "quantized_conv3d_relu", + "quantized_conv1d", + "quantized_conv2d", + "quantized_conv3d", + "quantized_conv_transpose1d", + "quantized_conv_transpose2d", + "quantized_conv_transpose3d", + "quantized_group_norm", + "quantized_hardswish", + "quantized_instance_norm", + "quantized_layer_norm", + "quantized_leaky_relu", + "quantized_linear", + "quantized_linear_relu", + "quantized_mul", + "quantized_sigmoid", + "slice", + "sort", + "topk", +] + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10) + + +@_onnx_symbolic("aten::div") +def div(g: jit_utils.GraphContext, self, other, *args): + if len(args) == 0: + return opset9.true_divide(g, self, other) + else: + return _div_rounding_mode(g, self, other, *args) + + +@symbolic_helper.parse_args("v", "v", "s") +def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): + if rounding_mode == "floor": + return _floor_divide(g, self, other) + else: + return opset9._div_rounding_mode(g, self, other, rounding_mode) + + +@_onnx_symbolic("aten::_floor_divide") +def _floor_divide(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): + out = opset9.true_divide(g, self, other) + return g.op("Floor", out) + else: + # Integer division does trunction rounding + div = g.op("Div", self, other) + # Division is negative if: self < 0 != other < 0 + zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) + negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero)) + + # For negative numbers with self % other != 0, subtract 1 to round down instead of up + mod = g.op("Mod", self, other, fmod_i=0) + fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) + + one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + fixup = g.op("Sub", div, one) + return g.op("Where", fixup_mask, fixup, div) + + +@_onnx_symbolic("aten::sort") +@symbolic_helper.parse_args("v", "i", "i", "none") +def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): + return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) + + +@_onnx_symbolic("aten::topk") +@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") +def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): + return symbolic_helper._topk_helper( + g, self, k, dim, largest=largest, sorted=sorted, out=out + ) + + +def _aten_max_pool_onnx( + g: jit_utils.GraphContext, + self: _C.Value, + kernel_shape: Sequence[int], + strides: Sequence[int], + pads: Sequence[int], + dilations: Sequence[int], + ceil_mode: bool, + unbatched_rank: int, +) -> _C.Value: + self_rank = g.op("Size", g.op("Shape", self)) + if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 + self = g.op( + "Unsqueeze", + self, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + pool_result, _ = g.op( + "MaxPool", + self, + outputs=2, + ceil_mode_i=ceil_mode, + dilations_i=dilations, + kernel_shape_i=kernel_shape, + pads_i=pads, + strides_i=strides, + ) + + if self_rank == unbatched_rank: + pool_result = g.op( + "Squeeze", + pool_result, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + return pool_result + + +# For MaxPool +def _adjust_attributes_of_max_pool( + expand_size: int, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + padding: Sequence[int] | int, + dilation: Sequence[int] | int, +) -> tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: + """Adjust attributes of avg_pool to match ONNX specification.""" + + if isinstance(dilation, int): + dilation = [dilation] * expand_size + + if isinstance(kernel_size, int): + kernel_shape = [kernel_size] * expand_size + else: + kernel_shape = kernel_size # type: ignore[assignment] + + if isinstance(padding, int): + pads = [padding] * expand_size * 2 # type: ignore[operator, assignment] + elif len(padding) == 1: + pads = padding * expand_size * 2 # type: ignore[operator, assignment] + elif len(padding) == 2: + # 2D padding + pads = padding * 2 # type: ignore[operator, assignment] + elif len(padding) == 3: + # 3D padding + pads = padding * 2 # type: ignore[operator, assignment] + else: + # When padding is already done for all dimensions, + # we don't need to double it + # eg: (1, 1, 1, 1, 1, 1) + pads = padding # type: ignore[assignment] + + if isinstance(stride, int): + strides = [stride] * expand_size + elif not stride: + strides = kernel_shape + else: + strides = stride # type: ignore[assignment] + + return (kernel_shape, strides, pads, dilation) + + +def _aten_max_pool_with_indices_onnx( + g: jit_utils.GraphContext, + self: _C.Value, + kernel_shape: Sequence[int], + strides: Sequence[int], + pads: Sequence[int], + dilations: Sequence[int], + ceil_mode: bool, + unbatched_rank: int, + n_dims_one: Sequence[int], + n_dims_zero: Sequence[int], + n_dims_axes: Sequence[int], +) -> tuple[_C.Value, Sequence[int]]: + self_rank = g.op("Size", g.op("Shape", self)) + if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 + self = g.op( + "Unsqueeze", + self, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + pool_result, indices = g.op( + "MaxPool", + self, + outputs=2, + ceil_mode_i=ceil_mode, + dilations_i=dilations, + kernel_shape_i=kernel_shape, + pads_i=pads, + strides_i=strides, + ) + _, flatten_indices = g.op( + "MaxPool", + self, + outputs=2, + dilations_i=dilations, + kernel_shape_i=n_dims_one, + strides_i=n_dims_one, + ) + + ends = g.op("Constant", value_t=torch.tensor(n_dims_one)) + starts = g.op("Constant", value_t=torch.tensor(n_dims_zero)) + axes = g.op("Constant", value_t=torch.tensor(n_dims_axes)) + + delta = g.op("Slice", flatten_indices, starts, ends, axes) + indices = g.op("Sub", indices, delta) + + if self_rank == unbatched_rank: + pool_result = g.op( + "Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64) + ) + indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64)) + + return (pool_result, indices) + + +@_onnx_symbolic( + "aten::max_pool1d", + decorate=[symbolic_helper._apply_params("max_pool1d", 1, return_indices=False)], +) +@_onnx_symbolic( + "aten::max_pool2d", + decorate=[symbolic_helper._apply_params("max_pool2d", 2, return_indices=False)], +) +@_onnx_symbolic( + "aten::max_pool3d", + decorate=[symbolic_helper._apply_params("max_pool3d", 3, return_indices=False)], +) +@_onnx_symbolic( + "aten::max_pool1d_with_indices", + decorate=[ + symbolic_helper._apply_params( + "max_pool1d_with_indices", + 1, + return_indices=True, + ) + ], +) +@_onnx_symbolic( + "aten::max_pool2d_with_indices", + decorate=[ + symbolic_helper._apply_params( + "max_pool2d_with_indices", + 2, + return_indices=True, + ) + ], +) +@_onnx_symbolic( + "aten::max_pool3d_with_indices", + decorate=[ + symbolic_helper._apply_params( + "max_pool3d_with_indices", + 3, + return_indices=True, + ) + ], +) +def _max_pool(name: str, expand_size: int, return_indices: bool): + @symbolic_helper.quantized_args(True, False, False, False, False, False) + @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") + def symbolic_fn( + g: jit_utils.GraphContext, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: int | Sequence[int], + dilation: Sequence[int], + ceil_mode: bool, + ): + kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool( + expand_size, kernel_size, stride, padding, dilation + ) + + if return_indices: + return _aten_max_pool_with_indices_onnx( + g, + input, + kernel_shape, + strides, + pads, + dilations, + ceil_mode, + expand_size + 1, + ([1] * expand_size), + ([0] * expand_size), + ([2 + i for i in range(expand_size)]), + ) + else: + return _aten_max_pool_onnx( + g, + input, + kernel_shape, + strides, + pads, + dilations, + ceil_mode, + expand_size + 1, + ) + + return symbolic_fn + + +# For AvgPool +def _adjust_attributes_of_avg_pool( + expand_size: int, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + padding: Sequence[int] | int, +) -> tuple[Sequence[int], Sequence[int], Sequence[int]]: + """Adjust attributes of avg_pool to match ONNX specification.""" + + if isinstance(kernel_size, int): + kernel_shape = [kernel_size] * expand_size + else: + kernel_shape = kernel_size # type: ignore[assignment] + + if isinstance(padding, int): + pads = [padding] * expand_size * 2 + elif len(padding) == 1: + pads = padding * expand_size * 2 # type: ignore[operator, assignment] + elif len(padding) == 2: + pads = padding * expand_size # type: ignore[operator, assignment] + else: + pads = padding * 2 # type: ignore[operator, assignment] + + if isinstance(stride, int): + strides = [stride] * expand_size + elif not stride: + strides = kernel_shape + else: + strides = stride # type: ignore[assignment] + + return (kernel_shape, strides, pads) + + +@_onnx_symbolic( + "aten::avg_pool1d", + decorate=[symbolic_helper._apply_params("avg_pool1d", 1)], +) +@_onnx_symbolic( + "aten::avg_pool2d", + decorate=[symbolic_helper._apply_params("avg_pool2d", 2)], +) +@_onnx_symbolic( + "aten::avg_pool3d", + decorate=[symbolic_helper._apply_params("avg_pool3d", 3)], +) +def _avg_pool(name, expand_size): + @symbolic_helper.quantized_args(True, False, False, False, False, False, False) + @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") + def symbolic_fn( + g, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: int | Sequence[int], + ceil_mode: int, + count_include_pad: int, + divisor_override=None, + ): + kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( + expand_size, kernel_size, stride, padding + ) + + result = g.op( + "AveragePool", + input, + ceil_mode_i=ceil_mode, + count_include_pad_i=count_include_pad, + kernel_shape_i=kernel_shape, + pads_i=pads, + strides_i=strides, + ) + + return result + + return symbolic_fn + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], +) +def _interpolate(name, dim, interpolate_mode): + @symbolic_helper.quantized_args(True, False, False) + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = symbolic_helper._get_interpolate_attributes( + g, interpolate_mode, args + ) + symbolic_helper._interpolate_warning(interpolate_mode) + align_corners = symbolic_helper._maybe_get_scalar(align_corners) + if align_corners: + return symbolic_helper._unimplemented(name, "align_corners == True", input) + if scales is None: + scales = symbolic_helper._interpolate_size_to_scales( + g, input, output_size, dim + ) + return g.op("Resize", input, scales, mode_s=interpolate_mode) + + return symbolic_fn + + +@_onnx_symbolic("aten::__interpolate") +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + scales, mode = symbolic_helper._interpolate_get_scales_and_mode( + g, input, size, scale_factor, mode, align_corners + ) + return g.op("Resize", input, scales, mode_s=mode) + + +def _slice( + g: jit_utils.GraphContext, + input: torch._C.Value, + axes: list | torch.Tensor | torch._C.Value, + starts: list | torch.Tensor | torch._C.Value, + ends: list | torch.Tensor | torch._C.Value, + steps: list | torch.Tensor | torch._C.Value | None = None, +): + def is_none_value(value): + if value is None: + return True + return ( + isinstance(value, torch._C.Value) + and value.node().kind() == "prim::Constant" + and isinstance(value.type(), _C.NoneType) + ) + + def to_slice_input(list_or_value, default_value=None): + # Convert input param into a 1D torch.Value. + if is_none_value(list_or_value) and default_value is not None: + list_or_value = [default_value] + + if isinstance(list_or_value, (list, torch.Tensor)): + return g.op("Constant", value_t=torch.tensor(list_or_value)) + + rank = symbolic_helper._get_tensor_rank(list_or_value) + if rank == 0: + return symbolic_helper._unsqueeze_helper(g, list_or_value, [0]) + if rank == 1: + return list_or_value + raise errors.SymbolicValueError( + f"Rank must be 0 or 1, not {rank}", list_or_value + ) + + def get_const_value(list_or_value): + if isinstance(list_or_value, (list, torch.Tensor)): + if len(list_or_value) == 1: + return list_or_value[0] + return None + return symbolic_helper._maybe_get_const(list_or_value, "i") + + # Check if slice is a no-op + if ( + get_const_value(starts) == 0 + and get_const_value(ends) == _constants.INT64_MAX + and (steps is None or get_const_value(steps) == 1) + ): + return input + + axes = to_slice_input(axes) + starts = to_slice_input(starts, default_value=0) + ends = to_slice_input(ends, default_value=_constants.INT64_MAX) + if steps is None: + return g.op("Slice", input, starts, ends, axes) + steps = to_slice_input(steps, default_value=1) + return g.op("Slice", input, starts, ends, axes, steps) + + +@_onnx_symbolic("aten::slice") +def slice(g: jit_utils.GraphContext, self, *args): + if len(args) == 4: + # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor + dims, start, end, step = args + elif len(args) == 3: + # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[] + start, end, step = args + dims = [0] + else: + raise errors.SymbolicValueError("Unknown aten::slice signature", self) + + return symbolic_helper._slice_helper( + g, + self, + axes=dims, + starts=start, + ends=end, + steps=step, + ) + + +@_onnx_symbolic("aten::flip") +@symbolic_helper.parse_args("v", "is") +def flip(g: jit_utils.GraphContext, input, dims): + return symbolic_helper._slice_helper( + g, + input, + axes=dims, + starts=[-1] * len(dims), + ends=[-_constants.INT64_MAX] * len(dims), + steps=[-1] * len(dims), + ) + + +@_onnx_symbolic("aten::fmod") +def fmod(g: jit_utils.GraphContext, input, other): + return g.op("Mod", input, other, fmod_i=1) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + if scale_grad_by_freq and GLOBALS.export_training: + return symbolic_helper._onnx_unsupported( + "embedding_bag with scale_grad_by_freq for training mode" + ) + if padding_idx is not None and padding_idx >= 0: + raise RuntimeError("embedding_bag with padding_idx") + + warnings.warn( + "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " + "Please use opset 11 or higher to export model for dynamic input shape.'" + ) + offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0) + if offsets_dim_0 is not None: + if include_last_offset: + offset_len = offsets_dim_0 - 1 + offsets_extended = offsets + else: + offset_len = offsets_dim_0 + offsets_extended = [ + offsets, + g.op("Constant", value_t=torch.tensor([sys.maxsize])), + ] + offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) + list_ = [] + for i in range(offset_len): + start_ = symbolic_helper._unsqueeze_helper( + g, + opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), + [0], + ) + end_ = symbolic_helper._unsqueeze_helper( + g, + opset9.select( + g, offsets_extended, torch.tensor(0), torch.tensor(i + 1) + ), + [0], + ) + axes_ = g.op("Constant", value_t=torch.tensor([0])) + indices_row = g.op("Slice", indices, start_, end_, axes_) + + embeddings = g.op("Gather", embedding_matrix, indices_row) + if not symbolic_helper._is_none(per_sample_weights): + per_sample_weights_row = g.op( + "Slice", per_sample_weights, start_, end_, axes_ + ) + per_sample_weights_row = symbolic_helper._unsqueeze_helper( + g, per_sample_weights_row, [1] + ) + embeddings = g.op("Mul", embeddings, per_sample_weights_row) + if mode == 0: + embeddings = symbolic_helper._reducesum_helper( + g, embeddings, axes_i=[0], keepdims_i=0 + ) + elif mode == 1: + embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) + else: + embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) + + embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0]) + list_.append(embeddings) + + output = g.op("Concat", *list_, axis_i=0) + # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. + # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. + return output, None, None, None + else: + return symbolic_helper._onnx_unsupported( + "embedding_bag with unknown shape of offsets for opset 10 is not supported. " + "please use opset 11 or higher." + ) + + +@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i") +def fake_quantize_per_tensor_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) == (0, 127): + symbolic_helper._onnx_opset_unsupported_detailed( + "fake_quantize_per_tensor_affine", + 10, + 13, + "Quantize range (0, 127) not supported, requires opset 13 Clip", + inputs, + ) + if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: + raise errors.SymbolicValueError( + f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + inputs, + ) + scale = symbolic_helper._maybe_get_scalar(scale) + if scale is None: + symbolic_helper._onnx_opset_unsupported_detailed( + "fake_quantize_per_tensor_affine", + 10, + 13, + "Non-constant scale not supported", + inputs, + ) + scale = scale.float().data # Avoid exporter generating double type + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + return g.op( + "DequantizeLinear", + g.op("QuantizeLinear", inputs, scale, zero_point), + scale, + zero_point, + ) + + +@_onnx_symbolic("aten::isinf") +def isinf(g: jit_utils.GraphContext, input): + return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)) + + +@_onnx_symbolic("aten::isfinite") +def isfinite(g: jit_utils.GraphContext, input): + inf_node = isinf(g, input) + nan_node = opset9.isnan(g, input) + return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node)) + + +@_onnx_symbolic("aten::quantize_per_tensor") +def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + # TODO(justinchuby): Extract all the cast ops into a helper function. + zero_point = g.op( + "Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return symbolic_helper.quantize_helper(g, input, scale, zero_point) + + +@_onnx_symbolic("aten::dequantize") +def dequantize(g: jit_utils.GraphContext, input): + return symbolic_helper.dequantize_helper(g, input)[0] + + +@_onnx_symbolic("aten::nan_to_num") +@symbolic_helper.parse_args("v", "f", "f", "f") +def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf): + # Cannot create a int type tensor with inf/nan values, so we simply + # return the original tensor + if not symbolic_helper._is_fp(input): + return input + input_dtype = _type_utils.JitScalarType.from_value(input).dtype() + if nan is None: + nan = 0.0 + nan_cond = opset9.isnan(g, input) + nan_result = g.op( + "Where", + nan_cond, + g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), + input, + ) + + # For None values of posinf, neginf we use the greatest/lowest finite + # value representable by input's dtype. + finfo = torch.finfo(input_dtype) + if posinf is None: + posinf = finfo.max + posinf_cond = opset9.logical_and( + g, + isinf(g, nan_result), + opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))), + ) + nan_posinf_result = g.op( + "Where", + posinf_cond, + g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), + nan_result, + ) + + if neginf is None: + neginf = finfo.min + neginf_cond = opset9.logical_and( + g, + isinf(g, nan_posinf_result), + opset9.lt( + g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0])) + ), + ) + return g.op( + "Where", + neginf_cond, + g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), + nan_posinf_result, + ) + + +# Quantized symbolics --------------------------------------------------------- +# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export +# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were +# introduced in opset version 10. +@_onnx_symbolic("quantized::linear") +def quantized_linear( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::linear_relu") +def quantized_linear_relu( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::add") +def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + + output = opset9.add(g, x, y) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::add_relu") +def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + + output = opset9.add(g, x, y) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::mul") +def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + + output = opset9.mul(g, x, y) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::hardswish") +def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.hardswish(g, x) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::sigmoid") +def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.sigmoid(g, x) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::leaky_relu") +def quantized_leaky_relu( + g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.leaky_relu(g, x, negative_slope, inplace) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::layer_norm") +def quantized_layer_norm( + g: jit_utils.GraphContext, + x, + normalized_shape, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::group_norm") +def quantized_group_norm( + g: jit_utils.GraphContext, + x, + num_groups, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::instance_norm") +@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v") +def quantized_instance_norm( + g: jit_utils.GraphContext, + q_input, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) + + output = opset9.instance_norm( + g, input, weight, bias, None, None, False, 0.0, eps, False + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d_relu") +def quantized_conv1d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d_relu") +def quantized_conv2d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d_relu") +def quantized_conv3d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d") +def quantized_conv1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d") +def quantized_conv2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d") +def quantized_conv3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose1d") +def quantized_conv_transpose1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose2d") +def quantized_conv_transpose2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose3d") +def quantized_conv_transpose3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose3d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::cat") +@symbolic_helper.parse_args("v", "i", "v", "v") +def quantized_cat( + g: jit_utils.GraphContext, + q_inputs: _C.Value, + dim: int, + op_scale: _C.Value, + op_zero_point: _C.Value, +) -> _C.Value: + unpacked_inputs = symbolic_helper._unpack_list(q_inputs) + dequantized = [ + symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs + ] + concatenated = g.op("Concat", *dequantized, axis_i=dim) + return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset11.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset11.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf27b273832f2f12293f565b4598ee1b9d2868a --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset11.py @@ -0,0 +1,1467 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 11.""" + +from __future__ import annotations + +import functools +import sys +import warnings +from typing import Sequence + +import torch +from torch import _C +from torch._C import _onnx as _C_onnx +from torch.onnx import ( + _type_utils, + errors, + symbolic_helper, + symbolic_opset10 as opset10, + symbolic_opset9 as opset9, + utils, +) +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +__all__ = [ + "add", + "append", + "arange", + "argsort", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "cat", + "chunk", + "clamp_max", + "clamp_min", + "clamp", + "constant_pad_nd", + "cumsum", + "Delete", + "embedding_bag", + "embedding_renorm", + "flatten", + "gather", + "hardtanh", + "hstack", + "im2col", + "index_fill", + "index", + "index_copy", + "index_put", + "insert", + "linalg_det", + "linalg_vector_norm", + "logdet", + "masked_scatter", + "masked_select", + "mm", + "narrow", + "normal", + "pad", + "pixel_shuffle", + "pop", + "prim_constant_chunk", + "reflection_pad", + "relu6", + "remainder", + "replication_pad", + "round", + "scatter", + "select", + "size", + "sort", + "split_with_sizes", + "split", + "squeeze", + "stack", + "topk", + "unbind", + "unique_dim", + "unsqueeze", + "vstack", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11) + + +@_onnx_symbolic("aten::hardtanh") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "f", "f") +def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + min_val = g.op( + "Constant", + value_t=torch.tensor(min_val, dtype=scalar_type.dtype()), + ) + max_val = g.op( + "Constant", + value_t=torch.tensor(max_val, dtype=scalar_type.dtype()), + ) + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min_val, max_val, opset_before=12 + ) + + +@_onnx_symbolic("aten::clamp") +def clamp(g: jit_utils.GraphContext, self, min, max): + def _cast_if_not_none(tensor, dtype): + if tensor is not None and not symbolic_helper._is_none(tensor): + return g.op( + "Cast", + tensor, + to_i=dtype.onnx_type(), + ) + else: + return tensor + + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + min = _cast_if_not_none(min, scalar_type) + max = _cast_if_not_none(max, scalar_type) + + if symbolic_helper._is_none(min): + return clamp_max(g, self, max) + elif symbolic_helper._is_none(max): + return clamp_min(g, self, min) + else: + if ( + symbolic_helper._get_tensor_rank(min) == 0 + and symbolic_helper._get_tensor_rank(max) == 0 + ): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min, max, opset_before=12 + ) + else: + return clamp_max(g, clamp_min(g, self, min), max) + + +@_onnx_symbolic("aten::clamp_min") +@symbolic_helper.parse_args("v", "v") +def clamp_min(g: jit_utils.GraphContext, self, min): + min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) + if symbolic_helper._get_tensor_rank(min) == 0: + max = opset9.unused(g) + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min, max, opset_before=12 + ) + else: + return symbolic_helper._op_with_optional_float_cast( + g, "Max", self, min, opset_before=12 + ) + + +@_onnx_symbolic("aten::clamp_max") +@symbolic_helper.parse_args("v", "v") +def clamp_max(g: jit_utils.GraphContext, self, max): + max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) + if symbolic_helper._get_tensor_rank(max) == 0: + min = opset9.unused(g) + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min, max, opset_before=12 + ) + else: + return symbolic_helper._op_with_optional_float_cast( + g, "Min", self, max, opset_before=12 + ) + + +@_onnx_symbolic("aten::relu6") +def relu6(g: jit_utils.GraphContext, input): + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + min_val = g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ) + max_val = g.op( + "Constant", + value_t=torch.tensor(6, dtype=scalar_type.dtype()), + ) + return clamp(g, input, min_val, max_val) + + +@_onnx_symbolic("aten::select") +# Opset 11 gather accepts negative indices +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "i", "v") +def select(g: jit_utils.GraphContext, self, dim, index): + return g.op("Gather", self, index, axis_i=dim) + + +@_onnx_symbolic("aten::index_put") +def index_put( + g: jit_utils.GraphContext, self, indices_list_value, values, accumulate=False +): + if symbolic_helper._is_packed_list(indices_list_value): + indices_list = symbolic_helper._unpack_list(indices_list_value) + else: + indices_list = [indices_list_value] + accumulate = symbolic_helper._parse_arg(accumulate, "b") + + if len(indices_list) == 0: + return values + + if len(indices_list) > 1: + for idx_ in range(len(indices_list)): + if symbolic_helper._is_bool(indices_list[idx_]): + indices_list[idx_] = g.op("NonZero", indices_list[idx_]) + index = indices_list[0] + + for ind in indices_list[1:]: + index = opset9.add(g, index, ind) + broadcast_index_shape = g.op("Shape", index) + indices_list = [ + symbolic_helper._unsqueeze_helper( + g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] + ) + for ind in indices_list + ] + index = g.op("Concat", *indices_list, axis_i=-1) + else: + # Replace index_put node with masked_scatter or masked_fill + # when inputs to the index_put node contains a single boolean input. + # + # index_put -> masked_fill + # * input index contains single tensor of Bool type (e.g.: %24 <- %23). + # * input value contains single element (e.g.: %18). + # + # Torch IR + # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) + # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) + # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() + # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) + # %24 : Tensor?[] = prim::ListConstruct(%23) + # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # aten::index_put(%mask, %24, %18, %30) + # return (%25) + # + # + # index_put -> masked_scatter + # * input index contains single tensor of Bool type (e.g.: %32 <- %31). + # * input value contains multiple elements (e.g.: %28). + # + # Torch IR + # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) + # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) + # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() + # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::ne(%mask, %some_const) + # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) + # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %30 : int[] = prim::Constant[value=[-1]]() + # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) + # %32 : Tensor?[] = prim::ListConstruct(%31) + # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::index_put(%mask, %32, %28, %38) + # return (%33) + index = indices_list[0] + bool_inp = index + if symbolic_helper._is_bool(bool_inp): + rank = symbolic_helper._get_tensor_rank(values) + if rank is not None and rank == 0: + return opset9.masked_fill(g, self, bool_inp, values) + mask_rank = symbolic_helper._get_tensor_rank(bool_inp) + self_rank = symbolic_helper._get_tensor_rank(self) + if ( + mask_rank is not None + and self_rank is not None + and self_rank > mask_rank + ): + # Unsqueeze 'bool_inp' to be broadcastable to shape of 'self'. + bool_inp = symbolic_helper._unsqueeze_helper( + g, bool_inp, list(range(mask_rank, self_rank)) + ) + return masked_scatter(g, self, bool_inp, values) + broadcast_index_shape = g.op("Shape", index) + index = symbolic_helper._unsqueeze_helper(g, index, [-1]) + sub_data_shape = symbolic_helper._slice_helper( + g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] + ) + values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) + # Check if values is a singular value and expand accordingly + rank = symbolic_helper._get_tensor_rank(values) + if rank is not None and rank == 0: + values = opset9.expand(g, values, values_shape, None) + values = symbolic_helper._reshape_helper(g, values, values_shape) + + self_scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if self_scalar_type != _type_utils.JitScalarType.UNDEFINED: + values_scalar_type = _type_utils.JitScalarType.from_value( + values, _type_utils.JitScalarType.UNDEFINED + ) + if self_scalar_type != values_scalar_type: + values = g.op("Cast", values, to_i=self_scalar_type.onnx_type()) + elif accumulate: + raise errors.SymbolicValueError("self does not have a valid scalar type.", self) + + if accumulate: + zeros = g.op( + "ConstantOfShape", + g.op("Shape", self), + value_t=torch.tensor([0], dtype=self_scalar_type.dtype()), + ) + result = g.op("ScatterND", zeros, index, values) + result = add(g, self, result) + else: + result = g.op("ScatterND", self, index, values) + + return result + + +@_onnx_symbolic("aten::pixel_shuffle") +@symbolic_helper.parse_args("v", "i") +def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None and rank != 4: + return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input") + return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD") + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bicubic2d", + decorate=[symbolic_helper._apply_params("upsample_bicubic2d", 4, "cubic")], +) +def _interpolate(name: str, dim: int, interpolate_mode: str): + return symbolic_helper._interpolate_helper(name, dim, interpolate_mode) + + +@_onnx_symbolic("aten::__interpolate") +@symbolic_helper.quantized_args(True, False, False, False, False, False, False) +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + return symbolic_helper.__interpolate_helper( + g, input, size, scale_factor, mode, align_corners, recompute_scale_factor + ) + + +@_onnx_symbolic("aten::gather") +@symbolic_helper.parse_args("v", "i", "v", "v") +def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): + if symbolic_helper._maybe_get_const(sparse_grad, "i"): + return symbolic_helper._unimplemented("gather", "sparse_grad == True") + return g.op("GatherElements", self, index, axis_i=dim) + + +@_onnx_symbolic("aten::scatter") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter(g: jit_utils.GraphContext, self, dim, index, src): + src_type = _type_utils.JitScalarType.from_value(src) + src = symbolic_helper._maybe_get_scalar(src) + if symbolic_helper._is_value(src): + return g.op("ScatterElements", self, index, src, axis_i=dim) + else: + # Check if scalar "src" has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + if _type_utils.JitScalarType.from_value(self) != src_type: + src = g.op( + "Cast", + src, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + return g.op( + "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim + ) + + +@_onnx_symbolic("aten::cumsum") +@symbolic_helper.parse_args("v", "i", "none") +def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None): + dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int)) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + cast = g.op( + "Cast", self, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + else: + cast = self + csum = g.op("CumSum", cast, dim_tensor) + return csum + + +@_onnx_symbolic("aten::masked_select") +def masked_select(g: jit_utils.GraphContext, self, mask): + index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) + return g.op("GatherND", self, index) + + +@_onnx_symbolic("aten::masked_scatter") +def masked_scatter(g: jit_utils.GraphContext, self, mask, source): + index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) + # NOTE: source can have more elements than needed. + # It could also have arbitrary shape. + # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. + source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1])) + source = symbolic_helper._slice_helper( + g, + source, + axes=torch.LongTensor([0]), + starts=torch.LongTensor([0]), + ends=opset9.size(g, index, torch.LongTensor([0])), + ) + return g.op("ScatterND", self, index, source) + + +@_onnx_symbolic("aten::len") +def _len(g: jit_utils.GraphContext, self): + if ( + symbolic_helper._is_tensor_list(self) + or self.node().kind() == "onnx::SplitToSequence" + ): + return g.op("SequenceLength", self) + sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) + return symbolic_helper._squeeze_helper(g, sz_0, [0]) + + +@_onnx_symbolic("aten::__getitem_") +def __getitem_(g: jit_utils.GraphContext, self, i): + if symbolic_helper._is_tensor_list(self): + # SequenceAt requires that the input be a List of Tensors + return g.op("SequenceAt", self, i) + else: + from torch.onnx.symbolic_opset9 import __getitem_ as getitem + + return getitem(g, self, i) + + +@_onnx_symbolic("aten::_set_item") +def _set_item(g: jit_utils.GraphContext, tensor_list, i, v): + tensor_list = g.op("SequenceErase", tensor_list, i) + return g.op("SequenceInsert", tensor_list, v, i) + + +@_onnx_symbolic("aten::append") +def append(g: jit_utils.GraphContext, self, tensor): + return g.op("SequenceInsert", self, tensor) + + +@_onnx_symbolic("aten::add") +def add(g: jit_utils.GraphContext, self, other, alpha=None): + if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): + tensor_list_node = other.node() + if tensor_list_node.kind() != "prim::ListConstruct": + return symbolic_helper._unimplemented( + "add", "does not support adding dynamic tensor list to another" + ) + tensors = symbolic_helper._unpack_list(other) + l = self + for t in tensors: + l = g.op("SequenceInsert", l, t) + return l + + return opset9.add(g, self, other, alpha) + + +@_onnx_symbolic("aten::insert") +def insert(g: jit_utils.GraphContext, self, pos, tensor): + return g.op("SequenceInsert", self, tensor, pos) + + +@_onnx_symbolic("aten::pop") +def pop(g: jit_utils.GraphContext, tensor_list, dim): + return g.op("SequenceErase", tensor_list, dim) + + +@_onnx_symbolic("aten::Delete") +def Delete(g: jit_utils.GraphContext, tensor_list, dim): + return g.op("SequenceErase", tensor_list, dim) + + +@_onnx_symbolic("aten::cat") +@symbolic_helper.quantized_args(True) +def cat(g: jit_utils.GraphContext, tensor_list, dim): + if symbolic_helper._is_packed_list(tensor_list): + return opset9.cat(g, tensor_list, dim) + else: + dim = symbolic_helper._get_const(dim, "i", "dim") + return g.op("ConcatFromSequence", tensor_list, axis_i=dim) + + +@_onnx_symbolic("aten::stack") +def stack(g: jit_utils.GraphContext, tensor_list, dim): + if symbolic_helper._is_packed_list(tensor_list): + return opset9.stack(g, tensor_list, dim) + else: + dim = symbolic_helper._get_const(dim, "i", "dim") + return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1) + + +@_onnx_symbolic("aten::_unique2") +@symbolic_helper.parse_args("v", "i", "i", "i") +def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts): + u, indices, inverse_indices, counts = g.op( + "Unique", self, sorted_i=sorted, outputs=4 + ) + return u, inverse_indices, counts + + +@_onnx_symbolic("aten::unique_dim") +@symbolic_helper.parse_args("v", "i", "i", "i", "i") +def unique_dim( + g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts +): + u, indices, inverse_indices, counts = g.op( + "Unique", self, axis_i=dim, sorted_i=sorted, outputs=4 + ) + return u, inverse_indices, counts + + +@_onnx_symbolic("aten::topk") +@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") +def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): + return symbolic_helper._topk_helper( + g, self, k, dim, largest=largest, sorted=sorted, out=out + ) + + +@_onnx_symbolic("aten::sort") +@symbolic_helper.parse_args("v", "i", "i", "none") +def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): + return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) + + +@_onnx_symbolic("aten::argsort") +@symbolic_helper.parse_args("v", "i", "i", "none") +def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None): + _, indices = symbolic_helper._sort_helper( + g, self, dim, decending=decending, out=out + ) + return indices + + +@_onnx_symbolic("aten::round") +@symbolic_helper.parse_args("v", "i") +def round(g: jit_utils.GraphContext, self, decimals=0): + if not symbolic_helper._is_fp(self): + return self + if decimals == 0: + return g.op("Round", self) + mul = g.op("Mul", self, g.op("Constant", value_t=torch.tensor(pow(10, decimals)))) + round = g.op("Round", mul) + return g.op( + "Mul", round, g.op("Constant", value_t=torch.tensor(pow(10, -1 * decimals))) + ) + + +@_onnx_symbolic("aten::remainder") +def remainder(g: jit_utils.GraphContext, input, other): + if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other): + return opset9.remainder(g, input, other) + return g.op("Mod", input, other, fmod_i=0) + + +@_onnx_symbolic("aten::split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): + split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) + if _outputs is None: + return split_out + # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. + if ( + symbolic_helper._is_packed_list(split_size_or_sizes) + and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs + ): + split_sizes = [ + symbolic_helper._unsqueeze_helper(g, v, [0]) + for v in symbolic_helper._unpack_list(split_size_or_sizes) + ] + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + res = [] + for i in range(_outputs): + end = g.op( + "Add", start, split_sizes[i] + ) # split_sizes is a list of same length as _outputs + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + return [ + g.op( + "SequenceAt", + split_out, + g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), + ) + for i in range(_outputs) + ] + else: + return opset9.split(g, self, split_size_or_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::split_with_sizes") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): + return split(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unbind") +@symbolic_helper.parse_args("v", "i", "i") +def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): + if _outputs is None: + return g.op( + "SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, + keepdims_i=0, + ) + else: + return opset9.unbind(g, self, dim, _outputs) + + +def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad): + """Generate paddings in ONNX order based on pad in pytorch. + + Args: + input: the input tensor. + pad: the paddings in pytorch. + The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end, + where m is in range [0, n]. + """ + if ( + not symbolic_helper._is_packed_list(pad) + and symbolic_helper._is_list(pad) + and symbolic_helper._is_scalar_list(pad) + ): + pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1) + # The desired order of paddings is + # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. + # n is the dimension of input. + # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning + pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0]))) + # Set extension = [0] * (dim * 2 - len(pad)) + rank = symbolic_helper._get_tensor_rank(input) + if rank is None: + rank = g.op("Size", g.op("Shape", input)) + else: + rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64)) + extension = g.op( + "Sub", + g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), + pad_len, + ) + # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ] + # Currently ONNX only supports int64 type for Pad + pad = g.op("Cast", pad, to_i=_C_onnx.TensorProtoDataType.INT64) + paddings = g.op( + "Concat", + pad, + g.op( + "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64) + ), + axis_i=0, + ) + # Reshape and reverse order and collate first beginnings and then ends + # paddings = [[..., 0, dim_n-1_begin, dim_n_begin], + # [..., 0, dim_n-1_end, dim_n_end]] + # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end] + paddings = symbolic_helper._reshape_helper( + g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2])) + ) + paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0]) + paddings = symbolic_helper._reshape_helper( + g, paddings, g.op("Constant", value_t=torch.tensor([-1])) + ) + padding_c = g.op("Cast", paddings, to_i=_C_onnx.TensorProtoDataType.INT64) + return padding_c + + +@_onnx_symbolic("aten::constant_pad_nd") +def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None): + mode = "constant" + value = symbolic_helper._maybe_get_scalar(value) + value = symbolic_helper._if_scalar_type_as(value, input) + pad = _prepare_onnx_paddings(g, input, padding) + return g.op("Pad", input, pad, value, mode_s=mode) + + +@_onnx_symbolic("aten::reflection_pad1d") +@_onnx_symbolic("aten::reflection_pad2d") +@_onnx_symbolic("aten::reflection_pad3d") +def reflection_pad(g: jit_utils.GraphContext, input, padding): + mode = "reflect" + paddings = _prepare_onnx_paddings(g, input, padding) + return g.op("Pad", input, paddings, mode_s=mode) + + +@_onnx_symbolic("aten::replication_pad1d") +@_onnx_symbolic("aten::replication_pad2d") +@_onnx_symbolic("aten::replication_pad3d") +def replication_pad(g: jit_utils.GraphContext, input, padding): + mode = "edge" + paddings = _prepare_onnx_paddings(g, input, padding) + return g.op("Pad", input, paddings, mode_s=mode) + + +@_onnx_symbolic("aten::pad") +def pad( + g: jit_utils.GraphContext, + input: _C.Value, + pad: _C.Value, + mode: _C.Value, + value: _C.Value, +): + mode = symbolic_helper._parse_arg(mode, "s") + if mode == "replicate": + return replication_pad(g, input, pad) + elif mode == "reflect": + return reflection_pad(g, input, pad) + elif mode == "constant": + return constant_pad_nd(g, input, pad, value) + elif mode == "circular": + return opset9._pad_circular(g, input, pad) + else: + raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) + + +@_onnx_symbolic("aten::linalg_det") +def linalg_det(g: jit_utils.GraphContext, self): + return g.op("Det", self) + + +@_onnx_symbolic("aten::logdet") +def logdet(g: jit_utils.GraphContext, input): + return opset9.log(g, linalg_det(g, input)) + + +@_onnx_symbolic("aten::arange") +def arange(g: jit_utils.GraphContext, *args): + def _get_arange_dtype(dtype): + dtype = symbolic_helper._maybe_get_const(dtype, "i") + return dtype + + if len(args) == 2 and all(isinstance(val, int) for val in args): + # aten::arange(Scalar start, Scalar end) + dtype = torch.int64 + # Start index. + start = g.op( + "Constant", + value_t=torch.tensor(args[0], dtype=dtype), + ) + # End (exclusive) index. + end = g.op( + "Constant", + value_t=torch.tensor(args[1], dtype=dtype), + ) + # Step size from start to end indexes. + delta_default = g.op( + "Constant", + value_t=torch.tensor(1, dtype=dtype), + ) + return g.op("Range", start, end, delta_default) + elif len(args) == 2 or len(args) == 5: + if len(args) == 2: + # aten::arange(Scalar end, Tensor out) + dtype = None + else: + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[1]) + type_, end, start, step = symbolic_helper._arange_cast_helper( + g, end=args[0], dtype=dtype + ) + start_default = g.op( + "Constant", + value_t=torch.tensor(0, dtype=type_.dtype()), + ) + delta_default = g.op( + "Constant", + value_t=torch.tensor(1, dtype=type_.dtype()), + ) + return g.op("Range", start_default, end, delta_default) + elif len(args) == 4 or len(args) == 7: + if len(args) == 4: + # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) + dtype = None + else: + # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[3]) + _, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], step=args[2], dtype=dtype + ) + return g.op("Range", start, end, step) + elif len(args) == 6: + # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[2]) + type_, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], dtype=dtype + ) + delta_default = g.op( + "Constant", + value_t=torch.tensor(1, dtype=type_.dtype()), + ) + return g.op("Range", start, end, delta_default) + else: + return symbolic_helper._unimplemented( + "aten::arange", f"with {len(args)} arguments" + ) + + +@_onnx_symbolic("aten::_dim_arange") +@symbolic_helper.parse_args("v", "i") +def _dim_arange(g: jit_utils.GraphContext, like, dim): + like_shape = g.op("Shape", like) + stop = g.op( + "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 + ) + return arange(g, stop, 4, None, None, None) + + +@_onnx_symbolic("aten::size") +@symbolic_helper.quantized_args(True, quantize_output=False) +def size(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Shape", self) + return symbolic_helper._size_helper(g, self, dim) + + +@_onnx_symbolic("aten::squeeze") +def squeeze(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Squeeze", self) + + # dim as a tensor + if not symbolic_helper._is_constant(dim): + return symbolic_helper._squeeze_helper(g, self, [dim]) + + dim = symbolic_helper._get_const(dim, "i", "dim") + + input_rank = symbolic_helper._get_tensor_rank(self) + adjusted_dim = dim + if input_rank is not None and dim < 0: + adjusted_dim += input_rank + dim_size = symbolic_helper._get_tensor_dim_size(self, adjusted_dim) + if (dim < 0 and input_rank is None) or dim_size is None: + # If onnx shape inference is not on, export always as dynamic. + # Because we cannot tell if observed static shape is also static at runtime. + # create "cond" node (condition is shape[i]==1) + dim_constant = g.op("Constant", value_t=torch.tensor([dim])) + size = symbolic_helper._size_helper(g, self, dim_constant) + const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64)) + cond = g.op("Equal", size, const_one) + # create the "If" node and add the "then" and "else" blocks to it. + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", cond, n_blocks=2 + ) + squeeze_ = symbolic_helper._squeeze_helper(if_context, self, [dim]) + utils._add_output_to_block(if_context.block, squeeze_) + identity_ = else_context.op("Identity", self) + utils._add_output_to_block(else_context.block, identity_) + return if_op + + # For static input shape + dim = adjusted_dim + if dim_size > 1: + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(dim) + + ". The size of " + + "this dimension in the given input is " + + str(dim_size) + + ". The model will " + + "be exported without the squeeze node. If the model is intended to be used with dynamic " + + "input shapes, please export with dynamic_axes argument." + ) + return self + return symbolic_helper._squeeze_helper(g, self, [dim]) + + +@_onnx_symbolic("aten::unsqueeze") +def unsqueeze(g: jit_utils.GraphContext, self, dim): + if symbolic_helper._is_constant(dim): + dim = symbolic_helper._get_const(dim, "i", "dim") + + return symbolic_helper._unsqueeze_helper(g, self, [dim]) + + +@_onnx_symbolic("aten::mm") +def mm(g: jit_utils.GraphContext, self, other): + return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) + + +@_onnx_symbolic("aten::index") +def index(g: jit_utils.GraphContext, self, index): + if symbolic_helper._is_packed_list(index): + indices = symbolic_helper._unpack_list(index) + else: + indices = [index] + + # Handle single mask index. + if len(indices) == 1: + index = indices[0] + if not symbolic_helper._is_none(index) and ( + symbolic_helper._is_bool(index) + or _type_utils.JitScalarType.from_value(index) + == _type_utils.JitScalarType.UINT8 + ): + index = opset9.nonzero(g, index) + return g.op("GatherND", self, index) + return opset9.index(g, self, index) + + +@_onnx_symbolic("aten::index_fill") +def index_fill(g: jit_utils.GraphContext, self, dim, index, value): + dim_value = symbolic_helper._parse_arg(dim, "i") + expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + value = symbolic_helper._maybe_get_scalar(value) + value = symbolic_helper._if_scalar_type_as(value, self) + expanded_value = opset9.expand(g, value, expanded_index_shape, None) + return scatter(g, self, dim, expanded_index, expanded_value) + + +@_onnx_symbolic("aten::index_copy") +def index_copy(g: jit_utils.GraphContext, self, dim, index, source): + dim_value = symbolic_helper._parse_arg(dim, "i") + expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + return scatter(g, self, dim, expanded_index, source) + + +@_onnx_symbolic("aten::bitwise_right_shift") +@_onnx_symbolic("aten::__rshift_") +def __rshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + if _type_utils.JitScalarType.from_value( + other, _type_utils.JitScalarType.UNDEFINED + ) != _type_utils.JitScalarType.from_value(self): + other = g.op( + "Cast", + other, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + + if ( + _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) + == _type_utils.JitScalarType.UINT8 + ): + return g.op("BitShift", self, other, direction_s="RIGHT") + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + rshift = g.op("Div", self, two_pow) + return rshift + + +@_onnx_symbolic("aten::bitwise_left_shift") +@_onnx_symbolic("aten::__lshift_") +def __lshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + if _type_utils.JitScalarType.from_value( + other, _type_utils.JitScalarType.UNDEFINED + ) != _type_utils.JitScalarType.from_value(self): + other = g.op( + "Cast", + other, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + + if ( + _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) + == _type_utils.JitScalarType.UINT8 + ): + return g.op("BitShift", self, other, direction_s="LEFT") + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + lshift = g.op("Mul", self, two_pow) + return lshift + + +def _get_im2col_indices_along_dim( + g: jit_utils.GraphContext, input_d, kernel_size_d, dilation_d, padding_d, stride_d +): + # Input is always 4-D (N, C, H, W) + # Calculate indices of sliding blocks along spatial dimension + # Slide kernel over input each dim d: + # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) + # with steps = stride + + blocks_d = g.op( + "Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2)) + ) + blocks_d = g.op( + "Sub", + blocks_d, + g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))), + ) + + # Stride kernel over input and find starting indices along dim d + blocks_d_indices = g.op( + "Range", + g.op("Constant", value_t=torch.tensor(0)), + blocks_d, + g.op("Constant", value_t=torch.tensor(stride_d)), + ) + + # Apply dilation on kernel and find its indices along dim d + kernel_grid = torch.arange(0, kernel_size_d * dilation_d, dilation_d) + kernel_grid = g.op("Constant", value_t=kernel_grid.unsqueeze(0)) + + # Broadcast and add kernel staring positions (indices) with + # kernel_grid along dim d, to get block indices along dim d + blocks_d_indices = symbolic_helper._unsqueeze_helper( + g, blocks_d_indices, [0] + ) # Reshape to [1, -1] + kernel_mask = symbolic_helper._reshape_helper( + g, kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1])) + ) + block_mask = g.op("Add", blocks_d_indices, kernel_mask) + + return block_mask + + +def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, padding_w): + # Input is always 4-D tensor (N, C, H, W) + # Padding tensor has the following format: (padding_h, padding_w) + # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...) + pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2)) + return g.op("Pad", input, pad) + + +def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_w): + batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0))) + channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1))) + channel_unfolded = g.op( + "Mul", channel_dim, g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w)) + ) + + return g.op( + "Concat", + symbolic_helper._unsqueeze_helper(g, batch_dim, [0]), + symbolic_helper._unsqueeze_helper(g, channel_unfolded, [0]), + g.op("Constant", value_t=torch.tensor([-1])), + axis_i=0, + ) + + +@_onnx_symbolic("aten::im2col") +@symbolic_helper.parse_args("v", "is", "is", "is", "is") +def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, stride): + # Input is always 4-D tensor (N, C, H, W) + # All other args are int[2] + + input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2))) + input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3))) + + stride_h, stride_w = stride[0], stride[1] + padding_h, padding_w = padding[0], padding[1] + dilation_h, dilation_w = dilation[0], dilation[1] + kernel_h, kernel_w = kernel_size[0], kernel_size[1] + + blocks_row_indices = _get_im2col_indices_along_dim( + g, input_h, kernel_h, dilation_h, padding_h, stride_h + ) + blocks_col_indices = _get_im2col_indices_along_dim( + g, input_w, kernel_w, dilation_w, padding_w, stride_w + ) + + output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w) + padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w) + + # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 + # [[[[1., 2., 3.,], + # [4., 5., 6.,], + # [7., 8., 9.,]]]] + # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[1., 2., 3.], + # [4., 5., 6.]], + # [[4., 5., 6.], + # [7., 8., 9.]]]]] + # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[[1., 2.], + # [4., 5.]], + # [[2., 3.], + # [5., 6]]], + # [[[4., 5.], + # [7., 8.]], + # [[5., 6.], + # [8., 9.]]]]]] + # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: + # [[[1., 2., 4., 5.], + # [2., 3., 5., 6.], + # [4., 5., 7., 8.], + # [5., 6., 8., 9.]]] + output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2) + output = g.op("Gather", output, blocks_col_indices, axis_i=4) + output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5]) + return symbolic_helper._reshape_helper(g, output, output_shape) + + +@_onnx_symbolic("aten::narrow") +def narrow(g: jit_utils.GraphContext, input, dim, start, length): + end = g.op("Add", start, length) + return symbolic_helper._slice_helper(g, input, axes=dim, starts=start, ends=end) + + +@_onnx_symbolic("aten::flatten") +@symbolic_helper.quantized_args(True, False, False) +@symbolic_helper.parse_args("v", "i", "i") +def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): + dim = symbolic_helper._get_tensor_rank(input) + if dim == 1: + return input + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim == 1: + if end_dim == -1 or (dim is not None and end_dim == dim - 1): + return g.op("Flatten", input, axis_i=start_dim) + elif start_dim == 0: + if end_dim == -2 or (dim is not None and end_dim == dim - 2): + return g.op("Flatten", input, axis_i=end_dim + 1) + if dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + ) + # if end_dim is negative add dim + if end_dim < 0: + end_dim = dim + end_dim + + return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) + + +@_onnx_symbolic("aten::linalg_vector_norm") +@symbolic_helper.parse_args("v", "f", "is", "b", "v") +def linalg_vector_norm( + g: jit_utils.GraphContext, + self, + ord, + dim: Sequence[int] | None, + keepdim: bool, + dtype, +): + return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + return symbolic_helper._embedding_bag_helper( + g, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) + + +@_onnx_symbolic("aten::embedding_renorm") +@symbolic_helper.parse_args("v", "v", "f", "f") +def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type): + unique_indices = g.op("Unique", indices) + partial_weight = g.op("Gather", weight, unique_indices) + norm_i = int(norm_type) + if norm_i == 1: + norm_type = "ReduceL1" + elif norm_i == 2: + norm_type = "ReduceL2" + else: + raise errors.SymbolicValueError( + f"Unsupported: ONNX export of embedding_renorm with norm: {norm_i}. " + "Only 1. and 2. are supported.", + weight, + ) + partial_weight_norm = g.op(norm_type, partial_weight, axes_i=[1], keepdims_i=1) + # https://github.com/pytorch/pytorch/blob/0a07488ed2c47765e337e290bd138c0e6e459cbd/aten/src/ATen/native/Embedding.cpp#L177 + # Add 1e-7 to prevent division by zero. + partial_weight_norm_ = g.op( + "Add", partial_weight_norm, g.op("Constant", value_t=torch.tensor(1e-7)) + ) + max_norm = torch.tensor(max_norm) + scales = g.op("Div", max_norm, partial_weight_norm_) + partial_weight_renorm = g.op("Mul", partial_weight, scales) + partial_weight_renorm = g.op( + "Where", + g.op("Greater", partial_weight_norm, max_norm), + partial_weight_renorm, + partial_weight, + ) + return g.op( + "ScatterND", + weight, + symbolic_helper._unsqueeze_helper(g, unique_indices, [1]), + partial_weight_renorm, + ) + + +@_onnx_symbolic("aten::chunk") +def chunk(g: jit_utils.GraphContext, self, chunks, dim): + # Calculate chunk size for dynamic chunk + dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0) + chunk_size_s = g.op( + "Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long)) + ) + chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks) + # Create splits vector + chunk_vec = [ + opset9.expand(g, chunk_size, chunk_size_s, None), + g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s)), + ] + chunk_vec = g.op("Concat", *chunk_vec, axis_i=0) + return split(g, self, chunk_vec, dim) + + +@_onnx_symbolic("aten::normal") +def normal( + g: jit_utils.GraphContext, + mean, + std, + sizes=None, + generator=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, +): + # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a + # scale-location transformation of that distribution, which has mean mu and variance sigma's square. If x is a sample + # from a mean 0 and variance 1 distribution then + # sigma x+mu + # is a sample with mean mu and variance sigma's square. + if sizes is not None and not symbolic_helper._is_none(sizes): + mean = opset9.expand(g, mean, sizes, None) + result = opset9.mul(g, std, g.op("RandomNormalLike", mean)) + return add(g, result, mean) + + +@_onnx_symbolic("aten::atleast_1d") +def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value): + # NOTE: If it's 0D, reshape to 1D + + # NOTE: self could be a packed list or a tensor + if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): + tensor_list = symbolic_helper._unpack_list(self) + new_tensor_list = [] + for tensor in tensor_list: + new_tensor = tensor + tensor_rank = symbolic_helper._get_tensor_rank(tensor) + if tensor_rank == 0: + new_tensor = symbolic_helper._reshape_helper( + g, new_tensor, g.op("Constant", value_t=torch.tensor([1])) + ) + new_tensor_list.append(new_tensor) + return g.op("SequenceConstruct", *new_tensor_list) + + tensor_rank = symbolic_helper._get_tensor_rank(self) + if tensor_rank == 0: + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([1])) + ) + return self + + +@_onnx_symbolic("aten::atleast_2d") +def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value): + # NOTE: If it's 0D, reshape to 2D + # If it's 1D, unsqueeze to 2D + + # NOTE: self could be a packed list or a tensor + if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): + tensor_list = symbolic_helper._unpack_list(self) + new_tensor_list = [] + for tensor in tensor_list: + new_tensor = tensor + tensor_rank = symbolic_helper._get_tensor_rank(tensor) + if tensor_rank == 0: + new_tensor = symbolic_helper._reshape_helper( + g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1])) + ) + elif tensor_rank == 1: + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[0] + ) + new_tensor_list.append(new_tensor) + return g.op("SequenceConstruct", *new_tensor_list) + + tensor_rank = symbolic_helper._get_tensor_rank(self) + if tensor_rank == 0: + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([1, 1])) + ) + elif tensor_rank == 1: + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) + return self + + +@_onnx_symbolic("aten::atleast_3d") +def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value): + # NOTE: If it's 0D, reshape to 3D + # If it's 1D, unsqueeze to 3D + # If it's 2D, unsqueeze to 3D + + # NOTE: self could be a packed list or a tensor + if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): + tensor_list = symbolic_helper._unpack_list(self) + new_tensor_list = [] + for tensor in tensor_list: + new_tensor = tensor + tensor_rank = symbolic_helper._get_tensor_rank(tensor) + if tensor_rank == 0: + new_tensor = symbolic_helper._reshape_helper( + g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1, 1])) + ) + elif tensor_rank == 1: + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[0] + ) + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[-1] + ) + elif tensor_rank == 2: + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[-1] + ) + new_tensor_list.append(new_tensor) + return g.op("SequenceConstruct", *new_tensor_list) + + tensor_rank = symbolic_helper._get_tensor_rank(self) + if tensor_rank == 0: + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([1, 1, 1])) + ) + elif tensor_rank == 1: + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) + elif tensor_rank == 2: + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) + return self + + +@_onnx_symbolic("prim::ConstantChunk") +def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): + input_shape = g.op("Shape", self) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0) + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)) + chunk_size_minus_1 = g.op( + "Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long) + ) + input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1) + chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size) + res = [] + for i in range(chunks): + index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long)) + end = g.op("Mul", chunk_dim, index) + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + + +@_onnx_symbolic("aten::hstack") +def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value): + tensor_list = atleast_1d(g, tensor_list) + first_tensor = g.op( + "SequenceAt", + tensor_list, + g.op("Constant", value_t=torch.tensor(0, dtype=torch.long)), + ) + first_tensor_shape = g.op("Shape", first_tensor) + first_tensor_dim = g.op("Size", first_tensor_shape) + + const_one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) + equal_to_one = g.op("Equal", first_tensor_dim, const_one) + + ( + if_op_greater, + (if_context_equal, else_context_equal), + _, + ) = jit_utils.add_op_with_blocks(g, "If", equal_to_one, n_blocks=2, outputs=1) + result_if = if_context_equal.op( + "ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0 + ) + utils._add_output_to_block(if_context_equal.block, result_if) + result_else = else_context_equal.op( + "ConcatFromSequence", tensor_list, axis_i=1, new_axis_i=0 + ) + utils._add_output_to_block(else_context_equal.block, result_else) + result = if_op_greater.node().output() + + return result + + +@_onnx_symbolic("aten::vstack") +def vstack(g: jit_utils.GraphContext, tensor_list: _C.Value): + tensor_list = atleast_2d(g, tensor_list) + return g.op("ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0) diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset12.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset12.py new file mode 100644 index 0000000000000000000000000000000000000000..7aaefd37201dd4541c520fb138f3e85c436a9ebe --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset12.py @@ -0,0 +1,465 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +from __future__ import annotations + +import functools +import sys + +import torch +from torch._C import _onnx as _C_onnx +from torch.onnx import ( + _type_utils, + errors, + symbolic_helper, + symbolic_opset9 as opset9, + utils, +) +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +# This file exports ONNX ops for opset 12 + +__all__ = [ + "argmax", + "argmin", + "binary_cross_entropy_with_logits", + "celu", + "cross_entropy_loss", + "dropout", + "einsum", + "ge", + "le", + "native_dropout", + "nll_loss", + "nll_loss2d", + "nll_loss_nd", + "outer", + "pow", + "tensordot", + "unfold", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12) + + +def _einsum_helper(g: jit_utils.GraphContext, equation, tensors): + if not tensors: + raise RuntimeError("Einsum inputs are empty.") + # ONNX does not support bool for Einsum inputs. + if symbolic_helper._is_bool(tensors[0]): + tensors = [ + g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64) + for tensor in tensors + ] + return g.op( + "Cast", + g.op("Einsum", *tensors, equation_s=equation), + to_i=_C_onnx.TensorProtoDataType.BOOL, + ) + else: + return g.op("Einsum", *tensors, equation_s=equation) + + +@_onnx_symbolic("aten::einsum") +@symbolic_helper.parse_args("s", "v", "is") +def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None): + tensors = symbolic_helper._unpack_list(tensor_list) + return _einsum_helper(g, equation, tensors) + + +@_onnx_symbolic("aten::outer") +@symbolic_helper.parse_args("v", "v") +def outer(g: jit_utils.GraphContext, input, other): + # make sure to cast other to self's type + if _type_utils.JitScalarType.from_value( + other, _type_utils.JitScalarType.UNDEFINED + ) != _type_utils.JitScalarType.from_value(input): + other = g.op( + "Cast", + other, + to_i=_type_utils.JitScalarType.from_value(input).onnx_type(), + ) + return _einsum_helper(g, "i,j->ij", [input, other]) + + +def _dropout_returns_masked_input_and_mask( + g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool +) -> tuple[torch._C.Value, torch._C.Value | None]: + symbolic_helper.check_training_mode(train, "dropout") + # In eval mode, dropout is non-op. That is, if the node's + # train param is set to False, dropout just returns its inputs. + if not train: + return input, None + p = g.op("Constant", value_t=torch.tensor(p)) + t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool)) + r, mask = g.op("Dropout", input, p, t, outputs=2) + return r, mask + + +@_onnx_symbolic("aten::dropout") +@symbolic_helper.parse_args("v", "f", "b") +def dropout(g: jit_utils.GraphContext, input, p, train): + masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train) + return masked + + +@_onnx_symbolic("aten::native_dropout") +@symbolic_helper.parse_args("v", "f", "b") +def native_dropout(g: jit_utils.GraphContext, input, p, train): + return _dropout_returns_masked_input_and_mask(g, input, p, train) + + +@_onnx_symbolic("aten::nll_loss") +def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index): + # none reduction : onnx::Constant[value={0}] + # mean reduction : onnx::Constant[value={1}] + # sum reduction : onnx::Constant[value={2}] + reduction = symbolic_helper._maybe_get_const(reduction, "i") + reduction_vals = ["none", "mean", "sum"] + reduction = reduction_vals[reduction] + + # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. + # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). + ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") + if weight.node().mustBeNone(): + nllloss = g.op( + "NegativeLogLikelihoodLoss", + self, + target, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + else: + nllloss = g.op( + "NegativeLogLikelihoodLoss", + self, + target, + weight, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + + return nllloss + + +@_onnx_symbolic("aten::nll_loss2d") +def nll_loss2d( + g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index +): + return nll_loss(g, self, target, weight, reduction, ignore_index) + + +@_onnx_symbolic("aten::nll_loss_nd") +def nll_loss_nd( + g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index +): + return nll_loss(g, self, target, weight, reduction, ignore_index) + + +@_onnx_symbolic("aten::cross_entropy_loss") +def cross_entropy_loss( + g: jit_utils.GraphContext, + self, + target, + weight, + reduction, + ignore_index, + label_smoothing, +): + # none reduction : onnx::Constant[value={0}] + # mean reduction : onnx::Constant[value={1}] + # sum reduction : onnx::Constant[value={2}] + reduction = symbolic_helper._maybe_get_const(reduction, "i") + reduction_vals = ["none", "mean", "sum"] + reduction = reduction_vals[reduction] + + label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f") + if label_smoothing is not None and label_smoothing > 0.0: + raise errors.SymbolicValueError( + "Unsupported: ONNX does not support label_smoothing", self + ) + + # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value. + # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). + ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") + if weight.node().mustBeNone(): + celoss = g.op( + "SoftmaxCrossEntropyLoss", + self, + target, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + else: + celoss = g.op( + "SoftmaxCrossEntropyLoss", + self, + target, + weight, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + + return celoss + + +@_onnx_symbolic("aten::binary_cross_entropy_with_logits") +@symbolic_helper.parse_args("v", "v", "v", "v", "i") +def binary_cross_entropy_with_logits( + g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction +): + p = g.op("Constant", value_t=torch.tensor([1])) + sig_x = opset9.sigmoid(g, input) + log_sig_x = opset9.log(g, sig_x) + sub_1_x = opset9.sub(g, p, sig_x) + sub_1_y = opset9.sub(g, p, target) + log_1_x = opset9.log(g, sub_1_x) + if pos_weight is None or symbolic_helper._is_none(pos_weight): + output = opset9.neg( + g, + opset9.add( + g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x) + ), + ) + else: + output = opset9.neg( + g, + opset9.add( + g, + opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight), + opset9.mul(g, sub_1_y, log_1_x), + ), + ) + + if weight is not None and not symbolic_helper._is_none(weight): + output = opset9.mul(g, weight, output) + + reduction = symbolic_helper._maybe_get_const(reduction, "i") + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output, keepdims_i=0) + elif reduction == 2: + return g.op("ReduceSum", output, keepdims_i=0) + else: + return symbolic_helper._onnx_unsupported( + "binary_cross_entropy_with_logits with reduction other than none, mean, or sum", + input, + ) + + +@_onnx_symbolic("aten::celu") +def celu(g: jit_utils.GraphContext, self, alpha): + alpha = symbolic_helper._maybe_get_const(alpha, "f") + # if the input is of type double cast it to float + if ( + _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) + == _type_utils.JitScalarType.DOUBLE + ): + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = g.op("Celu", self, alpha_f=alpha) + return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE) + + return g.op("Celu", self, alpha_f=alpha) + + +@_onnx_symbolic("aten::argmax") +@symbolic_helper.parse_args("v", "v", "b") +def argmax( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") + + +@_onnx_symbolic("aten::argmin") +@symbolic_helper.parse_args("v", "v", "b") +def argmin( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") + + +@_onnx_symbolic("aten::pow") +def pow(g: jit_utils.GraphContext, self, exponent): + return g.op("Pow", self, exponent) + + +@_onnx_symbolic("aten::ge") +def ge(g: jit_utils.GraphContext, input, other): + return g.op("GreaterOrEqual", input, other) + + +@_onnx_symbolic("aten::le") +def le(g: jit_utils.GraphContext, input, other): + return g.op("LessOrEqual", input, other) + + +@_onnx_symbolic("aten::unfold") +@symbolic_helper.parse_args("v", "i", "v", "v") +def unfold(g: jit_utils.GraphContext, input, dimension, size, step): + const_size = symbolic_helper._maybe_get_const(size, "i") + const_step = symbolic_helper._maybe_get_const(step, "i") + if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value( + const_step + ): + return opset9.unfold(g, input, dimension, const_size, const_step) + + sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) + if sizedim is not None: + low_start = g.op("Constant", value_t=torch.tensor(0)) + low_end = g.op("Constant", value_t=torch.tensor(sizedim)) + hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) + low_indices = g.op("Range", low_start, low_end, step) + hi_indices = g.op("Range", size, hi_end, step) + + low_size = symbolic_helper._size_helper( + g, low_indices, g.op("Constant", value_t=torch.tensor(0)) + ) + hi_size = symbolic_helper._size_helper( + g, hi_indices, g.op("Constant", value_t=torch.tensor(0)) + ) + + ndim = symbolic_helper._get_tensor_rank(input) + assert ndim is not None + perm = list(range(0, ndim)) + perm.append(perm.pop(dimension)) + + unsqueeze_list = [] + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op( + "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL + ) + loop_len = g.op("Min", low_size, hi_size) + + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, n_blocks=1 + ) + + loop_block = loop_context.block + block_input_iter = utils._add_input_to_block(loop_block) + # FIXME(justinchuby): cond is unused? + cond = utils._add_input_to_block(loop_block) + + starts = loop_context.op("Gather", low_indices, block_input_iter) + ends = loop_context.op("Gather", hi_indices, block_input_iter) + axes = loop_context.op("Constant", value_t=torch.tensor([2])) + starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0]) + ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0]) + stack = loop_context.op("Slice", input, starts, ends, axes) + + unsqueeze = symbolic_helper._unsqueeze_helper( + loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension] + ) + unsqueeze_list.append(unsqueeze) + concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0) + + cond_out = loop_context.op( + "Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL + ) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, concat) + + loop_output = loop.node().output() + perm = [0, 1, 2, 3, 4] + perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] + transpose = g.op("Transpose", loop_output, perm_i=perm) + squeeze = symbolic_helper._squeeze_helper(g, transpose, [0]) + + return squeeze + + return symbolic_helper._unimplemented("Unfold", "input size not accessible") + + +@_onnx_symbolic("aten::tensordot") +@symbolic_helper.parse_args("v", "v", "is", "is", "v") +def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None): + if out is not None: + symbolic_helper._unimplemented( + "Tensordot", "Out parameter is not supported for tensordot." + ) + + dim_count_a = symbolic_helper._get_tensor_rank(input_a) + if dim_count_a is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.", + input_a, + ) + + dim_count_b = symbolic_helper._get_tensor_rank(input_b) + if dim_count_b is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.", + input_b, + ) + + dims_a = [ + (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] + for i in range(len(dims_a)) + ] + dims_b = [ + (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] + for i in range(len(dims_b)) + ] + + left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] + left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] + + new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a) + new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b) + + input_shape = g.op("Shape", new_input_a) + left_sizes_a = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)] + ) + shape_sizes = [ + left_sizes_a, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + ] + output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) + + input_shape = g.op("Shape", output_a) + slices = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] + ) + shape_sizes = [ + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + slices, + ] + output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) + + input_shape = g.op("Shape", new_input_b) + left_sizes_b = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize] + ) + slices = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)] + ) + shape_sizes = [ + slices, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + ] + output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) + + input_shape = g.op("Shape", output_b) + slices = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] + ) + shape_sizes = [ + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + slices, + ] + output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) + + output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) + + shape_sizes = [left_sizes_a, left_sizes_b] + return opset9._reshape_from_tensor(g, output, shape_sizes) diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset13.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset13.py new file mode 100644 index 0000000000000000000000000000000000000000..e31416ae2bc9060e57f6edaefdb271fc5b40f2f6 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset13.py @@ -0,0 +1,1113 @@ +# mypy: allow-untyped-defs +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +# This file exports ONNX ops for opset 13 +import functools + +import torch +import torch._C._onnx as _C_onnx +from torch.onnx import ( + _constants, + _type_utils, + errors, + symbolic_helper, + symbolic_opset11 as opset11, + symbolic_opset9 as opset9, + utils, +) +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13) + + +@_onnx_symbolic("aten::softmax") +@symbolic_helper.parse_args("v", "i", "none") +def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + softmax = g.op("Softmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + softmax = g.op( + "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + + return softmax + + +@_onnx_symbolic("aten::log_softmax") +@symbolic_helper.parse_args("v", "i", "none") +def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + return_op = g.op("LogSoftmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + return_op = g.op( + "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + return return_op + + +@_onnx_symbolic("aten::frobenius_norm") +@symbolic_helper.parse_args("v", "v", "i") +def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): + dim_val = symbolic_helper._maybe_get_const(dim, "is") + if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0: + return g.op("ReduceL2", self, keepdims_i=0) + sqr = g.op("Mul", self, self) + sumsqr = symbolic_helper._reducesum_helper(g, sqr, dim, keepdims_i=keepdim) + return g.op("Sqrt", sumsqr) + + +@_onnx_symbolic("aten::split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): + split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) + if _outputs is None: + return split_out + # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. + if ( + symbolic_helper._is_packed_list(split_size_or_sizes) + and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs + ): + split_sizes = [ + symbolic_helper._unsqueeze_helper(g, v, [0]) + for v in symbolic_helper._unpack_list(split_size_or_sizes) + ] + + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + res = [] + for i in range(_outputs): + end = g.op( + "Add", start, split_sizes[i] + ) # split_sizes is a list of same length as _outputs + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + return [ + g.op( + "SequenceAt", + split_out, + g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), + ) + for i in range(_outputs) + ] + + split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") + if split_val.dim() > 0: + return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) + split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + if _outputs is not None: + size = split_size * _outputs + else: + raise errors.SymbolicValueError( + "Unknown dimension size not supported", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + splits = g.op("Constant", value_t=torch.tensor(splits)) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::split_with_sizes") +def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): + return split(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unsafe_split") +def unsafe_split( + g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None +): + return split(g, self, split_size_or_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unsafe_split_with_sizes") +def unsafe_split_with_sizes( + g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None +): + return split_with_sizes(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::tensor_split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def tensor_split( + g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None +): + axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + axis = opset11.unsqueeze(g, axis, 0) + const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) + + if symbolic_helper._is_split_static(indices_or_sections, _outputs): + split_val = symbolic_helper._node_get(indices_or_sections.node(), "value") + + if split_val.dim() > 0: + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + res = [] + assert _outputs is not None + for i in range(_outputs - 1): + end = g.op( + "Gather", + indices_or_sections, + g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), + axis_i=0, + ) + res.append(g.op("Slice", self, start, end, axis)) + start = end + + end = symbolic_helper._size_helper(g, self, axis) + res.append(g.op("Slice", self, start, end, axis)) + return res + + split_size = symbolic_helper._get_const( + indices_or_sections, "i", "indices_or_sections" + ) + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + if _outputs is not None: + size = split_size * _outputs + else: + raise errors.SymbolicValueError( + "Unknown dimension size not supported", self + ) + + min_split_size = size // split_size + num_splits_one_extra = size % split_size + + splits = num_splits_one_extra * [min_split_size + 1] + leftover = (split_size - num_splits_one_extra) * [min_split_size] + + splits = g.op( + "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long) + ) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + if ( + symbolic_helper._is_tensor(indices_or_sections) + and symbolic_helper._get_tensor_rank(indices_or_sections) == 1 + ): + loop_len = symbolic_helper._size_helper( + g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0)) + ) + loop_len = opset11.unsqueeze(g, loop_len, 0) + loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL) + + # To make the first slice in the below loop work, + # we pad a zero to the first position so that it will be the initial start of slice. + padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0) + + final_splits = g.op("SequenceEmpty") + # Loop inputs + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, final_splits, outputs=1, n_blocks=1 + ) + + loop_block = loop_context.block + block_input_iter = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) + final_splits = utils._add_input_to_block(loop_block) + + start = loop_context.op( + "Gather", indices_or_sections, block_input_iter, axis_i=0 + ) + end = loop_context.op( + "Gather", + indices_or_sections, + loop_context.op("Add", block_input_iter, const_1), + axis_i=0, + ) + + slice = loop_context.op("Slice", self, start, end, axis) + final_splits = loop_context.op("SequenceInsert", final_splits, slice) + + # Loop outputs + cond_out = loop_context.op("Identity", loop_condition) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, final_splits) + + loop_out = loop.node().output() + start = g.op( + "Gather", + indices_or_sections, + g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)), + axis_i=0, + ) + start = opset11.unsqueeze(g, start, 0) + end = symbolic_helper._size_helper(g, self, axis) + + last_slice = g.op("Slice", self, start, end, axis) + + return g.op("SequenceInsert", loop_out, last_slice) + + else: # scalar tensor + dim_size = symbolic_helper._size_helper(g, self, axis) + min_split_size = g.op("Div", dim_size, indices_or_sections) + min_split_size_plus_1 = g.op( + "Add", + min_split_size, + const_1, + ) + num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections) + splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra) + leftover = g.op( + "Tile", + min_split_size, + g.op( + "Sub", + opset11.unsqueeze(g, indices_or_sections, 0), + num_splits_one_extra, + ), + ) + + splits = g.op("Concat", splits, leftover, axis_i=0) + if _outputs is None: + return g.op("SplitToSequence", self, splits, axis_i=dim) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::unbind") +@symbolic_helper.parse_args("v", "i", "i") +def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): + if _outputs is None: + return g.op( + "SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, + keepdims_i=0, + ) + + splits = g.op("Constant", value_t=torch.tensor([1] * _outputs)) + outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + outputs = [outputs] if _outputs == 1 else outputs + squeezed_outputs = [ + g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim]))) + for out in outputs + ] + return squeezed_outputs + + +@_onnx_symbolic("aten::nonzero_numpy") +# Emitted from `torch.nonzero(x, as_tuple=True)` +def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): + return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs) + + +@_onnx_symbolic("aten::where") +@symbolic_helper.parse_args("v", "v", "v", "i") +def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): + # Assumes that torch.where's first argument takes only Bool and Byte tensors. + if not symbolic_helper._is_bool(condition): + condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + if self is None: + condition = opset9.nonzero(g, condition) + return symbolic_helper._unbind_helper( + g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs + ) + return g.op("Where", condition, self, other) + + +@_onnx_symbolic("aten::fake_quantize_per_channel_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i") +def fake_quantize_per_channel_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + axis, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise errors.SymbolicValueError( + "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + inputs, + ) + # ONNX defines zero_point to be int8 or uint8 + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis) + + +@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i") +def fake_quantize_per_tensor_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise errors.SymbolicValueError( + "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + inputs, + ) + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + if ( + _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) + != _type_utils.JitScalarType.FLOAT + ): + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point) + + +def _reduce_op_symbolic(onnx_op_name): + def symbolic(g, self, dim=None, keepdim=None): + self = symbolic_helper._maybe_cast_reduce_op_input(g, self) + if dim is None: + # all-reduce path + return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name) + else: + keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") + return g.op(onnx_op_name, self, dim, keepdims_i=keepdim) + + return symbolic + + +@_onnx_symbolic( + "aten::sum", + decorate=[symbolic_helper._apply_params("ReduceSum", "sum")], +) +def _reduce_with_dtype(onnx_op, name): + symbolic = _reduce_op_symbolic(onnx_op) + + @symbolic_helper._overload_by_arg_count + def reduce(g, *args, **kwargs): + @symbolic_helper.parse_args("v", "none") + def reduce_nodim(g, self, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return symbolic_helper._unimplemented(name, "dtype", dtype) + result = symbolic(g, self) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + @symbolic_helper.parse_args("v", "v", "i", "none") + def reduce_dim(g, self, dim, keepdim, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return symbolic_helper._unimplemented(name, "dtype", dtype) + result = symbolic(g, self, dim, keepdim) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + return reduce_nodim, reduce_dim + + return reduce + + +# Ported from +# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/core.py#L6097 +# NOTE: Supporting aten::unflatten before opset13 needs helper function to adjust ONNX op changes in Concat, Slice, ... +@_onnx_symbolic("aten::unflatten") +def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size): + input_dim = symbolic_helper._get_tensor_rank(input) + if input_dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + ) + + # dim could be negative + input_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64)) + dim = g.op("Add", input_dim, dim) + dim = g.op("Mod", dim, input_dim) + + input_size = g.op("Shape", input) + + head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) + head_end_idx = g.op( + "Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) + ) + head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx) + + dim_plus_one = g.op( + "Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) + ) + tail_start_idx = g.op( + "Reshape", + dim_plus_one, + g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)), + ) + tail_end_idx = g.op( + "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) + ) + tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx) + + final_shape = g.op( + "Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0 + ) + + return symbolic_helper._reshape_helper(g, input, final_shape) + + +@_onnx_symbolic("aten::unsafe_chunk") +@symbolic_helper.parse_args("v", "i", "i", "i") +def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): + if _outputs is None: + return g.op( + "SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, + keepdims_i=0, + ) + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size") + split_size = (size + chunks - 1) // chunks + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + + # TODO: So far we don"t have a module using this method. We"ll keep + # this as a constant unless we see a request of dynamics in any + # user's modules. + splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long)) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::tile") +def tile(g: jit_utils.GraphContext, self, dims): + self_shape = g.op("Shape", self) + self_rank = g.op("Size", self_shape) + dims_rank = g.op("Size", dims) + diff = g.op("Sub", self_rank, dims_rank) + const_zero = g.op("Constant", value_t=torch.tensor([0])) + + # 1. If dims is shorter than self.shape pad dims with 1 + dims_shorter_than_self_shape = g.op("Greater", diff, const_zero) + ( + if_op_greater, + (if_context_greater, else_context_greater), + _, + ) = jit_utils.add_op_with_blocks( + g, "If", dims_shorter_than_self_shape, n_blocks=2, outputs=1 + ) + const_one = if_context_greater.op("Constant", value_t=torch.LongTensor([1])) + diff_1d_greater = if_context_greater.op("Reshape", diff, const_one) + exapnd_ones_greater = if_context_greater.op("Expand", const_one, diff_1d_greater) + dims_ = if_context_greater.op("Concat", exapnd_ones_greater, dims, axis_i=0) + utils._add_output_to_block(if_context_greater.block, dims_) + identity_dim = else_context_greater.op("Identity", dims) + utils._add_output_to_block(else_context_greater.block, identity_dim) + dims_final = if_op_greater.node().output() + + # 2. If dims is longer than self.shape pad self.shape with 1 + dims_longer_than_self_shape = g.op("Less", diff, const_zero) + ( + if_op_less, + (if_context_less, else_context_less), + _, + ) = jit_utils.add_op_with_blocks( + g, "If", dims_longer_than_self_shape, n_blocks=2, outputs=1 + ) + const_one = if_context_less.op("Constant", value_t=torch.LongTensor([1])) + diff_1d_less = if_context_less.op( + "Reshape", + if_context_less.op("Abs", diff), + const_one, + ) + exapnd_ones_less = if_context_less.op("Expand", const_one, diff_1d_less) + self_final_shape = if_context_less.op( + "Concat", exapnd_ones_less, self_shape, axis_i=0 + ) + self_ = if_context_less.op("Reshape", self, self_final_shape) + utils._add_output_to_block(if_context_less.block, self_) + identity_self = else_context_less.op("Identity", self) + utils._add_output_to_block(else_context_less.block, identity_self) + self_final = if_op_less.node().output() + + dims_final = g.op("Cast", dims_final, to_i=_C_onnx.TensorProtoDataType.INT64) + return g.op("Tile", self_final, dims_final) + + +@_onnx_symbolic("aten::repeat_interleave") +def repeat_interleave( + g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None +): + repeats_dim = symbolic_helper._get_tensor_rank(repeats) + repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) + input_sizes = symbolic_helper._get_tensor_sizes(self) + if repeats_dim is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", + self, + ) + if repeats_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", + self, + ) + if input_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown input size.", + self, + ) + + final_dim = dim + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if symbolic_helper._is_none(dim): + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1])) + ) + dim = torch.tensor(0, dtype=torch.int64) + else: + dim = symbolic_helper._maybe_get_scalar(dim) + + # Handle cases where dim is negative + if dim < 0: + dim += len(input_sizes) + + output_sizes = input_sizes.copy() + for idx, input_size in enumerate(input_sizes): + if input_size is None: + output_sizes[idx], input_sizes[idx] = 0, -1 + + # Check if all indices should be repeated the same number of times. + if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): + return symbolic_helper._repeat_interleave_single_value_repeat_helper( + g, self, repeats, dim + ) + + cond_dynamic_repeats = repeats_dim == 1 and repeats_sizes[0] is None + # If input size is dynamic or repeats vector is dynamic + if output_sizes[dim] == 0 or cond_dynamic_repeats: + reps = symbolic_helper._size_helper(g, self, dim) + reps = opset11.unsqueeze(g, reps, 0) + + # Check if repeats is dynamic + # As repeats is dynamic, we use a where node as a substitute for the if statement + # If repests_dim = 1, expand repeats otherwise use original tensor + if cond_dynamic_repeats: + repeat_dim = symbolic_helper._size_helper( + g, repeats, g.op("Constant", value_t=torch.LongTensor([0])) + ) + repeat_cond = g.op( + "Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1])) + ) + repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats) + # There are cases when the repeats are 1-d tensor with multiple repeats, but dim + # provided along one of the dynamic axes provided. A simple example would be + # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 + # Now, repeat interleaving can be performed in pytorch when the value of * matches + # with the number of elements in repeat, for example if * -> 2, number of repeats + # should be 2 as well. + else: + return opset9.repeat_interleave(g, self, repeats, final_dim) + + reps_like = g.op( + "ConstantOfShape", + g.op("Shape", repeats), + value_t=torch.tensor([1], dtype=torch.long), + ) + r_splits = split(g, repeats, reps_like, 0) + i_splits = split(g, self, reps_like, dim) + + output_sizes[dim], input_sizes[dim] = -1, 1 + + # Create a loop to iterate over each value along the dimension + # and perform individual interleaving using the repeats tensor + # Loop is of the following pattern + # input (trip_count, cond) + # int trip_count = ...; + # bool cond = ...; + # for (int i=0; i < trip_count && cond; ++i) { + # cond = ...; + # } + + # Loop conditions + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + loop_len = reps + + # Create an empty sequence to store final expansions + final_splits = g.op("SequenceEmpty") + + # Loop inputs + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, final_splits, n_blocks=1 + ) + + loop_block = loop_context.block + block_input_iter = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) + final_splits = utils._add_input_to_block(loop_block) + + r_split = loop_context.op("SequenceAt", r_splits, block_input_iter) + i_split = loop_context.op("SequenceAt", i_splits, block_input_iter) + + i_split = opset11.unsqueeze(loop_context, i_split, dim + 1) + r_concat = [ + loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[: dim + 1])), + r_split, + loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1 :])), + ] + r_concat = loop_context.op("Concat", *r_concat, axis_i=0) + i_split = opset9.expand(loop_context, i_split, r_concat, None) + i_split = symbolic_helper._reshape_helper( + loop_context, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes)) + ) + final_splits = loop_context.op("SequenceInsert", final_splits, i_split) + + # Loop outputs + cond_out = loop_context.op( + "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL + ) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, final_splits) + + loop_out = loop.node().output() + loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim) + return loop_out + + +@_onnx_symbolic("aten::diagonal") +@symbolic_helper.parse_args("v", "i", "i", "i") +def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2): + rank = symbolic_helper._get_tensor_rank(self) + # Replace negative indexing when rank is known + if rank is not None: + dim1 = dim1 if dim1 >= 0 else dim1 + rank + dim2 = dim2 if dim2 >= 0 else dim2 + rank + + dim1_size = opset9.size( + g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1])) + ) + dim2_size = opset9.size( + g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2])) + ) + # Create appropriate mask + mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0) + mask = opset9.zeros(g, mask_shape, None, None, None) + mask = g.op("EyeLike", mask, k_i=offset) + # dim1 and dim2 appended as a dimension at the end of the shape + + if rank is not None: + axes = list(range(rank)) + axes.remove(dim1) + axes.remove(dim2) + self = g.op("Transpose", self, perm_i=axes + [dim1, dim2]) + else: + return symbolic_helper._unimplemented("diagonal", "unknown input rank") + + # Multiply input and mask to calculate values along diagonal + # The mask consists of one values where diagonal values are to be calculated + # For example: + # [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0], + # [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0], + # [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]] + result = g.op("Mul", self, mask) + result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0) + + # Calculate gather indices based on offset and dims + # If offset is greater than zero, set offset to zero as this aids in + # calculation of selection window + offset_op = g.op("Constant", value_t=torch.LongTensor([offset])) + if offset >= 0: + diag_size = g.op( + "Max", + g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)), + g.op("Constant", value_t=torch.LongTensor([0])), + ) + offset = 0 + else: + diag_size = g.op( + "Max", + g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size), + g.op("Constant", value_t=torch.LongTensor([0])), + ) + diag_size = g.op("Concat", diag_size, axis_i=0) + + # Calculate which diagonal values to select + # For example, in cases with offsets: + # [[0, 1.1, 0] + # [0, 0, 2.2]] + # we need to select the last two columns, so we create a tensor + # with all columns that are to be selected + # So in this example, it is [1, 2] + select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None) + select_window = g.op( + "CumSum", + select_window_ones_fill, + g.op("Constant", value_t=torch.LongTensor([0])), + ) + select_window = g.op( + "Add", + select_window, + g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])), + ) + + gather_shape = [ + opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis]))) + for axis in list(range(rank))[:-2] + ] + gather_shape.append(diag_size) + gather_shape = g.op("Concat", *gather_shape, axis_i=0) + gather_indices = opset9.zeros(g, gather_shape, 4, None, None) + + # There might be cases where offset value is greater than number of rows/columns + # and might cause the diagonal to overrun and as a result of this, diag_size would be zero. + # For example, if + # offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows) + # diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above + # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0 + # In cases without diagonal overrun, we select the appropriate rows/columns along which we + # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has + # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially + # returning an empty tensor + overrun_cond = g.op( + "Not", + g.op( + "Equal", + diag_size, + g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)), + ), + ) + + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", overrun_cond, n_blocks=2 + ) + + gather_indices_if_block = if_context.op("Add", gather_indices, select_window) + gather_indices_if_block = symbolic_helper._unsqueeze_helper( + if_context, gather_indices_if_block, [rank - 1] + ) + final_non_overrun = if_context.op( + "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2 + ) + final_overrun = opset9.zeros(else_context, gather_shape, 6, None, None) + utils._add_output_to_block(if_context.block, final_non_overrun) + utils._add_output_to_block(else_context.block, final_overrun) + return if_op + + +# Quantized ops + + +@_onnx_symbolic("quantized::linear") +def quantized_linear( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::linear_relu") +def quantized_linear_relu( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d_relu") +def quantized_conv1d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d_relu") +def quantized_conv2d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d_relu") +def quantized_conv3d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d") +def quantized_conv1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d") +def quantized_conv2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d") +def quantized_conv3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose1d") +def quantized_conv_transpose1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose2d") +def quantized_conv_transpose2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose3d") +def quantized_conv_transpose3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose3d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset14.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset14.py new file mode 100644 index 0000000000000000000000000000000000000000..ae33ddf58c6e09245d1419b6d3f5bb749cd61198 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset14.py @@ -0,0 +1,283 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 14. + +Note [ONNX operators that are added/updated in opset 14] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +New operators: + HardSwish, Trilu + +Updated operators: + Reshape + Add, Sub, Mul, Div + GRU, LSTM, RNN + BatchNorm, Cumsum, Relu +""" + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md +from __future__ import annotations + +import functools + +import torch +from torch.onnx import _constants, _type_utils, symbolic_helper +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils, registration + + +__all__ = [ + "hardswish", + "tril", + "triu", + "reshape", + "batch_norm", + "quantized_hardswish", + "scaled_dot_product_attention", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14) + + +@_onnx_symbolic("aten::hardswish") +@symbolic_helper.parse_args("v") +def hardswish(g: jit_utils.GraphContext, self): + return g.op("HardSwish", self) + + +@_onnx_symbolic("aten::tril") +def tril(g: jit_utils.GraphContext, self, diagonal, out=None): + return g.op("Trilu", self, diagonal, upper_i=0) + + +@_onnx_symbolic("aten::triu") +def triu(g: jit_utils.GraphContext, self, diagonal, out=None): + return g.op("Trilu", self, diagonal, upper_i=1) + + +@_onnx_symbolic("aten::reshape") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v") +def reshape(g: jit_utils.GraphContext, self, shape): + # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664 + # Reshape export cannot utilize the new allowzero attribute introduced in opset 14. + return symbolic_helper._reshape_helper(g, self, shape, allowzero=0) + + +@_onnx_symbolic("aten::batch_norm") +@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") +def batch_norm( + g: jit_utils.GraphContext, + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + cudnn_enabled, +): + if ( + torch.is_autocast_enabled() + and not symbolic_helper.args_have_same_dtype( + [input, weight, bias, running_mean, running_var] + ) + and GLOBALS.export_onnx_opset_version < 15 + ): + return symbolic_helper._onnx_opset_unsupported_detailed( + "BatchNormalization", + 14, + 15, + "All input tensors must have the same `dtype`." + " Turn off Autocast or export using opset version 15.", + input, + ) + + symbolic_helper.check_training_mode(training, "batch_norm") + weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( + g, input, weight, bias, running_mean, running_var + ) + out = g.op( + "BatchNormalization", + input, + weight, + bias, + running_mean, + running_var, + epsilon_f=eps, + momentum_f=1 - momentum, + training_mode_i=0 if not training else 1, + outputs=1 if not training else 3, + ) + if not training: + return out + else: + res, new_running_mean, new_running_var = out + new_running_mean.setType(running_mean.type()) + new_running_var.setType(running_var.type()) + return res + + +@_onnx_symbolic("quantized::hardswish") +def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = hardswish(g, x) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +# Ported from +# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504 +# aten_scaled_dot_product_attention +# NOTE: Need op.Trilu +@_onnx_symbolic("aten::scaled_dot_product_attention") +@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b") +def scaled_dot_product_attention( + g: jit_utils.GraphContext, + query: torch._C.Value, + key: torch._C.Value, + value: torch._C.Value, + attn_mask: torch._C.Value | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: torch._C.Value | None = None, + enable_gqa: bool = False, +): + assert (not is_causal) or ( + is_causal and symbolic_helper._is_none(attn_mask) + ), "is_causal and attn_mask cannot be set at the same time" + assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + + if symbolic_helper._is_none(scale): + scale = _attention_scale(g, query) + + if is_causal: + attn_mask = _causal_attention_mask(g, query, key) + + # Swap the last two axes of key + # NOTE: onnx-script has different logic here, because the attribute perms in + # transpose needs list of ints + key_shape_builtin = symbolic_helper._get_tensor_rank(key) + key_transposed_axes = list(range(key_shape_builtin)) + key_transposed_axes[-1], key_transposed_axes[-2] = ( + key_transposed_axes[-2], + key_transposed_axes[-1], + ) + key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes) + + # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 + # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math + query_scaled = g.op("Mul", query, g.op("Sqrt", scale)) + key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale)) + mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) + + if symbolic_helper._is_none(attn_mask): + mul_qk_add = mul_qk + elif ( + _type_utils.JitScalarType.from_value(attn_mask) + == _type_utils.JitScalarType.BOOL + ): + # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) + const_zero = g.op("Constant", value_t=torch.tensor([0.0])) + const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) + attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) + mul_qk_add = g.op("Add", mul_qk, attn_mask) + elif _type_utils.JitScalarType.from_value(attn_mask) in ( + _type_utils.JitScalarType.FLOAT, + _type_utils.JitScalarType.HALF, + _type_utils.JitScalarType.BFLOAT16, + ): + mul_qk_add = g.op("Add", mul_qk, attn_mask) + else: + raise ValueError( + f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" + ) + + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) + + if dropout_p != 0: + attn_weight = g.op( + "Dropout", + attn_weight, + g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), + ) + + return g.op("MatMul", attn_weight, value) + + +def _attention_scale( + g: jit_utils.GraphContext, query: torch._C.Value +) -> torch._C.Value: + """Calculate the scale factor for the attention result. + + Args: + query: Tensor of shape [..., L, E] + + Returns: + Scalar scale factor := 1 / math.sqrt(query.size(-1)) + """ + query_shape = g.op("Shape", query) + query_shape_last = g.op( + "Slice", + query_shape, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), + g.op( + "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) + ), + ) + embedding_size = g.op( + "Cast", + query_shape_last, + to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), + ) + const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float)) + scale = g.op("Div", const_one, g.op("Sqrt", embedding_size)) + # Add a Cast to convert the scale back to original type + scale = g.op( + "Cast", + scale, + to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), + ) + return scale + + +def _causal_attention_mask( + g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value +) -> torch._C.Value: + """Create a causal mask for the given query and key tensors. + + Equivalent to:: + mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_mask = torch.zeros(L, S, dtype=torch.float) + attn_mask = attn_mask.masked_fill(not mask, -float("inf")) + + Args: + query: Tensor of shape [..., L, E] + key: Tensor of shape [..., S, E] + + Returns: + Tensor of shape [L, S] + """ + + query_shape = g.op("Shape", query) + key_shape = g.op("Shape", key) + + last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64)) + target_length = g.op("Slice", query_shape, second_last_idx, last_idx) + source_length = g.op("Slice", key_shape, second_last_idx, last_idx) + # attn_mask = torch.ones(L, S) := { + size = g.op("Concat", target_length, source_length, axis_i=0) + const_one = g.op("Constant", value_t=torch.tensor([1.0])) + attn_mask = g.op("Expand", const_one, size) + # } + attn_mask = g.op("Trilu", attn_mask, upper_i=0) + # The causal mask has 0s in the lower triangle and -inf in the upper triangle. + const_zero = g.op("Constant", value_t=torch.tensor([0.0])) + const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) + attn_mask = g.op( + "Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero + ) + return attn_mask diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset15.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset15.py new file mode 100644 index 0000000000000000000000000000000000000000..08f8dcbf5a2266774f878bf2e91693e280da4228 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset15.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 15. + +Note [ONNX operators that are added/updated in opset 15] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set +New operators: + Bernoulli + CastLike + Optional + OptionalGetElement + OptionalHasElement + +Updated operators: + BatchNormalization https://github.com/onnx/onnx/pull/3545 + Backwards compatible + TODO: test coverage for mixed types inputs. + Pow https://github.com/onnx/onnx/pull/3412 + Backwards compatible + TODO: bfloat16 support. + Shape https://github.com/onnx/onnx/pull/3580 + Backwards compatible + TODO: optional start/end attribute. +""" + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +import functools + +import torch +from torch import _C +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15) + + +@_onnx_symbolic("aten::__is_") +def aten__is_(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_none(other): + if isinstance(self.type(), _C.OptionalType): + none = g.op("OptionalHasElement", self) + return g.op("Not", none) + else: + return g.op("Constant", value_t=torch.BoolTensor([0])) + return opset9.eq(g, self, other) + + +@_onnx_symbolic("aten::__isnot_") +@opset9.wrap_logical_op_with_negation # type: ignore[has-type] +def aten__isnot_(g: jit_utils.GraphContext, self, other): + return aten__is_(g, self, other) + + +@_onnx_symbolic("aten::bernoulli") +def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): + if out is not None and not symbolic_helper._is_none(out): + symbolic_helper._unimplemented( + "Bernoulli", "out parameter is not supported for bernoulli", input + ) + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Bernoulli", "generator is not supported for bernoulli", input + ) + if p is None or symbolic_helper._is_none(p): + return g.op("Bernoulli", input) + return opset9.bernoulli(g, input, p, generator, out) + + +@_onnx_symbolic("prim::unchecked_cast") +def prim_unchecked_cast(g: jit_utils.GraphContext, self): + # exists to refine the type of the Value + # if x is Optional[Tensor], unchecked_cast will cast + # x to Tensor, so the rest of the graph knows that x is a Tensor. + if isinstance(self.type(), _C.OptionalType): + return g.op("OptionalGetElement", self) + + return self diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset16.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset16.py new file mode 100644 index 0000000000000000000000000000000000000000..d4a7baa78c2d574be7bd12869d9dc6b6a3ca1e87 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset16.py @@ -0,0 +1,185 @@ +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 16. + +Note [ONNX Operators that are added/updated in opset 16] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set +New operators: + GridSample https://github.com/onnx/onnx/pull/3557 + +Updated operators: + Identity + If + LeakyRelu + Loop + PRelu + RoiAlign + Scan + ScatterElements + ScatterND + Where + GreaterOrEqual + LessOrEqual +""" + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +import functools + +import torch +from torch.nn.functional import ( + GRID_SAMPLE_INTERPOLATION_MODES, + GRID_SAMPLE_PADDING_MODES, +) +from torch.onnx import _type_utils, errors, symbolic_helper, utils +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16) + + +# note (mkozuki): Why `grid_sampler` instead of `grid_sample`? +# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`. +@_onnx_symbolic("aten::grid_sampler") +@symbolic_helper.parse_args("v", "v", "i", "i", "b") +def grid_sampler( + g: jit_utils.GraphContext, + input, + grid, + mode_enum, + padding_mode_enum, + align_corners, +): + # Check the input and grid tensor rank beforehand. + if symbolic_helper._get_tensor_rank(input) == 5: + return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input") + mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg] + padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg] + padding_mode_enum + ] + return g.op( + "GridSample", + input, + grid, + align_corners_i=int(align_corners), + mode_s=mode_s, + padding_mode_s=padding_mode_s, + ) + + +@_onnx_symbolic("aten::scatter_add") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): + src_type = _type_utils.JitScalarType.from_value( + src, _type_utils.JitScalarType.UNDEFINED + ) + src_sizes = symbolic_helper._get_tensor_sizes(src) + index_sizes = symbolic_helper._get_tensor_sizes(index) + + if len(src_sizes) != len(index_sizes): + return symbolic_helper._unimplemented( + "scatter_add", + f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})", + ) + + # PyTorch only allows index shape <= src shape, so we can only consider + # taking index as subset size to src, like PyTorch does. When sizes for src + # and index are not matched or there are dynamic axes, we take index shape to + # slice src to accommodate. + if src_sizes != index_sizes or None in index_sizes: + adjusted_shape = g.op("Shape", index) + starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes))) + src = g.op("Slice", src, starts, adjusted_shape) + + src = symbolic_helper._maybe_get_scalar(src) + if symbolic_helper._is_value(src): + return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add") + else: + # Check if scalar "src" has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + if _type_utils.JitScalarType.from_value(self) != src_type: + src = g.op( + "Cast", + src, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + + return g.op( + "ScatterElements", + self, + index, + src, + axis_i=dim, + reduction_s="add", + ) + + +@_onnx_symbolic("aten::scatter_reduce") +@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b") +def scatter_reduce( + g: jit_utils.GraphContext, + self: torch._C.Value, + dim: int, + index: torch._C.Value, + src: torch._C.Value, + reduce: str, + include_self: bool, +): + if reduce == "mean": + raise errors.OnnxExporterError( + "ONNX does not support mean reduction for scatter_reduce" + ) + if not include_self: + raise errors.OnnxExporterError( + "ONNX does not support include_self=False for scatter_reduce" + ) + + reduce_mode = { # convert torch string name to onnx string name + "mean": "none", # 'mean' doesn't support in ONNX 1.14 definition + "sum": "add", + "prod": "mul", + "amin": "min", + "amax": "max", + } + onnx_reduce = reduce_mode[reduce] + + self_rank = g.op("Size", g.op("Shape", self)) + + # if self_rank == 0: # assert (index_rank == 0 and rank_src == 0) + self_rank_is_zero = g.op( + "Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) + ) + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", self_rank_is_zero, n_blocks=2, outputs=3 + ) + neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + + self_reshape = if_context.op("Reshape", self, neg_1) + utils._add_output_to_block(if_context.block, self_reshape) + index_reshape = if_context.op("Reshape", index, neg_1) + utils._add_output_to_block(if_context.block, index_reshape) + src_reshape = if_context.op("Reshape", src, neg_1) + utils._add_output_to_block(if_context.block, src_reshape) + + self_identity = else_context.op("Identity", self) + utils._add_output_to_block(else_context.block, self_identity) + index_identitye = else_context.op("Identity", index) + utils._add_output_to_block(else_context.block, index_identitye) + src_identity = else_context.op("Identity", src) + utils._add_output_to_block(else_context.block, src_identity) + + result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce) + + # if self_rank == 0: + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", self_rank_is_zero, n_blocks=2, outputs=1 + ) + result_squeezed = if_context.op("Squeeze", result) + utils._add_output_to_block(if_context.block, result_squeezed) + result_identity = else_context.op("Identity", result) + utils._add_output_to_block(else_context.block, result_identity) + result_final = if_op.node().output() + + return result_final diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset17.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset17.py new file mode 100644 index 0000000000000000000000000000000000000000..0aca6634d2f69be0033fe71af678158a8c7b6130 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset17.py @@ -0,0 +1,231 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 17. + +Note [ONNX Operators that are added/updated in opset 17] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set +New operators: + BlackmanWindow + DFT + HammingWindow + HannWindow + LayerNormalization + MelWeightMatrix + STFT + SequenceMap +""" + +import functools +from typing import Optional, Sequence + +import torch +from torch import _C +from torch.onnx import _type_utils, errors, symbolic_helper +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +__all__ = ["layer_norm", "stft", "quantized_layer_norm"] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) + + +@_onnx_symbolic("aten::layer_norm") +@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") +def layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, + cudnn_enable: bool, +): + # normalized_shape: input shape from an expected input of size + # axis: The first normalization dimension. + # layer_norm normalizes on the last D dimensions, + # where D is the size of normalized_shape + axis = -len(normalized_shape) + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + dtype = scalar_type.dtype() + if symbolic_helper._is_none(weight): + weight_value = torch.ones(normalized_shape, dtype=dtype) + weight = g.op("Constant", value_t=weight_value) + if symbolic_helper._is_none(bias): + bias_value = torch.zeros(normalized_shape, dtype=dtype) + bias = g.op("Constant", value_t=bias_value) + return g.op( + "LayerNormalization", + input, + weight, + bias, + epsilon_f=eps, + axis_i=axis, + ) + + +@_onnx_symbolic("quantized::layer_norm") +def quantized_layer_norm( + g: jit_utils.GraphContext, + x, + normalized_shape, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = layer_norm(g, x, normalized_shape, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +def _compute_edge_sizes(n_fft, window_size): + """Helper function to compute the sizes of the edges (left and right) + of a given window centered within an FFT size.""" + left = (n_fft - window_size) // 2 + right = n_fft - left - window_size + return left, right + + +@_onnx_symbolic("aten::stft") +@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b") +def stft( + g: jit_utils.GraphContext, + input: _C.Value, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[_C.Value] = None, + normalized: bool = False, + onesided: Optional[bool] = True, + return_complex: Optional[bool] = False, +) -> _C.Value: + """Associates `torch.stft` with the `STFT` ONNX operator. + Note that torch.stft calls _VF.stft, without centering or padding options. + Hence, this function does not contain these two arguments. + See torch.stft source code for more info. + + Args: + g: Graph to write the ONNX representation into + input: Input tensor for the transformation + n_fft: FFT size + hop_length: Size of the hop. Defaults to `floot(n_fft // 4)` + win_length: Size of the analysis window. Defaults to `n_fft` + window: Analysis window. Defaults to a window of all ones + normalized: Whether to return a normalized STFT + onesided: Whether to return only half (+1) of the results, given the + symmetry of the STFT + return_complex: Whether to return the complex value (Note: Must be + `False` or `None`) + + Returns: + op: Operator for torch.stft associated with STFT (ONNX) + """ + # Checks + if return_complex: + raise errors.SymbolicValueError( + msg="STFT does not currently support complex types", value=input + ) + + # Get STFT sizes + frame_step_value = hop_length if hop_length is not None else n_fft // 4 + frame_step_const = g.op( + "Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64) + ) + frame_length_const = g.op( + "Constant", value_t=torch.tensor(n_fft, dtype=torch.int64) + ) + + # Pre-process input if needed + signal = input + signal_rank = symbolic_helper._get_tensor_rank(signal) + if signal_rank == 1: + # Add batch dimension + signal = g.op( + "Unsqueeze", + signal, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + elif signal_rank is None or signal_rank > 2: + raise errors.SymbolicValueError( + msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. " + f"Current rank of signal is {signal_rank}, please reduce it.", + value=input, + ) + + # Get window and make sure it's the same size as `win_length` or `n_fft` + n_win = symbolic_helper._get_tensor_dim_size(window, dim=0) + if n_win is not None: + win_length_default = win_length if win_length else n_fft + assert n_win == win_length_default, ( + "Analysis window size must equal `win_length` or `n_fft`. " + f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})", + ) + + # Center window around zeros if needed (required by ONNX's STFT) + if n_win < n_fft: + left, right = _compute_edge_sizes(n_fft, n_win) + left_win = g.op("Constant", value_t=torch.zeros(left)) + right_win = g.op("Constant", value_t=torch.zeros(right)) + window = g.op("Concat", left_win, window, right_win, axis_i=0) + + # Create window, if needed + if symbolic_helper._is_none(window): + if win_length: + if win_length > n_fft: + raise errors.SymbolicValueError( + msg="The analysis window can't be longer than the size of the FFT. " + f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.", + value=input, + ) + + # Center window, if needed + left, right = _compute_edge_sizes(n_fft, win_length) + torch_window = torch.hstack( + (torch.zeros(left), torch.ones(win_length), torch.zeros(right)) + ) + else: + # Rectangle window + torch_window = torch.ones(n_fft) + assert torch_window.shape[0] == n_fft + window = g.op("Constant", value_t=torch_window) + window = g.op( + "Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type() + ) + + # Run STFT + result = g.op( + "STFT", + signal, + frame_step_const, + window, + frame_length_const, + onesided_i=1 if onesided is None or onesided else 0, + ) + + # Transpose to mimic torch.stft's behavior + result = g.op("Transpose", result, perm_i=[0, 2, 1, 3]) + + # Remove batch dimension, if needed + if signal_rank == 1: + result = g.op( + "Squeeze", + result, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + # Normalize, if needed + if normalized: + sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype())) + result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft)) + + return result diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset18.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset18.py new file mode 100644 index 0000000000000000000000000000000000000000..d28fadc1bab1de5b21366c4bd5aa855cb33938b8 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset18.py @@ -0,0 +1,265 @@ +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 18. + +Note [ONNX Operators that are added/updated in opset 18] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set +New operators: + BitwiseAnd + CenterCropPad + Col2Im + Mish + OptionalGetElement + OptionalHasElement + Pad + Resize + ScatterElements + ScatterND + Split +""" + +import functools +from typing import List, Optional, Sequence, Tuple + +import torch +from torch import _C +from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +__all__ = [ + "col2im", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18) + + +@_onnx_symbolic("aten::__and_") +@_onnx_symbolic("aten::bitwise_and") +def __and_(g: jit_utils.GraphContext, self, other): + # do type promotion (scalars don't seem to apply) + args = [self, other] + # type promotion doesn't happen with torch.bitwise_and(tensor, scalar) + prom_args = [arg for arg in args if symbolic_helper._get_tensor_rank(arg)] + if len(prom_args) == 0: + prom_args = args + promotion_jit_type = symbolic_helper._type_promote_from_values(*prom_args) + self = symbolic_helper._maybe_cast_to_type(g, self, promotion_jit_type) + other = symbolic_helper._maybe_cast_to_type(g, other, promotion_jit_type) + if promotion_jit_type == _type_utils.JitScalarType.BOOL: + return g.op("And", self, other) + return g.op("BitwiseAnd", self, other) + + +@_onnx_symbolic("aten::col2im") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is") +def col2im( + g, + input: _C.Value, + output_size: _C.Value, + kernel_size: _C.Value, + dilation: Sequence[int], + padding: Sequence[int], + stride: Sequence[int], +): + # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in] + adjusted_padding = [] + for pad in padding: + for _ in range(2): + adjusted_padding.append(pad) + + num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0] + if not adjusted_padding: + adjusted_padding = [0, 0] * num_dimensional_axis + + if not dilation: + dilation = [1] * num_dimensional_axis + + if not stride: + stride = [1] * num_dimensional_axis + + return g.op( + "Col2Im", + input, + output_size, + kernel_size, + dilations_i=dilation, + pads_i=adjusted_padding, + strides_i=stride, + ) + + +@_onnx_symbolic( + "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] +) +@_onnx_symbolic( + "aten::prod", + decorate=[ + symbolic_helper._apply_params( + "ReduceProd", "prod", allow_multi_dim_support=False + ) + ], +) +def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): + return symbolic_helper._reduce_with_dtype_helper( + onnx_op, name, allow_multi_dim_support + ) + + +@_onnx_symbolic("aten::native_layer_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "is", "v", "v", "f") +def _native_layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, +) -> Tuple[_C.Value, _C.Value, _C.Value]: + return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps) + + +@_onnx_symbolic("aten::glu") +@symbolic_helper.parse_args("v", "i") +def _glu(g: jit_utils.GraphContext, input, dim): + dim_size = symbolic_helper._get_tensor_dim_size(input, dim) + if dim_size is not None: + assert dim_size % 2 == 0 + + first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2) + return g.op("Mul", first, g.op("Sigmoid", second)) + + +@_onnx_symbolic("aten::max") +# torch.max (same for torch.min) actually has two interfaces smashed together: +# torch.max(x, dim, keepdim) and torch.max(x, y) +# TODO(justinchuby): Support multiple quantized args in output +def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::maximum") +@symbolic_helper.quantized_args(True, True) +def maximum(g: jit_utils.GraphContext, input, other): + return max(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::min") +# TODO(justinchuby): Support multiple quantized args in output +def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::minimum") +@symbolic_helper.quantized_args(True, True) +def minimum(g: jit_utils.GraphContext, input, other): + return min(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::amax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amax(g: jit_utils.GraphContext, self, dim, keepdim): + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + return g.op("ReduceMax", self, axes, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::amin") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amin(g: jit_utils.GraphContext, self, dim, keepdim): + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + return g.op("ReduceMin", self, axes, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::aminmax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "i") +def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): + if not symbolic_helper._is_none(dim): + dim = symbolic_helper._get_const(dim, "i", "dim") + axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op( + "ReduceMax", self, axes, keepdims_i=keepdim + ) + else: + return g.op("ReduceMin", self, keepdims_i=keepdim), g.op( + "ReduceMax", self, keepdims_i=keepdim + ) + + +@_onnx_symbolic("aten::var_mean") +def _var_mean(g: jit_utils.GraphContext, input, *args): + if len(args) == 1: + return symbolic_helper._var_mean_helper(g, input, None, args[0], None) + else: + return symbolic_helper._var_mean_helper(g, input, *args) + + +@_onnx_symbolic("aten::logsumexp") +@symbolic_helper.parse_args("v", "is", "i") +def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): + if dim is None: + return g.op("ReduceLogSumExp", input, keepdims_i=0) + else: + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::linalg_matrix_norm") +@symbolic_helper.parse_args("v", "v", "is", "b", "v") +def _linalg_matrix_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: torch._C.Value, + dim: List[int], + keepdim: bool, + dtype: torch._C.Value, +): + return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + return symbolic_helper._embedding_bag_helper( + g, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) + + +@_onnx_symbolic("aten::linalg_vector_norm") +@symbolic_helper.parse_args("v", "f", "is", "b", "v") +def linalg_vector_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: float, + dim: Optional[Sequence[int]], + keepdim: bool, + dtype: torch._C.Value, +): + return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset19.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset19.py new file mode 100644 index 0000000000000000000000000000000000000000..a97dda26f81f5562fc26a1d4f8b54e6460695113 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset19.py @@ -0,0 +1,33 @@ +"""This file exports ONNX ops for opset 19. + +Note [ONNX Operators that are added/updated in opset 19] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-19-of-the-default-onnx-operator-set +New operators: +AveragePool +Cast +CastLike +Constant +DeformConv +DequantizeLinear +Equal +Identity +If +Loop +Pad +QuantizeLinear +Reshape +Resize +Scan +Shape +Size +""" + +from typing import List + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +__all__: List[str] = [] diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset20.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset20.py new file mode 100644 index 0000000000000000000000000000000000000000..d96f770ca11e26db582768332b0860c7c94558ae --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset20.py @@ -0,0 +1,92 @@ +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 20. + +Note [ONNX Operators that are added/updated in opset 20] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set +New operators: + AffineGrid + ConstantOfShape + DFT + Gelu + GridSample + ImageDecoder + IsInf + IsNaN + ReduceMax + ReduceMin + RegexFullMatch + StringConcat + StringSplit +""" + +import functools + +import torch.nn.functional as F +from torch import _C +from torch.onnx import symbolic_helper +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"] + + +def convert_grid_sample_mode(mode_s): + return ( + "linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s + ) + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20) + + +@_onnx_symbolic("aten::grid_sampler") +@symbolic_helper.parse_args("v", "v", "i", "i", "b") +def _grid_sampler( + g: jit_utils.GraphContext, + input: _C.Value, + grid: _C.Value, + mode_enum: int, + padding_mode_enum: int, + align_corners: bool, +): + mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index] + # mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html + mode_s = convert_grid_sample_mode(mode_s) + padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index] + padding_mode_enum # type: ignore[index] + ] + return g.op( + "GridSample", + input, + grid, + align_corners_i=int(align_corners), + mode_s=mode_s, + padding_mode_s=padding_mode_s, + ) + + +@_onnx_symbolic("aten::affine_grid_generator") +@symbolic_helper.parse_args("v", "v", "b") +def _affine_grid_generator( + g: jit_utils.GraphContext, + theta: _C.Value, + size: _C.Value, + align_corners: bool, +): + return g.op( + "AffineGrid", + theta, + size, + align_corners_i=int(align_corners), + ) + + +@_onnx_symbolic("aten::gelu") +@symbolic_helper.parse_args("v", "s") +def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"): + return g.op("Gelu", self, approximate_s=approximate) diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset7.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset7.py new file mode 100644 index 0000000000000000000000000000000000000000..c647ead4e2975e0497a7de9b8b5801617443ab27 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset7.py @@ -0,0 +1,67 @@ +# mypy: allow-untyped-defs +""" +Note [ONNX operators that are added/updated from opset 7 to opset 8] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +New operators: + Expand + +Updated operators: + Min, Max, Sum, Mean: supports multidirectional broadcasting. + MaxPool: added optional indices output. + Scan +""" + +import functools +import warnings + +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=7) + +block_listed_operators = ( + "scan", + "expand", + "expand_as", + "meshgrid", + "adaptive_max_pool1d", + "adaptive_max_pool2d", + "adaptive_max_pool3d", + "max_pool1d_with_indices", + "max_pool2d_with_indices", + "max_pool3d_with_indices", +) + + +# NOTE: max, min, sum, mean: broadcasting is not supported in opset 7. +# torch.max (same for torch.min) actually has two interfaces smashed together: +# torch.max(x, dim, keepdim) and torch.max(x, y) +@_onnx_symbolic("aten::max") +def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.max(input, other) + if keepdim is None and dim_or_y is not None: + warnings.warn( + "Multidirectional broadcasting is not supported in opset 7. " + "This might cause the onnx model to be incorrect, if inputs to max operators " + "have different shapes" + ) + return opset9.max(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::min") +def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.min(input, other) + if keepdim is None and dim_or_y is not None: + warnings.warn( + "Multidirectional broadcasting is not supported in opset 7. " + "This might cause the onnx model to be incorrect, if inputs to min operators " + "have different shapes" + ) + return opset9.min(g, self, dim_or_y, keepdim) + + +for block_listed_op in block_listed_operators: + _onnx_symbolic(f"aten::{block_listed_op}")( + symbolic_helper._block_list_in_opset(block_listed_op) + ) diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset8.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset8.py new file mode 100644 index 0000000000000000000000000000000000000000..41abf46be2a0af9facde773d026efacb32ac51d7 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset8.py @@ -0,0 +1,463 @@ +# mypy: allow-untyped-defs +""" +Note [ONNX operators that are added/updated from opset 8 to opset 9] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +New operators: + Compress + ConstantOfShape + EyeLike + MaxUnpool + OneHot + Sinh + Cosh + Asinh + Acosh + Atanh + Shrink + IsNaN + Sign + Erf + Scatter + Where + NonZero + TfIdfVectorizer + MeanVarianceNormalization + +Updated operators: + BatchNormalization: removed spatial attribute. + Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported. + Cast: more data types{string} supported. + Upsample: moved scales from attribute to input. + Scan +""" + +import functools +import warnings + +import torch +from torch._C import _onnx as _C_onnx +from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8) + +block_listed_operators = ( + "nonzero", + "where", + "scatter", + "scatter_add", + "erf", + "sign", + "isnan", + "gather", + "arange", + "masked_fill", + "index_fill", + "index_copy", + "repeat_interleave", + "any", + "all", +) + +for block_listed_op in block_listed_operators: + _onnx_symbolic(f"aten::{block_listed_op}")( + symbolic_helper._block_list_in_opset(block_listed_op) + ) + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], +) +def _interpolate(name, dim, interpolate_mode): + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = symbolic_helper._get_interpolate_attributes( + g, interpolate_mode, args + ) + symbolic_helper._interpolate_warning(interpolate_mode) + align_corners = symbolic_helper._maybe_get_scalar(align_corners) + if align_corners: + return symbolic_helper._unimplemented(name, "align_corners == True", input) + output_size = symbolic_helper._maybe_get_const(output_size, "is") + if symbolic_helper._is_value(output_size): + return symbolic_helper._unimplemented( + name, "torch._C.Value (output_size) indexing" + ) + if scales is None: + scales = [ + 1.0 + if i < 2 + else float(output_size[-(dim - i)]) + / float(input.type().sizes()[-(dim - i)]) + for i in range(0, dim) + ] + return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales) + + return symbolic_fn + + +@_onnx_symbolic("aten::__interpolate") +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + align_corners = symbolic_helper._maybe_get_const(align_corners, "b") + if not symbolic_helper._is_none(align_corners) and align_corners: + return symbolic_helper._unimplemented("interpolate", "align_corners == True") + + if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value( + scale_factor + ): + return symbolic_helper._unimplemented( + "interpolate", "dynamic scales in opset 8" + ) + + if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size): + return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8") + + scales, mode = symbolic_helper._interpolate_get_scales_and_mode( + g, input, size, scale_factor, mode, align_corners + ) + return g.op("Upsample", input, mode_s=mode, scales_f=scales) + + +# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation +# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which +# is lost after casting. +def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args): + floating_scalar_types = { + _type_utils.JitScalarType.HALF, + _type_utils.JitScalarType.FLOAT, + _type_utils.JitScalarType.DOUBLE, + } + old_type = None + # Cast the input tensor to Float if its scalarType is known and is not floating number. + # If casting is performed, return the old scalarType, otherwise return None. + arg0_type = _type_utils.JitScalarType.from_value( + args[0], _type_utils.JitScalarType.UNDEFINED + ) + if arg0_type != _type_utils.JitScalarType.UNDEFINED: + old_type = arg0_type + if old_type not in floating_scalar_types: + old_type = old_type.scalar_name() # type: ignore[assignment] + args = tuple( + g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT) + for arg in args + ) + else: + return (None,) + args + else: + warnings.warn( + "Only floating datatype is supported for these operators: " + "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause " + "the onnx model to be incorrect, if inputs have integer datatypes." + ) + return (old_type,) + args + + +def _cast_to_type(g: jit_utils.GraphContext, input, to_type): + if to_type is None: + return input + return getattr(opset9, f"_cast_{to_type}")(g, input, False) + + +def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name): + other = symbolic_helper._maybe_get_scalar(other) + other = symbolic_helper._if_scalar_type_as(other, input) + _, input, other = _try_cast_integer_to_float(g, input, other) + return g.op(op_name, input, other) + + +# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten}, +# integer input type not supported in opset8. Cast to float if possible. +@_onnx_symbolic("aten::gt") +def gt(g: jit_utils.GraphContext, input, other): + return _comparison_operator(g, input, other, "Greater") + + +@_onnx_symbolic("aten::lt") +def lt(g: jit_utils.GraphContext, input, other): + return _comparison_operator(g, input, other, "Less") + + +@_onnx_symbolic("aten::bmm") +def bmm(g: jit_utils.GraphContext, self, other): + if symbolic_helper._try_get_scalar_type(self): + old_type, self, other = _try_cast_integer_to_float(g, self, other) + return _cast_to_type(g, g.op("MatMul", self, other), old_type) + else: + return g.op("MatMul", self, other) + + +@_onnx_symbolic("aten::matmul") +def matmul(g: jit_utils.GraphContext, self, other): + return bmm(g, self, other) + + +@_onnx_symbolic("aten::prelu") +def prelu(g: jit_utils.GraphContext, self, weight): + self_rank = symbolic_helper._get_tensor_rank(self) + weight_sizes = symbolic_helper._get_tensor_sizes(weight) + if self_rank is not None and self_rank > 2: + weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) + elif self_rank == 0 and weight_sizes == [1]: + # self and weight are both scalar but weight has rank == 1, squeeze weight. + weight = symbolic_helper._squeeze_helper(g, weight, [0]) + if symbolic_helper._try_get_scalar_type(self): + old_type, self, weight = _try_cast_integer_to_float(g, self, weight) + return _cast_to_type(g, g.op("PRelu", self, weight), old_type) + else: + return g.op("PRelu", self, weight) + + +@_onnx_symbolic("aten::mm") +def mm(g: jit_utils.GraphContext, self, other): + # Create a dummy C tensor. Only needed for API purposes, the value is + # since beta = 0 + scalar_type = symbolic_helper._try_get_scalar_type(self, other) + if scalar_type is None: + raise errors.SymbolicValueError( + "mm can only operate on tensors with known types", self + ) + zero_constant = g.op( + "Constant", + value_t=torch.tensor([0], dtype=scalar_type.dtype()), + ) + + if symbolic_helper._try_get_scalar_type(self): + old_type, self, other, zero_constant = _try_cast_integer_to_float( + g, self, other, zero_constant + ) + return _cast_to_type( + g, + g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0), + old_type, + ) + return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0) + + +@_onnx_symbolic("aten::addmm") +@symbolic_helper.parse_args("v", "v", "v", "t", "t") +def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): + if symbolic_helper._try_get_scalar_type(self): + old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2) + return _cast_to_type( + g, + g.op( + "Gemm", + mat1, + mat2, + self, + beta_f=symbolic_helper._scalar(beta), + alpha_f=symbolic_helper._scalar(alpha), + ), + old_type, + ) + else: + return g.op( + "Gemm", + mat1, + mat2, + self, + beta_f=symbolic_helper._scalar(beta), + alpha_f=symbolic_helper._scalar(alpha), + ) + + +@_onnx_symbolic("aten::flatten") +def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): + start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim") + end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim") + + dim = input.type().dim() + if end_dim_i < 0: + end_dim_i = dim + end_dim_i + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim_i == 1 and end_dim_i == dim - 1: + if symbolic_helper._try_get_scalar_type(input): + old_type, input = _try_cast_integer_to_float(g, input) + return _cast_to_type( + g, g.op("Flatten", input, axis_i=start_dim_i), old_type + ) + else: + return g.op("Flatten", input, axis_i=start_dim_i) + if start_dim_i == 0 and end_dim_i == dim - 2: + if symbolic_helper._try_get_scalar_type(input): + old_type, input = _try_cast_integer_to_float(g, input) + return _cast_to_type( + g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type + ) + else: + return g.op("Flatten", input, axis_i=end_dim_i + 1) + + return opset9.flatten(g, input, start_dim, end_dim) + + +def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value): + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + if not scalar_type.dtype().is_floating_point: + result = g.op( + "ConstantFill", + sizes, + dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(), + input_as_shape_i=1, + value_f=const_value, + ) + return g.op("Cast", result, to_i=scalar_type.onnx_type()) + else: + return g.op( + "ConstantFill", + sizes, + dtype_i=scalar_type.onnx_type(), + input_as_shape_i=1, + value_f=const_value, + ) + + +@_onnx_symbolic("aten::empty") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty( + g: jit_utils.GraphContext, + sizes, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + return zeros(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::empty_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty_like( + g: jit_utils.GraphContext, + input, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + return zeros_like(g, input, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::zeros") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + # NOTE: no way to set device and layout in ONNX, so we ignore it + return _constant_fill(g, sizes, dtype, 0) + + +@_onnx_symbolic("aten::zeros_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def zeros_like( + g: jit_utils.GraphContext, + input, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + return _constant_fill(g, shape, dtype, 0) + + +@_onnx_symbolic("aten::ones") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + return _constant_fill(g, sizes, dtype, 1) + + +@_onnx_symbolic("aten::ones_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def ones_like( + g: jit_utils.GraphContext, + input, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + return _constant_fill(g, shape, dtype, 1) + + +@_onnx_symbolic("aten::full") +def full( + g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False +): + const_value = symbolic_helper._maybe_get_const(value, "t") + if symbolic_helper._is_value(const_value): + tmp = zeros(g, sizes, dtype, layout, device) + return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) + else: + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + return _constant_fill(g, sizes, dtype, const_value) + + +@_onnx_symbolic("aten::full_like") +@symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v") +def full_like( + g: jit_utils.GraphContext, + input, + fill_value, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + return _constant_fill(g, shape, dtype, fill_value) + + +@_onnx_symbolic("aten::repeat") +def repeat(g: jit_utils.GraphContext, self, repeats): + if not symbolic_helper._is_value(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + if symbolic_helper._is_packed_list(repeats): + repeat_size_len = len(symbolic_helper._unpack_list(repeats)) + else: + const_repeats = symbolic_helper._maybe_get_const(repeats, "is") + repeat_size_len = len(const_repeats) + if self.isCompleteTensor(): + sizes = self.type().sizes() + diff_dims = repeat_size_len - len(sizes) + if diff_dims > 0: + self = opset9.view( + g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes)) + ) + return g.op("Tile", self, repeats) diff --git a/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py b/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcba2f93d04ae27709f7f0acc67f846acd0a01d --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py @@ -0,0 +1,6637 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 9. + +Opset 9 is supported by ONNX release 1.4.1 +release on 01/23/19 +""" + +from __future__ import annotations + +import builtins +import functools +import math +import sys +import warnings +from typing import Callable, Sequence, TYPE_CHECKING + +import torch +import torch._C._onnx as _C_onnx +import torch.nn.modules.utils +import torch.onnx +from torch import _C + +# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics +from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils, registration + + +if TYPE_CHECKING: + from torch.types import Number + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +__all__ = [ + "abs", + "acos", + "add", + "addcmul", + "addmm", + "alias", + "amax", + "amin", + "aminmax", + "arange", + "argmax", + "argmin", + "as_strided", + "as_tensor", + "asin", + "atan", + "atan2", + "baddbmm", + "batch_norm", + "bernoulli", + "bitwise_not", + "bitwise_or", + "bmm", + "broadcast_tensors", + "broadcast_to", + "bucketize", + "cat", + "cdist", + "ceil", + "clamp_max", + "clamp_min", + "clamp", + "clone", + "constant_pad_nd", + "contiguous", + "conv_tbc", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "conv1d", + "conv2d", + "conv3d", + "convert_element_type", + "convolution", + "cos", + "cosine_similarity", + "cross", + "cumsum", + "detach", + "dim", + "div", + "dot", + "dropout", + "elu", + "embedding_bag", + "embedding", + "empty_like", + "empty", + "eq", + "erf", + "exp", + "expand_as", + "expand", + "eye", + "fill", + "flatten", + "floor_divide", + "floor", + "floordiv", + "frobenius_norm", + "full_like", + "full", + "gather", + "ge", + "gelu", + "get_pool_ceil_padding", + "glu", + "group_norm", + "gt", + "hann_window", + "hardshrink", + "hardsigmoid", + "hardswish", + "hardtanh", + "index_add", + "index_copy", + "index_fill", + "index_put", + "index_select", + "index", + "instance_norm", + "is_floating_point", + "is_pinned", + "isnan", + "item", + "kl_div", + "layer_norm", + "le", + "leaky_relu", + "lerp", + "lift", + "linalg_cross", + "linalg_matrix_norm", + "linalg_norm", + "linalg_vector_norm", + "linear", + "linspace", + "log_sigmoid", + "log_softmax", + "log", + "log10", + "log1p", + "log2", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logit", + "logsumexp", + "lstm_cell", + "lstm", + "lt", + "masked_fill", + "masked_fill_", + "matmul", + "max_pool1d_with_indices", + "max_pool2d_with_indices", + "max_pool3d_with_indices", + "max", + "maximum", + "meshgrid", + "min", + "minimum", + "mish", + "mm", + "movedim", + "mse_loss", + "mul", + "multinomial", + "mv", + "narrow", + "native_layer_norm", + "ne", + "neg", + "new_empty", + "new_full", + "new_ones", + "new_zeros", + "nonzero_numpy", + "nonzero", + "norm", + "numel", + "numpy_T", + "one_hot", + "ones_like", + "ones", + "onnx_placeholder", + "pad", + "pairwise_distance", + "permute", + "pixel_shuffle", + "pixel_unshuffle", + "pow", + "prelu", + "prim_constant_chunk", + "prim_constant_split", + "prim_constant", + "prim_data", + "prim_device", + "prim_dtype", + "prim_if", + "prim_layout", + "prim_list_construct", + "prim_list_unpack", + "prim_loop", + "prim_max", + "prim_min", + "prim_shape", + "prim_tolist", + "prim_tuple_construct", + "prim_type", + "prim_unchecked_cast", + "prim_uninitialized", + "rand_like", + "rand", + "randint_like", + "randint", + "randn_like", + "randn", + "reciprocal", + "reflection_pad", + "relu", + "relu6", + "remainder", + "repeat_interleave", + "repeat", + "replication_pad", + "reshape_as", + "reshape", + "roll", + "rrelu", + "rsqrt", + "rsub", + "scalar_tensor", + "scatter_add", + "scatter", + "select", + "selu", + "sigmoid", + "sign", + "silu", + "sin", + "size", + "slice", + "softmax", + "softplus", + "softshrink", + "sort", + "split_with_sizes", + "split", + "sqrt", + "square", + "squeeze", + "stack", + "std_mean", + "std", + "sub", + "t", + "take", + "tan", + "tanh", + "tanhshrink", + "tensor", + "threshold", + "to", + "topk", + "transpose", + "true_divide", + "type_as", + "unbind", + "unfold", + "unsafe_chunk", + "unsafe_split_with_sizes", + "unsafe_split", + "unsqueeze", + "unsupported_complex_operators", + "noop_complex_operators", + "unused", + "var_mean", + "var", + "view_as", + "view", + "where", + "wrap_logical_op_with_cast_to", + "wrap_logical_op_with_negation", + "zeros_like", + "zeros", + "zero", +] + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9) + + +def _export(name: str): + """Exports the function in the current global namespace.""" + + def wrapper(func): + globals()[name] = func + __all__.append(name) + return func + + return wrapper + + +def unused(g): + """Represents "missing" optional inputs.""" + n = g.op("prim::Constant") + n.setType(_C.OptionalType.ofTensor()) + return n + + +@_onnx_symbolic("aten::_shape_as_tensor") +def _shape_as_tensor(g: jit_utils.GraphContext, input): + return g.op("Shape", input) + + +@_onnx_symbolic("aten::_reshape_from_tensor") +def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape): + if isinstance(shape, list): + shape = g.op("Concat", *shape, axis_i=0) + return reshape(g, input, shape) + + +@_onnx_symbolic("aten::reshape") +@symbolic_helper.quantized_args(True) +def reshape(g: jit_utils.GraphContext, self, shape): + return symbolic_helper._reshape_helper(g, self, shape) + + +@_onnx_symbolic("aten::reshape_as") +@symbolic_helper.quantized_args(True) +def reshape_as(g: jit_utils.GraphContext, self, other): + shape = g.op("Shape", other) + return reshape(g, self, shape) + + +@_onnx_symbolic("aten::add") +def add(g: jit_utils.GraphContext, self, other, alpha=None): + """ + This function takes the add function and returns the corresponding ONNX operator. + + This function is not meant to be called directly by the user. + + Args: + g (GraphContext): The graph context. + self (Tensor): The first operand. + other (Tensor): The second operand. + alpha (float, optional): The scaling factor for the second operand. Defaults to None. + + Returns: + ONNX operator. + """ + if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): + return symbolic_helper._onnx_opset_unsupported_detailed( + "Add", 9, 11, "Add between list of tensors not supported", self + ) + if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: + other = g.op("Mul", other, alpha) + return g.op("Add", self, other) + + +@_onnx_symbolic("aten::sub") +def sub(g: jit_utils.GraphContext, self, other, alpha=None): + """ + Consumes sub function and returns the corresponding ONNX operator. + + This function is not meant to be called directly by the user. + + Args: + g (GraphContext): The graph context. + self (Tensor): The first operand. + other (Tensor): The second operand. + alpha (Optional[Tensor]): A scaling factor to apply to the second operand. + If `alpha` is not provided, it defaults to 1. + + Returns: + ONNX operator + """ + if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: + other = g.op("Mul", other, alpha) + return g.op("Sub", self, other) + + +@_onnx_symbolic("aten::rsub") +def rsub(g: jit_utils.GraphContext, self, other, alpha=None): + return sub(g, other, self, alpha=alpha) + + +@_onnx_symbolic("aten::mul") +def mul(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other): + # ONNX Mul doesn't support Boolean, so use And as an equivalent operator. + return g.op("And", self, other) + else: + return g.op("Mul", self, other) + + +@_onnx_symbolic("aten::div") +def div(g: jit_utils.GraphContext, self, other, *args): + if len(args) == 0: + return true_divide(g, self, other) + else: + return _div_rounding_mode(g, self, other, *args) + + +@_onnx_symbolic("aten::addcmul") +@symbolic_helper.parse_args("v", "v", "v", "f") +def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0): + value_tens = g.op("Constant", value_t=torch.tensor([value])) + return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens)) + + +@symbolic_helper.parse_args("v", "v", "s") +def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): + if rounding_mode is None: + return true_divide(g, self, other) + elif rounding_mode == "floor": + return _floor_divide(g, self, other) + elif rounding_mode == "trunc": + return _trunc_divide(g, self, other) + else: + raise errors.SymbolicValueError( + f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"', + self, + ) + + +def _trunc_divide(g: jit_utils.GraphContext, self, other): + out = g.op("Div", self, other) + # the correct operation is truncate, which is not supported in ONNX, + # we cannot call floor since it will behave differently for negative numbers + # (eg. -0.1 should become -0 ) + # - if scalar_type information are not available, assume that + # we need to call floor (treat as float) + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64) + + # Matching PyTorch's behavior: + # - if self is fp the output's type is self's type + # - if self is not fp and other is fp, the output is of type JitScalarType.FLOAT + # - self is not fp and other is not fp, the output's type is self's output type + # - the output type defaults to Float + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other): + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) + else: + out = g.op( + "Cast", + out, + to_i=scalar_type.onnx_type(), + ) + else: + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return out + + +def _floor_divide(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): + out = true_divide(g, self, other) + return g.op("Floor", out) + else: + # Integer division does trunction rounding + div = g.op("Div", self, other) + # Division is negative if: self < 0 != other < 0 + zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) + negative = g.op( + "Xor", + symbolic_helper._lt_helper(g, self, zero), + symbolic_helper._lt_helper(g, other, zero), + ) + + # For negative numbers with self % other != 0, subtract 1 to round down instead of up + mod = g.op("Sub", self, g.op("Mul", div, other)) + fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) + + one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + fixup = g.op("Mul", fixup_mask, one) + return g.op("Sub", div, fixup) + + +@_onnx_symbolic("aten::floor_divide") +def floor_divide(g: jit_utils.GraphContext, self, other): + # Deprecated behavior, floor_divide actually truncates + return _trunc_divide(g, self, other) + + +@_onnx_symbolic("aten::floordiv") +def floordiv(g: jit_utils.GraphContext, self, other): + return floor_divide(g, self, other) + + +@_onnx_symbolic("aten::true_divide") +def true_divide(g: jit_utils.GraphContext, self, other): + """Division where both inputs are cast to floating types + + If both inputs are floating, performs div as usual + If only one input is a floating type, the other input is cast to its type + If neither input is a floating type, both inputs are cast to the default scalar type + """ + + # Case 1: either values are floating + # Performs div as usual. + # Implicit casting will be handled in scalar type analysis pass. + if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): + return g.op("Div", self, other) + + # Case 2: neither is floating + # Casts both inputs to the default scalar type + scalar_type = torch.get_default_dtype() + onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT + assert scalar_type is torch.float or scalar_type is torch.double + if torch.get_default_dtype() is torch.double: + onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE + + self = g.op("Cast", self, to_i=onnx_scalar_type) + other = g.op("Cast", other, to_i=onnx_scalar_type) + return g.op("Div", self, other) + + +@_onnx_symbolic("aten::reciprocal") +def reciprocal(g: jit_utils.GraphContext, self): + # torch.reciprocal implicitly casts to float, so we do the same. + if not symbolic_helper._is_fp(self): + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return g.op("Reciprocal", self) + + +@_onnx_symbolic("aten::cat") +@symbolic_helper.parse_args("v", "i") +def cat(g: jit_utils.GraphContext, tensor_list, dim): + """Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension. + + Parameters: + g (jit_utils.GraphContext): Graph context. + tensor_list (List[torch.Tensor]): List of tensors to concatenate. + dim (int): Dimension along which to concatenate the tensors. + + Returns: + ONNX graph node representing the concatenated tensor. + """ + tensors = symbolic_helper._unpack_list(tensor_list) + # torch.cat ignores empty tensors such as `torch.Tensor([])` + # These needs to be removed as input from ONNX's concat too, otherwise shape inference + # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else) + nonempty_tensors = [] + for t in tensors: + if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size( + t, 0 + ): + continue + nonempty_tensors.append(t) + assert len(nonempty_tensors) > 0 + assert all( + symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None + or symbolic_helper._get_tensor_rank(t) is None + or symbolic_helper._get_tensor_rank(t) + == symbolic_helper._get_tensor_rank(nonempty_tensors[0]) + for t in nonempty_tensors + ) + tensor_list.node().removeAllInputs() + for t in nonempty_tensors: + tensor_list.node().addInput(t) + + tensors = symbolic_helper._unpack_list(tensor_list) + return g.op("Concat", *tensors, axis_i=dim) + + +@_onnx_symbolic("aten::stack") +@symbolic_helper.parse_args("v", "i") +def stack(g: jit_utils.GraphContext, tensor_list, dim): + unsqueezed = [ + symbolic_helper._unsqueeze_helper(g, t, [dim]) + for t in symbolic_helper._unpack_list(tensor_list) + ] + return g.op("Concat", *unsqueezed, axis_i=dim) + + +@_onnx_symbolic("aten::list") +def _list(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("aten::mm") +def mm(g: jit_utils.GraphContext, self, other): + # Create a dummy C tensor. Only needed for API purposes, the value is + # since beta = 0 + C = g.op("Constant", value_t=torch.tensor([1])) + return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) + + +@_onnx_symbolic("aten::bmm") +def bmm(g: jit_utils.GraphContext, self, other): + return g.op("MatMul", self, other) + + +@_onnx_symbolic("aten::matmul") +def matmul(g: jit_utils.GraphContext, self, other): + return g.op("MatMul", self, other) + + +@_onnx_symbolic("aten::addmm") +@symbolic_helper.parse_args("v", "v", "v", "t", "t") +def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): + scalar_type = None + self_scalar_type = symbolic_helper._try_get_scalar_type(self) + mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1) + mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2) + if self_scalar_type is not None: + scalar_type = self_scalar_type + elif mat1_scalar_type is not None: + scalar_type = mat1_scalar_type + elif mat2_scalar_type is not None: + scalar_type = mat2_scalar_type + + mat1_rank = symbolic_helper._get_tensor_rank(mat1) + mat2_rank = symbolic_helper._get_tensor_rank(mat2) + + def is_not_none_nor(v, u): + return v is not None and v != u + + if scalar_type is not None and ( + is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2) + ): + res1 = g.op("MatMul", mat1, mat2) + res2 = self + + alpha = symbolic_helper._scalar(alpha) + beta = symbolic_helper._scalar(beta) + + if alpha != 1: + alpha = g.op( + "Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype()) + ) + res1 = g.op("Mul", res1, alpha) + if beta != 1: + beta = g.op( + "Constant", + value_t=torch.tensor( + symbolic_helper._scalar(beta), dtype=scalar_type.dtype() + ), + ) + res2 = g.op("Mul", res2, beta) + + return g.op("Add", res1, res2) + + return g.op( + "Gemm", + mat1, + mat2, + self, + beta_f=symbolic_helper._scalar(beta), + alpha_f=symbolic_helper._scalar(alpha), + ) + + +@_onnx_symbolic("aten::neg") +def neg(g: jit_utils.GraphContext, self): + return g.op("Neg", self) + + +@_onnx_symbolic("aten::sqrt") +def sqrt(g: jit_utils.GraphContext, self): + if _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.UINT8, + _type_utils.JitScalarType.INT8, + _type_utils.JitScalarType.INT16, + _type_utils.JitScalarType.INT, + _type_utils.JitScalarType.INT64, + }: + # torch converts all int inputs to sqrt to float + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + return g.op("Sqrt", self) + + +@_onnx_symbolic("aten::rsqrt") +def rsqrt(g: jit_utils.GraphContext, self): + return g.op( + "Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self) + ) + + +@_onnx_symbolic("aten::tanh") +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp +@symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128) +def tanh(g: jit_utils.GraphContext, self): + return g.op("Tanh", self) + + +@_onnx_symbolic("aten::sin") +def sin(g: jit_utils.GraphContext, self): + return g.op("Sin", self) + + +@_onnx_symbolic("aten::cos") +def cos(g: jit_utils.GraphContext, self): + return g.op("Cos", self) + + +@_onnx_symbolic("aten::tan") +def tan(g: jit_utils.GraphContext, self): + return g.op("Tan", self) + + +@_onnx_symbolic("aten::asin") +def asin(g: jit_utils.GraphContext, self): + return g.op("Asin", self) + + +@_onnx_symbolic("aten::acos") +def acos(g: jit_utils.GraphContext, self): + return g.op("Acos", self) + + +@_onnx_symbolic("aten::atan") +def atan(g: jit_utils.GraphContext, self): + return g.op("Atan", self) + + +@_onnx_symbolic("aten::atan2") +def atan2(g: jit_utils.GraphContext, self, other): + # self is y, and other is x on coordinate + slope = g.op("Div", self, other) + atan = g.op("Atan", slope) + const_zero = g.op("Constant", value_t=torch.tensor(0)) + const_pi = g.op("Constant", value_t=torch.tensor(math.pi)) + + condition_second_or_third_quadrant = g.op("Greater", self, const_zero) + second_third_quadrant = g.op( + "Where", + condition_second_or_third_quadrant, + g.op("Add", atan, const_pi), + g.op("Sub", atan, const_pi), + ) + + condition_14_or_23_quadrant = g.op("Less", other, const_zero) + result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan) + + return result + + +@_onnx_symbolic("aten::sigmoid") +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp +@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) +def sigmoid(g: jit_utils.GraphContext, self): + """Converts the corresponding PyTorch function into ONNX operators. + + It is not meant to be called directly by a user. + + Args: + g (jit_utils.GraphContext): Graph context. + self (Tensor): the input tensor. + Returns: + ONNX operator + """ + return g.op("Sigmoid", self) + + +@_onnx_symbolic("aten::sign") +def sign(g: jit_utils.GraphContext, self): + return g.op("Sign", self) + + +@symbolic_helper.quantized_args(True) +def _slice(g: jit_utils.GraphContext, input, axes, starts, ends): + assert len(starts) == len(ends) + if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX: + return input + return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends) + + +@_onnx_symbolic( + "aten::sum", decorate=[symbolic_helper._apply_params("ReduceSum", "sum")] +) +@_onnx_symbolic( + "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] +) +# torch.prod does not support multidimensional "dim" +@_onnx_symbolic( + "aten::prod", + decorate=[ + symbolic_helper._apply_params( + "ReduceProd", "prod", allow_multi_dim_support=False + ) + ], +) +def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): + return symbolic_helper._reduce_with_dtype_helper( + onnx_op, name, allow_multi_dim_support + ) + + +@_onnx_symbolic("aten::cumsum") +@symbolic_helper.parse_args("v", "i", "none") +def cumsum(g: jit_utils.GraphContext, input, dim, dtype): + symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) + + +@_onnx_symbolic("aten::_sample_dirichlet") +def _sample_dirichlet(g: jit_utils.GraphContext, self, generator): + return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) + + +@_onnx_symbolic("aten::_standard_gamma") +def _standard_gamma(g: jit_utils.GraphContext, self, generator): + return symbolic_helper._onnx_unsupported("_standard_gamma", self) + + +@_onnx_symbolic("aten::t") +def t(g: jit_utils.GraphContext, self): + rank = symbolic_helper._get_tensor_rank(self) + if rank is None or rank < 2: + # The transpose of a 1d or 0d tensor is itself. ONNX does not define the behavior + # clearly and onnxruntime fails on these cases. So we add an Identity node to + # mirror the behavior of eager mode. + return g.op("Identity", self) + return g.op("Transpose", self, perm_i=(1, 0)) + + +@_onnx_symbolic("aten::numpy_T") +@symbolic_helper.quantized_args(True) +def numpy_T(g: jit_utils.GraphContext, input): + ndim = symbolic_helper._get_tensor_rank(input) + assert ndim is not None + perm = list(reversed(range(0, ndim))) + return g.op("Transpose", input, perm_i=perm) + + +@_onnx_symbolic("aten::expand") +@symbolic_helper.quantized_args(True) +def expand(g: jit_utils.GraphContext, self, size, implicit): + """Implement the expand function for a pytorch tensor in ONNX according to specified `size`""" + size = symbolic_helper._maybe_get_const(size, "is") + if not symbolic_helper._is_value(size): + size = g.op("Constant", value_t=torch.LongTensor(size)) + elif symbolic_helper._is_packed_list(size): + # Expand with -1 dim value means dim is unchanged. + # Since onnx::expand supports two-way broadcasting, + # -1 dim value can be exported to onnx as 1 + size = symbolic_helper._reshape_helper( + g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) + ) + dtype = _type_utils.JitScalarType.INT64 + ones = ones_like(g, size, dtype) + neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) + size = where(g, g.op("Equal", size, neg_ones), ones, size) + return g.op("Expand", self, size) + + +@_onnx_symbolic("aten::broadcast_to") +@symbolic_helper.quantized_args(True) +def broadcast_to(g: jit_utils.GraphContext, self, size): + size = symbolic_helper._maybe_get_const(size, "is") + if not symbolic_helper._is_value(size): + size = g.op("Constant", value_t=torch.LongTensor(size)) + elif symbolic_helper._is_packed_list(size): + # Expand with -1 dim value means dim is unchanged. + # Since onnx::expand supports two-way broadcasting, + # -1 dim value can be exported to onnx as 1 + size = symbolic_helper._reshape_helper( + g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) + ) + dtype = _type_utils.JitScalarType.INT64 + ones = ones_like(g, size, dtype) + neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) + size = where(g, g.op("Equal", size, neg_ones), ones, size) + return g.op("Expand", self, size) + + +@_onnx_symbolic("aten::expand_as") +@symbolic_helper.quantized_args(True, True) +def expand_as(g: jit_utils.GraphContext, self, other): + self_t = symbolic_helper._maybe_get_const(self, "t") + if isinstance(self_t, torch.Tensor): + orig_type = self_t.dtype + self_t = self_t.to(torch.double) + dims = [] + for d in range(self_t.dim()): + if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t): + dims.append(d) + self = g.op( + "Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type) + ) + + shape = g.op("Shape", other) + return g.op("Expand", self, shape) + + +@_onnx_symbolic("aten::embedding") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "i", "b", "v") +def embedding( + g: jit_utils.GraphContext, + weight, + indices, + padding_idx, + scale_grad_by_freq, + sparse, +): + if scale_grad_by_freq and GLOBALS.export_training: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of embedding with scale_grad_by_freq=True " + "for training mode. ONNX does not support scaling the gradients.", + weight, + ) + if padding_idx >= 0 and GLOBALS.export_training: + warnings.warn( + "Warning: ONNX export of embedding with padding_idx >= 0 " + "for training mode. " + "ONNX does not support not updating the embedding vector at padding_idx during training." + ) + + return g.op("Gather", weight, indices) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + if not symbolic_helper._is_none(per_sample_weights): + return symbolic_helper._onnx_unsupported( + "embedding_bag with per_sample_weights" + ) + + return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) + + +@_onnx_symbolic("aten::size") +@symbolic_helper.quantized_args(True, quantize_output=False) +def size(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Shape", self) + if symbolic_helper._maybe_get_const(dim, "i") < 0: + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + dim = symbolic_helper._maybe_get_const(dim, "i") + rank + dim = g.op("Constant", value_t=torch.tensor(dim)) + return symbolic_helper._size_helper(g, self, dim) + + +@_onnx_symbolic("aten::transpose") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "i", "i") +def transpose(g: jit_utils.GraphContext, self, dim0, dim1): + if dim0 == dim1: # micro-optimization + return self + + # NB: Transpose in ONNX is actually a Permute + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + axes = list(range(rank)) + axes[dim0], axes[dim1] = axes[dim1], axes[dim0] + return g.op("Transpose", self, perm_i=axes) + else: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of transpose for tensor of unknown rank.", + self, + ) + + +@_onnx_symbolic("aten::permute") +@symbolic_helper.parse_args("v", "is") +def permute(g: jit_utils.GraphContext, self, dims): + if dims == list(range(0, len(dims))): + return self + return g.op("Transpose", self, perm_i=dims) + + +@_onnx_symbolic("aten::view") +@symbolic_helper.quantized_args(True) +def view(g: jit_utils.GraphContext, self, size): + return reshape(g, self, size) + + +@_onnx_symbolic("aten::view_as") +def view_as(g: jit_utils.GraphContext, self, other): + shape = g.op("Shape", other) + return reshape(g, self, shape) + + +@_onnx_symbolic("aten::unsafe_chunk") +@symbolic_helper.parse_args("v", "i", "i", "i") +def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): + if _outputs is None: + return symbolic_helper._onnx_opset_unsupported_detailed( + "unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self + ) + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented( + "unsafe_chunk", "unknown dimension size", self + ) + split_size = (size + chunks - 1) // chunks + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): + return symbolic_helper._onnx_opset_unsupported_detailed( + "split", 9, 11, "Dynamic number of outputs not supported", self + ) + split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") + if split_val.dim() > 0: + return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs) + split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + if _outputs is not None: + size = split_size * _outputs + else: + return symbolic_helper._onnx_opset_unsupported_detailed( + "split", 9, 11, "Unknown dimension size not supported", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::unsafe_split") +def unsafe_split( + g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None +): + return split(g, self, split_size_or_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::split_with_sizes") +@symbolic_helper.parse_args("v", "is", "i", "i") +def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_sizes, _outputs): + return symbolic_helper._onnx_opset_unsupported_detailed( + "split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self + ) + return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::unsafe_split_with_sizes") +def unsafe_split_with_sizes( + g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None +): + return split_with_sizes(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unbind") +@symbolic_helper.parse_args("v", "i", "i") +def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): + if _outputs is None: + return symbolic_helper._onnx_opset_unsupported_detailed( + "unbind", 9, 11, "Dynamic number of outputs not supported", self + ) + + outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs) + outputs = [outputs] if _outputs == 1 else outputs + squeezed_outputs = [ + symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs + ] + return squeezed_outputs + + +@_onnx_symbolic("aten::select") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "i", "v") +def select(g: jit_utils.GraphContext, self, dim, index): + """Implement the select functionality for a pytorch tensor in ONNX. + + Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor. + """ + index = symbolic_helper._maybe_get_scalar(index) + if (not symbolic_helper._is_value(index)) and (index < 0): + if index == -1: + end_index = _constants.INT64_MAX + else: + end_index = index + 1 + slice_node = symbolic_helper._slice_helper( + g, self, axes=[dim], starts=[index], ends=[end_index] + ) + return symbolic_helper._squeeze_helper(g, slice_node, [dim]) + else: + # FIXME(justinchuby): can index be an int and not a value? + return g.op("Gather", self, index, axis_i=dim) + + +@_onnx_symbolic("aten::square") +def square(g: jit_utils.GraphContext, self): + return g.op("Mul", self, self) + + +@_onnx_symbolic("aten::squeeze") +def squeeze(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Squeeze", self) + + squeeze_dim = symbolic_helper._get_const(dim, "i", "dim") + # Handle negative dims + if squeeze_dim < 0: + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + warnings.warn( + "ONNX export squeeze with negative axis " + + str(squeeze_dim) + + " might cause the onnx model to be incorrect. " + + "Negative axis is not supported in ONNX. " + + "Axis is converted to " + + str(squeeze_dim + rank) + + " based on input shape at export time. " + + "Passing an tensor of different rank in execution will be incorrect." + ) + squeeze_dim += rank + else: + return symbolic_helper._unimplemented( + "squeeze", "negative axis with unknown input rank", self + ) + + dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim) + if dim_size is None: + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(squeeze_dim) + + " on an input " + + "with unknown shape. Note that if the size of dimension " + + str(squeeze_dim) + + " of the input " + + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " + + "non-singleton dimensions, it is recommended to export this model using opset " + + "version 11 or higher." + ) + return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) + if dim_size > 1: + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(squeeze_dim) + + ". The size of " + + "this dimension in the given input is " + + str(dim_size) + + ". The model will " + + "be exported without the squeeze node. If the model is intended to be used with dynamic " + + "input shapes, please use opset version 11 to " + + "export the model." + ) + return self + + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(squeeze_dim) + + ". If the model is " + + "intended to be used with dynamic input shapes, please use opset version 11 to export the model." + ) + return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) + + +@_onnx_symbolic("aten::prelu") +def prelu(g: jit_utils.GraphContext, self, weight): + self_rank = symbolic_helper._get_tensor_rank(self) + weight_sizes = symbolic_helper._get_tensor_sizes(weight) + weight_rank = len(weight_sizes) + if self_rank is not None: + if self_rank > 2: + # make weight unidirectional broadcastable + weight = symbolic_helper._unsqueeze_helper( + g, weight, list(range(1, self_rank - 1)) + ) + elif self_rank == 0 and weight_sizes == [1]: + # self and weight are both scalar but weight has rank == 1, squeeze weight. + weight = symbolic_helper._squeeze_helper(g, weight, [0]) + weight_rank = 0 + + if self_rank is not None and weight_rank is not None: + assert ( + self_rank >= weight_rank + ), f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}" + return g.op("PRelu", self, weight) + + +@_onnx_symbolic("aten::silu") +def silu(g: jit_utils.GraphContext, input): + return g.op("Mul", input, g.op("Sigmoid", input)) + + +@_onnx_symbolic("aten::mish") +def mish(g: jit_utils.GraphContext, input): + return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input))) + + +@_onnx_symbolic("aten::relu") +@symbolic_helper.quantized_args(True) +def relu(g: jit_utils.GraphContext, input): + return symbolic_helper._op_with_optional_float_cast( + g, "Relu", input, opset_before=14 + ) + + +@_onnx_symbolic("aten::relu6") +@symbolic_helper.quantized_args(True) +def relu6(g: jit_utils.GraphContext, input): + return clamp(g, input, 0, 6) + + +@_onnx_symbolic("aten::ceil") +def ceil(g: jit_utils.GraphContext, input): + return g.op("Ceil", input) + + +@_onnx_symbolic("aten::floor") +def floor(g: jit_utils.GraphContext, input): + return g.op("Floor", input) + + +@_onnx_symbolic("aten::len") +def _len(g: jit_utils.GraphContext, self): + sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) + return symbolic_helper._squeeze_helper(g, sz_0, [0]) + + +@_onnx_symbolic("aten::threshold") +@symbolic_helper.parse_args("v", "t", "t") +def threshold(g: jit_utils.GraphContext, self, threshold, value): + # See Note [Export inplace] + if symbolic_helper._scalar(threshold) != 0: + return symbolic_helper._unimplemented("threshold", "non-zero threshold", self) + if symbolic_helper._scalar(value) != 0: + return symbolic_helper._unimplemented("threshold", "non-zero value", self) + return g.op("Relu", self) + + +@_onnx_symbolic("aten::leaky_relu") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "f", "b") +def leaky_relu( + g: jit_utils.GraphContext, + input: _C.Value, + negative_slope: float, + inplace: bool = False, +): + # See Note [Export inplace] + return g.op("LeakyRelu", input, alpha_f=negative_slope) + + +@_onnx_symbolic("aten::glu") +@symbolic_helper.parse_args("v", "i") +def glu(g: jit_utils.GraphContext, input, dim): + dim_size = symbolic_helper._get_tensor_dim_size(input, dim) + if dim_size is not None: + assert dim_size % 2 == 0 + + first, second = g.op("Split", input, axis_i=dim, outputs=2) + return g.op("Mul", first, g.op("Sigmoid", second)) + + +@_onnx_symbolic("aten::softmax") +@symbolic_helper.parse_args("v", "i", "none") +def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + # Softmax does normalization at vector level. + # PyTorch and ONNX use different strategies to split the input tensor into vectors. + # Thus dim and axis have different meanings. + # PyTorch slices the input tensor into vectors along the `dim`-th dimension. + # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced. + # If input is a 2 x 3 tensor: + # input = [[1.0, 1.0, 1.0], + # [1.0, 1,0, 1,0]] + # with dim = 0, the result is: + # result = [[0.5, 0.5, 0.5], + # [0.5, 0.5, 0.5]] + # with axis = 0, the result is: + # result = [[0.167, 0.167, 0.167], + # [0.167, 0.167, 0.167]] + # So only when dim and axis both equal to ndim - 1 (the last dimension), + # their semantics are equivalent. + # So use softmax when dim and axis both equal to ndim - 1, + # otherwise transpose the input to put the vectors to be normalized to the last dimension. + # When input rank is not known at export time we compute softmax using a subgraph + # with other operators + input_dim = symbolic_helper._get_tensor_rank(input) + if input_dim is not None: + # TODO: remove this as onnx opset 11 spec allows negative axes + if dim < 0: + dim = input_dim + dim + + is_transpose_required = input_dim != dim + 1 + + if is_transpose_required: + axes = list(range(input_dim)) + axes[dim], axes[-1] = axes[-1], axes[dim] + input = g.op("Transpose", input, perm_i=axes) + dim = input_dim - 1 + + softmax = g.op("Softmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + softmax = g.op( + "Cast", + softmax, + to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(), + ) + + if is_transpose_required: + softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined] + return softmax + + # Apply max normalization. + input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1)) + + exp = g.op("Exp", input) + sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim]) + softmax = g.op("Div", exp, sum) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + softmax = g.op( + "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + return softmax + + +@_onnx_symbolic("aten::softplus") +def softplus(g: jit_utils.GraphContext, self, beta, threshold): + beta_const = symbolic_helper._maybe_get_const(beta, "f") + if beta_const != 1: + return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta) + return g.op("Softplus", self) + + +@_onnx_symbolic("aten::get_pool_ceil_padding") +def get_pool_ceil_padding(input, kernel_size, stride, padding): + # TODO(justinchuby): Looks like this op is deprecated in torch + sizes = symbolic_helper._get_tensor_sizes(input) + dim = sizes[-len(padding) :] if sizes is not None else None + if dim is None or any(i is None for i in dim): + return symbolic_helper._unimplemented( + "get_pool_ceil_padding", "input size not accessible", input + ) + ceiled_output_dim = [ + int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + + 1 + for i in range(0, len(padding)) + ] + # ensure last pooling starts inside + ceiled_output_dim = [ + ( + ceiled_output_dim[i] - 1 + if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) + else ceiled_output_dim[i] + ) + for i in range(0, len(ceiled_output_dim)) + ] + padding_ceil = [ + ( + 0 + if (stride[i] == 1) + else ( + kernel_size[i] + - ( + dim[i] + + 2 * padding[i] + - ((ceiled_output_dim[i] - 1) * stride[i] + 1) + ) + ) + ) + for i in range(0, len(padding)) + ] + # ensure padding is not > kernel_size + padding_ceil = [ + ( + ( + int(padding_ceil[i]) + if padding_ceil[i] < kernel_size[i] - 1 + else int(kernel_size[i] - 1) + ) + if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) + else int(padding_ceil[i]) + ) + for i in range(0, len(padding_ceil)) + ] + return padding_ceil + + +@_onnx_symbolic( + "aten::max_pool1d", + decorate=[ + symbolic_helper._apply_params( + "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False + ), + _export("max_pool1d"), + ], +) +@_onnx_symbolic( + "aten::max_pool2d", + decorate=[ + symbolic_helper._apply_params( + "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False + ), + _export("max_pool2d"), + ], +) +@_onnx_symbolic( + "aten::max_pool3d", + decorate=[ + symbolic_helper._apply_params( + "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False + ), + _export("max_pool3d"), + ], +) +def _max_pool(name, tuple_fn, ndims, return_indices): + @symbolic_helper.quantized_args(True, False, False, False, False, False) + @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") + def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): + if set(tuple_fn(dilation)) != {1}: + return symbolic_helper._unimplemented(name, "dilation", input) + if not stride: + stride = kernel_size + padding = tuple(tuple_fn(padding)) + if ceil_mode: + padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) + padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding)) + else: + padding = padding * 2 + kwargs = { + "kernel_shape_i": tuple_fn(kernel_size), + "pads_i": padding, + "strides_i": tuple_fn(stride), + } + # easy but hacky way to get flattened indices values + # to be used to convert the indices values to non-flattened. + # In ONNX the indices are computed as a flatten 1-D tensor, + # so the values in indices are in [0, N x C x D1 x ... x Dn). + # To convert the indices to the same format used by Pytorch, + # we first execute a maxpool with a kernel and stride of 1 on the same input. + # This will result in a tensor of indices in which each index will have it's own value. + # Using this tensor as a reference, we extract the first index of each axis and subtract + # it from each index of this axis in the indices to convert. + # This step will result in a tensor were each dimension has values of indices within + # the dimension it is in. + # For more information : + # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 + if return_indices: + r, indices = g.op("MaxPool", input, outputs=2, **kwargs) + _, flattened_indices = g.op( + "MaxPool", + input, + outputs=2, + kernel_shape_i=[1 for _ in range(ndims)], + strides_i=[1 for _ in range(ndims)], + ) + # convert indices to have non-flattened indices values + s = symbolic_helper._slice_helper( + g, + flattened_indices, + axes=[2 + i for i in range(ndims)], + starts=list(tuple_fn(0)), + ends=list(tuple_fn(1)), + ) + indices = sub(g, indices, s) + return r, indices + else: + r = g.op("MaxPool", input, outputs=1, **kwargs) + return r + + return symbolic_fn + + +max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")( + _max_pool( + "max_pool1d_with_indices", + torch.nn.modules.utils._single, + 1, + return_indices=True, + ) +) +max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")( + _max_pool( + "max_pool2d_with_indices", + torch.nn.modules.utils._pair, + 2, + return_indices=True, + ) +) +max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")( + _max_pool( + "max_pool3d_with_indices", + torch.nn.modules.utils._triple, + 3, + return_indices=True, + ) +) + + +@_onnx_symbolic( + "aten::avg_pool1d", + decorate=[ + symbolic_helper._apply_params("avg_pool1d", torch.nn.modules.utils._single), + _export("avg_pool1d"), + ], +) +@_onnx_symbolic( + "aten::avg_pool2d", + decorate=[ + symbolic_helper._apply_params("avg_pool2d", torch.nn.modules.utils._pair), + _export("avg_pool2d"), + ], +) +@_onnx_symbolic( + "aten::avg_pool3d", + decorate=[ + symbolic_helper._apply_params("avg_pool3d", torch.nn.modules.utils._triple), + _export("avg_pool3d"), + ], +) +def _avg_pool(name, tuple_fn): + @symbolic_helper.quantized_args(True) + @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") + def symbolic_fn( + g, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: int | Sequence[int], + ceil_mode: int, + count_include_pad: int, + divisor_override=None, + ): + if not stride: + stride = kernel_size + padding = symbolic_helper._avgpool_helper( + tuple_fn, padding, kernel_size, stride, divisor_override, name + ) + assert isinstance(padding, tuple) + adjusted_padding = padding + # Although onnx::AvgPool provides count_include_pad, + # The corner case of Average Pooling with ceil_mode on + # PyTorch allows sliding window go off bound, which leads to + # this accommodation. + # More detail on https://github.com/pytorch/pytorch/issues/57178 + if count_include_pad: + input = symbolic_helper._op_with_optional_float_cast( + g, + "Pad", + input, + pads_i=((0,) * 2 + padding) * 2, + mode_s="constant", + value_f=0.0, + opset_before=11, + ) + adjusted_padding = (0,) * len(padding) + if ceil_mode: + padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) + adjusted_padding = adjusted_padding + tuple( + a + b for (a, b) in zip(padding_ceil, adjusted_padding) + ) + else: + adjusted_padding = adjusted_padding * 2 + output = g.op( + "AveragePool", + input, + kernel_shape_i=tuple_fn(kernel_size), + strides_i=tuple_fn(stride), + pads_i=adjusted_padding, + ) + return output + + return symbolic_fn + + +@_onnx_symbolic( + "aten::adaptive_avg_pool1d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single + ), + _export("adaptive_avg_pool1d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_avg_pool2d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair + ), + _export("adaptive_avg_pool2d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_avg_pool3d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple + ), + _export("adaptive_avg_pool3d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool1d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_max_pool1d", + "MaxPool", + torch.nn.modules.utils._single, + max_pool1d_with_indices, + ), + _export("adaptive_max_pool1d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool2d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_max_pool2d", + "MaxPool", + torch.nn.modules.utils._pair, + max_pool2d_with_indices, + ), + _export("adaptive_max_pool2d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool3d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_max_pool3d", + "MaxPool", + torch.nn.modules.utils._triple, + max_pool3d_with_indices, + ), + _export("adaptive_max_pool3d"), + ], +) +def _adaptive_pool(name, type, tuple_fn, fn=None): + @symbolic_helper.quantized_args(True, False) + def symbolic_fn(g, input, output_size): + # _adaptive_pool is supported for cases where output_size is 1 for all dimensions, + # by executing a GlobalPool. + # It is also supported for cases where the output size is a factor of the input size. + # For these cases the stride and kernel size are uniform along all the indices of + # the same dimension, which makes it possible to export it to ONNX. + # for MaxPool, GlobalMaxPool does not return indices, + # so we try using max_poolxd_with_indices, and if it is not possible + # (input is not a complete tensor or output size not factor of input size) + # then we call GlobalAveragePool and return None for the indices + output_size_value = output_size + try: + output_size = symbolic_helper._parse_arg(output_size, "is") + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + return symbolic_helper._onnx_unsupported( + "adaptive pooling, since output_size is not constant.", input + ) + if output_size == [1] * len(output_size) and type == "AveragePool": + return g.op("GlobalAveragePool", input) + sizes = symbolic_helper._get_tensor_sizes(input) + try: + dim = sizes[2:] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + dim = None + if dim is None or any(i is None for i in dim): + if output_size == [1] * len(output_size): + return g.op("GlobalMaxPool", input), None + return symbolic_helper._unimplemented( + name, "input size not accessible", input + ) + # verify if output size % input size = 0 for all dim + mod = [dim[i] % output_size[i] for i in range(0, len(dim))] + if mod != [0] * len(mod): + if output_size == [1] * len(output_size): + return g.op("GlobalMaxPool", input), None + return symbolic_helper._unimplemented( + name, "output size that are not factor of input size", output_size_value + ) + k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] + # call max_poolxd_with_indices to get indices in the output + if type == "MaxPool": + return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False) + output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k)) + return output + + return symbolic_fn + + +def _prepare_onnx_paddings(dim: int, pad): + """Generate paddings in ONNX order based on pad in pytorch. + Args: + dim: the dimension of the tensor. + pad: the paddings in pytorch. + The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... + """ + # The desired order of paddings is + # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. + # n is the dimension of input. + # assume zero-dimensions in the beginning + paddings = list(pad[:]) + [0] * (dim * 2 - len(pad)) + # reverse order and collate first beginnings and then ends + paddings = paddings[-2::-2] + paddings[-1::-2] + return paddings + + +def _convert_padding_node(input): + padding = symbolic_helper._maybe_get_const(input, "is") + if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding): + input_list = symbolic_helper._unpack_list(padding) + try: + padding = [ + symbolic_helper._get_const(v, "i", "padding") for v in input_list + ] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + return symbolic_helper._onnx_opset_unsupported_detailed( + "Pad", 9, 11, "The sizes of the padding must be constant", input + ) + return padding + + +@_onnx_symbolic("aten::constant_pad_nd") +def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value): + mode = "constant" + try: + value = symbolic_helper._get_const(value, "f", "value") + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + return symbolic_helper._onnx_opset_unsupported_detailed( + "Pad", 9, 11, "The value for the padding must be constant", value + ) + + padding = _convert_padding_node(padding) + paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) + return symbolic_helper._op_with_optional_float_cast( + g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 + ) + + +def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value): + padding = _convert_padding_node(pad) + assert len(padding) % 2 == 0 + ndim = len(padding) // 2 + + cur = input + for idx in range(ndim): + pad_r = padding[-(2 * idx + 1)] + pad_l = padding[-(2 * idx + 2)] + tensors = [] + if pad_l > 0: + left = symbolic_helper._slice_helper( + g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX] + ) + tensors.append(left) + + if pad_l < 0 or pad_r < 0: + start = builtins.max(0, -pad_l) + end = -(builtins.max(0, -pad_r)) + middle = symbolic_helper._slice_helper( + g, + cur, + axes=[2 + idx], + starts=[start], + ends=[end], + ) + tensors.append(middle) + else: + tensors.append(cur) + + if pad_r > 0: + right = symbolic_helper._slice_helper( + g, cur, axes=[2 + idx], starts=[0], ends=[pad_r] + ) + tensors.append(right) + + cur = g.op("Concat", *tensors, axis_i=(2 + idx)) + + return cur + + +@_onnx_symbolic("aten::reflection_pad1d") +@_onnx_symbolic("aten::reflection_pad2d") +@_onnx_symbolic("aten::reflection_pad3d") +def reflection_pad(g: jit_utils.GraphContext, input, padding): + mode = "reflect" + padding = _convert_padding_node(padding) + paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) + return symbolic_helper._op_with_optional_float_cast( + g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 + ) + + +@_onnx_symbolic("aten::replication_pad1d") +@_onnx_symbolic("aten::replication_pad2d") +@_onnx_symbolic("aten::replication_pad3d") +def replication_pad(g: jit_utils.GraphContext, input, padding): + mode = "edge" + padding = _convert_padding_node(padding) + paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) + return symbolic_helper._op_with_optional_float_cast( + g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 + ) + + +@_onnx_symbolic("aten::pad") +def pad( + g: jit_utils.GraphContext, + input: _C.Value, + pad: _C.Value, + mode: _C.Value, + value: _C.Value, +): + mode = symbolic_helper._parse_arg(mode, "s") + if mode == "replicate": + return replication_pad(g, input, pad) + elif mode == "reflect": + return reflection_pad(g, input, pad) + elif mode == "constant": + return constant_pad_nd(g, input, pad, value) + elif mode == "circular": + return _pad_circular(g, input, pad) + else: + raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[ + symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest"), + _export("upsample_nearest1d"), + ], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[ + symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest"), + _export("upsample_nearest2d"), + ], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[ + symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest"), + _export("upsample_nearest3d"), + ], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[ + symbolic_helper._apply_params("upsample_linear1d", 3, "linear"), + _export("upsample_linear1d"), + ], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[ + symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear"), + _export("upsample_bilinear2d"), + ], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[ + symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear"), + _export("upsample_trilinear3d"), + ], +) +def _interpolate(name: str, dim: int, interpolate_mode: str): + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = symbolic_helper._get_interpolate_attributes( + g, interpolate_mode, args + ) + symbolic_helper._interpolate_warning(interpolate_mode) + align_corners = symbolic_helper._maybe_get_scalar(align_corners) + if align_corners: + return symbolic_helper._unimplemented(name, "align_corners == True", input) + if scales is None: + scales = symbolic_helper._interpolate_size_to_scales( + g, input, output_size, dim + ) + return g.op("Upsample", input, scales, mode_s=interpolate_mode) + + return symbolic_fn + + +@_onnx_symbolic("aten::__interpolate") +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + scales, mode = symbolic_helper._interpolate_get_scales_and_mode( + g, input, size, scale_factor, mode, align_corners + ) + return g.op("Upsample", input, scales, mode_s=mode) + + +@_onnx_symbolic("aten::bitwise_not") +def bitwise_not(g: jit_utils.GraphContext, input): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise Not " + "for non-boolean input values", + input, + ) + return g.op("Not", input) + + +@_onnx_symbolic("aten::bitwise_or") +def bitwise_or(g, self, other): + if not symbolic_helper._is_bool(self): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values. self: ", + self, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values. other: ", + other, + ) + return g.op("Or", self, other) + + +def wrap_logical_op_with_cast_to(to_type): + def decorator(fn): + @functools.wraps(fn) + def wrap_with_cast(g, input, other): + to_cast_func = globals()[f"_cast_{to_type}"] + return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) + + return wrap_with_cast + + return decorator + + +def wrap_logical_op_with_negation(func: Callable) -> Callable: + @functools.wraps(func) + def wrap_with_not(g, input, other): + return g.op("Not", func(g, input, other)) + + return wrap_with_not + + +@_onnx_symbolic("aten::__not_") +def __not_(g: jit_utils.GraphContext, self): + if not symbolic_helper._is_bool(self): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise Not " + "for non-boolean input values", + self, + ) + return g.op("Not", self) + + +@_onnx_symbolic("aten::eq") +@symbolic_helper.quantized_args(True, True) +def eq(g: jit_utils.GraphContext, self, other): + if isinstance(self.type(), _C.DeviceObjType) and isinstance( + other.type(), _C.DeviceObjType + ): + # ONNX doesn't have devices, so consider them all to be equal. + # The no-op check for equality will get constant-folded. + return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool)) + self_node = self.node() + other_node = other.node() + if self_node.kind() == other_node.kind() == "onnx::Constant": + if self_node.kindOf("value") == other_node.kindOf("value") == "s": + # Exporting strings to ONNX is not supported. + # If both strings are constant, we can compare them directly. + # The no-op check for equality will get constant-folded. + return g.op( + "Constant", + value_t=torch.tensor( + self_node.s("value") == other_node.s("value"), + dtype=torch.bool, + ), + ) + + return g.op("Equal", self, other) + + +@_onnx_symbolic("aten::ne") +@symbolic_helper.quantized_args(True, True) +@wrap_logical_op_with_negation +def ne(g: jit_utils.GraphContext, self, other): + return eq(g, self, other) + + +@_onnx_symbolic("aten::gt") +@symbolic_helper.quantized_args(True, True) +def gt(g: jit_utils.GraphContext, input, other): + return _gt_impl(g, input, other) + + +def _gt_impl(g: jit_utils.GraphContext, input, other): + if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) + return g.op("Greater", input, other) + + +@_onnx_symbolic("aten::lt") +@symbolic_helper.quantized_args(True, True) +def lt(g: jit_utils.GraphContext, input, other): + return _lt_impl(g, input, other) + + +def _lt_impl(g: jit_utils.GraphContext, input, other): + if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) + return g.op("Less", input, other) + + +@_onnx_symbolic("aten::ge") +@symbolic_helper.quantized_args(True, True) +@wrap_logical_op_with_negation +def ge(g: jit_utils.GraphContext, input, other): + return _lt_impl(g, input, other) + + +@_onnx_symbolic("aten::le") +@symbolic_helper.quantized_args(True, True) +@wrap_logical_op_with_negation +def le(g: jit_utils.GraphContext, input, other): + return _gt_impl(g, input, other) + + +@_onnx_symbolic("aten::__and_") +def __and_(g: jit_utils.GraphContext, input, other): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise AND " + "for non-boolean input values", + input, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise AND " + "for non-boolean input values", + other, + ) + return g.op("And", input, other) + + +@_onnx_symbolic("aten::__or_") +def __or_(g: jit_utils.GraphContext, input, other): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values", + input, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values", + other, + ) + return g.op("Or", input, other) + + +@_onnx_symbolic("aten::__xor_") +def __xor_(g: jit_utils.GraphContext, input, other): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise XOR " + "for non-boolean input values", + input, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise XOR " + "for non-boolean input values", + other, + ) + return g.op("Xor", input, other) + + +@_onnx_symbolic("aten::logical_and") +@wrap_logical_op_with_cast_to("Bool") +def logical_and(g: jit_utils.GraphContext, input, other): + return g.op("And", input, other) + + +@_onnx_symbolic("aten::logical_or") +@wrap_logical_op_with_cast_to("Bool") +def logical_or(g: jit_utils.GraphContext, input, other): + return g.op("Or", input, other) + + +@_onnx_symbolic("aten::logical_xor") +@wrap_logical_op_with_cast_to("Bool") +def logical_xor(g: jit_utils.GraphContext, input, other): + return g.op("Xor", input, other) + + +@_onnx_symbolic("aten::logical_not") +def logical_not(g: jit_utils.GraphContext, input): + return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)) + + +@_onnx_symbolic("aten::__rshift_") +def __rshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + self_scalar_type = _type_utils.JitScalarType.from_value(self) + if ( + _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) + != self_scalar_type + ): + other = g.op( + "Cast", + other, + to_i=self_scalar_type.onnx_type(), + ) + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=self_scalar_type.onnx_type(), + ) + rshift = g.op("Div", self, two_pow) + return rshift + + +@_onnx_symbolic("aten::__lshift_") +def __lshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + self_scalar_type = _type_utils.JitScalarType.from_value(self) + if ( + _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) + != self_scalar_type + ): + other = g.op( + "Cast", + other, + to_i=self_scalar_type.onnx_type(), + ) + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=self_scalar_type.onnx_type(), + ) + lshift = g.op("Mul", self, two_pow) + return lshift + + +@_onnx_symbolic("aten::where") +@symbolic_helper.parse_args("v", "v", "v", "i") +def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): + # Assumes that torch.where's first argument takes only Bool and Byte tensors. + if not symbolic_helper._is_bool(condition): + condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + if self is None: + condition = nonzero(g, condition) + return symbolic_helper._unbind_helper( + g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs + ) + return g.op("Where", condition, self, other) + + +@_onnx_symbolic("aten::log_softmax") +@symbolic_helper.parse_args("v", "i", "none") +def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + # PyTorch dim and ONNX axis have different meanings. + # See Softmax comment for details. + # TODO: remove this as onnx opset 11 spec allows negative axes + input_dim = symbolic_helper._get_tensor_rank(input) + if input_dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + ) + if dim < 0: + dim = input_dim + dim + is_transpose_required = input_dim != dim + 1 + # ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases. + if is_transpose_required: + axes = list(range(input_dim)) + axes[dim], axes[-1] = axes[-1], axes[dim] + input = g.op("Transpose", input, perm_i=axes) + dim = input_dim - 1 + return_op = g.op("LogSoftmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + return_op = g.op( + "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + if is_transpose_required: + return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined] + return return_op + + +@_onnx_symbolic("aten::_log_softmax") +@symbolic_helper.parse_args("v", "i", "i") +def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float): + if ( + half_to_float + and _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.UNDEFINED + ) + == _type_utils.JitScalarType.HALF + ): + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return log_softmax(g, input, dim) + + +@_onnx_symbolic("aten::_convolution") +@symbolic_helper.parse_args( + "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i" +) +def _convolution( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32=None, +): + weight_size = symbolic_helper._get_tensor_sizes(weight) + try: + kernel_shape = weight_size[2:] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + kernel_shape = None + + if kernel_shape is None or any(i is None for i in kernel_shape): + raise errors.SymbolicValueError( + "Unsupported: ONNX export of convolution for kernel of unknown shape.", + input, + ) + + args = [input, weight] + # ONNX only supports 1D bias + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) == 1 + ): + args.append(bias) + + kwargs = { + "kernel_shape_i": weight_size[2:], + "strides_i": stride, + # NB: ONNX supports asymmetric padding, whereas PyTorch supports only + # symmetric padding + "pads_i": padding + padding, + "dilations_i": dilation, + "group_i": groups, + } + + if any(o != 0 for o in output_padding): + # ONNX supports both output_shape and output_padding. they are equivalent expressive. + # output_padding is more straightforward, so we use it here. + # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2 + assert transposed + assert len(stride) == len(output_padding) + kwargs["output_padding_i"] = output_padding + + n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) + + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) != 1 + ): + return g.op("Add", n, bias) + else: + return n + + +@_onnx_symbolic("aten::_convolution_mode") +@symbolic_helper.parse_args( + "v", + "v", + "v", + "is", + "s", + "is", + "i", +) +def _convolution_mode( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + groups, +): + weight_size = symbolic_helper._get_tensor_sizes(weight) + try: + kernel_shape = weight_size[2:] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + kernel_shape = None + + if kernel_shape is None or any(i is None for i in kernel_shape): + raise errors.SymbolicValueError( + "Unsupported: ONNX export of convolution for kernel of unknown shape.", + input, + ) + + args = [input, weight] + # ONNX only supports 1D bias + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) == 1 + ): + args.append(bias) + + if padding == "valid": + padding = "VALID" + elif padding == "same": + padding = "SAME_UPPER" + kwargs = { + "kernel_shape_i": weight_size[2:], + "strides_i": stride, + "auto_pad_s": padding, + "dilations_i": dilation, + "group_i": groups, + } + + n = g.op("Conv", *args, **kwargs) + + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) != 1 + ): + return g.op("Add", n, bias) + else: + return n + + +@_onnx_symbolic("aten::convolution") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i") +def convolution( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv1d") +@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") +def conv1d( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + str_padding = symbolic_helper._parse_arg(padding, "s") + if str_padding in ["valid", "same"]: + return _convolution_mode( + g, + input, + weight, + bias, + stride, + str_padding, + dilation, + groups, + ) + else: + padding = symbolic_helper._parse_arg(padding, "is") + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + False, + (), + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv2d") +@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") +def conv2d( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + str_padding = symbolic_helper._parse_arg(padding, "s") + if str_padding in ["valid", "same"]: + return _convolution_mode( + g, + input, + weight, + bias, + stride, + str_padding, + dilation, + groups, + ) + else: + padding = symbolic_helper._parse_arg(padding, "is") + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + False, + (), + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv3d") +@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") +def conv3d( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + str_padding = symbolic_helper._parse_arg(padding, "s") + if str_padding in ["valid", "same"]: + return _convolution_mode( + g, + input, + weight, + bias, + stride, + str_padding, + dilation, + groups, + ) + else: + padding = symbolic_helper._parse_arg(padding, "is") + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + False, + (), + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv_transpose1d") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") +def conv_transpose1d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + output_padding, + groups, + dilation, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv_transpose2d") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") +def conv_transpose2d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + output_padding, + groups, + dilation, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv_transpose3d") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") +def conv_transpose3d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + output_padding, + groups, + dilation, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::batch_norm") +@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") +def batch_norm( + g: jit_utils.GraphContext, + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + cudnn_enabled, +): + symbolic_helper.check_training_mode(training, "batch_norm") + + if ( + torch.is_autocast_enabled() + and not symbolic_helper.args_have_same_dtype( + [input, weight, bias, running_mean, running_var] + ) + and GLOBALS.export_onnx_opset_version < 15 + ): + return symbolic_helper._onnx_opset_unsupported_detailed( + "BatchNormalization", + 9, + 15, + "All input tensors must have the same `dtype`." + " Turn off Autocast or export using opset version 15.", + input, + ) + + weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( + g, input, weight, bias, running_mean, running_var + ) + out = g.op( + "BatchNormalization", + input, + weight, + bias, + running_mean, + running_var, + epsilon_f=eps, + momentum_f=1 - momentum, + outputs=1 if not training else 5, + ) + if not training: + return out + else: + res, new_running_mean, new_running_var, saved_mean, saved_var = out + new_running_mean.setType(running_mean.type()) + new_running_var.setType(running_var.type()) + saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName()) + saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName()) + return res + + +@_onnx_symbolic("aten::native_layer_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "is", "v", "v", "f") +def native_layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, +) -> tuple[_C.Value, _C.Value, _C.Value]: + axes = [-i for i in range(len(normalized_shape), 0, -1)] + + two_cst = symbolic_helper._generate_wrapped_number(g, 2.0) + eps_cst = symbolic_helper._generate_wrapped_number(g, eps) + + if g.opset < 18: + mean = g.op("ReduceMean", input, axes_i=axes) + else: + mean = g.op( + "ReduceMean", + input, + g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), + ) + + numerator = sub(g, input, mean) + + # Cast it to eps dtype to avoid precision loss + is_type_half = ( + _type_utils.JitScalarType.from_value(numerator) + == _type_utils.JitScalarType.HALF + ) + if is_type_half: + eps_dtype = _type_utils.JitScalarType.from_value(eps_cst) + numerator = g.op( + "Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type() + ) + + # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula + if g.opset < 18: + variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) + else: + variance = g.op( + "ReduceMean", + pow(g, numerator, two_cst), + g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), + ) + + denominator = sqrt(g, g.op("Add", variance, eps_cst)) + normalized = g.op("Div", numerator, denominator) + + # Cast back to input type as eps related ops are all done + if is_type_half: + input_dtype = _type_utils.JitScalarType.from_value(input) + normalized = g.op( + "Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() + ) + + if not (weight is None or symbolic_helper._is_none(weight)): + normalized = mul(g, normalized, weight) + if not (bias is None or symbolic_helper._is_none(bias)): + normalized = add(g, normalized, bias) + + # rdenominator := 1 / sqrt(variance + eps) + # According to aten::native_layer_norm, rdenominator should have the same dtype as input, + # mean and normalized, so we need to Cast it back + if is_type_half: + denominator = g.op( + "Cast", + denominator, + to_i=_type_utils.JitScalarType(input_dtype).onnx_type(), # type: ignore[possibly-undefined] + ) + rdenominator = g.op("Reciprocal", denominator) + else: + rdenominator = reciprocal(g, denominator) + + return normalized, mean, rdenominator + + +@_onnx_symbolic("aten::layer_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "is", "v", "v", "f", "b") +def layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, + cudnn_enable: bool, +) -> _C.Value: + normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps) + return normalized + + +@_onnx_symbolic("aten::instance_norm") +@symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b") +def instance_norm( + g: jit_utils.GraphContext, + input, + weight, + bias, + running_mean, + running_var, + use_input_stats: bool, + momentum: Number, + eps: Number, + cudnn_enabled: bool, +): + symbolic_helper.check_training_mode(use_input_stats, "instance_norm") + channel_size = symbolic_helper._get_tensor_dim_size(input, 1) + if weight is None or symbolic_helper._is_none(weight): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of instance_norm for unknown channel size.", + input, + ) + weight_value = torch.tensor( + [1.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + weight = g.op("Constant", value_t=weight_value) + if bias is None or symbolic_helper._is_none(bias): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of instance_norm for unknown channel size.", + input, + ) + bias_value = torch.tensor( + [0.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + bias = g.op("Constant", value_t=bias_value) + if ( + running_mean is None + or symbolic_helper._is_none(running_mean) + or running_var is None + or symbolic_helper._is_none(running_var) + ): + return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) + else: + input_size = symbolic_helper._get_tensor_sizes(input) + # If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm. + # For more information instance_norm(): + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542 + input_size_reshape = input_size.copy() + n = input_size[0] + if n is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of instance_norm training for unknown " + "batch size.", + input, + ) + c = input_size[1] + input_size_reshape[0] = 1 + input_size_reshape[1] = n * c + weight_ = repeat( + g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) + ) + bias_ = repeat( + g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) + ) + running_mean_ = repeat( + g, + running_mean, + g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), + ) + running_var_ = repeat( + g, + running_var, + g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), + ) + input_reshaped = g.op( + "Reshape", + input, + g.op("Constant", value_t=torch.LongTensor(input_size_reshape)), + ) + out = batch_norm( + g, + input_reshaped, + weight_, + bias_, + running_mean_, + running_var_, + use_input_stats, + momentum, + eps, + cudnn_enabled, + ) + return view(g, out, g.op("Constant", value_t=torch.tensor(input_size))) + + +@_onnx_symbolic("aten::unfold") +@symbolic_helper.parse_args("v", "i", "i", "i") +def unfold(g: jit_utils.GraphContext, input, dimension, size, step): + sizes = symbolic_helper._get_tensor_sizes(input) + # FIXME(justinchuby): Get rid of the try catch here to improve readability + try: + sizedim = sizes[dimension] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + sizedim = None + if sizedim is not None: + low_indices = range(0, sizedim, step) + hi_indices = range(size, sizedim + 1, step) + stack = [ + symbolic_helper._slice_helper( + g, input, axes=[dimension], starts=[low], ends=[hi] + ) + for low, hi in zip(low_indices, hi_indices) + ] + ndim = len(sizes) + perm = list(range(0, ndim)) + perm.append(perm.pop(dimension)) + unsqueeze = [ + symbolic_helper._unsqueeze_helper( + g, g.op("Transpose", t, perm_i=perm), [dimension] + ) + for t in stack + ] + return g.op("Concat", *unsqueeze, axis_i=dimension) + else: + return symbolic_helper._unimplemented( + "Unfold", "input size not accessible", input + ) + + +@_onnx_symbolic("aten::elu") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "t", "t", "t") +def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale): + if scale and scale != 1.0: + return symbolic_helper._unimplemented( + "scale", "does not support scale in Elu", scale + ) + if input_scale and input_scale != 1.0: + return symbolic_helper._unimplemented( + "input_scale", "does not support input_scale in Elu", input_scale + ) + # See Note [Export inplace] + return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha)) + + +@_onnx_symbolic("aten::selu") +@symbolic_helper.quantized_args(True) +def selu(g: jit_utils.GraphContext, input): + return g.op("Selu", input) + + +@_onnx_symbolic("aten::index_select") +@symbolic_helper.parse_args("v", "i", "v") +def index_select(g: jit_utils.GraphContext, self, dim, index): + # In case of a scalar index, index_select returns a tensor with the same rank as the input. + # To match this behavior in ONNX, we make index a 1D tensor so that the following gather + # also produces a tensor with the same rank as the input. + return symbolic_helper._select_helper(g, self, dim, index) + + +@_onnx_symbolic("aten::index_put") +def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate): + if symbolic_helper._is_packed_list(indices_list_value): + indices_list = symbolic_helper._unpack_list(indices_list_value) + else: + indices_list = [indices_list_value] + + accumulate = symbolic_helper._parse_arg(accumulate, "b") + + if len(indices_list) == 0: + if accumulate: + return add(g, self, values) + return values + symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self) + + +@_onnx_symbolic("aten::index_fill") +def index_fill(g: jit_utils.GraphContext, self, dim, index, value): + dim_value = symbolic_helper._parse_arg(dim, "i") + expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + value = symbolic_helper._maybe_get_scalar(value) + value = symbolic_helper._if_scalar_type_as(value, self) + expanded_value = expand(g, value, expanded_index_shape, None) + + return scatter(g, self, dim, expanded_index, expanded_value) + + +@_onnx_symbolic("aten::index_copy") +def index_copy(g: jit_utils.GraphContext, self, dim, index, source): + dim_value = symbolic_helper._parse_arg(dim, "i") + expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + return scatter(g, self, dim, expanded_index, source) + + +@_onnx_symbolic("aten::bucketize") +@symbolic_helper.parse_args("v", "v", "b", "b") +def bucketize( + g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False +): + out_type = _C_onnx.TensorProtoDataType.INT64 + if out_int32: + out_type = _C_onnx.TensorProtoDataType.INT32 + # A tensor expanded_boundaries is created such that it + # contains a copy of boundaries for each element of self. + new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0) + # Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops + # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md + tensor_rank = symbolic_helper._get_tensor_rank(self) + assert tensor_rank is not None + unsqueeze_axes = list(range(1, tensor_rank + 1)) + expanded_boundaries = expand( + g, + symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes), + new_shape, + None, + ) + # Compare each element of self to boundaries to get a tensor + # with leading 1s and trailing 0s. + # e.g., 4 > [1, 3, 4] = [1, 1, 0] + # The index of the last 1 is the bucket where the element should go. + if right: + cond = ge(g, self, expanded_boundaries) + else: + cond = gt(g, self, expanded_boundaries) + cond_out = g.op("Cast", cond, to_i=out_type) + # Sum to get the number of 1s corresponding to each element, + # which is the same as the bucket index. + # e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2 + return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0) + + +@_onnx_symbolic("aten::type_as") +def type_as(g: jit_utils.GraphContext, self, other): + self_dtype = symbolic_helper._try_get_scalar_type(self) + other_dtype = symbolic_helper._try_get_scalar_type(other) + if self_dtype == other_dtype and self_dtype is not None: + return self + if other_dtype is not None: + return g.op( + "Cast", + self, + to_i=other_dtype.onnx_type(), + ) + + raise errors.SymbolicValueError( + "Unsupported: ONNX export of type_as for tensor " + "of unknown dtype. Please check if the dtype of the " + "parameter passed to the type_as function is correct.", + other, + ) + + +@_onnx_symbolic("aten::cosine_similarity") +@symbolic_helper.parse_args("v", "v", "i", "f") +def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps): + cross = symbolic_helper._reducesum_helper( + g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 + ) + x1_l2 = symbolic_helper._reducesum_helper( + g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0 + ) + x2_l2 = symbolic_helper._reducesum_helper( + g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0 + ) + div_tens = max( + g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])) + ) + return div(g, cross, div_tens) + + +@_onnx_symbolic("aten::pairwise_distance") +def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim): + if not symbolic_helper._is_value(eps): + eps = g.op("Constant", value_t=torch.tensor([eps])) + inv_p = div( + g, + g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)), + add(g, p, eps), + ) + summation = symbolic_helper._reducesum_helper( + g, + pow(g, sub(g, input1, input2), p), + axes_i=[-1], + keepdims_i=symbolic_helper._parse_arg(keepdim, "i"), + ) + return pow(g, summation, inv_p) + + +@_onnx_symbolic("aten::clone") +# ignore clone operators that are inserted by PyTorch autograd +def clone(g: jit_utils.GraphContext, input, unused_memory_format): + return input + + +@_onnx_symbolic("aten::abs") +def abs(g: jit_utils.GraphContext, self): + return g.op("Abs", self) + + +@_onnx_symbolic("aten::log") +def log(g: jit_utils.GraphContext, self): + return g.op("Log", self) + + +@_onnx_symbolic("aten::log1p") +def log1p(g: jit_utils.GraphContext, self): + return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self)) + + +@_onnx_symbolic("aten::log10") +def log10(g: jit_utils.GraphContext, self): + _ln10 = 2.30258509299404568401 + return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10]))) + + +@_onnx_symbolic("aten::pow") +def pow(g: jit_utils.GraphContext, self, exponent): + f_dtype = _type_utils.JitScalarType.from_value(self) + if not symbolic_helper._is_fp(self): + f_dtype = _type_utils.JitScalarType.FLOAT + self = g.op("Cast", self, to_i=f_dtype.onnx_type()) + if not symbolic_helper._is_fp(exponent): + exponent = g.op( + "Cast", + exponent, + to_i=f_dtype.onnx_type(), + ) + pow = g.op("Pow", self, exponent) + return pow + + +@_onnx_symbolic("aten::clamp") +def clamp(g: jit_utils.GraphContext, self, min, max): + # min or max may be None that we need to dispatch to + # Clip separately, as ONNX does not have None syntax + if symbolic_helper._is_none(min): + return clamp_max(g, self, max) + elif symbolic_helper._is_none(max): + return clamp_min(g, self, min) + else: + if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max): + return symbolic_helper._op_with_optional_float_cast( + g, + "Clip", + self, + min_f=symbolic_helper._parse_arg(min, "f"), + max_f=symbolic_helper._parse_arg(max, "f"), + opset_before=12, + ) + else: + return clamp_max(g, clamp_min(g, self, min), max) + + +@_onnx_symbolic("aten::clamp_min") +@symbolic_helper.parse_args("v", "v") +def clamp_min(g: jit_utils.GraphContext, self, min): + if symbolic_helper._is_constant(min): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12 + ) + else: + dtype = _type_utils.JitScalarType.from_value(self) + min = g.op("Cast", min, to_i=dtype.onnx_type()) + return symbolic_helper._op_with_optional_float_cast( + g, "Max", self, min, opset_before=12 + ) + + +@_onnx_symbolic("aten::clamp_max") +@symbolic_helper.parse_args("v", "v") +def clamp_max(g: jit_utils.GraphContext, self, max): + if symbolic_helper._is_constant(max): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12 + ) + else: + dtype = _type_utils.JitScalarType.from_value(self) + max = g.op("Cast", max, to_i=dtype.onnx_type()) + return symbolic_helper._op_with_optional_float_cast( + g, "Min", self, max, opset_before=12 + ) + + +@_onnx_symbolic("aten::max") +# torch.max (same for torch.min) actually has two interfaces smashed together: +# torch.max(x, dim, keepdim) and torch.max(x, y) +# TODO(justinchuby): Support multiple quantized args in output +def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::maximum") +@symbolic_helper.quantized_args(True, True) +def maximum(g: jit_utils.GraphContext, input, other): + return max(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::min") +# TODO(justinchuby): Support multiple quantized args in output +def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::minimum") +@symbolic_helper.quantized_args(True, True) +def minimum(g: jit_utils.GraphContext, input, other): + return min(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::amax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amax(g: jit_utils.GraphContext, self, dim, keepdim): + return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::amin") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amin(g: jit_utils.GraphContext, self, dim, keepdim): + return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::aminmax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "i") +def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): + reduce_kwargs = {"keepdims_i": keepdim} + if not symbolic_helper._is_none(dim): + dim = symbolic_helper._get_const(dim, "i", "dim") + reduce_kwargs["axes_i"] = [dim] + + return g.op("ReduceMin", self, **reduce_kwargs), g.op( + "ReduceMax", self, **reduce_kwargs + ) + + +@_onnx_symbolic("aten::exp") +def exp(g: jit_utils.GraphContext, self): + return g.op("Exp", self) + + +@_onnx_symbolic("aten::dropout_") +@_onnx_symbolic("aten::dropout") +@symbolic_helper.parse_args("v", "f", "i") +def dropout(g: jit_utils.GraphContext, input, p, train): + symbolic_helper.check_training_mode(train, "dropout") + # if train is False, dropout is no-op + if not train: + return input + r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) + return r + + +@_onnx_symbolic( + "aten::alpha_dropout_", + decorate=[symbolic_helper._apply_params("aten::alpha_dropout_")], +) # See Note [Export inplace] +@_onnx_symbolic( + "aten::feature_alpha_dropout_", + decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout_")], +) +@_onnx_symbolic( + "aten::feature_dropout_", + decorate=[symbolic_helper._apply_params("aten::feature_dropout_")], +) +@_onnx_symbolic( + "aten::feature_alpha_dropout", + decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout")], +) +@_onnx_symbolic( + "aten::alpha_dropout", + decorate=[symbolic_helper._apply_params("aten::alpha_dropout")], +) +@_onnx_symbolic( + "aten::feature_dropout", + decorate=[symbolic_helper._apply_params("aten::feature_dropout")], +) +def _unsupported_dropout(name: str): + @symbolic_helper.parse_args("v", "none", "b") + def feature_dropout(g, input, p, train): + # NB: In inference mode, FeatureDropout is exported as an identity op. + if train: + return symbolic_helper._unimplemented(name, "training mode", input) + return input + + return feature_dropout + + +@_onnx_symbolic("aten::norm") +@symbolic_helper.parse_args("v", "t", "is", "i", "v") +def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None): + if p == 1: + f = symbolic_helper._reduce_op_symbolic_helper("ReduceL1") + elif p == 2: + f = symbolic_helper._reduce_op_symbolic_helper("ReduceL2") + else: + raise errors.SymbolicValueError( + "ONNX export only p-norms with p of 1 or 2", self + ) + result = f(g, self, dim=dim, keepdim=keepdim) + if dtype is not None: + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + return result + + +@_onnx_symbolic("aten::conv_tbc") +@symbolic_helper.parse_args("v", "v", "v", "i") +def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad): + # input must have 3 dimensions, see: + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 + # input = (time, batch, in_channels) + # weight = (kernel_width, in_channels, out_channels) + # bias = (out_channels,) + input = g.op("Transpose", input, perm_i=[1, 2, 0]) + weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) + conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) + return g.op("Transpose", conv, perm_i=[2, 0, 1]) + + +@_onnx_symbolic("aten::_unique") +@symbolic_helper.parse_args("v", "i", "i") +def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): + return symbolic_helper._onnx_unsupported("_unique", input) + + +@_onnx_symbolic("aten::_unique2") +@symbolic_helper.parse_args("v", "i", "i", "i") +def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): + symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) + + +@_onnx_symbolic("aten::_cast_Byte") +@_deprecation.deprecated( + "2.0", + "the future", + "Avoid using this function and create a Cast node instead", +) +def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8) + + +@_onnx_symbolic("aten::_cast_Char") +@_deprecation.deprecated( + "2.0", + "the future", + "Avoid using this function and create a Cast node instead", +) +def _cast_Char(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8) + + +@_onnx_symbolic("aten::_cast_Short") +@_deprecation.deprecated( + "2.0", + "the future", + "Avoid using this function and create a Cast node instead", +) +def _cast_Short(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16) + + +@_onnx_symbolic("aten::_cast_Int") +@_deprecation.deprecated( + "2.0", + "the future", + "Avoid using this function and create a Cast node instead", +) +def _cast_Int(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + + +@_onnx_symbolic("aten::_cast_Long") +@_deprecation.deprecated( + "2.0", + "the future", + "Avoid using this function and create a Cast node instead", +) +def _cast_Long(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) + + +@_onnx_symbolic("aten::_cast_Half") +@_deprecation.deprecated( + "2.0", + "the future", + "Avoid using this function and create a Cast node instead", +) +def _cast_Half(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + + +@_onnx_symbolic("aten::_cast_Float") +@_deprecation.deprecated( + "2.0", + "the future", + "Avoid using this function and create a Cast node instead", +) +def _cast_Float(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + +@_onnx_symbolic("aten::_cast_Double") +@_deprecation.deprecated( + "2.0", + "the future", + "Avoid using this function and create a Cast node instead", +) +def _cast_Double(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE) + + +@_onnx_symbolic("aten::_cast_Bool") +@_deprecation.deprecated( + "2.0", + "the future", + "Avoid using this function and create a Cast node instead", +) +def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL) + + +@_onnx_symbolic("aten::empty") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty( + g: jit_utils.GraphContext, + sizes, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + return zeros(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::empty_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty_like( + g: jit_utils.GraphContext, + input, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + return zeros_like(g, input, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::new_empty") +def new_empty( + g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return empty(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::scalar_tensor") +def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + dtype = _type_utils.JitScalarType.FLOAT + scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + return scalar + + +@_onnx_symbolic("aten::tensor") +def tensor( + g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False +): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if symbolic_helper._is_packed_list(data): + if dtype is None: + dtype = _type_utils.JitScalarType.from_value( + symbolic_helper._unpack_list(data)[0] + ) + input_list = [] + for t in symbolic_helper._unpack_list(data): + shape_reference = g.op("Constant", value_t=torch.LongTensor([1])) + t = symbolic_helper._reshape_helper(g, t, shape_reference) + t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + input_list.append(t) + return g.op("Concat", *input_list, axis_i=0) + else: + if dtype is None: + dtype = _type_utils.JitScalarType.from_value(data) + if symbolic_helper._is_list(data) and ( + symbolic_helper._is_tensor_list(data) + or symbolic_helper._is_scalar_list(data) + ): + data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1) + return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + + +@_onnx_symbolic("aten::as_tensor") +def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None): + return tensor(g, data, dtype, device) + + +@_onnx_symbolic("aten::zeros") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + sizes_ = symbolic_helper._maybe_get_const(sizes, "is") + if isinstance(sizes_, list) and len(sizes_) == 0: + sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) + return g.op( + "ConstantOfShape", + sizes, + value_t=torch.tensor([0], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::zeros_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def zeros_like( + g: jit_utils.GraphContext, + input, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + if symbolic_helper._is_none(dtype): + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + return g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([0], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::new_zeros") +def new_zeros( + g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return zeros(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::zero") +def zero(g: jit_utils.GraphContext, self): + self_dtype = symbolic_helper._try_get_scalar_type(self) + return zeros_like(g, self, self_dtype) + + +@_onnx_symbolic("aten::ones") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + sizes_ = symbolic_helper._maybe_get_const(sizes, "is") + if isinstance(sizes_, list) and len(sizes_) == 0: + sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) + return g.op( + "ConstantOfShape", + sizes, + value_t=torch.tensor([1], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::ones_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def ones_like( + g: jit_utils.GraphContext, + input, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + if symbolic_helper._is_none(dtype): + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + return g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([1], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::new_ones") +def new_ones( + g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return ones(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::full") +def full( + g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False +): + const_value = symbolic_helper._maybe_get_const(value, "t") + if symbolic_helper._is_value(const_value): + dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype + tmp = zeros(g, sizes, dtype, layout, device) + return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) + else: + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + sizes_ = symbolic_helper._maybe_get_const(sizes, "is") + if isinstance(sizes_, list) and len(sizes_) == 0: + sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) + return g.op( + "ConstantOfShape", + sizes, + value_t=const_value.view(1).to(scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::full_like") +def full_like( + g: jit_utils.GraphContext, + input, + fill_value, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + fill_value = symbolic_helper._maybe_get_const(fill_value, "f") + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + if symbolic_helper._is_value(fill_value): + tmp = zeros_like(g, input, dtype, layout, device) + fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type()) + return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1))) + else: + shape = g.op("Shape", input) + return g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::new_full") +def new_full( + g: jit_utils.GraphContext, + self, + size, + fill_value, + dtype, + layout, + device, + pin_memory=False, +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return full(g, size, fill_value, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::eye") +def eye(g: jit_utils.GraphContext, *args): + if len(args) == 5: + # aten::eye(n, dtype, layout, device, pin_memory) + n, dtype, layout, device, pin_memory = args + dim_size = symbolic_helper._unsqueeze_helper(g, n, [0]) + shape = g.op("Concat", dim_size, dim_size, axis_i=0) + tensor = zeros(g, shape, dtype, layout, device) + return g.op("EyeLike", tensor) + if len(args) == 6: + # aten::eye(n, m, dtype, layout, device, pin_memory) + n, m, dtype, layout, device, pin_memory = args + shape = g.op( + "Concat", + symbolic_helper._unsqueeze_helper(g, n, [0]), + symbolic_helper._unsqueeze_helper(g, m, [0]), + axis_i=0, + ) + tensor = zeros(g, shape, dtype, layout, device) + return g.op("EyeLike", tensor) + + return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments") + + +@_onnx_symbolic("aten::slice") +def slice(g: jit_utils.GraphContext, self, *args): + if len(args) == 4: + # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor + dim, start, end, step = args + step = symbolic_helper._parse_arg(step, "i") + if step != 1: + raise errors.SymbolicValueError("step!=1 is currently not supported", self) + is_start_none = start.node().kind() == "prim::Constant" and isinstance( + start.type(), _C.NoneType + ) + is_end_none = end.node().kind() == "prim::Constant" and isinstance( + end.type(), _C.NoneType + ) + is_start_onnx_const = start.node().kind() == "onnx::Constant" + is_end_onnx_const = end.node().kind() == "onnx::Constant" + if ( + ((not is_start_none) and (not is_start_onnx_const)) + or ((not is_end_none) and (not is_end_onnx_const)) + or dim.node().kind() != "onnx::Constant" + ): + if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice " + "is a deprecated experimental op. Please use statically allocated " + "variables or export to a higher opset version.", + self, + ) + else: + start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0]) + end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0]) + dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0]) + return g.op( + "DynamicSlice", + self, + start_unsqueezed, + end_unsqueezed, + dim_unsqueezed, + ) + else: + start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") + end = ( + _constants.INT64_MAX + if is_end_none + else symbolic_helper._parse_arg(end, "i") + ) + dim = symbolic_helper._parse_arg(dim, "i") + return symbolic_helper._slice_helper( + g, self, axes=[dim], starts=[start], ends=[end] + ) + elif len(args) == 3: + # aten::slice(t[] l, int start, int end, int step) -> t[] + start, end, step = args + dim = 0 + is_start_none = start.node().kind() == "prim::Constant" and isinstance( + start.type(), _C.NoneType + ) + is_end_none = end.node().kind() == "prim::Constant" and isinstance( + end.type(), _C.NoneType + ) + start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") + end = ( + _constants.INT64_MAX + if is_end_none + else symbolic_helper._parse_arg(end, "i") + ) + return symbolic_helper._slice_helper( + g, self, axes=[dim], starts=[start], ends=[end] + ) + + return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments") + + +@_onnx_symbolic("aten::hardtanh") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "f", "f") +def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12 + ) + + +@_onnx_symbolic("aten::hardswish") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v") +def hardswish(g: jit_utils.GraphContext, self): + hs = hardsigmoid(g, self) + return g.op("Mul", self, hs) + + +@_onnx_symbolic("aten::hardsigmoid") +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp +@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) +@symbolic_helper.parse_args("v") +def hardsigmoid(g: jit_utils.GraphContext, self): + # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid. + # See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html + return g.op("HardSigmoid", self, alpha_f=1 / 6) + + +@_onnx_symbolic("aten::tanhshrink") +@symbolic_helper.parse_args("v") +def tanhshrink(g: jit_utils.GraphContext, self): + return g.op("Sub", self, tanh(g, self)) + + +@_onnx_symbolic("aten::hardshrink") +@symbolic_helper.parse_args("v", "f") +def hardshrink(g: jit_utils.GraphContext, self, lambd): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + lambd_op = g.op( + "Constant", + value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), + ) + cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op))) + return g.op( + "Where", + cond, + self, + g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ), + ) + + +@_onnx_symbolic("aten::softshrink") +@symbolic_helper.parse_args("v", "f") +def softshrink(g: jit_utils.GraphContext, self, lambd): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + lambd_op = g.op( + "Constant", + value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), + ) + gt_cond = gt(g, self, lambd_op) + gt_out = g.op( + "Where", + gt_cond, + sub(g, self, lambd_op), + g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ), + ) + lt_cond = lt(g, self, neg(g, lambd_op)) + lt_out = g.op( + "Where", + lt_cond, + add(g, self, lambd_op), + g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ), + ) + return add(g, gt_out, lt_out) + + +@_onnx_symbolic("aten::alias") +def alias(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("aten::unsqueeze") +@symbolic_helper.parse_args("v", "i") +def unsqueeze(g: jit_utils.GraphContext, self, dim): + """Implement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`""" + # Handle negative dim + if dim < 0: + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + warnings.warn( + "ONNX export unsqueeze with negative axis " + + str(dim) + + " might cause the onnx model to be incorrect. " + + "Negative axis is not supported in ONNX. " + + "Axis is converted to " + + str(dim + rank + 1) + + " based on input shape at export time. " + + "Passing an tensor of different rank in execution will be incorrect." + ) + dim = dim + rank + 1 + else: + return symbolic_helper._unimplemented( + "unsqueeze", "negative axis with unknown input rank", self + ) + + return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim]) + + +@_onnx_symbolic("aten::sort") +# TODO(justinchuby): Support multiple quantized args in output +@symbolic_helper.parse_args("v", "i", "i", "none") +def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): + if out is not None: + symbolic_helper._unimplemented( + "Sort", "Out parameter is not supported for sort", self + ) + self_sizes = symbolic_helper._get_tensor_sizes(self) + try: + dim_size = self_sizes[dim] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + dim_size = None + + if dim_size is None: + return symbolic_helper._unimplemented("Sort", "input size not accessible", self) + + return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) + + +@_onnx_symbolic("aten::numel") +def numel(g: jit_utils.GraphContext, self): + return symbolic_helper._numel_helper(g, self) + + +@_onnx_symbolic("aten::topk") +# TODO(justinchuby): Support multiple quantized args in output +@symbolic_helper.parse_args("v", "i", "i", "i", "i", "none") +def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): + if out is not None: + symbolic_helper._unimplemented( + "TopK", "Out parameter is not supported for topk", self + ) + if not largest: + symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self) + + return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) + + +@_onnx_symbolic("prim::convert_element_type") +def convert_element_type(g: jit_utils.GraphContext, self, *args): + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + + +@_onnx_symbolic("aten::to") +def to(g: jit_utils.GraphContext, self, *args): + def is_aten_to_device_only(args): + if len(args) == 4: + # aten::to(Tensor, Device, bool, bool, memory_format) + return ( + args[0].node().kind() == "prim::device" + or args[0].type().isSubtypeOf(_C.ListType.ofInts()) + or isinstance(args[0].type(), _C.DeviceObjType) + ) + elif len(args) == 5: + # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) + # When dtype is None, this is a aten::to(device) call + dtype = symbolic_helper._get_const(args[1], "i", "dtype") + return dtype is None + elif len(args) in (6, 7): + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor + # When dtype is None, this is a aten::to(device) call + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + return dtype is None + return False + + # ONNX doesn't have a concept of a device, so we ignore device-only casts + if is_aten_to_device_only(args): + return self + + if len(args) == 4: + # TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=]() + # In this case, the constant value is a tensor not int, + # so symbolic_helper._maybe_get_const(args[0], 'i') would not work. + dtype = args[0] + if ( + symbolic_helper._is_value(args[0]) + and args[0].node().kind() == "onnx::Constant" + ): + tval = symbolic_helper._node_get(args[0].node(), "value") + if isinstance(tval, torch.Tensor): + if len(tval.shape) == 0: + tval = tval.item() + dtype = int(tval) + else: + dtype = tval + + if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor): + # aten::to(Tensor, Tensor, bool, bool, memory_format) + dtype = _type_utils.JitScalarType.from_value(args[0]) + return g.op( + "Cast", + self, + to_i=dtype.onnx_type(), + ) + else: + # aten::to(Tensor, ScalarType, bool, bool, memory_format) + # memory_format is ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + elif len(args) == 5: + # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) + dtype = symbolic_helper._get_const(args[1], "i", "dtype") + # memory_format is ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + elif len(args) == 6: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + # Layout, device and memory_format are ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + elif len(args) == 7: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + # Layout, device and memory_format are ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + + return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self) + + +@_onnx_symbolic("aten::repeat") +def repeat(g: jit_utils.GraphContext, self, repeats): + dtype = _type_utils.JitScalarType.INT64 + shape_ = ones_like(g, repeats, dtype) + self = g.op("Expand", self, shape_) + return g.op("Tile", self, repeats) + + +@_onnx_symbolic("aten::repeat_interleave") +def repeat_interleave( + g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None +): + repeats_dim = symbolic_helper._get_tensor_rank(repeats) + repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) + input_sizes = symbolic_helper._get_tensor_sizes(self) + if repeats_dim is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", + self, + ) + if repeats_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", + self, + ) + if input_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown input size.", + self, + ) + + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if symbolic_helper._is_none(dim): + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1])) + ) + dim = torch.tensor(0, dtype=torch.int64) + else: + dim = symbolic_helper._maybe_get_scalar(dim) + + # Handle cases where dim is negative + if dim < 0: + dim += len(input_sizes) + + input_sizes_temp = input_sizes.copy() + for idx, input_size in enumerate(input_sizes): + if input_size is None: + input_sizes[idx], input_sizes_temp[idx] = 0, -1 + + # Cases where repeats is an int or single value tensor + if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): + if input_sizes[dim] == 0: + return symbolic_helper._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported along dimension with unknown input size", + self, + ) + return symbolic_helper._repeat_interleave_single_value_repeat_helper( + g, self, repeats, dim + ) + + # Cases where repeats is a 1 dim Tensor + elif repeats_dim == 1: + if input_sizes[dim] == 0: + return symbolic_helper._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported along dimension with unknown input size", + self, + ) + if repeats_sizes[0] is None: + return symbolic_helper._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported for cases with dynamic repeats", + self, + ) + assert ( + repeats_sizes[0] == input_sizes[dim] + ), "repeats must have the same size as input along dim" + reps = repeats_sizes[0] + else: + raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self) + + final_splits = [] + r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0) + i_splits = symbolic_helper._repeat_interleave_split_helper(g, self, reps, dim) + input_sizes[dim], input_sizes_temp[dim] = -1, 1 + for idx, r_split in enumerate(r_splits): + i_split = unsqueeze(g, i_splits[idx], dim + 1) + r_concat = [ + g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), + r_split, + g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), + ] + r_concat = g.op("Concat", *r_concat, axis_i=0) + i_split = expand(g, i_split, r_concat, None) + i_split = symbolic_helper._reshape_helper( + g, + i_split, + g.op("Constant", value_t=torch.LongTensor(input_sizes)), + allowzero=0, + ) + final_splits.append(i_split) + return g.op("Concat", *final_splits, axis_i=dim) + + +@_onnx_symbolic("aten::pixel_shuffle") +@symbolic_helper.parse_args("v", "i") +def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): + dims = symbolic_helper._get_tensor_sizes(self) + if len(dims) != 4: + return symbolic_helper._unimplemented( + "pixel_shuffle", "only support 4d input", self + ) + if any(i is None for i in dims[1:]): + after_view = symbolic_helper._reshape_helper( + g, + symbolic_helper._unsqueeze_helper(g, self, [2, 3]), + g.op( + "Constant", + value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]), + ), + allowzero=0, + ) + after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) + # For dynamic input shapes, two reshapes are performed + reshape_h = symbolic_helper._reshape_helper( + g, + after_transpose, + g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])), + allowzero=0, + ) + reshape_w = symbolic_helper._reshape_helper( + g, + reshape_h, + g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])), + allowzero=0, + ) + return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5]) + else: + output_channel = dims[1] // upscale_factor // upscale_factor + after_view = symbolic_helper._reshape_helper( + g, + self, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + output_channel, + upscale_factor, + upscale_factor, + dims[2], + dims[3], + ] + ), + ), + allowzero=0, + ) + after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) + return symbolic_helper._reshape_helper( + g, + after_transpose, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + output_channel, + dims[2] * upscale_factor, + dims[3] * upscale_factor, + ] + ), + ), + allowzero=0, + ) + + +@_onnx_symbolic("aten::pixel_unshuffle") +@symbolic_helper.parse_args("v", "i") +def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor): + dims = symbolic_helper._get_tensor_sizes(self) + if len(dims) != 4: + return symbolic_helper._unimplemented( + "pixel_shuffle", "only support 4d input", self + ) + if any(i is None for i in dims[1:]): + # For dynamic input shapes, two reshapes are performed + reshape_h = symbolic_helper._reshape_helper( + g, + symbolic_helper._unsqueeze_helper(g, self, [3]), + g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])), + allowzero=0, + ) + reshape_w = symbolic_helper._reshape_helper( + g, + reshape_h, + g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])), + allowzero=0, + ) + after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4]) + final_reshape = symbolic_helper._reshape_helper( + g, + after_transpose, + g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])), + allowzero=0, + ) + return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3]) + else: + output_channel = dims[1] * downscale_factor * downscale_factor + after_view = symbolic_helper._reshape_helper( + g, + self, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + dims[1], + dims[2] // downscale_factor, + downscale_factor, + dims[3] // downscale_factor, + downscale_factor, + ] + ), + ), + allowzero=0, + ) + after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4]) + return symbolic_helper._reshape_helper( + g, + after_transpose, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + output_channel, + dims[2] // downscale_factor, + dims[3] // downscale_factor, + ] + ), + ), + allowzero=0, + ) + + +def _generic_rnn( + g: jit_utils.GraphContext, + variant, + input, + initial_states, + all_weights, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first=None, + batch_sizes=None, +): + warnings.warn( + "Exporting a model to ONNX with a batch_size other than 1, " + + "with a variable length with " + + variant + + " can cause an error " + + "when running the ONNX model with a different batch size. " + + "Make sure to save the model with a batch size of 1, " + + "or define the initial states (h0/c0) as inputs of the model. " + ) + + onnxActivations = [ + "Relu", + "Tanh", + "Sigmoid", + "Affine", + "LeakyRelu", + "ThresholdedRelu", + "ScaledTanh", + "HardSigmoid", + "Elu", + "Softsign", + "Softplus", + ] + variantToOnnxActivationMap = dict( + zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations) + ) + weights_per_layer = 4 if has_biases else 2 + # this means that projections are used inside LSTM, so need to tell user that it's not supported + if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * ( + 1 + bidirectional + ): + return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input) + assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) + layer_weights = [ + all_weights[i : i + weights_per_layer] + for i in range(0, len(all_weights), weights_per_layer) + ] + if batch_first: + # batch, seq, feat -> seq, batch, feat + input = g.op("Transpose", input, perm_i=[1, 0, 2]) + if dropout and train: + return symbolic_helper._unimplemented( + "RNN/GRU/LSTM", "dropout in training mode", input + ) + + if variant.startswith("RNN"): + nonlinearity = variantToOnnxActivationMap[variant[4:].lower()] + variant = "RNN" + + w_hh = all_weights[1] + hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1) + if hidden_size is None: + return symbolic_helper._unimplemented( + "RNN/GRU/LSTM", "unknown hidden size", input + ) + + unidirectional = not bidirectional + + prev_output = input + + h_outs = [] + if variant == "RNN" or variant == "GRU": + h0 = initial_states + elif variant == "LSTM": + h0, c0 = initial_states + c_outs = [] + + sequence_lens = unused(g) if batch_sizes is None else batch_sizes + + if variant == "GRU": + # pytorch is reset, input, hidden + # onnx is input, reset, hidden + reform_permutation = [(1, 2), (0, 1), (2, 3)] + elif variant == "LSTM": + # pytorch is input, forget, cell, output. + # onnx is input, output, forget, cell. + reform_permutation = [(0, 1), (3, 4), (1, 3)] + + def reform_weights(g, w, n, intervals): + slices = [ + symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) + for x, y in intervals + ] + return g.op("Concat", *slices, axis_i=0) + + def transform_weights_no_bias(layer_index): + weights = layer_weights[layer_index] + if variant == "RNN": + weight_ih, weight_hh = weights + elif variant == "GRU" or variant == "LSTM": + weight_ih, weight_hh = ( + reform_weights(g, w, hidden_size, reform_permutation) for w in weights + ) + return tuple( + symbolic_helper._unsqueeze_helper(g, x, [0]) + for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined] + ) + + def transform_weights(layer_index): + weights = layer_weights[layer_index] + if variant == "RNN": + weight_ih, weight_hh, bias_ih, bias_hh = weights + elif variant == "GRU" or variant == "LSTM": + weight_ih, weight_hh, bias_ih, bias_hh = ( + reform_weights(g, w, hidden_size, reform_permutation) for w in weights + ) + bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined] + return tuple( + symbolic_helper._unsqueeze_helper(g, x, [0]) + for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined] + ) + + def retrieve_state(x, start, end): + return ( + x + if num_layers == 1 + else symbolic_helper._slice_helper( + g, x, axes=[0], starts=[start], ends=[end] + ) + ) + + for i in range(num_layers): + if unidirectional: + if weights_per_layer == 4: + weight_ih, weight_hh, bias_concat = transform_weights(i) + else: + weight_ih, weight_hh = transform_weights_no_bias(i) + bias_concat = unused(g) + + state_indices = i, i + 1 + else: + if weights_per_layer == 4: + weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i) + weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1) + bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0) + else: + weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i) + weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1) + bias_concat = unused(g) + + weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0) + weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0) + + state_indices = 2 * i, 2 * i + 2 + + inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens] + + inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined] + if variant == "LSTM": + inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined] + + extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"} + if variant == "RNN": + if bidirectional: + activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined] + else: + activation = [nonlinearity] # type: ignore[possibly-undefined] + + prev_output, h_out = g.op( + "RNN", + *inputs, + outputs=2, + hidden_size_i=hidden_size, + activations_s=activation, + **extra_kwargs, + ) + elif variant == "GRU": + prev_output, h_out = g.op( + "GRU", + *inputs, + outputs=2, + hidden_size_i=hidden_size, + linear_before_reset_i=1, + **extra_kwargs, + ) + elif variant == "LSTM": + prev_output, h_out, c_out = g.op( + "LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs + ) + + if bidirectional: + # The ONNX RNN/GRU/LSTM produce an output of dimensions + # seq_len, num_directions, batch, hidden_size + # We have to convert to match pytorch's expected + # seq_len, batch, num_directions * hidden_size + # by first moving num_directions before hidden_size with + # Transpose, and then combining it with hidden_size + # with Reshape. + prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3]) + prev_output = symbolic_helper._reshape_helper( + g, + prev_output, + g.op("Constant", value_t=torch.LongTensor([0, 0, -1])), + allowzero=0, + ) + else: + prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1]) + + h_outs.append(h_out) # type: ignore[possibly-undefined] + if variant == "LSTM": + c_outs.append(c_out) # type: ignore[possibly-undefined] + if batch_first: + # seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size + prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2]) + h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined] + if variant == "RNN" or variant == "GRU": + return prev_output, h_outs + elif variant == "LSTM": + c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined] + return prev_output, h_outs, c_outs + + +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") +def _lstm_full( + g: jit_utils.GraphContext, + input, + hidden_v, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + hidden, weight = ( + symbolic_helper._unpack_list(hidden_v), + symbolic_helper._unpack_list(weight_v), + ) + return _generic_rnn( + g, + "LSTM", + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + ) + + +@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") +def _lstm_packed( + g: jit_utils.GraphContext, + input, + batch_sizes, + hidden_v, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden, weight = ( + symbolic_helper._unpack_list(hidden_v), + symbolic_helper._unpack_list(weight_v), + ) + return _generic_rnn( + g, + "LSTM", + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_sizes=batch_sizes, + ) + + +@_onnx_symbolic("aten::lstm") +def lstm(g: jit_utils.GraphContext, *args): + if symbolic_helper._is_tensor_list(args[3]): + return _lstm_packed(g, *args) + else: + return _lstm_full(g, *args) + + +@_onnx_symbolic("aten::lstm_cell") +def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh): + input = symbolic_helper._unsqueeze_helper(g, self, [0]) + hidden = symbolic_helper._unpack_list(hidden) + hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden] + weight = ( + (w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh) + ) + has_biases = True if symbolic_helper._is_tensor(b_ih) else False + _, h_outs, c_outs = _generic_rnn( + g, + "LSTM", + input, + hidden, + weight, + has_biases, + num_layers=1, + dropout=0, + train=0, + bidirectional=False, + batch_first=False, + ) + return symbolic_helper._squeeze_helper( + g, h_outs, [0] + ), symbolic_helper._squeeze_helper(g, c_outs, [0]) + + +@_onnx_symbolic( + "aten::gru", decorate=[symbolic_helper._apply_params("GRU"), _export("gru")] +) +@_onnx_symbolic( + "aten::rnn_tanh", + decorate=[symbolic_helper._apply_params("RNN_TANH"), _export("rnn_tanh")], +) +@_onnx_symbolic( + "aten::rnn_relu", + decorate=[symbolic_helper._apply_params("RNN_RELU"), _export("rnn_relu")], +) +def _one_hidden_rnn(kind: str): + @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") + def _rnn_full( + g, + input, + hidden, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + ): + weight = symbolic_helper._unpack_list(weight_v) + return _generic_rnn( + g, + kind, + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + ) + + @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") + def _rnn_packed( + g, + input, + batch_sizes, + hidden, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, + ): + weight = symbolic_helper._unpack_list(weight_v) + return _generic_rnn( + g, + kind, + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_sizes=batch_sizes, + ) + + def symbolic(g, *args): + if symbolic_helper._is_tensor_list(args[3]): + return _rnn_packed(g, *args) + else: + return _rnn_full(g, *args) + + return symbolic + + +@_onnx_symbolic("aten::_dim_arange") +@symbolic_helper.parse_args("v", "i") +def _dim_arange(g: jit_utils.GraphContext, like, dim): + like_shape = g.op("Shape", like) + stop = g.op( + "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 + ) + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + return arange(g, stop, 4, None, None, None) + + +@_onnx_symbolic("aten::detach") +def detach(g: jit_utils.GraphContext, input): + # Erase aten::detach nodes because ONNX is inference only + return input + + +@_onnx_symbolic("aten::contiguous") +@symbolic_helper.parse_args("v", "i") +def contiguous(g: jit_utils.GraphContext, input, memory_format): + if memory_format > 2: # allower values are any, preserve and contiguous_format + raise errors.SymbolicValueError( + "onnx memory_format support is not implemented", input + ) + return input + + +@_onnx_symbolic("aten::_pack_padded_sequence") +@symbolic_helper.parse_args("v", "v", "i") +def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first): + # Currently there is no PackPadded operator in ONNX. We rely on an + # optimization pass to remove this later. It is an error if all + # PackPadded operators cannot be optimized out. + if batch_first: + input = g.op("Transpose", input, perm_i=[1, 0, 2]) + if not lengths.type().isSubtypeOf(torch._C.TensorType.get()): + raise errors.SymbolicValueError( + "'lengths' must be a Tensor for ONNX export", input + ) + # We know it's a TensorType so this check is now safe. + # It's really only necessary because those operators expand to something that + # only works with int32 types in Caffe2... + if ( + _type_utils.JitScalarType.from_value( + lengths, _type_utils.JitScalarType.UNDEFINED + ) + != _type_utils.JitScalarType.INT + ): + lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32) + return g.op("prim::PackPadded", input, lengths, outputs=2) + + +@_onnx_symbolic("aten::_pad_packed_sequence") +@symbolic_helper.parse_args("v", "v", "i", "t", "v") +def _pad_packed_sequence( + g: jit_utils.GraphContext, + data, + batch_sizes, + batch_first, + padding_value, + total_length, +): + # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence + # It is only useful/used when training using data_parallel model, so + # It shouldn't be relevant for ONNX anyway + data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2) + if batch_first: + data = g.op("Transpose", data, perm_i=[1, 0, 2]) + return data, lengths + + +@_onnx_symbolic("aten::randint") +def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + low_i = symbolic_helper._get_const(low, "i", "low") + high_i = symbolic_helper._get_const(high, "i", "high") + if dtype is None: + scalar_type = _type_utils.JitScalarType.INT64 + else: + scalar_type = _type_utils.JitScalarType(dtype) + if low_i is None: + raise symbolic_helper._onnx_unsupported("randint", low) + if high_i is None: + raise symbolic_helper._onnx_unsupported("randint", high) + + shape = symbolic_helper._maybe_get_const(shapes, "is") + if symbolic_helper._is_value(shape): + shape_const = g.op( + "ConstantOfShape", + shapes, + value_t=torch.tensor([0], dtype=torch.float), + ) + randn = g.op( + "RandomUniformLike", + shape_const, + low_f=low_i, + high_f=high_i, + ) + else: + randn = g.op( + "RandomUniform", + shape_i=shape, + low_f=low_i, + high_f=high_i, + ) + + # cast to integer type + int_dtype = _type_utils.JitScalarType.INT64 + randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) + if int_dtype != scalar_type: + randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) + return randint + + +@_onnx_symbolic("aten::randint_like") +def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + low_i = symbolic_helper._get_const(low, "i", "low") + high_i = symbolic_helper._get_const(high, "i", "high") + if dtype is None: + scalar_type = _type_utils.JitScalarType.INT64 + else: + scalar_type = _type_utils.JitScalarType(dtype) + if low_i is None: + raise symbolic_helper._onnx_unsupported("randint", low) + if high_i is None: + raise symbolic_helper._onnx_unsupported("randint", high) + + randn = g.op( + "RandomUniformLike", + self, + low_f=low_i, + high_f=high_i, + ) + + # cast to integer type + int_dtype = _type_utils.JitScalarType.INT64 + randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) + if int_dtype != scalar_type: + randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) + return randint + + +@_onnx_symbolic("aten::randn") +def randn(g: jit_utils.GraphContext, shapes, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + shape = symbolic_helper._maybe_get_const(shapes, "is") + if symbolic_helper._is_value(shape): + shape_const = g.op( + "ConstantOfShape", + shapes, + value_t=torch.tensor([0], dtype=torch.float), + ) + return g.op( + "RandomNormalLike", + shape_const, + dtype_i=scalar_type.onnx_type(), + ) + return g.op( + "RandomNormal", + shape_i=shape, + dtype_i=scalar_type.onnx_type(), + ) + + +@_onnx_symbolic("aten::rand") +def rand(g: jit_utils.GraphContext, shapes, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + shape = symbolic_helper._maybe_get_const(shapes, "is") + if symbolic_helper._is_value(shape): + shape_const = g.op( + "ConstantOfShape", + shapes, + value_t=torch.tensor([0], dtype=torch.float), + ) + return g.op( + "RandomUniformLike", + shape_const, + dtype_i=scalar_type.onnx_type(), + ) + return g.op( + "RandomUniform", + shape_i=shape, + dtype_i=scalar_type.onnx_type(), + ) + + +@_onnx_symbolic("aten::randn_like") +def randn_like( + g: jit_utils.GraphContext, + self, + dtype, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type()) + + +@_onnx_symbolic("aten::rand_like") +def rand_like( + g: jit_utils.GraphContext, + self, + dtype, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + dtype = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + return g.op( + "RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + + +@_onnx_symbolic("aten::rrelu") +@symbolic_helper.parse_args("v", "f", "f", "i", "none") +def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator): + if not training: + slope = (upper + lower) / 2.0 + return g.op("LeakyRelu", input, alpha_f=slope) + p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower) + return g.op("PRelu", input, p) + + +@_onnx_symbolic("aten::bernoulli") +def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): + if out is not None and not symbolic_helper._is_none(out): + symbolic_helper._unimplemented( + "Bernoulli", "out parameter is not supported for bernoulli", input + ) + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Bernoulli", "generator is not supported for bernoulli", input + ) + + dtype = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.UNDEFINED + ) + if dtype == _type_utils.JitScalarType.UNDEFINED: + return symbolic_helper._unimplemented( + "Bernoulli", "input dtype not accessible", input + ) + + rands = g.op( + "RandomUniformLike", + input, + high_f=1.0, + low_f=0.0, + dtype_i=dtype.onnx_type(), + ) + prob = p if p is not None and not symbolic_helper._is_none(p) else input + output = g.op("Less", rands, prob) + return g.op("Cast", output, to_i=dtype.onnx_type()) + + +@_onnx_symbolic("aten::log_sigmoid") +@symbolic_helper.parse_args("v") +def log_sigmoid(g: jit_utils.GraphContext, input): + p = g.op("Sigmoid", input) + return g.op("Log", p) + + +@_onnx_symbolic("aten::erf") +@symbolic_helper.parse_args("v") +def erf(g: jit_utils.GraphContext, input): + return g.op("Erf", input) + + +@_onnx_symbolic("aten::flatten") +@symbolic_helper.quantized_args(True, False, False) +@symbolic_helper.parse_args("v", "i", "i") +def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): + dim = symbolic_helper._get_tensor_rank(input) + if dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + input, + ) + + if dim == 0: + return symbolic_helper._reshape_helper(g, input, [1]) + if dim == 1: + return g.op("Identity", input) + # TODO: remove this as onnx opset 11 spec allows negative axes + if end_dim < 0: + end_dim = dim + end_dim + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim == 1 and end_dim == dim - 1: + return g.op("Flatten", input, axis_i=start_dim) + if start_dim == 0 and end_dim == dim - 2: + return g.op("Flatten", input, axis_i=end_dim + 1) + + return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) + + +@_onnx_symbolic("aten::nonzero") +@symbolic_helper.parse_args("v") +def nonzero(g: jit_utils.GraphContext, input): + """Emitted from `torch.nonzero(x, as_tuple=False)`""" + return t(g, g.op("NonZero", input)) + + +@_onnx_symbolic("aten::nonzero_numpy") +# Emitted from `torch.nonzero(x, as_tuple=True)` +def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): + return unbind(g, nonzero(g, input), 1, _outputs=_outputs) + + +@_onnx_symbolic("aten::isnan") +@symbolic_helper.parse_args("v") +def isnan(g: jit_utils.GraphContext, input): + output = g.op("IsNaN", input) + return output + + +@_onnx_symbolic("aten::any") +def _any(g: jit_utils.GraphContext, *args): + # aten::any(Tensor self) + if len(args) == 1: + input = args[0] + dim, keepdim = None, 0 + # aten::any(Tensor self, int[]? dim, bool keepdim) + else: + input, dim, keepdim = args + # Can be int list or single int + dim = symbolic_helper._parse_arg(dim, "t") + dim = [int(d) for d in dim.view(-1)] + keepdim = symbolic_helper._parse_arg(keepdim, "i") + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) + input_sum = symbolic_helper._reducesum_helper( + g, input, axes_i=dim, keepdims_i=keepdim + ) + return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long))) + + +@_onnx_symbolic("aten::all") +def _all(g: jit_utils.GraphContext, *args): + input = g.op("Not", args[0]) + # aten::all(Tensor self) + if len(args) == 1: + return g.op("Not", _any(g, input)) + # aten::all(Tensor self, int[]? dim, bool keepdim) + else: + return g.op("Not", _any(g, input, args[1], args[2])) + + +@_onnx_symbolic("aten::narrow") +@symbolic_helper.parse_args("v", "i", "i", "i") +def narrow(g: jit_utils.GraphContext, input, dim, start, length): + return symbolic_helper._slice_helper( + g, input, axes=[dim], starts=[start], ends=[start + length] + ) + + +@_onnx_symbolic("aten::argmax") +@symbolic_helper.parse_args("v", "v", "b") +def argmax( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") + + +@_onnx_symbolic("aten::argmin") +@symbolic_helper.parse_args("v", "v", "b") +def argmin( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") + + +@_onnx_symbolic("aten::scatter") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter(g: jit_utils.GraphContext, self, dim, index, src): + src_type = _type_utils.JitScalarType.from_value( + src, _type_utils.JitScalarType.UNDEFINED + ) + src = symbolic_helper._maybe_get_scalar(src) + if symbolic_helper._is_value(src): + return g.op("Scatter", self, index, src, axis_i=dim) + else: + # Check if scalar "src" has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + self_scalar_type = _type_utils.JitScalarType.from_value(self) + if self_scalar_type != src_type: + src = g.op("Cast", src, to_i=self_scalar_type.onnx_type()) + return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) + + +@_onnx_symbolic("aten::scatter_add") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): + scalar_type = symbolic_helper._try_get_scalar_type(self) + if scalar_type is None: + return symbolic_helper._unimplemented( + "scatter_add", "input dtype not accessible", self + ) + sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False) + if sizes: + to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype())) + else: + to_add = zeros_like(g, self, scalar_type) + to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src) + return add(g, self, to_add) + + +@_onnx_symbolic("aten::log2") +def log2(g: jit_utils.GraphContext, self): + _ln2 = 0.693147180559945309 + return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2))) + + +@_onnx_symbolic("aten::is_floating_point") +def is_floating_point(g: jit_utils.GraphContext, self): + if symbolic_helper._is_fp(self): + return g.op("Constant", value_t=torch.BoolTensor([1])) + return g.op("Constant", value_t=torch.BoolTensor([0])) + + +@_onnx_symbolic("aten::__is_") +def __is_(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_none(other): + if symbolic_helper._is_none(self): + return g.op("Constant", value_t=torch.BoolTensor([1])) + return g.op("Constant", value_t=torch.BoolTensor([0])) + return eq(g, self, other) + + +@_onnx_symbolic("aten::__isnot_") +@wrap_logical_op_with_negation +def __isnot_(g: jit_utils.GraphContext, self, other): + return __is_(g, self, other) + + +@_onnx_symbolic("aten::one_hot") +def one_hot(g: jit_utils.GraphContext, self, num_classes): + values = g.op("Constant", value_t=torch.LongTensor([0, 1])) + # onnxruntime supports limited type combinations for OneHot. + if _type_utils.JitScalarType.from_value( + num_classes, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.UINT8, + _type_utils.JitScalarType.INT8, + _type_utils.JitScalarType.INT, + _type_utils.JitScalarType.INT16, + }: + num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64) + return g.op("OneHot", self, num_classes, values, axis_i=-1) + + +@_onnx_symbolic("aten::gather") +@symbolic_helper.parse_args("v", "i", "v", "v") +def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): + if symbolic_helper._maybe_get_const(sparse_grad, "i"): + return symbolic_helper._unimplemented("gather", "sparse_grad == True", self) + # NOTE: This workaround is needed since GatherElement is only supported + # since opset 11, and Gather in ONNX is not the same as torch.gather. + scalar_type = _type_utils.JitScalarType.from_value(self) + values = g.op("Constant", value_t=torch.LongTensor([0, 1])) + depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim]))) + index = g.op( + "Cast", + g.op("OneHot", index, depth, values, axis_i=dim), + to_i=scalar_type.onnx_type(), + ) + mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index) + return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0) + + +@symbolic_helper.parse_args("v", "is", "i", "i") +def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim): + return symbolic_helper._var_mean_helper(g, input, dim, correction, keepdim) + + +@_onnx_symbolic("aten::std") +def std(g: jit_utils.GraphContext, input, *args): + var, _ = var_mean(g, input, *args) + return g.op("Sqrt", var) + + +@_onnx_symbolic("aten::var") +def var(g: jit_utils.GraphContext, input, *args): + var, _ = var_mean(g, input, *args) + return var + + +@_onnx_symbolic("aten::var_mean") +def var_mean(g: jit_utils.GraphContext, input, *args): + if len(args) == 1: + return _var_mean(g, input, None, args[0], None) + else: + return _var_mean(g, input, *args) + + +@_onnx_symbolic("aten::std_mean") +def std_mean(g: jit_utils.GraphContext, input, *args): + var, mean = var_mean(g, input, *args) + return g.op("Sqrt", var), mean + + +@_onnx_symbolic("aten::logsumexp") +@symbolic_helper.parse_args("v", "is", "i") +def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): + return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::arange") +def arange(g: jit_utils.GraphContext, *args): + def _get_arange_dtype(dtype): + dtype = symbolic_helper._maybe_get_const(dtype, "i") + return dtype + + def _float_step_convert(range_tensor): + if symbolic_helper._is_fp(range_tensor): + range_tensor = g.op( + "Cast", + g.op("Ceil", range_tensor), + to_i=_type_utils.JitScalarType.INT64.onnx_type(), + ) + return range_tensor + + if len(args) == 2 or len(args) == 5: + if len(args) == 2: + # aten::arange(Scalar end, Tensor out) + dtype = None + else: + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[1]) + dtype, end, start, step = symbolic_helper._arange_cast_helper( + g, end=args[0], dtype=dtype + ) + end = symbolic_helper._unsqueeze_helper(g, end, [0]) + range_tensor = _float_step_convert(end) + arange_tensor = symbolic_helper._squeeze_helper( + g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1] + ) + return g.op( + "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + elif len(args) == 4 or len(args) == 7: + if len(args) == 4: + # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) + dtype = None + else: + # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[3]) + dtype, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], step=args[2], dtype=dtype + ) + step = symbolic_helper._unsqueeze_helper(g, step, [0]) + end = symbolic_helper._unsqueeze_helper(g, end, [0]) + start = symbolic_helper._unsqueeze_helper(g, start, [0]) + range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step)) + arange_tensor = symbolic_helper._squeeze_helper( + g, nonzero(g, ones(g, range_tensor, None, None, None)), [1] + ) + arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) + return g.op( + "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + elif len(args) == 6: + # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[2]) + dtype, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], dtype=dtype + ) + end = symbolic_helper._unsqueeze_helper(g, end, [0]) + start = symbolic_helper._unsqueeze_helper(g, start, [0]) + range_tensor = _float_step_convert(g.op("Sub", end, start)) + arange_tensor = g.op( + "Add", + symbolic_helper._squeeze_helper( + g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1] + ), + start, + ) + return g.op( + "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + + return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments") + + +@_onnx_symbolic("aten::linspace") +def linspace( + g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory +): + range_tensor = symbolic_helper._arange_helper(g, steps, None) + step = div( + g, + sub(g, end, start), + sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))), + ) + return add(g, mul(g, range_tensor, step), start) + + +@_onnx_symbolic("aten::lift") +def lift(g: jit_utils.GraphContext, self): + # at::lift() is a no-op from the perspective of tracing for onnx + return self + + +@_onnx_symbolic("aten::masked_fill") +def masked_fill(g: jit_utils.GraphContext, self, mask, value): + """Implement the masked_fill functionality available for a pytorch tensor in ONNX. + + Fills elements of the input tensor with `value` where `mask` is True. + """ + mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) + value = symbolic_helper._maybe_get_scalar(value) + return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self) + + +@_onnx_symbolic("aten::masked_fill_") +def masked_fill_(g: jit_utils.GraphContext, self, mask, value): + return masked_fill(g, self, mask, value) + + +@_onnx_symbolic("aten::index") +def index(g: jit_utils.GraphContext, self, index): + if symbolic_helper._is_packed_list(index): + indices = symbolic_helper._unpack_list(index) + else: + indices = [index] + + def try_mask_to_index(index): + if not symbolic_helper._is_none(index) and ( + _type_utils.JitScalarType.from_value( + index, _type_utils.JitScalarType.UNDEFINED + ) + == _type_utils.JitScalarType.UINT8 + or symbolic_helper._is_bool(index) + ): + if g.opset < 9: + raise errors.SymbolicValueError( + "Exporting masked indices are only supported after ONNX opset 9.", + self, + ) + warnings.warn( + "Exporting aten::index operator with indices of type Byte. " + "Only 1-D indices are supported. In any other case, " + "this will produce an incorrect ONNX graph." + ) + index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1]) + return index + + indices = [try_mask_to_index(idx) for idx in indices] + if len(indices) == 1: + return symbolic_helper._select_helper( + g, self, 0, indices[0], apply_reshape=False + ) + else: + # Multiple tensors as indices. Each tensor could either be + # 1. prim::Constant() + # representing ":" in python indexing. E.g. tensor[:, :] + # 2. prim::Constant[value=...] or tensor output + # representing advanced indexing. E.g. tensor[[0, 1], [2, 0]]. + # For more info on advanced indexing, + # check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing + + # Consider a general case of + # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] + # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":". + # Same results can be achieved through transposing t into + # t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] + # and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t + # and process the tensor indices. + # t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n] + # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)) + # After gather, reshape and transpose back. + adv_idx_indices = [ + i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx) + ] + + if len(adv_idx_indices) == 0: + return self + elif len(adv_idx_indices) == 1: + return index_select( + g, self, adv_idx_indices[0], indices[adv_idx_indices[0]] + ) + else: + rank = symbolic_helper._get_tensor_rank(self) + if rank is None: + return symbolic_helper._unimplemented( + "aten::index", + "operator of advanced indexing on tensor of unknown rank. ", + self, + ) + # TODO: If indexing is supported natively in ONNX in future opsets, + # update the warning to recommend exporting with higher opset version. + warnings.warn( + "Exporting aten::index operator of advanced indexing in opset " + f"{GLOBALS.export_onnx_opset_version}" + " is achieved by combination of multiple ONNX operators, " + "including Reshape, Transpose, Concat, and Gather. " + "If indices include negative values, the exported graph will produce incorrect results." + ) + adv_idx_count = len(adv_idx_indices) + shape_tensor = _shape_as_tensor(g, self) + dim_tensor_list = [ + g.op( + "Gather", + shape_tensor, + g.op("Constant", value_t=torch.LongTensor([dim])), + axis_i=0, + ) + for dim in range(rank) + ] + + self = g.op( + "Transpose", + self, + perm_i=adv_idx_indices + + [i for i in range(rank) if i not in adv_idx_indices], + ) + self = g.op("Flatten", self, axis_i=adv_idx_count) + + # Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well. + cum_adv_index = indices[adv_idx_indices[-1]] + multiplier = dim_tensor_list[adv_idx_indices[-1]] + for i in range(adv_idx_count - 2, -1, -1): + adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier) + cum_adv_index = g.op("Add", cum_adv_index, adv_index) + multiplier = g.op( + "Mul", multiplier, dim_tensor_list[adv_idx_indices[i]] + ) + + # perform gather + self = index_select(g, self, 0, cum_adv_index) + + cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index) + # check if all advanced indices are consecutive. + # Refer to https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing + # to understand how the subarray position is decided. + if adv_idx_indices == list( + range(adv_idx_indices[0], adv_idx_indices[-1] + 1) + ): + # unfold regular index axes + folded_adv_idx_shape_list = [ + g.op("Constant", value_t=torch.LongTensor([-1])) + ] + [ + dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices + ] + folded_adv_idx_shape = g.op( + "Concat", *folded_adv_idx_shape_list, axis_i=0 + ) + self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape) + + # Transpose folded advanced indexed axis to its original location. + adv_idx_permute = ( + list(range(1, adv_idx_indices[0] + 1)) + + [0] + + list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1)) + ) + self = g.op("Transpose", self, perm_i=adv_idx_permute) + + # unfold advanced index axes + final_shape_list = ( + [dim_tensor_list[i] for i in range(adv_idx_indices[0])] + + [cum_adv_index_shape_tensor] + + [ + dim_tensor_list[i] + for i in range(adv_idx_indices[0], rank) + if i not in adv_idx_indices + ] + ) + final_shape = g.op("Concat", *final_shape_list, axis_i=0) + else: + final_shape = g.op( + "Concat", + cum_adv_index_shape_tensor, + *[ + dim_tensor_list[i] + for i in range(rank) + if i not in adv_idx_indices + ], + axis_i=0, + ) + + return symbolic_helper._reshape_helper(g, self, final_shape) + + +@_onnx_symbolic("aten::linalg_norm") +@symbolic_helper.parse_args("v", "v", "is", "b", "v") +def linalg_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: torch._C.Value, + dim: Sequence[int] | None, + keepdim: bool, + dtype: torch._C.Value, +): + # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html + ord_value = None + if dim is None: + if symbolic_helper._is_none(ord): + self = symbolic_helper._reshape_helper(g, self, [-1]) + ord = g.op("Constant", value_t=torch.LongTensor([2])) + self_dim = symbolic_helper._get_tensor_rank(self) + if self_dim is None: + return symbolic_helper._unimplemented( + "dim", "Input rank must be known at export time.", self + ) + if self_dim == 1: + ord_value = symbolic_helper._parse_arg(ord, "f") + else: + dim = [0, 1] + else: + if len(dim) == 1: + if symbolic_helper._is_none(ord): + ord = g.op("Constant", value_t=torch.LongTensor([2])) + ord_value = symbolic_helper._parse_arg(ord, "f") + if ord_value: + return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype) + return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::linalg_vector_norm") +@symbolic_helper.parse_args("v", "f", "is", "b", "v") +def linalg_vector_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: float, + dim: Sequence[int] | None, + keepdim: bool, + dtype: torch._C.Value, +): + return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::linalg_matrix_norm") +@symbolic_helper.parse_args("v", "v", "is", "b", "v") +def linalg_matrix_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: torch._C.Value, + dim: list[int], + keepdim: bool, + dtype: torch._C.Value, +): + # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html + ord_value = symbolic_helper._parse_arg(ord, "s") + if ord_value == "fro": + return frobenius_norm(g, self, dim, keepdim) + elif ord_value == "nuc": + return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self) + else: + ord_value = symbolic_helper._parse_arg(ord, "f") + if ord_value is None: + return frobenius_norm(g, self, dim, keepdim) + if ord_value == 2 or ord_value == -2: + # ord = 2/-2 unimplemented due to lack of operators + # used to calculate singular values + return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self) + # Wrap the dim vector to handle negative dim values + self_dim = symbolic_helper._get_tensor_rank(self) + if self_dim is None: + return symbolic_helper._unimplemented( + "linalg.matrix_norm", "Input rank must be known at export time.", self + ) + # Common implementation for cases with + # ord = 1/-1 and ord = inf/-inf + if dim[0] < 0: + dim[0] += self_dim + if dim[1] < 0: + dim[1] += self_dim + + if ord_value == math.inf or ord_value == -math.inf: + dim[0], dim[1] = dim[1], dim[0] + if dim[1] > dim[0] and not keepdim: + dim[1] -= 1 + sum = symbolic_helper._reducesum_helper( + g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim + ) + if ord_value > 0: + result, indices = max( + g, + sum, + dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), + keepdim=keepdim, + ) + else: + result, indices = min( + g, + sum, + dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), + keepdim=keepdim, + ) + return result + + +@_onnx_symbolic("aten::linalg_cross") +@symbolic_helper.parse_args("v", "v", "i") +def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1): + return cross(g, input, other, dim) + + +@_onnx_symbolic("aten::frobenius_norm") +@symbolic_helper.parse_args("v", "is", "b") +def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): + sqr = g.op("Mul", self, self) + sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim) + return g.op("Sqrt", sumsqr) + + +@_onnx_symbolic("aten::multinomial") +@symbolic_helper.parse_args("v", "i", "b", "v") +def multinomial( + g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None +): + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Multinomial", "generator is not supported for multinomial", input + ) + if not replacement and num_samples > 1: + symbolic_helper._unimplemented( + "Multinomial", + "replacement=False when num_samples > 1 is not supported for multinomial", + input, + ) + + log_input = log(g, input) + return g.op( + "Multinomial", + log_input, + dtype_i=_C_onnx.TensorProtoDataType.INT64, + sample_size_i=num_samples, + ) + + +@_onnx_symbolic("aten::baddbmm") +def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha): + scalar_type = _type_utils.JitScalarType.from_value(self) + batch_mul = matmul(g, batch1, batch2) + mul_a = mul( + g, + batch_mul, + g.op("Cast", alpha, to_i=scalar_type.onnx_type()), + ) + mul_b = mul( + g, + self, + g.op("Cast", beta, to_i=scalar_type.onnx_type()), + ) + return add(g, mul_a, mul_b) + + +@_onnx_symbolic("aten::meshgrid") +@symbolic_helper.parse_args("v", "s") +def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: str | None = None): + if indexing is None: + indexing = "ij" + elif indexing not in {"ij", "xy"}: + raise errors.SymbolicValueError( + f"Unsupported indexing: {indexing}", tensor_list + ) + unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list) + if indexing == "xy": + unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1] + tensors = [ + symbolic_helper._reshape_helper( + g, t, g.op("Constant", value_t=torch.LongTensor([-1])) + ) + for t in unpacked_tensor_list + ] + tensors_shape = [g.op("Shape", t) for t in tensors] + out_shape = g.op("Concat", *tensors_shape, axis_i=0) + out = [] + for i, t in enumerate(tensors): + shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len( + tensors + ) + shape_i[i] = tensors_shape[i] + t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0)) + out.append(g.op("Expand", t_reshaped, out_shape)) + if indexing == "xy": + out[0], out[1] = out[1], out[0] + return g.op("prim::ListConstruct", *out) + + +@_onnx_symbolic("aten::remainder") +def remainder(g: jit_utils.GraphContext, input, other): + div = _floor_divide(g, input, other) + quo = g.op("Mul", div, other) + return g.op("Sub", input, quo) + + +@_onnx_symbolic("aten::gelu") +@symbolic_helper.parse_args("v", "s") +def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"): + if approximate == "tanh": + kBeta = math.sqrt(2 / math.pi) + kKappa = 0.044715 + + beta = torch.tensor(kBeta, dtype=torch.double) + kappa = torch.tensor(kKappa, dtype=torch.double) + one = torch.tensor(1.0, dtype=torch.double) + half = torch.tensor(0.5, dtype=torch.double) + + self_cube = mul(g, self, mul(g, self, self)) + inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube))) + return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner)))) + else: + _sqrt2 = 1.4142135623730951 + erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) + erf_plusone = add( + g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)) + ) + return mul( + g, + mul(g, self, erf_plusone), + g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)), + ) + + +@_onnx_symbolic("aten::group_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "i", "v", "v", "f", "i") +def group_norm( + g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled +): + channel_size = symbolic_helper._get_tensor_dim_size(input, 1) + if channel_size is not None: + assert channel_size % num_groups == 0 + input_rank = symbolic_helper._get_tensor_rank(input) + if input_rank is None: + return symbolic_helper._unimplemented("group_norm", "unknown input rank", input) + # 0 in the shape list keeps dimension value unchanged. + shape = [0, num_groups, -1] + input_reshaped = symbolic_helper._reshape_helper( + g, input, g.op("Constant", value_t=torch.LongTensor(shape)) + ) + + # C is always divisible by num_groups + # Due to shape difference. we need to apply weight and bias after + # instance norm computation and reshape + weight_ = g.op( + "Constant", + value_t=torch.tensor( + [1.0] * num_groups, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ), + ) + bias_ = g.op( + "Constant", + value_t=torch.tensor( + [0.0] * num_groups, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ), + ) + + norm_reshaped = g.op( + "InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps + ) + norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input)) + + if weight is None or weight.node().mustBeNone(): + weight_value = torch.tensor( + [1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() + ) + weight = g.op("Constant", value_t=weight_value) + if bias is None or bias.node().mustBeNone(): + bias_value = torch.tensor( + [0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() + ) + bias = g.op("Constant", value_t=bias_value) + + # Norm has shape [N, C, *] so we reshape weight and bias to [C, *] + axes = list(range(1, input_rank - 1)) + return add( + g, + mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)), + symbolic_helper._unsqueeze_helper(g, bias, axes), + ) + + +@_onnx_symbolic("aten::_weight_norm") +@symbolic_helper.parse_args("v", "v", "i") +def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim): + rank = symbolic_helper._get_tensor_rank(weight_v) + if rank is not None: + # W = g * ((v) / ||v||) + # Compute norm_except_dim for l2 norm. dim = None means over all dims + # torch's weight_norm module sets dim = -1 if it's None. + # This conflicts the logic for negative axes to access dims backwards + # TODO: Might need a fix in torch group_norm module + axes = list(range(rank)) + if dim is not None: + if dim < -1: + dim += rank + if dim != -1: + axes.remove(dim) + norm_v = norm(g, weight_v, 2, axes, 1) + div = g.op("Div", weight_v, norm_v) + return g.op("Mul", div, weight_g) + raise errors.SymbolicValueError( + "Unsupported: ONNX export of _weight_norm for tensor of unknown rank.", + weight_v, + ) + + +@_onnx_symbolic("aten::dim") +def dim(g: jit_utils.GraphContext, self): + """Implement the dim functionality available for a pytorch tensor in ONNX""" + # ONNX does not support dim directly in this opset so we can use 2 ops to get the info + shape = g.op("Shape", self) + return g.op("Size", shape) + + +@_onnx_symbolic("aten::__contains_") +def __contains_(g: jit_utils.GraphContext, self, element): + unpacked_list = symbolic_helper._unpack_list(self) + if all( + symbolic_helper._is_constant(x) for x in unpacked_list + ) and symbolic_helper._is_constant(element): + return g.op( + "Constant", + value_t=torch.tensor( + symbolic_helper._node_get(element.node(), "value") + in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list) + ), + ) + + raise errors.SymbolicValueError( + "Unsupported: ONNX export of __contains__ for non-constant list or element.", + self, + ) + + +@_onnx_symbolic("aten::__getitem_") +def __getitem_(g: jit_utils.GraphContext, self, i): + return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) + + +@_onnx_symbolic("aten::item") +def item(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("aten::take") +def take(g: jit_utils.GraphContext, self, index): + self_flattened = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + ) + out = index_select(g, self_flattened, 0, index) + out = reshape_as(g, out, index) + return out + + +def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target): + diff_ = sub(g, target, input) + exp_ = exp(g, target) + output = mul(g, exp_, diff_) + return output + + +def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target): + log_ = log(g, target) + diff_ = sub(g, log_, input) + output_pos = mul(g, target, diff_) + zeros_ = zeros_like(g, output_pos) + mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0))) + output = where(g, mask_, output_pos, zeros_) + return output + + +@_onnx_symbolic("aten::kl_div") +@symbolic_helper.parse_args("v", "v", "i", "b") +def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target): + if log_target: + output = _kl_div_log_target_impl(g, input, target) + else: + output = _kl_div_non_log_target_impl(g, input, target) + + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output, keepdims_i=0) + elif reduction == 2: + return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) + else: + return symbolic_helper._onnx_unsupported( + "kl_div with reduction other than none, mean, or sum.", input + ) + + +@_onnx_symbolic("aten::mse_loss") +@symbolic_helper.parse_args("v", "v", "i") +def mse_loss(g: jit_utils.GraphContext, input, target, reduction): + output = mul(g, sub(g, input, target), sub(g, input, target)) + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output, keepdims_i=0) + elif reduction == 2: + return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) + else: + return symbolic_helper._onnx_unsupported( + "mse_loss with reduction other than none, mean, or sum.", input + ) + + +@_onnx_symbolic("aten::as_strided") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "is", "i") +def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None): + sizes = symbolic_helper._maybe_get_const(sizes, "is") + rank = len(strides) + self_1d = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + ) + ind: torch.Tensor | None + if not symbolic_helper._is_value(sizes): + ind = torch.tensor([0], dtype=torch.long) + for i, (size, stride) in enumerate(zip(sizes, strides)): + r_size = [1] * rank + r_size[i] = -1 + ind = ind + torch.arange(size).view(r_size) * stride + if offset: + ind = ind + offset + return g.op("Gather", self_1d, g.op("Constant", value_t=ind)) + else: + ind = None + for i, stride in enumerate(strides): + r_size = [1] * rank + r_size[i] = -1 + size = select( + g, + sizes, + g.op("Constant", value_t=torch.tensor([0])), + g.op("Constant", value_t=torch.tensor(i)), + ) + tmp_ind = symbolic_helper._reshape_helper( + g, + arange(g, size, 4, None, None, None), + g.op("Constant", value_t=torch.tensor(r_size)), + ) + tmp_ind = g.op( + "Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride])) + ) + if ind is None: + ind = tmp_ind + else: + ind = g.op("Add", ind, tmp_ind) + if offset: + ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) + return g.op("Gather", self_1d, ind) + + +@_onnx_symbolic("aten::__derive_index") +def __derive_index(g: jit_utils.GraphContext, index, start, step): + return g.op("Add", start, g.op("Mul", index, step)) + + +@_onnx_symbolic("aten::__range_length") +# Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp +# if (step > 0 && lo < hi) { +# push(stack, 1 + (hi - 1 - lo) / step); +# } else if (step < 0 && lo > hi) { +# push(stack, 1 + (lo - 1 - hi) / (0 - step)); +# } else { +# push(stack, 0); +# } +def __range_length(g: jit_utils.GraphContext, lo, hi, step): + sub = g.op("Sub", hi, lo) + div = g.op("Ceil", true_divide(g, sub, step)) + return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64) + + +@_onnx_symbolic("aten::linear") +def linear(g: jit_utils.GraphContext, input, weight, bias): + rank = symbolic_helper._get_tensor_rank(input) + weight = t(g, weight) + if rank == 2 and not bias.node().mustBeNone(): + alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + output = addmm(g, bias, input, weight, alpha, beta) + else: + output = matmul(g, input, weight) + if not bias.node().mustBeNone(): + output = add(g, bias, output) + + return output + + +@_onnx_symbolic("aten::hann_window") +@symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v") +def hann_window( + g: jit_utils.GraphContext, + window_length, + periodic=True, + dtype: int | None = None, + layout=None, + device=None, + pin_memory=None, + requires_grad=False, +): + if dtype is None: + dtype_ = torch.get_default_dtype() + if not dtype_ or not dtype_.is_floating_point: + dtype_ = torch.float + scalar_type = _type_utils.JitScalarType.from_dtype(dtype_) + else: + scalar_type = _type_utils.JitScalarType(dtype) + + n_array = arange(g, window_length, 4, None, None, None) + output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT) + output = mul( + g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output + ) + + if periodic is False: + window_length = sub( + g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int)) + ) + output = div(g, output, window_length) + output = g.op( + "Cast", + square(g, sin(g, output)), + to_i=scalar_type.onnx_type(), + ) + + return output + + +@_onnx_symbolic("aten::mv") +def mv(g: jit_utils.GraphContext, self, vec): + return matmul(g, self, vec) + + +@_onnx_symbolic("aten::dot") +def dot(g: jit_utils.GraphContext, self, other): + return matmul(g, self, other) + + +@_onnx_symbolic("aten::movedim") +@symbolic_helper.parse_args("v", "t", "t") +def movedim(g: jit_utils.GraphContext, self, source, destination): + # This is a pythonic implementation mostly taken from aten/src/ATen/native/TensorShape.cpp::movedim + source = source.view(-1) + destination = destination.view(-1) + + assert source.size() == destination.size() + + if (source == destination).all(): + return self + + self_rank = symbolic_helper._get_tensor_rank(self) + assert self_rank is not None + + perm = list(range(self_rank)) + + src_dims = perm.copy() + dst_dims = perm.copy() + + for src, dst in zip(source.tolist(), destination.tolist()): + perm[dst] = src + src_dims[src] = -1 + dst_dims[dst] = -1 + + src_dims = [dim for dim in src_dims if dim != -1] + dst_dims = [dim for dim in dst_dims if dim != -1] + + for src, dst in zip(src_dims, dst_dims): + perm[dst] = src + + return g.op("Transpose", self, perm_i=perm) + + +@_onnx_symbolic("aten::fill") +@symbolic_helper.parse_args("v", "v") +def fill(g: jit_utils.GraphContext, self, value): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + return full_like(g, self, value, scalar_type) + + +@_onnx_symbolic("aten::index_add") +def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None): + warnings.warn( + "Warning: ONNX export does not support duplicated values in 'index' field, " + + "this will cause the ONNX model to be incorrect." + ) + + # ONNX does not support "alpha" argument, unlike aten index_add + # See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context + if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: + return symbolic_helper._unimplemented("index_add", "alpha != 1", self) + + dim = symbolic_helper._maybe_get_const(dim, "i") + if dim is None: + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting 'index_add_()' function with " + "unknown 'dim' value.", + self, + ) + + self_dim_rank = symbolic_helper._get_tensor_rank(self) + other_dim_rank = symbolic_helper._get_tensor_rank(other) + + if self_dim_rank is None or other_dim_rank is None: + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting 'index_add_()' function while " + "the rank of self tensor or tensor to be added is unknown.", + self, + ) + + if other_dim_rank != self_dim_rank: + delta = self_dim_rank - other_dim_rank + for i in range(delta): + other = symbolic_helper._unsqueeze_helper( + g, other, [symbolic_helper._get_tensor_rank(other)] + ) + + other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim) + self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim) + + if (other_dim_size is not None) and (self_dim_size is not None): + if other_dim_size > self_dim_size: + raise errors.SymbolicValueError( + "ONNX export does not support exporting 'index_add_()' function with " + "duplicated values in 'index' parameter yet.", + self, + ) + + # Construct a new shape. It's almost as same as self except the size of the 'dim' + # dimension is 1, so that we can expand other dimensions as expected. + new_shape_axes = list(range(self_dim_rank)) + new_shape_starts = [0 for i in range(self_dim_rank)] + new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)] + + new_shape = symbolic_helper._slice_helper( + g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends + ) + other = expand_as(g, other, new_shape) + + for i in range(dim): + index = symbolic_helper._unsqueeze_helper(g, index, [0]) + + for i in range(self_dim_rank - dim - 1): + index = symbolic_helper._unsqueeze_helper( + g, index, [symbolic_helper._get_tensor_rank(index)] + ) + + return scatter_add(g, self, dim, expand_as(g, index, other), other) + + +@_onnx_symbolic("aten::roll") +@symbolic_helper.parse_args("v", "is", "is") +def roll(g: jit_utils.GraphContext, self, shifts, dims): + assert len(shifts) == len(dims) + + result = self + for i in range(len(shifts)): + shapes = [] + shape = symbolic_helper._slice_helper( + g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize] + ) + shapes.append(shape) + shape = symbolic_helper._slice_helper( + g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]] + ) + shapes.append(shape) + result = g.op("Concat", *shapes, axis_i=dims[i]) + + return result + + +@_onnx_symbolic("aten::cross") +@symbolic_helper.parse_args("v", "v", "i") +def cross(g: jit_utils.GraphContext, input, other, dim=None): + dim = symbolic_helper._get_dim_for_cross(input, dim) + # If we have two tensors such that + # A = [a, b, c], B = [d, e, f], we permute the tensor such that we have + # After first roll, + # A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e) + roll_x_1 = roll(g, input, [2], [dim]) + roll_y_1 = roll(g, other, [1], [dim]) + # After second roll, + # A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d) + roll_x_2 = roll(g, input, [1], [dim]) + roll_y_2 = roll(g, other, [2], [dim]) + # cross product is calculated as + # result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)] + return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2)) + + +@_onnx_symbolic("aten::cdist") +def cdist( + g: jit_utils.GraphContext, + x1, + x2, + p=2.0, + compute_mode="use_mm_for_euclid_dist_if_necessary", +): + # X1.shape = (B * P * D), X2.shape = (B * R * D) + # In order to respect numpy style broadcasting as demonstrated in + # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md + # we unsqueeze both input tensors + # Currently we ignore the 'compute_mode' variable as we use default to + # using matrix multiplication to calculate the euclidean distance + rank = symbolic_helper._get_tensor_rank(x1) + assert rank is not None + broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1]) + broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2]) + return pairwise_distance( + g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False + ) + + +@_onnx_symbolic("aten::lerp") +def lerp(g: jit_utils.GraphContext, self, end, weight): + # Conditional for better numeric. This has been discussed in + # https://github.com/pytorch/pytorch/pull/18871 + diff = g.op("Sub", end, self) + return where( + g, + g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))), + g.op("Add", self, g.op("Mul", weight, diff)), + g.op( + "Sub", + end, + g.op( + "Mul", + diff, + g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight), + ), + ), + ) + + +@_onnx_symbolic("aten::broadcast_tensors") +def broadcast_tensors(g: jit_utils.GraphContext, self): + all_tensors = symbolic_helper._unpack_list(self) + t_with_final_shape = zeros_like(g, all_tensors[0]) + + # Add operator supports multidirectional broadcasting. So we leverage this function + # to infer the final shape generated by the broadcast. + for t in all_tensors: + t_with_final_shape = add(g, t_with_final_shape, t) + + t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors] + return g.op("prim::ListConstruct", *t_list) + + +@_onnx_symbolic("aten::is_pinned") +def is_pinned(g: jit_utils.GraphContext, self, device=None): + # Unused by ONNX. + return None + + +@_onnx_symbolic("prim::ConstantSplit") +def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim): + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented( + "prim::ConstantSplit", "unknown dimension size", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) + + +# TODO: It would be better to export this as a chunk directly, as this is +# less sensitive to changes in input size. +# TODO: Once we have proper scoping, stop reimplementing chunk, delete this +# method, and use the desugared version +@_onnx_symbolic("prim::ConstantChunk") +def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): + dim_size = symbolic_helper._get_tensor_dim_size(self, dim) + if dim_size is None: + return symbolic_helper._unimplemented( + "prim::ConstantChunk", "unknown dimension size", self + ) + split_size = (dim_size + chunks - 1) // chunks + return prim_constant_split(g, self, split_size, dim) + + +@_onnx_symbolic("prim::shape") +def prim_shape(g: jit_utils.GraphContext, self): + return g.op("Shape", self) + + +@_onnx_symbolic("prim::max") +def prim_max(g: jit_utils.GraphContext, self, other): + return symbolic_helper._op_with_optional_float_cast( + g, "Max", self, other, opset_before=12 + ) + + +@_onnx_symbolic("prim::min") +def prim_min(g: jit_utils.GraphContext, self, other=None): + if not other: + if symbolic_helper._is_packed_list(self): + self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) + return min(g, self) + return min(g, self, other) + + +@_onnx_symbolic("prim::data") +def prim_data(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("prim::layout") +def prim_layout(g: jit_utils.GraphContext, self): + # Always return 'torch.strided'. Other layout types are not supported by JIT 'TensorType'. + # Layout class defined in 'c10/core/Layout.h'. + return g.op("Constant", value_t=torch.tensor(0)) + + +@_onnx_symbolic("prim::ListConstruct") +def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs): + return None + + +@_onnx_symbolic("prim::ListUnpack") +def prim_list_unpack( + g: jit_utils.GraphContext, *inputs, **kwargs +) -> list[_C.Value] | None: + if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": + # Cancel the previous node if it is ListConstruct by returning its inputs + # TODO(justinchuby): Use a public method in the helper module + return symbolic_helper._unpack_list(inputs[0]) + + return None + + +@_onnx_symbolic("prim::TupleConstruct") +def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs): + return None + + +@_onnx_symbolic("prim::Uninitialized") +def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs): + return None + + +# exists to refine the type of the Value +# if x is an optional Tensor, unchecked_cast will cast +# x to Tensor, so the rest of the graph knows that x is a Tensor +# this doesn't do anything in runtime and is a noop in ONNX +@_onnx_symbolic("prim::unchecked_cast") +def prim_unchecked_cast(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("prim::dtype") +def prim_dtype(g: jit_utils.GraphContext, self): + scalar_type = symbolic_helper._try_get_scalar_type(self) + if scalar_type is None: + scalar_type = _type_utils.JitScalarType.FLOAT + # This node records a torch dtype as int + return g.op("Constant", value_t=torch.tensor(scalar_type)) + + +@_onnx_symbolic("prim::tolist") +def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val): + """tolist is currently supported only for 1D input tensors. + + dim_val and elem_ty_val represent dimension and type annotations + that need to match dimension and type of the input tensor. + """ + dim = symbolic_helper._maybe_get_const(dim_val, "i") + if dim > 1: + return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input) + return input + + +# ----------------------------------------------------------------------------- +# Symbolic functions that need extra context +# ----------------------------------------------------------------------------- +@_onnx_symbolic("prim::device") +def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: + output_type = g.original_node.output().type() + if isinstance(output_type, _C.DeviceObjType): + return None + + return symbolic_helper._unimplemented( + "prim::device", + f"output type should be 'DeviceObjType', not '{output_type.kind()}'", + g.original_node.output(), + ) + + +@_onnx_symbolic("prim::Loop") +def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: + node = g.original_node + env = g.env + values_in_env = g.values_in_env + params_dict = g.params_dict + + operator_export_type = GLOBALS.operator_export_type + opset_version = GLOBALS.export_onnx_opset_version + + old_blocks = tuple(node.blocks()) + new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( + g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks) + ) + + for old_block, new_block_context in zip(old_blocks, new_block_contexts): + # Copy input metadata to subblock + # + # prim::Loop(iter, cond, input_1, ..., input_n) + # block0(iter, input_1, ..., input_n) + # + # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. + for i, b_in in enumerate(old_block.inputs()): + if i == 0 and i < len(inputs): + b_in.setType(inputs[i].type()) + # For optional block inputs, they may switch between None not-None inside + # the loop body, so if the loop input is not optional, the block input may + # still need to be optional. + if ( + i > 0 + and (i + 1) < len(inputs) + and not isinstance(b_in.type(), _C.OptionalType) + ): + b_in.setType(inputs[i + 1].type()) + torch._C._jit_pass_onnx_block( + old_block, + new_block_context.block, + operator_export_type, + env, + values_in_env, + False, + ) + fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( + new_node, opset_version + ) + # Run shape type inference for Loop after subblock is converted. + if GLOBALS.onnx_shape_inference: + torch._C._jit_pass_onnx_node_shape_type_inference( + new_node, params_dict, opset_version + ) + return fixed_outputs + + +@_onnx_symbolic("prim::If") +def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: + n = g.original_node + block = g.block + env = g.env + values_in_env = g.values_in_env + params_dict = g.params_dict + + operator_export_type = GLOBALS.operator_export_type + opset_version = GLOBALS.export_onnx_opset_version + + static_if = inputs[0].node().kind() == "onnx::Constant" + if static_if: + # Fold static if + # + # The torch IR + # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), + # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... + # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %21 : Long(device=cpu) = aten::eq(%20, %64) + # %22 : Long(device=cpu) = prim::If(%21) + # block0(): + # %23 : Long(device=cpu) = aten::is_floating_point(%input.1) + # -> (%23) + # block1(): + # -> (%65) + # %input.53 : Tensor, %weight : Tensor = prim::If(%22) + # block0(): + # -> (%embedding_matrix.1, %input.1) + # block1(): + # -> (%input.1, %embedding_matrix.1) + # %26 : int[] = aten::size(%input.53) + # + # The converted ONNX graph + # %10 : Bool(device=cpu) = onnx::Constant[value={0}]() + # %14 : Bool(device=cpu) = onnx::Equal(%13, %8) + # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() + # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) + input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist() + const_value = ( + all(input_flag) if isinstance(input_flag, list) else bool(input_flag) + ) + block_idx = 0 if const_value else 1 + current_b = list(n.blocks())[block_idx] + env = torch._C._jit_pass_onnx_block( + current_b, + block, + operator_export_type, + env, + values_in_env, + True, + ) + if_output_list = list(n.outputs()) + current_b_list = list(current_b.outputs()) + + final_b_list = [] + for idx in range(len(if_output_list)): + if current_b_list[idx] not in env: + raise errors.SymbolicValueError( + f"The sub block ATen output {current_b_list[idx]} is not in env.", + current_b_list[idx], + ) # type:ignore[operator] + onnx_b = env[current_b_list[idx]] + final_b_list.append(onnx_b) + return final_b_list + else: + old_blocks = tuple(n.blocks()) + new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( + g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks) + ) + + for old_block, new_block_context in zip(old_blocks, new_block_contexts): + torch._C._jit_pass_onnx_block( + old_block, + new_block_context.block, + operator_export_type, + env, + values_in_env, + False, + ) + fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( + new_node, opset_version + ) + # Run shape type inference for If after subblock is converted. + if GLOBALS.onnx_shape_inference: + torch._C._jit_pass_onnx_node_shape_type_inference( + new_node, params_dict, opset_version + ) + return fixed_outputs + + +@_onnx_symbolic("prim::Constant") +def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs): + node = g.original_node + + if node.mustBeNone(): + return None + # This must go before checking for string values, because some device constants + # have string values, but we want to keep them as unconverted Device types so + # that eq() can work on them. + if isinstance(node.output().type(), _C.DeviceObjType): + return None + if node.kindOf("value") == "t": + return g.op("Constant", value_t=symbolic_helper._node_get(node, "value")) + if node.kindOf("value") == "s": + return g.op("Constant", value_s=symbolic_helper._node_get(node, "value")) + if node.output().type().isSubtypeOf( + _C.ListType.ofInts() + ) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()): + return g.op( + "Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value")) + ) + if node.output().type().isSubtypeOf(_C.ListType.ofStrings()): + str_constants = [ + g.op("Constant", value_s=s) + for s in symbolic_helper._node_get(node, "value") + ] + return g.op("prim::ListConstruct", *str_constants) + + raise errors.SymbolicValueError( + f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. " + f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.", + node.output(), + ) + + +@_onnx_symbolic("prim::type") +def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs): + if device_value.node().kind() == "prim::device": + device = jit_utils.get_device_from_value(device_value.node().input()) + if device is not None: + return g.op("Constant", value_s=str(device)) + + return symbolic_helper._unimplemented( + "prim::type", + "Device type cannot be statically determined.", + device_value, + ) + + +@_onnx_symbolic("onnx::Placeholder") +def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs): + node = g.original_node + block = g.block + env = g.env + values_in_env = g.values_in_env + + return torch._C._jit_onnx_convert_pattern_from_subblock( + block, node, env, values_in_env + ) + + +@_onnx_symbolic("aten::resolve_conj") +@_onnx_symbolic("aten::resolve_neg") +def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value): + # ONNX does not have operators to *directly* manipulate real/imaginary components + # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, + # which results in failures due to missing operators for complex numbers + + # `aten::resolve_conj` and `aten::resolve_neg` can safely be implemented as no-op + return input + + +@_onnx_symbolic("aten::_conj") +@_onnx_symbolic("aten::conj_physical") +def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value): + # ONNX does not have operators to *directly* manipulate real/imaginary components + # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, + # which results in failures due to missing operators for complex numbers + + # While `aten::_conj` and `aten::conj_physical` raise exception when input is complex + if symbolic_helper.is_complex_value(input): + # FIXME(justinchuby): report correct name for symbolic being executed + return symbolic_helper._onnx_unsupported( + "aten::_conj, aten::conj_physical", + input, + ) + + # they can safely be implemented as no-op for real numbers only + return noop_complex_operators(g, input) + + +@_onnx_symbolic("aten::logit") +def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value): + one = g.op("Constant", value_t=torch.tensor(1.0)) + + if not symbolic_helper._is_none(eps): + eps = g.op( + "Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type() + ) + one_sub_eps = g.op("Sub", one, eps) + self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self) + temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps) + + temporary_self_less_eps = g.op("Less", temporary_self, eps) + z = g.op("Where", temporary_self_less_eps, eps, temporary_self) + else: + z = self + + sub = g.op("Sub", one, z) + div = g.op("Div", z, sub) + return g.op("Log", div) diff --git a/lib/python3.10/site-packages/torch/onnx/utils.py b/lib/python3.10/site-packages/torch/onnx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7848e580584519505df117128d5a5a6fff9708 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/utils.py @@ -0,0 +1,1990 @@ +# mypy: allow-untyped-defs +"""Functions to export models into the ONNX IR format. + +These models can be loaded with the ONNX library and then +converted to models which run on other deep learning frameworks. +""" + +from __future__ import annotations + +import contextlib +import copy +import inspect +import re +import typing +import warnings +from typing import Any, Callable, cast, Collection, Mapping, Sequence + +import torch +import torch._C._onnx as _C_onnx +import torch.jit._trace +import torch.serialization +from torch import _C +from torch.onnx import ( # noqa: F401 + _constants, + _deprecation, + _exporter_states, + errors, + symbolic_helper, +) +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration + + +__all__ = [ + "is_in_onnx_export", + "select_model_mode_for_export", + "disable_apex_o2_state_dict_hook", + "setup_onnx_logging", + "exporter_context", + "export", + "model_signature", + "warn_on_static_input_change", + "unpack_quantized_tensor", + "export_to_pretty_string", + "unconvertible_ops", + "register_custom_op_symbolic", + "unregister_custom_op_symbolic", +] + + +def is_in_onnx_export() -> bool: + """Returns whether it is in the middle of ONNX export.""" + return GLOBALS.in_onnx_export + + +# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp +# Skip check due to cannot import IValue from torch._C +_params_dict = {} # type: ignore[var-annotated] + + +@contextlib.contextmanager +def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): + r"""A context manager to temporarily set the training mode of ``model`` + to ``mode``, resetting it when we exit the with-block. + + Args: + model: Same type and meaning as ``model`` arg to :func:`export`. + mode: Same type and meaning as ``training`` arg to :func:`export`. + """ + if not isinstance(mode, _C_onnx.TrainingMode): + raise TypeError( + f"'mode' should be a torch.onnx.TrainingMode enum, but got '{type(mode)}'." + ) + originally_training: bool = False + + if hasattr(model, "training"): + originally_training = model.training + + # ONNX opset 12 has better support for training amenable models, with updated + # versions of the dropout and batch_norm operators + if mode == _C_onnx.TrainingMode.TRAINING or ( + mode == _C_onnx.TrainingMode.PRESERVE and originally_training + ): + GLOBALS.export_training = True + if GLOBALS.export_onnx_opset_version < 12: + warnings.warn( + "You are exporting the model in training mode with onnx opset " + f"version {GLOBALS.export_onnx_opset_version}. " + "Opset versions lower than opset 12 will not be able to export " + "nodes such as Dropout and BatchNorm correctly." + ) + else: + GLOBALS.export_training = False + + GLOBALS.training_mode = mode + if mode == _C_onnx.TrainingMode.TRAINING: + model.train(True) + elif mode == _C_onnx.TrainingMode.EVAL: + model.train(False) + # else mode == _C_onnx.TrainingMode.PRESERVE, do nothing + + try: + yield + finally: + if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE: + model.train(originally_training) + + +@contextlib.contextmanager +def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFunction): + # Apex O2 hook state_dict to return fp16 weights as fp32. + # Exporter cannot identify them as same tensors. + # Since this hook is only used by optimizer, it is safe to + # remove this hook while exporting. + if not isinstance(model, torch.jit.ScriptFunction): + model_hooks = {} # type: ignore[var-annotated] + for module in model.modules(): + for key, hook in module._state_dict_hooks.items(): + if type(hook).__name__ == "O2StateDictHook": + if module not in model_hooks: + model_hooks[module] = {} + model_hooks[module][key] = hook + if module in model_hooks: + for key in model_hooks[module]: + module._state_dict_hooks.pop(key) + try: + yield + finally: + # Add the hooks back + for module, m_map in model_hooks.items(): + for key, hook in m_map.items(): + module._state_dict_hooks[key] = hook + else: + try: + yield + finally: + pass + + +@contextlib.contextmanager +def setup_onnx_logging(verbose: bool): + is_originally_enabled = torch.onnx.is_onnx_log_enabled() + if is_originally_enabled or verbose: + torch.onnx.enable_log() + try: + yield + finally: + if not is_originally_enabled: + torch.onnx.disable_log() + + +@contextlib.contextmanager +def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool): + with select_model_mode_for_export( + model, mode + ) as mode_ctx, disable_apex_o2_state_dict_hook( + model + ) as apex_ctx, setup_onnx_logging( + verbose + ) as log_ctx, diagnostics.create_export_diagnostic_context() as diagnostic_ctx: + yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx) + + +def _get_torch_export_args( + args: tuple[Any, ...], + kwargs: dict[str, Any] | None, +) -> tuple[tuple[Any, ...], dict[str, Any] | None]: + """Obtain the arguments for torch.onnx.export from the model and the input arguments.""" + if not kwargs and args and isinstance(args[-1], dict): + kwargs = args[-1] + args = args[:-1] + return args, kwargs + + +def export( + model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction, + args: tuple[Any, ...] | torch.Tensor, + f: str, + *, + kwargs: dict[str, Any] | None = None, + export_params: bool = True, + verbose: bool = False, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, + opset_version: int | None = None, + do_constant_folding: bool = True, + dynamic_axes: Mapping[str, Mapping[int, str]] + | Mapping[str, Sequence[int]] + | None = None, + keep_initializers_as_inputs: bool | None = None, + custom_opsets: Mapping[str, int] | None = None, + export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, + autograd_inlining: bool = True, +) -> None: + r"""Exports a model into ONNX format. + + If ``model`` is not a :class:`torch.jit.ScriptModule` nor a + :class:`torch.jit.ScriptFunction`, this runs + ``model`` once in order to convert it to a TorchScript graph to be exported + (the equivalent of :func:`torch.jit.trace`). Thus this has the same limited support + for dynamic control flow as :func:`torch.jit.trace`. + + Args: + model: The model to be exported. + args: + + args can be structured either as: + + 1. ONLY A TUPLE OF ARGUMENTS:: + + args = (x, y, z) + + The tuple should contain model inputs such that ``model(*args)`` is a valid + invocation of the model. Any non-Tensor arguments will be hard-coded into the + exported model; any Tensor arguments will become inputs of the exported model, + in the order they occur in the tuple. + + 2. A TENSOR:: + + args = torch.Tensor([1]) + + This is equivalent to a 1-ary tuple of that Tensor. + + 3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS:: + + args = (x, {"y": input_y, "z": input_z}) + + All but the last element of the tuple will be passed as non-keyword arguments, + and named arguments will be set from the last element. If a named argument is + not present in the dictionary, it is assigned the default value, or None if a + default value is not provided. + + .. warning:: + This behavior will be deprecated in a future release. Please use the + kwargs argument instead. + + .. note:: + If a dictionary is the last element of the args tuple, it will be + interpreted as containing named arguments. In order to pass a dict as the + last non-keyword arg, provide an empty dict as the last element of the args + tuple. For example, instead of:: + + torch.onnx.export( + model, + ( + x, + # WRONG: will be interpreted as named arguments + {y: z}, + ), + "test.onnx.pb", + ) + + Write:: + + torch.onnx.export(model, (x, {y: z}, {}), "test.onnx.pb") + + f: Path to the output ONNX model file. E.g. "model.onnx". + kwargs: Named arguments to the model. + export_params: If True, all parameters will + be exported. Set this to False if you want to export an untrained model. + In this case, the exported model will first take all of its parameters + as arguments, with the ordering as specified by ``model.state_dict().values()`` + verbose: if True, prints a description of the + model being exported to stdout. In addition, the final ONNX graph will include the + field ``doc_string``` from the exported model which mentions the source code locations + for ``model``. If True, ONNX exporter logging will be turned on. + training: + * ``TrainingMode.EVAL``: export the model in inference mode. + * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is + False and in training mode if model.training is True. + * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations + which might interfere with training. + input_names (list of str, default empty list): names to assign to the + input nodes of the graph, in order. + output_names (list of str, default empty list): names to assign to the + output nodes of the graph, in order. + operator_export_type (enum, default OperatorExportTypes.ONNX): + + .. warning:: + This option will be deprecated in a future release. Future exported + graphs will always use the default opset domain. + + * ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops + (in the default opset domain). + * ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops + to standard ONNX ops in the default opset domain. If unable to do so + (e.g. because support has not been added to convert a particular torch op to ONNX), + fall back to exporting the op into a custom opset domain without conversion. Applies + to `custom ops `_ + as well as ATen ops. For the exported model to be usable, the runtime must support + these non-standard ops. + * ``OperatorExportTypes.ONNX_ATEN``: All ATen ops (in the TorchScript namespace "aten") + are exported as ATen ops (in opset domain "org.pytorch.aten"). + `ATen `_ is PyTorch's built-in tensor library, so + this instructs the runtime to use PyTorch's implementation of these ops. + + .. warning:: + + Models exported this way are probably runnable only by Caffe2. + + This may be useful if the numeric differences in implementations of operators are + causing large differences in behavior between PyTorch and Caffe2 (which is more + common on untrained models). + + * ``OperatorExportTypes.ONNX_ATEN_FALLBACK``: Try to export each ATen op + (in the TorchScript namespace "aten") as a regular ONNX op. If we are unable to do so + (e.g. because support has not been added to convert a particular torch op to ONNX), + fall back to exporting an ATen op. See documentation on OperatorExportTypes.ONNX_ATEN for + context. + For example:: + + graph(%0 : Float): + %3 : int = prim::Constant[value=0]() + # conversion unsupported + %4 : Float = aten::triu(%0, %3) + # conversion supported + %5 : Float = aten::mul(%4, %0) + return (%5) + + Assuming ``aten::triu`` is not supported in ONNX, this will be exported as:: + + graph(%0 : Float): + %1 : Long() = onnx::Constant[value={0}]() + # not converted + %2 : Float = aten::ATen[operator="triu"](%0, %1) + # converted + %3 : Float = onnx::Mul(%2, %0) + return (%3) + + .. warning:: + + Models exported this way are probably runnable only by Caffe2. + + opset_version (int, default 17): The version of the + `default (ai.onnx) opset `_ + to target. Must be >= 7 and <= 17. + do_constant_folding: Apply the constant-folding optimization. + Constant-folding will replace some of the ops that have all constant inputs + with pre-computed constant nodes. + dynamic_axes: + + By default the exported model will have the shapes of all input and output tensors + set to exactly match those given in ``args``. To specify axes of tensors as + dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: + + * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or + ``output_names``. + * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a + list, each element is an axis index. + + For example:: + + class SumModule(torch.nn.Module): + def forward(self, x): + return torch.sum(x, dim=1) + + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_value: 2 # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_value: 2 # axis 0 + ... + + While:: + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + dynamic_axes={ + # dict value: manually named axes + "x": {0: "my_custom_axis_name"}, + # list value: automatic names + "sum": [0], + }, + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_param: "my_custom_axis_name" # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_param: "sum_dynamic_axes_1" # axis 0 + ... + + keep_initializers_as_inputs: If True, all the + initializers (typically corresponding to parameters) in the + exported graph will also be added as inputs to the graph. If False, + then initializers are not added as inputs to the graph, and only + the non-parameter inputs are added as inputs. + This may allow for better optimizations (e.g. constant folding) by + backends/runtimes. + + If True, `deduplicate_initializers` pass will not be executed. This means + initializers with duplicated values will not be deduplicated and + will be treated as distinct inputs to the graph. This allows different + input initializers to be supplied at the runtime following export. + + If ``opset_version < 9``, initializers MUST be part of graph + inputs and this argument will be ignored and the behavior will be + equivalent to setting this argument to True. + + custom_opsets (dict[str, int], default empty dict): A dict with schema: + + * KEY (str): opset domain name + * VALUE (int): opset version + + If a custom opset is referenced by ``model`` but not mentioned in this dictionary, + the opset version is set to 1. Only custom opset domain name and version should be + indicated through this argument. + + export_modules_as_functions: Flag to enable + exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the + particular types of modules to export as local functions in ONNX. + This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because + ``opset_version`` < 15 implies IR version < 8, which means no local function support. + Module variables will be exported as function attributes. There are two categories of function + attributes. + + 1. Annotated attributes: class variables that have type annotations via + `PEP 526-style `_ + will be exported as attributes. + Annotated attributes are not used inside the subgraph of ONNX local function because + they are not created by PyTorch JIT tracing, but they may be used by consumers + to determine whether or not to replace the function with a particular fused kernel. + + 2. Inferred attributes: variables that are used by operators inside the module. Attribute names + will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from + python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. + + * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. + * ``True``: export all ``nn.Module`` forward calls as local function nodes. + * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, + only if the type of the ``nn.Module`` is found in the set. + + autograd_inlining: Flag used to control whether to inline autograd functions. + Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. + + Raises: + :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph. + :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it + uses an operator that is not supported by the exporter. + :class:`torch.onnx.errors.OnnxExporterError`: Other errors that can occur during export. + All errors are subclasses of :class:`errors.OnnxExporterError`. + """ + if operator_export_type != _C_onnx.OperatorExportTypes.ONNX: + warnings.warn( + "Setting `operator_export_type` to something other than default is deprecated. " + "The option will be removed in a future release.", + category=FutureWarning, + ) + if training == _C_onnx.TrainingMode.TRAINING: + warnings.warn( + "Setting `training` to something other than default is deprecated. " + "The option will be removed in a future release. Please set the training mode " + "before exporting the model.", + category=FutureWarning, + ) + + args = (args,) if isinstance(args, torch.Tensor) else args + if kwargs is not None: + args = args + (kwargs,) + + _export( + model, + args, + f, + export_params, + verbose, + training, + input_names, + output_names, + operator_export_type=operator_export_type, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + custom_opsets=custom_opsets, + export_modules_as_functions=export_modules_as_functions, + autograd_inlining=autograd_inlining, + ) + + return None + + +def _is_constant_tensor_list(node): + if node.kind() != "prim::Constant": + return False + output_type = node.output().type() + if output_type.isSubtypeOf(_C.ListType.ofTensors()): + return True + if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())): + return True + + +# ONNX can't handle constants that are lists of tensors, which can +# get generated in constant prop. So we split them back into prim::ListConstructs + + +def _split_tensor_list_constants(g, block): + for node in block.nodes(): + for subblock in node.blocks(): + _split_tensor_list_constants(g, subblock) + if _is_constant_tensor_list(node): + inputs = [] + for val in node.output().toIValue(): + input = g.insertConstant(val) + input.node().moveBefore(node) + input.node().copyMetadata(node) + inputs.append(input) + + lc = ( + g.create("prim::ListConstruct", inputs) + .insertBefore(node) + .output() + .setType(_C.ListType.ofTensors()) + ) + lc.node().copyMetadata(node) + node.output().replaceAllUsesWith(lc) + + +def _optimize_graph( + graph: _C.Graph, + operator_export_type: _C_onnx.OperatorExportTypes, + _disable_torch_constant_prop: bool = False, + fixed_batch_size: bool = False, + params_dict=None, + dynamic_axes=None, + input_names=None, + module=None, +): + if params_dict is None: + params_dict = {} + + # Inline everything + _C._jit_pass_inline(graph) + + # Remove fork/wait nodes + _C._jit_pass_inline_fork_wait(graph) + _C._jit_pass_lint(graph) + if GLOBALS.autograd_inlining: + _C._jit_pass_onnx_autograd_function_process(graph) + _C._jit_pass_lower_all_tuples(graph) + + # we now record some ops like ones/zeros + # into a trace where we previously recorded constants. + # use constant prop to maintain our current level of onnx support + # without implementing symbolics for all of them + if _disable_torch_constant_prop is False: + _C._jit_pass_constant_propagation(graph) + + _split_tensor_list_constants(graph, graph) + # run dce to eliminate dead parts of the graph that might have been + # left behind by things like symbolic_override + _C._jit_pass_dce(graph) + _C._jit_pass_lint(graph) + + # CSE should improve perf when Autocast is used with disabled cache + # Autocast is disabled due to a limitation on tracer as described at https://github.com/pytorch/pytorch/issues/84092 + # Must run before _C._jit_pass_erase_number_types to prevent type substitution + if _C._jit_pass_cse(graph): + _C._jit_pass_onnx_lint(graph) + + _C._jit_pass_canonicalize_graph_fuser_ops(graph) + _C._jit_pass_lint(graph) + _C._jit_pass_peephole(graph, True) + _C._jit_pass_fuse_addmm(graph) + _C._jit_pass_lint(graph) + + _C._jit_pass_peephole(graph, True) + _C._jit_pass_lower_all_tuples(graph) + # in _jit_pass_onnx, symbolic functions are called for each node for conversion. + # However, there are nodes that cannot be converted without additional context. + # For example, the number of outputs from split (and whether it is static or dynamic) is unknown + # until the point where it is unpacked by listUnpack node. + # This pass does a preprocess, and prepares the nodes such that enough context can be received + # by the symbolic function. + _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) + _C._jit_pass_onnx_preprocess(graph) + + # onnx does not support tuples, so try to remove them + _C._jit_pass_lint(graph) + + # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0 + _C._jit_pass_prepare_division_for_onnx(graph) + + _C._jit_pass_onnx_remove_print(graph) + _C._jit_pass_onnx_preprocess_caffe2(graph) + + symbolic_helper._quantized_ops.clear() + # Unpack quantized weights for conv and linear ops and insert into graph. + _C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict) + # onnx only supports tensors, so we turn all out number types into tensors + _C._jit_pass_erase_number_types(graph) + if GLOBALS.onnx_shape_inference: + input_names = [] if input_names is None else input_names + dynamic_axes = {} if dynamic_axes is None else dynamic_axes + _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names) + _C._jit_pass_onnx_lint(graph) + + graph = _C._jit_pass_onnx(graph, operator_export_type) + _C._jit_pass_onnx_lint(graph) + _C._jit_pass_lint(graph) + + _C._jit_pass_onnx_scalar_type_analysis( + graph, True, GLOBALS.export_onnx_opset_version + ) + _C._jit_pass_lint(graph) + + _C._jit_pass_onnx_peephole( + graph, GLOBALS.export_onnx_opset_version, fixed_batch_size + ) + _C._jit_pass_lint(graph) + + # graph is not a valid jit graph anymore because types have been replaced + # (e.g. int with Tensor), so it now contains operators that don't actually + # exist. We can't run normal dead code elimination because it'd fail trying + # to look up if an operator has side effects, but we can run a dead code + # elimination variant that doesn't need to look up if an op has side effects. + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + _C._jit_pass_lint(graph) + graph = _C._jit_pass_canonicalize(graph) + _C._jit_pass_lint(graph) + if GLOBALS.onnx_shape_inference: + _C._jit_pass_onnx_graph_shape_type_inference( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) + + return graph + + +def warn_on_static_input_change(input_states): + """Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph. + + We accept dictionaries and strings as ONNX inputs, but they should be only for + configuration use. we detect here if these inputs are modified, and if so we warn + the user that the changes won't take effect in the traced ONNX graph. + """ + for input, traced_input in zip(input_states[0], input_states[1]): + if isinstance(input, dict): + if list(input.keys()) != list(traced_input.keys()): + warning = ( + "We detected that you are modifying a dictionary that is an input to your " + "model. " + "Note that dictionaries are allowed as inputs in ONNX but they should be " + "handled with care. " + "Usages of dictionaries is not recommended, and should not be used except " + "for configuration use. " + "Also note that the order and values of the keys must remain the same. " + ) + warnings.warn(warning) + elif isinstance(input, str): + if input != traced_input: + warning = ( + "The model seems to have string inputs/outputs. " + "Note that strings will not appear as inputs/outputs of the ONNX graph. " + ) + warnings.warn(warning) + + +def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type): + """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX.""" + return arg_value + + +def _decide_keep_init_as_input( + keep_initializers_as_inputs: bool | None, + operator_export_type: _C_onnx.OperatorExportTypes, + opset_version: int, +): + """Decides whether the initializers in the graph should be listed as ONNX graph inputs. + + This method encapsulates the logic to decide whether the initializers in the graph + should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4). + If keep_initializers_as_inputs is not specified (None), then we decide whether to keep + initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type + is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other + export types keep initializers as input (val_keep_init_as_ip=True). + If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8, + in which case it must be ignored because for opset version <= 8, all initializers MUST be + part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True. + + Special handling is needed for opset version 8 or lower, because irrespective + of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3 + semantics, i.e. all initializers must be listed as ONNX graph input. + """ + + if opset_version < 9: + if keep_initializers_as_inputs is False: + warnings.warn( + "Setting 'keep_initializers_as_inputs=False' for opset version" + "8 or lower would lead to an invalid ONNX graph. Therefore, " + "'keep_initializers_as_inputs=False' is ignored during export." + "Exported model will have initializers as graph inputs (compliant " + " to ONNX IR v3)." + ) + return True # i.e. True == initializers are part of graph input (ONNX IR v3) + val_keep_init_as_ip = ( + True if keep_initializers_as_inputs is None else keep_initializers_as_inputs + ) + if ( + keep_initializers_as_inputs is None + and operator_export_type is _C_onnx.OperatorExportTypes.ONNX + ): + val_keep_init_as_ip = False + return val_keep_init_as_ip + + +def _decide_add_node_names(add_node_names, operator_export_type): + return _resolve_args_by_export_type( + "add_node_names", add_node_names, operator_export_type + ) + + +def _decide_constant_folding(do_constant_folding, operator_export_type, training): + do_constant_folding = _resolve_args_by_export_type( + "do_constant_folding", do_constant_folding, operator_export_type + ) + if do_constant_folding and ( + training is not None and training is not _C_onnx.TrainingMode.EVAL + ): + warnings.warn( + "It is recommended that constant folding be turned off ('do_constant_folding=False') " + "when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' " + "or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some " + "learnable model parameters may not translate correctly in the exported ONNX model " + "because constant folding mutates model parameters. Please consider " + "turning off constant folding or setting the training=TrainingMode.EVAL." + ) + return do_constant_folding + + +def _signature(model) -> inspect.Signature: + should_be_callable = getattr(model, "forward", model) + if callable(should_be_callable): + return inspect.signature(should_be_callable) + raise ValueError("model has no forward method and is not callable") + + +def _decide_input_format(model, args): + try: + sig = _signature(model) + except ValueError as e: + warnings.warn(f"{e}, skipping _decide_input_format") + return args + try: + ordered_list_keys = list(sig.parameters.keys()) + if ordered_list_keys[0] == "self": + ordered_list_keys = ordered_list_keys[1:] + args_dict: dict = {} + if isinstance(args, list): + args_list = args + elif isinstance(args, tuple): + args_list = list(args) + else: + args_list = [args] + if isinstance(args_list[-1], dict): + args_dict = args_list[-1] + args_list = args_list[:-1] + n_nonkeyword = len(args_list) + for optional_arg in ordered_list_keys[n_nonkeyword:]: + if optional_arg in args_dict: + args_list.append(args_dict[optional_arg]) + # Check if this arg has a default value + else: + param = sig.parameters[optional_arg] + if param.default != param.empty: + args_list.append(param.default) + args = args_list if isinstance(args, list) else tuple(args_list) + # Cases of models with no input args + except IndexError: + warnings.warn("No input args, skipping _decide_input_format") + except Exception as e: + warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}") + return args + + +def _from_dynamic_axes_to_dynamic_shapes( + model, + dynamic_axes: Mapping[str, Mapping[int, str]] + | Mapping[str, Sequence[int]] + | None = None, + input_names: Sequence[str] | None = None, +) -> dict[str, Any] | None: + """ + + dynamic_axes examples: + (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}} + (2) dynamic_axes = {"x": [0], "y": [1]} + + these will be converted to dynamic_shapes respectively: + (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}} + (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}} # auto-generated dim names + + """ + if dynamic_axes is None: + return None + + if input_names is None: + input_names_set = set() + else: + input_names_set = set(input_names) + + dynamic_shapes: dict[str, Any | None] = {} + for input_name, axes in dynamic_axes.items(): + if input_name in input_names_set: + raise ValueError( + "Assinging new input names is not supported yet. Please use model forward signature " + "to specify input names in dynamix_axes." + ) + if isinstance(axes, dict): + dynamic_shapes[input_name] = { + k: torch.export.Dim(v) for k, v in axes.items() + } + elif isinstance(axes, list): + dynamic_shapes[input_name] = { + k: torch.export.Dim(f"{input_name}_dim_{k}") for k in axes + } + else: + raise TypeError( + f"dynamic_axes value must be either a dict or a list, but got {type(axes)}" + ) + # torch.export.export needs static dim to present in dynamic_shapes + # for all input tensors, so we need to add them with None + try: + sig = _signature(model) + except ValueError as e: + warnings.warn(f"{e}, skipping auto filling None on static axes...") + return dynamic_shapes + for input_name in sig.parameters.keys(): + if input_name not in dynamic_shapes: + dynamic_shapes[input_name] = None + return dynamic_shapes + + +def _trace(func, args, operator_export_type, return_outs=False): + # Special case for common case of passing a single Tensor + if isinstance(args, torch.Tensor): + args = (args,) + + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + func, + args, + strict=False, + _force_outplace=False, + _return_inputs_states=True, + ) + warn_on_static_input_change(inputs_states) + + trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={}) + if return_outs: + return trace_graph, torch_out + return trace_graph + + +def _trace_and_get_graph_from_model(model, args): + # A basic sanity check: make sure the state_dict keys are the same + # before and after running the model. Fail fast! + orig_state_dict_keys = torch.jit._unique_state_dict(model).keys() + + # Disable Autocast cache because it replaces kernel's weight and bias + # by (undesired) constants. + # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665 + prev_autocast_cache_enabled = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(False) + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + model, + args, + strict=False, + _force_outplace=False, + _return_inputs_states=True, + ) + torch.set_autocast_cache_enabled(prev_autocast_cache_enabled) + + warn_on_static_input_change(inputs_states) + + if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys(): + raise RuntimeError( + "state_dict changed after running the tracer; " + "something weird is happening in your model!" + ) + + return trace_graph, torch_out + + +def _get_param_count_list(method_graph, args_params): + param_count_list = [] + for input_, arg_params_ in zip(method_graph.inputs(), args_params): + if "PackedParams" in str(input_.type()): + in_vars, _ = torch.jit._flatten(arg_params_) + param_count_list.append(len(in_vars)) + else: + param_count_list.append(arg_params_ is not None) + + return param_count_list + + +def _check_flatten_did_not_remove(original, jit_flattened): + """torch.jit._flatten removes None. Check if it did so in this case.""" + + def flatten(x): + if isinstance(x, (list, tuple)): + for inner in x: + yield from flatten(inner) + elif isinstance(x, dict): + for inner in x.values(): + yield from flatten(inner) + else: + yield x + + flattened_with_none = list(flatten(original)) + num_none = len(flattened_with_none) - len(jit_flattened) + assert num_none >= 0 + if num_none: + raise ValueError( + f"args contained {num_none} None's after flattening. " + "When exporting a ScriptModule or ScriptFunction, no args may " + "be None because that breaks type propagation." + ) + + +def _create_jit_graph( + model: torch.nn.Module | torch.jit.ScriptFunction, args: Sequence[Any] +) -> tuple[_C.Graph, list[_C.IValue], Any | None, _C.ScriptModule | None]: + if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): + flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) + _check_flatten_did_not_remove(args, flattened_args) + torch_out = None + + if isinstance(model, torch.jit.ScriptModule): + try: + graph = model.forward.graph # type: ignore[attr-defined] + except AttributeError as e: + raise RuntimeError("'forward' method must be a script method") from e + _C._jit_pass_onnx_function_substitution(graph) + freezed_module = _C._freeze_module( + cast(_C.ScriptModule, model._c), preserveParameters=True + ) + module, params = _C._jit_onnx_list_model_parameters(freezed_module) + method_graph = module._get_method("forward").graph + args_params = tuple(args) + tuple(params) + param_count_list = _get_param_count_list(method_graph, args_params) + in_vars, _ = torch.jit._flatten(args_params) + graph = _C._propagate_and_assign_input_shapes( + method_graph, tuple(in_vars), param_count_list, False, False + ) + return graph, params, torch_out, module + + # torch.jit.ScriptFunction + params = [] + graph = model.graph + _C._jit_pass_onnx_function_substitution(graph) + param_count_list = _get_param_count_list(graph, args) + graph = _C._propagate_and_assign_input_shapes( + graph, flattened_args, param_count_list, False, False + ) + return graph, params, torch_out, None + + graph, torch_out = _trace_and_get_graph_from_model(model, args) + _C._jit_pass_onnx_lint(graph) + state_dict = torch.jit._unique_state_dict(model) + params = list(state_dict.values()) + graph_inputs = list(graph.inputs()) + user_input_num = len(graph_inputs) - len(state_dict) + param_names = list(state_dict.keys()) + for i, inp in enumerate(graph_inputs): + if i >= user_input_num: + inp.setDebugName(param_names[i - user_input_num]) + _C._jit_pass_onnx_function_substitution(graph) + return graph, params, torch_out, None + + +def _get_named_param_dict(graph, params): + input_and_param_names = [val.debugName() for val in graph.inputs()] + param_names = input_and_param_names[len(input_and_param_names) - len(params) :] + _params_dict = dict(zip(param_names, params)) + return _params_dict + + +def _get_example_outputs(model, args): + input_args = copy.deepcopy(args) + input_kwargs = {} + if input_args and isinstance(input_args[-1], dict): + input_kwargs = input_args[-1] + input_args = input_args[:-1] + + example_outputs = model(*input_args, **input_kwargs) + if isinstance(example_outputs, list): + example_outputs = [example_outputs] + elif not isinstance(example_outputs, tuple): + example_outputs = (example_outputs,) + + return example_outputs + + +_qtype_vtype_map = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + torch.quint4x2: torch.int8, +} + + +def unpack_quantized_tensor(value, cast_onnx_accepted=True): + if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map: + q_value_dequantize = value.dequantize() + q_scale = ( + torch.tensor(value.q_scale(), dtype=torch.double) + if cast_onnx_accepted + else torch.tensor(value.q_scale(), dtype=torch.float32) + ) + q_zero_point = ( + torch.tensor(value.q_zero_point(), dtype=torch.int64) + if cast_onnx_accepted + else torch.tensor(value.q_zero_point(), dtype=_qtype_vtype_map[value.dtype]) + ) + q_value = q_value_dequantize / q_scale + q_zero_point + q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype]) + return q_value, q_scale, q_zero_point + else: + return (value,) + + +def _pre_trace_quant_model(model, args): + r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return + original model. + + This is due to https://github.com/pytorch/pytorch/issues/75761. + """ + if any( + hasattr(m, "_packed_params") for m in getattr(model, "modules", list)() + ) or any(getattr(arg, "is_quantized", False) for arg in args): + return torch.jit.trace(model, args) + return model + + +def _model_to_graph( + model, + args, + verbose=False, + input_names=None, + output_names=None, + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, + do_constant_folding=True, + _disable_torch_constant_prop=False, + fixed_batch_size=False, + training=_C_onnx.TrainingMode.EVAL, + dynamic_axes=None, +) -> tuple[ + _C.Graph, + dict[str, torch.Tensor], + torch.Tensor + | tuple[torch.Tensor, ...] + | list[torch.Tensor] + | dict[str, torch.Tensor] + | Any + | None, +]: + """Converts model into an ONNX graph. + + Returns: + graph: A TorchScript IR Graph with ONNX nodes. + params_dict: Dict from input param name to param value. + torch_out: The output tensors resulting from the trace of ``model``. + If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`, + this will be None, since we are not doing any tracing. + """ + # TODO: can we simplify this to always return a tuple of Tensor or None? + + # Special case for common case of passing a single Tensor + if isinstance(args, (torch.Tensor, int, float, bool)): + args = (args,) + + model = _pre_trace_quant_model(model, args) + graph, params, torch_out, module = _create_jit_graph(model, args) + params_dict = _get_named_param_dict(graph, params) + + try: + graph = _optimize_graph( + graph, + operator_export_type, + _disable_torch_constant_prop=_disable_torch_constant_prop, + fixed_batch_size=fixed_batch_size, + params_dict=params_dict, + dynamic_axes=dynamic_axes, + input_names=input_names, + module=module, + ) + except Exception as e: + torch.onnx.log("Torch IR graph at exception: ", graph) + raise + + is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)) + if is_script: + example_outputs = _get_example_outputs(model, args) + example_outputs_final = () + for example_output in example_outputs: + example_outputs_final += unpack_quantized_tensor(example_output) + out_vars, desc = torch.jit._flatten(example_outputs_final) + _C._jit_pass_onnx_assign_output_shape( + graph, + out_vars, + desc, + GLOBALS.onnx_shape_inference, + is_script, + GLOBALS.export_onnx_opset_version, + ) + + # NB: ONNX requires complete information about output types, which might be + # erased by some optimizations, so we need to set it explicitly again. + else: + if not isinstance(torch_out, (list, tuple)): + output_wrapped = [torch_out] + else: + output_wrapped = torch_out # type: ignore[assignment] + + output_tensors, out_desc = torch.jit._flatten(tuple(output_wrapped)) + # assign_output_shape pass is not compatible with quantized outputs. + # Quantized outputs are flattened to 3 values in ONNX, while packed as + # single value in PyTorch. + if not any(getattr(out, "is_quantized", False) for out in output_tensors): + _C._jit_pass_onnx_assign_output_shape( + graph, + output_tensors, + out_desc, + GLOBALS.onnx_shape_inference, + is_script, + GLOBALS.export_onnx_opset_version, + ) + + _set_input_and_output_names(graph, input_names, output_names) + params_dict = _get_named_param_dict(graph, params) + + if ( + do_constant_folding + and GLOBALS.export_onnx_opset_version + >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET + ): + if training is None or training == _C_onnx.TrainingMode.EVAL: + params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict) + + params_dict = _C._jit_pass_onnx_constant_fold( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if GLOBALS.onnx_shape_inference: + _C._jit_pass_onnx_graph_shape_type_inference( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) + + params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) + + # For ONNX opset < 9, constants only have three data types: float16, float, double. + # In this pass transform constants of other data types to float/double + cast operator. + if GLOBALS.export_onnx_opset_version < 9: + _C._jit_pass_onnx_cast_all_constant_to_floating(graph) + + params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) + _C._jit_decay_packed_param_input_types(graph) + + # If output names lack a proper name and are identified only by their unique + # give them a legible name for debugging purposes + _apply_friendly_debug_names(graph, params_dict) + + return graph, params_dict, torch_out + + +@torch._disable_dynamo +@_deprecation.deprecated("2.5", "the future", "use onnx.printer.to_text() instead") +def export_to_pretty_string( + model, + args, + export_params=True, + verbose=False, + training=_C_onnx.TrainingMode.EVAL, + input_names=None, + output_names=None, + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, + export_type=None, + google_printer=False, + opset_version=None, + keep_initializers_as_inputs=None, + custom_opsets=None, + add_node_names=True, + do_constant_folding=True, + dynamic_axes=None, +): + """Similar to :func:`export`, but returns a text representation of the ONNX model. + + Only differences in args listed below. All other args are the same + as :func:`export`. + + Args: + add_node_names (bool, default True): Whether or not to set + NodeProto.name. This makes no difference unless + ``google_printer=True``. + google_printer (bool, default False): If False, will return a custom, + compact representation of the model. If True will return the + protobuf's `Message::DebugString()`, which is more verbose. + + Returns: + A UTF-8 str containing a human-readable representation of the ONNX model. + """ + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + if custom_opsets is None: + custom_opsets = {} + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + with exporter_context(model, training, verbose): + val_keep_init_as_ip = _decide_keep_init_as_input( + keep_initializers_as_inputs, operator_export_type, opset_version + ) + val_add_node_names = _decide_add_node_names( + add_node_names, operator_export_type + ) + val_do_constant_folding = _decide_constant_folding( + do_constant_folding, operator_export_type, training + ) + args = _decide_input_format(model, args) + graph, params_dict, torch_out = _model_to_graph( + model, + args, + verbose, + input_names, + output_names, + operator_export_type, + val_do_constant_folding, + training=training, + dynamic_axes=dynamic_axes, + ) + + return graph._pretty_print_onnx( # type: ignore[attr-defined] + params_dict, + opset_version, + False, + operator_export_type, + google_printer, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + ) + + +@_deprecation.deprecated("2.5", "the future", "avoid using this function") +def unconvertible_ops( + model, + args, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + opset_version: int | None = None, +) -> tuple[_C.Graph, list[str]]: + """Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`. + + The list is approximated because some ops may be removed during the conversion + process and don't need to be converted. Some other ops may have partial support + that will fail conversion with particular inputs. Please open a Github Issue + for op support requests. + + Args: + model: Same as the `model` parameter in :func:`torch.onnx.export`. + args: Same as the `args` parameter in :func:`torch.onnx.export`. + training: Same as the `training` parameter in :func:`torch.onnx.export`. + opset_version: Same as the `opset_version` parameter in :func:`torch.onnx.export`. + + Returns: + The JIT graph and a list of unconvertible ops in the format of "domain::op". + """ + + opset_version = opset_version or _constants.ONNX_DEFAULT_OPSET + GLOBALS.export_onnx_opset_version = opset_version + + try: + with exporter_context(model, training, verbose=False): + # Create a mostly clean JIT graph that contains the plain aten and + # other ops we can check with the symbolic registry. + # NOTE: We don't want to actually convert any ops to ONNX or run any + # symbolic functions because there is a higher chance that a pass + # fails or an unconvertible op messes up the graph during ONNX conversion. + # This way we can always generate a list just by looking at the names + # of the ops in the graph. + args = _decide_input_format(model, args) + model = _pre_trace_quant_model(model, args) + graph, _, _, module = _create_jit_graph(model, args) + _C._jit_pass_inline(graph) + _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) + _C._jit_pass_erase_number_types(graph) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + except Exception as e: + raise errors.OnnxExporterError( + "Failed to discover unconvertible ops because of errors during the JIT graph " + "generation process." + ) from e + + unsupported_ops = [] + for node in graph.nodes(): + domain_op = node.kind() + if domain_op.startswith(("onnx::", "prim::")): + # We consider onnx and prim ops as supported ops, even though some "prim" + # ops are not implemented as symbolic functions, because they may be + # eliminated in the conversion passes. Users may still see errors caused + # by prim ops even though they don't show up in the list. + continue + if not registration.registry.is_registered_op( + domain_op.rstrip("_"), opset_version + ): + # We consider all registered ops supported, even though some of them are + # only partially supported, because there is not yet a good way to check + # if an op is fully supported. + # TODO(justinchuby): Create a way to check if an op is fully supported. + unsupported_ops.append(domain_op) + return graph, unsupported_ops + + +def _setup_trace_module_map( + model: torch.nn.Module | torch.jit.ScriptModule, + export_modules_as_functions: bool | Collection[type[torch.nn.Module]], +) -> set[str]: + def __register_attribute_hook(): + attr_name = "_onnx_attrs" + + def _track_module_attributes_forward_pre_hook(module, input): + setattr(module, attr_name, _get_module_attributes(module)) + + def _track_module_attributes_forward_hook(module, input, output): + tracing_state = _C._get_tracing_state() + if not tracing_state: + return + + graph = tracing_state.graph() + onnx_attrs = {} + if hasattr(module, attr_name): + onnx_attrs = getattr(module, attr_name) + delattr(module, attr_name) + + _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) + + for m in model.modules(): + m.register_forward_hook(_track_module_attributes_forward_hook) + m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook) + + def _unqualified_variable_name(qualified_name: str) -> str: + """ + Parse qualified variable name and return the unqualified version. + + Pure numeric atoms are considered inadequate, so this function will look past them, + and start from the first non-numeric atom. + + Example: + >>> _unqualified_variable_name("__main__.Foo.bar") + 'bar' + >>> _unqualified_variable_name("__main__.Foo.bar.0") + 'bar.0' + """ + name_atoms = qualified_name.split(".") + for i, atom in reversed(list(enumerate(name_atoms))): + if not atom.isnumeric(): + return ".".join(name_atoms[i:]) + return qualified_name + + trace_module_map = { + _m: torch._C._jit_onnx_create_full_scope_name( + torch.typename(type(_m)), _unqualified_variable_name(_n) + ) + for _n, _m in model.named_modules() + } + torch.jit._trace._trace_module_map = trace_module_map + if isinstance(export_modules_as_functions, bool) and export_modules_as_functions: + module_typenames = {torch.typename(type(module)) for module in trace_module_map} + elif isinstance(export_modules_as_functions, set) and export_modules_as_functions: + + def _find_typename(v): + if isinstance(v, type): + return torch.typename(v) + else: + raise RuntimeError( + "Only type of the `nn.Module` should be " + "passed in the set for argument `export_modules_as_functions`. " + f"Got `{type(v).__name__}`." + ) + + module_typenames = {_find_typename(v) for v in export_modules_as_functions} + else: + module_typenames = set() + + if module_typenames: + __register_attribute_hook() + + return module_typenames + + +def _reset_trace_module_map(): + torch.jit._trace._trace_module_map = None + _C._jit_pass_onnx_clear_scope_records() + + +def _get_module_attributes(module): + annotations = typing.get_type_hints(type(module)) + base_m_annotations = typing.get_type_hints(torch.nn.Module) + [annotations.pop(k, None) for k in base_m_annotations] + # Check whether module attributes can be accessed. Some classes + # define attributes but don't provide access to them in their + # constructor. + # + # For example, torch.nn.Embedding has the `freeze` variable and its + # type specified in the class but the attribute is not created in the + # constructor. In other words, there is no `self.freeze = ` + # in the constructor. + # + # Reference: https://github.com/pytorch/pytorch/blob/92de1d322223fb5584e384971b32c46b93bc2f4b/torch/nn/modules/sparse.py#L120 + attrs = {} + for k in annotations: + try: + attrs[k] = getattr(module, k) + except AttributeError: + torch.onnx.log(f"Skipping module attribute '{k}'") + continue + return attrs + + +def _export( + model, + args, + f, + export_params=True, + verbose=False, + training=_C_onnx.TrainingMode.EVAL, + input_names=None, + output_names=None, + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, + export_type=None, + opset_version=None, + do_constant_folding=True, + dynamic_axes=None, + keep_initializers_as_inputs=None, + fixed_batch_size=False, + custom_opsets=None, + add_node_names=True, + onnx_shape_inference=True, + export_modules_as_functions: Any = False, + autograd_inlining=True, +): + assert GLOBALS.in_onnx_export is False + + if export_type is None: + export_type = _exporter_states.ExportTypes.PROTOBUF_FILE + + if isinstance(model, torch.nn.DataParallel): + raise ValueError( + "torch.nn.DataParallel is not supported by ONNX " + "exporter, please use 'attribute' module to " + "unwrap model from torch.nn.DataParallel. Try " + "torch.onnx.export(model.module, ...)" + ) + + GLOBALS.onnx_shape_inference = onnx_shape_inference + + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + + # torch.onnx.export does not support opset versions >=18 + if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET: + # We do not want to fail because we should still allow users to create + # custom symbolic functions for opset>17 + warnings.warn( + f"Exporting to ONNX opset version {opset_version} is not supported. " + f"by 'torch.onnx.export()'. " + f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. " + f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ", + category=errors.OnnxExporterWarning, + ) + + if export_modules_as_functions and opset_version < 15: + raise ValueError( + "`export_modules_as_functions` is not supported for `opset_version` < 15." + "This is because `opset_version` < 15 implies IR version < 8, which means " + "no local function support. " + ) + if not operator_export_type: + operator_export_type = _C_onnx.OperatorExportTypes.ONNX + + # By default, training=TrainingMode.EVAL, + # which is good because running a model in training mode could result in + # internal buffers getting updated, dropout getting applied, etc. + # If you really know what you're doing, you can turn + # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE, + # (to preserve whatever the original training mode was.) + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + try: + GLOBALS.in_onnx_export = True + _autograd_inlining_previous = GLOBALS.autograd_inlining + GLOBALS.autograd_inlining = autograd_inlining + + module_typenames_to_export_as_functions: set[str] = set() + if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)): + module_typenames_to_export_as_functions = _setup_trace_module_map( + model, export_modules_as_functions + ) + + with exporter_context(model, training, verbose): + val_keep_init_as_ip = _decide_keep_init_as_input( + keep_initializers_as_inputs, + operator_export_type, + opset_version, + ) + val_add_node_names = _decide_add_node_names( + add_node_names, operator_export_type + ) + val_do_constant_folding = _decide_constant_folding( + do_constant_folding, operator_export_type, training + ) + # Normally f can be a file-like object, but for large models, the external data format requires a + # valid `model_file_location`. Code in export.cpp will enforce this. + if isinstance(f, str): + model_file_location = f + else: + model_file_location = "" + args = _decide_input_format(model, args) + if dynamic_axes is None: + dynamic_axes = {} + _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) + + graph, params_dict, torch_out = _model_to_graph( + model, + args, + verbose, + input_names, + output_names, + operator_export_type, + val_do_constant_folding, + fixed_batch_size=fixed_batch_size, + training=training, + dynamic_axes=dynamic_axes, + ) + + # TODO: Don't allocate a in-memory string for the protobuf + defer_weight_export = ( + export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE + ) + if custom_opsets is None: + custom_opsets = {} + + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + node_attr_to_name = {} # type: ignore[var-annotated] + if module_typenames_to_export_as_functions: + # NOTE: cannot call DCE after this pass. DCE will remove function definition nodes. + node_attr_to_name = _C._jit_pass_onnx_function_extraction( + graph, + module_typenames_to_export_as_functions, + list(params_dict.keys()), + ) + + if keep_initializers_as_inputs is not True: + params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment] + graph, + params_dict, # type: ignore[arg-type] + getattr(model, "training", False), # type: ignore[arg-type] + ) + _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph) + if export_params: + ( + proto, + export_map, + val_use_external_data_format, + node_names, + ) = graph._export_onnx( # type: ignore[attr-defined] + params_dict, + opset_version, + dynamic_axes, + defer_weight_export, + operator_export_type, + not verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + model_file_location, + node_attr_to_name, + ) + else: + ( + proto, + export_map, + val_use_external_data_format, + node_names, + ) = graph._export_onnx( # type: ignore[attr-defined] + {}, + opset_version, + dynamic_axes, + False, + operator_export_type, + not verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + model_file_location, + node_attr_to_name, + ) + # insert function_proto into model_proto. + proto = onnx_proto_utils._add_onnxscript_fn( + proto, + custom_opsets, + ) + if verbose: + torch.onnx.log("Exported graph: ", graph) + onnx_proto_utils._export_file(proto, f, export_type, export_map) + finally: + assert GLOBALS.in_onnx_export + GLOBALS.in_onnx_export = False + GLOBALS.autograd_inlining = _autograd_inlining_previous + _reset_trace_module_map() + + return torch_out + + +def _apply_friendly_debug_names(graph, params): + for n in graph.nodes(): + for v in n.inputs(): + old_name = v.debugName() + if old_name != str(v.unique()): + continue + new_name = f"{n.kind()}_{v.unique()}" + v.setDebugName(new_name) + if old_name in params: + params[new_name] = params.pop(old_name) + + +def _set_input_and_output_names(graph, input_names, output_names): + def set_names(node_list, name_list, descriptor): + if name_list is None: + return + if len(name_list) > len(node_list): + raise RuntimeError( + "number of %s names provided (%d) exceeded number of %ss (%d)" + % (descriptor, len(name_list), descriptor, len(node_list)) + ) + + # Mark if the output node DebugName is set before. + output_node_set = set() + for i, (name, node) in enumerate(zip(name_list, node_list)): + # Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName(). + if descriptor == "output": + if node in output_node_set: + identity_node = graph.create("onnx::Identity") + identity_node.insertAfter(node.node()) + identity_node.addInput(node) + identity_node.output().setType(node.type()) + graph.return_node().replaceInput(i, identity_node.output()) + node = identity_node.output() + output_node_set.add(node) + + if node.debugName() != name: + node.setDebugName(name) + + set_names(list(graph.inputs()), input_names, "input") + set_names(list(graph.outputs()), output_names, "output") + + +def _run_symbolic_method(g, op_name, symbolic_fn, args): + r""" + This trampoline function gets invoked for every symbolic method + call from C++. + """ + try: + graph_context = jit_utils.GraphContext( + graph=g, + block=g.block(), + opset=GLOBALS.export_onnx_opset_version, + original_node=None, # type: ignore[arg-type] + params_dict=_params_dict, + env={}, + values_in_env=set(), + new_nodes=[], + ) + return symbolic_fn(graph_context, *args) + except TypeError as e: + # Handle the specific case where we didn't successfully dispatch + # to symbolic_fn. Otherwise, the backtrace will have the clues + # you need. + e.args = (f"{e.args[0]} (occurred when translating {op_name})",) + raise + + +def _add_block(node: _C.Node) -> _C.Block: + return node.addBlock() + + +def _add_input_to_block(block: _C.Block): + return block.addInputToBlock() # type: ignore[attr-defined] + + +def _add_output_to_block(block: _C.Block, value: _C.Value) -> int: + return block.registerOutput(value) + + +def _should_aten_fallback( + name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes +): + # For all builds, if domain=="aten" and operator_export_type==ONNX_ATEN, + # an aten::ATen operator is created regardless of symbolics existence + + is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version) + is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN + is_aten_fallback_export = ( + operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK + ) + + if not name.startswith("aten::"): + return False + + if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op): + return True + + return False + + +def _get_aten_op_overload_name(n: _C.Node) -> str: + # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds + schema = n.schema() + if not schema.startswith("aten::"): + return "" + return _C.parse_schema(schema).overload_name + + +def _run_symbolic_function( + graph: _C.Graph, + block: _C.Block, + node: _C.Node, + inputs: Any, + env: dict[_C.Value, _C.Value], + values_in_env: set[_C.Value], + new_nodes: list[_C.Node], + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, +) -> _C.Value | Sequence[_C.Value | None] | None: + """Runs a symbolic function. + + The function is used in C++ to export the node to ONNX. + + Returns: + A single or a tuple of Values. + None when the node gets cloned as is into the new graph. + """ + + opset_version = GLOBALS.export_onnx_opset_version + + # See Note [Export inplace] + node_kind = node.kind() + if node_kind.endswith("_"): + # Treat relu_ -> relu; add_ -> add etc. + ns_op_name = node_kind[:-1] + else: + ns_op_name = node_kind + + namespace, op_name = jit_utils.parse_node_kind(ns_op_name) + + graph_context = jit_utils.GraphContext( + graph=graph, + block=block, + opset=opset_version, + original_node=node, + params_dict=_params_dict, + env=env, + values_in_env=values_in_env, + new_nodes=new_nodes, + ) + + # Direct ATen export requested + if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + outputs = node.outputsSize() + attrs["outputs"] = outputs + return graph_context.aten_op( + op_name, + *inputs, + overload_name=_get_aten_op_overload_name(node), + **attrs, + ) + + try: + domain = namespace + symbolic_function_name = f"{domain}::{op_name}" + + symbolic_function_group = registration.registry.get_function_group( + symbolic_function_name + ) + if symbolic_function_group is not None: + symbolic_fn = symbolic_function_group.get(opset_version) + if symbolic_fn is not None: + # TODO Wrap almost identical attrs assignment or comment the difference. + attrs = { + k: symbolic_helper._node_get(node, k) for k in node.attributeNames() + } + return symbolic_fn(graph_context, *inputs, **attrs) + + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + if namespace == "onnx": + # Clone node to trigger ONNX shape inference + return graph_context.op( + op_name, *inputs, **attrs, outputs=node.outputsSize() + ) # type: ignore[attr-defined] + + raise errors.UnsupportedOperatorError( + symbolic_function_name, + opset_version, + symbolic_function_group.get_min_supported() + if symbolic_function_group + else None, + ) + + except RuntimeError: + if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH: + return None + elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK` + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + return graph_context.aten_op( + op_name, + *inputs, + overload_name=_get_aten_op_overload_name(node), + **attrs, + ) + raise + except TypeError as e: + # Handle the specific case where we didn't successfully dispatch. + # Otherwise, the backtrace will have the clues you need. + e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",) + raise + + +def _verify_custom_op_name(symbolic_name: str): + if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name): + raise errors.OnnxExporterError( + f"Failed to register operator {symbolic_name}. " + "The symbolic name must match the format domain::name, " + "and should start with a letter and contain only " + "alphanumerical characters" + ) + + ns, _ = jit_utils.parse_node_kind(symbolic_name) + if ns == "onnx": + raise ValueError( + f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified." + ) + + +def register_custom_op_symbolic( + symbolic_name: str, + symbolic_fn: Callable, + opset_version: int, +): + """Registers a symbolic function for a custom operator. + + When the user registers symbolic for custom/contrib ops, + it is highly recommended to add shape inference for that operator via setType API, + otherwise the exported graph may have incorrect shape inference in some extreme cases. + An example of setType is `test_aten_embedding_2` in `test_operators.py`. + + See "Custom Operators" in the module documentation for an example usage. + + Args: + symbolic_name (str): The name of the custom operator in "::" + format. + symbolic_fn (Callable): A function that takes in the ONNX graph and + the input arguments to the current operator, and returns new + operator nodes to add to the graph. + opset_version (int): The ONNX opset version in which to register. + """ + if symbolic_name.startswith("::"): + symbolic_name = f"aten{symbolic_name}" + + _verify_custom_op_name(symbolic_name) + + registration.custom_onnx_symbolic(symbolic_name, opset_version)(symbolic_fn) + + +def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int): + """Unregisters ``symbolic_name``. + + See "Custom Operators" in the module documentation for an example usage. + + Args: + symbolic_name (str): The name of the custom operator in "::" + format. + opset_version (int): The ONNX opset version in which to unregister. + """ + if symbolic_name.startswith("::"): + symbolic_name = f"aten{symbolic_name}" + + _verify_custom_op_name(symbolic_name) + + registration.registry.unregister(symbolic_name, opset_version) + + +def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): + """Ensures dynamic axes argument is follows the expected format.""" + if len(dynamic_axes) == 0: + return + + if hasattr(model, "graph"): + # Extracting set of valid input/output names that shall be used for dynamic_axes + if (input_names is None) or len(input_names) == 0: + input_names = [x.debugName() for x in model.graph.inputs()] + if (output_names is None) or len(output_names) == 0: + output_names = [y.debugName() for y in model.graph.outputs()] + + valid_names = set((input_names or []) + (output_names or [])) + + # If dynamic axes are provided as a list rather than dictionary, they should + # first get converted to a dictionary in expected format. If desired axes names + # are not provided for dynamic axes, automatic names shall be generated for + # provided dynamic axes of specified input/output + for key, value in dynamic_axes.items(): + if key not in valid_names: + warnings.warn( + f"Provided key {key} for dynamic axes is not a valid input/output name" + ) + if isinstance(value, list): + warnings.warn( + "No names were found for specified dynamic axes of provided input." + f"Automatically generated names will be applied to each dynamic axes of input {key}" + ) + + value_dict = {} + for i, x in enumerate(value): + if not isinstance(x, int): + raise ValueError( + "The type of axis index is expected to be an integer" + ) + if x in value_dict: + warnings.warn( + f"Duplicate dynamic axis index {x} was provided for input {key}." + ) + else: + value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1) + dynamic_axes[key] = value_dict + + +def model_signature(model: torch.nn.Module | Callable) -> inspect.Signature: + return inspect.signature( + model.forward if isinstance(model, torch.nn.Module) else model + ) diff --git a/lib/python3.10/site-packages/torch/onnx/verification.py b/lib/python3.10/site-packages/torch/onnx/verification.py new file mode 100644 index 0000000000000000000000000000000000000000..a21f1ffbba778301570b87fc0eb929583ac0f9b0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/onnx/verification.py @@ -0,0 +1,1806 @@ +# mypy: allow-untyped-defs +"""Functions to verify exported ONNX model is functionally equivalent to original PyTorch model. + +ONNX Runtime is required, and is used as the ONNX backend for export verification. +""" + +from __future__ import annotations + +import contextlib +import copy +import dataclasses +import datetime +import difflib +import enum +import functools +import io +import itertools +import os +import tempfile +import warnings +from typing import Any, Callable, Collection, Mapping, Sequence, Tuple, Union + +import numpy as np + +import torch +import torch._C._onnx as _C_onnx +from torch import _C +from torch.onnx import _constants, _experimental, _exporter_states, utils +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import onnx_proto_utils +from torch.types import Number + + +_ORT_PROVIDERS = ("CPUExecutionProvider",) + +_NumericType = Union[Number, torch.Tensor, np.ndarray] +_ModelType = Union[torch.nn.Module, torch.jit.ScriptModule] +_InputArgsType = Union[torch.Tensor, Tuple[Any, ...]] +_InputKwargsType = Mapping[str, Any] +_OutputsType = Union[Sequence[_NumericType], Sequence] + + +class OnnxBackend(enum.Enum): + """Enum class for ONNX backend used for export verification.""" + + REFERENCE = "ONNXReferenceEvaluator" + ONNX_RUNTIME_CPU = "CPUExecutionProvider" + ONNX_RUNTIME_CUDA = "CUDAExecutionProvider" + + +@dataclasses.dataclass +class VerificationOptions: + """Options for ONNX export verification. + + Attributes: + flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of + Tensors for ONNX. Set this to False if nested structures are to be preserved + for ONNX, which is usually the case with exporting ScriptModules. Default True. + ignore_none: Whether to ignore None type in torch output, which is usually the + case with tracing. Set this to False, if torch output should keep None type, + which is usually the case with exporting ScriptModules. Default to True. + check_shape: Whether to check the shapes between PyTorch and ONNX Runtime outputs + are exactly the same. Set this to False to allow output shape broadcasting. + Default to True. + check_dtype: Whether to check the dtypes between PyTorch and ONNX Runtime outputs + are consistent. Default to True. + backend: ONNX backend for verification. Default to OnnxBackend.ONNX_RUNTIME_CPU. + rtol: relative tolerance in comparison between ONNX and PyTorch outputs. + atol: absolute tolerance in comparison between ONNX and PyTorch outputs. + remained_onnx_input_idx: If provided, only the specified inputs will be passed + to the ONNX model. Supply a list when there are unused inputs in the model. + Since unused inputs will be removed in the exported ONNX model, supplying + all inputs will cause an error on unexpected inputs. This parameter tells + the verifier which inputs to pass into the ONNX model. + acceptable_error_percentage: acceptable percentage of element mismatches in comparison. + It should be a float of value between 0.0 and 1.0. + """ + + flatten: bool = True + ignore_none: bool = True + check_shape: bool = True + check_dtype: bool = True + backend: OnnxBackend = OnnxBackend.ONNX_RUNTIME_CPU + rtol: float = 1e-3 + atol: float = 1e-7 + remained_onnx_input_idx: Sequence[int] | None = None + acceptable_error_percentage: float | None = None + + +def _flatten_tuples(elem): + flattened = [] + for t in elem: + if isinstance(t, tuple): + flattened.extend(_flatten_tuples(t)) + else: + flattened.append(t) + return flattened + + +# TODO(justinchuby): Add type checking by narrowing down the return type when input is None +def _to_numpy(elem) -> list | np.ndarray: + if isinstance(elem, torch.Tensor): + if elem.requires_grad: + return elem.detach().cpu().numpy() + else: + return elem.cpu().numpy() + elif isinstance(elem, (list, tuple)): + return [_to_numpy(inp) for inp in elem] + elif isinstance(elem, (bool, int, float)): + return np.array(elem) + elif isinstance(elem, dict): + flattened = [] + for k in elem: + flattened.extend([_to_numpy(k), _to_numpy(elem[k])]) + return flattened + return elem + + +def _inline_flatten_list(inputs, res_list) -> list: + for i in inputs: + res_list.append(i) if not isinstance( + i, (list, tuple) + ) else _inline_flatten_list(i, res_list) + return res_list + + +def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list: + value_unpacked = [] + for value in values: + value_unpacked.extend( + utils.unpack_quantized_tensor(value, cast_onnx_accepted=cast_onnx_accepted) + ) + return [_to_numpy(v) for v in value_unpacked] + + +def _run_onnx(onnx_session, inputs) -> _OutputsType: + kw_inputs = {} + if inputs and isinstance(inputs[-1], dict): + kw_inputs = inputs[-1] + inputs = inputs[:-1] + inputs = _unpack_to_numpy(_flatten_tuples(inputs)) + ort_inputs = {} + for input_name, input in kw_inputs.items(): + ort_inputs[input_name] = _to_numpy(input) + inputs = _to_numpy(inputs) + if hasattr(onnx_session, "get_inputs"): + # onnxruntime.InferenceSession + input_names = [i.name for i in onnx_session.get_inputs()] + elif hasattr(onnx_session, "input_names"): + # onnx.reference.ReferenceEvaluator + input_names = onnx_session.input_names + else: + raise ValueError(f"Unknown ONNX backend type: {type(onnx_session)}.") + + for i, input in enumerate(inputs): + if i == len(input_names) or input_names[i] in ort_inputs: + raise ValueError( + f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}. " + f"input names: {input_names}." + ) + ort_inputs[input_names[i]] = input + onnx_outs = onnx_session.run(None, ort_inputs) + return onnx_outs + + +def _ort_session( + model: str | io.BytesIO, ort_providers: Sequence[str] = _ORT_PROVIDERS +): + try: + import onnxruntime # type: ignore[import] + except ImportError as e: + raise ImportError("onnxruntime is required for export verification.") from e + + if ort_providers is None: + ort_providers = _ORT_PROVIDERS + + session_options = onnxruntime.SessionOptions() + # suppress ort warnings. + # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. + session_options.log_severity_level = 3 + ort_session = onnxruntime.InferenceSession( + model if isinstance(model, str) else model.getvalue(), + session_options, + providers=ort_providers, + ) + return ort_session + + +def _onnx_reference_evaluator_session(model: str | io.BytesIO): + try: + import onnx + from onnx import reference as onnx_reference # type: ignore[attr-defined] + except ImportError as exc: + raise ImportError("onnx >= 1.13 is required for reference evaluator.") from exc + + proto = ( + onnx.load(model) # type: ignore[attr-defined] + if isinstance(model, str) + else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined] + ) + onnx_session = onnx_reference.ReferenceEvaluator(proto) + return onnx_session + + +def _onnx_backend_session(model: str | io.BytesIO, backend: OnnxBackend): + if backend == OnnxBackend.REFERENCE: + onnx_session = _onnx_reference_evaluator_session(model) + elif backend in {OnnxBackend.ONNX_RUNTIME_CPU, OnnxBackend.ONNX_RUNTIME_CUDA}: + onnx_session = _ort_session(model, (backend.value,)) + else: + raise ValueError(f"Unsupported backend: {backend}") + return onnx_session + + +def _compare_onnx_pytorch_outputs_in_np( + onnx_outs: _OutputsType, + pt_outs: _OutputsType, + options: VerificationOptions, +): + assert ( + len(onnx_outs) == len(pt_outs) + ), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" + acceptable_error_percentage = options.acceptable_error_percentage + if acceptable_error_percentage and ( + acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0 + ): + raise ValueError( + "If set, acceptable_error_percentage should be between 0.0 and 1.0" + ) + + for ort_out, pt_out in zip(onnx_outs, pt_outs): + try: + # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed. + if not options.check_shape: + # Allow different but broadcastable output shapes. + ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out) + torch.testing.assert_close( + ort_out, + pt_out, + rtol=options.rtol, + atol=options.atol, + check_dtype=options.check_dtype, + equal_nan=True, + ) + except AssertionError as e: + if acceptable_error_percentage: + error_percentage = 1 - np.sum( + np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol) + ) / np.prod(ort_out.shape) + if error_percentage <= acceptable_error_percentage: + warnings.warn( + f"Suppressed AssertionError:\n{e}.\n" + f"Error percentage {error_percentage} " + f"within acceptable range {acceptable_error_percentage}." + ) + continue + if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8: + warnings.warn("ONNX output is quantized") + if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8: + warnings.warn("PyTorch output is quantized") + raise + + +def _compare_onnx_pytorch_outputs( + onnx_outs: _OutputsType, + pt_outs: Any, + options: VerificationOptions, +): + """ + Compare ONNX and PyTorch outputs. + + Args: + onnx_outs: outputs from ONNX backend. + pt_outs: outputs from PyTorch. + options: options for verification. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + ValueError: if arguments provided are invalid. + """ + if options.ignore_none: + # torch.jit._flatten filters None type + pt_outs, _ = torch.jit._flatten(pt_outs) + else: + pt_outs = _inline_flatten_list([pt_outs], []) + pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False) + onnx_outs = _inline_flatten_list(onnx_outs, []) + _compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options) + + +def _prepare_input_for_pytorch(args, kwargs): + """Prepare input for PyTorch model execution. + + Any future changes/formatting to the input before dispatching to the PyTorch + model should be made in this function. + + Args: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + + Returns: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + """ + if isinstance(args, (torch.Tensor, dict)): + args = (args,) + # In-place operators will update input tensor data as well. + # Thus inputs are replicated before every forward call. + args = copy.deepcopy(args) + if kwargs: + kwargs = copy.deepcopy(kwargs) + else: + kwargs = {} + return args, kwargs + + +def _prepare_input_for_export(args, kwargs): + """Prepare input for ONNX model export. + + Any future changes/formatting to the input before dispatching to the + :func:`torch.onnx.export` api should be made in this function. + + Args: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + + Returns: + onnx_inputs: positional arguments for ONNX model export, as `args` in + :func:`torch.onnx.export`. + """ + args, kwargs = _prepare_input_for_pytorch(args, kwargs) + if not kwargs and len(args) > 0 and isinstance(args[-1], dict): + onnx_inputs = args + ({},) + elif kwargs: + onnx_inputs = args + (kwargs,) + else: + onnx_inputs = args + return onnx_inputs + + +def _prepare_input_for_onnx( + args, kwargs, remained_onnx_input_idx: Sequence[int] | None, flatten: bool +): + """Prepare input for ONNX model execution in ONNX backend. + + Any future changes/formatting to the input before dispatching to the ONNX backend + run should be made in this function. + + Args: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + remained_onnx_input_idx: indices of inputs to be used for ONNX model execution. + flatten: whether to flatten the input before dispatching to the ONNX model execution. + + Returns: + onnx_inputs: positional arguments for ONNX model execution in ONNX backend. + """ + onnx_inputs = _prepare_input_for_export(args, kwargs) + if flatten: + onnx_inputs, _ = torch.jit._flatten(onnx_inputs) + elif onnx_inputs and onnx_inputs[-1] == {}: + # Handle empty kwargs (normally removed by flatten). + onnx_inputs = onnx_inputs[:-1] + if remained_onnx_input_idx is not None: + return [onnx_inputs[i] for i in remained_onnx_input_idx] + else: + return onnx_inputs + + +def _try_clone_model(model): + """Used for preserving original model in case forward mutates model states.""" + try: + return copy.deepcopy(model) + except Exception: + warnings.warn( + "Failed to clone model. Model state might be mutated during verification." + ) + return model + + +def _compare_onnx_pytorch_model( + pt_model: _ModelType, + onnx_model_f: str | io.BytesIO, + input_args: _InputArgsType, + input_kwargs: _InputKwargsType | None, + additional_test_inputs: Sequence[_InputArgsType] | None, + options: VerificationOptions, +): + """Compare outputs from ONNX model runs with outputs from PyTorch model runs. + + Args: + pt_model: PyTorch model. + onnx_model_f: ONNX model file path or file-like object. + input_args: positional arguments for PyTorch model forward method. + input_kwargs: keyword arguments for PyTorch model forward method. + additional_test_inputs: additional positional arguments for PyTorch model + forward method. + options: options for verification. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + """ + onnx_session = _onnx_backend_session(onnx_model_f, options.backend) + + def compare_onnx_pytorch_model_with_input(input_args, input_kwargs): + pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs) + # TODO: remove this and treat mutating model separately. See #77679 + pt_model_copy = _try_clone_model(pt_model) + pt_outs = pt_model_copy(*pt_args, **pt_kwargs) + + onnx_inputs = _prepare_input_for_onnx( + input_args, input_kwargs, options.remained_onnx_input_idx, options.flatten + ) + + onnx_outs = _run_onnx(onnx_session, onnx_inputs) + + _compare_onnx_pytorch_outputs( + onnx_outs=onnx_outs, + pt_outs=pt_outs, + options=options, + ) + + compare_onnx_pytorch_model_with_input(input_args, input_kwargs) + + if additional_test_inputs: + for test_input_args in additional_test_inputs: + compare_onnx_pytorch_model_with_input(test_input_args, {}) + + +class _GraphDiff: + """A class to represent the difference between two graphs.""" + + def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph): + """Construct a _GraphDiff object. + + Args: + graph_a (_C.Graph): First graph to compare. + graph_b (_C.Graph): Second graph to compare. + """ + self.graph_a = graph_a + self.graph_b = graph_b + + def __str__(self): + """See function :func:`diff_report`.""" + return self.diff_report() + + def _indent(self, lines: str) -> str: + return "\n".join(["\t" + line for line in lines.splitlines()]) + + def diff_report(self) -> str: + """Return a string representation of the graph difference. + + The report shows the first pair of nodes that diverges. It also shows the source + location of the pair of nodes. + + Returns: + graph_diff_report (str): A string representation of the graph difference. + """ + graph_a = self.graph_a + graph_b = self.graph_b + + graph_a_str = str(graph_a) + graph_b_str = str(graph_b) + + if graph_a_str == graph_b_str: + return "" + + graph_diff = difflib.ndiff( + graph_a_str.splitlines(True), graph_b_str.splitlines(True) + ) + graph_diff_report = ["Graph diff:", self._indent("".join(graph_diff))] + + for node_a, node_b in itertools.zip_longest(graph_a.nodes(), graph_b.nodes()): + if str(node_a) != str(node_b): + graph_diff_report.append("First diverging operator:") + node_diff = difflib.ndiff( + str(node_a).splitlines(True), str(node_b).splitlines(True) + ) + source_printout = ["node diff:", self._indent("".join(node_diff))] + + stack_a = node_a.sourceRange() if node_a else None + if stack_a: + source_printout.extend( + ["Former source location:", self._indent(str(stack_a))] + ) + stack_b = node_b.sourceRange() if node_b else None + if stack_b: + source_printout.extend( + ["Latter source location:", self._indent(str(stack_b))] + ) + + graph_diff_report.extend(source_printout) + + break + + return "\n".join(graph_diff_report) + + +def _check_graph_diff( + model: torch.nn.Module | torch.jit.ScriptModule, + test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], + export_options: _experimental.ExportOptions, + model_to_graph_func: Callable[ + [ + torch.nn.Module, + tuple[Any, ...], + Mapping[str, Any], + _experimental.ExportOptions, + ], + _C.Graph, + ], +) -> str: + """Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`. + + Args: + model: See :func:`check_export_model_diff`. + test_input_groups: See :func:`check_export_model_diff`. + export_options: See :func:`check_export_model_diff`. + model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph. + + Returns: + graph_diff_report (str): A string representation of the graph difference. + """ + if len(test_input_groups) < 2: + raise ValueError("Need at least two groups of test inputs to compare.") + + ref_jit_graph = None + for args, kwargs in test_input_groups: + jit_graph = model_to_graph_func(model, args, kwargs, export_options) + if ref_jit_graph is None: + ref_jit_graph = jit_graph + continue + + graph_diff_report = _GraphDiff(ref_jit_graph, jit_graph).diff_report() + if graph_diff_report: + return graph_diff_report + return "" + + +def _traced_graph_from_model( + model: torch.nn.Module | torch.jit.ScriptModule, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + export_options: _experimental.ExportOptions, +) -> _C.Graph: + """As part of the ONNX export steps, create a traced JIT graph from a PyTorch model. + + Args: + model: See :func:`check_export_model_diff`. + args: See :func:`check_export_model_diff`. + kwargs: See :func:`check_export_model_diff`. + export_options: See :func:`check_export_model_diff`. + + Returns: + jit_graph (_C.Graph): A traced JIT graph. + """ + training = export_options.training + verbose = export_options.verbose + + with utils.exporter_context(model, training, verbose): + export_inputs = _prepare_input_for_export(args, kwargs) + model = utils._pre_trace_quant_model(model, export_inputs) + jit_graph, _, _, _ = utils._create_jit_graph(model, export_inputs) + return jit_graph + + +def _onnx_graph_from_model( + model: torch.nn.Module | torch.jit.ScriptModule, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + export_options: _experimental.ExportOptions, +) -> _C.Graph: + """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model. + + Args: + model: See :func:`check_export_model_diff`. + args: See :func:`check_export_model_diff`. + kwargs: See :func:`check_export_model_diff`. + export_options: See :func:`check_export_model_diff`. + + Returns: + onnx_graph (_C.Graph): An ONNX JIT graph. + """ + # TODO: refactor utils.py to remove duplicated code of context setup. See #78834 + opset_version = export_options.opset_version + operator_export_type = export_options.operator_export_type + export_modules_as_functions = export_options.export_modules_as_functions + training = export_options.training + verbose = export_options.verbose + dynamic_axes = export_options.dynamic_axes + input_names = export_options.input_names + output_names = export_options.output_names + + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + + utils._setup_trace_module_map(model, export_modules_as_functions) + + if not operator_export_type: + operator_export_type = _C_onnx.OperatorExportTypes.ONNX + + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + with utils.exporter_context(model, training, verbose): + do_constant_folding = utils._decide_constant_folding( + export_options.do_constant_folding, operator_export_type, training + ) + + if dynamic_axes is None: + dynamic_axes = {} + utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names) + + export_inputs = _prepare_input_for_export(args, kwargs) + export_inputs = utils._decide_input_format(model, export_inputs) + onnx_graph, _, _ = utils._model_to_graph( + model, + export_inputs, + verbose, + input_names, + output_names, + operator_export_type, + do_constant_folding, + training=training, + dynamic_axes=dynamic_axes, + ) + + return onnx_graph + + +def _onnx_graph_from_aten_graph( + graph: torch.Graph, + export_options: _experimental.ExportOptions, + params_dict: dict[str, Any] | None = None, +) -> tuple[torch.Graph, dict[str, Any]]: + if params_dict is None: + params_dict = {} + operator_export_type = export_options.operator_export_type + dynamic_axes = export_options.dynamic_axes or {} + input_names = export_options.input_names + training = export_options.training + do_constant_folding = export_options.do_constant_folding + opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET + + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + do_constant_folding = utils._decide_constant_folding( + do_constant_folding, operator_export_type, training + ) + + # TODO: Below is doing aten graph to onnx. It should be abstracted as a + # function in torch/onnx/utils.py. + graph = graph.copy() + graph = utils._optimize_graph( + graph, + operator_export_type, + params_dict=params_dict, + dynamic_axes=dynamic_axes, + input_names=input_names, + ) + + if training is None or training == _C_onnx.TrainingMode.EVAL: + params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) + + if ( + do_constant_folding + and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET + ): + params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if GLOBALS.onnx_shape_inference: + _C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version) + + params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) + + # For ONNX opset < 9, constants only have three data types: float16, float, double. + # In this pass transform constants of other data types to float/double + cast operator. + if opset_version < 9: + _C._jit_pass_onnx_cast_all_constant_to_floating(graph) + + params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) + _C._jit_decay_packed_param_input_types(graph) + + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if export_options.verbose: + print("ONNX graph: ", graph) + + return graph, params_dict + + +def _onnx_proto_from_onnx_graph( + onnx_graph: torch.Graph, + export_options: _experimental.ExportOptions, + params_dict: dict[str, Any], +) -> tuple[bytes, Mapping[str, bytes]]: + opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET + dynamic_axes = export_options.dynamic_axes or {} + operator_export_type = export_options.operator_export_type + val_keep_init_as_ip = utils._decide_keep_init_as_input( + export_options.keep_initializers_as_inputs, + operator_export_type, + opset_version, + ) + val_add_node_names = utils._decide_add_node_names(True, operator_export_type) + custom_opsets = export_options.custom_opsets or {} + + proto, export_map, _, _ = onnx_graph._export_onnx( # type: ignore[attr-defined] + params_dict, + opset_version, + dynamic_axes, + False, + operator_export_type, + not export_options.verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + "", + {}, + ) + + return proto, export_map + + +def check_export_model_diff( + model: torch.nn.Module | torch.jit.ScriptModule, + test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], + export_options: _experimental.ExportOptions | None = None, +) -> str: + """Verify exported model discrepancy between different groups of inputs. + + A graph is exported for each group of inputs. The exported graphs are then compared + to each other, and discrepancies of first pair of nodes are reported. This function + first checks the jit graph. If no discrepancies were found, it then checks the onnx + graph. + + Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless + of the inputs used for exporting. A discrepancy implies the graph exported is + not accurate when run on other groups of inputs, which will typically results in + runtime errors or mismatching output. + + Args: + model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported. + test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence + of input groups to be used to export the model. Each input group is a pair of + (args, kwargs). + export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions + object that controls the export behavior. + + Returns: + str: A string containing the diff of the exported models. + """ + export_options = ( + _experimental.ExportOptions() if export_options is None else export_options + ) + + jit_diff_report = _check_graph_diff( + model, test_input_groups, export_options, _traced_graph_from_model + ) + if jit_diff_report: + return jit_diff_report + + return _check_graph_diff( + model, test_input_groups, export_options, _onnx_graph_from_model + ) + + +def verify( + model: _ModelType, + input_args: _InputArgsType, + input_kwargs: _InputKwargsType | None = None, + do_constant_folding: bool = True, + dynamic_axes: Mapping[str, Mapping[int, str] | Mapping[str, Sequence[int]]] + | None = None, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + opset_version: int | None = None, + keep_initializers_as_inputs: bool = True, + verbose: bool = False, + fixed_batch_size: bool = False, + use_external_data: bool = False, + additional_test_inputs: Sequence[_InputArgsType] | None = None, + options: VerificationOptions | None = None, +): + """Verify model export to ONNX against original PyTorch model. + + Args: + model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`. + input_args (tuple): See :func:`torch.onnx.export`. + input_kwargs (dict): See :func:`torch.onnx.export`. + do_constant_folding (bool, optional): See :func:`torch.onnx.export`. + dynamic_axes (dict, optional): See :func:`torch.onnx.export`. + input_names (list, optional): See :func:`torch.onnx.export`. + output_names (list, optional): See :func:`torch.onnx.export`. + training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`. + opset_version (int, optional): See :func:`torch.onnx.export`. + keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`. + verbose (bool, optional): See :func:`torch.onnx.export`. + fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases. + use_external_data (bool, optional): Explicitly specify whether to export the + model with external data. + additional_test_inputs (list, optional): List of tuples. Each tuple is a group of + input arguments to test. Currently only *args are supported. + options (_VerificationOptions, optional): A _VerificationOptions object that + controls the verification behavior. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + ValueError: if arguments provided are invalid. + """ + if options is None: + options = VerificationOptions() + + if training == torch.onnx.TrainingMode.TRAINING: + model.train() + elif training == torch.onnx.TrainingMode.EVAL: + model.eval() + with torch.no_grad(), contextlib.ExitStack() as stack: + model_f: str | io.BytesIO = io.BytesIO() + if use_external_data: + tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory()) + model_f = os.path.join(tmpdir_path, "model.onnx") + + inputs_for_export = _prepare_input_for_export(input_args, input_kwargs) + + # TODO(#77679): remove this and treat mutating model separately. + model_copy = _try_clone_model(model) + utils._export( + model, + inputs_for_export, + model_f, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + keep_initializers_as_inputs=keep_initializers_as_inputs, + dynamic_axes=dynamic_axes, + input_names=input_names, + output_names=output_names, + fixed_batch_size=fixed_batch_size, + training=training, + verbose=verbose, + ) + + _compare_onnx_pytorch_model( + pt_model=model_copy, + onnx_model_f=model_f, + input_args=input_args, + input_kwargs=input_kwargs, + additional_test_inputs=additional_test_inputs, + options=options, + ) + + +def verify_aten_graph( + graph: torch.Graph, + input_args: tuple[Any, ...], + export_options: _experimental.ExportOptions, + params_dict: dict[str, Any] | None = None, + verification_options: VerificationOptions | None = None, +) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: + if verification_options is None: + verification_options = VerificationOptions() + if params_dict is None: + params_dict = {} + + original_jit_graph = graph + graph = graph.copy() + + # Execute aten graph and get reference torch jit outputs. + graph_inputs = list(graph.inputs()) + jit_inputs = tuple([arg for arg in input_args if arg is not None]) + weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]] + assert all(w is not None for w in weights) + # TODO: Only copy the argument if mutation is detected in Graph. + jit_inputs = copy.deepcopy(jit_inputs) + jit_input_and_parameters = jit_inputs + tuple(weights) + jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters) # type: ignore[attr-defined] + if not isinstance(jit_outs, (list, tuple)): + jit_outs = [jit_outs] + + # Convert aten graph to onnx graph. + graph, onnx_params_dict = _onnx_graph_from_aten_graph( + graph, export_options, params_dict + ) + + proto, export_map = _onnx_proto_from_onnx_graph( + graph, export_options, onnx_params_dict + ) + model_f: str | io.BytesIO = io.BytesIO() + export_type = _exporter_states.ExportTypes.PROTOBUF_FILE + onnx_proto_utils._export_file(proto, model_f, export_type, export_map) + + # NOTE: Verification is unstable. Try catch to emit information for debugging. + try: + # NOTE: Input might be dce'ed, so we need to remove those from the input args. + new_input_names = {v.debugName() for v in graph.inputs()} + new_input_args = [] + for v, arg in zip(original_jit_graph.inputs(), input_args): + if v.debugName() in new_input_names: + new_input_args.append(arg) + input_args = tuple(new_input_args) + + onnx_inputs = _prepare_input_for_onnx( + input_args, + {}, + verification_options.remained_onnx_input_idx, + verification_options.flatten, + ) + + onnx_session = _onnx_backend_session(model_f, verification_options.backend) + onnx_outs = _run_onnx(onnx_session, onnx_inputs) + del onnx_session # To free device memory + + try: + _compare_onnx_pytorch_outputs( + onnx_outs=onnx_outs, + pt_outs=jit_outs, + options=verification_options, + ) + except AssertionError as e: + return e, graph, jit_outs, onnx_outs + + return None, graph, jit_outs, onnx_outs + + except Exception as e: + print("Unexpected error during verification.") + print("jit graph: ", original_jit_graph) + print("onnx graph: ", graph) + raise e + + +class GraphInfoPrettyPrinter: + graph_info: GraphInfo | None + upper_printer: GraphInfoPrettyPrinter | None + lower_printer: GraphInfoPrettyPrinter | None + + graph_str_lambdas: Mapping[int, str] + connector_str_lambdas: Mapping[int, str] + children_str_lambdas: Mapping[int, str] + + def __init__(self, graph_info: GraphInfo | None): + self.graph_info = graph_info + if ( + graph_info is not None + and graph_info.upper_graph_info is not None + and graph_info.lower_graph_info is not None + ): + self.upper_printer = GraphInfoPrettyPrinter(graph_info.upper_graph_info) + self.lower_printer = GraphInfoPrettyPrinter(graph_info.lower_graph_info) + else: + self.upper_printer = None + self.lower_printer = None + + def _total_rows(self) -> int: + if self.graph_info is None: + return 1 + if self.upper_printer and self.lower_printer: + return ( + self.upper_printer._total_rows() + self.lower_printer._total_rows() + 1 + ) + return 2 # Two lines: node count + id. + + def _node_count_segment_str(self) -> str: + if self.graph_info is None: + return "..." + node_count = self.graph_info.essential_node_count() + has_mismatch = self.graph_info.has_mismatch() + error_node_kind = ( + f"({self.graph_info.essential_node_kinds().pop()})" + if node_count == 1 and has_mismatch + else "" + ) + + return f"{node_count} {'X' if has_mismatch else chr(0x2713)} {error_node_kind}" + + def _graph_id_segment_str(self) -> str: + if self.graph_info is None: + return "" + return f"id: {self.graph_info.id}" + + def _max_segment_columns(self) -> int: + return max( + map(len, (self._node_count_segment_str(), self._graph_id_segment_str())) + ) + + def _graph_segment_str_at_line(self, line: int) -> str: + """Get the string representation of the graph segment at the given line.""" + if line == 0: + result_str = self._node_count_segment_str() + result_str += " " * (self._max_segment_columns() - len(result_str)) + return result_str + if line == 1: + result_str = self._graph_id_segment_str() + result_str += " " * (self._max_segment_columns() - len(result_str)) + return result_str + if 0 <= line < self._total_rows(): + return " " * self._max_segment_columns() + return "" + + def _connector_segment_str_at_line(self, line: int) -> str: + """Get the connector segment string at the given line.""" + if self.upper_printer is None and self.lower_printer is None: + return "" + upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 + lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 + if line == 0: + return " __" + elif line < upper_total_rows + 1: + return " | " + elif line == upper_total_rows + 1: + return " |__" + elif line < upper_total_rows + lower_total_rows + 1: + return " " + return "" + + def _children_str_at_line(self, line: int) -> str: + """Get the string representation of the children at the given line. + + Recursively calls `_str_at_line` on children nodes. + """ + if self.upper_printer is None and self.lower_printer is None: + return "" + upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 + lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 + if 0 <= line < upper_total_rows: + return ( + self.upper_printer._str_at_line(line) if self.upper_printer else "..." + ) + elif upper_total_rows < line < upper_total_rows + lower_total_rows + 1: + return ( + self.lower_printer._str_at_line(line - upper_total_rows - 1) + if self.lower_printer + else "..." + ) + return "" + + def _str_at_line(self, line: int) -> str: + """Get the string representation of the graph at the given line.""" + return ( + self._graph_segment_str_at_line(line) + + self._connector_segment_str_at_line(line) + + self._children_str_at_line(line) + ) + + def pretty_print(self): + if self.graph_info is None: + print(None) + return + # Print tree. + print(" Tree: ".center(80, "=")) + total_rows = self._total_rows() + for line in range(total_rows): + print(self._str_at_line(line).rstrip()) + if self.graph_info.has_mismatch(): + # Summarize leaf subgraphs with mismatch. + print(" Mismatch leaf subgraphs: ".center(80, "=")) + print( + [ + graph_info.id + for graph_info in self.graph_info.all_mismatch_leaf_graph_info() + ] + ) + # Summarize node kinds with mismatch. + mismatch_node_kinds: dict[str, int] = {} + for graph_info in self.graph_info.all_mismatch_leaf_graph_info(): + node_kinds = graph_info.essential_node_kinds() + if len(node_kinds) == 1: + node_kind = node_kinds.pop() + mismatch_node_kinds[node_kind] = ( + mismatch_node_kinds.get(node_kind, 0) + 1 + ) + print(" Mismatch node kinds: ".center(80, "=")) + print(mismatch_node_kinds) + else: + print(" No mismatch found. ".center(80, "=")) + + +class OnnxTestCaseRepro: + def __init__(self, repro_dir): + self.repro_dir = repro_dir + self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case( + repro_dir + ) + + @classmethod + def create_test_case_repro( + cls, proto: bytes, inputs, outputs, dir: str, name: str | None = None + ): + """Create a repro under "{dir}/test_{name}" for an ONNX test case. + + The test case contains the model and the inputs/outputs data. The directory + structure is as follows: + + dir + \u251c\u2500\u2500 test_ + \u2502 \u251c\u2500\u2500 model.onnx + \u2502 \u2514\u2500\u2500 test_data_set_0 + \u2502 \u251c\u2500\u2500 input_0.pb + \u2502 \u251c\u2500\u2500 input_1.pb + \u2502 \u251c\u2500\u2500 output_0.pb + \u2502 \u2514\u2500\u2500 output_1.pb + + Args: + proto: ONNX model proto. + inputs: Inputs to the model. + outputs: Outputs of the model. + dir: Directory to save the repro. + name: Name of the test case. If not specified, a name based on current time + will be generated. + Returns: + Path to the repro. + """ + if name is None: + name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + return onnx_proto_utils.export_as_test_case( + proto, + _to_numpy(inputs), + _to_numpy(outputs), + name, + dir, + ) + + def validate(self, options: VerificationOptions): + """Run the ONNX test case with options.backend, and compare with the expected outputs. + + Args: + options: Options for validation. + + Raise: + AssertionError: if outputs from options.backend and expected outputs are not + equal up to specified precision. + """ + onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend) + run_outputs = onnx_session.run(None, self.inputs) + if hasattr(onnx_session, "get_outputs"): + output_names = [o.name for o in onnx_session.get_outputs()] + elif hasattr(onnx_session, "output_names"): + output_names = onnx_session.output_names + else: + raise ValueError(f"Unknown onnx session type: {type(onnx_session)}") + expected_outs = [self.outputs[name] for name in output_names] + _compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options) + + +@dataclasses.dataclass +class GraphInfo: + """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph.""" + + graph: torch.Graph + input_args: tuple[Any, ...] + params_dict: dict[str, Any] + export_options: _experimental.ExportOptions = dataclasses.field( + default_factory=_experimental.ExportOptions + ) + mismatch_error: AssertionError | None = dataclasses.field(default=None, init=False) + pt_outs: Sequence[_NumericType] | None = dataclasses.field(default=None, init=False) + upper_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) + lower_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) + id: str = dataclasses.field(default="") + _onnx_graph: torch.Graph | None = dataclasses.field(init=False, default=None) + + _EXCLUDED_NODE_KINDS: frozenset[str] = frozenset( + {"prim::Constant", "prim::ListConstruct", "aten::ScalarImplicit"} + ) + + def clear(self): + """Clear states and results of previous verification.""" + self.mismatch_error = None + self.pt_outs = None + self._onnx_graph = None + self.upper_graph_info = None + self.lower_graph_info = None + + def pretty_print_tree(self): + """Pretty print `GraphInfo` tree. + + Each node represents a subgraph, showing the number of nodes in the subgraph and + a check mark if the subgraph has output mismatch between torch and ONNX. + + The id of the subgraph is shown under the node. The `GraphInfo` object for any + subgraph can be retrieved by calling `graph_info.find_partition(id)`. + + Example:: + + ==================================== Tree: ===================================== + 5 X __2 X __1 \u2713 + id: | id: 0 | id: 00 + | | + | |__1 X (aten::relu) + | id: 01 + | + |__3 X __1 \u2713 + id: 1 | id: 10 + | + |__2 X __1 X (aten::relu) + id: 11 | id: 110 + | + |__1 \u2713 + id: 111 + =========================== Mismatch leaf subgraphs: =========================== + ['01', '110'] + ============================= Mismatch node kinds: ============================= + {'aten::relu': 2} + + """ + GraphInfoPrettyPrinter(self).pretty_print() + + def pretty_print_mismatch(self, graph: bool = False): + """Pretty print details of the mismatch between torch and ONNX. + + Args: + graph: If True, print the ATen JIT graph and ONNX graph. + """ + print(f" Mismatch info for graph partition {self.id}: ".center(80, "=")) + if graph: + print(" ATen JIT graph ".center(80, "=")) + # TODO: A more compact graph printer. + # * Drop stride, grad, device information. + # * Show source location on a separate line. + print(self.graph) + if self._onnx_graph is not None: + print(" ONNX graph ".center(80, "=")) + print(self._onnx_graph) + if self.has_mismatch(): + print(" Mismatch error ".center(80, "=")) + print(self.mismatch_error) + else: + print(" No mismatch ".center(80, "=")) + + def has_mismatch(self) -> bool: + """Return True if the subgraph has output mismatch between torch and ONNX.""" + return self.mismatch_error is not None + + def essential_node_count(self) -> int: + """Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" + return sum( + 1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS + ) + + def essential_node_kinds(self) -> set[str]: + """Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" + return { + n.kind() + for n in self.graph.nodes() + if n.kind() not in self._EXCLUDED_NODE_KINDS + } + + def all_mismatch_leaf_graph_info(self) -> list[GraphInfo]: + """Return a list of all leaf `GraphInfo` objects that have mismatch.""" + if not self.has_mismatch(): + return [] + + no_mismatch_children = ( + self.upper_graph_info is None or not self.upper_graph_info.has_mismatch() + ) and ( + self.lower_graph_info is None or not self.lower_graph_info.has_mismatch() + ) + + if no_mismatch_children: + return [self] + + results = [] + if self.upper_graph_info is not None: + results += self.upper_graph_info.all_mismatch_leaf_graph_info() + if self.lower_graph_info is not None: + results += self.lower_graph_info.all_mismatch_leaf_graph_info() + + return results + + def find_partition(self, id: str) -> GraphInfo | None: + """Find the `GraphInfo` object with the given id.""" + if id == self.id: + return self + current_length = len(self.id) + if len(id) > current_length: + if id[current_length] == "0" and self.upper_graph_info is not None: + return self.upper_graph_info.find_partition(id) + elif id[current_length] == "1" and self.lower_graph_info is not None: + return self.lower_graph_info.find_partition(id) + return None + + def export_repro( + self, repro_dir: str | None = None, name: str | None = None + ) -> str: + """Export the subgraph to ONNX along with the input/output data for repro. + + The repro directory will contain the following files:: + + dir + \u251c\u2500\u2500 test_ + \u2502 \u251c\u2500\u2500 model.onnx + \u2502 \u2514\u2500\u2500 test_data_set_0 + \u2502 \u251c\u2500\u2500 input_0.pb + \u2502 \u251c\u2500\u2500 input_1.pb + \u2502 \u251c\u2500\u2500 output_0.pb + \u2502 \u2514\u2500\u2500 output_1.pb + + Args: + repro_dir: The directory to export the repro files to. Defaults to current + working directory if None. + name: An optional name for the test case folder: "test_{name}". + + Returns: + The path to the exported repro directory. + """ + + if repro_dir is None: + repro_dir = os.getcwd() + repro_dir = os.path.join(repro_dir, "onnx_debug") + + onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph( + self.graph, self.export_options, self.params_dict + ) + + proto, _ = _onnx_proto_from_onnx_graph( + onnx_graph, self.export_options, onnx_params_dict + ) + return OnnxTestCaseRepro.create_test_case_repro( + proto, self.input_args, self.pt_outs, repro_dir, name + ) + + def _graph_partition_pivot(self) -> int: + """Find the pivot index to partition the graph. + + The pivot is the node that splits the graph into two parts. Each part should + have the similar amount of nodes, excluding non essential ops, defined in + `_EXCLUDED_NODE_KINDS`, such as `prim::Constant`. + If the graph has an odd number of nodes, the upper part will have one more node. + If the graph does not have any node that can be partitioned, return -1. + + Returns: + The index of the pivot node. + """ + included_node_indices = [ + i + for i, n in enumerate(self.graph.nodes()) + if n.kind() not in self._EXCLUDED_NODE_KINDS + ] + half_idx = len(included_node_indices) // 2 - 1 + if half_idx >= 0 and len(included_node_indices) > half_idx: + return included_node_indices[half_idx] + 1 + return -1 + + def _partition_upper_graph(self) -> torch.Graph: + pivot = self._graph_partition_pivot() + if pivot == -1: + return torch.Graph() + graph = self.graph.copy() # Copy to not mutate parent graph. + original_outputs = list(graph.outputs()) + + def _process_bridge_value_for_upper( + new_outputs: list[torch.Value], bridge_value: torch.Value + ) -> torch.Value: + # Add bridge values as upper graph outputs. + new_outputs.append(bridge_value) + return bridge_value + + new_outputs: list[torch.Value] = [] + process_bridge_value_for_upper = functools.partial( + _process_bridge_value_for_upper, new_outputs + ) + _, dropped_nodes, complete_upper_nodes_set, _ = self._partition_nodes( + graph, pivot, process_bridge_value_for_upper + ) + + for _ in enumerate(original_outputs): + graph.eraseOutput(0) + for output in new_outputs: + graph.registerOutput(output) + + for node in reversed(dropped_nodes): + node.destroy() + + for i, input in reversed(list(enumerate(list(graph.inputs())))): + if ( + not _has_uses_by_nodes(input, complete_upper_nodes_set) + and input not in new_outputs + ): + try: + graph.eraseInput(i) + except RuntimeError as e: + print(input, graph) + raise e + + return graph + + def _partition_lower_graph(self) -> torch.Graph: + pivot = self._graph_partition_pivot() + if pivot == -1: + return torch.Graph() + graph = self.graph.copy() # Copy to not mutate parent graph. + original_outputs = list(graph.outputs()) + original_inputs = list(graph.inputs()) + + new_outputs = [] + + def _process_bridge_value_for_lower( + graph: torch.Graph, bridge_value: torch.Value + ) -> torch.Value: + # Add bridge values as lower graph inputs. + new_input = graph.addInput() + bridge_value.replaceAllUsesWith(new_input) + new_input.copyMetadata(bridge_value) + return new_input + + process_bridge_value_for_lower = functools.partial( + _process_bridge_value_for_lower, graph + ) + + upper_nodes, lower_nodes, _, complete_lower_nodes_set = self._partition_nodes( + graph, pivot, process_bridge_value_for_lower + ) + + for output in original_outputs: + if _produced_by(output, lower_nodes): + new_outputs.append(output) + for _ in enumerate(original_outputs): + graph.eraseOutput(0) + for output in new_outputs: + graph.registerOutput(output) + + for input in original_inputs: + if _has_uses_by_nodes(input, complete_lower_nodes_set): + new_input = graph.addInput() + input.replaceAllUsesWith(new_input) + new_input.copyMetadata(input) + + for node in reversed(upper_nodes): + if node not in complete_lower_nodes_set: + try: + node.destroy() + except RuntimeError as e: + print(node, graph) + raise e + + for _ in original_inputs: + graph.eraseInput(0) + + return graph + + def _partition_node( + self, + node: torch.Node, + complete_upper_nodes_set: set[torch.Node], + complete_lower_nodes_set: set[torch.Node], + original_graph_outputs: set[torch.Value], + covered_bridge_values: set[torch.Value], + process_bridge_value: Callable[[torch.Value], torch.Value], + ): + if node in complete_lower_nodes_set: + return + + if ( + _node_has_uses_by(node, complete_lower_nodes_set) + and node.kind() in self._EXCLUDED_NODE_KINDS + ): + complete_lower_nodes_set.update(_all_nodes([node])) + for input in node.inputs(): + if input in covered_bridge_values: + continue + self._partition_node( + input.node(), + complete_upper_nodes_set, + complete_lower_nodes_set, + original_graph_outputs, + covered_bridge_values, + process_bridge_value, + ) + else: + for output in node.outputs(): + if output in covered_bridge_values: + continue + if ( + _has_uses_by_nodes(output, complete_lower_nodes_set) + or output in original_graph_outputs + ): + covered_bridge_values.add(process_bridge_value(output)) + + def _partition_nodes( + self, + graph: torch.Graph, + pivot: int, + process_bridge_value: Callable[[torch.Value], torch.Value], + ) -> tuple[list[torch.Node], list[torch.Node], set[torch.Node], set[torch.Node]]: + nodes = list(graph.nodes()) + upper_nodes = nodes[:pivot] + lower_nodes = nodes[pivot:] + # `upper_nodes` and `complete_upper_nodes_set` differs in that the latter + # recursively contains nodes in subblock of `upper_nodes`. + # The same applies for `lower_nodes` and `complete_lower_nodes_set`. + # With addition that `complete_lower_nodes_set` will include nodes that + # are determined to be copied from `upper_nodes` to `lower_nodes`. + complete_upper_nodes_set = _all_nodes(upper_nodes) + complete_lower_nodes_set = _all_nodes(lower_nodes) + original_graph_outputs = set(graph.outputs()) + # Bridge values are values produced from upper graph, and consumed + # by lower graph. These values need to be become upper graph outputs + # and lower graph inputs, to bridge the interaction. + # Start with all graph inputs marked as covered. If any graph input is + # needed by lower graph, just keep it in lower graph inputs later. + covered_bridge_values = set(graph.inputs()) + for node in upper_nodes: + self._partition_node( + node, + complete_upper_nodes_set, + complete_lower_nodes_set, + original_graph_outputs, + covered_bridge_values, + process_bridge_value, + ) + return ( + upper_nodes, + lower_nodes, + complete_upper_nodes_set, + complete_lower_nodes_set, + ) + + def _bridge_kwargs(self): + pt_outs = self.pt_outs + graph_outputs = list(self.graph.outputs()) + assert pt_outs is not None + assert len(graph_outputs) == len( + pt_outs + ), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}" + return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)} + + def _args_and_params_for_partition_graph( + self, + graph: torch.Graph, + bridge_kwargs: Mapping[str, _NumericType | Sequence[_NumericType]], + full_kwargs: Mapping[str, torch.Tensor], + full_params: Mapping[str, torch.Tensor], + ): + input_names = [input.debugName() for input in graph.inputs()] + args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs) + args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs) + params = {k: full_params[k] for k in input_names if k in full_params} + assert len(args) + len(params) == len( + input_names + ), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}" + return args, params + + def verify_export( + self, options: VerificationOptions + ) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: + """ + Verify the export from TorchScript IR graph to ONNX. + + Export the TorchScript IR graph to ONNX, with the inputs, parameters and export + options recorded in this object. Then verify the exported ONNX graph against + the original TorchScript IR graph under the provided verification options. + + Args: + options: The verification options. + + Returns: + error: The AssertionError raised during the verification. Returns None if no + error is raised. + onnx_graph: The exported ONNX graph in TorchScript IR format. + onnx_outs: The outputs from running exported ONNX model under the onnx + backend in `options`. + pt_outs: The outputs from running the TorchScript IR graph. + """ + return verify_aten_graph( + self.graph, + input_args=self.input_args, + params_dict=self.params_dict, + export_options=self.export_options, + verification_options=options, + ) + + def find_mismatch( + self, + options: VerificationOptions | None = None, + ): + """ + Find all mismatches between the TorchScript IR graph and the exported onnx model. + + Binary searches the model graph to find the minimal subgraph that exhibits the + mismatch. A `GraphInfo` object is created for each subgraph, recording the test + inputs and export options, as well as the validation results. + + Args: + options: The verification options. + """ + self.clear() + + if options is None: + options = VerificationOptions() + + if self.export_options.verbose: + print(self.graph) + + if len(list(self.graph.outputs())) == 0: + return + + assert len(self.input_args) + len(self.params_dict) == len( + list(self.graph.inputs()) + ), ( + f"Number of graph inputs({len(list(self.graph.inputs()))}) does not match " + f"the provided tensor arguments({len(self.input_args)} + {len(self.params_dict)})." + ) + + self.mismatch_error, self._onnx_graph, self.pt_outs, _ = self.verify_export( + options + ) + + if self.mismatch_error is None: + # No mismatch found in graph. + return + + if self.essential_node_count() <= 1: + # Reached leaf node, no more partitioning. + return + + full_kwargs = { + k.debugName(): v for k, v in zip(self.graph.inputs(), self.input_args) + } + full_params = self.params_dict + + upper_graph = self._partition_upper_graph() + upper_args, upper_params = self._args_and_params_for_partition_graph( + upper_graph, {}, full_kwargs, full_params + ) + self.upper_graph_info = GraphInfo( + upper_graph, + upper_args, + upper_params, + self.export_options, + id=self.id + "0", + ) + + self.upper_graph_info.find_mismatch(options) + + bridge_kwargs = self.upper_graph_info._bridge_kwargs() + lower_graph = self._partition_lower_graph() + lower_args, lower_params = self._args_and_params_for_partition_graph( + lower_graph, bridge_kwargs, full_kwargs, full_params + ) + self.lower_graph_info = GraphInfo( + lower_graph, + lower_args, + lower_params, + self.export_options, + id=self.id + "1", + ) + + self.lower_graph_info.find_mismatch(options) + + +def _all_nodes(nodes: Collection[torch.Node]) -> set[torch.Node]: + all_nodes = set(nodes) + for n in nodes: + for b in n.blocks(): + all_nodes.update(_all_nodes(list(b.nodes()))) + return all_nodes + + +def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool: + return any(use.user in nodes for use in value.uses()) + + +def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool: + for output in node.outputs(): + if _has_uses_by_nodes(output, nodes): + return True + return False + + +def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool: + return value.node() in nodes + + +def find_mismatch( + model: torch.nn.Module | torch.jit.ScriptModule, + input_args: tuple[Any, ...], + do_constant_folding: bool = True, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + opset_version: int | None = None, + keep_initializers_as_inputs: bool = True, + verbose: bool = False, + options: VerificationOptions | None = None, +) -> GraphInfo: + r"""Find all mismatches between the original model and the exported model. + + Experimental. The API is subject to change. + + This tool helps debug the mismatch between the original PyTorch model and exported + ONNX model. It binary searches the model graph to find the minimal subgraph that + exhibits the mismatch. + + Args: + model: The model to be exported. + input_args: The input arguments to the model. + do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`. + training: Same as `training` in :func:`torch.onnx.export`. + opset_version: Same as `opset_version` in :func:`torch.onnx.export`. + keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`. + verbose: Same as `verbose` in :func:`torch.onnx.export`. + options: The options for the mismatch verification. + + Returns: + A GraphInfo object that contains the mismatch information. + + Example:: + + >>> import torch + >>> import torch.onnx.verification + >>> torch.manual_seed(0) + >>> opset_version = 15 + >>> # Define a custom symbolic function for aten::relu. + >>> # The custom symbolic function is incorrect, which will result in mismatches. + >>> def incorrect_relu_symbolic_function(g, self): + ... return self + >>> torch.onnx.register_custom_op_symbolic( + ... "aten::relu", + ... incorrect_relu_symbolic_function, + ... opset_version=opset_version, + ... ) + >>> class Model(torch.nn.Module): + ... def __init__(self) -> None: + ... super().__init__() + ... self.layers = torch.nn.Sequential( + ... torch.nn.Linear(3, 4), + ... torch.nn.ReLU(), + ... torch.nn.Linear(4, 5), + ... torch.nn.ReLU(), + ... torch.nn.Linear(5, 6), + ... ) + ... def forward(self, x): + ... return self.layers(x) + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> graph_info = torch.onnx.verification.find_mismatch( + ... Model(), + ... (torch.randn(2, 3),), + ... opset_version=opset_version, + ... ) + ===================== Mismatch info for graph partition : ====================== + ================================ Mismatch error ================================ + Tensor-likes are not close! + Mismatched elements: 12 / 12 (100.0%) + Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed) + Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed) + ==================================== Tree: ===================================== + 5 X __2 X __1 \u2713 + id: | id: 0 | id: 00 + | | + | |__1 X (aten::relu) + | id: 01 + | + |__3 X __1 \u2713 + id: 1 | id: 10 + | + |__2 X __1 X (aten::relu) + id: 11 | id: 110 + | + |__1 \u2713 + id: 111 + =========================== Mismatch leaf subgraphs: =========================== + ['01', '110'] + ============================= Mismatch node kinds: ============================= + {'aten::relu': 2} + + """ + if options is None: + options = VerificationOptions() + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + """From aten graph, do binary search on graph partition to find operator export discrepancy.""" + # TODO: Copied from utils.py `export` until `_optimize_graph`. + if training == torch.onnx.TrainingMode.TRAINING: + model.train() + elif training == torch.onnx.TrainingMode.EVAL: + model.eval() + with torch.no_grad(): + inputs_for_export = _prepare_input_for_export(input_args, {}) + args = utils._decide_input_format(model, inputs_for_export) + + model = utils._pre_trace_quant_model(model, args) + graph, params, torch_out, module = utils._create_jit_graph(model, args) + params_dict = utils._get_named_param_dict(graph, params) + + utils._apply_friendly_debug_names(graph, params_dict) + + graph_info = GraphInfo( + graph, + input_args, + params_dict, + _experimental.ExportOptions( + do_constant_folding=do_constant_folding, + training=training, + opset_version=opset_version, + keep_initializers_as_inputs=keep_initializers_as_inputs, + verbose=verbose, + ), + ) + graph_info.find_mismatch(options) + graph_info.pretty_print_mismatch() + graph_info.pretty_print_tree() + + return graph_info diff --git a/lib/python3.10/site-packages/torch/optim/__init__.py b/lib/python3.10/site-packages/torch/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7354092dda4e02bfa05dd8c71ebd1e0f8408a87d --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/__init__.py @@ -0,0 +1,63 @@ +""" +:mod:`torch.optim` is a package implementing various optimization algorithms. + +Most commonly used methods are already supported, and the interface is general +enough, so that more sophisticated ones can also be easily integrated in the +future. +""" + +from torch.optim import lr_scheduler as lr_scheduler, swa_utils as swa_utils +from torch.optim._adafactor import Adafactor as Adafactor +from torch.optim.adadelta import Adadelta as Adadelta +from torch.optim.adagrad import Adagrad as Adagrad +from torch.optim.adam import Adam as Adam +from torch.optim.adamax import Adamax as Adamax +from torch.optim.adamw import AdamW as AdamW +from torch.optim.asgd import ASGD as ASGD +from torch.optim.lbfgs import LBFGS as LBFGS +from torch.optim.nadam import NAdam as NAdam +from torch.optim.optimizer import Optimizer as Optimizer +from torch.optim.radam import RAdam as RAdam +from torch.optim.rmsprop import RMSprop as RMSprop +from torch.optim.rprop import Rprop as Rprop +from torch.optim.sgd import SGD as SGD +from torch.optim.sparse_adam import SparseAdam as SparseAdam + + +Adafactor.__module__ = "torch.optim" + + +del adadelta # type: ignore[name-defined] # noqa: F821 +del adagrad # type: ignore[name-defined] # noqa: F821 +del adam # type: ignore[name-defined] # noqa: F821 +del adamw # type: ignore[name-defined] # noqa: F821 +del sparse_adam # type: ignore[name-defined] # noqa: F821 +del adamax # type: ignore[name-defined] # noqa: F821 +del asgd # type: ignore[name-defined] # noqa: F821 +del sgd # type: ignore[name-defined] # noqa: F821 +del radam # type: ignore[name-defined] # noqa: F821 +del rprop # type: ignore[name-defined] # noqa: F821 +del rmsprop # type: ignore[name-defined] # noqa: F821 +del optimizer # type: ignore[name-defined] # noqa: F821 +del nadam # type: ignore[name-defined] # noqa: F821 +del lbfgs # type: ignore[name-defined] # noqa: F821 + +__all__ = [ + "Adafactor", + "Adadelta", + "Adagrad", + "Adam", + "Adamax", + "AdamW", + "ASGD", + "LBFGS", + "lr_scheduler", + "NAdam", + "Optimizer", + "RAdam", + "RMSprop", + "Rprop", + "SGD", + "SparseAdam", + "swa_utils", +] diff --git a/lib/python3.10/site-packages/torch/optim/_adafactor.py b/lib/python3.10/site-packages/torch/optim/_adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3941008ab8a850ccfa2e01a2f2784ea63b4e92 --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/_adafactor.py @@ -0,0 +1,656 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _disable_dynamo_if_unsupported, + _get_scalar_dtype, + _maximize_doc, + Optimizer, + ParamsT, + TensorListList, +) + + +__all__ = ["Adafactor", "adafactor"] + + +class Adafactor(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-2, + beta2_decay: float = -0.8, + eps: Tuple[Optional[float], float] = (None, 1e-3), + d: float = 1.0, + weight_decay: float = 0.0, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Learning rate should be >= 0 but is: {lr}") + if not 0.0 >= beta2_decay: + raise ValueError(f"beta2_decay should be <= 0 but is: {beta2_decay}") + if eps[0] is not None and not 0.0 <= eps[0]: + raise ValueError(f"epsilon1 should be >= 0 but is: {eps[0]}") + if not 0.0 <= eps[1]: + raise ValueError(f"epsilon2 should be >= 0 but is: {eps[1]}") + if not 1.0 <= d: + raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}") + if not 0.0 <= weight_decay: + raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}") + defaults = dict( + lr=lr, + beta2_decay=beta2_decay, + eps=eps, + d=d, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = torch.tensor(step_val, dtype=_get_scalar_dtype()) + + def _init_group( + self, + group, + params_with_grad, + grads, + row_vars, + col_vars, + variances, + state_steps, + ): + for p in group["params"]: + if p.grad is None: + continue + if torch.is_complex(p): + raise RuntimeError("Adafactor does not support complex parameters") + if p.grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients") + + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = torch.tensor(0.0, dtype=_get_scalar_dtype()) + + if p.grad.dim() > 1: + row_shape = list(p.grad.shape) + row_shape[-1] = 1 + # Row factor of variance, NOT the same shape as grads (will be reduced along last dim) + state["row_var"] = p.grad.new_zeros(row_shape) + + col_shape = list(p.grad.shape) + col_shape[-2] = 1 + # Col factor of variance, NOT the same shape as grads (will be reduced along penultimate dim) + state["col_var"] = p.grad.new_zeros(col_shape) + else: + state["variance"] = torch.zeros_like( + p.grad, memory_format=torch.preserve_format + ) + + row_vars.append(state.get("row_var", None)) + col_vars.append(state.get("col_var", None)) + variances.append(state.get("variance", None)) + state_steps.append(state["step"]) + return False # has_complex + + @torch.no_grad() + def step(self, closure=None): + r"""Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + row_vars: List[Optional[Tensor]] = [] + col_vars: List[Optional[Tensor]] = [] + variances: List[Optional[Tensor]] = [] + state_steps: List[Tensor] = [] + eps1, eps2 = group["eps"] + + has_complex = self._init_group( + group, + params_with_grad, + grads, + row_vars, + col_vars, + variances, + state_steps, + ) + + adafactor( + params_with_grad, + grads, + row_vars, + col_vars, + variances, + state_steps, + d=group["d"], + lr=group["lr"], + beta2_decay=group["beta2_decay"], + weight_decay=group["weight_decay"], + eps1=eps1, + eps2=eps2, + foreach=group["foreach"], + maximize=group["maximize"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + has_complex=has_complex, + ) + + return loss + + +Adafactor.__doc__ = ( + r"""Implements Adafactor algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{(lr)}, \: \tau + \text{(}\beta_2\text{ decay)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \\ + &\hspace{15mm} \: \epsilon_1, \epsilon_2 \text{ (epsilons)}, \: d \text{(clipping threshold)}, \\ + &\hspace{15mm} \: \lambda \text{(weight decay)}, + \: \textit{maximize} \\ + &\textbf{initialize} : \: R_0 \leftarrow 0 \text{ (second moment row factor)}, \\ + &\hspace{23mm} \: C_0 \leftarrow 0 \text{ (second moment col factor)}, \\ + &\hspace{23mm} \: \widehat{V}_0 \leftarrow 0 \text{ (second moment for vectors)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}G_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}G_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\widehat{\beta}_{2_t} \leftarrow 1 - t^{\tau} \\ + &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) \\ + &\hspace{5mm}\alpha_t \leftarrow max(\epsilon_2, + \text{RMS}(\theta_{t-1}))\rho_t \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ + &\hspace{5mm}\textbf{if} \: \text{dim}(G_t) > 1: \\ + &\hspace{10mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ + (1-\widehat{\beta}_{2_t})(G_t \odot G_t) \cdot 1_m \\ + &\hspace{10mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ + (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t) \\ + &\hspace{10mm}\widehat{V}_t \leftarrow + \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\widehat{V}_t \leftarrow \widehat{\beta}_{2_t}\widehat{V}_{t-1}+ + (1-\widehat{\beta}_{2_t}) \cdot (G_t \odot G_t) \\ + &\hspace{5mm}U_t \leftarrow + \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\ + &\hspace{5mm}\widehat{U}_t \leftarrow \frac{U_t}{max(1, \frac{\text{RMS}(U_t)}{d})} \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \alpha_t \widehat{U}_t \\ + + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, Tensor, optional): unlike other optimizers, Adafactor does not require a + learning rate, and Shazeer, Noam, and Mitchell Stern do not use lr at all. + Deviating from the paper, this implementation uses lr for applying weight + decay and as the maximum value for relative step size rho_t. Note that in + the paper, a constant of 0.01 is used as the maximum value for relative + step size, and so we set 0.01 as the default value. (default: 1e-2) + beta2_decay (float, optional): the decay rate of beta2. beta2 standardly refers + to the coefficient used for computing the running average of the gradient + squared. (default: -0.8) + eps (Tuple[float, float], optional): epsilon1 is the term added to the denominator + of the update calculation to improve numerical stability. This use of epsilon1 + deviates from the algorithm written in the paper! See note below for more details. + epsilon2 is the term used to avoid having too small a weight update when applying + parameter scaling. (default: (None, 1e-3)) + d (float, optional): the clipping threshold, used to avoid larger-than-desired + updates. + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + foreach (bool, optional): whether foreach implementation of optimizer is used. Note + that the foreach implementation uses ~ sizeof(params) more peak memory than the + for-loop version due to the intermediates being a tensorlist vs just one tensor. + As Adafactor is commonly used when memory is prohibitive, Adafactor will default + to the slower single tensor for-loop implementation unless this flag is explicitly + True. This behavior is contrary to other optimizers, which will attempt defaulting + to foreach on CUDA for faster runtime. (default: None) + {_maximize_doc}""" + + r""" + .. Note:: + The implementation of Adafactor subtly differs from Shazeer, Noam, and Mitchell Stern + and implementations in some other frameworks with its use of learning rate and + :math:`\epsilon_1`. + + Regarding the learning rate hyperparameter: Shazeer, Noam, and Mitchell Stern do not + use lr at all, as the stated algorithm uses :math:`\rho_t` and update clipping to + affect the step size. + + This implementation allows `lr` to influence the maximum value for :math:`\rho_t`: + + .. math:: + \begin{aligned} + &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) + \end{aligned} + + This differs from Shazeer, Noam, and Mitchell Stern, who use a constant of 0.01 as + the maximum value of :math:`\rho_t` + + .. math:: + \begin{aligned} + &\hspace{5mm}\rho_t \leftarrow min(0.01, \frac{1}{\sqrt{t}}) + \end{aligned} + + Shazeer, Noam, and Mitchell Stern do not enforce an opinion on how weight decay should + be computed, and so we use the learning rate as a coefficient for decoupled weight + decay, similar to what is suggested in `Decoupled Weight Decay Regularization`_. + + Regarding the use of :math:`\epsilon_1`: The implementation attempts to replicate the + presumed intention of Shazeer, Noam, and Mitchell Stern to use :math:`\epsilon_1` as + a stabilizing term when the squared gradient becomes small. + + This stabilization can be written as + + .. math:: + \begin{aligned} + &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ + (1-\widehat{\beta}_{2_t})(G_t \odot G_t + 1_n \cdot 1^\top_m) \cdot 1_m \\ + &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ + (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + 1_n \cdot 1^\top_m) \\ + &\hspace{5mm}\widehat{V}_t \leftarrow + \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\ + &\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\ + \end{aligned} + + where the row and column factors of gradient squared :math:`R_t` and :math:`C_t` + are left alone, and we apply :math:`\epsilon_1` at the final calculation of + the variance estimate :math:`\widehat{V}_t` and for the update :math:`U_t`. + + This is in contrast to Shazeer, Noam, and Mitchell Stern and other frameworks which + apply :math:`\epsilon_1` to both row and column factors of the squared gradient, but + not in the calculations after: + + .. math:: + \begin{aligned} + &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ + (1-\widehat{\beta}_{2_t})(G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \cdot 1_m \\ + &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ + (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \\ + &\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{1^\top_n \cdot R_t} \\ + &\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}} \\ + \end{aligned} + + + .. _Adafactor\: Adaptive Learning Rates with Sublinear Memory Cost: + https://arxiv.org/pdf/1804.04235 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + """ +) + + +def _single_tensor_adafactor( + params: List[Tensor], + grads: List[Tensor], + # If grad is 1-dimensional (aka a vector), there is no factorization necessary + # so row_var and col_var will be None while variance will be filled. + # Contrarily, for a grad with multiple dimensions, we will factor along the last + # 2 dimensions, and so row_var and col_var will be filled and variance will be None. + row_vars: List[Optional[Tensor]], + col_vars: List[Optional[Tensor]], + variances: List[Optional[Tensor]], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + d: float, + lr: Union[Tensor, float], + beta2_decay: float, + weight_decay: float, + eps1: Optional[float], + eps2: float, + maximize: bool, + has_complex: bool, +): + assert ( + grad_scale is None and found_inf is None + ), "Grad scaling should occur outside of optimizer.step()" + + if torch.jit.is_scripting(): + # this assert is due to JIT being dumb and not realizing that the ops below + # have overloads to handle both float and Tensor lrs, so we just assert it's + # a float since most people using JIT are using floats + assert isinstance(lr, float) + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + step_t = state_steps[i] + row_var = row_vars[i] + col_var = col_vars[i] + variance = variances[i] + if eps1 is None: + eps1 = torch.finfo(param.dtype).eps + + # update step + step_t += 1 + step_float = step_t.item() + + one_minus_beta2_t = step_float**beta2_decay + rho_t = min(lr, 1 / (step_float**0.5)) + alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t + + # Perform stepweight decay + if weight_decay != 0: + param.mul_(1 - lr * weight_decay) + + if grad.dim() > 1: + assert ( + row_var is not None and col_var is not None + ), "row_var and col_var should be defined when grad is multidimensional" + # same as (g * g).mean(dim=-1) w/o materializing an intermediate size g + row_mean = ( + torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1)) + ) + row_var.lerp_(row_mean, one_minus_beta2_t) + # same as (g * g).mean(dim=-2) w/o materializing an intermediate size g + col_mean = ( + torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2)) + ) + col_var.lerp_(col_mean, one_minus_beta2_t) + var_estimate = row_var @ col_var + var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1)) + else: + assert ( + variance is not None + ), "variance should be defined when grad is a vector" + grad_squared = grad * grad + variance.lerp_(grad_squared, one_minus_beta2_t) + # avoid writing into variance during update + var_estimate = variance.clone() + + # square the eps1 as we sqrt after to keep eps1's magnitude + update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_() + update.mul_(grad) + denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d)) + param.add_(update, alpha=-alpha / denom) + + +def _group_tensors_by_device_dtype_and_is_multidim( + tensorlists: TensorListList, +) -> Dict[ + Tuple[Optional[torch.device], Optional[torch.dtype], bool], + List[List[Optional[Tensor]]], +]: + """Groups tensors by device, dtype, AND multidimensionality -- whether the tensor + has multiple dims or just one dim (is a vector). This allows the foreach impl of + Adafactor to assume that every group of params will either be factored or not.""" + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(tensorlists) + ultra_grouped_tensors: Dict[ + Tuple[Optional[torch.device], Optional[torch.dtype], bool], + List[List[Optional[Tensor]]], + ] = {} + for (device, dtype), (tensorlists, _) in grouped_tensors.items(): + matrix_key = (device, dtype, True) + vector_key = (device, dtype, False) + + # assumes grad is the second tensorlist + for j, tensor in enumerate(tensorlists[1]): + assert tensor is not None, "grad should not be None" + if tensor.dim() > 1: + if matrix_key not in ultra_grouped_tensors: + ultra_grouped_tensors[matrix_key] = [[] for _ in tensorlists] + for i in range(len(tensorlists)): + ultra_grouped_tensors[matrix_key][i].append(tensorlists[i][j]) + else: + if vector_key not in ultra_grouped_tensors: + ultra_grouped_tensors[vector_key] = [[] for _ in tensorlists] + for i in range(len(tensorlists)): + ultra_grouped_tensors[vector_key][i].append(tensorlists[i][j]) + return ultra_grouped_tensors + + +def _multi_tensor_adafactor( + params: List[Tensor], + grads: List[Tensor], + # If grad is 1-dimensional (aka a vector), there is no factorization necessary + # so row_var and col_var will be None while variance will be filled. + # Contrarily, for a grad with multiple dimensions, we will factor along the last + # 2 dimensions, and so row_var and col_var will be filled and variance will be None. + row_vars: List[Optional[Tensor]], + col_vars: List[Optional[Tensor]], + variances: List[Optional[Tensor]], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + d: float, + lr: Union[Tensor, float], + beta2_decay: float, + weight_decay: float, + eps1: Optional[float], + eps2: float, + maximize: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert ( + grad_scale is None and found_inf is None + ), "Grad scaling should occur outside of optimizer.step()" + + grouped_tensors = _group_tensors_by_device_dtype_and_is_multidim( + [params, grads, row_vars, col_vars, variances, state_steps] # type: ignore[list-item] + ) + for (_, dtype, is_multidim), ( + ( + device_params_, + device_grads_, + device_row_vars_, + device_col_vars_, + device_variances_, + device_state_steps_, + ) + ) in grouped_tensors.items(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_state_steps = cast(List[Tensor], device_state_steps_) + if eps1 is None: + assert ( + dtype is not None + ), "dtype is needed to compute eps1 when eps1 is unset" + eps1 = torch.finfo(dtype).eps + + if TYPE_CHECKING: + assert device_state_steps[0] is not None + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1.0) + + one_minus_beta2_ts = [] + beta2_ts = [] + rho_ts = [] + for s in device_state_steps: + one_minus_beta2_ts.append(s.item() ** beta2_decay) + beta2_ts.append(1 - s.item() ** beta2_decay) + rho_ts.append(min(lr, 1 / (s.item() ** 0.5))) + + alphas = [ + max(eps2, p.norm(2).item() / (p.numel() ** 0.5)) * r + for p, r in zip(device_params, rho_ts) + ] + + # Perform stepweight decay + if weight_decay != 0: + torch._foreach_mul_(device_params, 1 - lr * weight_decay) + + if is_multidim: + device_row_vars = cast(List[Tensor], device_row_vars_) + device_col_vars = cast(List[Tensor], device_col_vars_) + assert ( + device_row_vars[0] is not None and device_col_vars[0] is not None + ), "row_var and col_var should be defined when grad is multidimensional" + # same as (g * g).mean(dim=-1) w/o materializing an intermediate size g + row_means = [ + torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads + ] + torch._foreach_mul_(row_means, row_means) + torch._foreach_div_(row_means, [grad.size(-1) for grad in device_grads]) + torch._foreach_mul_(device_row_vars, beta2_ts) + torch._foreach_mul_(row_means, one_minus_beta2_ts) + torch._foreach_add_(device_row_vars, row_means) + del row_means + + # same as (g * g).mean(dim=-2) w/o materializing an intermediate size g + col_means = [ + torch.norm(grad, dim=-2, keepdim=True) for grad in device_grads + ] + torch._foreach_mul_(col_means, col_means) + torch._foreach_div_(col_means, [grad.size(-2) for grad in device_grads]) + torch._foreach_mul_(device_col_vars, beta2_ts) + torch._foreach_mul_(col_means, one_minus_beta2_ts) + torch._foreach_add_(device_col_vars, col_means) + del col_means + + var_estimates = [ + row_var @ col_var + for row_var, col_var in zip(device_row_vars, device_col_vars) + ] + row_var_means = [ + row_var.mean(dim=-2, keepdim=True) for row_var in device_row_vars + ] + torch._foreach_clamp_min_(row_var_means, eps1) + torch._foreach_div_(var_estimates, row_var_means) + del row_var_means + else: + device_variances = cast(List[Tensor], device_variances_) + assert ( + device_variances[0] is not None + ), "variance should be defined when grad is a vector" + + grads_squared = torch._foreach_mul(device_grads, device_grads) + torch._foreach_mul_(device_variances, beta2_ts) + torch._foreach_mul_(grads_squared, one_minus_beta2_ts) + torch._foreach_add_(device_variances, grads_squared) + del grads_squared + + # avoid writing into variance during update + var_estimates = [v.clone() for v in device_variances] + + # square the eps1 as we sqrt after to keep eps1's magnitude + torch._foreach_clamp_min_(var_estimates, eps1 * eps1) + torch._foreach_sqrt_(var_estimates) + torch._foreach_reciprocal_(var_estimates) + torch._foreach_mul_(var_estimates, device_grads) + updates = var_estimates + + alphas = [ + -a / (max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d))) + for a, update in zip(alphas, updates) + ] + torch._foreach_mul_(updates, alphas) + torch._foreach_add_(device_params, updates) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adafactor) +def adafactor( + params: List[Tensor], + grads: List[Tensor], + row_vars: List[Optional[Tensor]], + col_vars: List[Optional[Tensor]], + variances: List[Optional[Tensor]], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + has_complex: bool = False, + *, + d: float, + lr: Union[float, Tensor], + beta2_decay: float, + weight_decay: float, + eps1: float, + eps2: float, + maximize: bool, +): + r"""Functional API that performs Adafactor algorithm computation. + + See :class:`~torch.optim.Adafactor` for details. + """ + if not torch._utils.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "`state_steps` argument must contain a list of singleton tensors" + ) + + if foreach: + func = _multi_tensor_adafactor + else: + func = _single_tensor_adafactor + + func( + params, + grads, + row_vars, + col_vars, + variances, + state_steps, + d=d, + lr=lr, + beta2_decay=beta2_decay, + weight_decay=weight_decay, + eps1=eps1, + eps2=eps2, + maximize=maximize, + grad_scale=grad_scale, + found_inf=found_inf, + has_complex=has_complex, + ) diff --git a/lib/python3.10/site-packages/torch/optim/_functional.py b/lib/python3.10/site-packages/torch/optim/_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..a307cc76846dc2be51a47a1b5b4e70c29aafffc4 --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/_functional.py @@ -0,0 +1,84 @@ +# mypy: allow-untyped-defs +r"""Functional interface.""" +import math +from typing import List + +from torch import Tensor + +from .adadelta import adadelta # type: ignore[attr-defined] # noqa: F401 +from .adagrad import _make_sparse, adagrad # type: ignore[attr-defined] # noqa: F401 +from .adam import adam # type: ignore[attr-defined] # noqa: F401 +from .adamax import adamax # type: ignore[attr-defined] # noqa: F401 +from .adamw import adamw # type: ignore[attr-defined] # noqa: F401 +from .asgd import asgd # type: ignore[attr-defined] # noqa: F401 +from .nadam import nadam # type: ignore[attr-defined] # noqa: F401 +from .radam import radam # type: ignore[attr-defined] # noqa: F401 +from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401 +from .rprop import rprop # type: ignore[attr-defined] # noqa: F401 +from .sgd import sgd # type: ignore[attr-defined] # noqa: F401 + + +# TODO: use foreach API in optim._functional to do all the computation + + +def sparse_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[int], + *, + eps: float, + beta1: float, + beta2: float, + lr: float, + maximize: bool, +): + r"""Functional API that performs Sparse Adam algorithm computation. + + See :class:`~torch.optim.SparseAdam` for details. + """ + for i, param in enumerate(params): + grad = grads[i] + grad = grad if not maximize else -grad + grad = grad.coalesce() # the update is non-linear so indices must be unique + grad_indices = grad._indices() + grad_values = grad._values() + if grad_values.numel() == 0: + # Skip update for empty grad + continue + size = grad.size() + + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + + def make_sparse(values): + constructor = grad.new + if grad_indices.dim() == 0 or values.dim() == 0: + return constructor().resize_as_(grad) + return constructor(grad_indices, values, size) + + # Decay the first and second moment running average coefficient + # old <- b * old + (1 - b) * new + # <==> old += (1 - b) * (new - old) + old_exp_avg_values = exp_avg.sparse_mask(grad)._values() + exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) + exp_avg.add_(make_sparse(exp_avg_update_values)) + old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() + exp_avg_sq_update_values = ( + grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) + ) + exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) + + # Dense addition again is intended, avoiding another sparse_mask + numer = exp_avg_update_values.add_(old_exp_avg_values) + exp_avg_sq_update_values.add_(old_exp_avg_sq_values) + denom = exp_avg_sq_update_values.sqrt_().add_(eps) + del exp_avg_update_values, exp_avg_sq_update_values + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + step_size = lr * math.sqrt(bias_correction2) / bias_correction1 + + param.add_(make_sparse(-step_size * numer.div_(denom))) diff --git a/lib/python3.10/site-packages/torch/optim/adadelta.py b/lib/python3.10/site-packages/torch/optim/adadelta.py new file mode 100644 index 0000000000000000000000000000000000000000..ef45706176a347d2f987a4c6d45eb5bf6b8aebc4 --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/adadelta.py @@ -0,0 +1,461 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import Any, cast, Dict, List, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["Adadelta", "adadelta"] + + +class Adadelta(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0, + foreach: Optional[bool] = None, + *, + capturable: bool = False, + maximize: bool = False, + differentiable: bool = False, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= rho <= 1.0: + raise ValueError(f"Invalid rho value: {rho}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable, + foreach=foreach, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group: Dict[str, Any], + params_with_grad: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + state_steps: List[Tensor], + ): + has_complex = False + p: Tensor + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("Adadelta does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # Lazy state initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.zeros((), dtype=_get_scalar_dtype()) + ) + + state["square_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["acc_delta"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + square_avgs.append(state["square_avg"]) + acc_deltas.append(state["acc_delta"]) + state_steps.append(state["step"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + square_avgs: List[Tensor] = [] + acc_deltas: List[Tensor] = [] + state_steps: List[Tensor] = [] + ( + lr, + rho, + eps, + weight_decay, + foreach, + maximize, + differentiable, + capturable, + ) = ( + group["lr"], + group["rho"], + group["eps"], + group["weight_decay"], + group["foreach"], + group["maximize"], + group["differentiable"], + group["capturable"], + ) + + has_complex = self._init_group( + group, params_with_grad, grads, square_avgs, acc_deltas, state_steps + ) + + adadelta( + params_with_grad, + grads, + square_avgs, + acc_deltas, + state_steps, + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) + + return loss + + +Adadelta.__doc__ = ( + r"""Implements Adadelta algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, + \: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)}, + \: \lambda \text{ (weight decay)} \\ + &\textbf{initialize} : v_0 \leftarrow 0 \: \text{ (square avg)}, + \: u_0 \leftarrow 0 \: \text{ (accumulate variables)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}if \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm} v_t \leftarrow v_{t-1} \rho + g^2_t (1 - \rho) \\ + &\hspace{5mm}\Delta x_t \leftarrow \frac{\sqrt{u_{t-1} + + \epsilon }}{ \sqrt{v_t + \epsilon} }g_t \hspace{21mm} \\ + &\hspace{5mm} u_t \leftarrow u_{t-1} \rho + + \Delta x^2_t (1 - \rho) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \Delta x_t \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `ADADELTA: An Adaptive Learning Rate Method`_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + rho (float, optional): coefficient used for computing a running average + of squared gradients (default: 0.9). A higher value of `rho` will + result in a slower average, which can be helpful for preventing + oscillations in the learning process. + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-6). + lr (float, Tensor, optional): coefficient that scale delta before it is applied + to the parameters (default: 1.0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + {_foreach_doc} + {_capturable_doc} + {_maximize_doc} + {_differentiable_doc} + + .. _ADADELTA\: An Adaptive Learning Rate Method: + https://arxiv.org/abs/1212.5701 + + """ +) + + +def _single_tensor_adadelta( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + state_steps: List[Tensor], + *, + lr: float, + rho: float, + eps: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + for param, grad, square_avg, acc_delta, step in zip( + params, grads, square_avgs, acc_deltas, state_steps + ): + step += 1 + grad = grad if not maximize else -grad + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if torch.is_complex(param): + square_avg = torch.view_as_real(square_avg) + acc_delta = torch.view_as_real(acc_delta) + grad = torch.view_as_real(grad) + + square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) + std = square_avg.add(eps).sqrt_() + delta = acc_delta.add(eps).sqrt_() + if differentiable: + delta = delta.clone() + delta.div_(std).mul_(grad) + acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) + + if torch.is_complex(param): + delta = torch.view_as_complex(delta) + param.add_(delta, alpha=-lr) + + +def _multi_tensor_adadelta( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + state_steps: List[Tensor], + *, + lr: float, + rho: float, + eps: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + if len(params) == 0: + return + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, square_avgs, acc_deltas, state_steps] # type: ignore[list-item] + ) + for ( + device_params_, + device_grads_, + device_square_avgs_, + device_acc_deltas_, + device_state_steps_, + ), _ in grouped_tensors.values(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_square_avgs = cast(List[Tensor], device_square_avgs_) + device_acc_deltas = cast(List[Tensor], device_acc_deltas_) + device_state_steps = cast(List[Tensor], device_state_steps_) + if has_complex: + _view_as_real( + device_params, device_grads, device_square_avgs, device_acc_deltas + ) + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + torch._foreach_mul_(device_square_avgs, rho) + torch._foreach_addcmul_( + device_square_avgs, device_grads, device_grads, value=1 - rho + ) + + std = torch._foreach_add(device_square_avgs, eps) + torch._foreach_sqrt_(std) + + deltas = torch._foreach_add(device_acc_deltas, eps) + torch._foreach_sqrt_(deltas) + torch._foreach_div_(deltas, std) + torch._foreach_mul_(deltas, device_grads) + + torch._foreach_mul_(device_acc_deltas, rho) + torch._foreach_addcmul_(device_acc_deltas, deltas, deltas, value=1 - rho) + + # If LR is a tensor, the else branch will internally call item() + # which will cause silent incorrectness if we are capturing + if capturable and isinstance(lr, torch.Tensor): + torch._foreach_mul_(deltas, -lr) + torch._foreach_add_(device_params, deltas) + else: + torch._foreach_add_(device_params, deltas, alpha=-lr) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta) +def adadelta( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + capturable: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + has_complex: bool = False, + *, + lr: float, + rho: float, + eps: float, + weight_decay: float, + maximize: bool, +): + r"""Functional API that performs Adadelta algorithm computation. + + See :class:`~torch.optim.Adadelta` for details. + """ + + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not torch._utils.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + # We still respect when the user inputs False for foreach. + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adadelta + else: + func = _single_tensor_adadelta + + func( + params, + grads, + square_avgs, + acc_deltas, + state_steps, + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) diff --git a/lib/python3.10/site-packages/torch/optim/adagrad.py b/lib/python3.10/site-packages/torch/optim/adagrad.py new file mode 100644 index 0000000000000000000000000000000000000000..7427471c1bfd49ea31b1505612245c349f39cf0e --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/adagrad.py @@ -0,0 +1,564 @@ +# mypy: allow-untyped-defs +from typing import cast, List, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _default_to_fused_or_foreach, + _device_dtype_check_for_fused, + _differentiable_doc, + _foreach_doc, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["Adagrad", "adagrad"] + + +class Adagrad(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-2, + lr_decay: float = 0, + weight_decay: float = 0, + initial_accumulator_value: float = 0, + eps: float = 1e-10, + foreach: Optional[bool] = None, + *, + maximize: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= lr_decay: + raise ValueError(f"Invalid lr_decay value: {lr_decay}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= initial_accumulator_value: + raise ValueError( + f"Invalid initial_accumulator_value value: {initial_accumulator_value}" + ) + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + + defaults = dict( + lr=lr, + lr_decay=lr_decay, + eps=eps, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + fused=fused, + ) + super().__init__(params, defaults) + + if fused: + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + self._need_device_dtype_check_for_fused = True + + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + state["step"] = ( + torch.zeros( + (), + dtype=_get_scalar_dtype(is_fused=group["fused"]), + device=p.device, + ) + if group["fused"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + init_value = ( + complex(initial_accumulator_value, initial_accumulator_value) + if torch.is_complex(p) + else initial_accumulator_value + ) + state["sum"] = torch.full_like( + p, init_value, memory_format=torch.preserve_format + ) + + def __setstate__(self, state): + super().__setstate__(state) + # define "fused" for + # MYPY error: Name "fused" may be undefined + fused = None + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + fused = group.setdefault("fused", None) + + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]["step"] + ) + if not step_is_tensor: + for s in state_values: + s["step"] = torch.tensor( + float(s["step"]), dtype=_get_scalar_dtype(is_fused=fused) + ) + + def share_memory(self): + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + state["sum"].share_memory_() + + def _init_group(self, group, params_with_grad, grads, state_sums, state_steps): + has_sparse_grad, has_complex = False, False + for p in group["params"]: + if p.grad is not None: + if group["fused"] and getattr( + self, + "_need_device_dtype_check_for_fused", + True, + ): + _device_dtype_check_for_fused(p, cuda_unsupported=True) + self._need_device_dtype_check_for_fused = False + has_sparse_grad |= p.grad.is_sparse + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + grads.append(p.grad) + state = self.state[p] + state_sums.append(state["sum"]) + state_steps.append(state["step"]) + + return has_sparse_grad, has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + state_sums: List[Tensor] = [] + state_steps: List[Tensor] = [] + + has_sparse_grad, has_complex = self._init_group( + group, params_with_grad, grads, state_sums, state_steps + ) + + adagrad( + params_with_grad, + grads, + state_sums, + state_steps, + lr=group["lr"], + weight_decay=group["weight_decay"], + lr_decay=group["lr_decay"], + eps=group["eps"], + has_sparse_grad=has_sparse_grad, + foreach=group["foreach"], + maximize=group["maximize"], + differentiable=group["differentiable"], + has_complex=has_complex, + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +Adagrad.__doc__ = ( + r"""Implements Adagrad algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) + \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ + &\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\ + &\textbf{initialize} : state\_sum_0 \leftarrow \tau \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\ + &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\ + &\hspace{5mm}\theta_t \leftarrow + \theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning + and Stochastic Optimization`_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, Tensor, optional): learning rate (default: 1e-2) + lr_decay (float, optional): learning rate decay (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + initial_accumulator_value (float, optional): initial value of the + sum of squares of gradients (default: 0) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-10) + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + fused (bool, optional): whether the fused implementation (CPU only) is used. + Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` + are supported. (default: None). Please note that the fused implementations does not + support sparse or complex gradients. + .. _Adaptive Subgradient Methods for Online Learning and Stochastic + Optimization: http://jmlr.org/papers/v12/duchi11a.html + + """ +) + + +def adagrad( + params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[Tensor], + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting these as kwargs for now as functional API is compiled by torch/distributed/optim + has_sparse_grad: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + has_complex: bool = False, + *, + lr: float, + weight_decay: float, + lr_decay: float, + eps: float, + maximize: bool, +): + r"""Functional API that performs Adagrad algorithm computation. + + See :class:`~torch.optim.Adagrad` for details. + """ + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if fused is None and foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if fused is None: + fused = False + if foreach is None: + foreach = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if fused and not torch.jit.is_scripting(): + func = _fused_adagrad + elif foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adagrad + else: + func = _single_tensor_adagrad + + func( + params, + grads, + state_sums, + state_steps, + lr=lr, + weight_decay=weight_decay, + lr_decay=lr_decay, + eps=eps, + has_sparse_grad=has_sparse_grad, + maximize=maximize, + differentiable=differentiable, + has_complex=has_complex, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +def _make_sparse(grad, grad_indices, values): + size = grad.size() + return torch.sparse_coo_tensor(grad_indices, values, size) + + +def _single_tensor_adagrad( + params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + lr: float, + weight_decay: float, + lr_decay: float, + eps: float, + has_sparse_grad: bool, + maximize: bool, + differentiable: bool, + has_complex: bool, +): + assert grad_scale is None and found_inf is None + for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps): + # update step + step_t += 1 + step = _get_value(step_t) + grad = grad if not maximize else -grad + + if weight_decay != 0: + if grad.is_sparse: + raise RuntimeError( + "weight_decay option is not compatible with sparse gradients" + ) + grad = grad.add(param, alpha=weight_decay) + + clr = lr / (1 + (step - 1) * lr_decay) + + if grad.is_sparse: + grad = grad.coalesce() # the update is non-linear so indices must be unique + grad_indices = grad._indices() + grad_values = grad._values() + + state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) + std = state_sum.sparse_mask(grad) + std_values = std._values().sqrt_().add_(eps) + param.add_( + _make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr + ) + else: + is_complex = torch.is_complex(param) + if is_complex: + grad = torch.view_as_real(grad) + state_sum = torch.view_as_real(state_sum) + param = torch.view_as_real(param) + state_sum.addcmul_(grad, grad, value=1) + if differentiable: + std = state_sum.sqrt() + eps + else: + std = state_sum.sqrt().add_(eps) + param.addcdiv_(grad, std, value=-clr) + if is_complex: + param = torch.view_as_complex(param) + state_sum = torch.view_as_complex(state_sum) + + +def _multi_tensor_adagrad( + params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + lr: float, + weight_decay: float, + lr_decay: float, + eps: float, + has_sparse_grad: bool, + maximize: bool, + differentiable: bool, + has_complex: bool, +): + assert not differentiable, "_foreach ops don't support autograd" + assert grad_scale is None and found_inf is None + + # Foreach functions will throw errors if given empty lists + if len(params) == 0: + return + + grouped_tensorlists = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, state_sums, state_steps] # type: ignore[list-item] + ) + for ( + device_params_, + device_grads_, + device_state_sums_, + device_state_steps_, + ), _ in grouped_tensorlists.values(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_state_sums = cast(List[Tensor], device_state_sums_) + device_state_steps = cast(List[Tensor], device_state_steps_) + + device_has_sparse_grad = has_sparse_grad and any( + grad.is_sparse for grad in device_grads + ) + + if device_has_sparse_grad: + _single_tensor_adagrad( + device_params, + device_grads, + device_state_sums, + device_state_steps, + lr=lr, + weight_decay=weight_decay, + lr_decay=lr_decay, + eps=eps, + has_sparse_grad=True, + maximize=maximize, + differentiable=differentiable, + has_complex=has_complex, + grad_scale=grad_scale, + found_inf=found_inf, + ) + continue + + # Handle complex parameters + if has_complex: + _view_as_real(device_params, device_grads, device_state_sums) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + minus_clr = [ + -lr / (1 + (_get_value(step) - 1) * lr_decay) for step in device_state_steps + ] + + torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1) + + std = torch._foreach_sqrt(device_state_sums) + torch._foreach_add_(std, eps) + + if weight_decay != 0 or maximize: + # Again, re-use the intermediate memory (device_grads) already allocated + torch._foreach_mul_(device_grads, minus_clr) + numerator = device_grads + else: + numerator = torch._foreach_mul(device_grads, minus_clr) # type: ignore[assignment] + + torch._foreach_addcdiv_(device_params, numerator, std) + + +def _fused_adagrad( + params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + lr: float, + weight_decay: float, + lr_decay: float, + eps: float, + has_sparse_grad: bool, + maximize: bool, + differentiable: bool, + has_complex: bool, +) -> None: + if not params: + return + if has_sparse_grad or has_complex: + raise RuntimeError("`fused` does not support sparse grad or complex param") + + if differentiable: + raise RuntimeError( + "adagrad with fused=True does not support differentiable=True" + ) + + grad_scale_dict = ( + {grad_scale.device: grad_scale} if grad_scale is not None else None + ) + found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, state_sums, state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_state_sums_, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_state_sums = cast(List[Tensor], device_state_sums_) + device_state_steps = cast(List[Tensor], device_state_steps_) + + device_grad_scale, device_found_inf = None, None + if grad_scale is not None and grad_scale_dict is not None: + if device not in grad_scale_dict: + grad_scale_dict[device] = grad_scale.to(device, non_blocking=True) # type: ignore[index] + device_grad_scale = grad_scale_dict[device] # type: ignore[index] + if found_inf is not None and found_inf_dict is not None: + if found_inf not in found_inf_dict: + found_inf_dict[device] = found_inf.to(device, non_blocking=True) # type: ignore[index] + device_found_inf = found_inf_dict[device] # type: ignore[index] + torch._foreach_add_(device_state_steps, 1) + torch._fused_adagrad_( + device_params, + device_grads, + device_state_sums, + device_state_steps, + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) + if device_found_inf is not None: + torch._foreach_sub_( + device_state_steps, [device_found_inf] * len(device_state_steps) + ) diff --git a/lib/python3.10/site-packages/torch/optim/adam.py b/lib/python3.10/site-packages/torch/optim/adam.py new file mode 100644 index 0000000000000000000000000000000000000000..cf8c5809ea3c36eca80639b703d6794a9171498f --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/adam.py @@ -0,0 +1,803 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import cast, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _device_dtype_check_for_fused, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _fused_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _stack_if_compiling, + _use_grad_for_differentiable, + _view_as_real, + DeviceDict, + Optimizer, + ParamsT, +) + + +__all__ = ["Adam", "adam"] + + +class Adam(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + amsgrad: bool = False, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + ): + if isinstance(lr, Tensor): + if foreach and not capturable: + raise ValueError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + if lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + maximize=maximize, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + fused=fused, + ) + super().__init__(params, defaults) + + if fused: + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + self._step_supports_amp_scaling = True + # TODO(crcrpar): [low prec params & their higher prec copy] + # Support AMP with FP16/BF16 model params which would need + # higher prec copy of params to do update math in higher prec to + # alleviate the loss of information. + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("capturable", False) + group.setdefault("differentiable", False) + fused = group.setdefault("fused", None) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, + dtype=_get_scalar_dtype(is_fused=fused), + device=p.device, + ) + if group["capturable"] or group["fused"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + if group["fused"]: + _device_dtype_check_for_fused(p) + # note(crcrpar): [special device hosting for step] + # Deliberately host `step` on CPU if both capturable and fused are off. + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = ( + torch.zeros( + (), + dtype=_get_scalar_dtype(is_fused=group["fused"]), + device=p.device, + ) + if group["capturable"] or group["fused"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["amsgrad"]: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if group["amsgrad"]: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + if group["differentiable"] and state["step"].requires_grad: + raise RuntimeError( + "`requires_grad` is not supported for `step` in differentiable mode" + ) + + # Foreach without capturable does not support a tensor lr + if ( + group["foreach"] + and torch.is_tensor(group["lr"]) + and not group["capturable"] + ): + raise RuntimeError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + + state_steps.append(state["step"]) + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + beta1, beta2 = group["betas"] + + has_complex = self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group["amsgrad"], + has_complex=has_complex, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +Adam.__doc__ = ( + r"""Implements Adam algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 + \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad}, + \:\textit{maximize} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, + \widehat{v_t}) \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR + is not yet supported for all our implementations. Please use a float + LR if you are not also specifying fused=True or capturable=True. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (bool, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + {_foreach_doc} + {_maximize_doc} + {_capturable_doc} + {_differentiable_doc} + {_fused_doc} + .. Note:: + A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + + """ +) + + +def _single_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + has_complex: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + assert grad_scale is None and found_inf is None + + if torch.jit.is_scripting(): + # this assert is due to JIT being dumb and not realizing that the ops below + # have overloads to handle both float and Tensor lrs, so we just assert it's + # a float since most people using JIT are using floats + assert isinstance(lr, float) + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + # update step + step_t += 1 + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + if amsgrad: + max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i]) + param = torch.view_as_real(param) + + # Decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + + if capturable or differentiable: + step = step_t + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + step_size_neg = step_size.neg() + + bias_correction2_sqrt = bias_correction2.sqrt() + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + if differentiable: + max_exp_avg_sq = max_exp_avg_sqs[i].clone() + else: + max_exp_avg_sq = max_exp_avg_sqs[i] + + max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq)) + + # Uses the max. for normalizing running avg. of gradient + # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write + # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) + denom = ( + max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg) + ).add_(eps / step_size_neg) + else: + denom = ( + exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg) + ).add_(eps / step_size_neg) + + param.addcdiv_(exp_avg, denom) + else: + step = _get_value(step_t) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = bias_correction2**0.5 + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) + + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + + # Lastly, switch back to complex view + if amsgrad and torch.is_complex(params[i]): + max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i]) + + +def _multi_tensor_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + has_complex: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, +): + if len(params) == 0: + return + + if isinstance(lr, Tensor) and not capturable: + raise RuntimeError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + assert grad_scale is None and found_inf is None + + assert not differentiable, "_foreach ops don't support autograd" + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs_, + device_state_steps_, + ), _ in grouped_tensors.values(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_exp_avgs = cast(List[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(List[Tensor], device_state_steps_) + + # Handle complex parameters + if has_complex: + if amsgrad: + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) + _view_as_real( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, + ) + else: + _view_as_real( + device_params, device_grads, device_exp_avgs, device_exp_avg_sqs + ) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + # Decay the first and second moment running average coefficient + torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1) + + torch._foreach_mul_(device_exp_avg_sqs, beta2) + torch._foreach_addcmul_( + device_exp_avg_sqs, device_grads, device_grads, 1 - beta2 + ) + + # Delete the local intermediate since it won't be used anymore to save on peak memory + del device_grads + + bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] + + if capturable: + bias_correction1 = torch._foreach_pow(beta1, device_state_steps) + bias_correction2 = torch._foreach_pow(beta2, device_state_steps) + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_correction1, 1) + torch._foreach_sub_(bias_correction2, 1) + # we do not negate bias_correction1 as it'll need to be negated later anyway + torch._foreach_neg_(bias_correction2) + + # foreach_div doesn't allow a scalar as the first arg + torch._foreach_div_(bias_correction1, lr) + torch._foreach_reciprocal_(bias_correction1) + + torch._foreach_sqrt_(bias_correction2) + + # Re-assign for clarity as we maintain minimal intermediates: we'll have + # step_size = - lr / (1 - beta1 ^ t) where t = num_steps + # bias_correction2_sqrt = sqrt(1 - beta2 ^ t) + step_size = bias_correction1 + bias_correction2_sqrt = bias_correction2 + + if amsgrad: + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) + # Maintains the maximum of all 2nd moment running avg. till now + torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment] + + # Set intermediate to the max. for normalizing running avg. of gradient when amsgrad + exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + torch._foreach_div_(exp_avg_sq_sqrt, step_size) + + # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr + torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt) + else: + bias_correction1 = [ + 1 - beta1 ** _get_value(step) for step in device_state_steps + ] + bias_correction2 = [ + 1 - beta2 ** _get_value(step) for step in device_state_steps + ] + + step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1]) + + bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type] + + if amsgrad: + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) + # Maintains the maximum of all 2nd moment running avg. till now + torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) + + # Use the max. for normalizing running avg. of gradient + exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + torch._foreach_addcdiv_( + device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size # type: ignore[arg-type] + ) + + +def _fused_adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + has_complex: bool, # Needed for consistency. + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, # Needed for consistency. + differentiable: bool, +) -> None: + if not params: + return + if differentiable: + raise RuntimeError("Adam with fused=True does not support differentiable=True") + + grad_scale_dict: DeviceDict = ( + {grad_scale.device: grad_scale} if grad_scale is not None else {} + ) + found_inf_dict: DeviceDict = ( + {found_inf.device: found_inf} if found_inf is not None else {} + ) + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: Optional[DeviceDict] = ( + {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None + ) + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_exp_avgs = cast(List[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(List[Tensor], device_state_steps_) + + if device.type == "mps": # type: ignore[union-attr] + assert found_inf is None and grad_scale is None + + device_grad_scale, device_found_inf = None, None + if grad_scale is not None: + device_grad_scale = grad_scale_dict.setdefault( + device, grad_scale.to(device, non_blocking=True) + ) + if found_inf is not None: + device_found_inf = found_inf_dict.setdefault( + device, found_inf.to(device, non_blocking=True) + ) + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + torch._fused_adam_( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) + if device_found_inf is not None: + torch._foreach_sub_( + device_state_steps, [device_found_inf] * len(device_state_steps) + ) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam) +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + has_complex: bool = False, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, +): + r"""Functional API that performs Adam algorithm computation. + + See :class:`~torch.optim.Adam` for details. + """ + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if fused is None and foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. + if foreach and isinstance(lr, Tensor) and not capturable: + foreach = False + if fused is None: + fused = False + if foreach is None: + foreach = False + + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not torch._utils.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if fused and not torch.jit.is_scripting(): + func = _fused_adam + elif foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adam + else: + func = _single_tensor_adam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + has_complex=has_complex, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + ) diff --git a/lib/python3.10/site-packages/torch/optim/adamax.py b/lib/python3.10/site-packages/torch/optim/adamax.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c80a2ae3dca957eacc010f80380554735de873 --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/adamax.py @@ -0,0 +1,473 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import cast, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["Adamax", "adamax"] + + +class Adamax(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 2e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + foreach: Optional[bool] = None, + *, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps + ): + has_complex = False + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("Adamax does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_inf"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_infs.append(state["exp_inf"]) + state_steps.append(state["step"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_infs: List[Tensor] = [] + state_steps: List[Tensor] = [] + + beta1, beta2 = group["betas"] + eps = group["eps"] + lr = group["lr"] + weight_decay = group["weight_decay"] + foreach = group["foreach"] + maximize = group["maximize"] + differentiable = group["differentiable"] + capturable = group["capturable"] + + has_complex = self._init_group( + group, params_with_grad, grads, exp_avgs, exp_infs, state_steps + ) + + adamax( + params_with_grad, + grads, + exp_avgs, + exp_infs, + state_steps, + eps=eps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) + + return loss + + +Adamax.__doc__ = ( + r"""Implements Adamax algorithm (a variant of Adam based on infinity norm). + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 + \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)}, + \: \lambda \text{ (weight decay)}, \\ + &\hspace{13mm} \epsilon \text{ (epsilon)} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}if \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, Tensor, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + {_capturable_doc} + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + + """ +) + + +def _single_tensor_adamax( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_infs: List[Tensor], + state_steps: List[Tensor], + *, + eps: float, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + for i, param in enumerate(params): + grad = grads[i] + grad = grad if not maximize else -grad + exp_avg = exp_avgs[i] + exp_inf = exp_infs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + # update step + step_t += 1 + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if torch.is_complex(param): + param = torch.view_as_real(param) + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_inf = torch.view_as_real(exp_inf) + + # Update biased first moment estimate. + exp_avg.lerp_(grad, 1 - beta1) + # Update the exponentially weighted infinity norm. + if not differentiable: + torch.maximum( + exp_inf.mul_(beta2), + grad.abs().add_(eps), + out=exp_inf, + ) + else: + norm_buf = torch.cat( + [exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], + 0, + ) + exp_inf.copy_(torch.amax(norm_buf, 0, keepdim=False)) + + if capturable: + # why jump through extra hoops and negate bias_correction? check out #121238 + # once fixed, we should use bias_correction with addcdiv value=-1 for readability + neg_bias_correction = beta1**step_t - 1 + neg_bias_correction.div_(lr) + denom = exp_inf * neg_bias_correction + param.addcdiv_(exp_avg, denom) + else: + bias_correction = 1 - beta1 ** _get_value(step_t) + clr = lr / bias_correction + + param.addcdiv_(exp_avg, exp_inf, value=-clr) + + +def _multi_tensor_adamax( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_infs: List[Tensor], + state_steps: List[Tensor], + *, + eps: float, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + assert not differentiable, "_foreach ops don't support autograd" + + if len(params) == 0: + return + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_infs, state_steps] # type: ignore[list-item] + ) + for ( + grouped_params_, + grouped_grads_, + grouped_exp_avgs_, + grouped_exp_infs_, + grouped_state_steps_, + ), _ in grouped_tensors.values(): + grouped_params = cast(List[Tensor], grouped_params_) + grouped_grads = cast(List[Tensor], grouped_grads_) + grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_) + grouped_exp_infs = cast(List[Tensor], grouped_exp_infs_) + grouped_state_steps = cast(List[Tensor], grouped_state_steps_) + + if has_complex: + _view_as_real( + grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs + ) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + if weight_decay != 0: + if maximize: + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) + + # Update biased first moment estimate. + torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) + + # Update the exponentially weighted infinity norm. + torch._foreach_mul_(grouped_exp_infs, beta2) + + # in this case, we need to introduce a copy of the grads + # since one has not been introduced previously + if not maximize and weight_decay == 0: + grouped_grads = torch._foreach_abs(grouped_grads) # type: ignore[assignment] + else: + torch._foreach_abs_(grouped_grads) + + torch._foreach_add_(grouped_grads, eps) + torch._foreach_maximum_(grouped_exp_infs, grouped_grads) + + bias_corrections: Union[Tuple[Tensor, ...], List[Tensor]] + if capturable: + bias_corrections = torch._foreach_pow(beta1, grouped_state_steps) + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_corrections, 1) + torch._foreach_div_(bias_corrections, lr) + + denom = torch._foreach_mul(grouped_exp_infs, bias_corrections) + torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom) + else: + bias_corrections = [ + 1 - beta1 ** _get_value(step) for step in grouped_state_steps + ] + step_size = [(_get_value(lr) / bc) * -1 for bc in bias_corrections] + torch._foreach_addcdiv_( + grouped_params, grouped_exp_avgs, grouped_exp_infs, step_size + ) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax) +def adamax( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_infs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + has_complex: bool = False, + *, + eps: float, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, +): + r"""Functional API that performs adamax algorithm computation. + + See :class:`~torch.optim.Adamax` for details. + """ + + if not torch._utils.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adamax + else: + func = _single_tensor_adamax + + func( + params, + grads, + exp_avgs, + exp_infs, + state_steps, + eps=eps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + differentiable=differentiable, + has_complex=has_complex, + capturable=capturable, + ) diff --git a/lib/python3.10/site-packages/torch/optim/adamw.py b/lib/python3.10/site-packages/torch/optim/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..0c49f528e8f136efa12d6b50709276ff5288be2f --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/adamw.py @@ -0,0 +1,801 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import cast, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _device_dtype_check_for_fused, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _fused_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _stack_if_compiling, + _use_grad_for_differentiable, + _view_as_real, + DeviceDict, + Optimizer, + ParamsT, +) + + +__all__ = ["AdamW", "adamw"] + + +class AdamW(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + *, + maximize: bool = False, + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + ): + if isinstance(lr, Tensor): + if foreach and not capturable: + raise ValueError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + if lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + foreach=foreach, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + fused=fused, + ) + super().__init__(params, defaults) + + if fused: + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + self._step_supports_amp_scaling = True + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("capturable", False) + group.setdefault("differentiable", False) + fused = group.setdefault("fused", None) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, + dtype=_get_scalar_dtype(is_fused=fused), + device=p.device, + ) + if group["capturable"] or group["fused"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group, + params_with_grad, + grads, + amsgrad, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ): + has_complex = False + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + if group["fused"]: + _device_dtype_check_for_fused(p) + # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = ( + torch.zeros( + (), + dtype=_get_scalar_dtype(is_fused=group["fused"]), + device=p.device, + ) + if group["capturable"] or group["fused"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if group["amsgrad"]: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + if group["differentiable"] and state["step"].requires_grad: + raise RuntimeError( + "`requires_grad` is not supported for `step` in differentiable mode" + ) + + # Foreach without capturable does not support a tensor lr + if ( + group["foreach"] + and isinstance(group["lr"], Tensor) + and not group["capturable"] + ): + raise RuntimeError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + + state_steps.append(state["step"]) + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + max_exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + amsgrad: bool = group["amsgrad"] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + has_complex = self._init_group( + group, + params_with_grad, + grads, + amsgrad, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + has_complex=has_complex, + ) + + return loss + + +AdamW.__doc__ = ( + r"""Implements AdamW algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 + \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, + \: \epsilon \text{ (epsilon)} \\ + &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, + \: \textit{maximize} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 + \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, + \widehat{v_t}) \\ + &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR + is not yet supported for all our implementations. Please use a float + LR if you are not also specifying fused=True or capturable=True. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (bool, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + {_maximize_doc} + {_foreach_doc} + {_capturable_doc} + {_differentiable_doc} + {_fused_doc} + .. Note:: + A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + + """ +) + + +def _single_tensor_adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[Tensor, float], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + assert grad_scale is None and found_inf is None + + if torch.jit.is_scripting(): + # this assert is due to JIT being dumb and not realizing that the ops below + # have overloads to handle both float and Tensor lrs, so we just assert it's + # a float since most people using JIT are using floats + assert isinstance(lr, float) + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + if amsgrad: + max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i]) + param = torch.view_as_real(param) + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + if capturable or differentiable: + step = step_t + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + step_size_neg = step_size.neg() + + bias_correction2_sqrt = bias_correction2.sqrt() + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + if differentiable: + max_exp_avg_sq = max_exp_avg_sqs[i].clone() + else: + max_exp_avg_sq = max_exp_avg_sqs[i] + + max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq)) + + # Uses the max. for normalizing running avg. of gradient + # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write + # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) + denom = ( + max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg) + ).add_(eps / step_size_neg) + else: + denom = ( + exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg) + ).add_(eps / step_size_neg) + + param.addcdiv_(exp_avg, denom) + else: + step = _get_value(step_t) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = bias_correction2**0.5 + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) + + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + + # Lastly, switch back to complex view + if amsgrad and torch.is_complex(params[i]): + max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i]) + + +def _multi_tensor_adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[Tensor, float], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + if isinstance(lr, Tensor) and not capturable: + raise RuntimeError( + "lr as a Tensor is not supported for capturable=False and foreach=True" + ) + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + assert not differentiable, "_foreach ops don't support autograd" + + assert grad_scale is None and found_inf is None + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs_, + device_state_steps_, + ), _ in grouped_tensors.values(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_exp_avgs = cast(List[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(List[Tensor], device_state_steps_) + + if has_complex: + if amsgrad: + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) + _view_as_real( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, + ) + else: + _view_as_real( + device_params, device_grads, device_exp_avgs, device_exp_avg_sqs + ) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: + torch._foreach_add_( + device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(device_state_steps, 1) + + # Perform stepweight decay + if weight_decay != 0: + torch._foreach_mul_(device_params, 1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1) + + torch._foreach_mul_(device_exp_avg_sqs, beta2) + torch._foreach_addcmul_( + device_exp_avg_sqs, device_grads, device_grads, 1 - beta2 + ) + + # Delete the local intermediate since it won't be used anymore to save on peak memory + del device_grads + + bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] + + if capturable: + bias_correction1 = torch._foreach_pow(beta1, device_state_steps) + bias_correction2 = torch._foreach_pow(beta2, device_state_steps) + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_correction1, 1) + torch._foreach_sub_(bias_correction2, 1) + # we do not negate bias_correction1 as it'll need to be negated later anyway + torch._foreach_neg_(bias_correction2) + + # foreach_div doesn't allow a scalar as the first arg + torch._foreach_div_(bias_correction1, lr) + torch._foreach_reciprocal_(bias_correction1) + + torch._foreach_sqrt_(bias_correction2) + + # Re-assign for clarity as we maintain minimal intermediates: we'll have + # step_size = - lr / (1 - beta1 ^ t) where t = num_steps + # bias_correction2_sqrt = sqrt(1 - beta2 ^ t) + step_size = bias_correction1 + bias_correction2_sqrt = bias_correction2 + + if amsgrad: + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) + + # Maintains the maximum of all 2nd moment running avg. till now + torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) + + # Use the max. for normalizing running avg. of gradient + exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + torch._foreach_div_(exp_avg_sq_sqrt, step_size) + + # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr + torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt) + else: + bias_correction1 = [ + 1 - beta1 ** _get_value(step) for step in device_state_steps + ] + bias_correction2 = [ + 1 - beta2 ** _get_value(step) for step in device_state_steps + ] + + step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1]) + + bias_correction2_sqrt = [ + bc**0.5 for bc in bias_correction2 # type: ignore[arg-type] + ] + + if amsgrad: + device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) + + # Maintains the maximum of all 2nd moment running avg. till now + torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) + + # Use the max. for normalizing running avg. of gradient + exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + torch._foreach_addcdiv_( + device_params, + device_exp_avgs, + exp_avg_sq_sqrt, + step_size, # type: ignore[arg-type] + ) + + +def _fused_adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[Tensor, float], + weight_decay: float, + eps: float, + maximize: bool, + capturable: bool, # Needed for consistency. + differentiable: bool, + has_complex: bool, # Needed for consistency. +) -> None: + if not params: + return + if differentiable: + raise RuntimeError("Adam with fused=True does not support differentiable=True") + + grad_scale_dict: DeviceDict = ( + {grad_scale.device: grad_scale} if grad_scale is not None else {} + ) + found_inf_dict: DeviceDict = ( + {found_inf.device: found_inf} if found_inf is not None else {} + ) + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: Optional[DeviceDict] = ( + {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None + ) + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(List[Tensor], device_params_) + device_grads = cast(List[Tensor], device_grads_) + device_exp_avgs = cast(List[Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) + device_state_steps = cast(List[Tensor], device_state_steps_) + + if device.type == "mps": # type: ignore[union-attr] + assert found_inf is None and grad_scale is None + + device_grad_scale, device_found_inf = None, None + if grad_scale is not None: + device_grad_scale = grad_scale_dict.setdefault( + device, grad_scale.to(device, non_blocking=True) + ) + if found_inf is not None: + device_found_inf = found_inf_dict.setdefault( + device, found_inf.to(device, non_blocking=True) + ) + if lr_dict is not None and device not in lr_dict: + lr = lr_dict.setdefault( + device, lr.to(device=device, non_blocking=True) # type: ignore[union-attr] + ) + torch._foreach_add_(device_state_steps, 1) + torch._fused_adamw_( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) + if device_found_inf is not None: + torch._foreach_sub_( + device_state_steps, [device_found_inf] * len(device_state_steps) + ) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamw) +def adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + has_complex: bool = False, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + maximize: bool, +): + r"""Functional API that performs AdamW algorithm computation. + + See :class:`~torch.optim.AdamW` for details. + """ + if not torch._utils.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if fused is None and foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. + if foreach and isinstance(lr, Tensor) and not capturable: + foreach = False + if fused is None: + fused = False + if foreach is None: + foreach = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if fused and not torch.jit.is_scripting(): + func = _fused_adamw + elif foreach and not torch.jit.is_scripting(): + func = _multi_tensor_adamw + else: + func = _single_tensor_adamw + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + grad_scale=grad_scale, + found_inf=found_inf, + has_complex=has_complex, + ) diff --git a/lib/python3.10/site-packages/torch/optim/asgd.py b/lib/python3.10/site-packages/torch/optim/asgd.py new file mode 100644 index 0000000000000000000000000000000000000000..79de96aa86cd2f5b28ad2ed36d57578ec520c1e0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/asgd.py @@ -0,0 +1,465 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import cast, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["ASGD", "asgd"] + + +class ASGD(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-2, + lambd: float = 1e-4, + alpha: float = 0.75, + t0: float = 1e6, + weight_decay: float = 0, + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + lambd=lambd, + alpha=alpha, + t0=t0, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0: + if not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if not torch.is_tensor(p_state["eta"]): + p_state["eta"] = torch.tensor( + p_state["eta"], dtype=_get_scalar_dtype(), device=p.device + ) + if not torch.is_tensor(p_state["mu"]): + p_state["mu"] = torch.tensor( + p_state["mu"], dtype=_get_scalar_dtype(), device=p.device + ) + + def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("ASGD does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + # State initialization + if len(state) == 0: + state["step"] = torch.zeros( + (), device=p.device, dtype=_get_scalar_dtype() + ) + state["eta"] = ( + torch.as_tensor( + group["lr"], device=p.device, dtype=_get_scalar_dtype() + ) + .clone() + .detach() + ) + state["mu"] = torch.ones( + (), device=p.device, dtype=_get_scalar_dtype() + ) + state["ax"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + mus.append(state["mu"]) + axs.append(state["ax"]) + etas.append(state["eta"]) + state_steps.append(state["step"]) + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + mus: List[Tensor] = [] + axs: List[Tensor] = [] + etas: List[Tensor] = [] + state_steps: List[Tensor] = [] + + has_complex = self._init_group( + group, params_with_grad, grads, mus, axs, etas, state_steps + ) + + asgd( + params_with_grad, + grads, + axs, + mus, + etas, + state_steps, + lambd=group["lambd"], + lr=group["lr"], + t0=group["t0"], + alpha=group["alpha"], + weight_decay=group["weight_decay"], + foreach=group["foreach"], + maximize=group["maximize"], + differentiable=group["differentiable"], + capturable=group["capturable"], + has_complex=has_complex, + ) + + return loss + + +ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent. + + It has been proposed in `Acceleration of stochastic approximation by + averaging`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, Tensor, optional): learning rate (default: 1e-2) + lambd (float, optional): decay term (default: 1e-4) + alpha (float, optional): power for eta update (default: 0.75) + t0 (float, optional): point at which to start averaging (default: 1e6) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + {_capturable_doc} + + .. _Acceleration of stochastic approximation by averaging: + https://dl.acm.org/citation.cfm?id=131098 + + """ + + +def _single_tensor_asgd( + params: List[Tensor], + grads: List[Tensor], + axs: List[Tensor], + mus: List[Tensor], + etas: List[Tensor], + state_steps: List[Tensor], + *, + lambd: float, + lr: float, + t0: float, + alpha: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + for i, param in enumerate(params): + grad = grads[i] + grad = grad if not maximize else -grad + mu = mus[i] + ax = axs[i] + eta = etas[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type + == mu.device.type + == eta.device.type + == step_t.device.type + and param.device.type in capturable_supported_devices + ), ( + f"If capturable=True, params, mus, etas, and state_steps must be " + f"on supported devices: {capturable_supported_devices}." + ) + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + param = torch.view_as_real(param) + ax = torch.view_as_real(ax) + + # update step + step_t += 1 + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if capturable: + param.mul_(1 - lambd * eta) + param.addcmul_(grad, eta, value=-1) # update parameter + else: + eta_value = _get_value(eta) + param.mul_(1 - lambd * eta_value) # decay term + param.add_(grad, alpha=-eta_value) # update parameter + + # averaging + if capturable or mu.item() != 1: + ax.add_(param.sub(ax).mul_(mu)) + else: + ax.copy_(param) + + if capturable: + eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha)) + mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t))) + else: + step = _get_value(step_t) + new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha)) + eta.copy_(new_eta) + new_mu = torch.as_tensor(1 / max(1, step - t0)) + mu.copy_(new_mu) + + +def _multi_tensor_asgd( + params: List[Tensor], + grads: List[Tensor], + axs: List[Tensor], + mus: List[Tensor], + etas: List[Tensor], + state_steps: List[Tensor], + *, + lambd: float, + lr: float, + t0: float, + alpha: float, + weight_decay: float, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == mu.device.type == eta.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, mu, eta, step in zip(params, mus, etas, state_steps) + ), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}." + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, axs, mus, etas, state_steps] # type: ignore[list-item] + ) + for (device, _), ( + ( + grouped_params_, + grouped_grads_, + grouped_axs_, + grouped_mus_, + grouped_etas_, + grouped_state_steps_, + ), + _, + ) in grouped_tensors.items(): + grouped_params = cast(List[Tensor], grouped_params_) + grouped_grads = cast(List[Tensor], grouped_grads_) + grouped_axs = cast(List[Tensor], grouped_axs_) + grouped_mus = cast(List[Tensor], grouped_mus_) + grouped_etas = cast(List[Tensor], grouped_etas_) + grouped_state_steps = cast(List[Tensor], grouped_state_steps_) + + if has_complex: + _view_as_real(grouped_params, grouped_grads, grouped_axs) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + # intermediate = grad + param * lambd + intermediate: Union[Tuple[Tensor, ...], List[Tensor]] + if weight_decay != 0: + if maximize: + torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) + intermediate = grouped_grads + else: + intermediate = torch._foreach_add( + grouped_grads, grouped_params, alpha=weight_decay + ) + + torch._foreach_add_(intermediate, grouped_params, alpha=lambd) + else: + intermediate = torch._foreach_add( + grouped_grads, grouped_params, alpha=lambd + ) + + # update param + # param * (1 - lambd * eta) - eta * grad + # => param - param * lambd * eta - eta * grad + # => param - eta * intermediate + torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1) + del intermediate + + # update grouped_axs + # averaging: ax = ax + mu * (param - ax) + # Note (mlazos): We can't use lerp here since it requires weight to be float64 + # and our grouping code requires dtypes to match for all tensors in a group (and it should, since + # we use the mus in other places) + # all dtypes need to match, so we could introduce a cast in a loop + # but since this only adds one additional kernel launch, this looks like the cleaner + # and faster solution + intermediate = torch._foreach_sub(grouped_params, grouped_axs) + torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus) + del intermediate + + new_etas: Union[Tuple[Tensor, ...], List[Tensor]] + new_mus: Union[Tuple[Tensor, ...], List[Tensor]] + if capturable: + # update grouped_mus + new_mus = torch._foreach_sub(grouped_state_steps, t0) + torch._foreach_maximum_(new_mus, 1.0) + torch._foreach_reciprocal_(new_mus) + torch._foreach_copy_(grouped_mus, new_mus) + del new_mus + + # update eta = lr / ((1 + lambd * lr * step)^alpha) + new_etas = torch._foreach_mul(grouped_state_steps, lambd) + torch._foreach_mul_(new_etas, lr) + torch._foreach_add_(new_etas, 1) + torch._foreach_pow_(new_etas, alpha) + torch._foreach_reciprocal_(new_etas) + torch._foreach_mul_(new_etas, lr) + torch._foreach_copy_(grouped_etas, new_etas) + else: + new_etas = [ + torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device) + for step in grouped_state_steps + ] + new_mus = [ + torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device) + for step in grouped_state_steps + ] + torch._foreach_copy_(grouped_etas, new_etas) + torch._foreach_copy_(grouped_mus, new_mus) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd) +def asgd( + params: List[Tensor], + grads: List[Tensor], + axs: List[Tensor], + mus: List[Tensor], + etas: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + has_complex: bool = False, + *, + lambd: float, + lr: float, + t0: float, + alpha: float, + weight_decay: float, +): + r"""Functional API that performs asgd algorithm computation. + + See :class:`~torch.optim.ASGD` for details. + """ + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_asgd + else: + func = _single_tensor_asgd + + func( + params, + grads, + axs, + mus, + etas, + state_steps, + lambd=lambd, + lr=lr, + t0=t0, + alpha=alpha, + weight_decay=weight_decay, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) diff --git a/lib/python3.10/site-packages/torch/optim/lbfgs.py b/lib/python3.10/site-packages/torch/optim/lbfgs.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c2e13077e3b8409f0048d60971619d2b4da588 --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/lbfgs.py @@ -0,0 +1,495 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor + +from .optimizer import Optimizer, ParamsT + + +__all__ = ["LBFGS"] + + +def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): + # ported from https://github.com/torch/optim/blob/master/polyinterp.lua + # Compute bounds of interpolation area + if bounds is not None: + xmin_bound, xmax_bound = bounds + else: + xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) + + # Code for most common case: cubic interpolation of 2 points + # w/ function and derivative values for both + # Solution in this case (where x2 is the farthest point): + # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); + # d2 = sqrt(d1^2 - g1*g2); + # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); + # t_new = min(max(min_pos,xmin_bound),xmax_bound); + d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) + d2_square = d1**2 - g1 * g2 + if d2_square >= 0: + d2 = d2_square.sqrt() + if x1 <= x2: + min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) + else: + min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) + return min(max(min_pos, xmin_bound), xmax_bound) + else: + return (xmin_bound + xmax_bound) / 2.0 + + +def _strong_wolfe( + obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25 +): + # ported from https://github.com/torch/optim/blob/master/lswolfe.lua + d_norm = d.abs().max() + g = g.clone(memory_format=torch.contiguous_format) + # evaluate objective and gradient using initial step + f_new, g_new = obj_func(x, t, d) + ls_func_evals = 1 + gtd_new = g_new.dot(d) + + # bracket an interval containing a point satisfying the Wolfe criteria + t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd + done = False + ls_iter = 0 + while ls_iter < max_ls: + # check conditions + if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): + bracket = [t_prev, t] + bracket_f = [f_prev, f_new] + bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] + bracket_gtd = [gtd_prev, gtd_new] + break + + if abs(gtd_new) <= -c2 * gtd: + bracket = [t] + bracket_f = [f_new] + bracket_g = [g_new] + done = True + break + + if gtd_new >= 0: + bracket = [t_prev, t] + bracket_f = [f_prev, f_new] + bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] + bracket_gtd = [gtd_prev, gtd_new] + break + + # interpolate + min_step = t + 0.01 * (t - t_prev) + max_step = t * 10 + tmp = t + t = _cubic_interpolate( + t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step) + ) + + # next step + t_prev = tmp + f_prev = f_new + g_prev = g_new.clone(memory_format=torch.contiguous_format) + gtd_prev = gtd_new + f_new, g_new = obj_func(x, t, d) + ls_func_evals += 1 + gtd_new = g_new.dot(d) + ls_iter += 1 + + # reached max number of iterations? + if ls_iter == max_ls: + bracket = [0, t] + bracket_f = [f, f_new] + bracket_g = [g, g_new] + + # zoom phase: we now have a point satisfying the criteria, or + # a bracket around it. We refine the bracket until we find the + # exact point satisfying the criteria + insuf_progress = False + # find high and low points in bracket + low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) # type: ignore[possibly-undefined] + while not done and ls_iter < max_ls: + # line-search bracket is so small + if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: # type: ignore[possibly-undefined] + break + + # compute new trial value + t = _cubic_interpolate( + bracket[0], + bracket_f[0], + bracket_gtd[0], # type: ignore[possibly-undefined] + bracket[1], + bracket_f[1], + bracket_gtd[1], + ) + + # test that we are making sufficient progress: + # in case `t` is so close to boundary, we mark that we are making + # insufficient progress, and if + # + we have made insufficient progress in the last step, or + # + `t` is at one of the boundary, + # we will move `t` to a position which is `0.1 * len(bracket)` + # away from the nearest boundary point. + eps = 0.1 * (max(bracket) - min(bracket)) + if min(max(bracket) - t, t - min(bracket)) < eps: + # interpolation close to boundary + if insuf_progress or t >= max(bracket) or t <= min(bracket): + # evaluate at 0.1 away from boundary + if abs(t - max(bracket)) < abs(t - min(bracket)): + t = max(bracket) - eps + else: + t = min(bracket) + eps + insuf_progress = False + else: + insuf_progress = True + else: + insuf_progress = False + + # Evaluate new point + f_new, g_new = obj_func(x, t, d) + ls_func_evals += 1 + gtd_new = g_new.dot(d) + ls_iter += 1 + + if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: + # Armijo condition not satisfied or not lower than lowest point + bracket[high_pos] = t + bracket_f[high_pos] = f_new + bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] + bracket_gtd[high_pos] = gtd_new + low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) + else: + if abs(gtd_new) <= -c2 * gtd: + # Wolfe conditions satisfied + done = True + elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: + # old high becomes new low + bracket[high_pos] = bracket[low_pos] + bracket_f[high_pos] = bracket_f[low_pos] + bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined] + bracket_gtd[high_pos] = bracket_gtd[low_pos] + + # new point becomes new low + bracket[low_pos] = t + bracket_f[low_pos] = f_new + bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] + bracket_gtd[low_pos] = gtd_new + + # return stuff + t = bracket[low_pos] # type: ignore[possibly-undefined] + f_new = bracket_f[low_pos] + g_new = bracket_g[low_pos] # type: ignore[possibly-undefined] + return f_new, g_new, t, ls_func_evals + + +class LBFGS(Optimizer): + """Implements L-BFGS algorithm. + + Heavily inspired by `minFunc + `_. + + .. warning:: + This optimizer doesn't support per-parameter options and parameter + groups (there can be only one). + + .. warning:: + Right now all parameters have to be on a single device. This will be + improved in the future. + + .. note:: + This is a very memory intensive optimizer (it requires additional + ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory + try reducing the history size, or use a different algorithm. + + Args: + params (iterable): iterable of parameters to optimize. Parameters must be real. + lr (float): learning rate (default: 1) + max_iter (int): maximal number of iterations per optimization step + (default: 20) + max_eval (int): maximal number of function evaluations per optimization + step (default: max_iter * 1.25). + tolerance_grad (float): termination tolerance on first order optimality + (default: 1e-7). + tolerance_change (float): termination tolerance on function + value/parameter changes (default: 1e-9). + history_size (int): update history size (default: 100). + line_search_fn (str): either 'strong_wolfe' or None (default: None). + """ + + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1, + max_iter: int = 20, + max_eval: Optional[int] = None, + tolerance_grad: float = 1e-7, + tolerance_change: float = 1e-9, + history_size: int = 100, + line_search_fn: Optional[str] = None, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if max_eval is None: + max_eval = max_iter * 5 // 4 + defaults = dict( + lr=lr, + max_iter=max_iter, + max_eval=max_eval, + tolerance_grad=tolerance_grad, + tolerance_change=tolerance_change, + history_size=history_size, + line_search_fn=line_search_fn, + ) + super().__init__(params, defaults) + + if len(self.param_groups) != 1: + raise ValueError( + "LBFGS doesn't support per-parameter options " "(parameter groups)" + ) + + self._params = self.param_groups[0]["params"] + self._numel_cache = None + + def _numel(self): + if self._numel_cache is None: + self._numel_cache = sum( + 2 * p.numel() if torch.is_complex(p) else p.numel() + for p in self._params + ) + + return self._numel_cache + + def _gather_flat_grad(self): + views = [] + for p in self._params: + if p.grad is None: + view = p.new(p.numel()).zero_() + elif p.grad.is_sparse: + view = p.grad.to_dense().view(-1) + else: + view = p.grad.view(-1) + if torch.is_complex(view): + view = torch.view_as_real(view).view(-1) + views.append(view) + return torch.cat(views, 0) + + def _add_grad(self, step_size, update): + offset = 0 + for p in self._params: + if torch.is_complex(p): + p = torch.view_as_real(p) + numel = p.numel() + # view as to avoid deprecated pointwise semantics + p.add_(update[offset : offset + numel].view_as(p), alpha=step_size) + offset += numel + assert offset == self._numel() + + def _clone_param(self): + return [p.clone(memory_format=torch.contiguous_format) for p in self._params] + + def _set_param(self, params_data): + for p, pdata in zip(self._params, params_data): + p.copy_(pdata) + + def _directional_evaluate(self, closure, x, t, d): + self._add_grad(t, d) + loss = float(closure()) + flat_grad = self._gather_flat_grad() + self._set_param(x) + return loss, flat_grad + + @torch.no_grad() + def step(self, closure): + """Perform a single optimization step. + + Args: + closure (Callable): A closure that reevaluates the model + and returns the loss. + """ + assert len(self.param_groups) == 1 + + # Make sure the closure is always called with grad enabled + closure = torch.enable_grad()(closure) + + group = self.param_groups[0] + lr = group["lr"] + max_iter = group["max_iter"] + max_eval = group["max_eval"] + tolerance_grad = group["tolerance_grad"] + tolerance_change = group["tolerance_change"] + line_search_fn = group["line_search_fn"] + history_size = group["history_size"] + + # NOTE: LBFGS has only global state, but we register it as state for + # the first param, because this helps with casting in load_state_dict + state = self.state[self._params[0]] + state.setdefault("func_evals", 0) + state.setdefault("n_iter", 0) + + # evaluate initial f(x) and df/dx + orig_loss = closure() + loss = float(orig_loss) + current_evals = 1 + state["func_evals"] += 1 + + flat_grad = self._gather_flat_grad() + opt_cond = flat_grad.abs().max() <= tolerance_grad + + # optimal condition + if opt_cond: + return orig_loss + + # tensors cached in state (for tracing) + d = state.get("d") + t = state.get("t") + old_dirs = state.get("old_dirs") + old_stps = state.get("old_stps") + ro = state.get("ro") + H_diag = state.get("H_diag") + prev_flat_grad = state.get("prev_flat_grad") + prev_loss = state.get("prev_loss") + + n_iter = 0 + # optimize for a max of max_iter iterations + while n_iter < max_iter: + # keep track of nb of iterations + n_iter += 1 + state["n_iter"] += 1 + + ############################################################ + # compute gradient descent direction + ############################################################ + if state["n_iter"] == 1: + d = flat_grad.neg() + old_dirs = [] + old_stps = [] + ro = [] + H_diag = 1 + else: + # do lbfgs update (update memory) + y = flat_grad.sub(prev_flat_grad) + s = d.mul(t) + ys = y.dot(s) # y*s + if ys > 1e-10: + # updating memory + if len(old_dirs) == history_size: + # shift history by one (limited-memory) + old_dirs.pop(0) + old_stps.pop(0) + ro.pop(0) + + # store new direction/step + old_dirs.append(y) + old_stps.append(s) + ro.append(1.0 / ys) + + # update scale of initial Hessian approximation + H_diag = ys / y.dot(y) # (y*y) + + # compute the approximate (L-BFGS) inverse Hessian + # multiplied by the gradient + num_old = len(old_dirs) + + if "al" not in state: + state["al"] = [None] * history_size + al = state["al"] + + # iteration in L-BFGS loop collapsed to use just one buffer + q = flat_grad.neg() + for i in range(num_old - 1, -1, -1): + al[i] = old_stps[i].dot(q) * ro[i] + q.add_(old_dirs[i], alpha=-al[i]) + + # multiply by initial Hessian + # r/d is the final direction + d = r = torch.mul(q, H_diag) + for i in range(num_old): + be_i = old_dirs[i].dot(r) * ro[i] + r.add_(old_stps[i], alpha=al[i] - be_i) + + if prev_flat_grad is None: + prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) + else: + prev_flat_grad.copy_(flat_grad) + prev_loss = loss + + ############################################################ + # compute step length + ############################################################ + # reset initial guess for step size + if state["n_iter"] == 1: + t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr + else: + t = lr + + # directional derivative + gtd = flat_grad.dot(d) # g * d + + # directional derivative is below tolerance + if gtd > -tolerance_change: + break + + # optional line search: user function + ls_func_evals = 0 + if line_search_fn is not None: + # perform line search, using user function + if line_search_fn != "strong_wolfe": + raise RuntimeError("only 'strong_wolfe' is supported") + else: + x_init = self._clone_param() + + def obj_func(x, t, d): + return self._directional_evaluate(closure, x, t, d) + + loss, flat_grad, t, ls_func_evals = _strong_wolfe( + obj_func, x_init, t, d, loss, flat_grad, gtd + ) + self._add_grad(t, d) + opt_cond = flat_grad.abs().max() <= tolerance_grad + else: + # no line search, simply move with fixed-step + self._add_grad(t, d) + if n_iter != max_iter: + # re-evaluate function only if not in last iteration + # the reason we do this: in a stochastic setting, + # no use to re-evaluate that function here + with torch.enable_grad(): + loss = float(closure()) + flat_grad = self._gather_flat_grad() + opt_cond = flat_grad.abs().max() <= tolerance_grad + ls_func_evals = 1 + + # update func eval + current_evals += ls_func_evals + state["func_evals"] += ls_func_evals + + ############################################################ + # check conditions + ############################################################ + if n_iter == max_iter: + break + + if current_evals >= max_eval: + break + + # optimal condition + if opt_cond: + break + + # lack of progress + if d.mul(t).abs().max() <= tolerance_change: + break + + if abs(loss - prev_loss) < tolerance_change: + break + + state["d"] = d + state["t"] = t + state["old_dirs"] = old_dirs + state["old_stps"] = old_stps + state["ro"] = ro + state["H_diag"] = H_diag + state["prev_flat_grad"] = prev_flat_grad + state["prev_loss"] = prev_loss + + return orig_loss diff --git a/lib/python3.10/site-packages/torch/optim/lr_scheduler.py b/lib/python3.10/site-packages/torch/optim/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..57dcbd85a8316444ef59ee6e12d692fdc6f657d9 --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/lr_scheduler.py @@ -0,0 +1,2151 @@ +# mypy: allow-untyped-defs +r"""Learning Rate Scheduler.""" +import math +import types +import warnings +from bisect import bisect_right +from collections import Counter +from functools import partial, wraps +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Literal, + Optional, + Sequence, + SupportsFloat, + TypedDict, + Union, +) +from weakref import ref + +from torch import inf, Tensor + +from .optimizer import Optimizer + + +__all__ = [ + "LambdaLR", + "MultiplicativeLR", + "StepLR", + "MultiStepLR", + "ConstantLR", + "LinearLR", + "ExponentialLR", + "SequentialLR", + "CosineAnnealingLR", + "ChainedScheduler", + "ReduceLROnPlateau", + "CyclicLR", + "CosineAnnealingWarmRestarts", + "OneCycleLR", + "PolynomialLR", + "LRScheduler", +] + +EPOCH_DEPRECATION_WARNING = ( + "The epoch parameter in `scheduler.step()` was not necessary and is being " + "deprecated where possible. Please use `scheduler.step()` to step the " + "scheduler. During the deprecation, if epoch is different from None, the " + "closed form is used instead of the new chainable form, where available. " + "Please open an issue if you are unable to replicate your use case: " + "https://github.com/pytorch/pytorch/issues/new/choose." +) + + +def _check_verbose_deprecated_warning(verbose): + """Raise a warning when verbose is not the default value.""" + if verbose != "deprecated": + warnings.warn( + "The verbose parameter is deprecated. Please use get_last_lr() " + "to access the learning rate.", + UserWarning, + ) + return verbose + return False + + +def _format_param(name: str, optimizer: Optimizer, param): + """Return correctly formatted lr/momentum for each param group.""" + + def _copy(_param): + return _param.clone() if isinstance(_param, Tensor) else _param + + if isinstance(param, (list, tuple)): + if len(param) != len(optimizer.param_groups): + raise ValueError( + f"{name} must have the same length as optimizer.param_groups. " + f"{name} has {len(param)} values, param_groups has {len(optimizer.param_groups)}." + ) + else: + param = [param] * len(optimizer.param_groups) + + return list(map(_copy, param)) + + +class LRScheduler: + r"""Adjusts the learning rate during optimization.""" + + _get_lr_called_within_step: bool = False + + def __init__( + self, optimizer: Optimizer, last_epoch=-1, verbose="deprecated" + ): # noqa: D107 + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if last_epoch == -1: + for group in optimizer.param_groups: + initial_lr = group["lr"] + if isinstance(initial_lr, Tensor): + initial_lr = initial_lr.clone() + group.setdefault("initial_lr", initial_lr) + else: + for i, group in enumerate(optimizer.param_groups): + if "initial_lr" not in group: + raise KeyError( + "param 'initial_lr' is not specified " + f"in param_groups[{i}] when resuming an optimizer" + ) + self.base_lrs: List[float] = [ + group["initial_lr"] for group in optimizer.param_groups + ] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def patch_track_step_called(opt: Optimizer): + if hasattr(opt.step, "_wrapped_by_lr_sched"): + # we've already patched + return opt.step + + def wrap_step(step_fn): + opt_ref = ref(self.optimizer) + func = step_fn.__func__ + + @wraps(func) + def wrapper(*args, **kwargs): + opt = opt_ref() + opt._opt_called = True # type: ignore[union-attr] + return func.__get__(opt, opt.__class__)(*args, **kwargs) + + wrapper._wrapped_by_lr_sched = True # type: ignore[attr-defined] + return wrapper + + opt.step = wrap_step(opt.step) # type: ignore[method-assign] + + patch_track_step_called(self.optimizer) + self.verbose = _check_verbose_deprecated_warning(verbose) + self._initial_step() + + def _initial_step(self): + """Initialize step counts and perform a step.""" + self._step_count = 0 + self.step() + + def state_dict(self): + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler.""" + return self._last_lr + + def get_lr(self) -> List[float]: + """Compute learning rate using chainable form of the scheduler.""" + raise NotImplementedError + + def print_lr( + self, + is_verbose: bool, + group: Dict[str, Any], + lr: float, + epoch: Optional[int] = None, + ): + """Display the current learning rate. + + .. deprecated:: 2.4 + ``print_lr()`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + """ + warnings.warn( + "`LRScheduler.print_lr()` is being deprecated. To fetch the learning rate, " + "please use `get_last_lr()` instead. For more details, " + "see https://github.com/pytorch/pytorch/issues/99270.", + UserWarning, + ) + if is_verbose: + if epoch is None: + print(f"Adjusting learning rate of group {group} to {lr:.4e}.") + else: + epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch + print( + f"Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}." + ) + + def step(self, epoch: Optional[int] = None): + """Perform a step.""" + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"): + warnings.warn( + "Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", + UserWarning, + ) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif not getattr(self.optimizer, "_opt_called", False): + warnings.warn( + "Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", + UserWarning, + ) + self._step_count += 1 + + with _enable_get_lr_call(self): + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + if hasattr(self, "_get_closed_form_lr"): + values = cast(List[float], self._get_closed_form_lr()) + else: + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + if isinstance(param_group["lr"], Tensor): + param_group["lr"].fill_(lr) + else: + param_group["lr"] = lr + + self._last_lr: List[float] = [ + group["lr"] for group in self.optimizer.param_groups + ] + + +def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler): + if not lr_scheduler._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + UserWarning, + stacklevel=2, + ) + + +# Including _LRScheduler for backwards compatibility +# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler). +class _LRScheduler(LRScheduler): + pass + + +class _enable_get_lr_call: + def __init__(self, o: LRScheduler): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + + +class LambdaLR(LRScheduler): + """Sets the initial learning rate. + + The learning rate of each parameter group is set to the initial lr + times a given function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer has two groups. + >>> lambda1 = lambda epoch: epoch // 30 + >>> lambda2 = lambda epoch: 0.95 ** epoch + >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.optimizer = optimizer + + self.lr_lambdas: List[Callable[[int], float]] + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError( + f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" + ) + self.lr_lambdas = list(lr_lambda) + super().__init__(optimizer, last_epoch, verbose) + + def state_dict(self): + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "lr_lambdas") + } + state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict["lr_lambdas"][idx] = fn.__dict__.copy() + + return state_dict + + def load_state_dict(self, state_dict): + """Load the scheduler's state. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + lr_lambdas = state_dict.pop("lr_lambdas") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["lr_lambdas"] = lr_lambdas + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + + def get_lr(self): + """Compute learning rate.""" + _warn_get_lr_called_within_step(self) + + return [ + base_lr * lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs) + ] + + +class MultiplicativeLR(LRScheduler): + """Multiply the learning rate of each parameter group by the factor given in the specified function. + + When last_epoch=-1, set initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> lmbda = lambda epoch: 0.95 + >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.optimizer = optimizer + + self.lr_lambdas: List[Callable[[int], float]] + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError( + f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" + ) + self.lr_lambdas = list(lr_lambda) + super().__init__(optimizer, last_epoch, verbose) + + def state_dict(self): + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "lr_lambdas") + } + state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict["lr_lambdas"][idx] = fn.__dict__.copy() + + return state_dict + + def load_state_dict(self, state_dict): + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + lr_lambdas = state_dict.pop("lr_lambdas") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["lr_lambdas"] = lr_lambdas + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch > 0: + return [ + group["lr"] * lmbda(self.last_epoch) + for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups) + ] + else: + return [group["lr"] for group in self.optimizer.param_groups] + + +class StepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every step_size epochs. + + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + step_size (int): Period of learning rate decay. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 60 + >>> # lr = 0.0005 if 60 <= epoch < 90 + >>> # ... + >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + step_size: int, + gamma=0.1, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.step_size = step_size + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [ + base_lr * self.gamma ** (self.last_epoch // self.step_size) + for base_lr in self.base_lrs + ] + + +class MultiStepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. + + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + milestones (list): List of epoch indices. Must be increasing. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 80 + >>> # lr = 0.0005 if epoch >= 80 + >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + milestones: Iterable[int], + gamma=0.1, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.milestones = Counter(milestones) + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch not in self.milestones: + return [group["lr"] for group in self.optimizer.param_groups] + return [ + group["lr"] * self.gamma ** self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + milestones = sorted(self.milestones.elements()) + return [ + base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) + for base_lr in self.base_lrs + ] + + +class ConstantLR(LRScheduler): + """Multiply the learning rate of each parameter group by a small constant factor. + + The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters. + Notice that such multiplication of the small constant factor can + happen simultaneously with other changes to the learning rate from outside this scheduler. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + factor (float): The number we multiply learning rate until the milestone. Default: 1./3. + total_iters (int): The number of steps that the scheduler multiplies the learning rate by the factor. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.025 if epoch == 1 + >>> # lr = 0.025 if epoch == 2 + >>> # lr = 0.025 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + factor=1.0 / 3, + total_iters=5, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + if factor > 1.0 or factor < 0: + raise ValueError( + "Constant multiplicative factor expected to be between 0 and 1." + ) + + self.factor = factor + self.total_iters = total_iters + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [group["lr"] * self.factor for group in self.optimizer.param_groups] + + if self.last_epoch != self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + return [ + group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + for base_lr in self.base_lrs + ] + + +class LinearLR(LRScheduler): + """Decays the learning rate of each parameter group by linearly changing small multiplicative factor. + + The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters. + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + start_factor (float): The number we multiply learning rate in the first epoch. + The multiplication factor changes towards end_factor in the following epochs. + Default: 1./3. + end_factor (float): The number we multiply learning rate at the end of linear changing + process. Default: 1.0. + total_iters (int): The number of iterations that multiplicative factor reaches to 1. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.03125 if epoch == 1 + >>> # lr = 0.0375 if epoch == 2 + >>> # lr = 0.04375 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + start_factor=1.0 / 3, + end_factor=1.0, + total_iters=5, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + if start_factor > 1.0 or start_factor <= 0: + raise ValueError( + "Starting multiplicative factor expected to be greater than 0 and less or equal to 1." + ) + + if end_factor > 1.0 or end_factor < 0: + raise ValueError( + "Ending multiplicative factor expected to be between 0 and 1." + ) + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = total_iters + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the learning rate.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [ + group["lr"] * self.start_factor for group in self.optimizer.param_groups + ] + + if self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + return [ + group["lr"] + * ( + 1.0 + + (self.end_factor - self.start_factor) + / ( + self.total_iters * self.start_factor + + (self.last_epoch - 1) * (self.end_factor - self.start_factor) + ) + ) + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * ( + self.start_factor + + (self.end_factor - self.start_factor) + * min(self.total_iters, self.last_epoch) + / self.total_iters + ) + for base_lr in self.base_lrs + ] + + +class ExponentialLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every epoch. + + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + gamma (float): Multiplicative factor of learning rate decay. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + """ + + def __init__( + self, optimizer: Optimizer, gamma: float, last_epoch=-1, verbose="deprecated" + ): # noqa: D107 + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs] + + +class SequentialLR(LRScheduler): + """Contains a list of schedulers expected to be called sequentially during the optimization process. + + Specifically, the schedulers will be called according to the milestone points, which should provide exact + intervals by which each scheduler should be called at a given epoch. + + Args: + optimizer (Optimizer): Wrapped optimizer. + schedulers (list): List of chained schedulers. + milestones (list): List of integers that reflects milestone points. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): Does nothing. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.1 if epoch == 0 + >>> # lr = 0.1 if epoch == 1 + >>> # lr = 0.9 if epoch == 2 + >>> # lr = 0.81 if epoch == 3 + >>> # lr = 0.729 if epoch == 4 + >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9) + >>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + schedulers: List[LRScheduler], + milestones: List[int], + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + if len(schedulers) < 1: + raise ValueError( + f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler." + ) + + for scheduler_idx, scheduler in enumerate(schedulers): + if not hasattr(scheduler, "optimizer"): + raise TypeError( + f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute." + ) + if isinstance(scheduler, ReduceLROnPlateau): + raise ValueError( + f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it " + "requires additional kwargs to be specified when calling `step`, " + f"but got one at index {scheduler_idx} in the given schedulers sequence." + ) + if optimizer != scheduler.optimizer: + raise ValueError( + f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but " + f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, " + f"which is different from {optimizer.__class__.__name__}." + ) + + if len(milestones) != len(schedulers) - 1: + raise ValueError( + "Sequential Schedulers expects number of schedulers provided to be one more " + f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the " + f"number of milestones to be equal to {len(milestones)}" + ) + _check_verbose_deprecated_warning(verbose) + self._schedulers = schedulers + self._milestones = milestones + self.last_epoch = last_epoch + 1 + self.optimizer = optimizer + + # Reset learning rates back to initial values + for group in self.optimizer.param_groups: + group["lr"] = group["initial_lr"] + + # "Undo" the step performed by other schedulers + for scheduler in self._schedulers: + scheduler.last_epoch -= 1 + + # Perform the initial step for only the first scheduler + self._schedulers[0]._initial_step() + + self._last_lr = schedulers[0].get_last_lr() + + def step(self): + """Perform a step.""" + self.last_epoch += 1 + idx = bisect_right(self._milestones, self.last_epoch) + scheduler = self._schedulers[idx] + if idx > 0 and self._milestones[idx - 1] == self.last_epoch: + scheduler.step(0) + else: + scheduler.step() + + self._last_lr = scheduler.get_last_lr() + + def state_dict(self): + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "_schedulers") + } + state_dict["_schedulers"] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict["_schedulers"][idx] = s.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop("_schedulers") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["_schedulers"] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class PolynomialLR(LRScheduler): + """Decays the learning rate of each parameter group using a polynomial function in the given total_iters. + + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. + power (float): The power of the polynomial. Default: 1.0. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> # Assuming optimizer uses lr = 0.001 for all groups + >>> # lr = 0.001 if epoch == 0 + >>> # lr = 0.00075 if epoch == 1 + >>> # lr = 0.00050 if epoch == 2 + >>> # lr = 0.00025 if epoch == 3 + >>> # lr = 0.0 if epoch >= 4 + >>> scheduler = PolynomialLR(optimizer, total_iters=4, power=1.0) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + optimizer: Optimizer, + total_iters=5, + power=1.0, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.total_iters = total_iters + self.power = power + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the learning rate.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0 or self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + decay_factor = ( + (1.0 - self.last_epoch / self.total_iters) + / (1.0 - (self.last_epoch - 1) / self.total_iters) + ) ** self.power + return [group["lr"] * decay_factor for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [ + ( + base_lr + * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) + ** self.power + ) + for base_lr in self.base_lrs + ] + + +class CosineAnnealingLR(LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing schedule. + + The :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + When last_epoch=-1, sets initial lr as lr. Notice that because the schedule + is defined recursively, the learning rate can be simultaneously modified + outside this scheduler by other operators. If the learning rate is set + solely by this scheduler, the learning rate at each step becomes: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__( + self, + optimizer: Optimizer, + T_max: int, + eta_min=0.0, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + self.T_max = T_max + self.eta_min = eta_min + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Retrieve the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + if self.last_epoch == 0: + return [group["lr"] for group in self.optimizer.param_groups] + elif self._step_count == 1 and self.last_epoch > 0: + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) + / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [ + (1 + math.cos(math.pi * self.last_epoch / self.T_max)) + / (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) + / 2 + for base_lr in self.base_lrs + ] + + +class ChainedScheduler(LRScheduler): + """Chains a list of learning rate schedulers. + + Takes in a sequence of chainable learning rate schedulers and calls their + step() functions consecutively in just one call to step(). + + Args: + schedulers (sequence): sequence of chained schedulers. + optimizer (Optimizer, optional): Wrapped optimizer. Default: None. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.09 if epoch == 0 + >>> # lr = 0.081 if epoch == 1 + >>> # lr = 0.729 if epoch == 2 + >>> # lr = 0.6561 if epoch == 3 + >>> # lr = 0.59049 if epoch >= 4 + >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9) + >>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, schedulers: Sequence[LRScheduler], optimizer: Optional[Optimizer] = None + ): # noqa: D107 + if len(schedulers) < 1: + raise ValueError( + f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler." + ) + + optimizer = optimizer or schedulers[0].optimizer + for scheduler_idx, scheduler in enumerate(schedulers): + if not hasattr(scheduler, "optimizer"): + raise TypeError( + f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute." + ) + if isinstance(scheduler, ReduceLROnPlateau): + raise ValueError( + f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it " + "requires additional kwargs to be specified when calling `step`, " + f"but got one at index {scheduler_idx} in the given schedulers sequence." + ) + if optimizer != scheduler.optimizer: + raise ValueError( + f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but " + f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, " + f"which is different from {optimizer.__class__.__name__}." + ) + self._schedulers = schedulers + self.optimizer = optimizer + self._last_lr = [ + group["lr"] for group in self._schedulers[-1].optimizer.param_groups + ] + + def step(self): + """Perform a step.""" + for scheduler in self._schedulers: + scheduler.step() + self._last_lr = [ + group["lr"] for group in self._schedulers[-1].optimizer.param_groups + ] + + def state_dict(self): + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = { + key: value + for key, value in self.__dict__.items() + if key not in ("optimizer", "_schedulers") + } + state_dict["_schedulers"] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict["_schedulers"][idx] = s.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop("_schedulers") + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict["_schedulers"] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class ReduceLROnPlateau(LRScheduler): + """Reduce learning rate when a metric has stopped improving. + + Models often benefit from reducing the learning rate by a factor + of 2-10 once learning stagnates. This scheduler reads a metrics + quantity and if no improvement is seen for a 'patience' number + of epochs, the learning rate is reduced. + + Args: + optimizer (Optimizer): Wrapped optimizer. + mode (str): One of `min`, `max`. In `min` mode, lr will + be reduced when the quantity monitored has stopped + decreasing; in `max` mode it will be reduced when the + quantity monitored has stopped increasing. Default: 'min'. + factor (float): Factor by which the learning rate will be + reduced. new_lr = lr * factor. Default: 0.1. + patience (int): The number of allowed epochs with no improvement after + which the learning rate will be reduced. + For example, consider the case of having no patience (`patience = 0`). + In the first epoch, a baseline is established and is always considered good as there's no previous baseline. + In the second epoch, if the performance is worse than the baseline, + we have what is considered an intolerable epoch. + Since the count of intolerable epochs (1) is greater than the patience level (0), + the learning rate is reduced at the end of this epoch. + From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch + if the performance is worse than the baseline. If the performance improves or remains the same, + the learning rate is not adjusted. + Default: 10. + threshold (float): Threshold for measuring the new optimum, + to only focus on significant changes. Default: 1e-4. + threshold_mode (str): One of `rel`, `abs`. In `rel` mode, + dynamic_threshold = best * ( 1 + threshold ) in 'max' + mode or best * ( 1 - threshold ) in `min` mode. + In `abs` mode, dynamic_threshold = best + threshold in + `max` mode or best - threshold in `min` mode. Default: 'rel'. + cooldown (int): Number of epochs to wait before resuming + normal operation after lr has been reduced. Default: 0. + min_lr (float or list): A scalar or a list of scalars. A + lower bound on the learning rate of all param groups + or each group respectively. Default: 0. + eps (float): Minimal decay applied to lr. If the difference + between new and old lr is smaller than eps, the update is + ignored. Default: 1e-8. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = ReduceLROnPlateau(optimizer, 'min') + >>> for epoch in range(10): + >>> train(...) + >>> val_loss = validate(...) + >>> # Note that step should be called after validate() + >>> scheduler.step(val_loss) + """ + + def __init__( + self, + optimizer: Optimizer, + mode: Literal["min", "max"] = "min", + factor=0.1, + patience=10, + threshold=1e-4, + threshold_mode: Literal["rel", "abs"] = "rel", + cooldown=0, + min_lr: Union[List[float], float] = 0, + eps=1e-8, + verbose="deprecated", + ): # noqa: D107 + if factor >= 1.0: + raise ValueError("Factor should be < 1.0.") + self.factor = factor + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + if isinstance(min_lr, (list, tuple)): + if len(min_lr) != len(optimizer.param_groups): + raise ValueError( + f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}" + ) + self.min_lrs = list(min_lr) + else: + self.min_lrs = [min_lr] * len(optimizer.param_groups) + + self.patience = patience + + self.verbose = _check_verbose_deprecated_warning(verbose) + self.cooldown = cooldown + self.cooldown_counter = 0 + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + self.best: float + self.num_bad_epochs: int + self.mode_worse: float # the worse value for the chosen mode + self.eps = eps + self.last_epoch = 0 + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + self._init_is_better( + mode=mode, threshold=threshold, threshold_mode=threshold_mode + ) + self._reset() + + def _reset(self): + """Reset num_bad_epochs counter and cooldown counter.""" + self.best = self.mode_worse + self.cooldown_counter = 0 + self.num_bad_epochs = 0 + + def step(self, metrics: SupportsFloat, epoch=None): # type: ignore[override] + """Perform a step.""" + # convert `metrics` to float, in case it's a zero-dim Tensor + current = float(metrics) + if epoch is None: + epoch = self.last_epoch + 1 + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + + if self.is_better(current, self.best): + self.best = current + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + + if self.in_cooldown: + self.cooldown_counter -= 1 + self.num_bad_epochs = 0 # ignore any bad epochs in cooldown + + if self.num_bad_epochs > self.patience: + self._reduce_lr(epoch) + self.cooldown_counter = self.cooldown + self.num_bad_epochs = 0 + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def _reduce_lr(self, epoch): + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group["lr"]) + new_lr = max(old_lr * self.factor, self.min_lrs[i]) + if old_lr - new_lr > self.eps: + param_group["lr"] = new_lr + + @property + def in_cooldown(self): # noqa: D102 + return self.cooldown_counter > 0 + + def is_better(self, a, best): # noqa: D102 + if self.mode == "min" and self.threshold_mode == "rel": + rel_epsilon = 1.0 - self.threshold + return a < best * rel_epsilon + + elif self.mode == "min" and self.threshold_mode == "abs": + return a < best - self.threshold + + elif self.mode == "max" and self.threshold_mode == "rel": + rel_epsilon = self.threshold + 1.0 + return a > best * rel_epsilon + + else: # mode == 'max' and epsilon_mode == 'abs': + return a > best + self.threshold + + def _init_is_better(self, mode, threshold, threshold_mode): + if mode not in {"min", "max"}: + raise ValueError("mode " + mode + " is unknown!") + if threshold_mode not in {"rel", "abs"}: + raise ValueError("threshold mode " + threshold_mode + " is unknown!") + + if mode == "min": + self.mode_worse = inf + else: # mode == 'max': + self.mode_worse = -inf + + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + + def state_dict(self): # noqa: D102 + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict): + """Load the scheduler's state.""" + self.__dict__.update(state_dict) + self._init_is_better( + mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode + ) + + +class CyclicLR(LRScheduler): + r"""Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). + + The policy cycles the learning rate between two boundaries with a constant frequency, + as detailed in the paper `Cyclical Learning Rates for Training Neural Networks`_. + The distance between the two boundaries can be scaled on a per-iteration + or per-cycle basis. + + Cyclical learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This class has three built-in policies, as put forth in the paper: + + * "triangular": A basic triangular cycle without amplitude scaling. + * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle. + * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}` + at each cycle iteration. + + This implementation was adapted from the github repo: `bckenstler/CLR`_ + + Args: + optimizer (Optimizer): Wrapped optimizer. + base_lr (float or list): Initial learning rate which is the + lower boundary in the cycle for each parameter group. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_lr - base_lr). + The lr at any cycle is the sum of base_lr + and some scaling of the amplitude; therefore + max_lr may not actually be reached depending on + scaling function. + step_size_up (int): Number of training iterations in the + increasing half of a cycle. Default: 2000 + step_size_down (int): Number of training iterations in the + decreasing half of a cycle. If step_size_down is None, + it is set to step_size_up. Default: None + mode (str): One of {triangular, triangular2, exp_range}. + Values correspond to policies detailed above. + If scale_fn is not None, this argument is ignored. + Default: 'triangular' + gamma (float): Constant in 'exp_range' scaling function: + gamma**(cycle iterations) + Default: 1.0 + scale_fn (function): Custom scaling policy defined by a single + argument lambda function, where + 0 <= scale_fn(x) <= 1 for all x >= 0. + If specified, then 'mode' is ignored. + Default: None + scale_mode (str): {'cycle', 'iterations'}. + Defines whether scale_fn is evaluated on + cycle number or cycle iterations (training + iterations since start of cycle). + Default: 'cycle' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.8 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + The momentum at any cycle is the difference of max_momentum + and some scaling of the amplitude; therefore + base_momentum may not actually be reached depending on + scaling function. Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.9 + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) + >>> data_loader = torch.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + + .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 + .. _bckenstler/CLR: https://github.com/bckenstler/CLR + """ + + def __init__( + self, + optimizer: Optimizer, + base_lr: Union[float, List[float]], + max_lr: Union[float, List[float]], + step_size_up=2000, + step_size_down: Optional[int] = None, + mode: Literal["triangular", "triangular2", "exp_range"] = "triangular", + gamma=1.0, + scale_fn: Optional[Callable[[float], float]] = None, + scale_mode: Literal["cycle", "iterations"] = "cycle", + cycle_momentum=True, + base_momentum=0.8, + max_momentum=0.9, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + base_lrs = _format_param("base_lr", optimizer, base_lr) + if last_epoch == -1: + for lr, group in zip(base_lrs, optimizer.param_groups): + if isinstance(group["lr"], Tensor): + lr_val = lr.item() if isinstance(lr, Tensor) else lr + group["lr"].fill_(lr_val) + else: + group["lr"] = lr + + self.max_lrs = _format_param("max_lr", optimizer, max_lr) + + step_size_up = float(step_size_up) + step_size_down = ( + float(step_size_down) if step_size_down is not None else step_size_up + ) + self.total_size = step_size_up + step_size_down + self.step_ratio = step_size_up / self.total_size + + if mode not in ["triangular", "triangular2", "exp_range"] and scale_fn is None: + raise ValueError("mode is invalid and scale_fn is None") + + self.mode = mode + self.gamma = gamma + + self._scale_fn_ref: Callable[[float], float] + self._scale_fn_custom = scale_fn + self.scale_mode = scale_mode + self._init_scale_fn() + + self.cycle_momentum = cycle_momentum + if cycle_momentum: + if ( + "momentum" not in optimizer.defaults + and "betas" not in optimizer.defaults + ): + raise ValueError( + "optimizer must support momentum or beta1 with `cycle_momentum` option enabled" + ) + + self.use_beta1 = "betas" in self.optimizer.defaults + self.base_momentums = _format_param( + "base_momentum", optimizer, base_momentum + ) + self.max_momentums = _format_param("max_momentum", optimizer, max_momentum) + if last_epoch == -1: + for m_momentum, b_momentum, group in zip( + self.max_momentums, self.base_momentums, optimizer.param_groups + ): + if self.use_beta1: + group["betas"] = (m_momentum, *group["betas"][1:]) + else: + group["momentum"] = m_momentum + group["max_momentum"] = m_momentum + group["base_momentum"] = b_momentum + + super().__init__(optimizer, last_epoch, verbose) + self.base_lrs = base_lrs + + def _init_scale_fn(self): + if self._scale_fn_custom is not None: + return + if self.mode == "triangular": + self._scale_fn_ref = self._triangular_scale_fn + self.scale_mode = "cycle" + elif self.mode == "triangular2": + self._scale_fn_ref = self._triangular2_scale_fn + self.scale_mode = "cycle" + elif self.mode == "exp_range": + self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma) + self.scale_mode = "iterations" + + def scale_fn(self, x) -> float: + """Get the scaling policy.""" + if self._scale_fn_custom is not None: + return self._scale_fn_custom(x) + else: + return self._scale_fn_ref(x) # static method + + @staticmethod + def _triangular_scale_fn(x: float) -> float: + return 1.0 + + @staticmethod + def _triangular2_scale_fn(x: float) -> float: + return 1 / (2.0 ** (x - 1)) + + @staticmethod + def _exp_range_scale_fn(gamma: float, x: float) -> float: + return gamma**x + + def get_lr(self): + """Calculate the learning rate at batch index. + + This function treats `self.last_epoch` as the last batch index. + + If `self.cycle_momentum` is ``True``, this function has a side effect of + updating the optimizer's momentum. + """ + _warn_get_lr_called_within_step(self) + + cycle = math.floor(1 + self.last_epoch / self.total_size) + x = 1.0 + self.last_epoch / self.total_size - cycle + if x <= self.step_ratio: + scale_factor = x / self.step_ratio + else: + scale_factor = (x - 1) / (self.step_ratio - 1) + + lrs = [] + for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): + base_height = (max_lr - base_lr) * scale_factor + if self.scale_mode == "cycle": + lr = base_lr + base_height * self.scale_fn(cycle) + else: + lr = base_lr + base_height * self.scale_fn(self.last_epoch) + lrs.append(lr) + + if self.cycle_momentum: + momentums = [] + for base_momentum, max_momentum in zip( + self.base_momentums, self.max_momentums + ): + base_height = (max_momentum - base_momentum) * scale_factor + if self.scale_mode == "cycle": + momentum = max_momentum - base_height * self.scale_fn(cycle) + else: + momentum = max_momentum - base_height * self.scale_fn( + self.last_epoch + ) + momentums.append(momentum) + for param_group, momentum in zip(self.optimizer.param_groups, momentums): + if self.use_beta1: + param_group["betas"] = (momentum, *param_group["betas"][1:]) + else: + param_group["momentum"] = momentum + + return lrs + + def state_dict(self): # noqa: D102 + state = super().state_dict() + # We are dropping the `_scale_fn_ref` attribute because it is a + # `weakref.WeakMethod` and can't be pickled. + state.pop("_scale_fn_ref", None) + fn = state.pop("_scale_fn_custom") + state["_scale_fn_custom"] = None + if fn is not None and not isinstance(fn, types.FunctionType): + # The _scale_fn_custom will only be saved if it is a callable object + # and not if it is a function or lambda. + state["_scale_fn_custom"] = fn.__dict__.copy() + + return state + + def load_state_dict(self, state_dict): + """Load the scheduler's state.""" + fn = state_dict.pop("_scale_fn_custom") + super().load_state_dict(state_dict) + if fn is not None: + self._scale_fn_custom.__dict__.update(fn) + self._init_scale_fn() + + +class CosineAnnealingWarmRestarts(LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing schedule. + + The :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` + is the number of epochs since the last restart and :math:`T_{i}` is the number + of epochs between two warm restarts in SGDR: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) + + When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. + When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_0 (int): Number of iterations until the first restart. + T_mult (int, optional): A factor by which :math:`T_{i}` increases after a restart. Default: 1. + eta_min (float, optional): Minimum learning rate. Default: 0. + last_epoch (int, optional): The index of the last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__( + self, + optimizer: Optimizer, + T_0: int, + T_mult=1, + eta_min=0.0, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + if T_0 <= 0 or not isinstance(T_0, int): + raise ValueError(f"Expected positive integer T_0, but got {T_0}") + if T_mult < 1 or not isinstance(T_mult, int): + raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}") + if not isinstance(eta_min, (float, int)): + raise ValueError( + f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}" + ) + self.T_0 = T_0 + self.T_i = T_0 + self.T_mult = T_mult + self.eta_min = eta_min + self.T_cur = last_epoch + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + """Compute the initial learning rate.""" + _warn_get_lr_called_within_step(self) + + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * self.T_cur / self.T_i)) + / 2 + for base_lr in self.base_lrs + ] + + def step(self, epoch=None): + """Step could be called after every batch update. + + Example: + >>> # xdoctest: +SKIP("Undefined vars") + >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) + >>> iters = len(dataloader) + >>> for epoch in range(20): + >>> for i, sample in enumerate(dataloader): + >>> inputs, labels = sample['inputs'], sample['labels'] + >>> optimizer.zero_grad() + >>> outputs = net(inputs) + >>> loss = criterion(outputs, labels) + >>> loss.backward() + >>> optimizer.step() + >>> scheduler.step(epoch + i / iters) + + This function can be called in an interleaved way. + + Example: + >>> # xdoctest: +SKIP("Undefined vars") + >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) + >>> for epoch in range(20): + >>> scheduler.step() + >>> scheduler.step(26) + >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) + """ + if epoch is None and self.last_epoch < 0: + epoch = 0 + + if epoch is None: + epoch = self.last_epoch + 1 + self.T_cur = self.T_cur + 1 + if self.T_cur >= self.T_i: + self.T_cur = self.T_cur - self.T_i + self.T_i = self.T_i * self.T_mult + else: + if epoch < 0: + raise ValueError(f"Expected non-negative epoch, but got {epoch}") + if epoch >= self.T_0: + if self.T_mult == 1: + self.T_cur = epoch % self.T_0 + else: + n = int( + math.log( + (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult + ) + ) + self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / ( + self.T_mult - 1 + ) + self.T_i = self.T_0 * self.T_mult ** (n) + else: + self.T_i = self.T_0 + self.T_cur = epoch + self.last_epoch = math.floor(epoch) + + with _enable_get_lr_call(self): + for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): + param_group, lr = data + param_group["lr"] = lr + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + +class _SchedulePhase(TypedDict): + end_step: float + start_lr: str + end_lr: str + start_momentum: str + end_momentum: str + + +class OneCycleLR(LRScheduler): + r"""Sets the learning rate of each parameter group according to the 1cycle learning rate policy. + + The 1cycle policy anneals the learning rate from an initial learning rate to some maximum + learning rate and then from that maximum learning rate to some minimum learning rate much + lower than the initial learning rate. + This policy was initially described in the paper `Super-Convergence: + Very Fast Training of Neural Networks Using Large Learning Rates`_. + + The 1cycle learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This scheduler is not chainable. + + Note also that the total number of steps in the cycle can be determined in one + of two ways (listed in order of precedence): + + #. A value for total_steps is explicitly provided. + #. A number of epochs (epochs) and a number of steps per epoch + (steps_per_epoch) are provided. + In this case, the number of total steps is inferred by + total_steps = epochs * steps_per_epoch + + You must either provide a value for total_steps or provide a value for both + epochs and steps_per_epoch. + + The default behaviour of this scheduler follows the fastai implementation of 1cycle, which + claims that "unpublished work has shown even better results by using only two phases". To + mimic the behaviour of the original paper instead, set ``three_phase=True``. + + Args: + optimizer (Optimizer): Wrapped optimizer. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. + total_steps (int): The total number of steps in the cycle. Note that + if a value is not provided here, then it must be inferred by providing + a value for epochs and steps_per_epoch. + Default: None + epochs (int): The number of epochs to train for. This is used along + with steps_per_epoch in order to infer the total number of steps in the cycle + if a value for total_steps is not provided. + Default: None + steps_per_epoch (int): The number of steps per epoch to train for. This is + used along with epochs in order to infer the total number of steps in the + cycle if a value for total_steps is not provided. + Default: None + pct_start (float): The percentage of the cycle (in number of steps) spent + increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: "cos" for cosine annealing, "linear" for + linear annealing. + Default: 'cos' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.85 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.95 + div_factor (float): Determines the initial learning rate via + initial_lr = max_lr/div_factor + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor + Default: 1e4 + three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the + learning rate according to 'final_div_factor' instead of modifying the second + phase (the first two phases will be symmetrical about the step indicated by + 'pct_start'). + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + + Example: + >>> # xdoctest: +SKIP + >>> data_loader = torch.utils.data.DataLoader(...) + >>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) + >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> optimizer.step() + >>> scheduler.step() + + + .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: + https://arxiv.org/abs/1708.07120 + """ + + def __init__( + self, + optimizer: Optimizer, + max_lr: Union[float, List[float]], + total_steps: Optional[int] = None, + epochs: Optional[int] = None, + steps_per_epoch: Optional[int] = None, + pct_start=0.3, + anneal_strategy: Literal["cos", "linear"] = "cos", + cycle_momentum=True, + base_momentum: Union[float, List[float]] = 0.85, + max_momentum: Union[float, List[float]] = 0.95, + div_factor=25.0, + final_div_factor=1e4, + three_phase=False, + last_epoch=-1, + verbose="deprecated", + ): # noqa: D107 + # Validate optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") + self.optimizer = optimizer + + # Validate total_steps + if total_steps is not None: + if total_steps <= 0 or not isinstance(total_steps, int): + raise ValueError( + f"Expected positive integer total_steps, but got {total_steps}" + ) + self.total_steps = total_steps + elif epochs is not None and steps_per_epoch is not None: + if not isinstance(epochs, int) or epochs <= 0: + raise ValueError(f"Expected positive integer epochs, but got {epochs}") + if not isinstance(steps_per_epoch, int) or steps_per_epoch <= 0: + raise ValueError( + f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}" + ) + self.total_steps = epochs * steps_per_epoch + else: + raise ValueError( + "You must define either total_steps OR (epochs AND steps_per_epoch)" + ) + + self._schedule_phases: List[_SchedulePhase] + if three_phase: + self._schedule_phases = [ + { + "end_step": float(pct_start * self.total_steps) - 1, + "start_lr": "initial_lr", + "end_lr": "max_lr", + "start_momentum": "max_momentum", + "end_momentum": "base_momentum", + }, + { + "end_step": float(2 * pct_start * self.total_steps) - 2, + "start_lr": "max_lr", + "end_lr": "initial_lr", + "start_momentum": "base_momentum", + "end_momentum": "max_momentum", + }, + { + "end_step": self.total_steps - 1, + "start_lr": "initial_lr", + "end_lr": "min_lr", + "start_momentum": "max_momentum", + "end_momentum": "max_momentum", + }, + ] + else: + self._schedule_phases = [ + { + "end_step": float(pct_start * self.total_steps) - 1, + "start_lr": "initial_lr", + "end_lr": "max_lr", + "start_momentum": "max_momentum", + "end_momentum": "base_momentum", + }, + { + "end_step": self.total_steps - 1, + "start_lr": "max_lr", + "end_lr": "min_lr", + "start_momentum": "base_momentum", + "end_momentum": "max_momentum", + }, + ] + + # Validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError( + f"Expected float between 0 and 1 pct_start, but got {pct_start}" + ) + + # Validate anneal_strategy + if anneal_strategy not in ["cos", "linear"]: + raise ValueError( + f"anneal_strategy must be one of 'cos' or 'linear', instead got {anneal_strategy}" + ) + else: + self._anneal_func_type = anneal_strategy + + # Initialize learning rate variables + max_lrs = _format_param("max_lr", self.optimizer, max_lr) + if last_epoch == -1: + for idx, group in enumerate(self.optimizer.param_groups): + group["initial_lr"] = max_lrs[idx] / div_factor + group["max_lr"] = max_lrs[idx] + group["min_lr"] = group["initial_lr"] / final_div_factor + + # Initialize momentum variables + self.cycle_momentum = cycle_momentum + if self.cycle_momentum: + if ( + "momentum" not in self.optimizer.defaults + and "betas" not in self.optimizer.defaults + ): + raise ValueError( + "optimizer must support momentum or beta1 with `cycle_momentum` option enabled" + ) + self.use_beta1 = "betas" in self.optimizer.defaults + max_momentums = _format_param("max_momentum", optimizer, max_momentum) + base_momentums = _format_param("base_momentum", optimizer, base_momentum) + if last_epoch == -1: + for m_momentum, b_momentum, group in zip( + max_momentums, base_momentums, optimizer.param_groups + ): + if self.use_beta1: + group["betas"] = (m_momentum, *group["betas"][1:]) + else: + group["momentum"] = m_momentum + group["max_momentum"] = m_momentum + group["base_momentum"] = b_momentum + + super().__init__(optimizer, last_epoch, verbose) + + def _anneal_func(self, *args, **kwargs): + if hasattr(self, "_anneal_func_type"): + if self._anneal_func_type == "cos": + return self._annealing_cos(*args, **kwargs) + elif self._anneal_func_type == "linear": + return self._annealing_linear(*args, **kwargs) + else: + raise ValueError(f"Unknown _anneal_func_type: {self._anneal_func_type}") + else: + # For BC + return self.anneal_func(*args, **kwargs) # type: ignore[attr-defined] + + @staticmethod + def _annealing_cos(start, end, pct): + """Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + @staticmethod + def _annealing_linear(start, end, pct): + """Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" + return (end - start) * pct + start + + def get_lr(self): + """Compute the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + lrs = [] + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError( + f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}" # noqa: UP032 + ) + + for group in self.optimizer.param_groups: + start_step = 0.0 + for i, phase in enumerate(self._schedule_phases): + end_step = phase["end_step"] + if step_num <= end_step or i == len(self._schedule_phases) - 1: + pct = (step_num - start_step) / (end_step - start_step) + computed_lr = self._anneal_func( + group[phase["start_lr"]], group[phase["end_lr"]], pct + ) + if self.cycle_momentum: + computed_momentum = self._anneal_func( + group[phase["start_momentum"]], + group[phase["end_momentum"]], + pct, + ) + break + start_step = phase["end_step"] + + lrs.append(computed_lr) # type: ignore[possibly-undefined] + if self.cycle_momentum: + if self.use_beta1: + group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined] + else: + group[ + "momentum" + ] = computed_momentum # type: ignore[possibly-undefined] + + return lrs diff --git a/lib/python3.10/site-packages/torch/optim/nadam.py b/lib/python3.10/site-packages/torch/optim/nadam.py new file mode 100644 index 0000000000000000000000000000000000000000..e26b3bf302587fcdf7df9c1616551cf1687b740c --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/nadam.py @@ -0,0 +1,649 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +r"""Implementation for the NAdam algorithm.""" +from typing import cast, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _stack_if_compiling, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["NAdam", "nadam"] + + +class NAdam(Optimizer): # noqa: D101 + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 2e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + momentum_decay: float = 4e-3, + decoupled_weight_decay: bool = False, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + capturable: bool = False, + differentiable: bool = False, + ): # noqa: D107 + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= momentum_decay: + raise ValueError(f"Invalid momentum_decay value: {momentum_decay}") + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + momentum_decay=momentum_decay, + decoupled_weight_decay=decoupled_weight_decay, + maximize=maximize, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("capturable", False) + group.setdefault("differentiable", False) + group.setdefault("decoupled_weight_decay", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0: + if not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + if not torch.is_tensor(p_state["mu_product"]): + mu_prod_val = p_state["mu_product"] + p_state["mu_product"] = ( + torch.tensor( + mu_prod_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + mu_products, + state_steps, + ): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("NAdam does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + # note(crcrpar): [special device hosting for step] + # Deliberately host `step` and `mu_product` on CPU if capturable is False. + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + state["mu_product"] = ( + torch.ones((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.tensor(1.0, dtype=_get_scalar_dtype()) + ) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + mu_products.append(state["mu_product"]) + state_steps.append(state["step"]) + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + mu_products: List[Tensor] = [] + state_steps: List[Tensor] = [] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + has_complex = self._init_group( + group, + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + mu_products, + state_steps, + ) + + nadam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + mu_products, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + momentum_decay=group["momentum_decay"], + eps=group["eps"], + maximize=group["maximize"], + decoupled_weight_decay=group["decoupled_weight_decay"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + has_complex=has_complex, + ) + + return loss + + +NAdam.__doc__ = ( + r"""Implements NAdam algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)}, + \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\ + &\hspace{13mm} \: \textit{decoupled\_weight\_decay}, \:\textit{maximize} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} \\ + &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\ + &\hspace{15mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ + &\hspace{10mm}\textbf{else} \\ + &\hspace{15mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\ + &\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex] + & \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, Tensor, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + momentum_decay (float, optional): momentum momentum_decay (default: 4e-3) + decoupled_weight_decay (bool, optional): whether to use decoupled weight + decay as in AdamW to obtain NAdamW (default: False) + {_foreach_doc} + {_maximize_doc} + {_capturable_doc} + {_differentiable_doc} + + .. _Incorporating Nesterov Momentum into Adam: + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + + """ +) + + +def _single_tensor_nadam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + mu_products: List[Tensor], + state_steps: List[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + momentum_decay: float, + eps: float, + decoupled_weight_decay: bool, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + mu_product = mu_products[i] + step_t = state_steps[i] + + if torch.is_complex(param): + param = torch.view_as_real(param) + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == mu_product.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), ( + f"If capturable=True, params, mu_products and state_steps must be " + f"on supported devices: {capturable_supported_devices}." + ) + + # update step + step_t += 1 + + if capturable: + step = step_t + else: + step = _get_value(step_t) + + bias_correction2 = 1 - beta2**step + + if weight_decay != 0: + if decoupled_weight_decay: + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + # calculate the momentum cache \mu^{t} and \mu^{t+1} + mu = beta1 * (1.0 - 0.5 * (0.96 ** (step * momentum_decay))) + mu_next = beta1 * (1.0 - 0.5 * (0.96 ** ((step + 1) * momentum_decay))) + + # update mu_product + mu_product *= mu + + # decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = exp_avg_sq.div(bias_correction2).sqrt() + + if differentiable or capturable: + denom = denom.add(eps) + # Make autograd track the operations + # by updating the grad and exp_avg directly and not using the + # scalar "value" argument of addcdiv. + mu_product_next = mu_product * mu_next + grad = grad * (-lr * (1.0 - mu) / (1.0 - mu_product)) + exp_avg = exp_avg * (-lr * mu_next / (1.0 - mu_product_next)) + param.addcdiv_(grad, denom) + param.addcdiv_(exp_avg, denom) + else: + mu_product_next = _get_value(mu_product) * mu_next + denom.add_(eps) + param.addcdiv_( + grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product))) + ) + param.addcdiv_( + exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next) + ) + + +def _multi_tensor_nadam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + mu_products: List[Tensor], + state_steps: List[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + momentum_decay: float, + eps: float, + decoupled_weight_decay: bool, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == mp.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, mp, step in zip(params, mu_products, state_steps) + ), f"If capturable=True, params, mu_products, and state_steps must be on supported devices: {capturable_supported_devices}." + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps] # type: ignore[list-item] + ) + for ( + grouped_params_, + grouped_grads_, + grouped_exp_avgs_, + grouped_exp_avg_sqs_, + grouped_mu_products_, + grouped_state_steps_, + ), _ in grouped_tensors.values(): + grouped_params = cast(List[Tensor], grouped_params_) + grouped_grads = cast(List[Tensor], grouped_grads_) + grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_) + grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_) + grouped_mu_products = cast(List[Tensor], grouped_mu_products_) + grouped_state_steps = cast(List[Tensor], grouped_state_steps_) + + # handle complex + if has_complex: + _view_as_real( + grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs + ) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + if weight_decay != 0: + if decoupled_weight_decay: + # Perform stepweight decay + torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) + else: + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + if maximize: + torch._foreach_add_( + grouped_grads, grouped_params, alpha=weight_decay + ) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) + + # Decay the first and second moment running average coefficient + torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) + + torch._foreach_mul_(grouped_exp_avg_sqs, beta2) + torch._foreach_addcmul_( + grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2 + ) + + exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs) + + bias_correction_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] + mus: Union[Tuple[Tensor, ...], List[Tensor]] + mu_nexts: Union[Tuple[Tensor, ...], List[Tensor]] + if capturable: + # mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay)) + exponent = torch._foreach_mul(grouped_state_steps, momentum_decay) + mus = torch._foreach_pow(0.96, exponent) + torch._foreach_mul_(mus, -0.5) + torch._foreach_add_(mus, 1.0) + torch._foreach_mul_(mus, beta1) + + # mu_nexts will be beta1 * (1 - 0.5 * 0.96 ** ((step + 1) * momentum_decay)) + torch._foreach_add_(exponent, momentum_decay) + mu_nexts = torch._foreach_pow(0.96, exponent) + torch._foreach_mul_(mu_nexts, -0.5) + torch._foreach_add_(mu_nexts, 1.0) + torch._foreach_mul_(mu_nexts, beta1) + + # save peak memory as we don't need exponent anymore + del exponent + + bias_correction_sqrt = torch._foreach_pow(beta2, grouped_state_steps) + # foreach_sub doesn't allow a scalar as the first arg + torch._foreach_sub_(bias_correction_sqrt, 1.0) + torch._foreach_neg_(bias_correction_sqrt) + torch._foreach_sqrt_(bias_correction_sqrt) + else: + bias_correction_sqrt = [ + (1 - beta2 ** _get_value(step)) ** 0.5 for step in grouped_state_steps + ] + mus = [ + beta1 * (1.0 - 0.5 * (0.96 ** (_get_value(step) * momentum_decay))) + for step in grouped_state_steps + ] + mu_nexts = [ + beta1 + * (1.0 - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay))) + for step in grouped_state_steps + ] + + # update mu_products + torch._foreach_mul_(grouped_mu_products, mus) + + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt) + torch._foreach_add_(exp_avg_sq_sqrt, eps) + + # explicitly delete bias_correction refs to save memory + del bias_correction_sqrt + + if capturable: + # Build up the step_size multiplier for grad, reusing mus' memory + torch._foreach_sub_(mus, 1.0) + torch._foreach_mul_(mus, lr) + # foreach_sub doesn't allow a scalar as the first arg + denom = torch._foreach_sub(grouped_mu_products, 1.0) + torch._foreach_neg_(denom) + torch._foreach_div_(mus, denom) + # - lr * (1 - mu) / (1 - mu_product) + step_size_grads = mus + # explicitly delete denom to save memory + del denom + + # Build up the step_size multiplier for exp_avg, reusing mu_nexts' memory + denom = torch._foreach_mul(grouped_mu_products, mu_nexts) + torch._foreach_mul_(mu_nexts, lr) + # foreach_sub doesn't allow a scalar as the first arg, but it's okay because + # we need a negative here anyway + torch._foreach_sub_(denom, 1.0) + torch._foreach_div_(mu_nexts, denom) + # - lr * mu_next / (1 - mu_product * mu_next) + step_size_expavg = mu_nexts + # explicitly delete denom to save memory + del denom + + # we cannot inplace into step_size_grads cuz it is a list of ScalarTensors + # and mul'ing with grouped_grads will result in a list of bigger Tensors + numerator = torch._foreach_mul(step_size_grads, grouped_grads) + torch._foreach_addcmul_(numerator, step_size_expavg, grouped_exp_avgs) + + # finally, update params + torch._foreach_addcdiv_(grouped_params, numerator, exp_avg_sq_sqrt) + else: + step_size_grads = _stack_if_compiling( + [ + (_get_value(lr) * (1.0 - mu) / (1.0 - _get_value(mu_product))) * -1 + for mu_product, mu in zip(grouped_mu_products, mus) + ] + ) + step_size_expavg = _stack_if_compiling( + [ + ( + _get_value(lr) + * mu_next + / (1.0 - _get_value(mu_product) * mu_next) + ) + * -1 + for mu_product, mu_next in zip(grouped_mu_products, mu_nexts) + ] + ) + + torch._foreach_addcdiv_( + grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads # type: ignore[arg-type] + ) + torch._foreach_addcdiv_( + grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg # type: ignore[arg-type] + ) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam) +def nadam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + mu_products: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + decoupled_weight_decay: bool = False, + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + has_complex: bool = False, + maximize: bool = False, + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + momentum_decay: float, + eps: float, +): + r"""Functional API that performs NAdam algorithm computation. + + See :class:`~torch.optim.NAdam` for details. + """ + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if not all(isinstance(t, torch.Tensor) for t in mu_products): + raise RuntimeError( + "API has changed, `mu_products` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_nadam + else: + func = _single_tensor_nadam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + mu_products, + state_steps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + momentum_decay=momentum_decay, + maximize=maximize, + decoupled_weight_decay=decoupled_weight_decay, + eps=eps, + capturable=capturable, + differentiable=differentiable, + has_complex=has_complex, + ) diff --git a/lib/python3.10/site-packages/torch/optim/optimizer.py b/lib/python3.10/site-packages/torch/optim/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8f7993842c1009d0565d94a8d6d3187694fa236b --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/optimizer.py @@ -0,0 +1,1052 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +"""Base optimizer.""" +import functools +import warnings +from collections import defaultdict, OrderedDict +from copy import deepcopy +from itertools import chain +from typing import ( + Any, + Callable, + cast, + DefaultDict, + Dict, + Hashable, + Iterable, + List, + Optional, + overload, + Set, + Tuple, + TypeVar, + Union, +) +from typing_extensions import ParamSpec, Self, TypeAlias + +import torch +import torch.utils.hooks as hooks +from torch._utils import is_compiling +from torch.utils._foreach_utils import ( + _get_foreach_kernels_supported_devices, + _get_fused_kernels_supported_devices, + _group_tensors_by_device_and_dtype, + Indices, + TensorListList, +) +from torch.utils.hooks import RemovableHandle + + +Args: TypeAlias = Tuple[Any, ...] +Kwargs: TypeAlias = Dict[str, Any] +StateDict: TypeAlias = Dict[str, Any] +DeviceDict = Dict[Optional[torch.device], torch.Tensor] + + +GlobalOptimizerPreHook: TypeAlias = Callable[ + ["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]] +] +GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None] + +__all__ = [ + "Optimizer", + "register_optimizer_step_pre_hook", + "register_optimizer_step_post_hook", +] +_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict() +_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict() +_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter] + + +class _RequiredParameter: + """Singleton class representing a required parameter for an Optimizer.""" + + def __repr__(self) -> str: + return "" + + +required = _RequiredParameter() + + +def _use_grad_for_differentiable(func): + def _use_grad(self, *args, **kwargs): + import torch._dynamo + + prev_grad = torch.is_grad_enabled() + try: + # Note on graph break below: + # we need to graph break to ensure that aot respects the no_grad annotation. + # This is important for perf because without this, functionalization will generate an epilogue + # which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result, + # inductor will allocate for every parameter in the model, which is horrible. + # With this, aot correctly sees that this is an inference graph, and functionalization will generate + # an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that + # step is in place and is able to avoid the extra allocation. + # In the future, we will either 1) continue to graph break on backward, so this graph break does not matter + # or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this + # graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled. + # see https://github.com/pytorch/pytorch/issues/104053 + torch.set_grad_enabled(self.defaults["differentiable"]) + torch._dynamo.graph_break() + ret = func(self, *args, **kwargs) + finally: + torch._dynamo.graph_break() + torch.set_grad_enabled(prev_grad) + return ret + + functools.update_wrapper(_use_grad, func) + return _use_grad + + +def _get_value(x): + # item is significantly faster than a cpu tensor in eager mode + if not torch.jit.is_scripting() and is_compiling(): + return x + else: + return x.item() if isinstance(x, torch.Tensor) else x + + +def _stack_if_compiling(x): + if not torch.jit.is_scripting() and is_compiling(): + return torch.stack(x) + else: + return x + + +def _disable_dynamo_if_unsupported(single_tensor_fn=None): + # workaround for torchscript BC + # it requires all called functions to be in the + # global environment at the site at which the + # maybe_fallback closure is created + if single_tensor_fn: + globals()[single_tensor_fn.__name__] = single_tensor_fn + + def wrapper(func): + import inspect + + disabled_func = torch._disable_dynamo(func) + ps = inspect.signature(func).parameters + has_state_steps = True + try: + state_steps_ind = list(ps.keys()).index("state_steps") + except ValueError: + has_state_steps = False + + # Today, there are cases where we stack state steps + # and pass them as the value arg of foreach ops. + # Having state steps on cuda as the value arg is not supported in eager, + # but this only occurs in the rare case that the user explicitly deletes + # the capturable flag. If capturable=True, this is not a problem. + @functools.wraps(func) + def maybe_fallback(*args, **kwargs): + if is_compiling() and ( + not kwargs.get("capturable", False) + and has_state_steps + and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda) + or ( + "state_steps" in kwargs + and kwargs["state_steps"] + and kwargs["state_steps"][0].is_cuda + ) + ): + return disabled_func(*args, **kwargs) + else: + return func(*args, **kwargs) + + return maybe_fallback + + return wrapper + + +# For any optimizer with a faster implementation, we attempt to default to the +# fastest + stablest whenever possible. For foreach, the requirements are to have +# native params all on CUDA. For fused, there's currently the additional requirement +# that the tensors' dtypes must be floating point. Neither alternative supports +# torch.jit.script nor differentiable, so we fall back to the single tensor +# implementation in those cases. +def _default_to_fused_or_foreach( + params: List[torch.Tensor], differentiable: bool, use_fused: bool = False +) -> Tuple[bool, bool]: + if torch.jit.is_scripting() or differentiable: + return False, False + + fused_supported_devices = _get_fused_kernels_supported_devices() + foreach_supported_devices = _get_foreach_kernels_supported_devices() + fused = use_fused and all( + p is None + or ( + type(p) in _foreach_supported_types + and p.device.type in fused_supported_devices + and torch.is_floating_point(p) + ) + for p in params + ) + foreach = not fused and all( + p is None + or ( + type(p) in _foreach_supported_types + and p.device.type in foreach_supported_devices + ) + for p in params + ) + return fused, foreach + + +def _device_dtype_check_for_fused( + p: torch.Tensor, cuda_unsupported: bool = False +) -> None: + fused_supported_devices = _get_fused_kernels_supported_devices() + if cuda_unsupported: + fused_supported_devices.remove("cuda") + if not (p.device.type in fused_supported_devices and torch.is_floating_point(p)): + raise RuntimeError( + "`fused=True` requires all the params to be floating point Tensors of " + f"supported devices: {fused_supported_devices} but {p.dtype} and {p.device.type}" + ) + + +def _view_as_real(params, *state_and_grads): + for i, p in enumerate(params): + if torch.is_complex(p): + params[i] = torch.view_as_real(params[i]) + for s in state_and_grads: + s[i] = torch.view_as_real(s[i]) + + +def _get_scalar_dtype(is_fused=None): + if is_fused: + return torch.float32 + return ( + torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32 + ) + + +def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]: + r"""Return the device type list that supports capturable optimizer.""" + capturable_supported_devices = ["cuda", "xpu", "hpu"] + if not torch.jit.is_scripting(): + capturable_supported_devices.append(torch._C._get_privateuse1_backend_name()) + if supports_xla: + capturable_supported_devices.append("xla") + return capturable_supported_devices + + +# Common doc strings among optimizers +_foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer + is used. If unspecified by the user (so foreach is None), we will try to use + foreach over the for-loop implementation on CUDA, since it is usually + significantly more performant. Note that the foreach implementation uses + ~ sizeof(params) more peak memory than the for-loop version due to the intermediates + being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer + parameters through the optimizer at a time or switch this flag to False (default: None)""" + +_fused_doc = r"""fused (bool, optional): whether the fused implementation is used. + Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` + are supported. (default: None) + + .. note:: The foreach and fused implementations are typically faster than the for-loop, + single-tensor implementation, with fused being theoretically fastest with both + vertical and horizontal fusion. As such, if the user has not specified either + flag (i.e., when foreach = fused = None), we will attempt defaulting to the foreach + implementation when the tensors are all on CUDA. Why not fused? Since the fused + implementation is relatively new, we want to give it sufficient bake-in time. + To specify fused, pass True for fused. To force running the for-loop + implementation, pass False for either foreach or fused. """ + +_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to + capture in a CUDA graph. Passing True can impair ungraphed performance, + so if you don't intend to graph capture this instance, leave it False + (default: False)""" + +_differentiable_doc = r"""differentiable (bool, optional): whether autograd should + occur through the optimizer step in training. Otherwise, the step() + function runs in a torch.no_grad() context. Setting to True can impair + performance, so leave it False if you don't intend to run autograd + through this instance (default: False)""" + +_maximize_doc = r"""maximize (bool, optional): maximize the objective with respect to the + params, instead of minimizing (default: False)""" + + +def register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle: + r"""Register a pre hook common to all optimizers. + + The hook should have the following signature:: + + hook(optimizer, args, kwargs) -> None or modified args and kwargs + + Args: + hook (Callable): A user defined hook which is registered on all optimizers. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(_global_optimizer_pre_hooks) + _global_optimizer_pre_hooks[handle.id] = hook + return handle + + +def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle: + r"""Register a post hook common to all optimizers. + + The hook should have the following signature:: + + hook(optimizer, args, kwargs) -> None + + Args: + hook (Callable): A user defined hook which is registered on all optimizers. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(_global_optimizer_post_hooks) + _global_optimizer_post_hooks[handle.id] = hook + return handle + + +ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] + +_P = ParamSpec("_P") +R = TypeVar("R") +T = TypeVar("T") + + +class Optimizer: + r"""Base class for all optimizers. + + .. warning:: + Parameters need to be specified as collections that have a deterministic + ordering that is consistent between runs. Examples of objects that don't + satisfy those properties are sets and iterators over values of dictionaries. + + Args: + params (iterable): an iterable of :class:`torch.Tensor` s or + :class:`dict` s. Specifies what Tensors should be optimized. + defaults: (dict): a dict containing default values of optimization + options (used when a parameter group doesn't specify them). + """ + + OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]] # type: ignore[misc] + OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc] + + _optimizer_step_pre_hooks: Dict[int, OptimizerPreHook] + _optimizer_step_post_hooks: Dict[int, OptimizerPostHook] + _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' + _optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' + _optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' + _optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' + + def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None: # noqa: D107 + torch._C._log_api_usage_once("python.optimizer") + self.defaults = defaults + self._optimizer_step_pre_hooks = OrderedDict() + self._optimizer_step_post_hooks = OrderedDict() + self._optimizer_state_dict_pre_hooks = OrderedDict() + self._optimizer_state_dict_post_hooks = OrderedDict() + self._optimizer_load_state_dict_pre_hooks = OrderedDict() + self._optimizer_load_state_dict_post_hooks = OrderedDict() + + self._patch_step_function() + + if isinstance(params, torch.Tensor): + raise TypeError( + "params argument given to the optimizer should be " + "an iterable of Tensors or dicts, but got " + torch.typename(params) + ) + + self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict) + self.param_groups: List[Dict[str, Any]] = [] + + param_groups = list(params) + if len(param_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + if not isinstance(param_groups[0], dict): + param_groups = [{"params": param_groups}] + + for param_group in param_groups: + self.add_param_group(cast(dict, param_group)) + + # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python, + # which I don't think exists + # https://github.com/pytorch/pytorch/issues/72948 + self._warned_capturable_if_run_uncaptured = True + + def __getstate__(self) -> Dict[str, Any]: # noqa: D105 + return { + "defaults": self.defaults, + "state": self.state, + "param_groups": self.param_groups, + } + + def __setstate__(self, state: Dict[str, Any]) -> None: # noqa: D105 + self.__dict__.update(state) + if "_optimizer_step_pre_hooks" not in self.__dict__: + self._optimizer_step_pre_hooks = OrderedDict() + if "_optimizer_step_post_hooks" not in self.__dict__: + self._optimizer_step_post_hooks = OrderedDict() + if "_optimizer_state_dict_pre_hooks" not in self.__dict__: + self._optimizer_state_dict_pre_hooks = OrderedDict() + if "_optimizer_state_dict_post_hooks" not in self.__dict__: + self._optimizer_state_dict_post_hooks = OrderedDict() + if "_optimizer_load_state_dict_pre_hooks" not in self.__dict__: + self._optimizer_load_state_dict_pre_hooks = OrderedDict() + if "_optimizer_load_state_dict_post_hooks" not in self.__dict__: + self._optimizer_load_state_dict_post_hooks = OrderedDict() + self._patch_step_function() # To support multiprocessing pickle/unpickle + self.defaults.setdefault("differentiable", False) + + def __repr__(self) -> str: # noqa: D105 + format_string = self.__class__.__name__ + " (" + for i, group in enumerate(self.param_groups): + format_string += "\n" + format_string += f"Parameter Group {i}\n" + for key in sorted(group.keys()): + if key != "params": + format_string += f" {key}: {group[key]}\n" + format_string += ")" + return format_string + + # Currently needed by Adam and AdamW + def _cuda_graph_capture_health_check(self) -> None: + # Note [torch.compile x capturable] + # If we are compiling, we try to take the capturable path automatically by + # setting the flag to True during tracing. Due to this, we skip all the checks + # normally required for determining whether we can use CUDA graphs and + # shunt the responsibility to torch.inductor. This saves time during tracing + # since the checks are slow without sacrificing UX since inductor will warn + # later if CUDA graphs cannot be enabled, e.g., + # https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390. + # Thus, when compiling, inductor will determine if cudagraphs + # can be enabled based on whether there is input mutation or CPU tensors. + if ( + not is_compiling() + and torch.backends.cuda.is_built() + and torch.cuda.is_available() + ): + capturing = torch.cuda.is_current_stream_capturing() + + if capturing and not all( + group["capturable"] for group in self.param_groups + ): + raise RuntimeError( + "Attempting CUDA graph capture of step() for an instance of " + + self.__class__.__name__ + + " but param_groups' capturable is False." + ) + + if ( + (not getattr(self, "_warned_capturable_if_run_uncaptured", False)) + and all(group["capturable"] for group in self.param_groups) + and (not capturing) + ): + warnings.warn( + "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, " + "but step() is running without CUDA graph capture. If you never intend to graph-capture this " + "instance, capturable=True can impair performance, and you should set capturable=False." + ) + self._warned_capturable_if_run_uncaptured = True + + def _optimizer_step_code(self) -> None: + """Entry point for `torch.profile.profiler`. + + When python tracing is enabled the profiler will hook into this + function at the CPython level to inspect the optimizer's parameters and + param groups. It is called it after `step()` since many optimizers + lazily initialize state. + + This is a workaround due to lack of a proper step hook on the optimizer, + and will be removed if it exists. + """ + + @staticmethod + def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]: # noqa: D102 + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R: + self, *_ = args + self = cast(Optimizer, self) + profile_name = f"Optimizer.step#{self.__class__.__name__}.step" + with torch.autograd.profiler.record_function(profile_name): + # call optimizer step pre hooks + for pre_hook in chain( + _global_optimizer_pre_hooks.values(), + self._optimizer_step_pre_hooks.values(), + ): + result = pre_hook(self, args, kwargs) + if result is not None: + if isinstance(result, tuple) and len(result) == 2: + args, kwargs = result # type: ignore[assignment] + else: + raise RuntimeError( + f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}." + ) + + out = func(*args, **kwargs) + self._optimizer_step_code() + + # call optimizer step post hooks + for post_hook in chain( + self._optimizer_step_post_hooks.values(), + _global_optimizer_post_hooks.values(), + ): + post_hook(self, args, kwargs) + + return out + + return wrapper + + @staticmethod + def _group_tensors_by_device_and_dtype( + tensorlistlist: TensorListList, + with_indices: bool = False, + ) -> Union[ + Dict[Tuple[None, None], Tuple[TensorListList, Indices]], + Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]], + ]: + """Group a list of lists of tensors by device and dtype. + + Skips this step if we are compiling since this will occur during inductor lowering. + """ + if is_compiling(): + return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} + else: + return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type] + + def _patch_step_function(self) -> None: + self._zero_grad_profile_name = ( + f"Optimizer.zero_grad#{self.__class__.__name__}.zero_grad" + ) + hooked = getattr(self.__class__.step, "hooked", None) + if not hooked: + self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[assignment] + self.__class__.step.hooked = True # type: ignore[attr-defined] + + def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle: + r"""Register an optimizer step pre hook which will be called before optimizer step. + + It should have the following signature:: + + hook(optimizer, args, kwargs) -> None or modified args and kwargs + + The ``optimizer`` argument is the optimizer instance being used. If + args and kwargs are modified by the pre-hook, then the transformed + values are returned as a tuple containing the new_args and new_kwargs. + + Args: + hook (Callable): The user defined hook to be registered. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks) + self._optimizer_step_pre_hooks[handle.id] = hook + return handle + + def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle: + r"""Register an optimizer step post hook which will be called after optimizer step. + + It should have the following signature:: + + hook(optimizer, args, kwargs) -> None + + The ``optimizer`` argument is the optimizer instance being used. + + Args: + hook (Callable): The user defined hook to be registered. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_step_post_hooks) + self._optimizer_step_post_hooks[handle.id] = hook + return handle + + def register_state_dict_pre_hook( + self, hook: Callable[["Optimizer"], None], prepend: bool = False + ) -> RemovableHandle: # noqa: D101 + r"""Register a state dict pre-hook which will be called before :meth:`~torch.optim.Optimizer.state_dict` is called. + + It should have the following signature:: + + hook(optimizer) -> None + + The ``optimizer`` argument is the optimizer instance being used. + The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``. + The registered hook can be used to perform pre-processing before the ``state_dict`` + call is made. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided pre ``hook`` will be fired before + all the already registered pre-hooks on ``state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + pre-hooks. (default: False) + + Returns: + :class:`torch.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks) + self._optimizer_state_dict_pre_hooks[handle.id] = hook + if prepend: + self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False) + return handle + + def register_state_dict_post_hook( + self, + hook: Callable[["Optimizer", StateDict], Optional[StateDict]], + prepend: bool = False, + ) -> RemovableHandle: + r"""Register a state dict post-hook which will be called after :meth:`~torch.optim.Optimizer.state_dict` is called. + + It should have the following signature:: + + hook(optimizer, state_dict) -> state_dict or None + + The hook will be called with arguments ``self`` and ``state_dict`` after generating + a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally + return a new one. The registered hook can be used to perform post-processing + on the ``state_dict`` before it is returned. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided post ``hook`` will be fired before + all the already registered post-hooks on ``state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + post-hooks. (default: False) + + Returns: + :class:`torch.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks) + self._optimizer_state_dict_post_hooks[handle.id] = hook + if prepend: + self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False) + return handle + + @torch._disable_dynamo + def state_dict(self) -> StateDict: + r"""Return the state of the optimizer as a :class:`dict`. + + It contains two entries: + + * ``state``: a Dict holding current optimization state. Its content + differs between optimizer classes, but some common characteristics + hold. For example, state is saved per parameter, and the parameter + itself is NOT saved. ``state`` is a Dictionary mapping parameter ids + to a Dict with state corresponding to each parameter. + * ``param_groups``: a List containing all parameter groups where each + parameter group is a Dict. Each parameter group contains metadata + specific to the optimizer, such as learning rate and weight decay, + as well as a List of parameter IDs of the parameters in the group. + + NOTE: The parameter IDs may look like indices but they are just IDs + associating state with param_group. When loading from a state_dict, + the optimizer will zip the param_group ``params`` (int IDs) and the + optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to + match state WITHOUT additional verification. + + A returned state dict might look something like: + + .. code-block:: text + + { + 'state': { + 0: {'momentum_buffer': tensor(...), ...}, + 1: {'momentum_buffer': tensor(...), ...}, + 2: {'momentum_buffer': tensor(...), ...}, + 3: {'momentum_buffer': tensor(...), ...} + }, + 'param_groups': [ + { + 'lr': 0.01, + 'weight_decay': 0, + ... + 'params': [0] + }, + { + 'lr': 0.001, + 'weight_decay': 0.5, + ... + 'params': [1, 2, 3] + } + ] + } + + """ + for pre_hook in self._optimizer_state_dict_pre_hooks.values(): + pre_hook(self) + + # Save order indices instead of Tensors + param_mappings: Dict[int, int] = {} + start_index = 0 + + def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + param_mappings.update( + { + id(p): i + for i, p in enumerate(group["params"], start_index) + if id(p) not in param_mappings + } + ) + packed["params"] = [param_mappings[id(p)] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + param_groups = [pack_group(g) for g in self.param_groups] + # Remap state to use order indices as keys + packed_state = { + (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v + for k, v in self.state.items() + } + + state_dict = { + "state": packed_state, + "param_groups": param_groups, + } + + for post_hook in self._optimizer_state_dict_post_hooks.values(): + hook_result = post_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + return state_dict + + @staticmethod + def _process_value_according_to_param_policy( + param: torch.Tensor, + value: torch.Tensor, + param_id: int, + param_groups: List[Dict[Any, Any]], + key: Hashable = None, + ) -> torch.Tensor: + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 + # UNLESS fused or capturable, see note [special device hosting for step] + fused = False + capturable = False + assert param_groups is not None + for pg in param_groups: + if param_id in pg["params"]: + fused = pg["fused"] if "fused" in pg else False + capturable = pg["capturable"] if "capturable" in pg else False + break + if key == "step": + if capturable or fused: + return value.to(dtype=torch.float32, device=param.device) + else: + return value + else: + if param.is_floating_point(): + return value.to(dtype=param.dtype, device=param.device) + else: + return value.to(device=param.device) + + def register_load_state_dict_pre_hook( + self, + hook: Callable[["Optimizer", StateDict], Optional[StateDict]], + prepend: bool = False, + ) -> RemovableHandle: # noqa: D205 D400 + r"""Register a load_state_dict pre-hook which will be called before + :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the + following signature:: + + hook(optimizer, state_dict) -> state_dict or None + + The ``optimizer`` argument is the optimizer instance being used and the + ``state_dict`` argument is a shallow copy of the ``state_dict`` the user + passed in to ``load_state_dict``. The hook may modify the state_dict inplace + or optionally return a new one. If a state_dict is returned, it will be used + to be loaded into the optimizer. + + The hook will be called with argument ``self`` and ``state_dict`` before + calling ``load_state_dict`` on ``self``. The registered hook can be used to + perform pre-processing before the ``load_state_dict`` call is made. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided pre ``hook`` will be fired before + all the already registered pre-hooks on ``load_state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + pre-hooks. (default: False) + + Returns: + :class:`torch.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks) + self._optimizer_load_state_dict_pre_hooks[handle.id] = hook + if prepend: + self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False) + return handle + + def register_load_state_dict_post_hook( + self, hook: Callable[["Optimizer"], None], prepend: bool = False + ) -> RemovableHandle: # noqa: D205 D400 + r"""Register a load_state_dict post-hook which will be called after + :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the + following signature:: + + hook(optimizer) -> None + + The ``optimizer`` argument is the optimizer instance being used. + + The hook will be called with argument ``self`` after calling + ``load_state_dict`` on ``self``. The registered hook can be used to + perform post-processing after ``load_state_dict`` has loaded the + ``state_dict``. + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If True, the provided post ``hook`` will be fired before + all the already registered post-hooks on ``load_state_dict``. Otherwise, + the provided ``hook`` will be fired after all the already registered + post-hooks. (default: False) + + Returns: + :class:`torch.utils.hooks.RemoveableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks) + self._optimizer_load_state_dict_post_hooks[handle.id] = hook + if prepend: + self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + @torch._disable_dynamo + def load_state_dict(self, state_dict: StateDict) -> None: + r"""Load the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # shallow copy, to be consistent with module API + state_dict = state_dict.copy() + + for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): + hook_result = pre_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + + # Validate the state_dict + groups = self.param_groups + + # Deepcopy as we write into saved_groups later to update state + saved_groups = deepcopy(state_dict["param_groups"]) + + if len(groups) != len(saved_groups): + raise ValueError( + "loaded state dict has a different number of " "parameter groups" + ) + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = dict( + zip( + chain.from_iterable(g["params"] for g in saved_groups), + chain.from_iterable(g["params"] for g in groups), + ) + ) + + def _cast(param, value, param_id=None, param_groups=None, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + return Optimizer._process_value_according_to_param_policy( + param, value, param_id, param_groups, key + ) + elif isinstance(value, dict): + return { + k: _cast( + param, v, param_id=param_id, param_groups=param_groups, key=k + ) + for k, v in value.items() + } + elif isinstance(value, Iterable): + return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + state[param] = _cast( + param, v, param_id=k, param_groups=state_dict["param_groups"] + ) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group( + group: Dict[str, Any], new_group: Dict[str, Any] + ) -> Dict[str, Any]: + new_group["params"] = group["params"] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + for post_hook in self._optimizer_load_state_dict_post_hooks.values(): + post_hook(self) + + @torch._disable_dynamo + def zero_grad(self, set_to_none: bool = True) -> None: + r"""Reset the gradients of all optimized :class:`torch.Tensor` s. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + This will in general have lower memory footprint, and can modestly improve performance. + However, it changes certain behaviors. For example: + 1. When the user tries to access a gradient and perform manual ops on it, + a None attribute or a Tensor full of 0s will behave differently. + 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s + are guaranteed to be None for params that did not receive a gradient. + 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None + (in one case it does the step with a gradient of 0 and in the other it skips + the step altogether). + """ + foreach = self.defaults.get("foreach", False) or self.defaults.get( + "fused", False + ) + + if not hasattr(self, "_zero_grad_profile_name"): + self._patch_step_function() + + per_device_and_dtype_grads: Optional[ + DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]] + ] + if foreach: + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + else: + per_device_and_dtype_grads = None + + with torch.autograd.profiler.record_function(self._zero_grad_profile_name): + for group in self.param_groups: + for p in group["params"]: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + if not foreach or p.grad.is_sparse: + p.grad.zero_() + else: + assert per_device_and_dtype_grads is not None + per_device_and_dtype_grads[p.grad.device][ + p.grad.dtype + ].append(p.grad) + if foreach: + assert per_device_and_dtype_grads is not None + for per_dtype_grads in per_device_and_dtype_grads.values(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) + + @overload + def step(self, closure: None = ...) -> None: + ... + + @overload + def step(self, closure: Callable[[], float]) -> float: + ... + + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + r"""Perform a single optimization step to update parameter. + + Args: + closure (Callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + + .. note:: + Unless otherwise specified, this function should not modify the + ``.grad`` field of the parameters. + """ + raise NotImplementedError + + @torch._disable_dynamo + def add_param_group(self, param_group: Dict[str, Any]) -> None: + r"""Add a param group to the :class:`Optimizer` s `param_groups`. + + This can be useful when fine tuning a pre-trained network as frozen layers can be made + trainable and added to the :class:`Optimizer` as training progresses. + + Args: + param_group (dict): Specifies what Tensors should be optimized along with group + specific optimization options. + """ + if not isinstance(param_group, dict): + raise TypeError(f"param_group must be a dict, but got {type(param_group)}") + + params = param_group["params"] + if isinstance(params, torch.Tensor): + param_group["params"] = [params] + elif isinstance(params, set): + raise TypeError( + "optimizer parameters need to be organized in ordered collections, but " + "the ordering of tensors in sets will change between runs. Please use a list instead." + ) + else: + param_group["params"] = list(params) + + for param in param_group["params"]: + if not isinstance(param, torch.Tensor): + raise TypeError( + "optimizer can only optimize Tensors, " + "but one of the params is " + torch.typename(param) + ) + if not self.defaults.get("differentiable", None) and not ( + param.is_leaf or param.retains_grad + ): + raise ValueError("can't optimize a non-leaf Tensor") + + for name, default in self.defaults.items(): + if default is required and name not in param_group: + raise ValueError( + f"parameter group didn't specify a value of required optimization parameter {name}" + ) + else: + param_group.setdefault(name, default) + + params = param_group["params"] + if len(params) != len(set(params)): + warnings.warn( + "optimizer contains a parameter group with duplicate parameters; " + "in future, this will cause an error; " + "see github.com/pytorch/pytorch/issues/40967 for more information", + stacklevel=3, + ) + + param_set: Set[torch.Tensor] = set() + for group in self.param_groups: + param_set.update(set(group["params"])) + + if not param_set.isdisjoint(set(param_group["params"])): + raise ValueError("some parameters appear in more than one parameter group") + + self.param_groups.append(param_group) diff --git a/lib/python3.10/site-packages/torch/optim/radam.py b/lib/python3.10/site-packages/torch/optim/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d0c31a91736554e72c0664f5bc1ffbd85c3b75 --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/radam.py @@ -0,0 +1,608 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +r"""Implementation for the RAdam algorithm.""" +from typing import cast, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _get_value, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["RAdam", "radam"] + + +class RAdam(Optimizer): # noqa: D101 + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + decoupled_weight_decay: bool = False, + *, + foreach: Optional[bool] = None, + maximize: bool = False, + capturable: bool = False, + differentiable: bool = False, + ): # noqa: D107 + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + maximize=maximize, + foreach=foreach, + capturable=capturable, + decoupled_weight_decay=decoupled_weight_decay, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("decoupled_weight_decay", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps + ): + has_complex = False + for p in group["params"]: + if p.grad is not None: + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + state_steps.append(state["step"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + state_steps: List[Tensor] = [] + beta1, beta2 = cast(Tuple[float, float], group["betas"]) + + has_complex = self._init_group( + group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps + ) + + radam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + foreach=group["foreach"], + capturable=group["capturable"], + differentiable=group["differentiable"], + decoupled_weight_decay=group["decoupled_weight_decay"], + has_complex=has_complex, + ) + + return loss + + +RAdam.__doc__ = ( + r"""Implements RAdam algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2 + \text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \: + \lambda \text{ (weightdecay)}, \:\textit{maximize} \\ + &\hspace{13mm} \epsilon \text{ (epsilon)}, \textit{decoupled\_weight\_decay} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0 \leftarrow 0 \text{ ( second moment)}, \\ + &\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{6mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{12mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{6mm}\textbf{else} \\ + &\hspace{12mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{6mm} \theta_t \leftarrow \theta_{t-1} \\ + &\hspace{6mm} \textbf{if} \: \lambda \neq 0 \\ + &\hspace{12mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\ + &\hspace{18mm} \theta_t \leftarrow \theta_{t} - \gamma \lambda \theta_{t} \\ + &\hspace{12mm}\textbf{else} \\ + &\hspace{18mm} g_t \leftarrow g_t + \lambda \theta_{t} \\ + &\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{6mm}\rho_t \leftarrow \rho_{\infty} - + 2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex] + &\hspace{6mm}\textbf{if} \: \rho_t > 5 \\ + &\hspace{12mm} l_t \leftarrow \frac{\sqrt{ (1-\beta^t_2) }}{ \sqrt{v_t} +\epsilon } \\ + &\hspace{12mm} r_t \leftarrow + \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\ + &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} r_t l_t \\ + &\hspace{6mm}\textbf{else} \\ + &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_. + + This implementation provides an option to use either the original weight_decay implementation as in Adam + (where the weight_decay is applied to the gradient) or the one from AdamW (where weight_decay is applied + to the weight) through the decoupled_weight_decay option. When decoupled_weight_decay is set to False + (default), it uses the original Adam style weight decay, otherwise, it uses the AdamW style which + corresponds more closely to the `author's implementation`_ in the RAdam paper. Further information + about decoupled weight decay can be found in `Decoupled Weight Decay Regularization`_. + + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, Tensor, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + decoupled_weight_decay (bool, optional): whether to use decoupled weight + decay as in AdamW to obtain RAdamW (default: False) + {_foreach_doc} + {_maximize_doc} + {_differentiable_doc} + {_capturable_doc} + + .. _On the variance of the adaptive learning rate and beyond: + https://arxiv.org/abs/1908.03265 + .. _author's implementation: + https://github.com/LiyuanLucasLiu/RAdam + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + + """ +) + + +def _single_tensor_radam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + decoupled_weight_decay: bool, + differentiable: bool, + maximize: bool, + capturable: bool, + has_complex: bool, +): + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step_t.device.type + and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + if torch.is_complex(param): + param = torch.view_as_real(param) + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + + # update step + step_t += 1 + step = step_t if capturable else _get_value(step_t) + + if weight_decay != 0: + if decoupled_weight_decay: + param.mul_(1 - lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + # correcting bias for the first moving moment + bias_corrected_exp_avg = exp_avg / bias_correction1 + + # maximum length of the approximated SMA + rho_inf = 2 / (1 - beta2) - 1 + # compute the length of the approximated SMA + rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2 + + def _compute_rect(): + return ( + (rho_t - 4) + * (rho_t - 2) + * rho_inf + / ((rho_inf - 4) * (rho_inf - 2) * rho_t) + ) ** 0.5 + + def _compute_adaptive_lr(): + exp_avg_sq_sqrt = exp_avg_sq.sqrt() + if differentiable: + exp_avg_sq_sqrt = exp_avg_sq_sqrt.add(eps) + else: + exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps) + + return (bias_correction2**0.5) / exp_avg_sq_sqrt + + # Compute the variance rectification term and update parameters accordingly + if capturable: + update = torch.where( + rho_t > 5.0, _compute_rect() * _compute_adaptive_lr(), 1.0 + ) + param.add_(bias_corrected_exp_avg * lr * update, alpha=-1.0) + else: + if rho_t > 5.0: + param.add_( + bias_corrected_exp_avg + * lr + * _compute_adaptive_lr() + * _compute_rect(), + alpha=-1.0, + ) + else: + param.add_(bias_corrected_exp_avg * lr, alpha=-1.0) + + +def _multi_tensor_radam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + decoupled_weight_decay: bool, + differentiable: bool, + maximize: bool, + capturable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices( + supports_xla=False + ) + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item] + ) + for ( + grouped_params_, + grouped_grads_, + grouped_exp_avgs_, + grouped_exp_avg_sqs_, + grouped_state_steps_, + ), _ in grouped_tensors.values(): + grouped_params = cast(List[Tensor], grouped_params_) + grouped_grads = cast(List[Tensor], grouped_grads_) + grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_) + grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_) + grouped_state_steps = cast(List[Tensor], grouped_state_steps_) + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + if has_complex: + _view_as_real( + grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs + ) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # maximum length of the approximated SMA + rho_inf = 2 / (1 - beta2) - 1 + # compute the length of the approximated SMA + bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] + bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] + rho_t_list: Union[Tuple[Tensor, ...], List[Tensor]] + if capturable: + bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps) + torch._foreach_neg_(bias_correction1) + torch._foreach_add_(bias_correction1, 1) + bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps) + torch._foreach_mul_(bias_correction2, grouped_state_steps) + torch._foreach_mul_(bias_correction2, 2) + torch._foreach_div_(bias_correction2, bias_correction1) + torch._foreach_neg_(bias_correction2) + torch._foreach_add_(bias_correction2, rho_inf) + rho_t_list = bias_correction2 + else: + rho_t_list = [ + rho_inf + - 2 + * _get_value(step) + * (beta2 ** _get_value(step)) + / (1 - beta2 ** _get_value(step)) + for step in grouped_state_steps + ] + + if weight_decay != 0: + if decoupled_weight_decay: + torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) + else: + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + if maximize: + torch._foreach_add_( + grouped_grads, grouped_params, alpha=weight_decay + ) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) + + # Decay the first and second moment running average coefficient + torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) + + torch._foreach_mul_(grouped_exp_avg_sqs, beta2) + torch._foreach_addcmul_( + grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2 + ) + + # Delete the local intermediate since it won't be used anymore to save on peak memory + del grouped_grads + + if capturable: + num = torch._foreach_sub(rho_t_list, 4) + sub2 = torch._foreach_sub(rho_t_list, 2) + torch._foreach_mul_(num, sub2) + del sub2 + torch._foreach_mul_(num, rho_inf) + rho_inf = (rho_inf - 4) * (rho_inf - 2) + denom = torch._foreach_mul(rho_t_list, rho_inf) + torch._foreach_div_(num, denom) + del denom + torch._foreach_sqrt_(num) + + # TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884 + rect = [ + torch.where(rho_t > 5.0, n, 0.0) for n, rho_t in zip(num, rho_t_list) + ] + del num + del rho_t_list + unrect_step_size = [torch.where(rect > 0, 0.0, 1.0) for rect in rect] + torch._foreach_mul_(unrect_step_size, lr) + + bias_correction1 = torch._foreach_pow(beta1, grouped_state_steps) + torch._foreach_neg_(bias_correction1) + torch._foreach_add_(bias_correction1, 1) + + torch._foreach_div_(unrect_step_size, bias_correction1) + torch._foreach_neg_(unrect_step_size) + + bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps) + torch._foreach_neg_(bias_correction2) + torch._foreach_add_(bias_correction2, 1) + torch._foreach_sqrt_(bias_correction2) + torch._foreach_mul_(bias_correction2, lr) + torch._foreach_mul_(bias_correction2, rect) + del rect + torch._foreach_neg_(bias_correction2) + torch._foreach_div_(bias_correction2, bias_correction1) + del bias_correction1 + else: + rect = [ + ( + (rho_t - 4) # type: ignore[arg-type] + * (rho_t - 2) + * rho_inf + / ((rho_inf - 4) * (rho_inf - 2) * rho_t) + ) + ** 0.5 + if rho_t > 5 + else 0 + for rho_t in rho_t_list + ] + unrectified = [0 if rect > 0 else 1.0 for rect in rect] + + bias_correction1 = [ + 1 - beta1 ** _get_value(step) for step in grouped_state_steps + ] + unrect_step_size = [ + (lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1) + ] + bias_correction2 = [ + ((1 - beta2 ** _get_value(step)) ** 0.5) * (lr * rect / bc) * -1 + for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1) + ] + + buffer = torch._foreach_sqrt(grouped_exp_avg_sqs) + torch._foreach_add_(buffer, eps) + torch._foreach_div_(buffer, bias_correction2) + torch._foreach_reciprocal_(buffer) + torch._foreach_add_(buffer, unrect_step_size) + + # Here, buffer = sqrt(1 - beta2^t) * rect_step_size / (sqrt(v) + eps) + unrect_step_size + torch._foreach_addcmul_(grouped_params, grouped_exp_avgs, buffer) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam) +def radam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + decoupled_weight_decay: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + capturable: bool = False, + has_complex: bool = False, + maximize: bool = False, + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +): + r"""Functional API that performs RAdam algorithm computation. + + See :class:`~torch.optim.RAdam` for details. + """ + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_radam + else: + func = _single_tensor_radam + + func( + params, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + decoupled_weight_decay=decoupled_weight_decay, + differentiable=differentiable, + capturable=capturable, + has_complex=has_complex, + ) diff --git a/lib/python3.10/site-packages/torch/optim/rmsprop.py b/lib/python3.10/site-packages/torch/optim/rmsprop.py new file mode 100644 index 0000000000000000000000000000000000000000..9b77ad7fe3eeafcac5afa14c5127afc26d3c1a2d --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/rmsprop.py @@ -0,0 +1,528 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +r"""Implementation for the RMSprop algorithm.""" +from typing import cast, List, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["RMSprop", "rmsprop"] + + +class RMSprop(Optimizer): # noqa: D101 + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0, + momentum: float = 0, + centered=False, + capturable=False, + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + ): # noqa: D107 + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= momentum: + raise ValueError(f"Invalid momentum value: {momentum}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= alpha: + raise ValueError(f"Invalid alpha value: {alpha}") + + defaults = dict( + lr=lr, + momentum=momentum, + alpha=alpha, + eps=eps, + centered=centered, + weight_decay=weight_decay, + capturable=capturable, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("momentum", 0) + group.setdefault("centered", False) + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group( + self, + group, + params_with_grad, + grads, + square_avgs, + momentum_buffer_list, + grad_avgs, + state_steps, + ): + has_complex = False + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params_with_grad.append(p) + + if p.grad.is_sparse: + raise RuntimeError("RMSprop does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.zeros((), dtype=_get_scalar_dtype()) + ) + state["square_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["momentum"] > 0: + state["momentum_buffer"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["centered"]: + state["grad_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + square_avgs.append(state["square_avg"]) + state_steps.append(state["step"]) + + if group["momentum"] > 0: + momentum_buffer_list.append(state["momentum_buffer"]) + if group["centered"]: + grad_avgs.append(state["grad_avg"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + square_avgs: List[Tensor] = [] + grad_avgs: List[Tensor] = [] + momentum_buffer_list: List[Tensor] = [] + state_steps: List[Tensor] = [] + + has_complex = self._init_group( + group, + params_with_grad, + grads, + square_avgs, + momentum_buffer_list, + grad_avgs, + state_steps, + ) + + rmsprop( + params_with_grad, + grads, + square_avgs, + grad_avgs, + momentum_buffer_list, + state_steps, + lr=group["lr"], + alpha=group["alpha"], + eps=group["eps"], + weight_decay=group["weight_decay"], + momentum=group["momentum"], + centered=group["centered"], + foreach=group["foreach"], + maximize=group["maximize"], + differentiable=group["differentiable"], + capturable=group["capturable"], + has_complex=has_complex, + ) + + return loss + + +RMSprop.__doc__ = ( + r"""Implements RMSprop algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \alpha \text{ (alpha)},\: \gamma \text{ (lr)}, + \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\ + &\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \: + \textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}if \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}v_t \leftarrow \alpha v_{t-1} + (1 - \alpha) g^2_t + \hspace{8mm} \\ + &\hspace{5mm} \tilde{v_t} \leftarrow v_t \\ + &\hspace{5mm}if \: centered \\ + &\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t \\ + &\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} - \big(g^{ave}_{t} \big)^2 \\ + &\hspace{5mm}if \: \mu > 0 \\ + &\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} + + g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \\ + &\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t \\ + &\hspace{5mm} else \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - + \gamma g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \hspace{3mm} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to + `lecture notes `_ by G. Hinton. + and centered version `Generating Sequences + With Recurrent Neural Networks `_. + The implementation here takes the square root of the gradient average before + adding epsilon (note that TensorFlow interchanges these two operations). The effective + learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma` + is the scheduled learning rate and :math:`v` is the weighted moving average + of the squared gradient. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, Tensor, optional): learning rate (default: 1e-2) + momentum (float, optional): momentum factor (default: 0) + alpha (float, optional): smoothing constant (default: 0.99) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + {_foreach_doc} + {_maximize_doc} + {_capturable_doc} + {_differentiable_doc} + + """ +) + + +def _single_tensor_rmsprop( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + grad_avgs: List[Tensor], + momentum_buffer_list: List[Tensor], + state_steps: List[Tensor], + *, + lr: float, + alpha: float, + eps: float, + weight_decay: float, + momentum: float, + centered: bool, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + for i, param in enumerate(params): + step = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step.device.type + and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + grad = grads[i] + grad = grad if not maximize else -grad + square_avg = square_avgs[i] + + step += 1 + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + is_complex_param = torch.is_complex(param) + if is_complex_param: + param = torch.view_as_real(param) + grad = torch.view_as_real(grad) + square_avg = torch.view_as_real(square_avg) + + square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) + + if centered: + grad_avg = grad_avgs[i] + if is_complex_param: + grad_avg = torch.view_as_real(grad_avg) + grad_avg.lerp_(grad, 1 - alpha) + avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_() + else: + avg = square_avg.sqrt() + + if differentiable: + avg = avg.add(eps) + else: + avg = avg.add_(eps) + + if momentum > 0: + buf = momentum_buffer_list[i] + if is_complex_param: + buf = torch.view_as_real(buf) + buf.mul_(momentum).addcdiv_(grad, avg) + param.add_(buf, alpha=-lr) + else: + param.addcdiv_(grad, avg, value=-lr) + + +def _multi_tensor_rmsprop( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + grad_avgs: List[Tensor], + momentum_buffer_list: List[Tensor], + state_steps: List[Tensor], + *, + lr: float, + alpha: float, + eps: float, + weight_decay: float, + momentum: float, + centered: bool, + maximize: bool, + differentiable: bool, + capturable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, square_avgs, grad_avgs, momentum_buffer_list, state_steps] # type: ignore[list-item] + ) + for ( + ( + grouped_params_, + grouped_grads_, + grouped_square_avgs_, + grouped_grad_avgs_, + grouped_momentum_buffer_list_, + grouped_state_steps_, + ) + ), _ in grouped_tensors.values(): + grouped_params = cast(List[Tensor], grouped_params_) + grouped_grads = cast(List[Tensor], grouped_grads_) + grouped_square_avgs = cast(List[Tensor], grouped_square_avgs_) + grouped_state_steps = cast(List[Tensor], grouped_state_steps_) + + if has_complex: + state_and_grads = [grouped_grads, grouped_square_avgs] + if momentum > 0: + grouped_momentum_buffer_list = cast( + List[Tensor], grouped_momentum_buffer_list_ + ) + state_and_grads.append(grouped_momentum_buffer_list) + if centered: + grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_) + state_and_grads.append(grouped_grad_avgs) + _view_as_real(grouped_params, *state_and_grads) + + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + if weight_decay != 0: + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + if maximize: + torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) + + torch._foreach_mul_(grouped_square_avgs, alpha) + torch._foreach_addcmul_( + grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha + ) + + if centered: + grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_) + torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha) + avg = torch._foreach_addcmul( + grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1 + ) + torch._foreach_sqrt_(avg) + torch._foreach_add_(avg, eps) + else: + avg = torch._foreach_sqrt(grouped_square_avgs) + torch._foreach_add_(avg, eps) + + if momentum > 0: + grouped_momentum_buffer_list = cast( + List[Tensor], grouped_momentum_buffer_list_ + ) + torch._foreach_mul_(grouped_momentum_buffer_list, momentum) + torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg) + # If LR is a tensor, the else branch will internally call item() + # which will cause silent incorrectness if we are capturing + if capturable and isinstance(lr, torch.Tensor): + momentum_lr = torch._foreach_mul(grouped_momentum_buffer_list, -lr) + torch._foreach_add_(grouped_params, momentum_lr) + else: + torch._foreach_add_( + grouped_params, grouped_momentum_buffer_list, alpha=-lr + ) + else: + # If LR is a tensor, the else branch will internally call item() + # which will cause silent incorrectness if we are capturing + if capturable and isinstance(lr, torch.Tensor): + torch._foreach_div_(avg, -lr) + torch._foreach_addcdiv_(grouped_params, grouped_grads, avg) + else: + torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr) + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop) +def rmsprop( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + grad_avgs: List[Tensor], + momentum_buffer_list: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + capturable: bool = False, + has_complex: bool = False, + *, + lr: float, + alpha: float, + eps: float, + weight_decay: float, + momentum: float, + centered: bool, +): + r"""Functional API that performs rmsprop algorithm computation. + + See :class:`~torch.optim.RMSProp` for details. + """ + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not torch._utils.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_rmsprop + else: + func = _single_tensor_rmsprop + + func( + params, + grads, + square_avgs, + grad_avgs, + momentum_buffer_list, + state_steps, + lr=lr, + alpha=alpha, + eps=eps, + weight_decay=weight_decay, + momentum=momentum, + centered=centered, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + has_complex=has_complex, + ) diff --git a/lib/python3.10/site-packages/torch/optim/rprop.py b/lib/python3.10/site-packages/torch/optim/rprop.py new file mode 100644 index 0000000000000000000000000000000000000000..e28f3535a0b99561c15ed3030f72d1ef3dd1616c --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/rprop.py @@ -0,0 +1,464 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +r"""Implementation for the Resilient backpropagation.""" +from typing import cast, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _capturable_doc, + _default_to_fused_or_foreach, + _differentiable_doc, + _disable_dynamo_if_unsupported, + _foreach_doc, + _get_capturable_supported_devices, + _get_scalar_dtype, + _maximize_doc, + _use_grad_for_differentiable, + _view_as_real, + Optimizer, + ParamsT, +) + + +__all__ = ["Rprop", "rprop"] + + +class Rprop(Optimizer): # noqa: D101 + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-2, + etas: Tuple[float, float] = (0.5, 1.2), + step_sizes: Tuple[float, float] = (1e-6, 50), + *, + capturable: bool = False, + foreach: Optional[bool] = None, + maximize: bool = False, + differentiable: bool = False, + ): # noqa: D107 + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 < etas[0] < 1.0 < etas[1]: + raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}") + + defaults = dict( + lr=lr, + etas=etas, + step_sizes=step_sizes, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("foreach", None) + group.setdefault("maximize", False) + group.setdefault("differentiable", False) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, dtype=_get_scalar_dtype(), device=p.device + ) + if group["capturable"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + + def _init_group(self, group, params, grads, prevs, step_sizes, state_steps): + has_complex = False + for p in group["params"]: + if p.grad is None: + continue + has_complex |= torch.is_complex(p) + params.append(p) + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Rprop does not support sparse gradients") + + grads.append(grad) + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group["capturable"] + else torch.zeros((), dtype=_get_scalar_dtype()) + ) + + state["prev"] = torch.zeros_like(p, memory_format=torch.preserve_format) + if p.dtype.is_complex: + # Complex Number should be as if they are two independent real numbers. + # Hence the step_size shouldn't be zero for imaginary part. + state["step_size"] = torch.full_like( + grad, complex(group["lr"], group["lr"]) + ) + else: + state["step_size"] = torch.full_like(grad, group["lr"]) + + prevs.append(state["prev"]) + step_sizes.append(state["step_size"]) + state_steps.append(state["step"]) + + return has_complex + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params: List[Tensor] = [] + grads: List[Tensor] = [] + prevs: List[Tensor] = [] + step_sizes: List[Tensor] = [] + state_steps: List[Tensor] = [] + + etaminus, etaplus = group["etas"] + step_size_min, step_size_max = group["step_sizes"] + foreach = group["foreach"] + maximize = group["maximize"] + + has_complex = self._init_group( + group, params, grads, prevs, step_sizes, state_steps + ) + + rprop( + params, + grads, + prevs, + step_sizes, + state_steps, + step_size_min=step_size_min, + step_size_max=step_size_max, + etaminus=etaminus, + etaplus=etaplus, + foreach=foreach, + maximize=maximize, + differentiable=group["differentiable"], + capturable=group["capturable"], + has_complex=has_complex, + ) + + return loss + + +Rprop.__doc__ = ( + r"""Implements the resilient backpropagation algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta) + \text{ (objective)}, \\ + &\hspace{13mm} \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min} + \text{ (step sizes)} \\ + &\textbf{initialize} : g^0_{prev} \leftarrow 0, + \: \eta_0 \leftarrow \text{lr (learning rate)} \\ + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \textbf{for} \text{ } i = 0, 1, \ldots, d-1 \: \mathbf{do} \\ + &\hspace{10mm} \textbf{if} \: g^i_{prev} g^i_t > 0 \\ + &\hspace{15mm} \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+}, + \Gamma_{max}) \\ + &\hspace{10mm} \textbf{else if} \: g^i_{prev} g^i_t < 0 \\ + &\hspace{15mm} \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-}, + \Gamma_{min}) \\ + &\hspace{15mm} g^i_t \leftarrow 0 \\ + &\hspace{10mm} \textbf{else} \: \\ + &\hspace{15mm} \eta^i_t \leftarrow \eta^i_{t-1} \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t) \\ + &\hspace{5mm}g_{prev} \leftarrow g_t \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to the paper + `A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm + `_. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + etas (Tuple[float, float], optional): pair of (etaminus, etaplus), that + are multiplicative increase and decrease factors + (default: (0.5, 1.2)) + step_sizes (Tuple[float, float], optional): a pair of minimal and + maximal allowed step sizes (default: (1e-6, 50)) + {_foreach_doc} + {_capturable_doc} + {_maximize_doc} + {_differentiable_doc} + + """ +) + + +def _single_tensor_rprop( + params: List[Tensor], + grads: List[Tensor], + prevs: List[Tensor], + step_sizes: List[Tensor], + state_steps: List[Tensor], + *, + step_size_min: float, + step_size_max: float, + etaminus: float, + etaplus: float, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + for i, param in enumerate(params): + grad = grads[i] + grad = grad if not maximize else -grad + prev = prevs[i] + step_size = step_sizes[i] + step = state_steps[i] + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert ( + param.device.type == step.device.type + and param.device.type in capturable_supported_devices + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + step += 1 + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + prev = torch.view_as_real(prev) + param = torch.view_as_real(param) + step_size = torch.view_as_real(step_size) + if differentiable: + sign = grad.mul(prev.clone()).sign() + else: + sign = grad.mul(prev).sign() + + if capturable: + sign.copy_(torch.where(sign.gt(0), etaplus, sign)) + sign.copy_(torch.where(sign.lt(0), etaminus, sign)) + sign.copy_(torch.where(sign.eq(0), 1, sign)) + else: + sign[sign.gt(0)] = etaplus + sign[sign.lt(0)] = etaminus + sign[sign.eq(0)] = 1 + + # update stepsizes with step size updates + step_size.mul_(sign).clamp_(step_size_min, step_size_max) + + # for dir<0, dfdx=0 + # for dir>=0 dfdx=dfdx + grad = grad.clone(memory_format=torch.preserve_format) + if capturable: + grad.copy_(torch.where(sign.eq(etaminus), 0, grad)) + else: + grad[sign.eq(etaminus)] = 0 + + # update parameters + param.addcmul_(grad.sign(), step_size, value=-1) + prev.copy_(grad) + + +def _multi_tensor_rprop( + params: List[Tensor], + grads: List[Tensor], + prevs: List[Tensor], + step_sizes: List[Tensor], + state_steps: List[Tensor], + *, + step_size_min: float, + step_size_max: float, + etaminus: float, + etaplus: float, + maximize: bool, + capturable: bool, + differentiable: bool, + has_complex: bool, +): + if len(params) == 0: + return + + assert not differentiable, "_foreach ops don't support autograd" + + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + capturable_supported_devices = _get_capturable_supported_devices() + assert all( + p.device.type == step.device.type + and p.device.type in capturable_supported_devices + for p, step in zip(params, state_steps) + ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, prevs, step_sizes, state_steps] # type: ignore[list-item] + ) + for ( + grouped_params_, + grouped_grads_, + grouped_prevs_, + grouped_step_sizes_, + grouped_state_steps_, + ), _ in grouped_tensors.values(): + grouped_params = cast(List[Tensor], grouped_params_) + grouped_grads = cast(List[Tensor], grouped_grads_) + grouped_prevs = cast(List[Tensor], grouped_prevs_) + grouped_step_sizes = cast(List[Tensor], grouped_step_sizes_) + grouped_state_steps = cast(List[Tensor], grouped_state_steps_) + + # Update steps + # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over + # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just + # wrapped it once now. The alpha is required to assure we go to the right overload. + if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: + torch._foreach_add_( + grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 + ) + else: + torch._foreach_add_(grouped_state_steps, 1) + + # Handle complex params + if has_complex: + _view_as_real( + grouped_params, grouped_grads, grouped_prevs, grouped_step_sizes + ) + + signs = torch._foreach_mul(grouped_grads, grouped_prevs) + if maximize: + torch._foreach_neg_(signs) + + # At the end of the step, grouped_prevs will contain the current grads, so we reuse + # grouped_prevs memory instead of creating a new buffer, but, for clarity, we reassign + # to keep referring to the buffer as grouped_grads. + torch._foreach_copy_(grouped_prevs, grouped_grads) + if maximize: + torch._foreach_neg_(grouped_prevs) + grouped_grads = grouped_prevs + + torch._foreach_sign_(signs) + if capturable: + for sign in signs: + sign.copy_(torch.where(sign.gt(0), etaplus, sign)) + sign.copy_(torch.where(sign.lt(0), etaminus, sign)) + sign.copy_(torch.where(sign.eq(0), 1, sign)) + else: + for sign in signs: + sign[sign.gt(0)] = etaplus + sign[sign.lt(0)] = etaminus + sign[sign.eq(0)] = 1 + + # update stepsizes with step size updates + torch._foreach_mul_(grouped_step_sizes, signs) + for step_size in grouped_step_sizes: + step_size.clamp_(step_size_min, step_size_max) + + # for dir<0, dfdx=0 + # for dir>=0 dfdx=dfdx + grouped_grads = list(grouped_grads) + for i in range(len(grouped_grads)): + grouped_grads[i].copy_( + torch.where(signs[i].eq(etaminus), 0, grouped_grads[i]) + ) + + # explicitly del signs as it's not used after here to save memory + del signs + + # update parameters + grad_signs = [grad.sign() for grad in grouped_grads] + torch._foreach_addcmul_( + grouped_params, grad_signs, grouped_step_sizes, value=-1 + ) + + # Logically, you may expect grouped_prevs to get updated to grouped_grads, but that's + # basically already happened since we've been using grouped_prevs' memory to store + # updated grouped_grads! + + +@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rprop) +def rprop( + params: List[Tensor], + grads: List[Tensor], + prevs: List[Tensor], + step_sizes: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + foreach: Optional[bool] = None, + capturable: bool = False, + maximize: bool = False, + differentiable: bool = False, + has_complex: bool = False, + *, + step_size_min: float, + step_size_max: float, + etaminus: float, + etaplus: float, +): + r"""Functional API that performs rprop algorithm computation. + + See :class:`~torch.optim.Rprop` for details. + """ + # this check is slow during compilation, so we skip it + # if it's strictly needed we can add this check back in dynamo + if not torch._utils.is_compiling() and not all( + isinstance(t, torch.Tensor) for t in state_steps + ): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + if foreach is None: + _, foreach = _default_to_fused_or_foreach( + params, differentiable, use_fused=False + ) + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_rprop + else: + func = _single_tensor_rprop + + func( + params, + grads, + prevs, + step_sizes, + state_steps, + step_size_min=step_size_min, + step_size_max=step_size_max, + etaminus=etaminus, + etaplus=etaplus, + capturable=capturable, + maximize=maximize, + differentiable=differentiable, + has_complex=has_complex, + ) diff --git a/lib/python3.10/site-packages/torch/optim/sgd.py b/lib/python3.10/site-packages/torch/optim/sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..b270afce4d60a2b4ae478f062ea7d481e6985cb2 --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/sgd.py @@ -0,0 +1,511 @@ +# mypy: allow-untyped-defs +r"""Implementation for Stochastic Gradient Descent optimizer.""" +from typing import cast, List, Optional, Union + +import torch +from torch import Tensor + +from .optimizer import ( + _default_to_fused_or_foreach, + _device_dtype_check_for_fused, + _differentiable_doc, + _foreach_doc, + _fused_doc, + _maximize_doc, + _use_grad_for_differentiable, + DeviceDict, + Optimizer, +) + + +__all__ = ["SGD", "sgd"] + + +class SGD(Optimizer): # noqa: D101 + def __init__( + self, + params, + lr: Union[float, Tensor] = 1e-3, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov=False, + *, + maximize: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + fused: Optional[bool] = None, + ): # noqa: D107 + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + maximize=maximize, + foreach=foreach, + differentiable=differentiable, + fused=fused, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + if fused: + self._step_supports_amp_scaling = True + self._need_device_dtype_check_for_fused = True + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + def __setstate__(self, state): # noqa: D105 + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("differentiable", False) + group.setdefault("fused", False) + + def _init_group(self, group, params, grads, momentum_buffer_list): + has_sparse_grad = False + + for p in group["params"]: + if p.grad is not None: + if group["fused"] and getattr( + self, "_need_device_dtype_check_for_fused", True + ): + _device_dtype_check_for_fused(p) + self._need_device_dtype_check_for_fused = False + params.append(p) + grads.append(p.grad) + if p.grad.is_sparse: + has_sparse_grad = True + + if group["momentum"] != 0: + state = self.state[p] + momentum_buffer_list.append(state.get("momentum_buffer")) + + return has_sparse_grad + + @_use_grad_for_differentiable + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params: List[Tensor] = [] + grads: List[Tensor] = [] + momentum_buffer_list: List[Optional[Tensor]] = [] + + has_sparse_grad = self._init_group( + group, params, grads, momentum_buffer_list + ) + + sgd( + params, + grads, + momentum_buffer_list, + weight_decay=group["weight_decay"], + momentum=group["momentum"], + lr=group["lr"], + dampening=group["dampening"], + nesterov=group["nesterov"], + maximize=group["maximize"], + has_sparse_grad=has_sparse_grad, + foreach=group["foreach"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + if group["momentum"] != 0: + # update momentum_buffers in state + for p, momentum_buffer in zip(params, momentum_buffer_list): + state = self.state[p] + state["momentum_buffer"] = momentum_buffer + + return loss + + +SGD.__doc__ = ( + r"""Implements stochastic gradient descent (optionally with momentum). + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) + \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ + &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, + \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ + &\hspace{10mm}\textbf{if} \: t > 1 \\ + &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ + &\hspace{10mm}\textbf{else} \\ + &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ + &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ + &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ + &\hspace{10mm}\textbf{else} \\[-1.ex] + &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex] + &\hspace{5mm}\textbf{else} \\[-1.ex] + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + Nesterov momentum is based on the formula from + `On the importance of initialization and momentum in deep learning`__. + """ + + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, Tensor, optional): learning rate (default: 1e-3) + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + {_maximize_doc} + {_foreach_doc} + {_differentiable_doc} + {_fused_doc} + """ + + r""" + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf + + .. note:: + The implementation of SGD with Momentum/Nesterov subtly differs from + Sutskever et al. and implementations in some other frameworks. + + Considering the specific case of Momentum, the update can be written as + + .. math:: + \begin{aligned} + v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ + p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, + \end{aligned} + + where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the + parameters, gradient, velocity, and momentum respectively. + + This is in contrast to Sutskever et al. and + other frameworks which employ an update of the form + + .. math:: + \begin{aligned} + v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ + p_{t+1} & = p_{t} - v_{t+1}. + \end{aligned} + + The Nesterov version is analogously modified. + + Moreover, the initial value of the momentum buffer is set to the + gradient value at the first step. This is in contrast to some other + frameworks that initialize it to all zeros. + + """ +) + + +def sgd( + params: List[Tensor], + d_p_list: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + has_sparse_grad: bool = False, + foreach: Optional[bool] = None, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, +): + r"""Functional API that performs SGD algorithm computation. + + See :class:`~torch.optim.SGD` for details. + """ + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if foreach is None and fused is None: + # why must we be explicit about an if statement for torch.jit.is_scripting here? + # because JIT can't handle Optionals nor fancy conditionals when scripting + if not torch.jit.is_scripting(): + fused, foreach = _default_to_fused_or_foreach( + params, differentiable=False, use_fused=False + ) + else: + foreach = False + fused = False + if foreach is None: + foreach = False + if fused is None: + fused = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_sgd + elif fused and not torch.jit.is_scripting(): + func = _fused_sgd + else: + func = _single_tensor_sgd + + func( + params, + d_p_list, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=nesterov, + has_sparse_grad=has_sparse_grad, + maximize=maximize, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +def _single_tensor_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if momentum != 0: + buf = momentum_buffer_list[i] + + if buf is None: + buf = torch.clone(grad).detach() + momentum_buffer_list[i] = buf + else: + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + param.add_(grad, alpha=-lr) + + +def _multi_tensor_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +): + assert grad_scale is None and found_inf is None + + if len(params) == 0: + return + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, momentum_buffer_list], with_indices=True # type: ignore[list-item] + ) + + for ( + device_params_, + device_grads_, + device_momentum_buffer_list, + ), indices in grouped_tensors.values(): + device_params: List[Tensor] = cast(List[Tensor], device_params_) + device_grads: List[Tensor] = cast(List[Tensor], device_grads_) + + device_has_sparse_grad = has_sparse_grad and any( + grad.is_sparse for grad in device_grads + ) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + if momentum != 0: + bufs: List[Tensor] = [] + + all_states_with_momentum_buffer = True + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + all_states_with_momentum_buffer = False + break + else: + bufs.append(cast(Tensor, device_momentum_buffer_list[i])) + + if all_states_with_momentum_buffer: + torch._foreach_mul_(bufs, momentum) + torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) + else: + bufs = [] + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + buf = device_momentum_buffer_list[i] = momentum_buffer_list[ + indices[i] + ] = torch.clone(device_grads[i]).detach() + else: + buf = cast(Tensor, device_momentum_buffer_list[i]) + buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) + + bufs.append(buf) + + if nesterov: + torch._foreach_add_(device_grads, bufs, alpha=momentum) + else: + device_grads = bufs + + if not device_has_sparse_grad: + # handle internal item() call if lr is a tensor + if isinstance(lr, torch.Tensor) and torch._utils.is_compiling(): + grads_x_lr = torch._foreach_mul(device_grads, -lr) + torch._foreach_add_(device_params, grads_x_lr) + else: + torch._foreach_add_(device_params, device_grads, alpha=-lr) + else: + # foreach APIs don't support sparse + for i in range(len(device_params)): + device_params[i].add_(device_grads[i], alpha=-lr) + + +def _fused_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +) -> None: + if not params: + return + if has_sparse_grad: + raise RuntimeError("`_fused_sgd` does not support sparse gradients") + grad_scale_dict: DeviceDict = ( + {grad_scale.device: grad_scale} if grad_scale is not None else {} + ) + found_inf_dict: DeviceDict = ( + {found_inf.device: found_inf} if found_inf is not None else {} + ) + + no_momentum_buffer = momentum == 0 + is_first_step = ( + all(t is None for t in momentum_buffer_list) and not no_momentum_buffer + ) + if is_first_step: + for i, g in enumerate(grads): + momentum_buffer_list[i] = torch.empty_like(g) + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, momentum_buffer_list], with_indices=False # type: ignore[list-item] + ) + for (device, _), ( + (device_params_, device_grads_, device_momentum_buffer_list), + _, + ) in grouped_tensors.items(): + device_params: List[Tensor] = cast(List[Tensor], device_params_) + device_grads: List[Tensor] = cast(List[Tensor], device_grads_) + device_grad_scale, device_found_inf = None, None + if grad_scale is not None: + device_grad_scale = grad_scale_dict.setdefault( + device, grad_scale.to(device) + ) + if found_inf_dict is not None and found_inf is not None: + device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device)) + torch._fused_sgd_( + device_params, + device_grads, + [] + if no_momentum_buffer + else cast(List[Tensor], device_momentum_buffer_list), + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=nesterov, + maximize=maximize, + is_first_step=is_first_step, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) diff --git a/lib/python3.10/site-packages/torch/optim/sparse_adam.py b/lib/python3.10/site-packages/torch/optim/sparse_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..22ef7841270f6907f04d31a0d47edc1d02e42ff6 --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/sparse_adam.py @@ -0,0 +1,185 @@ +# mypy: allow-untyped-defs +from typing import List, Tuple, Union + +import torch +from torch import Tensor + +from . import _functional as F +from .optimizer import _maximize_doc, Optimizer, ParamsT + + +__all__ = ["SparseAdam"] + + +class SparseAdam(Optimizer): + def __init__( + self, + params: ParamsT, + lr: Union[float, Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + maximize: bool = False, + ): + if isinstance(lr, Tensor) and lr.numel() != 1: + raise ValueError("Tensor lr must be 1-element") + if not 0.0 < lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 < eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + + defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize) + super().__init__(params, defaults) + + sparse_params = [] + complex_params = [] + for index, param_group in enumerate(self.param_groups): + assert isinstance( + param_group, dict + ), f"param_groups must be a list of dicts, but got {type(param_group)}" + # given param group, convert given params to a list first before iterating + for d_index, d_param in enumerate(param_group["params"]): + if d_param.is_sparse: + sparse_params.append([index, d_index]) + if d_param.is_complex(): + complex_params.append([index, d_index]) + if sparse_params: + raise ValueError( + f"Sparse params at indices {sparse_params}: SparseAdam requires dense parameter tensors" + ) + if complex_params: + raise ValueError( + f"Complex params at indices {complex_params}: SparseAdam does not support complex parameters" + ) + + @torch.no_grad() + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad: List[Tensor] = [] + grads: List[Tensor] = [] + exp_avgs: List[Tensor] = [] + exp_avg_sqs: List[Tensor] = [] + state_steps: List[int] = [] + beta1, beta2 = group["betas"] + maximize = group.get("maximize", False) + + for p in group["params"]: + if p.grad is not None: + params_with_grad.append(p) + if not p.grad.is_sparse: + raise RuntimeError( + "SparseAdam does not support dense gradients, please consider Adam instead" + ) + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + # update the steps for each param group update + state["step"] += 1 + # record the step after step update + state_steps.append(state["step"]) + + F.sparse_adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + eps=group["eps"], + beta1=beta1, + beta2=beta2, + lr=group["lr"], + maximize=maximize, + ) + + return loss + + +SparseAdam.__doc__ = rf"""SparseAdam implements a masked version of the Adam algorithm + suitable for sparse gradients. Currently, due to implementation constraints (explained + below), SparseAdam is only intended for a narrow subset of use cases, specifically + parameters of a dense layout with gradients of a sparse layout. This occurs in a + special case where the module backwards produces grads already in a sparse layout. + One example NN module that behaves as such is ``nn.Embedding(sparse=True)``. + + SparseAdam approximates the Adam algorithm by masking out the parameter and moment + updates corresponding to the zero values in the gradients. Whereas the Adam algorithm + will update the first moment, the second moment, and the parameters based on all values + of the gradients, SparseAdam only updates the moments and parameters corresponding + to the non-zero values of the gradients. + + A simplified way of thinking about the `intended` implementation is as such: + + 1. Create a mask of the non-zero values in the sparse gradients. For example, + if your gradient looks like [0, 5, 0, 0, 9], the mask would be [0, 1, 0, 0, 1]. + 2. Apply this mask over the running moments and do computation on only the + non-zero values. + 3. Apply this mask over the parameters and only apply an update on non-zero values. + + In actuality, we use sparse layout Tensors to optimize this approximation, which means the + more gradients that are masked by not being materialized, the more performant the optimization. + Since we rely on using sparse layout tensors, we infer that any materialized value in the + sparse layout is non-zero and we do NOT actually verify that all values are not zero! + It is important to not conflate a semantically sparse tensor (a tensor where many + of its values are zeros) with a sparse layout tensor (a tensor where ``.is_sparse`` + returns ``True``). The SparseAdam approximation is intended for `semantically` sparse + tensors and the sparse layout is only a implementation detail. A clearer implementation + would be to use MaskedTensors, but those are experimental. + + + .. note:: + + If you suspect your gradients are semantically sparse (but do not have sparse + layout), this variant may not be the best for you. Ideally, you want to avoid + materializing anything that is suspected to be sparse in the first place, since + needing to convert all your grads from dense layout to sparse layout may outweigh + the performance gain. Here, using Adam may be the best alternative, unless you + can easily rig up your module to output sparse grads similar to + ``nn.Embedding(sparse=True)``. If you insist on converting your grads, you can do + so by manually overriding your parameters' ``.grad`` fields with their sparse + equivalents before calling ``.step()``. + + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, Tensor, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + {_maximize_doc} + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + + """ diff --git a/lib/python3.10/site-packages/torch/optim/swa_utils.py b/lib/python3.10/site-packages/torch/optim/swa_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7bea0d355bea34fbbf3bedea1cdcbbcbb6dc4e46 --- /dev/null +++ b/lib/python3.10/site-packages/torch/optim/swa_utils.py @@ -0,0 +1,467 @@ +# mypy: allow-untyped-defs +r"""Implementation for Stochastic Weight Averaging implementation.""" +import itertools +import math +import warnings +from copy import deepcopy +from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torch.optim.lr_scheduler import _format_param, LRScheduler +from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices + +from .optimizer import Optimizer + + +__all__ = [ + "AveragedModel", + "update_bn", + "SWALR", + "get_ema_multi_avg_fn", + "get_swa_multi_avg_fn", + "get_ema_avg_fn", + "get_swa_avg_fn", +] + +from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype + + +PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]] + + +def get_ema_multi_avg_fn(decay=0.999): + """Get the function applying exponential moving average (EMA) across multiple params.""" + + @torch.no_grad() + def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _): + # foreach lerp only handles float and complex + if torch.is_floating_point(ema_param_list[0]) or torch.is_complex( + ema_param_list[0] + ): + torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay) + else: + for p_ema, p_model in zip(ema_param_list, current_param_list): + p_ema.copy_(p_ema * decay + p_model * (1 - decay)) + + return ema_update + + +def get_swa_multi_avg_fn(): + """Get the function applying stochastic weight average (SWA) across multiple params.""" + + @torch.no_grad() + def swa_update( + averaged_param_list: PARAM_LIST, + current_param_list: PARAM_LIST, + num_averaged: Union[Tensor, int], + ): + # foreach lerp only handles float and complex + if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex( + averaged_param_list[0] + ): + torch._foreach_lerp_( + averaged_param_list, current_param_list, 1 / (num_averaged + 1) + ) + else: + diffs = torch._foreach_sub(current_param_list, averaged_param_list) + if isinstance(num_averaged, Tensor): + torch._foreach_addcdiv_( + averaged_param_list, + diffs, + [num_averaged + 1] * len(averaged_param_list), + ) + else: + torch._foreach_add_( + averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1) + ) + + return swa_update + + +def get_ema_avg_fn(decay=0.999): + """Get the function applying exponential moving average (EMA) across a single param.""" + + @torch.no_grad() + def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged): + return decay * ema_param + (1 - decay) * current_param + + return ema_update + + +def get_swa_avg_fn(): + """Get the function applying stochastic weight average (SWA) across a single param.""" + + @torch.no_grad() + def swa_update( + averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int] + ): + return averaged_param + (current_param - averaged_param) / (num_averaged + 1) + + return swa_update + + +class AveragedModel(Module): + r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA). + + Stochastic Weight Averaging was proposed in `Averaging Weights Leads to + Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii + Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson + (UAI 2018). + + Exponential Moving Average is a variation of `Polyak averaging`_, + but using exponential weights instead of equal weights across iterations. + + AveragedModel class creates a copy of the provided module :attr:`model` + on the device :attr:`device` and allows to compute running averages of the + parameters of the :attr:`model`. + + Args: + model (torch.nn.Module): model to use with SWA/EMA + device (torch.device, optional): if provided, the averaged model will be + stored on the :attr:`device` + avg_fn (function, optional): the averaging function used to update + parameters; the function must take in the current value of the + :class:`AveragedModel` parameter, the current value of :attr:`model` + parameter, and the number of models already averaged; if None, + an equally weighted average is used (default: None) + multi_avg_fn (function, optional): the averaging function used to update + parameters inplace; the function must take in the current values of the + :class:`AveragedModel` parameters as a list, the current values of :attr:`model` + parameters as a list, and the number of models already averaged; if None, + an equally weighted average is used (default: None) + use_buffers (bool): if ``True``, it will compute running averages for + both the parameters and the buffers of the model. (default: ``False``) + + Example: + >>> # xdoctest: +SKIP("undefined variables") + >>> loader, optimizer, model, loss_fn = ... + >>> swa_model = torch.optim.swa_utils.AveragedModel(model) + >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, + >>> T_max=300) + >>> swa_start = 160 + >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05) + >>> for i in range(300): + >>> for input, target in loader: + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + >>> if i > swa_start: + >>> swa_model.update_parameters(model) + >>> swa_scheduler.step() + >>> else: + >>> scheduler.step() + >>> + >>> # Update bn statistics for the swa_model at the end + >>> torch.optim.swa_utils.update_bn(loader, swa_model) + + You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters. + If no averaging function is provided, the default is to compute + equally-weighted average of the weights (SWA). + + Example: + >>> # xdoctest: +SKIP("undefined variables") + >>> # Compute exponential moving averages of the weights and buffers + >>> ema_model = torch.optim.swa_utils.AveragedModel(model, + >>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True) + + .. note:: + When using SWA/EMA with models containing Batch Normalization you may + need to update the activation statistics for Batch Normalization. + This can be done either by using the :meth:`torch.optim.swa_utils.update_bn` + or by setting :attr:`use_buffers` to `True`. The first approach updates the + statistics in a post-training step by passing data through the model. The + second does it during the parameter update phase by averaging all buffers. + Empirical evidence has shown that updating the statistics in normalization + layers increases accuracy, but you may wish to empirically test which + approach yields the best results in your problem. + + .. note:: + :attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model. + + .. note:: + When :meth:`update_parameters` is called for the first time (i.e. + :attr:`n_averaged` is `0`) the parameters of `model` are copied + to the parameters of :class:`AveragedModel`. For every subsequent + call of :meth:`update_parameters` the function `avg_fn` is used + to update the parameters. + + .. _Averaging Weights Leads to Wider Optima and Better Generalization: + https://arxiv.org/abs/1803.05407 + .. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should + Average: + https://arxiv.org/abs/1806.05594 + .. _SWALP: Stochastic Weight Averaging in Low-Precision Training: + https://arxiv.org/abs/1904.11943 + .. _Stochastic Weight Averaging in Parallel: Large-Batch Training That + Generalizes Well: + https://arxiv.org/abs/2001.02312 + .. _Polyak averaging: + https://paperswithcode.com/method/polyak-averaging + """ + + n_averaged: Tensor + + def __init__( + self, + model: Module, + device: Optional[Union[int, torch.device]] = None, + avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None, + multi_avg_fn: Optional[ + Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None] + ] = None, + use_buffers=False, + ): # noqa: D107 + super().__init__() + assert ( + avg_fn is None or multi_avg_fn is None + ), "Only one of avg_fn and multi_avg_fn should be provided" + self.module = deepcopy(model) + if device is not None: + self.module = self.module.to(device) + self.register_buffer( + "n_averaged", torch.tensor(0, dtype=torch.long, device=device) + ) + self.avg_fn = avg_fn + self.multi_avg_fn = multi_avg_fn + self.use_buffers = use_buffers + + def forward(self, *args, **kwargs): + """Forward pass.""" + return self.module(*args, **kwargs) + + def update_parameters(self, model: Module): + """Update model parameters.""" + self_param = ( + itertools.chain(self.module.parameters(), self.module.buffers()) + if self.use_buffers + else self.parameters() + ) + model_param = ( + itertools.chain(model.parameters(), model.buffers()) + if self.use_buffers + else model.parameters() + ) + self_param_detached: List[Optional[Tensor]] = [] + model_param_detached: List[Optional[Tensor]] = [] + for p_averaged, p_model in zip(self_param, model_param): + p_model_ = p_model.detach().to(p_averaged.device) + self_param_detached.append(p_averaged.detach()) + model_param_detached.append(p_model_) + if self.n_averaged == 0: + p_averaged.detach().copy_(p_model_) + + if self.n_averaged > 0: + if self.multi_avg_fn is not None or self.avg_fn is None: + grouped_tensors = _group_tensors_by_device_and_dtype( + [self_param_detached, model_param_detached] + ) + for (device, _), ( + [self_params, model_params], + _, + ) in grouped_tensors.items(): + if self.multi_avg_fn: + self.multi_avg_fn( + self_params, model_params, self.n_averaged.to(device) # type: ignore[arg-type] + ) + elif ( + device is not None + and device.type in _get_foreach_kernels_supported_devices() + ): + multi_avg_fn = get_swa_multi_avg_fn() + multi_avg_fn( + self_params, model_params, self.n_averaged.to(device) + ) + else: + avg_fn = get_swa_avg_fn() + n_averaged = self.n_averaged.to(device) + for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment] + p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged)) + else: + for p_averaged, p_model in zip( # type: ignore[assignment] + self_param_detached, model_param_detached + ): + n_averaged = self.n_averaged.to(p_averaged.device) + p_averaged.detach().copy_( + self.avg_fn(p_averaged.detach(), p_model, n_averaged) + ) + + if not self.use_buffers: + # If not apply running averages to the buffers, + # keep the buffers in sync with the source model. + for b_swa, b_model in zip(self.module.buffers(), model.buffers()): + b_swa.detach().copy_(b_model.detach().to(b_swa.device)) + self.n_averaged += 1 + + +@torch.no_grad() +def update_bn( + loader: Iterable[Any], + model: Module, + device: Optional[Union[int, torch.device]] = None, +): + r"""Update BatchNorm running_mean, running_var buffers in the model. + + It performs one pass over data in `loader` to estimate the activation + statistics for BatchNorm layers in the model. + + Args: + loader (torch.utils.data.DataLoader): dataset loader to compute the + activation statistics on. Each data batch should be either a + tensor, or a list/tuple whose first element is a tensor + containing data. + model (torch.nn.Module): model for which we seek to update BatchNorm + statistics. + device (torch.device, optional): If set, data will be transferred to + :attr:`device` before being passed into :attr:`model`. + + Example: + >>> # xdoctest: +SKIP("Undefined variables") + >>> loader, model = ... + >>> torch.optim.swa_utils.update_bn(loader, model) + + .. note:: + The `update_bn` utility assumes that each data batch in :attr:`loader` + is either a tensor or a list or tuple of tensors; in the latter case it + is assumed that :meth:`model.forward()` should be called on the first + element of the list or tuple corresponding to the data batch. + """ + momenta = {} + for module in model.modules(): + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + module.reset_running_stats() + momenta[module] = module.momentum + + if not momenta: + return + + was_training = model.training + model.train() + for module in momenta.keys(): + module.momentum = None + + for input in loader: + if isinstance(input, (list, tuple)): + input = input[0] + if device is not None: + input = input.to(device) + + model(input) + + for bn_module in momenta.keys(): + bn_module.momentum = momenta[bn_module] + model.train(was_training) + + +class SWALR(LRScheduler): + r"""Anneals the learning rate in each parameter group to a fixed value. + + This learning rate scheduler is meant to be used with Stochastic Weight + Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`). + + Args: + optimizer (torch.optim.Optimizer): wrapped optimizer + swa_lrs (float or list): the learning rate value for all param groups + together or separately for each group. + annealing_epochs (int): number of epochs in the annealing phase + (default: 10) + annealing_strategy (str): "cos" or "linear"; specifies the annealing + strategy: "cos" for cosine annealing, "linear" for linear annealing + (default: "cos") + last_epoch (int): the index of the last epoch (default: -1) + + The :class:`SWALR` scheduler can be used together with other + schedulers to switch to a constant learning rate late in the training + as in the example below. + + Example: + >>> # xdoctest: +SKIP("Undefined variables") + >>> loader, optimizer, model = ... + >>> lr_lambda = lambda epoch: 0.9 + >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, + >>> lr_lambda=lr_lambda) + >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, + >>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05) + >>> swa_start = 160 + >>> for i in range(300): + >>> for input, target in loader: + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + >>> if i > swa_start: + >>> swa_scheduler.step() + >>> else: + >>> scheduler.step() + + .. _Averaging Weights Leads to Wider Optima and Better Generalization: + https://arxiv.org/abs/1803.05407 + """ + + def __init__( + self, + optimizer: Optimizer, + swa_lr: float, + anneal_epochs=10, + anneal_strategy: Literal["cos", "linear"] = "cos", + last_epoch=-1, + ): # noqa: D107 + swa_lrs = _format_param("swa_lr", optimizer, swa_lr) + for swa_lr, group in zip(swa_lrs, optimizer.param_groups): + group["swa_lr"] = swa_lr + if anneal_strategy not in ["cos", "linear"]: + raise ValueError( + "anneal_strategy must by one of 'cos' or 'linear', " + f"instead got {anneal_strategy}" + ) + elif anneal_strategy == "cos": + self.anneal_func = self._cosine_anneal + elif anneal_strategy == "linear": + self.anneal_func = self._linear_anneal + if not isinstance(anneal_epochs, int) or anneal_epochs < 0: + raise ValueError( + f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}" + ) + self.anneal_epochs = anneal_epochs + super().__init__(optimizer, last_epoch) + + @staticmethod + def _linear_anneal(t): + return t + + @staticmethod + def _cosine_anneal(t): + return (1 - math.cos(math.pi * t)) / 2 + + @staticmethod + def _get_initial_lr(lr, swa_lr, alpha): + if alpha == 1: + return swa_lr + return (lr - alpha * swa_lr) / (1 - alpha) + + def get_lr(self): + """Get learning rate.""" + # `_get_lr_called_within_step` is only available `_enable_get_lr_call`, + # so we ignore the type error here. See `LRScheduler.step()` for more details. + if not self._get_lr_called_within_step: # type: ignore[attr-defined] + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + UserWarning, + ) + # Set in `LRScheduler._initial_step()` + step = self._step_count - 1 # type: ignore[attr-defined] + if self.anneal_epochs == 0: + step = max(1, step) + prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs))) + prev_alpha = self.anneal_func(prev_t) + prev_lrs = [ + self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha) + for group in self.optimizer.param_groups + ] + t = max(0, min(1, step / max(1, self.anneal_epochs))) + alpha = self.anneal_func(t) + return [ + group["swa_lr"] * alpha + lr * (1 - alpha) + for group, lr in zip(self.optimizer.param_groups, prev_lrs) + ] diff --git a/lib/python3.10/site-packages/torch/package/__init__.py b/lib/python3.10/site-packages/torch/package/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66cace5931ac17c548becfddbb0e56dbbdac3d38 --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/__init__.py @@ -0,0 +1,12 @@ +from .analyze.is_from_package import is_from_package +from .file_structure_representation import Directory +from .glob_group import GlobGroup +from .importer import ( + Importer, + ObjMismatchError, + ObjNotFoundError, + OrderedImporter, + sys_importer, +) +from .package_exporter import EmptyMatchError, PackageExporter, PackagingError +from .package_importer import PackageImporter diff --git a/lib/python3.10/site-packages/torch/package/_digraph.py b/lib/python3.10/site-packages/torch/package/_digraph.py new file mode 100644 index 0000000000000000000000000000000000000000..8b753f7ebdc4b162be565eb4e923c4cfa333c597 --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/_digraph.py @@ -0,0 +1,174 @@ +# mypy: allow-untyped-defs +from collections import deque +from typing import List, Set + + +class DiGraph: + """Really simple unweighted directed graph data structure to track dependencies. + + The API is pretty much the same as networkx so if you add something just + copy their API. + """ + + def __init__(self): + # Dict of node -> dict of arbitrary attributes + self._node = {} + # Nested dict of node -> successor node -> nothing. + # (didn't implement edge data) + self._succ = {} + # Nested dict of node -> predecessor node -> nothing. + self._pred = {} + + # Keep track of the order in which nodes are added to + # the graph. + self._node_order = {} + self._insertion_idx = 0 + + def add_node(self, n, **kwargs): + """Add a node to the graph. + + Args: + n: the node. Can we any object that is a valid dict key. + **kwargs: any attributes you want to attach to the node. + """ + if n not in self._node: + self._node[n] = kwargs + self._succ[n] = {} + self._pred[n] = {} + self._node_order[n] = self._insertion_idx + self._insertion_idx += 1 + else: + self._node[n].update(kwargs) + + def add_edge(self, u, v): + """Add an edge to graph between nodes ``u`` and ``v`` + + ``u`` and ``v`` will be created if they do not already exist. + """ + # add nodes + self.add_node(u) + self.add_node(v) + + # add the edge + self._succ[u][v] = True + self._pred[v][u] = True + + def successors(self, n): + """Returns an iterator over successor nodes of n.""" + try: + return iter(self._succ[n]) + except KeyError as e: + raise ValueError(f"The node {n} is not in the digraph.") from e + + def predecessors(self, n): + """Returns an iterator over predecessors nodes of n.""" + try: + return iter(self._pred[n]) + except KeyError as e: + raise ValueError(f"The node {n} is not in the digraph.") from e + + @property + def edges(self): + """Returns an iterator over all edges (u, v) in the graph""" + for n, successors in self._succ.items(): + for succ in successors: + yield n, succ + + @property + def nodes(self): + """Returns a dictionary of all nodes to their attributes.""" + return self._node + + def __iter__(self): + """Iterate over the nodes.""" + return iter(self._node) + + def __contains__(self, n): + """Returns True if ``n`` is a node in the graph, False otherwise.""" + try: + return n in self._node + except TypeError: + return False + + def forward_transitive_closure(self, src: str) -> Set[str]: + """Returns a set of nodes that are reachable from src""" + + result = set(src) + working_set = deque(src) + while len(working_set) > 0: + cur = working_set.popleft() + for n in self.successors(cur): + if n not in result: + result.add(n) + working_set.append(n) + return result + + def backward_transitive_closure(self, src: str) -> Set[str]: + """Returns a set of nodes that are reachable from src in reverse direction""" + + result = set(src) + working_set = deque(src) + while len(working_set) > 0: + cur = working_set.popleft() + for n in self.predecessors(cur): + if n not in result: + result.add(n) + working_set.append(n) + return result + + def all_paths(self, src: str, dst: str): + """Returns a subgraph rooted at src that shows all the paths to dst.""" + + result_graph = DiGraph() + # First compute forward transitive closure of src (all things reachable from src). + forward_reachable_from_src = self.forward_transitive_closure(src) + + if dst not in forward_reachable_from_src: + return result_graph + + # Second walk the reverse dependencies of dst, adding each node to + # the output graph iff it is also present in forward_reachable_from_src. + # we don't use backward_transitive_closures for optimization purposes + working_set = deque(dst) + while len(working_set) > 0: + cur = working_set.popleft() + for n in self.predecessors(cur): + if n in forward_reachable_from_src: + result_graph.add_edge(n, cur) + # only explore further if its reachable from src + working_set.append(n) + + return result_graph.to_dot() + + def first_path(self, dst: str) -> List[str]: + """Returns a list of nodes that show the first path that resulted in dst being added to the graph.""" + path = [] + + while dst: + path.append(dst) + candidates = self._pred[dst].keys() + dst, min_idx = "", None + for candidate in candidates: + idx = self._node_order.get(candidate, None) + if idx is None: + break + if min_idx is None or idx < min_idx: + min_idx = idx + dst = candidate + + return list(reversed(path)) + + def to_dot(self) -> str: + """Returns the dot representation of the graph. + + Returns: + A dot representation of the graph. + """ + edges = "\n".join(f'"{f}" -> "{t}";' for f, t in self.edges) + return f"""\ +digraph G {{ +rankdir = LR; +node [shape=box]; +{edges} +}} +""" diff --git a/lib/python3.10/site-packages/torch/package/_directory_reader.py b/lib/python3.10/site-packages/torch/package/_directory_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..f58065f47dc4374689abc7ac1c6f1ca8048ab5c7 --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/_directory_reader.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +import os.path +from glob import glob +from typing import cast + +import torch +from torch.types import Storage + + +__serialization_id_record_name__ = ".data/serialization_id" + + +# because get_storage_from_record returns a tensor!? +class _HasStorage: + def __init__(self, storage): + self._storage = storage + + def storage(self): + return self._storage + + +class DirectoryReader: + """ + Class to allow PackageImporter to operate on unzipped packages. Methods + copy the behavior of the internal PyTorchFileReader class (which is used for + accessing packages in all other cases). + + N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader + class due to ScriptObjects requiring an actual PyTorchFileReader instance. + """ + + def __init__(self, directory): + self.directory = directory + + def get_record(self, name): + filename = f"{self.directory}/{name}" + with open(filename, "rb") as f: + return f.read() + + def get_storage_from_record(self, name, numel, dtype): + filename = f"{self.directory}/{name}" + nbytes = torch._utils._element_size(dtype) * numel + storage = cast(Storage, torch.UntypedStorage) + return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes)) + + def has_record(self, path): + full_path = os.path.join(self.directory, path) + return os.path.isfile(full_path) + + def get_all_records( + self, + ): + files = [] + for filename in glob(f"{self.directory}/**", recursive=True): + if not os.path.isdir(filename): + files.append(filename[len(self.directory) + 1 :]) + return files + + def serialization_id( + self, + ): + if self.has_record(__serialization_id_record_name__): + return self.get_record(__serialization_id_record_name__) + else: + return "" diff --git a/lib/python3.10/site-packages/torch/package/_importlib.py b/lib/python3.10/site-packages/torch/package/_importlib.py new file mode 100644 index 0000000000000000000000000000000000000000..609efd294c4c9650d890fd36aafc9f521068ce8b --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/_importlib.py @@ -0,0 +1,95 @@ +# mypy: allow-untyped-defs +import _warnings +import os.path + + +# note: implementations +# copied from cpython's import code + + +# _zip_searchorder defines how we search for a module in the Zip +# archive: we first search for a package __init__, then for +# non-package .pyc, and .py entries. The .pyc entries +# are swapped by initzipimport() if we run in optimized mode. Also, +# '/' is replaced by path_sep there. + +_zip_searchorder = ( + ("/__init__.py", True), + (".py", False), +) + + +# Replace any occurrences of '\r\n?' in the input string with '\n'. +# This converts DOS and Mac line endings to Unix line endings. +def _normalize_line_endings(source): + source = source.replace(b"\r\n", b"\n") + source = source.replace(b"\r", b"\n") + return source + + +def _resolve_name(name, package, level): + """Resolve a relative module name to an absolute one.""" + bits = package.rsplit(".", level - 1) + if len(bits) < level: + raise ValueError("attempted relative import beyond top-level package") + base = bits[0] + return f"{base}.{name}" if name else base + + +def _sanity_check(name, package, level): + """Verify arguments are "sane".""" + if not isinstance(name, str): + raise TypeError(f"module name must be str, not {type(name)}") + if level < 0: + raise ValueError("level must be >= 0") + if level > 0: + if not isinstance(package, str): + raise TypeError("__package__ not set to a string") + elif not package: + raise ImportError("attempted relative import with no known parent package") + if not name and level == 0: + raise ValueError("Empty module name") + + +def _calc___package__(globals): + """Calculate what __package__ should be. + + __package__ is not guaranteed to be defined or could be set to None + to represent that its proper value is unknown. + + """ + package = globals.get("__package__") + spec = globals.get("__spec__") + if package is not None: + if spec is not None and package != spec.parent: + _warnings.warn( # noqa: G010 + f"__package__ != __spec__.parent ({package!r} != {spec.parent!r})", # noqa: G004 + ImportWarning, + stacklevel=3, + ) + return package + elif spec is not None: + return spec.parent + else: + _warnings.warn( # noqa: G010 + "can't resolve package from __spec__ or __package__, " + "falling back on __name__ and __path__", + ImportWarning, + stacklevel=3, + ) + package = globals["__name__"] + if "__path__" not in globals: + package = package.rpartition(".")[0] + return package + + +def _normalize_path(path): + """Normalize a path by ensuring it is a string. + + If the resulting string contains path separators, an exception is raised. + """ + parent, file_name = os.path.split(path) + if parent: + raise ValueError(f"{path!r} must be only a file name") + else: + return file_name diff --git a/lib/python3.10/site-packages/torch/package/_mangling.py b/lib/python3.10/site-packages/torch/package/_mangling.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf3791d160444e5c6ae397b7fccf6cadca96d0a --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/_mangling.py @@ -0,0 +1,64 @@ +# mypy: allow-untyped-defs +"""Import mangling. +See mangling.md for details. +""" +import re + + +_mangle_index = 0 + + +class PackageMangler: + """ + Used on import, to ensure that all modules imported have a shared mangle parent. + """ + + def __init__(self) -> None: + global _mangle_index + self._mangle_index = _mangle_index + # Increment the global index + _mangle_index += 1 + # Angle brackets are used so that there is almost no chance of + # confusing this module for a real module. Plus, it is Python's + # preferred way of denoting special modules. + self._mangle_parent = f"" + + def mangle(self, name) -> str: + assert len(name) != 0 + return self._mangle_parent + "." + name + + def demangle(self, mangled: str) -> str: + """ + Note: This only demangles names that were mangled by this specific + PackageMangler. It will pass through names created by a different + PackageMangler instance. + """ + if mangled.startswith(self._mangle_parent + "."): + return mangled.partition(".")[2] + + # wasn't a mangled name + return mangled + + def parent_name(self): + return self._mangle_parent + + +def is_mangled(name: str) -> bool: + return bool(re.match(r"", name)) + + +def demangle(name: str) -> str: + """ + Note: Unlike PackageMangler.demangle, this version works on any + mangled name, irrespective of which PackageMangler created it. + """ + if is_mangled(name): + first, sep, last = name.partition(".") + # If there is only a base mangle prefix, e.g. '', + # then return an empty string. + return last if len(sep) != 0 else "" + return name + + +def get_mangle_prefix(name: str) -> str: + return name.partition(".")[0] if is_mangled(name) else name diff --git a/lib/python3.10/site-packages/torch/package/_mock.py b/lib/python3.10/site-packages/torch/package/_mock.py new file mode 100644 index 0000000000000000000000000000000000000000..44876b1a1d3fb3ef4a485eaf16f26755d5bb00f2 --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/_mock.py @@ -0,0 +1,123 @@ +# mypy: allow-untyped-defs +_magic_methods = [ + "__subclasscheck__", + "__hex__", + "__rmul__", + "__float__", + "__idiv__", + "__setattr__", + "__div__", + "__invert__", + "__nonzero__", + "__rshift__", + "__eq__", + "__pos__", + "__round__", + "__rand__", + "__or__", + "__complex__", + "__divmod__", + "__len__", + "__reversed__", + "__copy__", + "__reduce__", + "__deepcopy__", + "__rdivmod__", + "__rrshift__", + "__ifloordiv__", + "__hash__", + "__iand__", + "__xor__", + "__isub__", + "__oct__", + "__ceil__", + "__imod__", + "__add__", + "__truediv__", + "__unicode__", + "__le__", + "__delitem__", + "__sizeof__", + "__sub__", + "__ne__", + "__pow__", + "__bytes__", + "__mul__", + "__itruediv__", + "__bool__", + "__iter__", + "__abs__", + "__gt__", + "__iadd__", + "__enter__", + "__floordiv__", + "__call__", + "__neg__", + "__and__", + "__ixor__", + "__getitem__", + "__exit__", + "__cmp__", + "__getstate__", + "__index__", + "__contains__", + "__floor__", + "__lt__", + "__getattr__", + "__mod__", + "__trunc__", + "__delattr__", + "__instancecheck__", + "__setitem__", + "__ipow__", + "__ilshift__", + "__long__", + "__irshift__", + "__imul__", + "__lshift__", + "__dir__", + "__ge__", + "__int__", + "__ior__", +] + + +class MockedObject: + _name: str + + def __new__(cls, *args, **kwargs): + # _suppress_err is set by us in the mocked module impl, so that we can + # construct instances of MockedObject to hand out to people looking up + # module attributes. + + # Any other attempt to construct a MockedObject instance (say, in the + # unpickling process) should give an error. + if not kwargs.get("_suppress_err"): + raise NotImplementedError( + f"Object '{cls._name}' was mocked out during packaging " + f"but it is being used in '__new__'. If this error is " + "happening during 'load_pickle', please ensure that your " + "pickled object doesn't contain any mocked objects." + ) + # Otherwise, this is just a regular object creation + # (e.g. `x = MockedObject("foo")`), so pass it through normally. + return super().__new__(cls) + + def __init__(self, name: str, _suppress_err: bool): + self.__dict__["_name"] = name + + def __repr__(self): + return f"MockedObject({self._name})" + + +def install_method(method_name): + def _not_implemented(self, *args, **kwargs): + raise NotImplementedError( + f"Object '{self._name}' was mocked out during packaging but it is being used in {method_name}" + ) + + setattr(MockedObject, method_name, _not_implemented) + + +for method_name in _magic_methods: + install_method(method_name) diff --git a/lib/python3.10/site-packages/torch/package/_package_pickler.py b/lib/python3.10/site-packages/torch/package/_package_pickler.py new file mode 100644 index 0000000000000000000000000000000000000000..7845ffe39a2a2efbb830eebaa24ccb8ef98ad06c --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/_package_pickler.py @@ -0,0 +1,118 @@ +# mypy: allow-untyped-defs +from pickle import ( # type: ignore[attr-defined] + _compat_pickle, + _extension_registry, + _getattribute, + _Pickler, + EXT1, + EXT2, + EXT4, + GLOBAL, + Pickler, + PicklingError, + STACK_GLOBAL, +) +from struct import pack +from types import FunctionType + +from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer + + +class PackagePickler(_Pickler): + """Package-aware pickler. + + This behaves the same as a normal pickler, except it uses an `Importer` + to find objects and modules to save. + """ + + def __init__(self, importer: Importer, *args, **kwargs): + self.importer = importer + super().__init__(*args, **kwargs) + + # Make sure the dispatch table copied from _Pickler is up-to-date. + # Previous issues have been encountered where a library (e.g. dill) + # mutate _Pickler.dispatch, PackagePickler makes a copy when this lib + # is imported, then the offending library removes its dispatch entries, + # leaving PackagePickler with a stale dispatch table that may cause + # unwanted behavior. + self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc] + self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment] + + def save_global(self, obj, name=None): + # unfortunately the pickler code is factored in a way that + # forces us to copy/paste this function. The only change is marked + # CHANGED below. + write = self.write # type: ignore[attr-defined] + memo = self.memo # type: ignore[attr-defined] + + # CHANGED: import module from module environment instead of __import__ + try: + module_name, name = self.importer.get_name(obj, name) + except (ObjNotFoundError, ObjMismatchError) as err: + raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None + + module = self.importer.import_module(module_name) + _, parent = _getattribute(module, name) + # END CHANGED + + if self.proto >= 2: # type: ignore[attr-defined] + code = _extension_registry.get((module_name, name)) + if code: + assert code > 0 + if code <= 0xFF: + write(EXT1 + pack("= 3. + if self.proto >= 4: # type: ignore[attr-defined] + self.save(module_name) # type: ignore[attr-defined] + self.save(name) # type: ignore[attr-defined] + write(STACK_GLOBAL) + elif parent is not module: + self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined] + elif self.proto >= 3: # type: ignore[attr-defined] + write( + GLOBAL + + bytes(module_name, "utf-8") + + b"\n" + + bytes(name, "utf-8") + + b"\n" + ) + else: + if self.fix_imports: # type: ignore[attr-defined] + r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING + r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING + if (module_name, name) in r_name_mapping: + module_name, name = r_name_mapping[(module_name, name)] + elif module_name in r_import_mapping: + module_name = r_import_mapping[module_name] + try: + write( + GLOBAL + + bytes(module_name, "ascii") + + b"\n" + + bytes(name, "ascii") + + b"\n" + ) + except UnicodeEncodeError: + raise PicklingError( + "can't pickle global identifier '%s.%s' using " + "pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined] + ) from None + + self.memoize(obj) # type: ignore[attr-defined] + + +def create_pickler(data_buf, importer, protocol=4): + if importer is sys_importer: + # if we are using the normal import library system, then + # we can use the C implementation of pickle which is faster + return Pickler(data_buf, protocol=protocol) + else: + return PackagePickler(importer, data_buf, protocol=protocol) diff --git a/lib/python3.10/site-packages/torch/package/_package_unpickler.py b/lib/python3.10/site-packages/torch/package/_package_unpickler.py new file mode 100644 index 0000000000000000000000000000000000000000..890e6b4e03ba076e30512712d57c4bf715c4c8bb --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/_package_unpickler.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +import _compat_pickle +import pickle + +from .importer import Importer + + +class PackageUnpickler(pickle._Unpickler): # type: ignore[name-defined] + """Package-aware unpickler. + + This behaves the same as a normal unpickler, except it uses `importer` to + find any global names that it encounters while unpickling. + """ + + def __init__(self, importer: Importer, *args, **kwargs): + super().__init__(*args, **kwargs) + self._importer = importer + + def find_class(self, module, name): + # Subclasses may override this. + if self.proto < 3 and self.fix_imports: # type: ignore[attr-defined] + if (module, name) in _compat_pickle.NAME_MAPPING: + module, name = _compat_pickle.NAME_MAPPING[(module, name)] + elif module in _compat_pickle.IMPORT_MAPPING: + module = _compat_pickle.IMPORT_MAPPING[module] + mod = self._importer.import_module(module) + return getattr(mod, name) diff --git a/lib/python3.10/site-packages/torch/package/_stdlib.py b/lib/python3.10/site-packages/torch/package/_stdlib.py new file mode 100644 index 0000000000000000000000000000000000000000..2d5145b40aa701043badad77889ff50da06ad497 --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/_stdlib.py @@ -0,0 +1,465 @@ +# mypy: allow-untyped-defs +"""List of Python standard library modules. + +Sadly, there is no reliable way to tell whether a module is part of the +standard library except by comparing to a canonical list. + +This is taken from https://github.com/PyCQA/isort/tree/develop/isort/stdlibs, +which itself is sourced from the Python documentation. +""" + +import sys + + +def is_stdlib_module(module: str) -> bool: + base_module = module.partition(".")[0] + return base_module in _get_stdlib_modules() + + +def _get_stdlib_modules(): + if sys.version_info.major == 3: + if sys.version_info.minor == 8: + return stdlib3_8 + if sys.version_info.minor == 9: + return stdlib3_9 + if sys.version_info.minor >= 10: + return sys.stdlib_module_names # type: ignore[attr-defined] + elif sys.version_info.major > 3: + return sys.stdlib_module_names # type: ignore[attr-defined] + + raise RuntimeError(f"Unsupported Python version: {sys.version_info}") + + +stdlib3_8 = { + "_dummy_thread", + "_thread", + "abc", + "aifc", + "argparse", + "array", + "ast", + "asynchat", + "asyncio", + "asyncore", + "atexit", + "audioop", + "base64", + "bdb", + "binascii", + "binhex", + "bisect", + "builtins", + "bz2", + "cProfile", + "calendar", + "cgi", + "cgitb", + "chunk", + "cmath", + "cmd", + "code", + "codecs", + "codeop", + "collections", + "colorsys", + "compileall", + "concurrent", + "configparser", + "contextlib", + "contextvars", + "copy", + "copyreg", + "crypt", + "csv", + "ctypes", + "curses", + "dataclasses", + "datetime", + "dbm", + "decimal", + "difflib", + "dis", + "distutils", + "doctest", + "dummy_threading", + "email", + "encodings", + "ensurepip", + "enum", + "errno", + "faulthandler", + "fcntl", + "filecmp", + "fileinput", + "fnmatch", + "formatter", + "fractions", + "ftplib", + "functools", + "gc", + "getopt", + "getpass", + "gettext", + "glob", + "grp", + "gzip", + "hashlib", + "heapq", + "hmac", + "html", + "http", + "imaplib", + "imghdr", + "imp", + "importlib", + "inspect", + "io", + "ipaddress", + "itertools", + "json", + "keyword", + "lib2to3", + "linecache", + "locale", + "logging", + "lzma", + "mailbox", + "mailcap", + "marshal", + "math", + "mimetypes", + "mmap", + "modulefinder", + "msilib", + "msvcrt", + "multiprocessing", + "netrc", + "nis", + "nntplib", + "ntpath", + "numbers", + "operator", + "optparse", + "os", + "ossaudiodev", + "parser", + "pathlib", + "pdb", + "pickle", + "pickletools", + "pipes", + "pkgutil", + "platform", + "plistlib", + "poplib", + "posix", + "posixpath", + "pprint", + "profile", + "pstats", + "pty", + "pwd", + "py_compile", + "pyclbr", + "pydoc", + "queue", + "quopri", + "random", + "re", + "readline", + "reprlib", + "resource", + "rlcompleter", + "runpy", + "sched", + "secrets", + "select", + "selectors", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "smtpd", + "smtplib", + "sndhdr", + "socket", + "socketserver", + "spwd", + "sqlite3", + "sre", + "sre_compile", + "sre_constants", + "sre_parse", + "ssl", + "stat", + "statistics", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "symbol", + "symtable", + "sys", + "sysconfig", + "syslog", + "tabnanny", + "tarfile", + "telnetlib", + "tempfile", + "termios", + "test", + "textwrap", + "threading", + "time", + "timeit", + "tkinter", + "token", + "tokenize", + "trace", + "traceback", + "tracemalloc", + "tty", + "turtle", + "turtledemo", + "types", + "typing", + "unicodedata", + "unittest", + "urllib", + "uu", + "uuid", + "venv", + "warnings", + "wave", + "weakref", + "webbrowser", + "winreg", + "winsound", + "wsgiref", + "xdrlib", + "xml", + "xmlrpc", + "zipapp", + "zipfile", + "zipimport", + "zlib", +} + +stdlib3_9 = { + "_thread", + "abc", + "aifc", + "argparse", + "array", + "ast", + "asynchat", + "asyncio", + "asyncore", + "atexit", + "audioop", + "base64", + "bdb", + "binascii", + "binhex", + "bisect", + "builtins", + "bz2", + "cProfile", + "calendar", + "cgi", + "cgitb", + "chunk", + "cmath", + "cmd", + "code", + "codecs", + "codeop", + "collections", + "colorsys", + "compileall", + "concurrent", + "configparser", + "contextlib", + "contextvars", + "copy", + "copyreg", + "crypt", + "csv", + "ctypes", + "curses", + "dataclasses", + "datetime", + "dbm", + "decimal", + "difflib", + "dis", + "distutils", + "doctest", + "email", + "encodings", + "ensurepip", + "enum", + "errno", + "faulthandler", + "fcntl", + "filecmp", + "fileinput", + "fnmatch", + "formatter", + "fractions", + "ftplib", + "functools", + "gc", + "getopt", + "getpass", + "gettext", + "glob", + "graphlib", + "grp", + "gzip", + "hashlib", + "heapq", + "hmac", + "html", + "http", + "imaplib", + "imghdr", + "imp", + "importlib", + "inspect", + "io", + "ipaddress", + "itertools", + "json", + "keyword", + "lib2to3", + "linecache", + "locale", + "logging", + "lzma", + "mailbox", + "mailcap", + "marshal", + "math", + "mimetypes", + "mmap", + "modulefinder", + "msilib", + "msvcrt", + "multiprocessing", + "netrc", + "nis", + "nntplib", + "ntpath", + "numbers", + "operator", + "optparse", + "os", + "ossaudiodev", + "parser", + "pathlib", + "pdb", + "pickle", + "pickletools", + "pipes", + "pkgutil", + "platform", + "plistlib", + "poplib", + "posix", + "posixpath", + "pprint", + "profile", + "pstats", + "pty", + "pwd", + "py_compile", + "pyclbr", + "pydoc", + "queue", + "quopri", + "random", + "re", + "readline", + "reprlib", + "resource", + "rlcompleter", + "runpy", + "sched", + "secrets", + "select", + "selectors", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "smtpd", + "smtplib", + "sndhdr", + "socket", + "socketserver", + "spwd", + "sqlite3", + "sre", + "sre_compile", + "sre_constants", + "sre_parse", + "ssl", + "stat", + "statistics", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "symbol", + "symtable", + "sys", + "sysconfig", + "syslog", + "tabnanny", + "tarfile", + "telnetlib", + "tempfile", + "termios", + "test", + "textwrap", + "threading", + "time", + "timeit", + "tkinter", + "token", + "tokenize", + "trace", + "traceback", + "tracemalloc", + "tty", + "turtle", + "turtledemo", + "types", + "typing", + "unicodedata", + "unittest", + "urllib", + "uu", + "uuid", + "venv", + "warnings", + "wave", + "weakref", + "webbrowser", + "winreg", + "winsound", + "wsgiref", + "xdrlib", + "xml", + "xmlrpc", + "zipapp", + "zipfile", + "zipimport", + "zlib", + "zoneinfo", +} diff --git a/lib/python3.10/site-packages/torch/package/file_structure_representation.py b/lib/python3.10/site-packages/torch/package/file_structure_representation.py new file mode 100644 index 0000000000000000000000000000000000000000..e1137234ab739b480ae7d5686c7dc313147ba314 --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/file_structure_representation.py @@ -0,0 +1,138 @@ +# mypy: allow-untyped-defs +from typing import Dict, List + +from .glob_group import GlobGroup, GlobPattern + + +__all__ = ["Directory"] + + +class Directory: + """A file structure representation. Organized as Directory nodes that have lists of + their Directory children. Directories for a package are created by calling + :meth:`PackageImporter.file_structure`.""" + + def __init__(self, name: str, is_dir: bool): + self.name = name + self.is_dir = is_dir + self.children: Dict[str, Directory] = {} + + def _get_dir(self, dirs: List[str]) -> "Directory": + """Builds path of Directories if not yet built and returns last directory + in list. + + Args: + dirs (List[str]): List of directory names that are treated like a path. + + Returns: + :class:`Directory`: The last Directory specified in the dirs list. + """ + if len(dirs) == 0: + return self + dir_name = dirs[0] + if dir_name not in self.children: + self.children[dir_name] = Directory(dir_name, True) + return self.children[dir_name]._get_dir(dirs[1:]) + + def _add_file(self, file_path: str): + """Adds a file to a Directory. + + Args: + file_path (str): Path of file to add. Last element is added as a file while + other paths items are added as directories. + """ + *dirs, file = file_path.split("/") + dir = self._get_dir(dirs) + dir.children[file] = Directory(file, False) + + def has_file(self, filename: str) -> bool: + """Checks if a file is present in a :class:`Directory`. + + Args: + filename (str): Path of file to search for. + Returns: + bool: If a :class:`Directory` contains the specified file. + """ + lineage = filename.split("/", maxsplit=1) + child = lineage[0] + grandchildren = lineage[1] if len(lineage) > 1 else None + if child in self.children.keys(): + if grandchildren is None: + return True + else: + return self.children[child].has_file(grandchildren) + return False + + def __str__(self): + str_list: List[str] = [] + self._stringify_tree(str_list) + return "".join(str_list) + + def _stringify_tree( + self, + str_list: List[str], + preamble: str = "", + dir_ptr: str = "\u2500\u2500\u2500 ", + ): + """Recursive method to generate print-friendly version of a Directory.""" + space = " " + branch = "\u2502 " + tee = "\u251c\u2500\u2500 " + last = "\u2514\u2500\u2500 " + + # add this directory's representation + str_list.append(f"{preamble}{dir_ptr}{self.name}\n") + + # add directory's children representations + if dir_ptr == tee: + preamble = preamble + branch + else: + preamble = preamble + space + + file_keys: List[str] = [] + dir_keys: List[str] = [] + for key, val in self.children.items(): + if val.is_dir: + dir_keys.append(key) + else: + file_keys.append(key) + + for index, key in enumerate(sorted(dir_keys)): + if (index == len(dir_keys) - 1) and len(file_keys) == 0: + self.children[key]._stringify_tree(str_list, preamble, last) + else: + self.children[key]._stringify_tree(str_list, preamble, tee) + for index, file in enumerate(sorted(file_keys)): + pointer = last if (index == len(file_keys) - 1) else tee + str_list.append(f"{preamble}{pointer}{file}\n") + + +def _create_directory_from_file_list( + filename: str, + file_list: List[str], + include: "GlobPattern" = "**", + exclude: "GlobPattern" = (), +) -> Directory: + """Return a :class:`Directory` file structure representation created from a list of files. + + Args: + filename (str): The name given to the top-level directory that will be the + relative root for all file paths found in the file_list. + + file_list (List[str]): List of files to add to the top-level directory. + + include (Union[List[str], str]): An optional pattern that limits what is included from the file_list to + files whose name matches the pattern. + + exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern. + + Returns: + :class:`Directory`: a :class:`Directory` file structure representation created from a list of files. + """ + glob_pattern = GlobGroup(include, exclude=exclude, separator="/") + + top_dir = Directory(filename, True) + for file in file_list: + if glob_pattern.matches(file): + top_dir._add_file(file) + return top_dir diff --git a/lib/python3.10/site-packages/torch/package/find_file_dependencies.py b/lib/python3.10/site-packages/torch/package/find_file_dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..80cfccbec50a6c5f4cb3839d1afc5e4ea404efd2 --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/find_file_dependencies.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +import ast +from typing import List, Optional, Tuple + +from ._importlib import _resolve_name + + +class _ExtractModuleReferences(ast.NodeVisitor): + """ + Extract the list of global variables a block of code will read and write + """ + + @classmethod + def run(cls, src: str, package: str) -> List[Tuple[str, Optional[str]]]: + visitor = cls(package) + tree = ast.parse(src) + visitor.visit(tree) + return list(visitor.references.keys()) + + def __init__(self, package): + super().__init__() + self.package = package + self.references = {} + + def _absmodule(self, module_name: str, level: int) -> str: + if level > 0: + return _resolve_name(module_name, self.package, level) + return module_name + + def visit_Import(self, node): + for alias in node.names: + self.references[(alias.name, None)] = True + + def visit_ImportFrom(self, node): + name = self._absmodule(node.module, 0 if node.level is None else node.level) + for alias in node.names: + # from my_package import foo + # foo may be a module, so we have to add it to the list of + # potential references, if import of it fails, we will ignore it + if alias.name != "*": + self.references[(name, alias.name)] = True + else: + self.references[(name, None)] = True + + def _grab_node_int(self, node): + return node.value + + def _grab_node_str(self, node): + return node.value + + def visit_Call(self, node): + # __import__ calls aren't routed to the visit_Import/From nodes + if hasattr(node.func, "id") and node.func.id == "__import__": + try: + name = self._grab_node_str(node.args[0]) + fromlist = [] + level = 0 + if len(node.args) > 3: + for v in node.args[3].elts: + fromlist.append(self._grab_node_str(v)) + elif hasattr(node, "keywords"): + for keyword in node.keywords: + if keyword.arg == "fromlist": + for v in keyword.value.elts: + fromlist.append(self._grab_node_str(v)) + if len(node.args) > 4: + level = self._grab_node_int(node.args[4]) + elif hasattr(node, "keywords"): + for keyword in node.keywords: + if keyword.arg == "level": + level = self._grab_node_int(keyword.value) + if fromlist == []: + # the top-level package (the name up till the first dot) is returned + # when the fromlist argument is empty in normal import system, + # we need to include top level package to match this behavior and last + # level package to capture the intended dependency of user + self.references[(name, None)] = True + top_name = name.rsplit(".", maxsplit=1)[0] + if top_name != name: + top_name = self._absmodule(top_name, level) + self.references[(top_name, None)] = True + else: + name = self._absmodule(name, level) + for alias in fromlist: + # fromlist args may be submodules, so we have to add the fromlist args + # to the list of potential references. If import of an arg fails we + # will ignore it, similar to visit_ImportFrom + if alias != "*": + self.references[(name, alias)] = True + else: + self.references[(name, None)] = True + except Exception as e: + return + + +find_files_source_depends_on = _ExtractModuleReferences.run diff --git a/lib/python3.10/site-packages/torch/package/glob_group.py b/lib/python3.10/site-packages/torch/package/glob_group.py new file mode 100644 index 0000000000000000000000000000000000000000..1c1d31930fd18fa5bbad453d2b55eac07fa20869 --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/glob_group.py @@ -0,0 +1,84 @@ +# mypy: allow-untyped-defs +import re +from typing import Iterable, Union + + +GlobPattern = Union[str, Iterable[str]] + + +class GlobGroup: + """A set of patterns that candidate strings will be matched against. + + A candidate is composed of a list of segments separated by ``separator``, e.g. "foo.bar.baz". + + A pattern contains one or more segments. Segments can be: + - A literal string (e.g. "foo"), which matches exactly. + - A string containing a wildcard (e.g. "torch*", or "foo*baz*"). The wildcard matches + any string, including the empty string. + - A double wildcard ("**"). This matches against zero or more complete segments. + + Examples: + ``torch.**``: matches ``torch`` and all its submodules, e.g. ``torch.nn`` and ``torch.nn.functional``. + ``torch.*``: matches ``torch.nn`` or ``torch.functional``, but not ``torch.nn.functional``. + ``torch*.**``: matches ``torch``, ``torchvision``, and all their submodules. + + A candidates will match the ``GlobGroup`` if it matches any of the ``include`` patterns and + none of the ``exclude`` patterns. + + Args: + include (Union[str, Iterable[str]]): A string or list of strings, + each representing a pattern to be matched against. A candidate + will match if it matches *any* include pattern + exclude (Union[str, Iterable[str]]): A string or list of strings, + each representing a pattern to be matched against. A candidate + will be excluded from matching if it matches *any* exclude pattern. + separator (str): A string that delimits segments in candidates and + patterns. By default this is "." which corresponds to how modules are + named in Python. Another common value for this is "/", which is + the Unix path separator. + """ + + def __init__( + self, include: GlobPattern, *, exclude: GlobPattern = (), separator: str = "." + ): + self._dbg = f"GlobGroup(include={include}, exclude={exclude})" + self.include = GlobGroup._glob_list(include, separator) + self.exclude = GlobGroup._glob_list(exclude, separator) + self.separator = separator + + def __str__(self): + return self._dbg + + def __repr__(self): + return self._dbg + + def matches(self, candidate: str) -> bool: + candidate = self.separator + candidate + return any(p.fullmatch(candidate) for p in self.include) and all( + not p.fullmatch(candidate) for p in self.exclude + ) + + @staticmethod + def _glob_list(elems: GlobPattern, separator: str = "."): + if isinstance(elems, str): + return [GlobGroup._glob_to_re(elems, separator)] + else: + return [GlobGroup._glob_to_re(e, separator) for e in elems] + + @staticmethod + def _glob_to_re(pattern: str, separator: str = "."): + # to avoid corner cases for the first component, we prefix the candidate string + # with '.' so `import torch` will regex against `.torch`, assuming '.' is the separator + def component_to_re(component): + if "**" in component: + if component == "**": + return "(" + re.escape(separator) + "[^" + separator + "]+)*" + else: + raise ValueError("** can only appear as an entire path segment") + else: + return re.escape(separator) + ("[^" + separator + "]*").join( + re.escape(x) for x in component.split("*") + ) + + result = "".join(component_to_re(c) for c in pattern.split(separator)) + return re.compile(result) diff --git a/lib/python3.10/site-packages/torch/package/importer.py b/lib/python3.10/site-packages/torch/package/importer.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb2891e076c7baf9dda3e1f92c5527f0f54f71d --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/importer.py @@ -0,0 +1,234 @@ +# mypy: allow-untyped-defs +import importlib +from abc import ABC, abstractmethod +from pickle import ( # type: ignore[attr-defined] + _getattribute, + _Pickler, + whichmodule as _pickle_whichmodule, +) +from types import ModuleType +from typing import Any, Dict, List, Optional, Tuple + +from ._mangling import demangle, get_mangle_prefix, is_mangled + + +__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] + + +class ObjNotFoundError(Exception): + """Raised when an importer cannot find an object by searching for its name.""" + + +class ObjMismatchError(Exception): + """Raised when an importer found a different object with the same name as the user-provided one.""" + + +class Importer(ABC): + """Represents an environment to import modules from. + + By default, you can figure out what module an object belongs by checking + __module__ and importing the result using __import__ or importlib.import_module. + + torch.package introduces module importers other than the default one. + Each PackageImporter introduces a new namespace. Potentially a single + name (e.g. 'foo.bar') is present in multiple namespaces. + + It supports two main operations: + import_module: module_name -> module object + get_name: object -> (parent module name, name of obj within module) + + The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError. + module_name, obj_name = env.get_name(obj) + module = env.import_module(module_name) + obj2 = getattr(module, obj_name) + assert obj1 is obj2 + """ + + modules: Dict[str, ModuleType] + + @abstractmethod + def import_module(self, module_name: str) -> ModuleType: + """Import `module_name` from this environment. + + The contract is the same as for importlib.import_module. + """ + + def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]: + """Given an object, return a name that can be used to retrieve the + object from this environment. + + Args: + obj: An object to get the module-environment-relative name for. + name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`. + This is only here to match how Pickler handles __reduce__ functions that return a string, + don't use otherwise. + Returns: + A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment. + Use it like: + mod = importer.import_module(parent_module_name) + obj = getattr(mod, attr_name) + + Raises: + ObjNotFoundError: we couldn't retrieve `obj by name. + ObjMisMatchError: we found a different object with the same name as `obj`. + """ + if name is None and obj and _Pickler.dispatch.get(type(obj)) is None: + # Honor the string return variant of __reduce__, which will give us + # a global name to search for in this environment. + # TODO: I guess we should do copyreg too? + reduce = getattr(obj, "__reduce__", None) + if reduce is not None: + try: + rv = reduce() + if isinstance(rv, str): + name = rv + except Exception: + pass + if name is None: + name = getattr(obj, "__qualname__", None) + if name is None: + name = obj.__name__ + + orig_module_name = self.whichmodule(obj, name) + # Demangle the module name before importing. If this obj came out of a + # PackageImporter, `__module__` will be mangled. See mangling.md for + # details. + module_name = demangle(orig_module_name) + + # Check that this name will indeed return the correct object + try: + module = self.import_module(module_name) + obj2, _ = _getattribute(module, name) + except (ImportError, KeyError, AttributeError): + raise ObjNotFoundError( + f"{obj} was not found as {module_name}.{name}" + ) from None + + if obj is obj2: + return module_name, name + + def get_obj_info(obj): + assert name is not None + module_name = self.whichmodule(obj, name) + is_mangled_ = is_mangled(module_name) + location = ( + get_mangle_prefix(module_name) + if is_mangled_ + else "the current Python environment" + ) + importer_name = ( + f"the importer for {get_mangle_prefix(module_name)}" + if is_mangled_ + else "'sys_importer'" + ) + return module_name, location, importer_name + + obj_module_name, obj_location, obj_importer_name = get_obj_info(obj) + obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2) + msg = ( + f"\n\nThe object provided is from '{obj_module_name}', " + f"which is coming from {obj_location}." + f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}." + "\nTo fix this, make sure this 'PackageExporter's importer lists " + f"{obj_importer_name} before {obj2_importer_name}." + ) + raise ObjMismatchError(msg) + + def whichmodule(self, obj: Any, name: str) -> str: + """Find the module name an object belongs to. + + This should be considered internal for end-users, but developers of + an importer can override it to customize the behavior. + + Taken from pickle.py, but modified to exclude the search into sys.modules + """ + module_name = getattr(obj, "__module__", None) + if module_name is not None: + return module_name + + # Protect the iteration by using a list copy of self.modules against dynamic + # modules that trigger imports of other modules upon calls to getattr. + for module_name, module in self.modules.copy().items(): + if ( + module_name == "__main__" + or module_name == "__mp_main__" # bpo-42406 + or module is None + ): + continue + try: + if _getattribute(module, name)[0] is obj: + return module_name + except AttributeError: + pass + + return "__main__" + + +class _SysImporter(Importer): + """An importer that implements the default behavior of Python.""" + + def import_module(self, module_name: str): + return importlib.import_module(module_name) + + def whichmodule(self, obj: Any, name: str) -> str: + return _pickle_whichmodule(obj, name) + + +sys_importer = _SysImporter() + + +class OrderedImporter(Importer): + """A compound importer that takes a list of importers and tries them one at a time. + + The first importer in the list that returns a result "wins". + """ + + def __init__(self, *args): + self._importers: List[Importer] = list(args) + + def _is_torchpackage_dummy(self, module): + """Returns true iff this module is an empty PackageNode in a torch.package. + + If you intern `a.b` but never use `a` in your code, then `a` will be an + empty module with no source. This can break cases where we are trying to + re-package an object after adding a real dependency on `a`, since + OrderedImportere will resolve `a` to the dummy package and stop there. + + See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769 + """ + if not getattr(module, "__torch_package__", False): + return False + if not hasattr(module, "__path__"): + return False + if not hasattr(module, "__file__"): + return True + return module.__file__ is None + + def import_module(self, module_name: str) -> ModuleType: + last_err = None + for importer in self._importers: + if not isinstance(importer, Importer): + raise TypeError( + f"{importer} is not a Importer. " + "All importers in OrderedImporter must inherit from Importer." + ) + try: + module = importer.import_module(module_name) + if self._is_torchpackage_dummy(module): + continue + return module + except ModuleNotFoundError as err: + last_err = err + + if last_err is not None: + raise last_err + else: + raise ModuleNotFoundError(module_name) + + def whichmodule(self, obj: Any, name: str) -> str: + for importer in self._importers: + module_name = importer.whichmodule(obj, name) + if module_name != "__main__": + return module_name + + return "__main__" diff --git a/lib/python3.10/site-packages/torch/package/package_exporter.py b/lib/python3.10/site-packages/torch/package/package_exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..7b377b95454da1492ef09e469c7813e55376c2e0 --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/package_exporter.py @@ -0,0 +1,1199 @@ +# mypy: allow-untyped-defs +import collections +import importlib.machinery +import io +import linecache +import pickletools +import platform +import types +from collections import defaultdict, OrderedDict +from dataclasses import dataclass +from enum import Enum +from importlib.machinery import SourceFileLoader +from pathlib import Path +from typing import ( + Any, + BinaryIO, + Callable, + cast, + DefaultDict, + Dict, + List, + Optional, + Sequence, + Set, + Union, +) + +import torch +from torch.serialization import location_tag, normalize_storage_type +from torch.types import Storage +from torch.utils.hooks import RemovableHandle + +from ._digraph import DiGraph +from ._importlib import _normalize_path +from ._mangling import demangle, is_mangled +from ._package_pickler import create_pickler +from ._stdlib import is_stdlib_module +from .find_file_dependencies import find_files_source_depends_on +from .glob_group import GlobGroup, GlobPattern +from .importer import Importer, OrderedImporter, sys_importer + + +__all__ = [ + "PackagingErrorReason", + "EmptyMatchError", + "PackagingError", + "PackageExporter", +] + +_gate_torchscript_serialization = True + +ActionHook = Callable[["PackageExporter", str], None] + + +class _ModuleProviderAction(Enum): + """Represents one of the actions that :class:`PackageExporter` can take on a module. + + See :meth:`PackageExporter.extern` and friends for a description of what the actions do. + """ + + INTERN = 1 + EXTERN = 2 + MOCK = 3 + DENY = 4 + # Special case: when a module is mocked, PackageExporter writes out a + # `_mock` module that implements our mocking stubs. If we re-package code, + # we may encounter a `_mock` module from the original package. If we do, + # just ignore it and write a `_mock` module once. + REPACKAGED_MOCK_MODULE = 5 + # Special case: PackageImporter adds a fake module + # (`torch_package_importer`) that allows packaged code to access it. Don't + # re-export this. + SKIP = 6 + + +class PackagingErrorReason(Enum): + """Listing of different reasons a dependency may fail to package. + + This enum is used to provide good error messages when + :class:`PackagingError` is raised. + """ + + def __repr__(self): + return f"<{self.__class__.__name__}.{self.name}>" + + IS_EXTENSION_MODULE = ( + "Module is a C extension module. torch.package supports Python modules only." + ) + NO_DUNDER_FILE = "Module had no __file__ defined." + SOURCE_FILE_NOT_FOUND = ( + "Module had a __file__, but we could not find it in your filesystem." + ) + DEPENDENCY_RESOLUTION_FAILED = "Dependency resolution failed." + NO_ACTION = ( + "Module did not match against any action pattern. Extern, mock, or intern it." + ) + DENIED = "Module was denied by a pattern." + MOCKED_BUT_STILL_USED = ( + "Module was mocked out, but is still being used in the package. " + "Please intern or extern the mocked modules if objects are supposed to be in " + "the package." + ) + + +@dataclass +class _PatternInfo: + """Holds :class:`PackageExporter`-specific info about how to execute matches against""" + + # What action to take on a module that matches this pattern. + action: _ModuleProviderAction + # The value of `allow_empty` the user gave when specifying the pattern. + allow_empty: bool + # Whether this pattern has been matched during packaging. + was_matched: bool + + def __init__(self, action, allow_empty): + self.action = action + self.allow_empty = allow_empty + self.was_matched = False + + +class EmptyMatchError(Exception): + """This is an exception that is thrown when a mock or extern is marked as + ``allow_empty=False``, and is not matched with any module during packaging. + """ + + +class PackagingError(Exception): + """This exception is raised when there is an issue with exporting a package. + ``PackageExporter`` will attempt to gather up all the errors and present + them to you at once. + """ + + def __init__(self, dependency_graph: DiGraph, debug=False): + # Group errors by reason. + broken: Dict[PackagingErrorReason, List[str]] = defaultdict(list) + for module_name, attrs in dependency_graph.nodes.items(): + error = attrs.get("error") + if error is None: + continue + if error == PackagingErrorReason.NO_ACTION: + assert "action" not in attrs + broken[error].append(module_name) + + message = io.StringIO() + message.write("\n") + + for reason, module_names in broken.items(): + message.write(f"* {reason.value}\n") + for module_name in module_names: + message.write(f" {module_name}\n") + + # Print additional context if it's provided. + error_context = dependency_graph.nodes[module_name].get("error_context") + if error_context is not None: + message.write(f" Context: {error_context}\n") + if module_name in _DISALLOWED_MODULES: + message.write( + " Note: While we usually use modules in the python standard library " + f"from the local environment, `{module_name}` has a lot of system " + "level access and therefore can pose a security risk. We heavily " + f"recommend removing `{module_name}` from your packaged code. However, if that " + "is not possible, add it to the extern list by calling " + f'PackageExporter.extern("`{module_name}`")\n' + ) + if debug: + module_path = dependency_graph.first_path(module_name) + message.write( + f" A path to {module_name}: {' -> '.join(module_path)}\n" + ) + if not debug: + message.write("\n") + message.write( + "Set debug=True when invoking PackageExporter for a visualization of where " + "broken modules are coming from!\n" + ) + # Save the dependency graph so that tooling can get at it. + self.dependency_graph = dependency_graph + super().__init__(message.getvalue()) + + +class PackageExporter: + """Exporters allow you to write packages of code, pickled Python data, and + arbitrary binary and text resources into a self-contained package. + + Imports can load this code in a hermetic way, such that code is loaded + from the package rather than the normal Python import system. This allows + for the packaging of PyTorch model code and data so that it can be run + on a server or used in the future for transfer learning. + + The code contained in packages is copied file-by-file from the original + source when it is created, and the file format is a specially organized + zip file. Future users of the package can unzip the package, and edit the code + in order to perform custom modifications to it. + + The importer for packages ensures that code in the module can only be loaded from + within the package, except for modules explicitly listed as external using :meth:`extern`. + The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on. + This prevents "implicit" dependencies where the package runs locally because it is importing + a locally-installed package, but then fails when the package is copied to another machine. + + When source code is added to the package, the exporter can optionally scan it + for further code dependencies (``dependencies=True``). It looks for import statements, + resolves relative references to qualified module names, and performs an action specified by the user + (See: :meth:`extern`, :meth:`mock`, and :meth:`intern`). + """ + + """A importer that will be searched in order to find the modules referenced by other modules or by + pickled objects. The default module environment just uses sys_importer, which searches the Python environment. + """ + importer: Importer + + def __init__( + self, + f: Union[str, Path, BinaryIO], + importer: Union[Importer, Sequence[Importer]] = sys_importer, + debug: bool = False, + ): + """ + Create an exporter. + + Args: + f: The location to export to. Can be a ``string``/``Path`` object containing a filename + or a binary I/O object. + importer: If a single Importer is passed, use that to search for modules. + If a sequence of importers are passed, an ``OrderedImporter`` will be constructed out of them. + debug: If set to True, add path of broken modules to PackagingErrors. + """ + torch._C._log_api_usage_once("torch.package.PackageExporter") + self.debug = debug + if isinstance(f, (Path, str)): + f = str(f) + self.buffer: Optional[BinaryIO] = None + else: # is a byte buffer + self.buffer = f + + self.zip_file = torch._C.PyTorchFileWriter(f) + self.zip_file.set_min_version(6) + self._written_files: Set[str] = set() + + self.serialized_reduces: Dict[int, Any] = {} + + # A graph tracking all the modules and pickle objects added to this + # package and the dependencies between them. + # - Each node is a module name (or a pickle name that looks like '') + # - Each directed edge (u, v) means u depends on v. + # - Nodes may contain metadata that describe how to write the thing to the zipfile. + self.dependency_graph = DiGraph() + self.script_module_serializer = torch._C.ScriptModuleSerializer(self.zip_file) + self.storage_context = self.script_module_serializer.storage_context() + + # These are OrderedDicts for compatibility with RemovableHandle. + # Generic OrderedDict type annotations are not present until 3.7. + # The real type signature is OrderedDict[int, Callable[[PackageExporter, str], None]] + self._extern_hooks: OrderedDict = OrderedDict() + self._mock_hooks: OrderedDict = OrderedDict() + self._intern_hooks: OrderedDict = OrderedDict() + + if isinstance(importer, Importer): + self.importer = importer + else: + if not isinstance(importer, collections.abc.Sequence): + raise TypeError( + "importer arg should be an Importer or a sequence of Importers, " + f"got {type(importer)} instead." + ) + self.importer = OrderedImporter(*importer) + + self.patterns: Dict[GlobGroup, _PatternInfo] = {} + self._unique_id = 0 + + def save_source_file( + self, module_name: str, file_or_directory: str, dependencies=True + ): + """Adds the local file system ``file_or_directory`` to the source package to provide the code + for ``module_name``. + + Args: + module_name (str): e.g. ``"my_package.my_subpackage"``, code will be saved to provide code for this package. + file_or_directory (str): the path to a file or directory of code. When a directory, all python files in the directory + are recursively copied using :meth:`save_source_file`. If a file is named ``"/__init__.py"`` the code is treated + as a package. + dependencies (bool, optional): If ``True``, we scan the source for dependencies. + """ + path = Path(file_or_directory) + if path.is_dir(): + to_save = [] # list of tuples with arguments to save_source_string + module_path = module_name.replace(".", "/") + for filename in path.glob("**/*.py"): + relative_path = filename.relative_to(path).as_posix() + archivename = module_path + "/" + relative_path + submodule_name = None + if filename.name == "__init__.py": + submodule_name = archivename[: -len("/__init__.py")].replace( + "/", "." + ) + is_package = True + else: + submodule_name = archivename[: -len(".py")].replace("/", ".") + is_package = False + + # we delay the call to save_source_string so that we record all the source files + # being provided by this directory structure _before_ attempting to resolve the dependencies + # on the source. This makes sure we don't try to copy over modules that will just get + # overwritten by this directory blob + to_save.append( + ( + submodule_name, + _read_file(str(filename)), + is_package, + dependencies, + ) + ) + + for item in to_save: + self.save_source_string(*item) + else: + is_package = path.name == "__init__.py" + self.save_source_string( + module_name, + _read_file(file_or_directory), + is_package, + dependencies, + ) + + def get_unique_id(self) -> str: + """Get an id. This id is guaranteed to only be handed out once for this package.""" + ret = str(self._unique_id) + self._unique_id += 1 + return ret + + def _get_dependencies( + self, src: str, module_name: str, is_package: bool + ) -> List[str]: + """Return all modules that this source code depends on. + + Dependencies are found by scanning the source code for import-like statements. + + Arguments: + src: The Python source code to analyze for dependencies. + module_name: The name of the module that ``src`` corresponds to. + is_package: Whether this module should be treated as a package. + See :py:meth:`save_source_string` for more info. + + Returns: + A list containing modules detected as direct dependencies in + ``src``. The items in the list are guaranteed to be unique. + """ + package_name = ( + module_name if is_package else module_name.rsplit(".", maxsplit=1)[0] + ) + try: + dep_pairs = find_files_source_depends_on(src, package_name) + except Exception as e: + self.dependency_graph.add_node( + module_name, + error=PackagingErrorReason.DEPENDENCY_RESOLUTION_FAILED, + error_context=str(e), + ) + return [] + + # Use a dict to get uniquing but also deterministic order + dependencies = {} + for dep_module_name, dep_module_obj in dep_pairs: + # handle the case where someone did something like `from pack import sub` + # where `sub` is a submodule. In this case we don't have to save pack, just sub. + # this ensures we don't pick up additional dependencies on pack. + # However, in the case where `sub` is not a submodule but an object, then we do have + # to save pack. + if dep_module_obj is not None: + possible_submodule = f"{dep_module_name}.{dep_module_obj}" + if self._module_exists(possible_submodule): + dependencies[possible_submodule] = True + # we don't need to save `pack` + continue + if self._module_exists(dep_module_name): + dependencies[dep_module_name] = True + + return list(dependencies.keys()) + + def save_source_string( + self, + module_name: str, + src: str, + is_package: bool = False, + dependencies: bool = True, + ): + """Adds ``src`` as the source code for ``module_name`` in the exported package. + + Args: + module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code for this package. + src (str): The Python source code to save for this package. + is_package (bool, optional): If ``True``, this module is treated as a package. Packages are allowed to have submodules + (e.g. ``my_package.my_subpackage.my_subsubpackage``), and resources can be saved inside them. Defaults to ``False``. + dependencies (bool, optional): If ``True``, we scan the source for dependencies. + """ + self.dependency_graph.add_node( + module_name, + source=src, + is_package=is_package, + provided=True, + action=_ModuleProviderAction.INTERN, + ) + + if dependencies: + deps = self._get_dependencies(src, module_name, is_package) + + for dep in deps: + self.dependency_graph.add_edge(module_name, dep) + self.add_dependency(dep) + + def _write_source_string( + self, + module_name: str, + src: str, + is_package: bool = False, + ): + """Write ``src`` as the source code for ``module_name`` in the zip archive. + + Arguments are otherwise the same as for :meth:`save_source_string`. + """ + extension = "/__init__.py" if is_package else ".py" + filename = module_name.replace(".", "/") + extension + + self._write(filename, src) + + def _import_module(self, module_name: str): + try: + return self.importer.import_module(module_name) + except ModuleNotFoundError as e: + if not is_mangled(module_name): + raise + msg = ( + f"Module not found: '{module_name}'. Make sure the PackageImporter that " + "created this module is present in `self.importer`" + ) + raise ModuleNotFoundError(msg) from None + + def _module_exists(self, module_name: str) -> bool: + try: + self._import_module(module_name) + return True + except Exception: + return False + + def _get_source_of_module(self, module: types.ModuleType) -> Optional[str]: + filename = None + spec = getattr(module, "__spec__", None) + if spec is not None: + loader = getattr(spec, "loader", None) + if loader is not None and isinstance(loader, SourceFileLoader): + try: + filename = loader.get_filename(module.__name__) + except ImportError: + pass + if filename is None: + filename = getattr(module, "__file__", None) + if isinstance(filename, str) and filename.endswith(".py"): + return "".join(linecache.getlines(filename, module.__dict__)) + return None + + def add_dependency(self, module_name: str, dependencies=True): + """Given a module, add it to the dependency graph according to patterns + specified by the user. + """ + if ( + module_name in self.dependency_graph + and self.dependency_graph.nodes[module_name].get("provided") is True + ): + return + + # Special case: PackageImporter provides a special module called + # `torch_package_importer` that allows packaged modules to reference + # their PackageImporter. We don't want to re-export this. + if module_name == "torch_package_importer": + self.dependency_graph.add_node( + module_name, + action=_ModuleProviderAction.SKIP, + provided=True, + ) + return + + if module_name == "_mock": + self.dependency_graph.add_node( + module_name, + action=_ModuleProviderAction.REPACKAGED_MOCK_MODULE, + provided=True, + ) + return + + if self._can_implicitly_extern(module_name): + self.dependency_graph.add_node( + module_name, action=_ModuleProviderAction.EXTERN, provided=True + ) + return + + for pattern, pattern_info in self.patterns.items(): + if pattern.matches(module_name): + pattern_info.was_matched = True + self.dependency_graph.add_node( + module_name, action=pattern_info.action, provided=True + ) + + if pattern_info.action == _ModuleProviderAction.DENY: + # Requiring a denied module just adds an error to the graph. + self.dependency_graph.add_node( + module_name, error=PackagingErrorReason.DENIED + ) + + # If we are interning this module, we need to retrieve its + # dependencies and package those as well. + if pattern_info.action == _ModuleProviderAction.INTERN: + self._intern_module(module_name, dependencies) + return + + # No patterns have matched. Explicitly add this as an error. + self.dependency_graph.add_node( + module_name, error=PackagingErrorReason.NO_ACTION + ) + + def save_module(self, module_name: str, dependencies=True): + """Save the code for ``module`` into the package. Code for the module is resolved using the ``importers`` path to find the + module object, and then using its ``__file__`` attribute to find the source code. + + Args: + module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code + for this package. + dependencies (bool, optional): If ``True``, we scan the source for dependencies. + """ + if not isinstance(module_name, str): + raise TypeError( + "save_module() expects a string input, did you perhaps mean to pass `__name__`?" + ) + + self._intern_module(module_name, dependencies) + + def _intern_module( + self, + module_name: str, + dependencies: bool, + ): + """Adds the module to the dependency graph as an interned module, + along with any metadata needed to write it out to the zipfile at serialization time. + """ + module_obj = self._import_module(module_name) + # Subtle: if the import above succeeded, either: + # 1. The module name is not mangled, and this was just a regular import, or + # 2. The module name is mangled, but one of the importers was able to + # recognize the mangling and import it. + # Either way, it is now safe to demangle this name so that we don't + # serialize the mangled version to the package. + module_name = demangle(module_name) + + # Find dependencies of this module and require them as well. + is_package = hasattr(module_obj, "__path__") + source = self._get_source_of_module(module_obj) + if source is None: + # Couldn't find a source! Add it to our dependency graph as broken + # and continue. + filename = getattr(module_obj, "__file__", None) + error_context = None + if filename is None: + packaging_error = PackagingErrorReason.NO_DUNDER_FILE + elif filename.endswith(tuple(importlib.machinery.EXTENSION_SUFFIXES)): + packaging_error = PackagingErrorReason.IS_EXTENSION_MODULE + else: + packaging_error = PackagingErrorReason.SOURCE_FILE_NOT_FOUND + error_context = f"filename: {filename}" + self.dependency_graph.add_node( + module_name, + action=_ModuleProviderAction.INTERN, + is_package=is_package, + error=packaging_error, + error_context=error_context, + provided=True, + ) + return + + self.dependency_graph.add_node( + module_name, + action=_ModuleProviderAction.INTERN, + is_package=is_package, + source=source, + provided=True, + ) + + if dependencies: + deps = self._get_dependencies(source, module_name, is_package) + for dep in deps: + self.dependency_graph.add_edge(module_name, dep) + self.add_dependency(dep) + + def save_pickle( + self, + package: str, + resource: str, + obj: Any, + dependencies: bool = True, + pickle_protocol: int = 3, + ): + """Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into + the archive rather than a stand-alone file. Standard pickle does not save the code, only the objects. + If ``dependencies`` is true, this method will also scan the pickled objects for which modules are required + to reconstruct them and save the relevant code. + + To be able to save an object where ``type(obj).__name__`` is ``my_module.MyObject``, + ``my_module.MyObject`` must resolve to the class of the object according to the ``importer`` order. When saving objects that + have previously been packaged, the importer's ``import_module`` method will need to be present in the ``importer`` list + for this to work. + + Args: + package (str): The name of module package this resource should go in (e.g. ``"my_package.my_subpackage"``). + resource (str): A unique name for the resource, used to identify it to load. + obj (Any): The object to save, must be picklable. + dependencies (bool, optional): If ``True``, we scan the source for dependencies. + """ + + assert (pickle_protocol == 4) or ( + pickle_protocol == 3 + ), "torch.package only supports pickle protocols 3 and 4" + + filename = self._filename(package, resource) + # Write the pickle data for `obj` + data_buf = io.BytesIO() + pickler = create_pickler(data_buf, self.importer, protocol=pickle_protocol) + pickler.persistent_id = self._persistent_id + pickler.dump(obj) + data_value = data_buf.getvalue() + mocked_modules = defaultdict(list) + name_in_dependency_graph = f"<{package}.{resource}>" + self.dependency_graph.add_node( + name_in_dependency_graph, + action=_ModuleProviderAction.INTERN, + provided=True, + is_pickle=True, + ) + + def _check_mocked_error(module: Optional[str], field: Optional[str]): + """ + checks if an object (field) comes from a mocked module and then adds + the pair to mocked_modules which contains mocked modules paired with their + list of mocked objects present in the pickle. + + We also hold the invariant that the first user defined rule that applies + to the module is the one we use. + """ + + assert isinstance(module, str) + assert isinstance(field, str) + if self._can_implicitly_extern(module): + return + for pattern, pattern_info in self.patterns.items(): + if pattern.matches(module): + if pattern_info.action == _ModuleProviderAction.MOCK: + mocked_modules[module].append(field) + return + + if dependencies: + all_dependencies = [] + module = None + field = None + memo: DefaultDict[int, str] = defaultdict(None) + memo_count = 0 + # pickletools.dis(data_value) + for opcode, arg, pos in pickletools.genops(data_value): + if pickle_protocol == 4: + if ( + opcode.name == "SHORT_BINUNICODE" + or opcode.name == "BINUNICODE" + or opcode.name == "BINUNICODE8" + ): + assert isinstance(arg, str) + module = field + field = arg + memo[memo_count] = arg + elif ( + opcode.name == "LONG_BINGET" + or opcode.name == "BINGET" + or opcode.name == "GET" + ): + assert isinstance(arg, int) + module = field + field = memo.get(arg, None) + elif opcode.name == "MEMOIZE": + memo_count += 1 + elif opcode.name == "STACK_GLOBAL": + if module is None: + # If not module was passed on in the entries preceeding this one, continue. + continue + assert isinstance(module, str) + if module not in all_dependencies: + all_dependencies.append(module) + _check_mocked_error(module, field) + elif ( + pickle_protocol == 3 and opcode.name == "GLOBAL" + ): # a global reference + assert isinstance(arg, str) + module, field = arg.split(" ") + if module not in all_dependencies: + all_dependencies.append(module) + _check_mocked_error(module, field) + for module_name in all_dependencies: + self.dependency_graph.add_edge(name_in_dependency_graph, module_name) + + """ If an object happens to come from a mocked module, then we collect these errors and spit them + out with the other errors found by package exporter. + """ + if module_name in mocked_modules: + assert isinstance(module_name, str) + fields = mocked_modules[module_name] + self.dependency_graph.add_node( + module_name, + action=_ModuleProviderAction.MOCK, + error=PackagingErrorReason.MOCKED_BUT_STILL_USED, + error_context=f"Object(s) '{fields}' from module `{module_name}` was mocked out during packaging " + f"but is being used in resource - `{resource}` in package `{package}`. ", + provided=True, + ) + else: + self.add_dependency(module_name) + + self._write(filename, data_value) + + def save_text(self, package: str, resource: str, text: str): + """Save text data to the package. + + Args: + package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``). + resource (str): A unique name for the resource, used to identify it to load. + text (str): The contents to save. + """ + return self.save_binary(package, resource, text.encode("utf-8")) + + def save_binary(self, package, resource, binary: bytes): + """Save raw bytes to the package. + + Args: + package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``). + resource (str): A unique name for the resource, used to identify it to load. + binary (str): The data to save. + """ + filename = self._filename(package, resource) + self._write(filename, binary) + + def register_extern_hook(self, hook: ActionHook) -> RemovableHandle: + """Registers an extern hook on the exporter. + + The hook will be called each time a module matches against an :meth:`extern` pattern. + It should have the following signature:: + + hook(exporter: PackageExporter, module_name: str) -> None + + Hooks will be called in order of registration. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + A handle that can be used to remove the added hook by calling + ``handle.remove()``. + """ + handle = RemovableHandle(self._extern_hooks) + self._extern_hooks[handle.id] = hook + return handle + + def register_mock_hook(self, hook: ActionHook) -> RemovableHandle: + """Registers a mock hook on the exporter. + + The hook will be called each time a module matches against a :meth:`mock` pattern. + It should have the following signature:: + + hook(exporter: PackageExporter, module_name: str) -> None + + Hooks will be called in order of registration. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + A handle that can be used to remove the added hook by calling + ``handle.remove()``. + """ + handle = RemovableHandle(self._mock_hooks) + self._mock_hooks[handle.id] = hook + return handle + + def register_intern_hook(self, hook: ActionHook) -> RemovableHandle: + """Registers an intern hook on the exporter. + + The hook will be called each time a module matches against an :meth:`intern` pattern. + It should have the following signature:: + + hook(exporter: PackageExporter, module_name: str) -> None + + Hooks will be called in order of registration. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + A handle that can be used to remove the added hook by calling + ``handle.remove()``. + """ + handle = RemovableHandle(self._intern_hooks) + self._intern_hooks[handle.id] = hook + return handle + + def intern( + self, + include: "GlobPattern", + *, + exclude: "GlobPattern" = (), + allow_empty: bool = True, + ): + """Specify modules that should be packaged. A module must match some ``intern`` pattern in order to be + included in the package and have its dependencies processed recursively. + + Args: + include (Union[List[str], str]): A string e.g. "my_package.my_subpackage", or list of strings + for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`. + + exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. + + allow_empty (bool): An optional flag that specifies whether the intern modules specified by this call + to the ``intern`` method must be matched to some module during packaging. If an ``intern`` module glob + pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``) + before any modules match that pattern, an exception is thrown. If ``allow_empty=True``, no such exception is thrown. + + """ + self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( + _ModuleProviderAction.INTERN, allow_empty + ) + + def mock( + self, + include: "GlobPattern", + *, + exclude: "GlobPattern" = (), + allow_empty: bool = True, + ): + """Replace some required modules with a mock implementation. Mocked modules will return a fake + object for any attribute accessed from it. Because we copy file-by-file, the dependency resolution will sometimes + find files that are imported by model files but whose functionality is never used + (e.g. custom serialization code or training helpers). + Use this function to mock this functionality out without having to modify the original code. + + Args: + include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings + for the names of the modules to be mocked out. Strings can also be a glob-style pattern + string that may match multiple modules. Any required dependencies that match this pattern + string will be mocked out automatically. + + Examples : + ``'torch.**'`` -- matches ``torch`` and all submodules of torch, e.g. ``'torch.nn'`` + and ``'torch.nn.functional'`` + + ``'torch.*'`` -- matches ``'torch.nn'`` or ``'torch.functional'``, but not + ``'torch.nn.functional'`` + + exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. + e.g. ``include='torch.**', exclude='torch.foo'`` will mock all torch packages except ``'torch.foo'``, + Default: is ``[]``. + + allow_empty (bool): An optional flag that specifies whether the mock implementation(s) specified by this call + to the :meth:`mock` method must be matched to some module during packaging. If a mock is added with + ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``) and the mock has + not been matched to a module used by the package being exported, an exception is thrown. + If ``allow_empty=True``, no such exception is thrown. + + """ + self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( + _ModuleProviderAction.MOCK, allow_empty + ) + + def extern( + self, + include: "GlobPattern", + *, + exclude: "GlobPattern" = (), + allow_empty: bool = True, + ): + """Include ``module`` in the list of external modules the package can import. + This will prevent dependency discovery from saving + it in the package. The importer will load an external module directly from the standard import system. + Code for extern modules must also exist in the process loading the package. + + Args: + include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings + for the names of the modules to be externed. This can also be a glob-style pattern, as + described in :meth:`mock`. + + exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the + include string. + + allow_empty (bool): An optional flag that specifies whether the extern modules specified by this call + to the ``extern`` method must be matched to some module during packaging. If an extern module glob + pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via + ``__exit__``) before any modules match that pattern, an exception is thrown. If ``allow_empty=True``, + no such exception is thrown. + + """ + self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( + _ModuleProviderAction.EXTERN, allow_empty + ) + + def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()): + """Blocklist modules who names match the given glob patterns from the list of modules the package can import. + If a dependency on any matching packages is found, a :class:`PackagingError` is raised. + + Args: + include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings + for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`. + + exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. + """ + self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( + _ModuleProviderAction.DENY, allow_empty=True + ) + + def _persistent_id(self, obj): + if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): + storage: Storage + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, we can + # remove this case + untyped_storage = obj._untyped_storage + storage_type_str = obj.pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage = cast(Storage, untyped_storage) + storage_numel = obj.size() + + elif isinstance(obj, torch.UntypedStorage): + untyped_storage = obj + storage = cast(Storage, untyped_storage) + storage_type = normalize_storage_type(type(storage)) + storage_numel = storage.nbytes() + else: + raise RuntimeError(f"storage type not recognized: {type(obj)}") + + location = location_tag(storage) + + # serialize storage if not already written + storage_present = self.storage_context.has_storage(storage) + storage_id = self.storage_context.get_or_add_storage(storage) + if not storage_present: + if storage.device.type != "cpu": + storage = storage.cpu() + num_bytes = storage.nbytes() + self.zip_file.write_record( + f".data/{storage_id}.storage", storage, num_bytes + ) + return ("storage", storage_type, storage_id, location, storage_numel) + + if hasattr(obj, "__reduce_package__"): + if _gate_torchscript_serialization and isinstance( + obj, torch.jit.RecursiveScriptModule + ): + raise Exception( # noqa: TRY002 + "Serializing ScriptModules directly into a package is a beta feature. " + "To use, set global " + "`torch.package.package_exporter._gate_torchscript_serialization` to `False`." + ) + if self.serialized_reduces.get(id(obj)) is None: + self.serialized_reduces[id(obj)] = ( + "reduce_package", + id(obj), + *obj.__reduce_package__(self), + ) + + return self.serialized_reduces[id(obj)] + + return None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + # If __exit__ was called because an exception was raised, we do not + # attempt to finalize the package. Instead, control is returned to the + # caller to continue raising the exception. + if exc_type is not None: + # Do the bare minimum to leave the open buffer in a valid state. + self._finalize_zip() + return + + self.close() + + def _write(self, filename, str_or_bytes): + if filename in self._written_files: + raise AssertionError( + f"Tried to write file '{filename}', but it already exists in this archive. " + "Please file a bug." + ) + self._written_files.add(filename) + + if is_mangled(filename): + raise AssertionError( + f"Tried to save a torch.package'd module as '{filename}'. " + "Directly saving torch.package'd modules is not allowed." + ) + if isinstance(str_or_bytes, str): + str_or_bytes = str_or_bytes.encode("utf-8") + self.zip_file.write_record(filename, str_or_bytes, len(str_or_bytes)) + + def _validate_dependency_graph(self): + # 1. Check the graph for any errors inserted during dependency analysis. + for attrs in self.dependency_graph.nodes.values(): + if "error" in attrs: + raise PackagingError(self.dependency_graph, debug=self.debug) + + # 2. Check that all patterns for which allow_empty=False have been matched at least once. + for pattern, pattern_info in self.patterns.items(): + if not pattern_info.allow_empty and not pattern_info.was_matched: + raise EmptyMatchError( + f"Exporter did not match any modules to {pattern}, which was marked as allow_empty=False" + ) + + def _write_mock_file(self): + if "_mock.py" not in self._written_files: + mock_file = str(Path(__file__).parent / "_mock.py") + self._write_source_string("_mock", _read_file(mock_file), is_package=False) + + def _execute_dependency_graph(self): + """Takes a finalized dependency graph describing how to package all + modules and executes it, writing to the ZIP archive. + """ + self._validate_dependency_graph() + + extern_modules = [] + for module_name, attrs in self.dependency_graph.nodes.items(): + action = attrs["action"] + + if action == _ModuleProviderAction.EXTERN: + for hook in self._extern_hooks.values(): + hook(self, module_name) + + extern_modules.append(module_name) + + elif action == _ModuleProviderAction.MOCK: + for hook in self._mock_hooks.values(): + hook(self, module_name) + + self._write_mock_file() + + is_package = hasattr(self._import_module(module_name), "__path__") + self._write_source_string(module_name, _MOCK_IMPL, is_package) + + elif action == _ModuleProviderAction.INTERN: + for hook in self._intern_hooks.values(): + hook(self, module_name) + + # The node in the dependency graph contains metadata that tells us + # how to intern the module. + if "provided" not in attrs: + raise AssertionError( + f"Module was marked `intern` but not provided: {module_name}" + ) + + if attrs.get("is_pickle") is True: + # This node came from save_pickle, we don't need to write any source for it. + continue + + is_package = attrs["is_package"] + source = attrs["source"] + self._write_source_string(module_name, source, is_package) + + elif action == _ModuleProviderAction.REPACKAGED_MOCK_MODULE: + self._write_mock_file() + elif action == _ModuleProviderAction.SKIP: + continue + else: + raise AssertionError( + f"Invalid action: {module_name}, {action}. Please report a bug to PyTorch." + ) + + extern_file_contents = "\n".join(extern_modules) + "\n" + self._write(".data/extern_modules", extern_file_contents) + + def _write_python_version(self): + """Writes the python version that the package was created with to .data/python_version""" + self._write(".data/python_version", platform.python_version()) + + def close(self): + """Write the package to the filesystem. Any calls after :meth:`close` are now invalid. + It is preferable to use resource guard syntax instead:: + + with PackageExporter("file.zip") as e: + ... + """ + self._execute_dependency_graph() + self._write_python_version() + + self.script_module_serializer.write_files() + self._finalize_zip() + + def _finalize_zip(self): + """Called at the very end of packaging to leave the zipfile in a closed but valid state.""" + del self.zip_file + if self.buffer: + self.buffer.flush() + + def _filename(self, package, resource): + package_path = package.replace(".", "/") + resource = _normalize_path(resource) + return f"{package_path}/{resource}" + + def _can_implicitly_extern(self, module_name: str): + top_level_package_name = module_name.partition(".")[0] + return top_level_package_name == "torch" or ( + top_level_package_name not in _DISALLOWED_MODULES + and is_stdlib_module(top_level_package_name) + ) + + def dependency_graph_string(self) -> str: + """Returns digraph string representation of dependencies in package. + + Returns: + A string representation of dependencies in package. + """ + return self.dependency_graph.to_dot() + + def _nodes_with_action_type( + self, action: Optional[_ModuleProviderAction] + ) -> List[str]: + result = [] + for name, node_dict in self.dependency_graph.nodes.items(): + node_action = node_dict.get("action", None) + if node_action == action and "is_pickle" not in node_dict: + result.append(name) + result.sort() + return result + + def externed_modules(self) -> List[str]: + """Return all modules that are currently externed. + + Returns: + A list containing the names of modules which will be + externed in this package. + """ + return self._nodes_with_action_type(_ModuleProviderAction.EXTERN) + + def interned_modules(self) -> List[str]: + """Return all modules that are currently interned. + + Returns: + A list containing the names of modules which will be + interned in this package. + """ + return self._nodes_with_action_type(_ModuleProviderAction.INTERN) + + def mocked_modules(self) -> List[str]: + """Return all modules that are currently mocked. + + Returns: + A list containing the names of modules which will be + mocked in this package. + """ + return self._nodes_with_action_type(_ModuleProviderAction.MOCK) + + def denied_modules(self) -> List[str]: + """Return all modules that are currently denied. + + Returns: + A list containing the names of modules which will be + denied in this package. + """ + return self._nodes_with_action_type(_ModuleProviderAction.DENY) + + def get_rdeps(self, module_name: str) -> List[str]: + """Return a list of all modules which depend on the module ``module_name``. + + Returns: + A list containing the names of modules which depend on ``module_name``. + """ + if module_name in self.dependency_graph._pred.keys(): + return list(self.dependency_graph._pred[module_name].keys()) + else: + return [] + + def all_paths(self, src: str, dst: str) -> str: + """Return a dot representation of the subgraph + that has all paths from src to dst. + + Returns: + A dot representation containing all paths from src to dst. + (https://graphviz.org/doc/info/lang.html) + """ + return self.dependency_graph.all_paths(src, dst) + + +# even though these are in the standard library, we do not allow them to be +# automatically externed since they offer a lot of system level access +_DISALLOWED_MODULES = ["sys", "io"] + +_MOCK_IMPL = """\ +from _mock import MockedObject +def __getattr__(attr: str): + return MockedObject(__name__ + '.' + attr, _suppress_err=True) +""" + + +def _read_file(filename: str) -> str: + with open(filename, "rb") as f: + b = f.read() + return b.decode("utf-8") diff --git a/lib/python3.10/site-packages/torch/package/package_importer.py b/lib/python3.10/site-packages/torch/package/package_importer.py new file mode 100644 index 0000000000000000000000000000000000000000..cf557d72bd4f7df0182e9db37b09089c724b291a --- /dev/null +++ b/lib/python3.10/site-packages/torch/package/package_importer.py @@ -0,0 +1,791 @@ +# mypy: allow-untyped-defs +import builtins +import importlib +import importlib.machinery +import inspect +import io +import linecache +import os +import types +from contextlib import contextmanager +from typing import ( + Any, + BinaryIO, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + TYPE_CHECKING, + Union, +) +from weakref import WeakValueDictionary + +import torch +from torch.serialization import _get_restore_location, _maybe_decode_ascii + +from ._directory_reader import DirectoryReader +from ._importlib import ( + _calc___package__, + _normalize_line_endings, + _normalize_path, + _resolve_name, + _sanity_check, +) +from ._mangling import demangle, PackageMangler +from ._package_unpickler import PackageUnpickler +from .file_structure_representation import _create_directory_from_file_list, Directory +from .importer import Importer + + +if TYPE_CHECKING: + from .glob_group import GlobPattern + +__all__ = ["PackageImporter"] + + +# This is a list of imports that are implicitly allowed even if they haven't +# been marked as extern. This is to work around the fact that Torch implicitly +# depends on numpy and package can't track it. +# https://github.com/pytorch/MultiPy/issues/46 +IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [ + "numpy", + "numpy.core", + "numpy.core._multiarray_umath", + # FX GraphModule might depend on builtins module and users usually + # don't extern builtins. Here we import it here by default. + "builtins", +] + + +# Compatibility name mapping to facilitate upgrade of external modules. +# The primary motivation is to enable Numpy upgrade that many modules +# depend on. The latest release of Numpy removed `numpy.str` and +# `numpy.bool` breaking unpickling for many modules. +EXTERN_IMPORT_COMPAT_NAME_MAPPING: Dict[str, Dict[str, Any]] = { + "numpy": { + "str": str, + "bool": bool, + }, +} + + +class PackageImporter(Importer): + """Importers allow you to load code written to packages by :class:`PackageExporter`. + Code is loaded in a hermetic way, using files from the package + rather than the normal python import system. This allows + for the packaging of PyTorch model code and data so that it can be run + on a server or used in the future for transfer learning. + + The importer for packages ensures that code in the module can only be loaded from + within the package, except for modules explicitly listed as external during export. + The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on. + This prevents "implicit" dependencies where the package runs locally because it is importing + a locally-installed package, but then fails when the package is copied to another machine. + """ + + """The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but + local to this importer. + """ + + modules: Dict[str, types.ModuleType] + + def __init__( + self, + file_or_buffer: Union[str, torch._C.PyTorchFileReader, os.PathLike, BinaryIO], + module_allowed: Callable[[str], bool] = lambda module_name: True, + ): + """Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules + allowed by ``module_allowed`` + + Args: + file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), + a string, or an ``os.PathLike`` object containing a filename. + module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module + should be allowed. Can be used to ensure packages loaded do not depend on modules that the server + does not support. Defaults to allowing anything. + + Raises: + ImportError: If the package will use a disallowed module. + """ + torch._C._log_api_usage_once("torch.package.PackageImporter") + + self.zip_reader: Any + if isinstance(file_or_buffer, torch._C.PyTorchFileReader): + self.filename = "" + self.zip_reader = file_or_buffer + elif isinstance(file_or_buffer, (os.PathLike, str)): + self.filename = os.fspath(file_or_buffer) + if not os.path.isdir(self.filename): + self.zip_reader = torch._C.PyTorchFileReader(self.filename) + else: + self.zip_reader = DirectoryReader(self.filename) + else: + self.filename = "" + self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer) + + torch._C._log_api_usage_metadata( + "torch.package.PackageImporter.metadata", + { + "serialization_id": self.zip_reader.serialization_id(), + "file_name": self.filename, + }, + ) + + self.root = _PackageNode(None) + self.modules = {} + self.extern_modules = self._read_extern() + + for extern_module in self.extern_modules: + if not module_allowed(extern_module): + raise ImportError( + f"package '{file_or_buffer}' needs the external module '{extern_module}' " + f"but that module has been disallowed" + ) + self._add_extern(extern_module) + + for fname in self.zip_reader.get_all_records(): + self._add_file(fname) + + self.patched_builtins = builtins.__dict__.copy() + self.patched_builtins["__import__"] = self.__import__ + # Allow packaged modules to reference their PackageImporter + self.modules["torch_package_importer"] = self # type: ignore[assignment] + + self._mangler = PackageMangler() + + # used for reduce deserializaiton + self.storage_context: Any = None + self.last_map_location = None + + # used for torch.serialization._load + self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs) + + def import_module(self, name: str, package=None): + """Load a module from the package if it hasn't already been loaded, and then return + the module. Modules are loaded locally + to the importer and will appear in ``self.modules`` rather than ``sys.modules``. + + Args: + name (str): Fully qualified name of the module to load. + package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``. + + Returns: + types.ModuleType: The (possibly already) loaded module. + """ + # We should always be able to support importing modules from this package. + # This is to support something like: + # obj = importer.load_pickle(...) + # importer.import_module(obj.__module__) <- this string will be mangled + # + # Note that _mangler.demangle will not demangle any module names + # produced by a different PackageImporter instance. + name = self._mangler.demangle(name) + + return self._gcd_import(name) + + def load_binary(self, package: str, resource: str) -> bytes: + """Load raw bytes. + + Args: + package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). + resource (str): The unique name for the resource. + + Returns: + bytes: The loaded data. + """ + + path = self._zipfile_path(package, resource) + return self.zip_reader.get_record(path) + + def load_text( + self, + package: str, + resource: str, + encoding: str = "utf-8", + errors: str = "strict", + ) -> str: + """Load a string. + + Args: + package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). + resource (str): The unique name for the resource. + encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``. + errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``. + + Returns: + str: The loaded text. + """ + data = self.load_binary(package, resource) + return data.decode(encoding, errors) + + def load_pickle(self, package: str, resource: str, map_location=None) -> Any: + """Unpickles the resource from the package, loading any modules that are needed to construct the objects + using :meth:`import_module`. + + Args: + package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). + resource (str): The unique name for the resource. + map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``. + + Returns: + Any: The unpickled object. + """ + pickle_file = self._zipfile_path(package, resource) + restore_location = _get_restore_location(map_location) + loaded_storages = {} + loaded_reduces = {} + storage_context = torch._C.DeserializationStorageContext() + + def load_tensor(dtype, size, key, location, restore_location): + name = f"{key}.storage" + + if storage_context.has_storage(name): + storage = storage_context.get_storage(name, dtype)._typed_storage() + else: + tensor = self.zip_reader.get_storage_from_record( + ".data/" + name, size, dtype + ) + if isinstance(self.zip_reader, torch._C.PyTorchFileReader): + storage_context.add_storage(name, tensor) + storage = tensor._typed_storage() + loaded_storages[key] = restore_location(storage, location) + + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + if typename == "storage": + storage_type, key, location, size = data + dtype = storage_type.dtype + + if key not in loaded_storages: + load_tensor( + dtype, + size, + key, + _maybe_decode_ascii(location), + restore_location, + ) + storage = loaded_storages[key] + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + return torch.storage.TypedStorage( + wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True + ) + elif typename == "reduce_package": + # to fix BC breaking change, objects on this load path + # will be loaded multiple times erroneously + if len(data) == 2: + func, args = data + return func(self, *args) + reduce_id, func, args = data + if reduce_id not in loaded_reduces: + loaded_reduces[reduce_id] = func(self, *args) + return loaded_reduces[reduce_id] + else: + f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'" + + # Load the data (which may in turn use `persistent_load` to load tensors) + data_file = io.BytesIO(self.zip_reader.get_record(pickle_file)) + unpickler = self.Unpickler(data_file) + unpickler.persistent_load = persistent_load # type: ignore[assignment] + + @contextmanager + def set_deserialization_context(): + # to let reduce_package access deserializaiton context + self.storage_context = storage_context + self.last_map_location = map_location + try: + yield + finally: + self.storage_context = None + self.last_map_location = None + + with set_deserialization_context(): + result = unpickler.load() + + # TODO from zdevito: + # This stateful weird function will need to be removed in our efforts + # to unify the format. It has a race condition if multiple python + # threads try to read independent files + torch._utils._validate_loaded_sparse_tensors() + + return result + + def id(self): + """ + Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances. + Looks like:: + + + """ + return self._mangler.parent_name() + + def file_structure( + self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = () + ) -> Directory: + """Returns a file structure representation of package's zipfile. + + Args: + include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings + for the names of the files to be included in the zipfile representation. This can also be + a glob-style pattern, as described in :meth:`PackageExporter.mock` + + exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern. + + Returns: + :class:`Directory` + """ + return _create_directory_from_file_list( + self.filename, self.zip_reader.get_all_records(), include, exclude + ) + + def python_version(self): + """Returns the version of python that was used to create this package. + + Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock + file later on. + + Returns: + :class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package + """ + python_version_path = ".data/python_version" + return ( + self.zip_reader.get_record(python_version_path).decode("utf-8").strip() + if self.zip_reader.has_record(python_version_path) + else None + ) + + def _read_extern(self): + return ( + self.zip_reader.get_record(".data/extern_modules") + .decode("utf-8") + .splitlines(keepends=False) + ) + + def _make_module( + self, name: str, filename: Optional[str], is_package: bool, parent: str + ): + mangled_filename = self._mangler.mangle(filename) if filename else None + spec = importlib.machinery.ModuleSpec( + name, + self, # type: ignore[arg-type] + origin="", + is_package=is_package, + ) + module = importlib.util.module_from_spec(spec) + self.modules[name] = module + module.__name__ = self._mangler.mangle(name) + ns = module.__dict__ + ns["__spec__"] = spec + ns["__loader__"] = self + ns["__file__"] = mangled_filename + ns["__cached__"] = None + ns["__builtins__"] = self.patched_builtins + ns["__torch_package__"] = True + + # Add this module to our private global registry. It should be unique due to mangling. + assert module.__name__ not in _package_imported_modules + _package_imported_modules[module.__name__] = module + + # pre-emptively install on the parent to prevent IMPORT_FROM from trying to + # access sys.modules + self._install_on_parent(parent, name, module) + + if filename is not None: + assert mangled_filename is not None + # pre-emptively install the source in `linecache` so that stack traces, + # `inspect`, etc. work. + assert filename not in linecache.cache # type: ignore[attr-defined] + linecache.lazycache(mangled_filename, ns) + + code = self._compile_source(filename, mangled_filename) + exec(code, ns) + + return module + + def _load_module(self, name: str, parent: str): + cur: _PathNode = self.root + for atom in name.split("."): + if not isinstance(cur, _PackageNode) or atom not in cur.children: + if name in IMPLICIT_IMPORT_ALLOWLIST: + module = self.modules[name] = importlib.import_module(name) + return module + raise ModuleNotFoundError( + f'No module named "{name}" in self-contained archive "{self.filename}"' + f" and the module is also not in the list of allowed external modules: {self.extern_modules}", + name=name, + ) + cur = cur.children[atom] + if isinstance(cur, _ExternNode): + module = self.modules[name] = importlib.import_module(name) + + if compat_mapping := EXTERN_IMPORT_COMPAT_NAME_MAPPING.get(name): + for old_name, new_name in compat_mapping.items(): + module.__dict__.setdefault(old_name, new_name) + + return module + return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined] + + def _compile_source(self, fullpath: str, mangled_filename: str): + source = self.zip_reader.get_record(fullpath) + source = _normalize_line_endings(source) + return compile(source, mangled_filename, "exec", dont_inherit=True) + + # note: named `get_source` so that linecache can find the source + # when this is the __loader__ of a module. + def get_source(self, module_name) -> str: + # linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here. + module = self.import_module(demangle(module_name)) + return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8") + + # note: named `get_resource_reader` so that importlib.resources can find it. + # This is otherwise considered an internal method. + def get_resource_reader(self, fullname): + try: + package = self._get_package(fullname) + except ImportError: + return None + if package.__loader__ is not self: + return None + return _PackageResourceReader(self, fullname) + + def _install_on_parent(self, parent: str, name: str, module: types.ModuleType): + if not parent: + return + # Set the module as an attribute on its parent. + parent_module = self.modules[parent] + if parent_module.__loader__ is self: + setattr(parent_module, name.rpartition(".")[2], module) + + # note: copied from cpython's import code, with call to create module replaced with _make_module + def _do_find_and_load(self, name): + path = None + parent = name.rpartition(".")[0] + module_name_no_parent = name.rpartition(".")[-1] + if parent: + if parent not in self.modules: + self._gcd_import(parent) + # Crazy side-effects! + if name in self.modules: + return self.modules[name] + parent_module = self.modules[parent] + + try: + path = parent_module.__path__ # type: ignore[attr-defined] + + except AttributeError: + # when we attempt to import a package only containing pybinded files, + # the parent directory isn't always a package as defined by python, + # so we search if the package is actually there or not before calling the error. + if isinstance( + parent_module.__loader__, + importlib.machinery.ExtensionFileLoader, + ): + if name not in self.extern_modules: + msg = ( + _ERR_MSG + + "; {!r} is a c extension module which was not externed. C extension modules \ + need to be externed by the PackageExporter in order to be used as we do not support interning them.}." + ).format(name, name) + raise ModuleNotFoundError(msg, name=name) from None + if not isinstance( + parent_module.__dict__.get(module_name_no_parent), + types.ModuleType, + ): + msg = ( + _ERR_MSG + + "; {!r} is a c extension package which does not contain {!r}." + ).format(name, parent, name) + raise ModuleNotFoundError(msg, name=name) from None + else: + msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent) + raise ModuleNotFoundError(msg, name=name) from None + + module = self._load_module(name, parent) + + self._install_on_parent(parent, name, module) + + return module + + # note: copied from cpython's import code + def _find_and_load(self, name): + module = self.modules.get(name, _NEEDS_LOADING) + if module is _NEEDS_LOADING: + return self._do_find_and_load(name) + + if module is None: + message = f"import of {name} halted; None in sys.modules" + raise ModuleNotFoundError(message, name=name) + + # To handle https://github.com/pytorch/pytorch/issues/57490, where std's + # creation of fake submodules via the hacking of sys.modules is not import + # friendly + if name == "os": + self.modules["os.path"] = cast(Any, module).path + elif name == "typing": + self.modules["typing.io"] = cast(Any, module).io + self.modules["typing.re"] = cast(Any, module).re + + return module + + def _gcd_import(self, name, package=None, level=0): + """Import and return the module based on its name, the package the call is + being made from, and the level adjustment. + + This function represents the greatest common denominator of functionality + between import_module and __import__. This includes setting __package__ if + the loader did not. + + """ + _sanity_check(name, package, level) + if level > 0: + name = _resolve_name(name, package, level) + + return self._find_and_load(name) + + # note: copied from cpython's import code + def _handle_fromlist(self, module, fromlist, *, recursive=False): + """Figure out what __import__ should return. + + The import_ parameter is a callable which takes the name of module to + import. It is required to decouple the function from assuming importlib's + import implementation is desired. + + """ + module_name = demangle(module.__name__) + # The hell that is fromlist ... + # If a package was imported, try to import stuff from fromlist. + if hasattr(module, "__path__"): + for x in fromlist: + if not isinstance(x, str): + if recursive: + where = module_name + ".__all__" + else: + where = "``from list''" + raise TypeError( + f"Item in {where} must be str, " f"not {type(x).__name__}" + ) + elif x == "*": + if not recursive and hasattr(module, "__all__"): + self._handle_fromlist(module, module.__all__, recursive=True) + elif not hasattr(module, x): + from_name = f"{module_name}.{x}" + try: + self._gcd_import(from_name) + except ModuleNotFoundError as exc: + # Backwards-compatibility dictates we ignore failed + # imports triggered by fromlist for modules that don't + # exist. + if ( + exc.name == from_name + and self.modules.get(from_name, _NEEDS_LOADING) is not None + ): + continue + raise + return module + + def __import__(self, name, globals=None, locals=None, fromlist=(), level=0): + if level == 0: + module = self._gcd_import(name) + else: + globals_ = globals if globals is not None else {} + package = _calc___package__(globals_) + module = self._gcd_import(name, package, level) + if not fromlist: + # Return up to the first dot in 'name'. This is complicated by the fact + # that 'name' may be relative. + if level == 0: + return self._gcd_import(name.partition(".")[0]) + elif not name: + return module + else: + # Figure out where to slice the module's name up to the first dot + # in 'name'. + cut_off = len(name) - len(name.partition(".")[0]) + # Slice end needs to be positive to alleviate need to special-case + # when ``'.' not in name``. + module_name = demangle(module.__name__) + return self.modules[module_name[: len(module_name) - cut_off]] + else: + return self._handle_fromlist(module, fromlist) + + def _get_package(self, package): + """Take a package name or module object and return the module. + + If a name, the module is imported. If the passed or imported module + object is not a package, raise an exception. + """ + if hasattr(package, "__spec__"): + if package.__spec__.submodule_search_locations is None: + raise TypeError(f"{package.__spec__.name!r} is not a package") + else: + return package + else: + module = self.import_module(package) + if module.__spec__.submodule_search_locations is None: + raise TypeError(f"{package!r} is not a package") + else: + return module + + def _zipfile_path(self, package, resource=None): + package = self._get_package(package) + assert package.__loader__ is self + name = demangle(package.__name__) + if resource is not None: + resource = _normalize_path(resource) + return f"{name.replace('.', '/')}/{resource}" + else: + return f"{name.replace('.', '/')}" + + def _get_or_create_package( + self, atoms: List[str] + ) -> "Union[_PackageNode, _ExternNode]": + cur = self.root + for i, atom in enumerate(atoms): + node = cur.children.get(atom, None) + if node is None: + node = cur.children[atom] = _PackageNode(None) + if isinstance(node, _ExternNode): + return node + if isinstance(node, _ModuleNode): + name = ".".join(atoms[:i]) + raise ImportError( + f"inconsistent module structure. module {name} is not a package, but has submodules" + ) + assert isinstance(node, _PackageNode) + cur = node + return cur + + def _add_file(self, filename: str): + """Assembles a Python module out of the given file. Will ignore files in the .data directory. + + Args: + filename (str): the name of the file inside of the package archive to be added + """ + *prefix, last = filename.split("/") + if len(prefix) > 1 and prefix[0] == ".data": + return + package = self._get_or_create_package(prefix) + if isinstance(package, _ExternNode): + raise ImportError( + f"inconsistent module structure. package contains a module file {filename}" + f" that is a subpackage of a module marked external." + ) + if last == "__init__.py": + package.source_file = filename + elif last.endswith(".py"): + package_name = last[: -len(".py")] + package.children[package_name] = _ModuleNode(filename) + + def _add_extern(self, extern_name: str): + *prefix, last = extern_name.split(".") + package = self._get_or_create_package(prefix) + if isinstance(package, _ExternNode): + return # the shorter extern covers this extern case + package.children[last] = _ExternNode() + + +_NEEDS_LOADING = object() +_ERR_MSG_PREFIX = "No module named " +_ERR_MSG = _ERR_MSG_PREFIX + "{!r}" + + +class _PathNode: + pass + + +class _PackageNode(_PathNode): + def __init__(self, source_file: Optional[str]): + self.source_file = source_file + self.children: Dict[str, _PathNode] = {} + + +class _ModuleNode(_PathNode): + __slots__ = ["source_file"] + + def __init__(self, source_file: str): + self.source_file = source_file + + +class _ExternNode(_PathNode): + pass + + +# A private global registry of all modules that have been package-imported. +_package_imported_modules: WeakValueDictionary = WeakValueDictionary() + +# `inspect` by default only looks in `sys.modules` to find source files for classes. +# Patch it to check our private registry of package-imported modules as well. +_orig_getfile = inspect.getfile + + +def _patched_getfile(object): + if inspect.isclass(object): + if object.__module__ in _package_imported_modules: + return _package_imported_modules[object.__module__].__file__ + return _orig_getfile(object) + + +inspect.getfile = _patched_getfile + + +class _PackageResourceReader: + """Private class used to support PackageImporter.get_resource_reader(). + + Confirms to the importlib.abc.ResourceReader interface. Allowed to access + the innards of PackageImporter. + """ + + def __init__(self, importer, fullname): + self.importer = importer + self.fullname = fullname + + def open_resource(self, resource): + from io import BytesIO + + return BytesIO(self.importer.load_binary(self.fullname, resource)) + + def resource_path(self, resource): + # The contract for resource_path is that it either returns a concrete + # file system path or raises FileNotFoundError. + if isinstance( + self.importer.zip_reader, DirectoryReader + ) and self.importer.zip_reader.has_record( + os.path.join(self.fullname, resource) + ): + return os.path.join( + self.importer.zip_reader.directory, self.fullname, resource + ) + raise FileNotFoundError + + def is_resource(self, name): + path = self.importer._zipfile_path(self.fullname, name) + return self.importer.zip_reader.has_record(path) + + def contents(self): + from pathlib import Path + + filename = self.fullname.replace(".", "/") + + fullname_path = Path(self.importer._zipfile_path(self.fullname)) + files = self.importer.zip_reader.get_all_records() + subdirs_seen = set() + for filename in files: + try: + relative = Path(filename).relative_to(fullname_path) + except ValueError: + continue + # If the path of the file (which is relative to the top of the zip + # namespace), relative to the package given when the resource + # reader was created, has a parent, then it's a name in a + # subdirectory and thus we skip it. + parent_name = relative.parent.name + if len(parent_name) == 0: + yield relative.name + elif parent_name not in subdirs_seen: + subdirs_seen.add(parent_name) + yield parent_name diff --git a/lib/python3.10/site-packages/torch/profiler/__init__.py b/lib/python3.10/site-packages/torch/profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..073096607afe7ef44a9fe0896ff4b242b37e851b --- /dev/null +++ b/lib/python3.10/site-packages/torch/profiler/__init__.py @@ -0,0 +1,50 @@ +# mypy: allow-untyped-defs +r""" +PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference. +Profiler's context manager API can be used to better understand what model operators are the most expensive, +examine their input shapes and stack traces, study device kernel activity and visualize the execution trace. + +.. note:: + An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated. + +""" +import os + +from torch._C._autograd import _supported_activities, DeviceType, kineto_available +from torch._C._profiler import _ExperimentalConfig, ProfilerActivity, RecordScope +from torch.autograd.profiler import KinetoStepTracker, record_function +from torch.optim.optimizer import register_optimizer_step_post_hook + +from .profiler import ( + _KinetoProfile, + ExecutionTraceObserver, + profile, + ProfilerAction, + schedule, + supported_activities, + tensorboard_trace_handler, +) + + +__all__ = [ + "profile", + "schedule", + "supported_activities", + "tensorboard_trace_handler", + "ProfilerAction", + "ProfilerActivity", + "kineto_available", + "DeviceType", + "record_function", + "ExecutionTraceObserver", +] + +from . import itt + + +def _optimizer_post_hook(optimizer, args, kwargs): + KinetoStepTracker.increment_step("Optimizer") + + +if os.environ.get("KINETO_USE_DAEMON", None): + _ = register_optimizer_step_post_hook(_optimizer_post_hook) diff --git a/lib/python3.10/site-packages/torch/profiler/_memory_profiler.py b/lib/python3.10/site-packages/torch/profiler/_memory_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..2095b882f5de9d090ac161ae49d278643919837d --- /dev/null +++ b/lib/python3.10/site-packages/torch/profiler/_memory_profiler.py @@ -0,0 +1,1204 @@ +# mypy: allow-untyped-defs +import collections +import dataclasses +import enum +import itertools as it +import logging +from typing import ( + Any, + cast, + DefaultDict, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + Union, +) +from typing_extensions import Literal + +import torch +from torch._C import FunctionSchema +from torch._C._autograd import _ProfilerResult +from torch._C._profiler import ( + _EventType, + _ExtraFields_Allocation, + _ExtraFields_TorchOp, + _ProfilerEvent, + _TensorMetadata, + RecordScope, +) +from torch._utils import _element_size +from torch.profiler import _utils + + +KeyAndID = Tuple["Key", int] +TensorAndID = Tuple["TensorKey", int] + +log = logging.getLogger(__name__) + + +class Category(enum.Enum): + INPUT = enum.auto() + TEMPORARY = enum.auto() + ACTIVATION = enum.auto() + GRADIENT = enum.auto() + AUTOGRAD_DETAIL = enum.auto() + PARAMETER = enum.auto() + OPTIMIZER_STATE = enum.auto() + + +_CATEGORY_TO_COLORS = { + Category.PARAMETER: "darkgreen", + Category.OPTIMIZER_STATE: "goldenrod", + Category.INPUT: "black", + Category.TEMPORARY: "mediumpurple", + Category.ACTIVATION: "red", + Category.GRADIENT: "mediumblue", + Category.AUTOGRAD_DETAIL: "royalblue", + None: "grey", +} + +_CATEGORY_TO_INDEX = {c: i for i, c in enumerate(_CATEGORY_TO_COLORS)} + + +class Action(enum.Enum): + PREEXISTING = enum.auto() + CREATE = enum.auto() + INCREMENT_VERSION = enum.auto() + DESTROY = enum.auto() + + +_ACTION_TO_INDEX = {i: i.value for i in Action} + + +@dataclasses.dataclass(eq=True, unsafe_hash=False, frozen=True) +class Key: + device: torch.device + + +@dataclasses.dataclass +class _Storage: + """Bundle storage pointer and id. + + All profiling logic should use `allocation_id`, however it is useful to + print storage pointers for debugging and unit tests sometimes look up + values using the storage data pointer of a live Tensor.""" + + ptr: int + allocation_id: int + + def __repr__(self) -> str: + return f"{hex(self.ptr):>18} ({self.allocation_id})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, _Storage) and self.allocation_id == other.allocation_id + + def __hash__(self) -> int: + return hash(self.allocation_id) + + +@dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True) +class TensorKey(Key): + """Hashable identifier for a storage which has been asigned an ID. + + A detailed description of Tensor IDs and why they are needed is given in + `torch/csrc/profiler/collection.h` when `TensorID` is declared. To + summarize, multiple Storage buffers can map to the same logical Tensor. + This dataclass is used to refer to a concrete in-memory StorageImpl of + a Tensor. + """ + + id: int + storage: _Storage + + def __repr__(self) -> str: + return f"id={self.id}: {repr(self.storage):<24} ({self.device})" + + def __lt__(self, other: "TensorKey") -> bool: + return self._as_sortable < other._as_sortable + + @staticmethod + def _make( + tensor_id: Optional[int], + storage_ptr: Optional[int], + allocation_id: Optional[int], + device: torch.device, + ) -> Optional["TensorKey"]: + if ( + tensor_id is not None + and storage_ptr is not None + and allocation_id is not None + ): + return TensorKey(device, tensor_id, _Storage(storage_ptr, allocation_id)) + return None + + @classmethod + def from_allocation(cls, alloc: _ExtraFields_Allocation) -> Optional["TensorKey"]: + return cls._make(alloc.id, alloc.ptr, alloc.allocation_id, alloc.device) + + @classmethod + def from_tensor(cls, t: Optional[_TensorMetadata]) -> Optional["TensorKey"]: + if t is not None: + return cls._make(t.id, t.storage_data_ptr, t.allocation_id, t.device) + return None + + @property + def _as_sortable(self) -> Tuple[int, int, str, int]: + return self.id, self.storage.allocation_id, self.device.type, self.device.index + + +def _extract_parameters_and_gradients( + node: _ProfilerEvent, +) -> Iterator[Tuple[Optional[TensorKey], Optional[TensorKey]]]: + children = node.children + + # AccumulateGrad is used in the Autograd engine to handle gradient updates. + # There are two possible cases: + # 1) This is a newly created gradient Tensor. In that case there is nothing + # to accumulate, so autograd simply detaches the Tensor. + # + # 2) There is a preexisting gradient Tensor and we need to add the newly + # computed update. This is done with an in-place add (aten::add_) op. + # (The underscore suffix denotes "in-place".) + if ( + node.typed[0] == _EventType.TorchOp + and node.typed[1].scope == RecordScope.BACKWARD_FUNCTION + # TODO(robieta): Move away from load bearing names + and node.name == "torch::autograd::AccumulateGrad" + and children + and children[0].typed[0] == _EventType.TorchOp + and children[0].name in ("aten::detach", "aten::add_") + and children[0].typed[1].inputs + and isinstance(children[0].typed[1].inputs[0], _TensorMetadata) + ): + yield None, TensorKey.from_tensor(children[0].typed[1].inputs[0]) + + # We directly instrument `torch.nn.Module` and `torch.optim.Optimizer` + # NOTE: The values captured by the python tracer are cached; they can be + # used to build up labels but do not imply that a Tensor was live at + # a particular time. + elif node.typed[0] == _EventType.PyCall: + typed_fields = node.typed[1] + assert typed_fields.module is None or typed_fields.optimizer is None + if typed_fields.module is not None: + for _, p, p_grad in typed_fields.module.parameters: + yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad) + + if typed_fields.optimizer is not None: + for p, p_grad, _ in typed_fields.optimizer.parameters: + yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad) + + +def extract_parameters(node: _ProfilerEvent) -> Iterator[TensorKey]: + for p, p_grad in _extract_parameters_and_gradients(node): + if p is not None: + yield p + + +def extract_gradients( + node: _ProfilerEvent, +) -> Iterator[Tuple[Optional[TensorKey], TensorKey]]: + for p, p_grad in _extract_parameters_and_gradients(node): + if p_grad is not None: + yield p, p_grad + + +def get_scopes(event: Optional[_ProfilerEvent]) -> Tuple[RecordScope, ...]: + scopes = [] + while event: + if event.typed[0] == _EventType.TorchOp: + scopes.append(event.typed[1].scope) + event = event.parent + return tuple(scopes) + + +class SchemaMatcher: + """Lookup operator schema based on profiled name. + + When profiling we record the operator's name but not the schema. However + some analysis requires that information. Fortunately we can look up + registered schema from the recorded name. We do not, however, record the + overload and so we must compare the profiled arguments with all overloads + to determine viable matches. + + Note: Once https://github.com/pytorch/pytorch/issues/78871 is completed + this code will be obsolete. + """ + + @classmethod + def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> Tuple[Optional[bool], ...]: + """Determine which inputs may have mutated based on function schema. + + Note that we don't need to resolve down to a single schema to perform + this analysis. An input is mutable if it is mutable in any overload. In + practice, however, it is overwhelmingly common to match a single + overload. If we cannot find any valid schema then we must be + conservative and assume all inputs are mutable. + """ + mutable: Optional[List[bool]] = None + for schema in cls.match_schemas(t): + mutable = mutable or [False for _ in schema.arguments] + for i, arg in enumerate(schema.arguments): + mutable[i] |= getattr(arg.alias_info, "is_write", False) + + return tuple(mutable or (None for _ in t.inputs)) + + @classmethod + def match_schemas(cls, t: _ExtraFields_TorchOp) -> Tuple[FunctionSchema, ...]: + signature = tuple( + # Tensor + TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata) + # + # TensorList + else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list) + # + # Scalar and uncaptured inputs. + else i + for i in t.inputs + ) + + def matches(schema) -> bool: + return len(schema.arguments) == len(signature) and all( + cls._types_match(observed, schema_arg.type) + for observed, schema_arg in zip(signature, schema.arguments) + ) + + return tuple(s for s in cls.lookup_schemas(t.name) or () if matches(s)) + + @classmethod + def _types_match(cls, observed, schema_type) -> bool: + if isinstance(schema_type, torch._C.OptionalType): + schema_type = schema_type.getElementType() + return observed is None or cls._types_match(observed, schema_type) + + if isinstance(schema_type, torch._C.AnyType): + return True + + if schema_type.isSubtypeOf(torch._C.ListType.ofTensors()): + return isinstance(observed, list) and all( + isinstance(i, TensorKey) for i in observed + ) + + type_map: Tuple[Tuple[Any, Union[type, Tuple[type, ...]]], ...] = ( + (torch._C.TensorType, TensorKey), + (torch._C.NoneType, type(None)), + (torch._C.BoolType, bool), + (torch._C.IntType, int), + (torch._C.FloatType, float), + (torch._C.ComplexType, complex), + (torch._C.NumberType, (bool, int, float, complex)), + ) + + for jit_type, py_types in type_map: + if isinstance(schema_type, jit_type): + return isinstance(observed, py_types) + + # Profiler only records a subset of possible argument types. If we + # reach this point then the schema must call for a type that profiler + # does not record. Thus, the schema can only be a match if `observed` + # is also None. + return observed is None + + @staticmethod + def lookup_schemas(name: str) -> Optional[Tuple[FunctionSchema, ...]]: + # TODO(robieta): + # _jit_get_schemas_for_operator is quite expensive. (~100us / call) + # Consider adding `functools.lru_cache` if that becomes an issue. + + try: + # Schema lookup will throw if `name` is malformed. (For example, + # schemas must be namespaced and schema lookup will fail if name + # does not include "::".) We simply catch the exception and return + # `None` to denote that `name` cannot be an operator name. + # + # Note that record_function annotations also go through this path, + # so it is expected that some names will not correspond to PyTorch + # operators. + if "::" not in name: + return None + return tuple(torch._C._jit_get_schemas_for_operator(name)) + except RuntimeError: + return None + + +class OpTree: + def __init__(self, result: _ProfilerResult) -> None: + self._root_nodes = result.experimental_event_tree() + self._sorted_nodes = tuple(sorted(self.dfs(), key=lambda x: x.start_time_ns)) + + def dfs(self, *args, **kwargs) -> Iterator[_ProfilerEvent]: + yield from _utils.traverse_dfs(self._root_nodes, *args, **kwargs) + + @property + def sorted_nodes(self) -> Tuple[_ProfilerEvent, ...]: + return self._sorted_nodes + + +class SizeMap: + def __init__(self, op_tree: OpTree) -> None: + self._values: Dict[TensorKey, int] = {} + + for node in op_tree.sorted_nodes: + if node.typed[0] == _EventType.TorchOp: + for t in self._flat_tensor_inputs(node.typed[1]): + self._update_values(t) + + elif node.typed[0] == _EventType.PyCall: + typed_fields = node.typed[1] + assert typed_fields.module is None or typed_fields.optimizer is None + if typed_fields.module is not None: + for _, p, p_grad in typed_fields.module.parameters: + self._update_values(p) + self._update_values(p_grad) + + if typed_fields.optimizer is not None: + for p, p_grad, state in typed_fields.optimizer.parameters: + self._update_values(p) + self._update_values(p_grad) + for _, t in state: + self._update_values(t) + + allocations: Dict[TensorKey, int] = {} + for node in op_tree.sorted_nodes: + if node.typed[0] == _EventType.Allocation: + alloc_fields = node.typed[1] + key = TensorKey.from_allocation(alloc_fields) + if key: + new_size = abs(alloc_fields.alloc_size) + prior_size = allocations.setdefault(key, new_size) + + # It is possible to resize Storage in PyTorch, however we + # key on data pointer so most resizes will be treated as a + # change in storage. The one corner case that cannot be + # handled is `realloc` which successfully resizes the + # storage. At time of writing this is not done anywhere in + # the core PyTorch codebase. + if prior_size != new_size: + delta = f"{prior_size} vs. {new_size}" + log.warning("Mismatch between allocation and free: %s", delta) + + self._values.update(allocations) + + def _update_values(self, t: Optional[_TensorMetadata]) -> None: + key = TensorKey.from_tensor(t) + if key is not None and t is not None and t.layout == torch.strided: + # Scalars are represented as zero dim Tensors + n = max(i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1])) + + num_bytes = n * _element_size(t.dtype) + assert num_bytes >= 0, f"{num_bytes}" + self._values[key] = max(self._values.get(key, 0), num_bytes) + + @staticmethod + def _flat_tensor_inputs(op: _ExtraFields_TorchOp) -> Iterator[_TensorMetadata]: + for i in op.inputs: + if isinstance(i, _TensorMetadata): + yield i + elif isinstance(i, list): + yield from i + + def __getitem__(self, key: TensorKey): + return self._values[key] + + +@dataclasses.dataclass() +class DataFlowEdge: + input_version: Optional[int] = None + mutated: Optional[bool] = False + + @property + def is_allocation(self) -> bool: + return self.input_version is None + + @property + def is_deletion(self) -> bool: + return self.mutated is None + + +class DataFlowNode: + def __init__(self, event: _ProfilerEvent, graph: "DataFlowGraph") -> None: + self._event = event + self._graph = graph + self._edges: Dict[TensorKey, DataFlowEdge] = self._determine_edges() + + for key, edge in self._edges.items(): + if edge.mutated and not edge.is_allocation: + self._graph.bump(key) + + # Make sure the version bumping behavior matches what we expect. + versions = {k: (v, self._graph.lookup(k)) for k, v in self.outputs.items()} + assert all(i == j for i, j in versions.values()), f"{versions}, {self._edges}" + + def _determine_edges(self) -> Dict[TensorKey, DataFlowEdge]: + subtree = tuple(_utils.traverse_dfs([self._event])) + + # Start by populating edges from op inputs and outputs. + mutable_by_key: Dict[Optional[TensorKey], Set[Optional[bool]]] = {} + for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp): + for op_input, mutable in zip( + op.inputs, SchemaMatcher.inputs_are_mutable(op) + ): + # Tensor + if isinstance(op_input, _TensorMetadata): + key = TensorKey.from_tensor(op_input) + mutable_by_key.setdefault(key, set()).add(mutable) + + # TensorList + elif isinstance(op_input, list): + for op_input_i in op_input: + key = TensorKey.from_tensor(op_input_i) + mutable_by_key.setdefault(key, set()).add(mutable) + + edges: DefaultDict[Optional[TensorKey], DataFlowEdge] + edges = collections.defaultdict(DataFlowEdge) + for key, mutable_set in mutable_by_key.items(): + if key is not None: + edges[key].input_version = self._graph.lookup(key) if key else -1 + + # We consider an op to be mutated if we encounter a schema where it + # is a mutable argument OR if it is ambiguous. (We never explicitly + # see it in any schema.) + mutated = (True in mutable_set) or (tuple(mutable_set) == (None,)) + edges[key].mutated = mutated + + # Then handle deletions. Note that deleting a Tensor implicitly adds + # it as an input edge. + for i in subtree: + if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size < 0: + key = TensorKey.from_allocation(i.typed[1]) + edge = edges[key] + assert key is None or edge.mutated is not None, f"Double delete: {key}" + edge.mutated = None + edge.input_version = self._graph.lookup(key) if key else -1 + + # And finally handle allocations. This step must be last, because the + # previous two steps optimistically add input edges. + for i in subtree: + if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size > 0: + edges[TensorKey.from_allocation(i.typed[1])].input_version = None + + # We don't need to sort the inputs, but it makes debugging and unit tests nicer. + return dict(sorted((k, v) for k, v in edges.items() if k is not None)) + + @property + def inputs(self) -> Dict[TensorKey, Tuple[bool, int]]: + return { + # MyPy can't see through `is_allocation` to know that + # `v.input_version` is not None. + k: (bool(v.mutated), cast(int, v.input_version)) + for k, v in self._edges.items() + if not v.is_allocation + } + + @property + def outputs(self) -> Dict[TensorKey, int]: + return { + k: 0 if v.input_version is None else v.input_version + 1 + for k, v in self._edges.items() + if (v.is_allocation and not v.is_deletion) or v.mutated + } + + @property + def intermediates(self) -> Tuple[TensorKey, ...]: + return tuple( + k for k, v in self._edges.items() if v.is_allocation and v.is_deletion + ) + + @property + def start_time(self) -> int: + return self._event.start_time_ns + + +class DataFlowGraph: + def __init__(self, op_tree: OpTree) -> None: + self._op_tree = op_tree + self._leaf_events = self._extract_leaf_events(op_tree) + self._active_version: Dict[TensorKey, Optional[int]] = {} + self._flow_nodes = [DataFlowNode(e, self) for e in self.leaf_events] + self._flow_nodes.sort(key=lambda x: x.start_time) + self.validate() + + @property + def flow_nodes(self) -> Tuple[DataFlowNode, ...]: + return tuple(self._flow_nodes) + + def validate(self): + # Check that each (Tensor, version) pair has a unique creation node + outputs: Set[Tuple[TensorKey, int]] = set() + for node in self.flow_nodes: + node_outputs = set(node.outputs.items()) + duplicates = outputs & node_outputs + assert not duplicates, f"{node._event.name} {node._edges} {duplicates}" + outputs |= node_outputs + + # And check that `self._nodes` forms a valid topologically sorted DAG. + tensor_versions: Dict[TensorKey, int] = {} + for node in self.flow_nodes: + for key, (_, version) in node.inputs.items(): + expected = tensor_versions.get(key, 0) + assert expected == version, (expected, version) + + for key, version in node.outputs.items(): + prior_version = tensor_versions.get(key, version) + assert version >= prior_version, (version, prior_version) + tensor_versions[key] = version + + @property + def leaf_events(self) -> Tuple[_ProfilerEvent, ...]: + return self._leaf_events + + @staticmethod + def _extract_leaf_events(op_tree: OpTree) -> Tuple[_ProfilerEvent, ...]: + """Partially traverse the op tree and extract top level ops. + + Consider the following code: + ``` + with record_function("My annotation"): + x.zero_() + y.zero_() + ``` + + The op tree (assuming no Autograd) will look like: + + TorchOp: "My annotation" + TorchOp: zero_ + TorchOp: fill_ + TorchOp: zero_ + TorchOp: fill_ + + The recursive structure of operator calls makes data flow unwieldy. + In order to simplify analysis we would like to select the highest level + ops to represent in the graph. In this case those are the `zero_` ops; + the fact that `fill_` is called is an implementation detail. We also + do not want to group everything under "My annotation" as this could + create overly coarse bundles and lose critical semantics. + + To address this issue we walk over the graph and select the topmost + torch ops ** which match at least one operator schema **. These form + the leaves of the first pass through the op tree. (As well as any + allocations or frees which do are not part of a kernel.) These events + form the logical nodes in our data flow graph. + """ + + leaf_events: List[_ProfilerEvent] = [] + + def leaf_op(e: _ProfilerEvent) -> bool: + return e.typed[0] == _EventType.TorchOp and ( + e.typed[1].scope == RecordScope.BACKWARD_FUNCTION + or bool(SchemaMatcher.match_schemas(e.typed[1])) + ) + + def children_fn(e: _ProfilerEvent): + if leaf_op(e) or e.tag == _EventType.Allocation: + leaf_events.append(e) + return [] + + return e.children + + for _ in op_tree.dfs(children_fn=children_fn): + pass + + return tuple(sorted(leaf_events, key=lambda x: x.start_time_ns)) + + def lookup(self, key: TensorKey) -> int: + version = self._active_version.setdefault(key, 0) + assert version is not None + return version + + def bump(self, key: TensorKey) -> None: + prior_version = self._active_version.get(key, None) + assert prior_version is not None + self._active_version[key] = prior_version + 1 + + def delete(self, key: TensorKey) -> None: + assert self._active_version.setdefault(key, 0) is not None + self._active_version[key] = None + + +@dataclasses.dataclass +class CategoryElement: + by_id: Optional[Category] = None + by_key: Dict[TensorKey, Category] = dataclasses.field(default_factory=dict) + by_version: Dict[TensorAndID, Category] = dataclasses.field(default_factory=dict) + + # Used by unit tests to check internals. (And consequently by + # MemoryProfile.lookup) This should not be used in any other capacity. + _by_id_keyset: Set[TensorKey] = dataclasses.field(default_factory=set) + + +@dataclasses.dataclass +class CategoryDict: + _values: DefaultDict[int, CategoryElement] = dataclasses.field( + default_factory=lambda: collections.defaultdict(CategoryElement) + ) + + def set_by_id(self, key: TensorKey, category: Category) -> None: + self._values[key.id].by_id = category + self._values[key.id]._by_id_keyset.add(key) + + def set_by_key(self, key: TensorKey, category: Category) -> None: + self._values[key.id].by_key[key] = category + + def set_by_version(self, key: TensorKey, version: int, category: Category) -> None: + self._values[key.id].by_version[(key, version)] = category + + def setdefault_by_version( + self, key: TensorKey, version: int, category: Category + ) -> None: + self._values[key.id].by_version.setdefault((key, version), category) + + def get(self, key: Key, version: int) -> Optional[Category]: + if isinstance(key, Key) and not isinstance(key, TensorKey): + return None + element = self._values[key.id] + return ( + element.by_id + or element.by_key.get(key, None) + or element.by_version.get((key, version), None) + ) + + +class MemoryProfile: + def __init__(self, result: _ProfilerResult) -> None: + self._op_tree = OpTree(result) + self._data_flow_graph = DataFlowGraph(self._op_tree) + self._size_map = SizeMap(self._op_tree) + self._categories = CategoryDict() + + self._set_gradients_and_temporaries() + self._set_parameters_using_python_tracer() + self._set_inputs() + self._set_parameters_using_data_flow() + self._set_activations() + self._set_optimizer_state() + self._set_autograd_detail() + + @property + def timeline(self) -> Tuple[Tuple[int, Action, KeyAndID, int], ...]: + output: List[Tuple[int, Action, KeyAndID, int]] = [] + allocation_times: Dict[Tuple[TensorKey, bool], int] = {} + live_unknown: Dict[Tuple[int, torch.device], Literal[True]] = {} + for event in self._op_tree.dfs(): + if event.typed[0] == _EventType.Allocation: + alloc_fields = event.typed[1] + alloc_size = alloc_fields.alloc_size + is_allocation = alloc_size > 0 + t = event.start_time_ns + + tkey = TensorKey.from_allocation(alloc_fields) + if tkey is not None: + allocation_times[(tkey, is_allocation)] = t + + else: + key = Key(alloc_fields.device) + ptr_and_device = (alloc_fields.ptr, key.device) + if is_allocation: + if ptr_and_device in live_unknown: + output.append( + (t, Action.INCREMENT_VERSION, (key, 0), alloc_size) + ) + else: + live_unknown[ptr_and_device] = True + output.append((t, Action.CREATE, (key, 0), alloc_size)) + else: + output.append((t, Action.DESTROY, (key, 0), -alloc_size)) + if not live_unknown.pop(ptr_and_device, False): + output.append( + (-1, Action.PREEXISTING, (key, 0), -alloc_size) + ) + + snapshot = self._category_snapshot() + last_version = dict(sorted(snapshot.keys())) + + events: List[Tuple[int, Action, TensorAndID]] = [ + (-1, Action.PREEXISTING, (key, version)) + for key, version in snapshot.keys() + if (key, True) not in allocation_times and version == 0 + ] + + for node in self._data_flow_graph.flow_nodes: + for key, edge in node._edges.items(): + if edge.is_allocation: + t = allocation_times[(key, True)] + events.append((t, Action.CREATE, (key, 0))) + + elif edge.mutated: + t = node._event.start_time_ns + version = edge.input_version + assert version is not None + events.append((t, Action.INCREMENT_VERSION, (key, version))) + + if edge.is_deletion: + t = allocation_times[(key, False)] + events.append((t, Action.DESTROY, (key, last_version[key]))) + + output.extend( + (time, action, (key, version), self._size_map[key]) + for time, action, (key, version) in events + ) + + output.sort(key=lambda x: (x[0], x[1].value)) + return tuple(output) + + def _is_gradient(self, *args, **kwargs) -> bool: + return self._categories.get(*args, **kwargs) == Category.GRADIENT + + def _category_snapshot(self) -> Dict[TensorAndID, Optional[Category]]: + all_tensor_versions: Set[TensorAndID] = set() + + for node in self._data_flow_graph.flow_nodes: + all_tensor_versions.update(((k, v) for k, (_, v) in node.inputs.items())) + all_tensor_versions.update((key, 0) for key in node.intermediates) + all_tensor_versions.update(node.outputs.items()) + + for i in self._categories._values.values(): + all_tensor_versions.update((key, 0) for key in i._by_id_keyset) + + return { + (key, version): self._categories.get(key, version) + for key, version in sorted(all_tensor_versions) + } + + def _any_version_depends_on_gradient(self) -> Set[int]: + """Extract IDs of Tensors which depend or will depend on a gradient. + + Note that this weakened definition of "depends" requires us to loop + over the data flow graph multiple times because it allows dependency + information to flow backward through edges and removes the guarantee + that nodes are topologically sorted. (Or indeed, even that a valid + topological order exists.) Put another way, we have converted an + acyclic data flow graph into a cyclic graph and we are attempting to + partition cycles involving a gradient from the rest of the graph. + """ + depends_on_gradient: Set[int] = set() + while True: + start_size = len(depends_on_gradient) + for node in self._data_flow_graph.flow_nodes: + ids = tuple( + key.id + for key, (_, version) in node.inputs.items() + if self._categories.get(key, version) + in (Category.GRADIENT, Category.PARAMETER) + or key.id in depends_on_gradient + ) + + if ids: + depends_on_gradient.update(ids) + depends_on_gradient.update(key.id for key in node.outputs) + + # We are guaranteed to exit because there is a finite set of + # TensorAndID pairs. In practice we do not expect to loop more than + # three times: once to identify the core parameter update loop, + # once to fold the first step into that loop, and a third time + # where no new elements are added. + if len(depends_on_gradient) == start_size: + return depends_on_gradient + + def _set_gradients_and_temporaries(self) -> None: + """Mark Tensors which are unambiguous and simple to reason about.""" + + # Gradients are straightforward to detect. We directly check the + # `.grad` property in the Python tracer, and we can detect any new + # gradient Tensors from `AccumulateGrad` ops. + for event in self._op_tree.dfs(): + for _, p_grad in extract_gradients(event): + self._categories.set_by_id(p_grad, Category.GRADIENT) + + # Similarly, temporary Tensors are easy to identify and are useful to + # flag since they can make memory use "spikier" than one would + # otherwise expect. + for node in self._data_flow_graph.flow_nodes: + for i in node.intermediates: + self._categories.set_by_key(i, Category.TEMPORARY) + + def _set_parameters_using_python_tracer(self) -> None: + for event in self._op_tree.dfs(): + for p in extract_parameters(event): + if p is not None: + self._categories.set_by_id(p, Category.PARAMETER) + + def _set_inputs(self) -> None: + """Mark inputs based on which Tensors are updated using gradients. + + The process for differentiating between inputs and activations is more + involved. Most Tensors in a training loop depend on at least one + gradient: parameters depend on them through updates, and activations + and optimizer state depend on them transitively through parameters. + Critically, we do not need to know which Tensors are parameters to + apply this method; we can simply walk the data flow graph to build the + set of all values which depend on a gradient and then obtain the set + of inputs from the conjugate set. + + There is, however, one hiccup. The first time we see a parameter is + generally on the forward pass of the first step. We know from + inspection of the data flow graph that v1 of that Tensor depends on + a gradient (provided we profile an optimizer step), but not v0. To + address this problem we weaken the definition of "depends on a + gradient" to "any version of this Tensor depends on a gradient", + which in turn strengthens the criteria for the input set enough to + filter the activations in the forward pass of the first step.""" + + # All of this analysis is predicated on using at least one training + # step (or parameters from the python tracer) to partition the graph. + # Absent that we cannot determine which Tensors are inputs and which + # ones are part of the model. + depends_on_gradient = self._any_version_depends_on_gradient() + + # We only want to annotate Tensors which actually contribute to the + # model calculation. + produces_gradient: Set[TensorAndID] = set() + for node in reversed(self._data_flow_graph.flow_nodes): + tensors = {(key, version) for key, (_, version) in node.inputs.items()} + tensors |= node.outputs.items() + if any( + self._categories.get(*i) in (Category.GRADIENT, Category.PARAMETER) + or i in produces_gradient + for i in tensors + ): + produces_gradient |= tensors + + # Don't include Tensors created in the backward pass, as these are + # generally Autograd implementation details rather than proper inputs. + input_candidates = produces_gradient.copy() + for node in self._data_flow_graph.flow_nodes: + if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event): + input_candidates -= set(node.outputs.items()) + + for key, version in input_candidates: + if key.id not in depends_on_gradient: + self._categories.setdefault_by_version(key, version, Category.INPUT) + + def _set_parameters_using_data_flow(self) -> None: + """Deduce which Tensors are parameters. + + Consider the following code for the step of SGD with momentum + (nesterov=False), where `d_p` is the gradient of `param` and `buf` is + the momentum buffer. + ``` + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + d_p = buf + param.add_(d_p, alpha=-lr) + ``` + Both `param` and `buf` take a gradient and perform an in-place update. + + The python tracer will inspect calls to `nn.Module.forward` and + `optim.Optimizer.step` to extract parameter and optimizer state + respectively (including parameters), so this is generally a non-issue. + + However as a fallback we can also exploit several properties of + parameters to distinguish them from other model state. + + First, they are directly used in the forward pass. (At this point we + haven't established which parts of the graph correspond to the forward + pass but we can deduce enough to suffice.) Some mutable state such as + batch norm moving averages also contribute to the forward pass, but + optimizer state does not. + + Second, a parameter is by definition used to compute at least one + gradient and depends on at least one gradient. + """ + snapshot = self._category_snapshot() + + # Determine which Tensors might be parameters based on forward pass + # data flow. Note this these are only candidates; we filter nodes that + # we know are part of the backward pass but that doesn't guarantee that + # they are part of the forward pass. + candidate_parameters: Set[TensorAndID] = set() + candidate_fwd_tensors: Set[TensorAndID] = { + i for i, category in snapshot.items() if category == Category.INPUT + } + + for node in self._data_flow_graph.flow_nodes: + inputs = {(key, value) for key, (_, value) in node.inputs.items()} + if ( + # Don't check nodes in the backward pass. + RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event) + and not any(self._is_gradient(*i) for i in inputs) + and not any(self._is_gradient(*i) for i in node.outputs.items()) + # + # and only check nodes which depend on an input. + and candidate_fwd_tensors.intersection(inputs) + ): + candidate_fwd_tensors |= node.outputs.items() + candidate_parameters |= inputs.difference(candidate_fwd_tensors) + + # Require that each parameter eventually contributes to the value of a gradient + used_for_gradient: Set[TensorAndID] = set() + for node in reversed(self._data_flow_graph.flow_nodes): + if any( + self._is_gradient(*i) or i in used_for_gradient + for i in node.outputs.items() + ): + used_for_gradient.update( + (key, version) for key, (_, version) in node.inputs.items() + ) + candidate_parameters.intersection_update(used_for_gradient) + + # and depends on a gradient. + parameter_keys = {key.id for key, _ in candidate_parameters} + parameter_keys &= self._any_version_depends_on_gradient() + + for key, _ in snapshot.keys(): + if key.id in parameter_keys: + self._categories.set_by_id(key, Category.PARAMETER) + + def _set_activations(self) -> None: + """Flood the graph to identify activations.""" + + required = {Category.INPUT, Category.ACTIVATION} + also_allowed = {Category.PARAMETER, Category.TEMPORARY} + for node in self._data_flow_graph.flow_nodes: + inputs = {(key, value) for key, (_, value) in node.inputs.items()} + input_categories = {self._categories.get(*i) for i in inputs} + + if ( + (input_categories & required) + and not (input_categories - (required | also_allowed)) + # + # Stop filling when we reach the backward pass. + and RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event) + ): + for i in node.outputs.items(): + self._categories.setdefault_by_version(*i, Category.ACTIVATION) + + def _set_optimizer_state(self) -> None: + for event in self._op_tree.dfs(): + if event.typed[0] == _EventType.PyCall and event.typed[1].optimizer: + parameters = event.typed[1].optimizer.parameters + for _, t in it.chain(*[state for _, _, state in parameters]): + key = TensorKey.from_tensor(t) + if key is not None: + self._categories.set_by_id(key, Category.OPTIMIZER_STATE) + + def _set_autograd_detail(self): + prior = {None, Category.AUTOGRAD_DETAIL} + for node in self._data_flow_graph.flow_nodes: + if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event): + for key, version in node.outputs.items(): + if version == 0 or self._categories.get(key, version - 1) in prior: + self._categories.setdefault_by_version( + key, version, Category.AUTOGRAD_DETAIL + ) + + +class MemoryProfileTimeline: + def __init__(self, memory_profile): + """The minimum representation of the memory profile timeline + includes the memory timeline and categories. The timeline + consists of [timestamp, action, (TensorKey, version), numbytes] + elements, to denote any actions (pre-existing, create, destroy, + or increment_version) that occurred to a specific Tensor for a + chunk of memory. The categories help map each (TensorKey, + version) pair into a category.""" + self.timeline = memory_profile.timeline + self.categories = memory_profile._categories + + def _coalesce_timeline(self, device_str): + """Convert the memory timeline and categories into a memory plot + consisting of timestamps and their respective sizes by category + for a given device. + + Input: device + Output: [timestamps, sizes by category] + """ + device = torch.device(device_str) + times: List[int] = [] + sizes: List[List[int]] = [] + + def update(key, version, delta): + category = ( + self.categories.get(key, version) + if isinstance(key, TensorKey) + else None + ) + index = _CATEGORY_TO_INDEX[category] + 1 + sizes[-1][index] += int(delta) + + t_min = -1 + for t, action, (key, version), numbytes in self.timeline: + if key.device != device: + continue + + # Convert timestamps from ns to us, to match trace events. + if t != -1: + t = int(t / 1000) + + # Save the smallest timestamp to populate pre-existing allocs. + if t_min == -1 or (t < t_min and t > 0): + t_min = t + + # Handle timestep + if len(times) == 0: + times.append(t) + sizes.append([0] + [0 for _ in _CATEGORY_TO_INDEX]) + + elif t != times[-1]: + times.append(t) + sizes.append(sizes[-1].copy()) + + # Handle memory and categories + if action in (Action.PREEXISTING, Action.CREATE): + update(key, version, numbytes) + + elif action == Action.INCREMENT_VERSION: + update(key, version, -numbytes) + update(key, version + 1, numbytes) + + elif action == Action.DESTROY: + update(key, version, -numbytes) + + else: + raise ValueError(f"Unknown action: {action}") + + times = [t_min if t < 0 else t for t in times] + return times, sizes + + def export_memory_timeline(self, path, device_str) -> None: + """Saves the memory timeline as [times, sizes by category] + as a JSON formatted file to the given path for the given + device.""" + times, sizes = self._coalesce_timeline(device_str) + # TODO: Write a faster serialize (orjson not available in CI) + import json + + with open(path, "w") as f: + json.dump([times, sizes], f) + + def export_memory_timeline_raw(self, path, device_str) -> None: + """Saves the memory timeline as raw memory event tuples in the + form of (timestamp, action, numbytes, category) + as a JSON formatted file to the given path for the given + device.""" + device = torch.device(device_str) + raw_events: List[Tuple[int, int, int, int]] = [] + + def get_category_index(key, version): + category = ( + self.categories.get(key, version) + if isinstance(key, TensorKey) + else None + ) + return _CATEGORY_TO_INDEX[category] + + for t, action, (key, version), numbytes in self.timeline: + if key.device != device: + continue + + if action in (Action.PREEXISTING, Action.CREATE): + raw_events.append( + ( + t, + _ACTION_TO_INDEX[action], + numbytes, + get_category_index(key, version), + ) + ) + + elif action == Action.INCREMENT_VERSION: + raw_events.append( + ( + t, + _ACTION_TO_INDEX[action], + -numbytes, + get_category_index(key, version), + ) + ) + raw_events.append( + ( + t, + _ACTION_TO_INDEX[action], + numbytes, + get_category_index(key, version + 1), + ) + ) + + elif action == Action.DESTROY: + raw_events.append( + ( + t, + _ACTION_TO_INDEX[action], + -numbytes, + get_category_index(key, version), + ) + ) + + else: + raise ValueError(f"Unknown action: {action}") + + import json + + with open(path, "w") as f: + json.dump(raw_events, f) + + def export_memory_timeline_html( + self, path, device_str, figsize=(20, 12), title=None + ) -> None: + """Exports the memory timeline as an HTML file which contains + the memory timeline plot embedded as a PNG file.""" + # Check if user has matplotlib installed, return gracefully if not. + import importlib.util + + matplotlib_spec = importlib.util.find_spec("matplotlib") + if matplotlib_spec is None: + print( + "export_memory_timeline_html failed because matplotlib was not found." + ) + return + + from base64 import b64encode + from os import remove + from tempfile import NamedTemporaryFile + + import matplotlib.pyplot as plt + import numpy as np + + mt = self._coalesce_timeline(device_str) + times, sizes = np.array(mt[0]), np.array(mt[1]) + # For this timeline, start at 0 to match Chrome traces. + t_min = min(times) + times -= t_min + stacked = np.cumsum(sizes, axis=1) / 1024**3 + device = torch.device(device_str) + max_memory_allocated = torch.cuda.max_memory_allocated(device) + max_memory_reserved = torch.cuda.max_memory_reserved(device) + + # Plot memory timeline as stacked data + fig = plt.figure(figsize=figsize, dpi=80) + axes = fig.gca() + for category, color in _CATEGORY_TO_COLORS.items(): + i = _CATEGORY_TO_INDEX[category] + axes.fill_between( + times / 1e3, stacked[:, i], stacked[:, i + 1], color=color, alpha=0.7 + ) + fig.legend(["Unknown" if i is None else i.name for i in _CATEGORY_TO_COLORS]) + # Usually training steps are in magnitude of ms. + axes.set_xlabel("Time (ms)") + axes.set_ylabel("Memory (GB)") + title = "\n\n".join( + ([title] if title else []) + + [ + f"Max memory allocated: {max_memory_allocated/(1024**3):.2f} GiB \n" + f"Max memory reserved: {max_memory_reserved/(1024**3):.2f} GiB" + ] + ) + axes.set_title(title) + + # Embed the memory timeline image into the HTML file + tmpfile = NamedTemporaryFile("wb", suffix=".png", delete=False) + tmpfile.close() + fig.savefig(tmpfile.name, format="png") + + with open(tmpfile.name, "rb") as tmp: + encoded = b64encode(tmp.read()).decode("utf-8") + html = f""" +GPU Memory Timeline HTML + + + +""" + + with open(path, "w") as f: + f.write(html) + remove(tmpfile.name) diff --git a/lib/python3.10/site-packages/torch/profiler/_pattern_matcher.py b/lib/python3.10/site-packages/torch/profiler/_pattern_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..a7ec5d05dd68e672157d876e7a91cd2f0caf1f21 --- /dev/null +++ b/lib/python3.10/site-packages/torch/profiler/_pattern_matcher.py @@ -0,0 +1,663 @@ +# mypy: allow-untyped-defs +import json +import math +import os +import re +from typing import Dict, List, Optional, Set + +import torch +import torch.utils.benchmark as benchmark +from torch._C._profiler import ( + _EventType, + _ExtraFields_PyCall, + _ExtraFields_PyCCall, + _ExtraFields_TorchOp, + _ProfilerEvent, +) +from torch.profiler import profile +from torch.profiler._utils import index_of_first_match, traverse_bfs, traverse_dfs + + +class Pattern: + """ + Base class for all patterns, subclass this class and implement match() + to define custom patterns. + + In subclass, define description and skip property. + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + self.prof = prof + self.should_benchmark = should_benchmark + self.name = "Please specify a name for pattern" + self.description = "Please specify a description for pattern" + self.url = "" + assert prof.profiler is not None and prof.profiler.kineto_results is not None + self.event_tree = prof.profiler.kineto_results.experimental_event_tree() + self.tid_root: Dict[int, List[_ProfilerEvent]] = {} + for event in self.event_tree: + self.tid_root.setdefault(event.start_tid, []).append(event) + + @property + def skip(self): + return False + + def report(self, event: _ProfilerEvent): + msg = ( + f"{self.description}\n[Source Code Location] {source_code_location(event)}" + ) + return msg + + def eventTreeTraversal(self): + """ + Traverse the event tree and yield all events. + Override this method in subclass to customize the traversal. + """ + yield from traverse_dfs(self.event_tree) + + def summary(self, events: List[_ProfilerEvent]): + default_summary = f"{self.name}: {len(events)} events matched." + if self.should_benchmark: + # If benchmark summary is not empty, use it. + return ( + self.benchmark_summary(events) + if hasattr(self, "benchmark") # type: ignore[attr-defined] + else default_summary + ) + return default_summary + + def benchmark_summary(self, events: List[_ProfilerEvent]): + def format_time(time_ns: int): + unit_lst = ["ns", "us", "ms"] + for unit in unit_lst: + if time_ns < 1000: + return f"{time_ns:.2f} {unit}" + time_ns //= 1000 + return f"{time_ns:.2f} s" + + assert hasattr(self, "benchmark"), "Please implement benchmark()" + shapes_factor_map = self.benchmark(events) # type: ignore[attr-defined] + original_time = sum(event.duration_time_ns for event in events) + new_time = sum( + shapes_factor_map[input_shapes(event)] * event.duration_time_ns + for event in events + ) + return ( + f"{self.name}: {len(events)} events matched. " + f"Total Estimated Speedup: {format_time(original_time - new_time)} ({round(original_time/new_time, 2)}X)" + ) + + def match(self, event: _ProfilerEvent): + """ + Return True if the event matches the pattern. + This method should be overriden in subclass. + """ + raise NotImplementedError + + def matched_events(self): + if self.skip: + return [] + matched_events = [] + for event in self.eventTreeTraversal(): + if self.match(event): + matched_events.append(event) + return matched_events + + def root_of(self, event: _ProfilerEvent): + while event.parent: + event = event.parent + return event + + def siblings_of(self, event: _ProfilerEvent): + if event.parent: + children = event.parent.children + else: + children = self.tid_root[event.start_tid] + index = children.index(event) + return children[:index], children[index + 1 :] + + def next_of(self, event: _ProfilerEvent): + _, next_events = self.siblings_of(event) + return next_events[0] if next_events else None + + def prev_of(self, event: _ProfilerEvent): + prev_events, _ = self.siblings_of(event) + return prev_events[-1] if prev_events else None + + def go_up_until(self, event: _ProfilerEvent, predicate): + if not event: + return None + while event.parent and not predicate(event): + event = event.parent + return event + + +# Patterns + + +class NamePattern(Pattern): + def __init__(self, prof: profile, name: str, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.description = f"Matched Name Event: {name}" + self.name = name + + def match(self, event: _ProfilerEvent): + return re.search(self.name, event.name) is not None + + +class ExtraCUDACopyPattern(Pattern): + """ + This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU. + example: torch.zeros((100, 100)).to("cuda") + + Pattern: + build-in method |build-in method + ... | aten::to + aten::fill_/aten::zero_ | aten::_to_copy + + Algorithm: + We start at node aten::to, go parent events' previous events, + and check if we have a aten::fill_/aten::zero_ as we keep going down the tree. + We always select the last child in the children list when we go down the tree. + If at any step we failed, it is not a match. + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Extra CUDA Copy Pattern" + self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU." + self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device" + self.init_ops = { + "aten::fill_", + "aten::zero_", + "aten::normal_", + "aten::uniform_", + } + + @property + def skip(self): + return not self.prof.with_stack or not self.prof.record_shapes + + def match(self, event): + # TODO: We should also check tensor identities + if event.name != "aten::to": + return False + to_event = event + if not event.children: + return False + event = event.children[-1] + if event.name != "aten::_to_copy": + return False + if not event.children: + return False + event = event.children[-1] + if event.name != "aten::copy_": + return False + # aten::copy_ should have the first 2 args dtype the same + dtypes = input_dtypes(event) + if len(dtypes) < 2: + return False + if dtypes[0] is None or dtypes[0] != dtypes[1]: + return False + event = to_event + # Up one level + event = event.parent + if event is None: + return False + # Check if we have a aten::fill_ in previous leaf + event = self.prev_of(event) + if event is None: + return False + while event.children: + event = event.children[-1] + # aten::zero_ is a special optimzation case where fill_ is not called + if event.name in self.init_ops: + return True + return event.name in self.init_ops + # TODO: Check if tensor is reused + + def benchmark(self, events: List[_ProfilerEvent]): + shapes_factor_map = {input_shapes(event): 0.0 for event in events} + for shape in shapes_factor_map: + size = shape[0] + to_timer = benchmark.Timer( + stmt='torch.ones(size).to("cuda")', globals={"size": size} + ) + de_timer = benchmark.Timer( + stmt='torch.ones(size, device="cuda")', globals={"size": size} + ) + to_time = to_timer.timeit(10).mean + de_time = de_timer.timeit(10).mean + shapes_factor_map[shape] = de_time / to_time + return shapes_factor_map + + +class ForLoopIndexingPattern(Pattern): + """ + This pattern identifies if we use a for loop to index a tensor that + can be vectorized. + example: + tensor = torch.empty((100, 100)) + for i in range(100): + tensor[i] = i + + Pattern: + aten::select | ... | aten::select | ... (Repeat) + + Algorithm: + We start at node aten::select, and we check if we can find this alternating patterns. + We also keep a dictionary to avoid duplicate match in the for loop. + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "For Loop Indexing Pattern" + self.description = "For loop indexing detected. Vectorization recommended." + self.visited: Set[int] = set() + + def eventTreeTraversal(self): + """ + We need to use BFS traversal order to avoid duplicate match. + """ + yield from traverse_bfs(self.event_tree) + + def match(self, event: _ProfilerEvent): + if event.name != "aten::select": + return False + if event.id in self.visited: + return False + repeat_count = 1 + _, next = self.siblings_of(event) + if len(next) <= 1: + return False + + # Custom event list matching + def same_ops(list1, list2): + if len(list1) != len(list2): + return False + for op1, op2 in zip(list1, list2): + if op1.name != op2.name: + return False + return True + + # Record the ops between two aten::select + next_select_idx = index_of_first_match(next, lambda e: e.name == "aten::select") + if next_select_idx is None: + return False + indexing_ops = [event] + next[:next_select_idx] + next = next[len(indexing_ops) - 1 :] + for i in range(0, len(next), len(indexing_ops)): + if same_ops(indexing_ops, next[i : i + len(indexing_ops)]): + repeat_count += 1 + self.visited.add(next[i].id) + else: + break + return repeat_count >= 10 + + +class FP32MatMulPattern(Pattern): + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "FP32 MatMul Pattern" + self.description = ( + "You are currently using GPU that supports TF32. " + "Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'" + ) + self.url = "https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + + @property + def skip(self): + if torch.version.hip is not None: + has_tf32 = False + else: + # Anything less than sm_80 is not Ampere which doesn't support TF32 + has_tf32 = all(int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list()) + return has_tf32 is False or super().skip or not self.prof.record_shapes + + def match(self, event: _ProfilerEvent): + # If we saw this pattern once, we don't need to match it again + if event.tag != _EventType.TorchOp: + return False + assert isinstance(event.extra_fields, _ExtraFields_TorchOp) + if event.name == "aten::mm": + if event.extra_fields.allow_tf32_cublas is False: + return True + return False + + def report(self, event: _ProfilerEvent): + return self.description + + def benchmark(self, events: List[_ProfilerEvent]): + shapes_factor_map = {input_shapes(event): 0.0 for event in events} + for shape in shapes_factor_map: + matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32) + matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32) + fp32_timer = benchmark.Timer( + stmt="torch.mm(matrixA, matrixB)", + globals={"matrixA": matrixA, "matrixB": matrixB}, + ) + tf32_timer = benchmark.Timer( + stmt="torch.mm(matrixA, matrixB)", + setup="torch.backends.cuda.matmul.allow_tf32 = True", + globals={"matrixA": matrixA, "matrixB": matrixB}, + ) + torch.backends.cuda.matmul.allow_tf32 = False + fp32_time = fp32_timer.timeit(10).mean + tf32_time = tf32_timer.timeit(10).mean + shapes_factor_map[shape] = tf32_time / fp32_time + return shapes_factor_map + + +class OptimizerSingleTensorPattern(Pattern): + """ + This pattern identifies if we are using the single-tensor version of an optimizer. + example: + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + By adding foreach=True to enable multi-tensor optimizer, we can gain speedup when + the kernels are relatively small. + + Pattern: + XXXXX: _single_tenser_ + + Algorithm: + String match + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Optimizer Single Tensor Pattern" + self.optimizers_with_foreach = ["adam", "sgd", "adamw"] + self.description = ( + "Deteced optimizer running with single tensor implementation. " + "Please enable multi tensor implementation by passing 'foreach=True' into optimizer." + ) + self.url = "" + + def match(self, event: _ProfilerEvent): + for optimizer in self.optimizers_with_foreach: + if event.name.endswith(f"_single_tensor_{optimizer}"): + return True + return False + + +class SynchronizedDataLoaderPattern(Pattern): + """ + This pattern identifies if we are using num_workers=0 in DataLoader. + example: + torch.utils.data.DataLoader(dataset, batch_size=batch_size) + Add num_workers=N to the arguments. N depends on system configuration. + + Pattern: + dataloader.py(...): __iter__ + dataloader.py(...): _get_iterator + NOT dataloader.py(...): check_worker_number_rationality + + Algorithm: + If we don't see check_worker_number_rationality call in the dataloader __iter__, + It is not an asynchronous dataloader. + + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Synchronized DataLoader Pattern" + self.description = ( + "Detected DataLoader running with synchronized implementation. " + "Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader." + ) + self.url = ( + "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html" + "#enable-async-data-loading-and-augmentation" + ) + + def match(self, event: _ProfilerEvent): + def is_dataloader_function(name: str, function_name: str): + return name.startswith( + os.path.join("torch", "utils", "data", "dataloader.py") + ) and name.endswith(function_name) + + # TODO: fixme! Due to lifetime issues of the function name, this field might + # actually point to an already freed string when the even is a PyCall. + # Just silently skip this to unblock testing. + try: + event.name + except UnicodeDecodeError: + return False + + if not is_dataloader_function(event.name, "__iter__"): + return False + if not event.children: + return False + event = event.children[0] + if not is_dataloader_function(event.name, "_get_iterator"): + return False + if not event.children: + return False + event = event.children[0] + return not is_dataloader_function(event.name, "check_worker_number_rationality") + # TODO: We should also check if the loader is bottleneck. + + +class GradNotSetToNonePattern(Pattern): + """ + This pattern identifies if we are not setting grad to None in zero_grad. + example: + optimizer.zero_grad() + By setting set_to_none=True, we can gain speedup + + Pattern: + XXXXX: _zero_grad + NOT aten::zeros + aten::zero_ + + aten::zero_ is called on each parameter in the model. + We also want to make sure it is not called by aten::zeros. + + Algorithm: + String match + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Gradient Set To Zero Instead of None Pattern" + self.description = ( + "Detected gradient set to zero instead of None. " + "Please add 'set_to_none=True' when calling zero_grad()." + ) + self.url = ( + "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html" + "#disable-gradient-calculation-for-validation-or-inference" + ) + + def match(self, event: _ProfilerEvent): + if not event.name.endswith(": zero_grad"): + return False + if not event.children: + return False + + for sub_event in traverse_dfs(event.children): + if ( + sub_event.name == "aten::zero_" + and sub_event.parent.name != "aten::zeros" + ): + return True + # TODO: We should also check if the optimizer's numerical behavior will change. + return False + + +class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern): + """ + This pattern identifies if we are enabling bias in Conv2d which is followed by BatchNorm2d. + Bias doesn't do anything when followed by batchnorm. + Pattern: + nn.Module: Conv2d | nn.Module: BatchNorm2d + ... + aten::conv2d AND dtype of third argument is not null + The third argument is the bias + Algorithm: + String match + """ + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern" + self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d." + self.url = ( + "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html" + "#disable-bias-for-convolutions-directly-followed-by-a-batch-norm" + ) + + @property + def skip(self): + return self.prof.record_shapes is False or super().skip + + def match(self, event: _ProfilerEvent): + if event.name != "aten::conv2d": + return False + if len(input_dtypes(event)) < 3 or input_dtypes(event)[2] is None: + return False + # This means bias=True + event = self.go_up_until( + event, lambda e: e.name.startswith("nn.Module: Conv2d") + ) + if not event: + return False + event = self.next_of(event) + if not event: + return False + return event.name.startswith("nn.Module: BatchNorm2d") + + +class MatMulDimInFP16Pattern(Pattern): + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Matrix Multiplication Dimension Not Aligned Pattern" + self.description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension." + self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp" + + @property + def skip(self): + return not self.prof.with_stack or not self.prof.record_shapes + + def match(self, event: _ProfilerEvent): + def mutiple_of(shapes, multiple): + return all(dim % multiple == 0 for shape in shapes for dim in shape[-2:]) + + if event.name not in ("aten::mm", "aten::bmm", "aten::addmm"): + return False + if not input_dtypes(event): + return False + arg_dtype = input_dtypes(event)[0] + if arg_dtype in (torch.bfloat16, torch.half) and not mutiple_of( + input_shapes(event), 8 + ): + return True + return False + + def benchmark(self, events: List[_ProfilerEvent]): + def closest_multiple(shapes, multiple): + return [multiple * math.ceil(shape / multiple) for shape in shapes] + + shapes_factor_map = {input_shapes(event): 0.0 for event in events} + for shape in shapes_factor_map: + matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float16) + matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float16) + not_aligned_dim_timer = benchmark.Timer( + stmt="torch.mm(matrixA, matrixB)", + globals={"matrixA": matrixA, "matrixB": matrixB}, + ) + matrixA = torch.randn( + closest_multiple(shape[0], 8), device="cuda", dtype=torch.float16 + ) + matrixB = torch.randn( + closest_multiple(shape[1], 8), device="cuda", dtype=torch.float16 + ) + aligned_dim_timer = benchmark.Timer( + stmt="torch.mm(matrixA, matrixB)", + globals={"matrixA": matrixA, "matrixB": matrixB}, + ) + not_aligned_dim_time = not_aligned_dim_timer.timeit(10).mean + aligned_dim_time = aligned_dim_timer.timeit(10).mean + shapes_factor_map[shape] = aligned_dim_time / not_aligned_dim_time + return shapes_factor_map + + +def source_code_location(event: Optional[_ProfilerEvent]): + while event: + if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall: + assert isinstance( + event.extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall) + ) + if not event.extra_fields.caller.file_name.startswith("torch" + os.sep): + return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}" + event = event.parent + return "No source code location found" + + +def input_shapes(event: _ProfilerEvent): + assert isinstance(event.extra_fields, _ExtraFields_TorchOp) + return tuple(tuple(getattr(i, "sizes", ())) for i in event.extra_fields.inputs) + + +def input_dtypes(event: _ProfilerEvent): + assert isinstance(event.extra_fields, _ExtraFields_TorchOp) + return tuple(getattr(i, "dtype", None) for i in event.extra_fields.inputs) + + +def report_all_anti_patterns( + prof, + should_benchmark: bool = False, + print_enable: bool = True, + json_report_dir: Optional[str] = None, +): + report_dict: Dict = {} + anti_patterns = [ + ExtraCUDACopyPattern(prof, should_benchmark), + # ForLoopIndexingPattern(prof, should_benchmark), + FP32MatMulPattern(prof, should_benchmark), + OptimizerSingleTensorPattern(prof, should_benchmark), + SynchronizedDataLoaderPattern(prof, should_benchmark), + GradNotSetToNonePattern(prof, should_benchmark), + Conv2dBiasFollowedByBatchNorm2dPattern(prof, should_benchmark), + MatMulDimInFP16Pattern(prof, should_benchmark), + ] + reported = set() + summaries = [] + message_list = [f"{'-'*40}TorchTidy Report{'-'*40}"] + message_list.append("Matched Events:") + + for anti_pattern in anti_patterns: + matched_events = anti_pattern.matched_events() + if not matched_events: + continue + summaries.append(anti_pattern.summary(matched_events)) + for event in matched_events: + report_msg = anti_pattern.report(event) + if report_msg not in reported: + message_list.append(report_msg) + reported.add(report_msg) + src_location, line_no = source_code_location(event).split(":") + report_dict.setdefault(src_location, []).append( + { + "line_number": int(line_no), + "name": anti_pattern.name, + "url": anti_pattern.url, + "message": anti_pattern.description, + } + ) + + if json_report_dir is not None: + json_report_path = os.path.join(json_report_dir, "torchtidy_report.json") + if os.path.exists(json_report_path): + with open(json_report_path) as f: + exisiting_report = json.load(f) + exisiting_report.update(report_dict) + report_dict = exisiting_report + with open(json_report_path, "w") as f: + json.dump(report_dict, f, indent=4) + + message_list.append("Summary:") + message_list += summaries + message_list.append(f"{'-'*40}TorchTidy Report{'-'*40}") + if print_enable: + print("\n".join(message_list)) diff --git a/lib/python3.10/site-packages/torch/profiler/_utils.py b/lib/python3.10/site-packages/torch/profiler/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..20dfeb80adeb6f5ef583571f20d18d0d4867e7c5 --- /dev/null +++ b/lib/python3.10/site-packages/torch/profiler/_utils.py @@ -0,0 +1,385 @@ +# mypy: allow-untyped-defs +import functools +import operator +import re +from collections import deque +from dataclasses import dataclass +from typing import Dict, List, TYPE_CHECKING + +from torch.autograd.profiler import profile +from torch.profiler import DeviceType + + +if TYPE_CHECKING: + from torch.autograd import _KinetoEvent + + +def _traverse(tree, next_fn, children_fn=lambda x: x.children, reverse: bool = False): + order = reversed if reverse else lambda x: x + remaining = deque(order(tree)) + while remaining: + curr_event = next_fn(remaining) + yield curr_event + for child_event in order(children_fn(curr_event)): + remaining.append(child_event) + + +traverse_dfs = functools.partial(_traverse, next_fn=lambda x: x.pop(), reverse=True) +traverse_bfs = functools.partial( + _traverse, next_fn=lambda x: x.popleft(), reverse=False +) + + +@dataclass +class EventMetrics: + duration_time_ns: int = 0 + self_time_ns: int = 0 + idle_time_ns: int = 0 + queue_depth: int = 0 + + @property + def fraction_idle_time(self): + if self.duration_time_ns == 0: + return 0.0 + return self.idle_time_ns / self.duration_time_ns + + +@dataclass +class Interval: + start: int + end: int + queue_depth: int = 0 + + +class EventKey: + def __init__(self, event): + self.event = event + + def __hash__(self): + return hash(self.event.id) + + def __eq__(self, other): + return self.event.id == other.event.id + + def __repr__(self): + return f"{self.event.name}" + + def intervals_overlap(self, intervals: List[Interval]): + overlap_time = 0 + intervals = sorted(intervals, key=lambda x: x.start) + + if intervals: + overlap_start = max(self.event.start_time_ns, intervals[0].start) + overlap_end = min(self.event.end_time_ns, intervals[0].end) + + if overlap_start < overlap_end: + overlap_time += overlap_end - overlap_start + + i, j = 0, 1 + while j < len(intervals): + prev_interval = intervals[i] + curr_interval = intervals[j] + j += 1 + if prev_interval.end > curr_interval.start: + # Completely subsumed by previous interval + if prev_interval.end > curr_interval.end: + j += 1 + continue + else: + curr_interval.start = prev_interval.end + i = j + + overlap_start = max(self.event.start_time_ns, curr_interval.start) + overlap_end = min(self.event.end_time_ns, curr_interval.end) + if overlap_start < overlap_end: + overlap_time += overlap_end - overlap_start + + return overlap_time + + +class BasicEvaluation: + def __init__(self, prof: profile): + self.profile = prof + self.metrics: Dict[EventKey, EventMetrics] = {} + self.compute_self_time() + self.event_keys = sorted( + (e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns + ) + self.events = [e.event for e in self.event_keys] + self.cuda_events: List[_KinetoEvent] = [] + self.queue_depth_list = self.compute_queue_depth() + self.compute_idle_time() + + def compute_self_time(self): + """ + Computes event's self time(total time - time in child ops). + """ + assert self.profile.kineto_results is not None + stack = deque(self.profile.kineto_results.experimental_event_tree()) + + # standard iterating dfs + while stack: + curr_event = stack.pop() + self_time = curr_event.duration_time_ns + for child_event in curr_event.children: + self_time -= child_event.duration_time_ns + stack.append(child_event) + assert ( + EventKey(curr_event) not in self.metrics + ), f"Duplicate id: {curr_event.id}, {curr_event.name}" + self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time) + self.metrics[ + EventKey(curr_event) + ].duration_time_ns = curr_event.duration_time_ns + + def compute_queue_depth(self): + """ + Computes queue_depth at each event. This will calculate the queue depth data for + All the events in the tree. + This will return a list of Interval of queue depth data of cuda launch and kernels. + """ + assert self.profile.kineto_results is not None + cuda_event_list = self.profile.kineto_results.events() + + def is_cuda_launch_kernel(e): + # TODO: find a better way to identify cudaLaunchKernel + return e.name == "cudaLaunchKernel" + + def is_cuda_kernel(e): + # TODO: find a better way to identify CUDA Kernel + return e.device_type() == DeviceType.CUDA and "mem" not in e.name.lower() + + cuda_launch_events = sorted( + (e for e in cuda_event_list if is_cuda_launch_kernel(e)), + key=lambda x: x.start_ns(), + ) + cuda_kernel_events = sorted( + (e for e in cuda_event_list if is_cuda_kernel(e)), + key=lambda x: x.start_ns(), + ) + + self.cuda_events = sorted( + cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_ns() + ) + + kernel_mapping: Dict[_KinetoEvent, int] = {} + last_mapped_kernel = 0 + for cuda_launch_event in cuda_launch_events: + index = index_of_first_match( + cuda_kernel_events, + lambda x: x.linked_correlation_id() + == cuda_launch_event.linked_correlation_id(), + start=last_mapped_kernel, + ) + kernel_mapping[cuda_launch_event] = index + last_mapped_kernel = index if index is not None else last_mapped_kernel + + current_kernel_index = 0 + spawned_kernel_index = -1 + + all_events = cuda_launch_events + cuda_kernel_events + self.events + + def new_old_event_comparator(event): + if hasattr(event, "start_us"): + return event.start_us() * 1000 + if hasattr(event, "start_ns"): + return event.start_ns() + if hasattr(event, "start_time_ns"): + return event.start_time_ns + raise Exception("Unknown Event Type") # noqa: TRY002 + + queue_depth_list: List[Interval] = [] + all_events.sort(key=new_old_event_comparator) + for event in all_events: + # Find latest cuda kernel event + if hasattr(event, "start_us"): + start_time = event.start_us() * 1000 + end_time = (event.start_us() + event.duration_us()) * 1000 + # Find current spawned cuda kernel event + if event in kernel_mapping and kernel_mapping[event] is not None: + spawned_kernel_index = kernel_mapping[event] + if hasattr(event, "start_ns"): + start_time = event.start_ns() + end_time = event.start_ns() + event.duration_ns() + # Find current spawned cuda kernel event + if event in kernel_mapping and kernel_mapping[event] is not None: + spawned_kernel_index = kernel_mapping[event] + elif hasattr(event, "start_time_ns"): + start_time = event.start_time_ns # type: ignore[attr-defined] + end_time = event.end_time_ns # type: ignore[attr-defined] + + while ( + current_kernel_index < len(cuda_kernel_events) + and (cuda_kernel_events[current_kernel_index].start_ns()) + <= start_time # type: ignore[possibly-undefined] + ): + current_kernel_index += 1 + current_queue_depth = spawned_kernel_index - current_kernel_index + 1 + current_queue_depth = max(current_queue_depth, 0) + + if hasattr(event, "start_us") or hasattr(event, "start_ns"): + queue_depth_list.append( + Interval(start_time, end_time, current_queue_depth) # type: ignore[possibly-undefined] + ) + elif hasattr(event, "start_time_ns"): + self.metrics[EventKey(event)].queue_depth = current_queue_depth + + return queue_depth_list + + def compute_idle_time(self): + """ + Computes idle time of the profile. + """ + # Based on queue_depth_list, we can calculate idle time for all the events + idle = False + idle_start = 0 + idle_intervals: List[Interval] = [] + if self.queue_depth_list and self.events: + idle_intervals += [ + Interval(self.events[0].start_time_ns, self.queue_depth_list[0].start), + Interval(self.queue_depth_list[-1].end, self.events[-1].end_time_ns), + ] + + for data_point in self.queue_depth_list: + if data_point.queue_depth == 0 and not idle: + idle_start = data_point.end + idle = True + if data_point.queue_depth > 0 and idle: + idle_intervals.append(Interval(idle_start, data_point.start)) + idle = False + + event_list = [e.event for e in self.metrics.keys()] + for event in event_list: + self.metrics[EventKey(event)].idle_time_ns = EventKey( + event + ).intervals_overlap(idle_intervals) + + def rank_events(self, length): + """ + Filter and Rank the events based on some heuristics: + 1) Events that are in the falling phase of the queue depth. + 2) Events that have a high idle_time, self_time difference. + + Parameters: + length: The number of events to return. + """ + + # Find the interval when qd is falling to 0 + import torch + + queue_depth_list = list(reversed(self.queue_depth_list)) + qd_values = [e.queue_depth for e in queue_depth_list] + + bottom_threashold = 0 + top_threashold = 4 + decrease_interval = [] + i = 0 + while i < len(qd_values): + if qd_values[i] > bottom_threashold: + i += 1 + continue + for j in range(i + 1, len(qd_values)): + # Find next zero and if the max value between them exceeds + # the threshold, then we have a falling interval + next_minimum_idx = index_of_first_match( + qd_values, lambda x: x <= bottom_threashold, start=j + ) + peak_idx = argmax(qd_values, start=j, end=next_minimum_idx) + + # if is a valid peak, we add to list and continue + if peak_idx is not None and qd_values[peak_idx] >= top_threashold: + decrease_interval.append( + Interval( + queue_depth_list[peak_idx].start, queue_depth_list[i].start + ) + ) + i = next_minimum_idx if next_minimum_idx is not None else i + break + i += 1 + # Filter out events that are not in the decrease interval + event_list = [ + event + for event in self.metrics.keys() + if event.intervals_overlap(decrease_interval) + ] + if event_list: + self_time = torch.tensor( + [self.metrics[event].self_time_ns for event in event_list], + dtype=torch.float32, + ) + idle_time = torch.tensor( + [self.metrics[event].fraction_idle_time for event in event_list], + dtype=torch.float32, + ) + normalized_gain = (idle_time - torch.mean(idle_time)) / torch.std(idle_time) + normalized_self = (self_time - torch.mean(self_time)) / torch.std(self_time) + heuristic_score_list = normalized_gain + 0.6 * normalized_self + + # Sort events by heuristic + event_list = [ + event + for _, event in sorted( + zip(heuristic_score_list, event_list), + key=operator.itemgetter(0), + reverse=True, + ) + ] + event_list = event_list[:length] + return event_list + + def get_optimizable_events(self, length: int = 1, print_enable: bool = True): + event_list = self.rank_events(length) + if not print_enable: + return event_list + output = "Optimizable events:\n" if event_list else "No events to optimize\n" + + output += "\n".join( + [ + f"""{'-'*80} +Event: {event} +Source code location: {source_code_location(event.event)} +Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}% +{'-'*80}""" + for event in event_list + ] + ) + if print_enable: + print(output) + return event_list + + +def index_of_first_match(seq, predicate, start=0, end=None): + if end is None or end >= len(seq): + end = len(seq) + for i in range(start, end): + if predicate(seq[i]): + return i + return None + + +def argmax(seq, key=lambda x: x, start=0, end=None): + seq = seq[start:end] + if len(seq) == 0: + return None + return seq.index(max(seq, key=key)) + start + + +def source_code_location(event): + while event is not None: + match = re.search(r"\.py\(.*\)", event.name) + if match is None: + event = event.parent + continue + return event.name + return "No source code location found" + + +# Provide an OSS workaround for cudagraphs + CUPTI issue +# https://github.com/pytorch/pytorch/issues/75504 +# TODO(dberard) - deprecate / remove workaround for CUDA >= 12, when +# we stop supporting older CUDA versions. +def _init_for_cuda_graphs(): + from torch.autograd.profiler import profile + + with profile(): + pass diff --git a/lib/python3.10/site-packages/torch/profiler/itt.py b/lib/python3.10/site-packages/torch/profiler/itt.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4bda2b3420bdb367033aba2c0ef426bdd2a59a --- /dev/null +++ b/lib/python3.10/site-packages/torch/profiler/itt.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + + +try: + from torch._C import _itt +except ImportError: + + class _ITTStub: + @staticmethod + def _fail(*args, **kwargs): + raise RuntimeError( + "ITT functions not installed. Are you sure you have a ITT build?" + ) + + @staticmethod + def is_available(): + return False + + rangePush = _fail + rangePop = _fail + mark = _fail + + _itt = _ITTStub() # type: ignore[assignment] + + +__all__ = ["is_available", "range_push", "range_pop", "mark", "range"] + + +def is_available(): + """ + Check if ITT feature is available or not + """ + return _itt.is_available() + + +def range_push(msg): + """ + Pushes a range onto a stack of nested range span. Returns zero-based + depth of the range that is started. + + Arguments: + msg (str): ASCII message to associate with range + """ + return _itt.rangePush(msg) + + +def range_pop(): + """ + Pops a range off of a stack of nested range spans. Returns the + zero-based depth of the range that is ended. + """ + return _itt.rangePop() + + +def mark(msg): + """ + Describe an instantaneous event that occurred at some point. + + Arguments: + msg (str): ASCII message to associate with the event. + """ + return _itt.mark(msg) + + +@contextmanager +def range(msg, *args, **kwargs): + """ + Context manager / decorator that pushes an ITT range at the beginning + of its scope, and pops it at the end. If extra arguments are given, + they are passed as arguments to msg.format(). + + Args: + msg (str): message to associate with the range + """ + range_push(msg.format(*args, **kwargs)) + try: + yield + finally: + range_pop() diff --git a/lib/python3.10/site-packages/torch/profiler/profiler.py b/lib/python3.10/site-packages/torch/profiler/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..939ae73a99afb4d2202aafcdc1209738e643f3df --- /dev/null +++ b/lib/python3.10/site-packages/torch/profiler/profiler.py @@ -0,0 +1,935 @@ +# mypy: allow-untyped-defs +import gzip +import json +import os +import shutil +import tempfile +from abc import ABC, abstractmethod +from enum import Enum +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing_extensions import Self +from warnings import warn + +import torch +import torch.autograd.profiler as prof +from torch._C import _get_privateuse1_backend_name +from torch._C._profiler import ( + _add_execution_trace_observer, + _disable_execution_trace_observer, + _enable_execution_trace_observer, + _ExperimentalConfig, + _remove_execution_trace_observer, +) +from torch.autograd import kineto_available, ProfilerActivity +from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline + + +__all__ = [ + "supported_activities", + "ProfilerAction", + "schedule", + "tensorboard_trace_handler", + "profile", + "ExecutionTraceObserver", +] +PROFILER_STEP_NAME = "ProfilerStep" + + +def supported_activities(): + """ + Returns a set of supported profiler tracing activities. + + Note: profiler uses CUPTI library to trace on-device CUDA kernels. + In case when CUDA is enabled but CUPTI is not available, passing + ``ProfilerActivity.CUDA`` to profiler results in using the legacy CUDA + profiling code (same as in the legacy ``torch.autograd.profiler``). + This, in turn, results in including CUDA time in the profiler table output, + but not in the JSON trace. + """ + return torch.autograd._supported_activities() + + +class _ITraceObserver(ABC): + """Abstract interface for a Trace observer. + This satisfies 3 methods: start, stop and cleanup""" + + @abstractmethod + def start(self): + pass + + @abstractmethod + def stop(self): + pass + + @abstractmethod + def cleanup(self): + pass + + +class _KinetoProfile: + """Low-level profiler wrap the autograd profile + + Args: + activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: + ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``, + ``torch.profiler.ProfilerActivity.XPU``. + Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA + or (when available) ProfilerActivity.XPU. + record_shapes (bool): save information about operator's input shapes. + profile_memory (bool): track tensor memory allocation/deallocation (see ``export_memory_timeline`` + for more details). + with_stack (bool): record source information (file and line number) for the ops. + with_flops (bool): use formula to estimate the FLOPS of specific operators + (matrix multiplication and 2D convolution). + with_modules (bool): record module hierarchy (including function names) + corresponding to the callstack of the op. e.g. If module A's forward call's + module B's forward which contains an aten::add op, + then aten::add's module hierarchy is A.B + Note that this support exist, at the moment, only for TorchScript models + and not eager mode models. + experimental_config (_ExperimentalConfig) : A set of experimental options + used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed. + execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object. + `PyTorch Execution Traces `__ offer a graph based + representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators. + When this argument is included the observer start() and stop() will be called for the + same time window as PyTorch profiler. + acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles + + + .. note:: + This API is experimental and subject to change in the future. + + Enabling shape and stack tracing results in additional overhead. + When record_shapes=True is specified, profiler will temporarily hold references to the tensors; + that may further prevent certain optimizations that depend on the reference count and introduce + extra tensor copies. + """ + + def __init__( + self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + experimental_config: Optional[_ExperimentalConfig] = None, + execution_trace_observer: Optional[_ITraceObserver] = None, + acc_events: bool = False, + ): + self.activities = set(activities) if activities else supported_activities() + self.record_shapes = record_shapes + self.with_flops = with_flops + self.profile_memory = profile_memory + self.with_stack = with_stack + self.with_modules = with_modules + self.experimental_config = experimental_config + self.execution_trace_observer = execution_trace_observer + self.acc_events = acc_events + self.profiler: Optional[prof.profile] = None + self.mem_tl: Optional[MemoryProfileTimeline] = None + self.use_device = None + if ProfilerActivity.CUDA in self.activities: + self.use_device = "cuda" + elif ProfilerActivity.XPU in self.activities: + self.use_device = "xpu" + elif ProfilerActivity.MTIA in self.activities: + self.use_device = "mtia" + elif ProfilerActivity.PrivateUse1 in self.activities: + self.use_device = _get_privateuse1_backend_name() + + # user-defined metadata to be amended to the trace + self.preset_metadata: Dict[str, str] = {} + + def start(self): + self.prepare_trace() + self.start_trace() + + def stop(self): + self.stop_trace() + + def prepare_trace(self): + if (self.profiler is None) or (not self.acc_events): + self.profiler = prof.profile( + use_cpu=(ProfilerActivity.CPU in self.activities), + use_device=self.use_device, + record_shapes=self.record_shapes, + with_flops=self.with_flops, + profile_memory=self.profile_memory, + with_stack=self.with_stack, + with_modules=self.with_modules, + use_kineto=True, + experimental_config=self.experimental_config, + acc_events=self.acc_events, + ) + self.profiler._prepare_trace() + + def start_trace(self): + if self.execution_trace_observer: + self.execution_trace_observer.start() + assert self.profiler is not None + self.profiler._start_trace() + + if self.profile_memory: + self.add_metadata_json("profile_memory", "1") + if self.with_stack: + self.add_metadata_json("with_stack", "1") + if self.record_shapes: + self.add_metadata_json("record_shapes", "1") + if self.with_modules: + self.add_metadata_json("with_modules", "1") + if self.with_flops: + self.add_metadata_json("with_flops", "1") + + if kineto_available(): + dist_info = self._get_distributed_info() + if dist_info: + self.add_metadata_json("distributedInfo", json.dumps(dist_info)) + + if hasattr(torch, "_inductor"): + import torch._inductor.config as inductor_config + + if inductor_config.triton.cudagraphs: + os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" + self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1") + # FIXME: CUDA Graph does not work well with CUPTI teardown. + # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11) + # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12) + # Workaround: turn off CUPTI teardown when using CUDA Graphs. + os.environ["TEARDOWN_CUPTI"] = "0" + + # Insert the preset user metadata to the trace + for k, v in self.preset_metadata.items(): + self.add_metadata_json(k, v) + + def stop_trace(self): + if self.execution_trace_observer: + self.execution_trace_observer.stop() + assert self.profiler is not None + self.profiler.__exit__(None, None, None) + + def export_chrome_trace(self, path: str): + """ + Exports the collected trace in Chrome JSON format. If kineto is enabled, only + last cycle in schedule is exported. + """ + assert self.profiler + if path.endswith(".gz"): + fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) + fp.close() + retvalue = self.profiler.export_chrome_trace(fp.name) + with open(fp.name) as fin: + with gzip.open(path, "wt") as fout: + fout.writelines(fin) + os.remove(fp.name) + return retvalue + else: + return self.profiler.export_chrome_trace(path) + + def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): + """Save stack traces to a file + + Args: + path (str): save stacks file to this location; + metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total" + """ + assert self.profiler + return self.profiler.export_stacks(path, metric) + + def toggle_collection_dynamic( + self, enable: bool, activities: Iterable[ProfilerActivity] + ): + """Toggle collection of activities on/off at any point of collection. Currently supports toggling Torch Ops + (CPU) and CUDA activity supported in Kineto + + Args: + activities (iterable): list of activity groups to use in profiling, supported values: + ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA`` + Examples: + + .. code-block:: python + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + code_to_profile_0() + // turn off collection of all CUDA activity + p.toggle_collection_dynamic(False, [torch.profiler.ProfilerActivity.CUDA]) + code_to_profile_1() + // turn on collection of all CUDA activity + p.toggle_collection_dynamic(True, [torch.profiler.ProfilerActivity.CUDA]) + code_to_profile_2() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + """ + if not self.profiler: + return + self.profiler.toggle_collection_dynamic(enable, activities) + + def key_averages( + self, group_by_input_shape: bool = False, group_by_stack_n: int = 0 + ): + """Averages events, grouping them by operator name and (optionally) input shapes and + stack. + + .. note:: + To use shape/stack functionality make sure to set record_shapes/with_stack + when creating profiler context manager. + """ + assert self.profiler + return self.profiler.key_averages(group_by_input_shape, group_by_stack_n) + + def events(self): + """ + Returns the list of unaggregated profiler events, + to be used in the trace callback or after the profiling is finished + """ + assert self.profiler + return self.profiler.function_events + + def add_metadata(self, key: str, value: str): + """ + Adds a user defined metadata with a string key and a string value + into the trace file + """ + wrapped_value = '"' + value.replace('"', '\\"') + '"' + torch.autograd._add_metadata_json(key, wrapped_value) + + def add_metadata_json(self, key: str, value: str): + """ + Adds a user defined metadata with a string key and a valid json value + into the trace file + """ + torch.autograd._add_metadata_json(key, value) + + def preset_metadata_json(self, key: str, value: str): + """ + Preset a user defined metadata when the profiler is not started + and added into the trace file later. + Metadata is in the format of a string key and a valid json value + """ + self.preset_metadata[key] = value + + def _get_distributed_info(self): + import torch.distributed as dist + + if not dist.is_available() or not dist.is_initialized(): + return None + + backend = dist.get_backend() + dist_info = { + "backend": backend, + "rank": dist.get_rank(), + "world_size": dist.get_world_size(), + "pg_count": dist.get_pg_count(), + "pg_config": dist.distributed_c10d._get_all_pg_configs(), + } + if backend == "nccl": + nccl_version = torch.cuda.nccl.version() + dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version) + return dist_info + + def _memory_profile(self) -> MemoryProfile: + required = ("record_shapes", "profile_memory", "with_stack") + missing = [f"{i}=True" for i in required if not getattr(self, i)] + if missing: + raise ValueError(f"{', '.join(missing)} required for memory profiling.") + + assert self.profiler is not None and self.profiler.kineto_results is not None + return MemoryProfile(self.profiler.kineto_results) + + def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None: + """Export memory event information from the profiler collected + tree for a given device, and export a timeline plot. There are 3 + exportable files using ``export_memory_timeline``, each controlled by the + ``path``'s suffix. + + - For an HTML compatible plot, use the suffix ``.html``, and a memory timeline + plot will be embedded as a PNG file in the HTML file. + + - For plot points consisting of ``[times, [sizes by category]]``, where + ``times`` are timestamps and ``sizes`` are memory usage for each category. + The memory timeline plot will be saved a JSON (``.json``) or gzipped JSON + (``.json.gz``) depending on the suffix. + + - For raw memory points, use the suffix ``.raw.json.gz``. Each raw memory + event will consist of ``(timestamp, action, numbytes, category)``, where + ``action`` is one of ``[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]``, + and ``category`` is one of the enums from + ``torch.profiler._memory_profiler.Category``. + + Output: Memory timeline written as gzipped JSON, JSON, or HTML. + """ + # Default to device 0, if unset. Fallback on cpu. + if device is None and self.use_device and self.use_device != "cuda": + device = self.use_device + ":0" + + if device is None: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # Construct the memory timeline plot data + self.mem_tl = MemoryProfileTimeline(self._memory_profile()) + + # Depending on the file suffix, save the data as json.gz or json. + # For html, we can embed the image into an HTML file. + if path.endswith(".html"): + self.mem_tl.export_memory_timeline_html(path, device) + elif path.endswith(".gz"): + fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) + fp.close() + if path.endswith("raw.json.gz"): + self.mem_tl.export_memory_timeline_raw(fp.name, device) + else: + self.mem_tl.export_memory_timeline(fp.name, device) + with open(fp.name) as fin: + with gzip.open(path, "wt") as fout: + fout.writelines(fin) + os.remove(fp.name) + else: + self.mem_tl.export_memory_timeline(path, device) + + +class ProfilerAction(Enum): + """ + Profiler actions that can be taken at the specified intervals + """ + + NONE = 0 + WARMUP = 1 + RECORD = 2 + RECORD_AND_SAVE = 3 + + +def schedule( + *, wait: int, warmup: int, active: int, repeat: int = 0, skip_first: int = 0 +) -> Callable: + """ + Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip + the first ``skip_first`` steps, then wait for ``wait`` steps, then do the warmup for the next ``warmup`` steps, + then do the active recording for the next ``active`` steps and then repeat the cycle starting with ``wait`` steps. + The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that + the cycles will continue until the profiling is finished. + """ + + def schedule_fn(step: int) -> ProfilerAction: + assert step >= 0 + if step < skip_first: + return ProfilerAction.NONE + else: + step -= skip_first + num_steps = wait + warmup + active + if repeat > 0 and step / num_steps >= repeat: + return ProfilerAction.NONE + mod_step = step % num_steps + if mod_step < wait: + return ProfilerAction.NONE + elif mod_step < wait + warmup: + return ProfilerAction.WARMUP + else: + return ( + ProfilerAction.RECORD + if mod_step < num_steps - 1 + else ProfilerAction.RECORD_AND_SAVE + ) + + assert ( + wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0 + ), "Invalid profiler schedule arguments" + if warmup == 0: + warn("Profiler won't be using warmup, this can skew profiler results") + return schedule_fn + + +def _default_schedule_fn(_: int) -> ProfilerAction: + """ + Default profiler behavior - immediately starts recording the events, + keeps doing it on every profiler step. + """ + return ProfilerAction.RECORD + + +def tensorboard_trace_handler( + dir_name: str, worker_name: Optional[str] = None, use_gzip: bool = False +): + """ + Outputs tracing files to directory of ``dir_name``, then that directory can be + directly delivered to tensorboard as logdir. + ``worker_name`` should be unique for each worker in distributed scenario, + it will be set to '[hostname]_[pid]' by default. + """ + import os + import socket + import time + + def handler_fn(prof) -> None: + nonlocal worker_name + if not os.path.isdir(dir_name): + try: + os.makedirs(dir_name, exist_ok=True) + except Exception as e: + raise RuntimeError("Can't create directory: " + dir_name) from e + if not worker_name: + worker_name = f"{socket.gethostname()}_{os.getpid()}" + # Use nanosecond here to avoid naming clash when exporting the trace + file_name = f"{worker_name}.{time.time_ns()}.pt.trace.json" + if use_gzip: + file_name = file_name + ".gz" + prof.export_chrome_trace(os.path.join(dir_name, file_name)) + + return handler_fn + + +class profile(_KinetoProfile): + """Profiler context manager. + + Args: + activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: + ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``, + ``torch.profiler.ProfilerActivity.XPU``. + Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA + or (when available) ProfilerActivity.XPU. + schedule (Callable): callable that takes step (int) as a single parameter and returns + ``ProfilerAction`` value that specifies the profiler action to perform at each step. + on_trace_ready (Callable): callable that is called at each step when ``schedule`` + returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling. + record_shapes (bool): save information about operator's input shapes. + profile_memory (bool): track tensor memory allocation/deallocation. + with_stack (bool): record source information (file and line number) for the ops. + with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators + (matrix multiplication and 2D convolution). + with_modules (bool): record module hierarchy (including function names) + corresponding to the callstack of the op. e.g. If module A's forward call's + module B's forward which contains an aten::add op, + then aten::add's module hierarchy is A.B + Note that this support exist, at the moment, only for TorchScript models + and not eager mode models. + experimental_config (_ExperimentalConfig) : A set of experimental options + used for Kineto library features. Note, backward compatibility is not guaranteed. + execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object. + `PyTorch Execution Traces `__ offer a graph based + representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators. + When this argument is included the observer start() and stop() will be called for the + same time window as PyTorch profiler. See the examples section below for a code sample. + acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles + use_cuda (bool): + .. deprecated:: 1.8.1 + use ``activities`` instead. + + .. note:: + Use :func:`~torch.profiler.schedule` to generate the callable schedule. + Non-default schedules are useful when profiling long training jobs + and allow the user to obtain multiple traces at the different iterations + of the training process. + The default schedule simply records all the events continuously for the + duration of the context manager. + + .. note:: + Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard: + + ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)`` + + After profiling, result files can be found in the specified directory. Use the command: + + ``tensorboard --logdir dir_name`` + + to see the results in TensorBoard. + For more information, see + `PyTorch Profiler TensorBoard Plugin `__ + + .. note:: + Enabling shape and stack tracing results in additional overhead. + When record_shapes=True is specified, profiler will temporarily hold references to the tensors; + that may further prevent certain optimizations that depend on the reference count and introduce + extra tensor copies. + + + Examples: + + .. code-block:: python + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + code_to_profile() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + + Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: + + .. code-block:: python + + # Non-default profiler schedule allows user to turn profiler on and off + # on different iterations of the training loop; + # trace_handler is called every time a new trace becomes available + def trace_handler(prof): + print(prof.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + + # In this example with wait=1, warmup=1, active=2, repeat=1, + # profiler will skip the first step/iteration, + # start warming up on the second, record + # the third and the forth iterations, + # after which the trace will become available + # and on_trace_ready (when set) is called; + # the cycle repeats starting with the next step + + schedule=torch.profiler.schedule( + wait=1, + warmup=1, + active=2, + repeat=1), + on_trace_ready=trace_handler + # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') + # used when outputting for tensorboard + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + # send a signal to the profiler that the next iteration has started + p.step() + + The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`) + + .. code-block:: python + + with torch.profiler.profile( + ... + execution_trace_observer=( + ExecutionTraceObserver().register_callback("./execution_trace.json") + ), + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + p.step() + + You can also refer to test_execution_trace_with_kineto() in tests/profiler/test_profiler.py. + Note: One can also pass any object satisfying the _ITraceObserver interface. + """ + + def __init__( + self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + schedule: Optional[Callable[[int], ProfilerAction]] = None, + on_trace_ready: Optional[Callable[..., Any]] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + experimental_config: Optional[_ExperimentalConfig] = None, + execution_trace_observer: Optional[_ITraceObserver] = None, + acc_events: bool = False, + # deprecated: + use_cuda: Optional[bool] = None, + ): + activities_set = set(activities) if activities else supported_activities() + if use_cuda is not None: + warn( + "`use_cuda` is deprecated, use `activities` argument instead", + FutureWarning, + stacklevel=2, + ) + if use_cuda: + activities_set.add(ProfilerActivity.CUDA) + elif ProfilerActivity.CUDA in activities_set: + activities_set.remove(ProfilerActivity.CUDA) + assert len(activities_set) > 0, "No valid profiler activities found" + + super().__init__( + activities=activities, + record_shapes=record_shapes, + profile_memory=profile_memory, + with_stack=with_stack, + with_flops=with_flops, + with_modules=with_modules, + experimental_config=experimental_config, + execution_trace_observer=execution_trace_observer, + acc_events=acc_events, + ) + + if schedule: + self.schedule = schedule + # add step markers into the trace and table view + self.record_steps = True + else: + self.schedule = _default_schedule_fn + self.record_steps = False + self.on_trace_ready = on_trace_ready + self.step_num = 0 + self.current_action = self.schedule(self.step_num) + self.step_rec_fn: Optional[prof.record_function] = None + + self.action_map: Dict[ + Tuple[ProfilerAction, Optional[ProfilerAction]], List[Any] + ] = { + # key is (prev_action, current_action), value is action list corresponding to the state pair. + (ProfilerAction.NONE, ProfilerAction.NONE): [], + (ProfilerAction.NONE, ProfilerAction.WARMUP): [self.prepare_trace], + (ProfilerAction.NONE, ProfilerAction.RECORD): [ + self.prepare_trace, + self.start_trace, + ], + (ProfilerAction.NONE, ProfilerAction.RECORD_AND_SAVE): [ + self.prepare_trace, + self.start_trace, + ], + (ProfilerAction.WARMUP, ProfilerAction.NONE): [ + partial(warn, "Incorrect schedule: WARMUP followed by NONE"), + self.start_trace, + self.stop_trace, + ], + (ProfilerAction.WARMUP, ProfilerAction.WARMUP): [], + (ProfilerAction.WARMUP, ProfilerAction.RECORD): [self.start_trace], + (ProfilerAction.WARMUP, ProfilerAction.RECORD_AND_SAVE): [self.start_trace], + (ProfilerAction.RECORD, ProfilerAction.NONE): [ + partial(warn, "Incorrect schedule: RECORD followed by NONE"), + self.stop_trace, + ], + (ProfilerAction.RECORD, ProfilerAction.WARMUP): [ + partial(warn, "Incorrect schedule: RECORD followed by WARMUP"), + self.stop_trace, + ], + (ProfilerAction.RECORD, ProfilerAction.RECORD): [], + (ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE): [], + (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.NONE): [ + self.stop_trace, + self._trace_ready, + ], + (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.WARMUP): [ + self.stop_trace, + self._trace_ready, + self.prepare_trace, + ], + (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD): [ + self.stop_trace, + self._trace_ready, + self.prepare_trace, + self.start_trace, + ], + (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD_AND_SAVE): [ + self.stop_trace, + self._trace_ready, + self.prepare_trace, + self.start_trace, + ], + # used for exit action + (ProfilerAction.WARMUP, None): [self.start_trace, self.stop_trace], + (ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready], + (ProfilerAction.RECORD_AND_SAVE, None): [ + self.stop_trace, + self._trace_ready, + ], + } + # Start tracking increments to profiler step, this will be used + # by Kineto + prof.KinetoStepTracker.init_step_count(PROFILER_STEP_NAME) + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME) + if self.execution_trace_observer: + self.execution_trace_observer.cleanup() + + def start(self): + self._transit_action(ProfilerAction.NONE, self.current_action) + if self.record_steps: + self.step_rec_fn = prof.record_function( + "ProfilerStep#" + str(self.step_num) + ) + self.step_rec_fn.__enter__() + + def stop(self): + if self.record_steps and self.step_rec_fn: + self.step_rec_fn.__exit__(None, None, None) + self._transit_action(self.current_action, None) + + def step(self): + """ + Signals the profiler that the next profiling step has started. + """ + if self.record_steps and self.step_rec_fn: + self.step_rec_fn.__exit__(None, None, None) + prev_action = self.current_action + self.step_num += 1 + self.current_action = self.schedule(self.step_num) + + self._transit_action(prev_action, self.current_action) + prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME) + + if self.record_steps: + self.step_rec_fn = prof.record_function( + "ProfilerStep#" + str(self.step_num) + ) + self.step_rec_fn.__enter__() + + def _trace_ready(self): + if self.on_trace_ready: + self.on_trace_ready(self) + + def _transit_action(self, prev_action, current_action): + action_list = self.action_map.get((prev_action, current_action)) + if action_list: + for action in action_list: + action() + + def _stats(self) -> Optional[prof._ProfilerStats]: + if self.profiler is None: + return None + return self.profiler._stats + + +class ExecutionTraceObserver(_ITraceObserver): + """Execution Trace Observer + + Each process can have a single ExecutionTraceObserver instance. The observer + can be added to record function callbacks via calling register_callback() + explicitly. Without calling unregister_callback(), repeated calls to + register_callback() will not add additional observers to record function + callbacks. Once an ExecutionTraceObserver is created, the start() and stop() + methods control when the event data is recorded. + + Deleting or calling unregister_callback() will remove the observer from the + record function callbacks, finalize the output file, and will stop + incurring any overheads. + """ + + def __init__(self) -> None: + """ + Initializes the default states. + """ + self._registered = False + self._execution_trace_running = False + + def __del__(self): + """ + Calls unregister_callback() to make sure to finalize outputs. + """ + self.unregister_callback() + + def register_callback(self, output_file_path: str) -> Self: + """ + Adds ET observer to record function callbacks. The data will be + written to output_file_path. + """ + if not self._registered: + self._output_file_path = output_file_path + self._registered = _add_execution_trace_observer(output_file_path) + return self + + def unregister_callback(self): + """ + Removes ET observer from record function callbacks. + """ + + def _save_triton_kernels(): + # Save the kernel paths for the generated kernels + from torch._inductor.codecache import PyCodeCache as PyCodeCache + + kernel_files = [ + v.__file__ + for v in PyCodeCache.cache.values() + if getattr(v, "__file__", None) is not None + ] + work_dir, file_name = os.path.split(self._output_file_path) + resource_dir = os.path.join( + work_dir, os.path.splitext(file_name)[0] + "_resources" + ) + if not os.path.exists(resource_dir): + os.mkdir(resource_dir) + + for kernel_file in kernel_files: + if kernel_file is None: + continue + path, name = os.path.split(kernel_file) + dst = os.path.join(resource_dir, name) + shutil.copyfile(kernel_file, dst) + + if self._registered: + self.stop() + try: + _save_triton_kernels() + except Exception as e: + warn(f"Execution trace failed to save kernels: {e}") + _remove_execution_trace_observer() + self._registered = False + + @property + def is_registered(self): + """ + Returns True if the execution trace observer is registered, otherwise False. + """ + return self._registered + + def is_running(self): + """ + Returns True if the observer is running, otherwise False. + """ + return self._execution_trace_running + + def start(self): + """ + Starts to capture. + """ + if self._registered and not self._execution_trace_running: + _enable_execution_trace_observer() + self._execution_trace_running = True + self._record_pg_config() + + def stop(self): + """ + Stops to capture. + """ + if self._execution_trace_running: + _disable_execution_trace_observer() + self._execution_trace_running = False + + def cleanup(self): + """ + Calls unregister_callback() to make sure to finalize outputs. + """ + self.unregister_callback() + + def get_output_file_path(self) -> str: + """ + Returns the output file name. + """ + if self.is_registered: + return self._output_file_path + else: + raise RuntimeError( + "A callback to the ET profiler needs to be registered " + "first before getting the output file path" + ) + + def _record_pg_config(self) -> None: + # Records the PG config info to the trace as node: + # ## process_group:init ## + if ( + self.is_registered + and torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info + torch.autograd._record_function_with_args_enter( + "## process_group:init ##", json.dumps(pg_config_info) + ) diff --git a/lib/python3.10/site-packages/torch/profiler/python_tracer.py b/lib/python3.10/site-packages/torch/profiler/python_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..b3e624911f95812a523d4dd927a74eec7fe5171b --- /dev/null +++ b/lib/python3.10/site-packages/torch/profiler/python_tracer.py @@ -0,0 +1,20 @@ +import os +import site +import sys +import typing + +import torch + + +def _prefix_regex() -> typing.List[str]: + raw_paths = ( + site.getsitepackages() + + sys.path + + [site.getuserbase()] + + [site.getusersitepackages()] + + [os.path.dirname(os.path.dirname(torch.__file__))] + ) + + path_prefixes = sorted({os.path.abspath(i) for i in raw_paths}, reverse=True) + assert all(isinstance(i, str) for i in path_prefixes) + return [i + os.sep for i in path_prefixes] diff --git a/lib/python3.10/site-packages/torch/quantization/__init__.py b/lib/python3.10/site-packages/torch/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8789fea17a17ffa8e490a8d744892c5140a70ee2 --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/__init__.py @@ -0,0 +1,86 @@ +# mypy: allow-untyped-defs +from .fake_quantize import * # noqa: F403 +from .fuse_modules import fuse_modules +from .fuser_method_mappings import * # noqa: F403 +from .observer import * # noqa: F403 +from .qconfig import * # noqa: F403 +from .quant_type import * # noqa: F403 +from .quantization_mappings import * # noqa: F403 +from .quantize import * # noqa: F403 +from .quantize_jit import * # noqa: F403 +from .stubs import * # noqa: F403 + + +def default_eval_fn(model, calib_data): + r""" + Default evaluation function takes a torch.utils.data.Dataset or a list of + input Tensors and run the model on the dataset + """ + for data, target in calib_data: + model(data) + + +__all__ = [ + "QuantWrapper", + "QuantStub", + "DeQuantStub", + # Top level API for eager mode quantization + "quantize", + "quantize_dynamic", + "quantize_qat", + "prepare", + "convert", + "prepare_qat", + # Top level API for graph mode quantization on TorchScript + "quantize_jit", + "quantize_dynamic_jit", + "_prepare_ondevice_dynamic_jit", + "_convert_ondevice_dynamic_jit", + "_quantize_ondevice_dynamic_jit", + # Top level API for graph mode quantization on GraphModule(torch.fx) + # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx + # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', + "QuantType", # quantization type + # custom module APIs + "get_default_static_quant_module_mappings", + "get_static_quant_module_class", + "get_default_dynamic_quant_module_mappings", + "get_default_qat_module_mappings", + "get_default_qconfig_propagation_list", + "get_default_compare_output_module_list", + "get_quantized_operator", + "get_fuser_method", + # Sub functions for `prepare` and `swap_module` + "propagate_qconfig_", + "add_quant_dequant", + "swap_module", + "default_eval_fn", + # Observers + "ObserverBase", + "WeightObserver", + "HistogramObserver", + "observer", + "default_observer", + "default_weight_observer", + "default_placeholder_observer", + "default_per_channel_weight_observer", + # FakeQuantize (for qat) + "default_fake_quant", + "default_weight_fake_quant", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_per_channel_weight_fake_quant", + "default_histogram_fake_quant", + # QConfig + "QConfig", + "default_qconfig", + "default_dynamic_qconfig", + "float16_dynamic_qconfig", + "float_qparams_weight_only_qconfig", + # QAT utilities + "default_qat_qconfig", + "prepare_qat", + "quantize_qat", + # module transformations + "fuse_modules", +] diff --git a/lib/python3.10/site-packages/torch/quantization/_numeric_suite.py b/lib/python3.10/site-packages/torch/quantization/_numeric_suite.py new file mode 100644 index 0000000000000000000000000000000000000000..49ccc8e69523f7dbee2335b788a2cb3a7db618a2 --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/_numeric_suite.py @@ -0,0 +1,28 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/ns/_numeric_suite.py`, while adding an import statement +here. +""" + +from torch.ao.ns._numeric_suite import ( + _convert_tuple_to_list, + _dequantize_tensor_list, + _find_match, + _get_logger_dict_helper, + _is_identical_module_type, + compare_model_outputs, + compare_model_stub, + compare_weights, + get_logger_dict, + get_matching_activations, + Logger, + NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST, + OutputLogger, + prepare_model_outputs, + prepare_model_with_stubs, + Shadow, + ShadowLogger, +) diff --git a/lib/python3.10/site-packages/torch/quantization/_numeric_suite_fx.py b/lib/python3.10/site-packages/torch/quantization/_numeric_suite_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..55cd7085740d0ce8de79491acbfc4888ebba21f8 --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/_numeric_suite_fx.py @@ -0,0 +1,26 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/ns/_numeric_suite_fx.py`, while adding an import statement +here. +""" + +from torch.ao.ns._numeric_suite_fx import ( + _add_loggers_impl, + _add_loggers_one_model, + _add_shadow_loggers_impl, + _extract_logger_info_one_model, + _extract_weights_impl, + _extract_weights_one_model, + add_loggers, + add_shadow_loggers, + extend_logger_results_with_comparison, + extract_logger_info, + extract_shadow_logger_info, + extract_weights, + NSTracer, + OutputLogger, + RNNReturnType, +) diff --git a/lib/python3.10/site-packages/torch/quantization/_quantized_conversions.py b/lib/python3.10/site-packages/torch/quantization/_quantized_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..8d930c366c0dd9857e463005474a2d59c04c4ae6 --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/_quantized_conversions.py @@ -0,0 +1,133 @@ +# mypy: allow-untyped-defs +import torch + + +# Pack pairs of int4 values into int8, in row major order; first int4 +# value goes into lower order bits, and second int4 value into higher +# order bits of resulting int8 value. +def pack_int4_to_int8(weight): + assert weight.dim() == 2 + assert weight.shape[1] % 2 == 0 + assert weight.dtype == torch.int8 + return ((weight[:, 1::2] & 0xF) << 4) | (weight[:, 0::2] & 0xF) + + +# Unpack quandruples of bits in int8 values into int4 values, in row +# major order; lower 4 bits go into first int4 value goes, and upper 4 +# bits go into second int4 value. +def unpack_int8_to_int4(weight): + assert weight.dim() == 2 + assert weight.dtype == torch.int8 + return torch.stack((weight & 0xF, (weight >> 4) & 0xF), dim=2).view( + weight.shape[0], 2 * weight.shape[1] + ) + + +# Transpose the weight matrix, and then reorder its elements according +# to underlying requirements of CUTLASS library, so that it could be +# used for CUTLASS-based mixed datatypes linear operation. +def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( + weight, dtypeq, transpose=False +): + assert weight.dim() == 2 + assert weight.dtype == torch.int8 + assert dtypeq == torch.int8 or dtypeq == torch.quint4x2 + assert weight.device.type == "cuda" + + device = weight.device + + # subbyte_transpose + if not transpose: + if dtypeq == torch.int8: + outp = weight.T + elif dtypeq == torch.quint4x2: + outp = pack_int4_to_int8(unpack_int8_to_int4(weight.view(torch.int8)).T) + else: + outp = weight + + ncols, nrows = outp.shape # type: ignore[possibly-undefined] + assert nrows % (32 if dtypeq == torch.quint4x2 else 64) == 0 + assert ncols % 64 == 0 + + # permute_B_rows_for_mixed_gemm + # (permute cols actually, as transpose is applied first here) + if dtypeq == torch.quint4x2: + cols_permuted = ( + torch.tensor( + [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15], + device=device, + ) + + (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand( + nrows // 16, 16 + ) + ).view(-1) + else: + cols_permuted = ( + torch.tensor( + [0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15], + device=device, + ) + + (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand( + nrows // 16, 16 + ) + ).view(-1) + outp = outp.index_copy(1, cols_permuted, outp) + + # interleave_column_major_tensor + magic0 = 4 if dtypeq == torch.quint4x2 else 2 + magic1 = 32 // magic0 + + tmp0 = ( + (torch.arange(0, ncols // magic0, device=device) * (nrows // 4 * magic0)) + .view(-1, 1) + .repeat(1, nrows // 4 * magic0) + .view(-1) + ) + tmp1 = ( + (torch.arange(0, nrows // 4 // magic1, device=device) * (magic0 * magic1)) + .view(-1, 1) + .repeat(1, magic1) + .view(-1) + .repeat(ncols) + ) + tmp2 = ( + (torch.arange(0, magic0, device=device) * magic1) + .view(-1, 1) + .repeat(1, nrows // 4) + .view(-1) + .repeat(ncols // magic0) + ) + tmp3 = torch.arange(0, magic1, device=device).repeat(nrows // 4 * ncols // magic1) + + outp_offsets = tmp0 + tmp1 + tmp2 + tmp3 + + tmp = outp.view(-1).view(torch.int32) + outp = torch.zeros_like(tmp) + outp.scatter_(0, outp_offsets, tmp) + outp = outp.view(weight.dtype) + + # add_bias_and_interleave_quantized_tensor_inplace + tmp = outp.view(-1) + + outp = torch.empty_like(tmp) + if dtypeq == torch.int8: + tmp = (tmp.to(torch.int) + 128).to(tmp.dtype) + outp[0::4] = tmp[0::4] + outp[1::4] = tmp[2::4] + outp[2::4] = tmp[1::4] + outp[3::4] = tmp[3::4] + elif dtypeq == torch.quint4x2: + tmp0 = ((tmp & 0xF) + 8) & 0xF + tmp0 = (tmp0[1::2] << 4) | tmp0[0::2] + tmp1 = (((tmp >> 4) & 0xF) + 8) & 0xF + tmp1 = (tmp1[1::2] << 4) | tmp1[0::2] + outp[0::4] = tmp0[0::2] + outp[1::4] = tmp0[1::2] + outp[2::4] = tmp1[0::2] + outp[3::4] = tmp1[1::2] + + if dtypeq == torch.quint4x2: + nrows *= 2 + ncols //= 2 + + return outp.view(nrows, ncols).view(torch.uint8) diff --git a/lib/python3.10/site-packages/torch/quantization/fake_quantize.py b/lib/python3.10/site-packages/torch/quantization/fake_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..69a5d730bfb68e89e24beb04ad13fd3fa5881ae9 --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/fake_quantize.py @@ -0,0 +1,32 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/fake_quantize.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.fake_quantize import ( + _is_fake_quant_script_module, + _is_per_channel, + _is_per_tensor, + _is_symmetric_quant, + default_fake_quant, + default_fixed_qparams_range_0to1_fake_quant, + default_fixed_qparams_range_neg1to1_fake_quant, + default_fused_act_fake_quant, + default_fused_per_channel_wt_fake_quant, + default_fused_wt_fake_quant, + default_histogram_fake_quant, + default_per_channel_weight_fake_quant, + default_weight_fake_quant, + disable_fake_quant, + disable_observer, + enable_fake_quant, + enable_observer, + FakeQuantize, + FakeQuantizeBase, + FixedQParamsFakeQuantize, + FusedMovingAvgObsFakeQuantize, +) diff --git a/lib/python3.10/site-packages/torch/quantization/fuse_modules.py b/lib/python3.10/site-packages/torch/quantization/fuse_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..6b704fa8094e8b367e9eba47102863ba845415b9 --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/fuse_modules.py @@ -0,0 +1,22 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/fuse_modules.py`, while adding an import statement +here. +""" + +# TODO: These functions are not used outside the `fuse_modules.py` +# Keeping here for now, need to remove them later. +from torch.ao.quantization.fuse_modules import ( + _fuse_modules, + _get_module, + _set_module, + fuse_known_modules, + fuse_modules, + get_fuser_method, +) + +# for backward compatiblity +from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn, fuse_conv_bn_relu diff --git a/lib/python3.10/site-packages/torch/quantization/fuser_method_mappings.py b/lib/python3.10/site-packages/torch/quantization/fuser_method_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb13ac96271fa7b926cc703918984760e6ede15 --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/fuser_method_mappings.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement +here. +""" +from torch.ao.quantization.fuser_method_mappings import ( + _DEFAULT_OP_LIST_TO_FUSER_METHOD, + fuse_conv_bn, + fuse_conv_bn_relu, + fuse_linear_bn, + get_fuser_method, +) diff --git a/lib/python3.10/site-packages/torch/quantization/observer.py b/lib/python3.10/site-packages/torch/quantization/observer.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6c7c1917c83433fc19f016140b25d060284535 --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/observer.py @@ -0,0 +1,36 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/observer.py`, while adding an import statement +here. +""" +from torch.ao.quantization.observer import ( + _is_activation_post_process, + _is_per_channel_script_obs_instance, + _ObserverBase, + _PartialWrapper, + _with_args, + _with_callable_args, + ABC, + default_debug_observer, + default_dynamic_quant_observer, + default_float_qparams_observer, + default_histogram_observer, + default_observer, + default_per_channel_weight_observer, + default_placeholder_observer, + default_weight_observer, + get_observer_state_dict, + HistogramObserver, + load_observer_state_dict, + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + NoopObserver, + ObserverBase, + PerChannelMinMaxObserver, + PlaceholderObserver, + RecordingObserver, +) diff --git a/lib/python3.10/site-packages/torch/quantization/qconfig.py b/lib/python3.10/site-packages/torch/quantization/qconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..6bb7e14110cb9cdc4e9c2c418c6776ea6445f0d3 --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/qconfig.py @@ -0,0 +1,30 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/qconfig.py`, while adding an import statement +here. +""" +from torch.ao.quantization.qconfig import ( + _add_module_to_qconfig_obs_ctr, + _assert_valid_qconfig, + default_activation_only_qconfig, + default_debug_qconfig, + default_dynamic_qconfig, + default_per_channel_qconfig, + default_qat_qconfig, + default_qat_qconfig_v2, + default_qconfig, + default_weight_only_qconfig, + float16_dynamic_qconfig, + float16_static_qconfig, + float_qparams_weight_only_qconfig, + get_default_qat_qconfig, + get_default_qconfig, + per_channel_dynamic_qconfig, + QConfig, + qconfig_equals, + QConfigAny, + QConfigDynamic, +) diff --git a/lib/python3.10/site-packages/torch/quantization/quant_type.py b/lib/python3.10/site-packages/torch/quantization/quant_type.py new file mode 100644 index 0000000000000000000000000000000000000000..8555f03792661f39c85c8facf3f911786cc25d0f --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/quant_type.py @@ -0,0 +1,10 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quant_type.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.quant_type import _get_quant_type_to_str, QuantType diff --git a/lib/python3.10/site-packages/torch/quantization/quantization_mappings.py b/lib/python3.10/site-packages/torch/quantization/quantization_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..8b44a980ce82fbfa5a81ad906499806cf99b876f --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/quantization_mappings.py @@ -0,0 +1,29 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quantization_mappings.py`, while adding an import statement +here. +""" +from torch.ao.quantization.quantization_mappings import ( + _get_special_act_post_process, + _has_special_act_post_process, + _INCLUDE_QCONFIG_PROPAGATE_LIST, + DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, + DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, + DEFAULT_MODULE_TO_ACT_POST_PROCESS, + DEFAULT_QAT_MODULE_MAPPINGS, + DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS, + DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, + get_default_compare_output_module_list, + get_default_dynamic_quant_module_mappings, + get_default_float_to_quantized_operator_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, + get_default_static_quant_module_mappings, + get_dynamic_quant_module_class, + get_quantized_operator, + get_static_quant_module_class, + no_observer_set, +) diff --git a/lib/python3.10/site-packages/torch/quantization/quantize.py b/lib/python3.10/site-packages/torch/quantization/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..600d3a46fed0346e3ae8909872cd5bf3c733860c --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/quantize.py @@ -0,0 +1,30 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quantize.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.quantize import ( + _add_observer_, + _convert, + _get_observer_dict, + _get_unique_devices_, + _is_activation_post_process, + _observer_forward_hook, + _propagate_qconfig_helper, + _register_activation_post_process_hook, + _remove_activation_post_process, + _remove_qconfig, + add_quant_dequant, + convert, + prepare, + prepare_qat, + propagate_qconfig_, + quantize, + quantize_dynamic, + quantize_qat, + swap_module, +) diff --git a/lib/python3.10/site-packages/torch/quantization/quantize_fx.py b/lib/python3.10/site-packages/torch/quantization/quantize_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..649142c7a7eee9885d96b37f70e582f3ea9a9f8d --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/quantize_fx.py @@ -0,0 +1,26 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quantize_fx.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.graph_module import ObservedGraphModule +from torch.ao.quantization.quantize_fx import ( + _check_is_graph_module, + _convert_fx, + _convert_standalone_module_fx, + _fuse_fx, + _prepare_fx, + _prepare_standalone_module_fx, + _swap_ff_with_fxff, + convert_fx, + fuse_fx, + prepare_fx, + prepare_qat_fx, + QuantizationTracer, + Scope, + ScopeContextManager, +) diff --git a/lib/python3.10/site-packages/torch/quantization/quantize_jit.py b/lib/python3.10/site-packages/torch/quantization/quantize_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..aa627dc7bb51ef7ea1fde7e2e5da283c9f6c8900 --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/quantize_jit.py @@ -0,0 +1,26 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quantize_jit.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.quantize_jit import ( + _check_forward_method, + _check_is_script_module, + _convert_jit, + _prepare_jit, + _prepare_ondevice_dynamic_jit, + _quantize_jit, + convert_dynamic_jit, + convert_jit, + fuse_conv_bn_jit, + prepare_dynamic_jit, + prepare_jit, + quantize_dynamic_jit, + quantize_jit, + script_qconfig, + script_qconfig_dict, +) diff --git a/lib/python3.10/site-packages/torch/quantization/stubs.py b/lib/python3.10/site-packages/torch/quantization/stubs.py new file mode 100644 index 0000000000000000000000000000000000000000..d3fd5c63683dc572c35cabc202ee4ddb2b0053c6 --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/stubs.py @@ -0,0 +1,10 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/stubs.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.stubs import DeQuantStub, QuantStub, QuantWrapper diff --git a/lib/python3.10/site-packages/torch/quantization/utils.py b/lib/python3.10/site-packages/torch/quantization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7d51d58f38d7462713f84ab62427852c1dd8e52c --- /dev/null +++ b/lib/python3.10/site-packages/torch/quantization/utils.py @@ -0,0 +1,29 @@ +# flake8: noqa: F401 +r""" +Utils shared by different modes of quantization (eager/graph) + +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/utils.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.utils import ( + activation_dtype, + activation_is_int8_quantized, + activation_is_statically_quantized, + calculate_qmin_qmax, + check_min_max_valid, + get_combined_dict, + get_qconfig_dtypes, + get_qparam_dict, + get_quant_type, + get_swapped_custom_module_class, + getattr_from_fqn, + is_per_channel, + is_per_tensor, + weight_dtype, + weight_is_quantized, + weight_is_statically_quantized, +) diff --git a/lib/python3.10/site-packages/torch/signal/__init__.py b/lib/python3.10/site-packages/torch/signal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3684eabe71215fe96689181065a9432f127aa8e7 --- /dev/null +++ b/lib/python3.10/site-packages/torch/signal/__init__.py @@ -0,0 +1,5 @@ +from . import windows + +__all__ = [ + 'windows' +] diff --git a/lib/python3.10/site-packages/torch/sparse/__init__.py b/lib/python3.10/site-packages/torch/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7203bf6a6fa4ad7f403abdc98fc2495db4411f32 --- /dev/null +++ b/lib/python3.10/site-packages/torch/sparse/__init__.py @@ -0,0 +1,703 @@ +# mypy: allow-untyped-defs +# The Tensor classes are added to this module by python_tensor.cpp +# A workaround to support both TorchScript and MyPy: +from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union + +import torch +from torch import Tensor +from torch._C import _add_docstr, _sparse # type: ignore[attr-defined] + +# Semi structured sparsity support +from .semi_structured import ( + SparseSemiStructuredTensor, + SparseSemiStructuredTensorCUSPARSELT, + SparseSemiStructuredTensorCUTLASS, + to_sparse_semi_structured, +) + + +if TYPE_CHECKING: + from torch.types import _dtype as DType + + DimOrDims = Optional[Union[int, Tuple[int, ...], List[int]]] +else: + # The JIT doesn't understand Union, nor torch.dtype here + DType = int + DimOrDims = Optional[Tuple[int]] + + +__all__ = [ + "addmm", + "check_sparse_tensor_invariants", + "mm", + "sum", + "softmax", + "solve", + "log_softmax", + "SparseSemiStructuredTensor", + "SparseSemiStructuredTensorCUTLASS", + "SparseSemiStructuredTensorCUSPARSELT", + "to_sparse_semi_structured", + "as_sparse_gradcheck", +] + +addmm = _add_docstr( + _sparse._sparse_addmm, + r""" +sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor + +This function does exact same thing as :func:`torch.addmm` in the forward, +except that it supports backward for sparse COO matrix :attr:`mat1`. +When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`. +When inputs are COO tensors, this function also supports backward for both inputs. + +Supports both CSR and COO storage formats. + +.. note:: + This function doesn't support computing derivaties with respect to CSR matrices. + +Args: + mat (Tensor): a dense matrix to be added + mat1 (Tensor): a sparse matrix to be multiplied + mat2 (Tensor): a dense matrix to be multiplied + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) +""", +) + + +mm = _add_docstr( + _sparse._sparse_mm, + r""" + Performs a matrix multiplication of the sparse matrix :attr:`mat1` + and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a + :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a + :math:`(n \times p)` tensor. + When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`. + When inputs are COO tensors, this function also supports backward for both inputs. + + Supports both CSR and COO storage formats. + +.. note:: + This function doesn't support computing derivaties with respect to CSR matrices. + + This function also additionally accepts an optional :attr:`reduce` argument that allows + specification of an optional reduction operation, mathematically performs the following operation: + +.. math:: + + z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj} + +where :math:`\bigoplus` defines the reduce operator. :attr:`reduce` is implemented only for +CSR storage format on CPU device. + +Args: + mat1 (Tensor): the first sparse matrix to be multiplied + mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense + reduce (str, optional): the reduction operation to apply for non-unique indices + (:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`). Default :obj:`"sum"`. + +Shape: + The format of the output tensor of this function follows: + - sparse x sparse -> sparse + - sparse x dense -> dense + +Example:: + + >>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_() + >>> a + tensor(indices=tensor([[0, 0, 1], + [0, 2, 1]]), + values=tensor([1., 2., 3.]), + size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True) + >>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True) + >>> b + tensor([[0., 1.], + [2., 0.], + [0., 0.]], requires_grad=True) + >>> y = torch.sparse.mm(a, b) + >>> y + tensor([[0., 1.], + [6., 0.]], grad_fn=) + >>> y.sum().backward() + >>> a.grad + tensor(indices=tensor([[0, 0, 1], + [0, 2, 1]]), + values=tensor([1., 0., 2.]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + >>> c = a.detach().to_sparse_csr() + >>> c + tensor(crow_indices=tensor([0, 2, 3]), + col_indices=tensor([0, 2, 1]), + values=tensor([1., 2., 3.]), size=(2, 3), nnz=3, + layout=torch.sparse_csr) + >>> y1 = torch.sparse.mm(c, b, 'sum') + >>> y1 + tensor([[0., 1.], + [6., 0.]], grad_fn=) + >>> y2 = torch.sparse.mm(c, b, 'max') + >>> y2 + tensor([[0., 1.], + [6., 0.]], grad_fn=) +""", +) + + +sampled_addmm = _add_docstr( + _sparse.sparse_sampled_addmm, + r""" +sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor + +Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locations +specified by the sparsity pattern of :attr:`input`. The matrix :attr:`input` is added to the final result. + +Mathematically this performs the following operation: + +.. math:: + + \text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input} + +where :math:`\text{spy}(\text{input})` is the sparsity pattern matrix of :attr:`input`, :attr:`alpha` +and :attr:`beta` are the scaling factors. +:math:`\text{spy}(\text{input})` has value 1 at the positions where :attr:`input` has non-zero values, and 0 elsewhere. + +.. note:: + :attr:`input` must be a sparse CSR tensor. :attr:`mat1` and :attr:`mat2` must be dense tensors. + +Args: + input (Tensor): a sparse CSR matrix of shape `(m, n)` to be added and used to compute + the sampled matrix multiplication + mat1 (Tensor): a dense matrix of shape `(m, k)` to be multiplied + mat2 (Tensor): a dense matrix of shape `(k, n)` to be multiplied + +Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> input = torch.eye(3, device='cuda').to_sparse_csr() + >>> mat1 = torch.randn(3, 5, device='cuda') + >>> mat2 = torch.randn(5, 3, device='cuda') + >>> torch.sparse.sampled_addmm(input, mat1, mat2) + tensor(crow_indices=tensor([0, 1, 2, 3]), + col_indices=tensor([0, 1, 2]), + values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0', + size=(3, 3), nnz=3, layout=torch.sparse_csr) + >>> torch.sparse.sampled_addmm(input, mat1, mat2).to_dense() + tensor([[ 0.2847, 0.0000, 0.0000], + [ 0.0000, -0.7805, 0.0000], + [ 0.0000, 0.0000, -0.1900]], device='cuda:0') + >>> torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5) + tensor(crow_indices=tensor([0, 1, 2, 3]), + col_indices=tensor([0, 1, 2]), + values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0', + size=(3, 3), nnz=3, layout=torch.sparse_csr) +""", +) + + +def sum(input: Tensor, dim: DimOrDims = None, dtype: Optional[DType] = None) -> Tensor: + r"""Return the sum of each row of the given sparse tensor. + + Returns the sum of each row of the sparse tensor :attr:`input` in the given + dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. When sum over all ``sparse_dim``, this method + returns a dense tensor instead of a sparse tensor. + + All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output + tensor having :attr:`dim` fewer dimensions than :attr:`input`. + + During backward, only gradients at ``nnz`` locations of :attr:`input` + will propagate back. Note that the gradients of :attr:`input` is coalesced. + + Args: + input (Tensor): the input sparse tensor + dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce + over all dims. + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: dtype of :attr:`input`. + + Example:: + + >>> nnz = 3 + >>> dims = [5, 5, 2, 3] + >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)), + torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz) + >>> V = torch.randn(nnz, dims[2], dims[3]) + >>> size = torch.Size(dims) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> S = torch.sparse_coo_tensor(I, V, size) + >>> S + tensor(indices=tensor([[2, 0, 3], + [2, 4, 1]]), + values=tensor([[[-0.6438, -1.6467, 1.4004], + [ 0.3411, 0.0918, -0.2312]], + + [[ 0.5348, 0.0634, -2.0494], + [-0.7125, -1.0646, 2.1844]], + + [[ 0.1276, 0.1874, -0.6334], + [-1.9682, -0.5340, 0.7483]]]), + size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo) + + # when sum over only part of sparse_dims, return a sparse tensor + >>> torch.sparse.sum(S, [1, 3]) + tensor(indices=tensor([[0, 2, 3]]), + values=tensor([[-1.4512, 0.4073], + [-0.8901, 0.2017], + [-0.3183, -1.7539]]), + size=(5, 2), nnz=3, layout=torch.sparse_coo) + + # when sum over all sparse dim, return a dense tensor + # with summed dims squeezed + >>> torch.sparse.sum(S, [0, 1, 3]) + tensor([-2.6596, -1.1450]) + """ + if dtype is None: + if dim is not None: + return torch._sparse_sum(input, dim) + else: + return torch._sparse_sum(input) + else: + if dim is not None: + return torch._sparse_sum(input, dim, dtype=dtype) + else: + return torch._sparse_sum(input, dtype=dtype) + + +softmax = _add_docstr( + _sparse._sparse_softmax, + r""" +sparse.softmax(input, dim, *, dtype=None) -> Tensor + +Applies a softmax function. + +Softmax is defined as: + +:math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}` + +where :math:`i, j` run over sparse tensor indices and unspecified +entries are ignores. This is equivalent to defining unspecified +entries as negative infinity so that :math:`exp(x_k) = 0` when the +entry with index :math:`k` has not specified. + +It is applied to all slices along `dim`, and will re-scale them so +that the elements lie in the range `[0, 1]` and sum to 1. + +Args: + input (Tensor): input + dim (int): A dimension along which softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. This is useful for preventing data type + overflows. Default: None +""", +) + + +spsolve = _add_docstr( + _sparse._spsolve, + r""" +sparse.spsolve(input, other, *, left=True) -> Tensor + +Computes the solution of a square system of linear equations with +a unique solution. Its purpose is similar to :func:`torch.linalg.solve`, +except that the system is defined by a sparse CSR matrix with layout +`sparse_csr`. + +Args: + input (Tensor): a sparse CSR matrix of shape `(n, n)` representing the + coefficients of the linear system. + other (Tensor): a dense matrix of shape `(n, )` representing the right-hand + side of the linear system. + left (bool, optional): whether to solve the system for `input @ out = other` + (default) or `out @ input = other`. Only `left=True` is supported. +""", +) + +log_softmax = _add_docstr( + _sparse._sparse_log_softmax, + r""" +sparse.log_softmax(input, dim, *, dtype=None) -> Tensor + +Applies a softmax function followed by logarithm. + +See :class:`~torch.sparse.softmax` for more details. + +Args: + input (Tensor): input + dim (int): A dimension along which softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. This is useful for preventing data type + overflows. Default: None +""", +) + + +spdiags = _add_docstr( + _sparse._spdiags, + r""" +sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor + +Creates a sparse 2D tensor by placing the values from rows of +:attr:`diagonals` along specified diagonals of the output + +The :attr:`offsets` tensor controls which diagonals are set. + +- If :attr:`offsets[i]` = 0, it is the main diagonal +- If :attr:`offsets[i]` < 0, it is below the main diagonal +- If :attr:`offsets[i]` > 0, it is above the main diagonal + +The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`, +and an offset may not be repeated. + +Args: + diagonals (Tensor): Matrix storing diagonals row-wise + offsets (Tensor): The diagonals to be set, stored as a vector + shape (2-tuple of ints): The desired shape of the result +Keyword args: + layout (:class:`torch.layout`, optional): The desired layout of the + returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr`` + are supported. Default: ``torch.sparse_coo`` + +Examples: + +Set the main and first two lower diagonals of a matrix:: + + >>> diags = torch.arange(9).reshape(3, 3) + >>> diags + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3)) + >>> s + tensor(indices=tensor([[0, 1, 2, 1, 2, 2], + [0, 1, 2, 0, 1, 0]]), + values=tensor([0, 1, 2, 3, 4, 6]), + size=(3, 3), nnz=6, layout=torch.sparse_coo) + >>> s.to_dense() + tensor([[0, 0, 0], + [3, 1, 0], + [6, 4, 2]]) + + +Change the output layout:: + + >>> diags = torch.arange(9).reshape(3, 3) + >>> diags + tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8]) + >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr) + >>> s + tensor(crow_indices=tensor([0, 1, 3, 6]), + col_indices=tensor([0, 0, 1, 0, 1, 2]), + values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6, + layout=torch.sparse_csr) + >>> s.to_dense() + tensor([[0, 0, 0], + [3, 1, 0], + [6, 4, 2]]) + +Set partial diagonals of a large output:: + + >>> diags = torch.tensor([[1, 2], [3, 4]]) + >>> offsets = torch.tensor([0, -1]) + >>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense() + tensor([[1, 0, 0, 0, 0], + [3, 2, 0, 0, 0], + [0, 4, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]) + +.. note:: + + When setting the values along a given diagonal the index into the diagonal + and the index into the row of :attr:`diagonals` is taken as the + column index in the output. This has the effect that when setting a diagonal + with a positive offset `k` the first value along that diagonal will be + the value in position `k` of the row of :attr:`diagonals` + +Specifying a positive offset:: + + >>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) + >>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense() + tensor([[1, 2, 3, 0, 0], + [0, 2, 3, 0, 0], + [0, 0, 3, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]) +""", +) + + +class check_sparse_tensor_invariants: + """A tool to control checking sparse tensor invariants. + + The following options exists to manage sparsr tensor invariants + checking in sparse tensor construction: + + 1. Using a context manager: + + .. code:: python + + with torch.sparse.check_sparse_tensor_invariants(): + run_my_model() + + 2. Using a procedural approach: + + .. code:: python + + prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled() + torch.sparse.check_sparse_tensor_invariants.enable() + + run_my_model() + + if not prev_checks_enabled: + torch.sparse.check_sparse_tensor_invariants.disable() + + 3. Using function decoration: + + .. code:: python + + @torch.sparse.check_sparse_tensor_invariants() + def run_my_model(): + ... + + run_my_model() + + 4. Using ``check_invariants`` keyword argument in sparse tensor constructor call. + For example: + + >>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True) + Traceback (most recent call last): + File "", line 1, in + RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied. + """ + + @staticmethod + def is_enabled(): + r"""Return True if the sparse tensor invariants checking is enabled. + + .. note:: + + Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or + :func:`torch.sparse.check_sparse_tensor_invariants.disable` to + manage the state of the sparse tensor invariants checks. + """ + return torch._C._check_sparse_tensor_invariants() + + @staticmethod + def enable(): + r"""Enable sparse tensor invariants checking in sparse tensor constructors. + + .. note:: + + By default, the sparse tensor invariants checks are disabled. Use + :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to + retrieve the current state of sparse tensor invariants checking. + + .. note:: + + The sparse tensor invariants check flag is effective to all sparse + tensor constructors, both in Python and ATen. + + The flag can be locally overridden by the ``check_invariants`` + optional argument of the sparse tensor constructor functions. + """ + torch._C._set_check_sparse_tensor_invariants(True) + + @staticmethod + def disable(): + r"""Disable sparse tensor invariants checking in sparse tensor constructors. + + See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information. + """ + torch._C._set_check_sparse_tensor_invariants(False) + + # context manager support + def __init__(self, enable=True): + self.state = enable + self.saved_state: Optional[bool] = None + + def __enter__(self): + if self.saved_state is not None: + raise RuntimeError( + "This context manager instance is already activated." + " Use a different context manager instance for context nesting." + ) + self.saved_state = self.is_enabled() + torch._C._set_check_sparse_tensor_invariants(self.state) + + def __exit__(self, type, value, traceback): + assert self.saved_state is not None + torch._C._set_check_sparse_tensor_invariants(self.saved_state) + self.saved_state = None + + # decorator support + def __call__(self, mth): + def test_mth(*args, **kwargs): + with type(self)(self.state): + return mth(*args, **kwargs) + + return test_mth + + +def as_sparse_gradcheck(gradcheck): + """Decorate function, to extend gradcheck for sparse tensors. + + Decorator for torch.autograd.gradcheck or its functools.partial + variants that extends the gradcheck function with support to input + functions that operate on or/and return sparse tensors. + + The specified gradcheck function itself is guaranteed to operate + on strided tensors only. + + For example: + + >>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck) + >>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True) + >>> gradcheck(lambda x: x.to_sparse_csr(), x) + True + """ + + def gradcheck_with_sparse_support(func, inputs, **kwargs): + """ + Create gradcheck with support for sparse tensors. + + Same as :func:`torch.autograd.gradcheck` but with sparse tensors inputs and outputs support. + """ + masked = kwargs.pop("masked", False) + sparse_layouts = { + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } + sparse_compressed_layouts = { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } + sparse_block_layouts = {torch.sparse_bsr, torch.sparse_bsc} + STRIDED_REPRESENTATION = "__STRIDED_REPRESENTATION__" + + def convert_to_strided_representation(args): + """Convert differentiable non-strided tensors to a representation containing differentiable strided tensors.""" + if not isinstance(args, (list, tuple)): + args = (args,) + new_args: List[Any] = [] + for obj in args: + if ( + isinstance(obj, torch.Tensor) + and obj.requires_grad + and obj.layout in sparse_layouts + ): + d = dict(layout=obj.layout, shape=obj.shape) + if not masked: + # Materialize unspecified elements with zero values + batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim() + blocksize = ( + obj.values().shape[batch_dim + 1 : batch_dim + 3] + if obj.layout in sparse_block_layouts + else None + ) + full_mask = torch.ones( + obj.shape, device=obj.device, dtype=torch.bool + ).to_sparse( + layout=obj.layout, + blocksize=blocksize, + dense_dim=obj.dense_dim(), + ) + obj = obj.to_dense().sparse_mask(full_mask) + if obj.layout is torch.sparse_coo: + d.update( + indices=obj._indices(), is_coalesced=obj.is_coalesced() + ) + values = obj._values() + elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}: + d.update( + compressed_indices=obj.crow_indices(), + plain_indices=obj.col_indices(), + ) + values = obj.values() + else: + d.update( + compressed_indices=obj.ccol_indices(), + plain_indices=obj.row_indices(), + ) + values = obj.values() + new_args.extend( + (STRIDED_REPRESENTATION, d, values.requires_grad_(True)) + ) + else: + new_args.append(obj) + return tuple(new_args) + + def restore_from_strided_representation(args): + """Restore non-strided differentiable tensosr from their strided representations.""" + new_args = [] + args = list(args) + while args: + a = args.pop(0) + if a == STRIDED_REPRESENTATION: + d, values = args.pop(0), args.pop(0) + if d["layout"] is torch.sparse_coo: + a = torch.sparse_coo_tensor( + d["indices"], + values, + size=d["shape"], + is_coalesced=d["is_coalesced"], + ) + elif d["layout"] in sparse_compressed_layouts: + a = torch.sparse_compressed_tensor( + d["compressed_indices"], + d["plain_indices"], + values, + size=d["shape"], + layout=d["layout"], + ) + else: + raise NotImplementedError( + f'conversion of {d["layout"]} strided representation to tensor' + ) + new_args.append(a) + return tuple(new_args) + + def func_wrapper(*args, **kwargs): + restored_args = restore_from_strided_representation(args) + + # convert differentiable output sparse tensors to strided + # tensors: + outputs = func(*restored_args, **kwargs) + + strided_outputs = ( + tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,) + ) + strided_outputs = tuple( + ( + o.to_dense(masked_grad=masked) + if isinstance(o, torch.Tensor) + and o.requires_grad + and o.layout in sparse_layouts + else o + ) + for o in strided_outputs + ) + + return ( + strided_outputs + if isinstance(outputs, (list, tuple)) + else strided_outputs[0] + ) + + args = (func_wrapper, convert_to_strided_representation(inputs)) + + return gradcheck(*args, **kwargs) + + return gradcheck_with_sparse_support diff --git a/lib/python3.10/site-packages/torch/sparse/_semi_structured_conversions.py b/lib/python3.10/site-packages/torch/sparse/_semi_structured_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..0828355202b576afb37584b8cb2e3b2c02e54367 --- /dev/null +++ b/lib/python3.10/site-packages/torch/sparse/_semi_structured_conversions.py @@ -0,0 +1,356 @@ +# mypy: allow-untyped-defs +import torch + + +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + """ + This is PyTorch implementation of main part of reorder_meta() + function, from tools/util/include/cutlass/util/host_reorder.h file + of CUTLASS source tree. Furthermore, CUTLASS template for sparse + GEMM decides upon layout of this matrix, and at the moment for the + sparse GEMM executed on tensor cores, this is layout described by + ColumnMajorInterleaved<2> data structure, in + include/cutlass/layout/matrix.h of CUTLASS source tree. The + reordering of meta matrix into meta_reordered matrix calculated + according to these segments of CUTLASS code is re-implemented here. + Note that this calculation produces offsets for scattering metadata + matrix elements into reordered metadata matrix elements (or, + equivalently, for gathering reordered metadata matrix element back + into metadata matrix elements). + """ + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group = 32 if meta_dtype.itemsize == 2 else 16 + interweave = 4 if meta_dtype.itemsize == 2 else 2 + dst_rows = ( + dst_rows // group * group + + (dst_rows % 8) * interweave + + (dst_rows % group) // 8 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +def sparse_semi_structured_from_dense_cutlass(dense): + """ + This function converts dense matrix into sparse semi-structured + representation, producing "compressed" matrix, in the layout used by + CUTLASS backend, and corresponding metadata matrix. + """ + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + """ + This function performs reverse of the function above - it + reconstructs dense matrix from a pair of "compressed" matrix, given + in the layout used by CUTLASS backend, and accompanying metadata + matrix. + """ + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + if sparse.dtype != torch.float: + ksparse = 4 + else: + ksparse = 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + dense.scatter_(0, dense_offsets, sparse.view(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def _sparse_semi_structured_tile(dense): + """ + This function computes a 2:4 sparse tile by greedily taking the largest values. + + Since we take the largest values greedily, how the sorting algorithm handles duplicates affects + the ultimate sparsity pattern. + + Note that this function does not have the same sorting semantics as our CUDA backend, + which is exposed via `torch._sparse_semi_structured_tile` and thus returns a different pattern. + """ + + def greedy_prune_tile(tile): + num_kept_row = [0, 0, 0, 0] + num_kept_col = [0, 0, 0, 0] + + for x in tile.flatten().sort(descending=True, stable=True).indices: + r, c = x // 4, x % 4 + if num_kept_row[r] < 2 and num_kept_col[c] < 2: + num_kept_row[r] += 1 + num_kept_col[c] += 1 + else: + tile[r, c] = 0 + + for batch in dense.unfold(0, 4, 4).unfold(1, 4, 4): + for tile in batch: + greedy_prune_tile(tile) + + return dense + + +def _compute_compressed_swizzled_bitmask(dense): + """ + Calculates the compressed swizzled bitmask from a dense tensor + """ + + # first we need to convert the dense tensor to a bitmask + int_bitmask = dense.bool().to(torch.uint8) + + # Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles: + # A, B, C and D, as displayed in the following schema: + # +---+---+ + # | A | B | + # +---+---+ + # | C | D | + # +---+---+ + + # we first need to split into the 8x8 tiles + bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8) + + # then we unfold again to get our indivdual 4x4 tiles + bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4) + + # Each 4x4 bitmask defines two 8-bit integers, which encode the sparsity pattern + # of that tile. Note that the least siginificant bit is stored first. + # [1 1 0 0] + # [1 1 0 0] -> 0011 0011 -> 51 + # [0 0 1 1] 1100 1100 204 + # [0 0 1 1] + + # reshape tensor to expand tiles into 8-bit vectors + bitmask_binary_representation = bitmask_4x4_chunks.reshape( + *bitmask_4x4_chunks.shape[:2], 4, 2, 8 + ) + + # to convert from binary representaiton, we can do a matmul with powers of two + powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda") + # To run on GPU: cast to float to do matmul and then cast back + compressed_swizzled_bitmask = ( + bitmask_binary_representation.to(torch.float) @ powers_of_two + ).to(torch.uint8) + + return compressed_swizzled_bitmask diff --git a/lib/python3.10/site-packages/torch/sparse/_semi_structured_ops.py b/lib/python3.10/site-packages/torch/sparse/_semi_structured_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c76d8d8be8a9566a8b98fc38a414d61f1171d3 --- /dev/null +++ b/lib/python3.10/site-packages/torch/sparse/_semi_structured_ops.py @@ -0,0 +1,168 @@ +# mypy: allow-untyped-defs +import contextlib + +import torch + + +__all__ = [ + "fallback_dispatcher", + "semi_sparse_values", + "semi_sparse_indices", + "semi_sparse_t", + "semi_sparse_view", + "semi_sparse_detach", + "semi_sparse_mm", + "semi_sparse_addmm", + "semi_sparse_linear", +] + + +@contextlib.contextmanager +def no_dispatch(): + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + + +def fallback_dispatcher(func, types, args, kwargs): + with no_dispatch(): + return func(*args) + + +def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 1 + A = args[0] + assert isinstance(A, torch.sparse.SparseSemiStructuredTensor) + assert A.packed is not None + if A.meta is None: + m, k = A.shape + num_kept_elements = m * k // 2 + return A.packed[:num_kept_elements:].view(m, -1) + else: + return A.packed.detach() + + +def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 1 + A = args[0] + assert isinstance(A, torch.sparse.SparseSemiStructuredTensor) + assert A.packed is not None + if A.meta is None: + m, k = A.shape + num_kept_elements = m * k // 2 + metadata = A.packed[num_kept_elements:].view(m, -1) + return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16) + else: + return A.meta + + +def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 1 + self = args[0] + assert isinstance(self, torch.sparse.SparseSemiStructuredTensor) + assert len(self.shape) == 2 + # Because we cannot go from the compressed representation back to the dense representation currently, + # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix + # is the first or second argument, we expect an even / odd number of calls to transpose respectively. + return self.__class__( + torch.Size([self.shape[-1], self.shape[0]]), + packed=self.packed_t, + meta=self.meta_t, + packed_t=self.packed, + meta_t=self.meta, + compressed_swizzled_bitmask=self.compressed_swizzled_bitmask.transpose(0, 1) + if self.compressed_swizzled_bitmask is not None + else None, + fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt, + alg_id_cusparselt=args[0].alg_id_cusparselt, + ) + + +def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 2 + self, shape = args + if tuple(shape) != self.shape: + raise NotImplementedError( + f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})" + ) + return self + + +def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor: + assert len(args) == 1 + self = args[0] + return self.__class__( + shape=self.shape, + packed=self.packed, + meta=self.meta, + packed_t=self.packed_t, + meta_t=self.meta_t, + compressed_swizzled_bitmask=self.compressed_swizzled_bitmask, + requires_grad=False, + ) + + +def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 2 + A, B = args + if A.ndim != 2 or B.ndim != 2: + raise NotImplementedError( + "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented" + ) + if isinstance(A, torch.sparse.SparseSemiStructuredTensor): + row, col = B.shape + B_padded = A._pad_dense_input(B) + res = A._mm(B_padded) + return res[:, :col] + else: + B_t = B.t() + assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor) + row, col = A.shape + A_padded = B._pad_dense_input(A) + res = B_t._mm(A_padded.t()).t() + return res[:row, :] + + +def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 3 + bias, A, B = args + if A.ndim != 2 or B.ndim != 2: + raise NotImplementedError( + "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented" + ) + if bias.ndim != 1: + raise NotImplementedError( + f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}" + ) + if isinstance(A, torch.sparse.SparseSemiStructuredTensor): + raise NotImplementedError( + "`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse" + ) + B_t = B.t() + assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor) + row, col = A.shape + A_padded = B_t._pad_dense_input(A) + result = B_t._mm(A_padded.t(), bias=bias).t() + return result[:row, :] + + +def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) in [2, 3] + A, B = args[:2] + bias = args[2] if len(args) == 3 else None + + shape = A.shape + A_2d = A.view(-1, shape[-1]) + + if bias is None: + res = A_2d @ B.t() + else: + res = semi_sparse_addmm( + func=None, + types=None, + args=[bias, A_2d, B.t()], + ) + + return res.view(*shape[:-1], -1) diff --git a/lib/python3.10/site-packages/torch/sparse/_triton_ops.py b/lib/python3.10/site-packages/torch/sparse/_triton_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..091e91d37f604a7435eb2a59e4a2ab88416a8824 --- /dev/null +++ b/lib/python3.10/site-packages/torch/sparse/_triton_ops.py @@ -0,0 +1,2415 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math +import os +import weakref +from functools import lru_cache +from typing import Optional, Tuple + +import torch +from torch._dynamo.utils import warn_once +from torch.utils._triton import has_triton + +from ._triton_ops_meta import get_meta + + +TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int( + os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2) +) + + +def check(cond, msg): + if not cond: + raise ValueError(msg) + + +def check_bsr_layout(f_name, t): + check( + t.layout == torch.sparse_bsr, + f"{f_name}(): only BSR sparse format is supported for the sparse argument.", + ) + + +def check_device(f_name, t, device): + check( + t.device == device and t.device.type == "cuda", + f"{f_name}(): all inputs are expected to be on the same GPU device.", + ) + + +def check_mm_compatible_shapes(f_name, lhs, rhs): + check( + lhs.dim() >= 2 and rhs.dim() >= 2, + f"{f_name}(): all inputs involved in the matrix product are expected to be at least 2D, " + f"but got lhs.dim() == {lhs.dim()} and rhs.dim() == {rhs.dim()}.", + ) + + m, kl = lhs.shape[-2:] + kr, n = rhs.shape[-2:] + + check( + kl == kr, + f"{f_name}(): arguments' sizes involved in the matrix product are not compatible for matrix multiplication, " + f"got lhs.shape[-1] == {kl} which is not equal to rhs.shape[-2] == {kr}.", + ) + + +def check_dtype(f_name, t, dtype, *additional_dtypes): + check( + t.dtype == dtype + and t.dtype + in ((torch.half, torch.bfloat16, torch.float) + tuple(*additional_dtypes)), + f"{f_name}(): all inputs are expected to be of the same dtype " + f"and one of (half, bfloat16, float32) or {additional_dtypes}, " + f"but got dtype == {t.dtype}.", + ) + + +def check_blocksize(f_name, blocksize): + assert len(blocksize) == 2 + + def is_power_of_two(v): + return not (v & (v - 1)) + + def is_compatible_blocksize(b): + res = True + for blocksize in b: + # Triton loads only blocks which are at least 16 and powers of 2. + res = (blocksize >= 16 and is_power_of_two(blocksize)) and res + return res + + check( + is_compatible_blocksize(blocksize), + f"{f_name}(): sparse inputs' blocksize ({blocksize[0]}, {blocksize[1]}) " + "should be at least 16 and a power of 2 in each dimension.", + ) + + +def make_triton_contiguous(t): + """Return input as a triton-contiguous tensor. + + A triton-contiguous tensor is defined as a tensor that has strides + with minimal value equal to 1. + + While triton kernels support triton-non-contiguous tensors (all + strides being greater than 1 or having 0 strides) arguments, a + considerable slow-down occurs because tensor data is copied + element-wise rather than chunk-wise. + """ + if min(t.stride()) != 1: + # TODO: investigate if contiguity along other axes than the + # last one can be beneficial for performance + return t.contiguous() + else: + return t + + +def broadcast_batch_dims(f_name, *tensors): + try: + return torch.broadcast_shapes(*(t.shape[:-2] for t in tensors)) + except Exception: + check(False, f"{f_name}(): inputs' batch dimensions are not broadcastable!") + + +def slicer(dim, slice_range, *tensors): + for t in tensors: + slices = [slice(None)] * t.dim() + slices[dim] = slice_range + yield t[slices] + + +def multidim_slicer(dims, slices, *tensors): + for t in tensors: + s = [slice(None)] * t.dim() + for d, d_slice in zip(dims, slices): + if d is not None: + s[d] = d_slice + yield t[s] + + +def ptr_stride_extractor(*tensors): + for t in tensors: + yield t + yield from t.stride() + + +def grid_partitioner(full_grid, grid_blocks, tensor_dims_map): + assert 0 <= len(full_grid) <= 3 + assert 0 <= len(grid_blocks) <= 3 + + import itertools + + def generate_grid_points(): + for fg, mg in zip(full_grid, grid_blocks): + yield range(0, fg, mg) + + def generate_sliced_tensors(slices): + for t, t_dims in tensor_dims_map.items(): + yield next(multidim_slicer(t_dims, slices, t)) + + for grid_point in itertools.product(*generate_grid_points()): + grid = [ + min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks) + ] + slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)] + # grid_points are iterated in a "contiguous" order, i.e. + # left dimensions traversed slower than right dimensions. + # This order is reversed for CUDA grids. + yield grid[::-1], *generate_sliced_tensors(slices) + + +def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None): + # cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1) + cuda_max_grid = (2147483647, 65535, 65535)[::-1] + if grid_blocks is None: + grid_blocks = cuda_max_grid + else: + + def valid_grid_dim(g, mg): + if g is None: + return mg + else: + # grid must be at least 1 and no greater than mg + return max(1, min(g, mg)) + + grid_blocks = tuple( + valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid) + ) # type: ignore[assignment] + + for grid, *sliced_tensors in grid_partitioner( + full_grid, grid_blocks, tensor_dims_map + ): + kernel(grid, *sliced_tensors) + + +def prepare_inputs(bsr, *dense_tensors): + # Introduce fake batch dimension if not present for convenience. + crow_indices = bsr.crow_indices().unsqueeze(0) + col_indices = bsr.col_indices().unsqueeze(0) + values = make_triton_contiguous(bsr.values().unsqueeze(0)) + tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors] + + # Compute broadcasted batch dimension + batch_dims_broadcasted = torch.broadcast_shapes( + values.shape[:-3], *(t.shape[:-2] for t in tensors) + ) + + # Broadcast batch dimensions and squash. + # The result can be either a view or a copy. + def batch_broadcast_and_squash(t, batch_dims, invariant_dims): + return t.broadcast_to(batch_dims + invariant_dims).flatten( + 0, len(batch_dims) - 1 + ) + + crow_indices = batch_broadcast_and_squash( + crow_indices, batch_dims_broadcasted, (-1,) + ) + + col_indices = batch_broadcast_and_squash(col_indices, batch_dims_broadcasted, (-1,)) + values = batch_broadcast_and_squash( + values, batch_dims_broadcasted, values.shape[-3:] + ) + tensors = [ + batch_broadcast_and_squash(t, batch_dims_broadcasted, t.shape[-2:]) + for t in tensors + ] + + return crow_indices, col_indices, values, *tensors + + +def broadcast_batch_dims_bsr(f_name, bsr, *tensors): + batch_shape = broadcast_batch_dims(f_name, bsr, *tensors) + + crow_indices = bsr.crow_indices().broadcast_to(batch_shape + (-1,)) + col_indices = bsr.col_indices().broadcast_to(batch_shape + (-1,)) + values = bsr.values().broadcast_to(batch_shape + bsr.values().shape[-3:]) + size = batch_shape + bsr.shape[-2:] + return torch.sparse_compressed_tensor( + crow_indices, col_indices, values, size=size, layout=bsr.layout + ) + + +# NOTE: this function will ALWAYS create a view +def tile_to_blocksize(t, blocksize): + *rest, m, n = t.shape + new_shape = rest + [ + m // blocksize[0], + blocksize[0], + n // blocksize[1], + blocksize[1], + ] + # using .view instead of .reshape to ensure that the result is + # indeed a view: + return t.view(new_shape).transpose(-3, -2) + + +def as1Dbatch(tensor): + """Return tensor as 3D tensor by either prepending new dimensions to + the tensor shape (when ``tensor.ndim < 3``), or by collapsing + starting dimensions into the first dimension (when ``tensor.ndim > + 3``). + """ + while tensor.ndim < 3: + tensor = tensor.unsqueeze(0) + if tensor.ndim > 3: + tensor = tensor.flatten(0, tensor.ndim - 3) + assert tensor.ndim == 3, tensor.shape + return tensor + + +def scatter_mm(blocks, others, indices_data, *, accumulators=None): + """Scattered matrix multiplication of tensors. + + A scattered matrix multiplication is defined as a series of matrix + multiplications applied to input tensors according to the input + and output mappings specified by indices data. + + The following indices data formats are supported for defining a + scattered matrix multiplication operation (:attr:`indices_data[0]` + holds the name of the indices data format as specified below): + + - ``"scatter_mm"`` - matrix multiplications scattered in batches + of tensors. + + If :attr:`blocks` is a :math:`(* \times M \times K) tensor, + :attr:`others` is a :math:`(* \times K \times N)` tensor, + :attr:`accumulators` is a :math:`(* \times M \times N)` tensor, + and :attr:`indices = indices_data['indices']` is a :math:`(* + \times 3)` tensor, then the operation is equivalent to the + following code:: + + c_offsets, pq = indices_data[1:] + for r in range(len(c_offsets) - 1): + for g in range(c_offsets[r], c_offsets[r + 1]): + p, q = pq[g] + accumulators[r] += blocks[p] @ others[q] + + - ``"bsr_strided_mm"`` - matrix multiplications scattered in + batches of tensors and a tensor. + + If :attr:`blocks` is a :math:`(Ms \times Ks) tensor, + :attr:`others` is a :math:`(* \times K \times N)` tensor, + :attr:`accumulators` is a :math:`(* \times M \times N)` tensor, then + the operation is equivalent to the following code:: + + c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:] + for b in range(nbatches): + for i, r in enumerate(r_offsets): + r0, r1 = divmod(r, N) + acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] + for g in range(c_indices[i], c_indices[i+1]): + p = p_offsets[g] + q0, q1 = divmod(q_offsets[g], N) + acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] + + where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are + integer multiples of ``Ms`` and ``Ks``, respectively. + + - ``"bsr_strided_mm_compressed"`` - matrix multiplications + scattered in batches of tensors and a tensor. A memory and + processor efficient version of ``"bsr_strided_mm"`` format. If + :attr:`blocks` is a :math:`(Ms \times Ks) tensor, :attr:`others` + is a :math:`(* \times K \times N)` tensor, :attr:`accumulators` + is a :math:`(* \times M \times N)` tensor, then the operation is + equivalent to the following code:: + + c_indices, r_offsets, q_offsets, meta = indices_data[1:] + for b in range(nbatches): + for r in r_offsets: + m = (r // N) // Ms + n = (r % N) // Ns + r0, r1 = divmod(r, N) + c0, c1 = c_indices[m], c_indices[m + 1] + acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] + for i, p in enumerate(range(c0, c1)): + q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i] + q0, q1 = divmod(q, N) + acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] + + where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are + integer multiples of ``Ms`` and ``Ks``, respectively. + + Notice that the order of ``r_offsets`` items can be arbitrary; + this property enables defining swizzle operators via + rearrangements of ``r_offsets`` items.. + + Auxilary functions are provided for pre-computing + :attr:`indices_data`. For example, + :func:`bsr_scatter_mm_indices_data` is used to define indices data + for matrix multiplication of BSR and strided tensors. + + Parameters + ---------- + blocks (Tensor): a 3-D tensor of first matrices to be multiplied + + others (Tensor): a tensor of second matrices to be multiplied. If + ``indices_data[0]=="scatter_mm"``, the tensor is a 1-D batch + tensor of second input matrices to be multiplied. Otherwise, the + second input matrices are slices of the :attr:`others` tensor. + indices_data (tuple): a format data that defines the inputs and + outputs of scattered matrix multiplications. + + Keyword arguments + ----------------- + + accumulators (Tensor, optional): a tensor of matrix product + accumulators. If ``indices_data[0]=="scatter_mm"``, the tensor + is a 1-D batch tensor of output matrices. Otherwise, output + matrices are slices of the :attr:`accumulators` tensor. + """ + indices_format = indices_data[0] + + assert blocks.ndim == 3 + P, Ms, Ks = blocks.shape + + if indices_format == "scatter_mm": + c_offsets, pq = indices_data[1:] + + assert others.ndim == 3 + Q, Ks_, Ns = others.shape + assert Ks == Ks_ + + if accumulators is None: + R = c_offsets.shape[0] - 1 + accumulators = torch.zeros( + (R, Ms, Ns), dtype=blocks.dtype, device=blocks.device + ) + else: + R, Ms_, Ns_ = accumulators.shape + assert Ms_ == Ms + assert Ns_ == Ns + + if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm2 is None: + for r in range(c_offsets.shape[0] - 1): + g0 = c_offsets[r] + g1 = c_offsets[r + 1] + for g in range(g0, g1): + p, q = pq[g] + accumulators[r] += blocks[p] @ others[q] + else: + _scatter_mm2(blocks, others, c_offsets, pq, accumulators) + return accumulators + + elif indices_format == "bsr_strided_mm": + others_shape = others.shape + others = as1Dbatch(others) + + B, K, N = others.shape + assert K % Ks == 0 + + c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:] + SPLIT_N = meta["SPLIT_N"] + + if accumulators is None: + M = Ms + (r_offsets.max().item() + 1) // N + accumulators = torch.zeros( + (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device + ) + else: + M, N_ = accumulators.shape[-2:] + assert N_ == N + + accumulators_shape = accumulators.shape + accumulators = as1Dbatch(accumulators) + + Ns = N // SPLIT_N + + if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None: + accumulators.zero_() + for b in range(B): + for r in range(r_offsets.shape[0]): + r_ = r_offsets[r].item() + g0 = c_indices[r].item() + g1 = c_indices[r + 1].item() + r0, r1 = divmod(r_, N) + acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] + for g in range(g0, g1): + p, q = p_offsets[g], q_offsets[g] + q0, q1 = divmod(q.item(), N) + acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] + else: + _scatter_mm6( + blocks, + others, + c_indices, + r_offsets, + p_offsets, + q_offsets, + meta, + accumulators, + ) + return accumulators.view(accumulators_shape) + + elif indices_format == "bsr_strided_mm_compressed": + others_shape = others.shape + others = as1Dbatch(others) + + B, K, N = others.shape + assert K % Ks == 0 + + c_indices, r_offsets, q_offsets, meta = indices_data[1:] + SPLIT_N = meta["SPLIT_N"] + + if accumulators is None: + M = Ms + (r_offsets.max().item() + 1) // N + accumulators = torch.zeros( + (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device + ) + else: + M, N_ = accumulators.shape[-2:] + assert N_ == N + + accumulators_shape = accumulators.shape + accumulators = as1Dbatch(accumulators) + + Ns = N // SPLIT_N + + if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None: + for b in range(B): + for j in range(len(r_offsets)): + r0, r1 = divmod(r_offsets[j].item(), N) + m = r0 // Ms + n = r1 // Ns + c0 = c_indices[m].item() + c1 = c_indices[m + 1].item() + acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] + for i, p in enumerate(range(c0, c1)): + q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i].item() + q0, q1 = divmod(q, N) + acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] + else: + p_offsets = torch.empty( + (0,), dtype=q_offsets.dtype, device=q_offsets.device + ) + _scatter_mm6( + blocks, + others, + c_indices, + r_offsets, + p_offsets, + q_offsets, + meta, + accumulators, + ) + return accumulators.view(accumulators_shape) + + else: + raise NotImplementedError(indices_format) + + +def scatter_mm_meta( + M, + K, + N, + Ms, + Ks, + GROUP_SIZE=None, + TILE_M=None, + TILE_N=None, + SPLIT_N=None, + num_warps=None, + num_stages=None, + **extra, +): + if {TILE_M, TILE_N, SPLIT_N, num_warps, num_stages, GROUP_SIZE} == {None}: + device_name = torch.cuda.get_device_name() + meta = get_meta( + "scatter_mm", + (M, K, N, Ms, Ks), + device_name, + version=(0, torch.float16, 0.5), + ) + if meta is not None: + meta.update(**extra) + return meta + # The following parameters are optimized for the performance + # equilibrium points of bsr-dense and dense-dense matrix + # multiplications when using GPU card NVIDIA GeForce RTX 2060 + # SUPER. For points far from the performance equilibrium + # points as well as for other GPU cards, the optimal + # parameters are likely different from what specified below. + if (M, K, N) == (256,) * 3: + if (Ms, Ks) == (16, 16): + SPLIT_N = 1 + TILE_M = 16 + TILE_N = 16 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (32, 32): + SPLIT_N = 2 + TILE_M = 32 + TILE_N = 16 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (64, 64): + SPLIT_N = 1 + TILE_M = 32 + TILE_N = 32 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (128, 128): + SPLIT_N = 1 + TILE_M = 32 + TILE_N = 32 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (M, K, N) == (512,) * 3: + if (Ms, Ks) == (16, 16): + SPLIT_N = 8 + TILE_M = 16 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 2 # noqa: E225,E231,E702 + elif (Ms, Ks) == (32, 32): + SPLIT_N = 8 + TILE_M = 32 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 2 # noqa: E225,E231,E702 + elif (Ms, Ks) == (64, 64): + SPLIT_N = 4 + TILE_M = 32 + TILE_N = 128 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (128, 128): + SPLIT_N = 8 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (M, K, N) == (1024,) * 3: + if (Ms, Ks) == (16, 16): + SPLIT_N = 4 + TILE_M = 16 + TILE_N = 128 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 + elif (Ms, Ks) == (32, 32): + SPLIT_N = 8 + TILE_M = 32 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 + elif (Ms, Ks) == (64, 64): + SPLIT_N = 16 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 2 # noqa: E225,E231,E702 + elif (Ms, Ks) == (128, 128): + SPLIT_N = 16 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (256, 256): + SPLIT_N = 16 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (M, K, N) == (2048,) * 3: + if (Ms, Ks) == (16, 16): + SPLIT_N = 4 + TILE_M = 16 + TILE_N = 128 + GROUP_SIZE = 8 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 + elif (Ms, Ks) == (32, 32): + SPLIT_N = 4 + TILE_M = 32 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 + elif (Ms, Ks) == (64, 64): + SPLIT_N = 4 + TILE_M = 64 + TILE_N = 128 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (128, 128): + SPLIT_N = 8 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 4 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (Ms, Ks) == (256, 256): + SPLIT_N = 4 + TILE_M = 64 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + elif (M, K, N) == (4096,) * 3: + if (Ms, Ks) == (16, 16): + SPLIT_N = 2 + TILE_M = 16 + TILE_N = 256 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 2 # noqa: E225,E231,E702 + elif (Ms, Ks) == (32, 32): + SPLIT_N = 2 + TILE_M = 32 + TILE_N = 64 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 1 # noqa: E225,E231,E702 + elif (Ms, Ks) == (64, 64): + SPLIT_N = 2 + TILE_M = 64 + TILE_N = 128 + GROUP_SIZE = 2 + num_stages = 1 + num_warps = 4 # noqa: E225,E231,E702 + + if SPLIT_N is None: + # Assume NVIDIA GeForce RTX 2060 SUPER: + # With the probality of 92% (99.9% when N > 512), the + # performance will not be worse more than 2% from the + # performance when using an optimal value. Otherwise, when N + # <= 512, using the following heuristics may give upto 15% + # lower performance. + SPLIT_N = { + 16: 1, + 32: 2, + 64: 4, + 128: 8, + 256: 16, + 512: 8, + 1024: 16, + 4096: 32, + 8192: 64, + }.get(N, 16) + if Ms >= 512 and N >= 2048: + SPLIT_N = 1 + Ns = N // SPLIT_N + if TILE_M is None: + TILE_M = min(64 if Ns < 512 else 32, Ms) + if TILE_N is None: + TILE_N = min(64 if Ns < 512 else 32, Ns) + num_stages = num_stages or 1 + if num_warps is None: + if min(M, N) > 1024: + num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4) + elif min(M, N) == 1024: + num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4) + elif min(M, N) == 256: + num_warps = {16: 1, 32: 4}.get(Ms, 4) + else: + num_warps = {16: 1, 32: 2}.get(Ms, 4) + GROUP_SIZE = GROUP_SIZE or 4 + + assert TILE_M <= Ms, dict(TILE_M=TILE_M, Ms=Ms) + assert TILE_N <= Ns, dict(TILE_N=TILE_N, Ns=Ns) + assert Ms <= M, dict(M=M, Ms=Ms) + assert Ns <= N, dict(N=N, Ns=Ns) + assert Ks <= K, dict(K=K, Ks=Ks) + + return dict( + TILE_M=TILE_M, + TILE_N=TILE_N, + GROUP_SIZE=GROUP_SIZE, + num_stages=num_stages, + num_warps=num_warps, + SPLIT_N=SPLIT_N, + **extra, + ) + + +def bsr_dense_addmm_meta( + M, + K, + N, + Ms, + Ks, + beta, + alpha, + SPLIT_N=None, + GROUP_SIZE_ROW=None, + num_warps=None, + num_stages=None, + sparsity=None, + dtype=None, + _version=0, + **extra, +): + # Specifying _version is useful for situations when one wants to + # discard existing triton kernel tuning results, say, in testing + # bsr_dense_addmm_meta functionality. + if dtype is None: + dtype = torch.float16 + if sparsity is None: + sparsity = 0.5 + if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}: + device_name = torch.cuda.get_device_name() + key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1) + meta = get_meta( + "bsr_dense_addmm", key, device_name, version=(_version, dtype, sparsity) + ) + if meta is None and sparsity != 0.5: + meta = get_meta( + "bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5) + ) + if meta is None: + # find approximate meta such that N % SPLIT_N == 0. + matching_meta = get_meta( + "bsr_dense_addmm", + (*key[:2], "*", *key[3:]), + device_name, + version=(_version, dtype, 0.5), + ) + for mkey in sorted(matching_meta or {}): + meta_ = matching_meta[mkey] + n = mkey[2] + split_n = meta_["SPLIT_N"] + c = n // split_n + if N % c == 0 and n <= N: + meta = dict(meta_) + meta["SPLIT_N"] = N // c + if meta is not None: + meta.update(**extra) + return meta + else: + # see [Computing optimal kernel parameters] in + # _triton_ops_meta.py for ways to avoid this warning + # message + warn_once( + f"bsr_dense_addmm uses non-optimal triton kernel parameters for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=}" + ) + + SPLIT_N = SPLIT_N or max(N // Ms, 1) + GROUP_SIZE_ROW = GROUP_SIZE_ROW or 4 + num_stages = num_stages or 1 + num_warps = num_warps or 4 + return dict( + SPLIT_N=SPLIT_N, + GROUP_SIZE_ROW=GROUP_SIZE_ROW, + num_stages=num_stages, + num_warps=num_warps, + **extra, + ) + + +class TensorAsKey: + """A light-weight wrapper of a tensor that enables storing tensors as + keys with efficient memory reference based comparision as an + approximation to data equality based keys. + + Motivation: the hash value of a torch tensor is tensor instance + based that does not use data equality and makes the usage of + tensors as keys less useful. For instance, the result of + ``len({a.crow_indices(), a.crow_indices()})`` is `2`, although, + the tensor results from `crow_indices` method call are equal, in + fact, these share the same data storage. + On the other hand, for efficient caching of tensors we want to + avoid calling torch.equal that compares tensors item-wise. + + TensorAsKey offers a compromise in that it guarantees key equality + of tensors that references data in the same storage in the same + manner and without accessing underlying data. However, this + approach does not always guarantee correctness. For instance, for + a complex tensor ``x``, we have ``TensorAsKey(x) == + TensorAsKey(x.conj())`` while ``torch.equal(x, x.conj())`` would + return False. + """ + + def __init__(self, obj): + def get_tensor_key(obj): + # Warning: TensorAsKey does not track negative nor + # conjugate bits of its input object because in the use + # case of wrapping compressed/plain indices of compressed + # sparse tensors (that are always integer tensors with + # non-negative items) these bits are never set. However, + # when extending the use of TensorAsKey to float or + # complex tensors, the values of these bits (see is_neg + # and is_conj methods) must be included in the key as + # well. + assert not (obj.dtype.is_floating_point or obj.dtype.is_complex), obj.dtype + return ( + obj.data_ptr(), + obj.storage_offset(), + obj.shape, + obj.stride(), + obj.dtype, + ) + + self._obj_ref = weakref.ref(obj) + if obj.layout is torch.strided: + self.key = get_tensor_key(obj) + elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}: + self.key = ( + get_tensor_key(obj.crow_indices()), + get_tensor_key(obj.col_indices()), + ) + elif obj.layout in {torch.sparse_csc, torch.sparse_bsc}: + self.key = ( + get_tensor_key(obj.ccol_indices()), + get_tensor_key(obj.row_indices()), + ) + else: + raise NotImplementedError(obj.layout) + self._hash = hash(self.key) + + def __hash__(self): + return self._hash + + def __eq__(self, other): + if not isinstance(other, TensorAsKey): + return False + if self.obj is None or other.obj is None: + # dead objects always compare unequal unless these are + # same objects + return self is other + return self.key == other.key + + @property + def obj(self): + """Return object if alive, otherwise None.""" + return self._obj_ref() + + +@lru_cache(maxsize=TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE) +def _bsr_scatter_mm_indices_data( + indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, compressed_sparse_tensor_as_key +): + bsr = compressed_sparse_tensor_as_key.obj + assert bsr is not None + crow_indices, col_indices = bsr.crow_indices(), bsr.col_indices() + device = crow_indices.device + indices_dtype = torch.int32 + + if indices_format == "bsr_strided_mm_compressed": + Ns = N // SPLIT_N + q_offsets_lst = [] + b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns + for m in range(M // Ms): + r0 = crow_indices[m].item() + r1 = crow_indices[m + 1].item() + if r1 == r0: + continue + q_offsets_lst.append( + (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) + + b.repeat_interleave(r1 - r0) + ) + q_offsets = torch.cat(q_offsets_lst) + crow_indices_diff = crow_indices.diff() + non_zero_row_indices = crow_indices_diff.nonzero() + a = non_zero_row_indices * (Ms * N) + r_offsets = (a + b).view(-1) + c_indices = crow_indices + # swizzle operation: mm elements with longer sums are computed first: + nnz_per_row = crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N) + nnz_per_row, indices = nnz_per_row.sort(descending=True, stable=True) + r_offsets = r_offsets[indices] + return (indices_format, c_indices, r_offsets, q_offsets) + + elif indices_format == "bsr_strided_mm": + Ns = N // SPLIT_N + p_offsets_lst = [] + q_offsets_lst = [] + b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns + for m in range(M // Ms): + r0 = crow_indices[m].item() + r1 = crow_indices[m + 1].item() + if r1 == r0: + continue + p_offsets_lst.append( + torch.arange(r0, r1, dtype=indices_dtype, device=device).repeat(SPLIT_N) + ) + q_offsets_lst.append( + (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N) + + b.repeat_interleave(r1 - r0) + ) + q_offsets = torch.cat(q_offsets_lst) + crow_indices_diff = crow_indices.diff() + non_zero_row_indices = crow_indices_diff.nonzero() + a = non_zero_row_indices * (Ms * N) + r_offsets = (a + b).view(-1) + c_indices = torch.cat( + ( + crow_indices[:1], + torch.cumsum( + crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N), + 0, + ), + ) + ) + p_offsets = torch.cat(p_offsets_lst) + return (indices_format, c_indices, r_offsets, p_offsets, q_offsets) + + elif indices_format == "scatter_mm": + Ns = Ms + c_indices = [0] + pq_offsets = [] + # todo: eliminate inner for-loops for efficiency + for b in range(nbatches): + for m in range(M // Ms): + r0 = crow_indices[m].item() + r1 = crow_indices[m + 1].item() + for n in range(N // Ns): + c_indices.append(c_indices[-1] + r1 - r0) + for t in range(r1 - r0): + p = r0 + t + q = (col_indices[p].item() + b * (K // Ks)) * (N // Ns) + n + pq_offsets.append([p, q]) + + return ( + indices_format, + torch.tensor(c_indices, dtype=indices_dtype, device=device), + torch.tensor(pq_offsets, dtype=indices_dtype, device=device), + ) + + else: + raise ValueError( + f"Invalid {indices_format=}. Expected bsr_strided_mm_compressed|bsr_strided_mm|scatter_mm" + ) + + +def bsr_scatter_mm_indices_data( + bsr, other, indices_format="bsr_strided_mm_compressed", **meta_input +): + """Computes indices data for :func:`scatter_mm` used in BSR and + strided tensor matrix multiplication. + """ + assert bsr.dense_dim() == 0 + assert bsr.ndim == 2 # no batch dims + crow_indices = bsr.crow_indices() + col_indices = bsr.col_indices() + blocksize = bsr.values().shape[-2:] + M, K = bsr.shape + Ms, Ks = blocksize + K_, N = other.shape[-2:] + assert K_ == K + nbatches = other.shape[:-2].numel() + + meta = scatter_mm_meta(M, K, N, Ms, Ks, **meta_input) + if "allow_tf32" not in meta_input: + meta.update(allow_tf32=bsr.dtype in {torch.float16, torch.bfloat16}) + SPLIT_N = meta["SPLIT_N"] + indices_data = _bsr_scatter_mm_indices_data( + indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, TensorAsKey(bsr) + ) + + if indices_format == "bsr_strided_mm_compressed": + meta.update(is_compressed=True) + return indices_data + (meta,) + elif indices_format == "bsr_strided_mm": + meta.update(is_compressed=False) + return indices_data + (meta,) + else: + return indices_data + + +def bsr_scatter_mm(bsr, other, indices_data=None, out=None): + """BSR @ strided -> strided""" + + assert bsr.ndim == 2 + assert other.ndim >= 2 + + Ms, Ks, Ns = bsr.shape[-2], bsr.shape[-1], other.shape[-1] + blocksize = bsr.values().shape[-2:] + + if indices_data is None: + indices_data = bsr_scatter_mm_indices_data( + bsr, other, indices_format="bsr_strided_mm_compressed" + ) + + indices_format = indices_data[0] + + if out is None: + out = torch.empty( + (*other.shape[:-2], Ms, Ns), dtype=bsr.dtype, device=bsr.device + ) + out_shape = out.shape + out = as1Dbatch(out) + + if bsr._nnz() == 0: + out.zero_() + elif indices_format in {"bsr_strided_mm_compressed", "bsr_strided_mm"}: + out.zero_() + scatter_mm(bsr.values(), other, indices_data, accumulators=out) + elif indices_format == "scatter_mm": + nbatches = other.shape[:-2].numel() + accumulators = torch.zeros( + ( + nbatches * Ms // blocksize[0] * Ns // blocksize[0], + blocksize[0], + blocksize[0], + ), + dtype=bsr.dtype, + device=bsr.device, + ) + others = ( + as1Dbatch(other) + .transpose(-2, -1) + .view( + nbatches, + Ns // blocksize[0], + blocksize[0], + Ks // blocksize[1], + blocksize[1], + ) + .movedim( + (3, 1, 4, 2), (1, 2, 3, 4) + ) # equivalent to .transpose(-3, -2).transpose(-2, -1).transpose(-4, -3) + .flatten(0, 2) + ) + scatter_mm(bsr.values(), others, indices_data, accumulators=accumulators) + out.copy_( + accumulators.unflatten( + 0, (nbatches, Ms // blocksize[0], Ns // blocksize[0]) + ) + .movedim( + (1, 2, 3, 4), (3, 1, 4, 2) + ) # equivalent to .transpose(-4, -3).transpose(-2, -1).transpose(-3, -2) + .reshape(nbatches, Ns, Ms) + .transpose(-2, -1) + ) + else: + raise NotImplementedError(indices_format) + + return out.view(out_shape) + + +def _int_bsr_dense_addmm( + input: torch.Tensor, + bsr: torch.Tensor, + dense: torch.Tensor, + *, + beta=1, + alpha=1, + out: Optional[torch.Tensor] = None, + skip_checks: bool = False, + max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, + meta: Optional[dict] = None, +): + if out is None and dense.dtype is torch.int8: + f_name = "_int_bsr_dense_addmm" + crow_indices = bsr.crow_indices() + batch_ndim = crow_indices.dim() - 1 + M = bsr.shape[batch_ndim] + N = dense.shape[-1] + original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) + out = torch.empty( + original_batch_dims_broadcasted + (M, N), + dtype=torch.int32, + device=dense.device, + ) + return bsr_dense_addmm( + input, + bsr, + dense, + beta=beta, + alpha=alpha, + out=out, + skip_checks=skip_checks, + max_grid=max_grid, + meta=meta, + ) + + +def bsr_dense_addmm( + input: torch.Tensor, + bsr: torch.Tensor, + dense: torch.Tensor, + *, + beta=1, + alpha=1, + out: Optional[torch.Tensor] = None, + skip_checks: bool = False, + max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, + meta: Optional[dict] = None, +): + f_name = "bsr_dense_addmm" + values = bsr.values() + crow_indices = bsr.crow_indices() + col_indices = bsr.col_indices() + batch_ndim = crow_indices.dim() - 1 + M, K = bsr.shape[batch_ndim : batch_ndim + 2] + blocksize = values.shape[batch_ndim + 1 : batch_ndim + 3] + N = dense.shape[-1] + + # todo: implement checks + + if out is None: + original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) + out = dense.new_empty(original_batch_dims_broadcasted + (M, N)) + + if bsr._nnz() == 0 or alpha == 0 or N == 0 or M == 0 or K == 0: + if beta == 0: + out.zero_() + else: + out.copy_(input) + if beta != 1: + out.mul_(beta) + return out + + if meta is None: + sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2) + meta = bsr_dense_addmm_meta( + M, + K, + N, + blocksize[0], + blocksize[1], + beta, + alpha, + sparsity=sparsity, + dtype=out.dtype, + ) + out_backup = out + + crow_indices, col_indices, values, input, dense, out = prepare_inputs( + bsr, input, dense, out + ) + + BM, BK = blocksize + SPLIT_N = meta.get("SPLIT_N", N // BM) + BN = N // SPLIT_N + + out_untiled = out + out = tile_to_blocksize(out, (BM, BN)) + dense = tile_to_blocksize(dense, (BK, BN)) + input = tile_to_blocksize(input, (BM, BN)) + + dot_out_dtype = { + torch.float16: tl.float32, + torch.bfloat16: tl.float32, + torch.float32: tl.float64, + torch.float64: tl.float64, + torch.int8: tl.int32, + torch.int32: tl.int32, + }[out.dtype] + + n_batches = dense.size(0) + n_block_rows = crow_indices.size(-1) - 1 + n_block_cols = dense.size(-3) + + full_grid = (n_batches, n_block_cols, n_block_rows) + if max_grid is not None: + grid_blocks = tuple(max_grid[:3][::-1]) + (None,) * (3 - len(max_grid[:3])) + else: + grid_blocks = None + + tensor_dims_map = { + values: (0, None, None), + crow_indices: (0, None, -1), + col_indices: (0, None, None), + input: (0, -3, -4), + dense: (0, -3, None), + out: (0, -3, -4), + } + + assert alpha != 0 + + def kernel(grid, *sliced_tensors): + _bsr_strided_addmm_kernel[grid]( + *ptr_stride_extractor(*sliced_tensors), + beta, + alpha, + beta_is_one=beta == 1, + beta_is_nonzero=beta != 0, + alpha_is_one=alpha == 1, + BLOCKSIZE_ROW=BM, + BLOCKSIZE_INNER=BK, + BLOCKSIZE_COL=BN, + allow_tf32=dot_out_dtype == tl.float32, + acc_dtype=dot_out_dtype, + **meta, + ) + + launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) + + if out.data_ptr() != out_backup.data_ptr(): + # prepare_inputs has made a copy of out, copy its content back + # to out_backup: + out_backup.copy_(out_untiled.view(out_backup.shape)) + + return out_backup + + +if has_triton(): + import triton + import triton.language as tl + + @triton.jit + def _sampled_addmm_kernel( + alpha, + beta, + IS_BETA_ZERO: tl.constexpr, + BLOCKSIZE_ROW: tl.constexpr, + BLOCKSIZE_COL: tl.constexpr, + k, + TILE_K: tl.constexpr, + values_ptr, + values_batch_stride, + values_nnz_stride, + values_row_block_stride, + values_col_block_stride, + crow_indices_ptr, + crow_indices_batch_stride, + crow_indices_stride, + col_indices_ptr, + col_indices_batch_stride, + col_indices_stride, + mat1_ptr, + mat1_batch_stride, + mat1_tiled_row_stride, + mat1_tiled_col_stride, + mat1_row_block_stride, + mat1_col_block_stride, + mat2_ptr, + mat2_batch_stride, + mat2_tiled_row_stride, + mat2_tiled_col_stride, + mat2_row_block_stride, + mat2_col_block_stride, + acc_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + ): + batch_pid = tl.program_id(axis=1) + row_block_pid = tl.program_id(axis=0) + + crow_indices_offset_ptr = ( + crow_indices_ptr + + crow_indices_batch_stride * batch_pid + + crow_indices_stride * row_block_pid + ) + nnz_offset = tl.load(crow_indices_offset_ptr) + nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) + + # Compute nnz for the row with number row_block_pid. + # If it is zero, skip the row. + row_nnz = nnz_offset_next - nnz_offset + if row_nnz == 0: + return + + row_block_arange = tl.arange(0, BLOCKSIZE_ROW) + col_block_arange = tl.arange(0, BLOCKSIZE_COL) + + # Pointers are set to the first block of the current row. + values_block_ptrs = ( + values_ptr + + values_batch_stride * batch_pid + + values_nnz_stride * nnz_offset + + values_row_block_stride * row_block_arange[:, None] + + values_col_block_stride * col_block_arange[None, :] + ) + + col_index_nnz_ptr = ( + col_indices_ptr + + col_indices_batch_stride * batch_pid + + col_indices_stride * nnz_offset + ) + + # Advance mat1 to the current tiled row, ignore columns. + mat1_block_ptrs = ( + mat1_ptr + + mat1_batch_stride * batch_pid + + mat1_tiled_row_stride * row_block_pid + + mat1_row_block_stride * row_block_arange[:, None] + ) + + # Advance mat2 in batch and block col dimension. + mat2_block_ptrs = ( + mat2_ptr + + mat2_batch_stride * batch_pid + + mat2_col_block_stride * col_block_arange[None, :] + ) + + k_tile_arange = tl.arange(0, TILE_K) + for _ in range(row_nnz): + acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) + + # find column block index + col_block = tl.load(col_index_nnz_ptr) + + for k_tile in range(0, k, TILE_K): + k_offsets = k_tile + k_tile_arange + mask_k = k_offsets < k + + mat1_block = tl.load( + mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :], + mask=mask_k[None, :], + other=0.0, + ) + + mat2_block = tl.load( + mat2_block_ptrs + + mat2_tiled_col_stride * col_block + + mat2_row_block_stride * k_offsets[:, None], + mask=mask_k[:, None], + other=0.0, + ) + + acc_block += tl.dot( + mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype + ) + + if IS_BETA_ZERO: + acc_block *= alpha + else: + acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs) + + # write result + tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty)) + + # advance val/col_index ptrs to the next block in the row. + values_block_ptrs += values_nnz_stride + col_index_nnz_ptr += col_indices_stride + + @triton.jit + def _bsr_strided_dense_rowspace_kernel( + # values prologue + values_ptr, + values_batch_stride, + values_nnz_stride, + values_row_block_stride, + values_col_block_stride, + # values epilogue + # crow_indices prologue + crow_indices_ptr, + crow_indices_batch_stride, + crow_indices_stride, + # crow_indices epilogue + # col_indices prologue + col_indices_ptr, + col_indices_batch_stride, + col_indices_stride, + # col_indices epilogue + # dense prologue + dense_ptr, + dense_batch_stride, + dense_tiled_row_stride, + dense_tiled_col_stride, + dense_row_block_stride, + dense_col_block_stride, + # dense epilogue + # output prologue + output_ptr, + output_batch_stride, + output_tiled_row_stride, + output_tiled_col_stride, + output_row_block_stride, + output_col_block_stride, + # output epilogue + # + # gh-113754: Always keep all constexpr arguments at the end of + # triton kernel arguments list because with triton 2.1 or + # earlier non-contiguous outputs will corrupt CUDA state due + # to a triton bug (fixed in openai/triton#2262). + BLOCKSIZE_ROW: tl.constexpr, + BLOCKSIZE_COL: tl.constexpr, + acc_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + GROUP_SIZE_ROW: tl.constexpr, + ): + batch_pid = tl.program_id(axis=2) + row_block_pid = tl.program_id(axis=0) + col_block_pid = tl.program_id(axis=1) + n_block_rows = tl.num_programs(axis=0) + n_block_cols = tl.num_programs(axis=1) + + row_block_pid, col_block_pid = tl.swizzle2d( + row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW + ) + + crow_indices_offset_ptr = ( + crow_indices_ptr + + crow_indices_batch_stride * batch_pid + + crow_indices_stride * row_block_pid + ) + nnz_offset = tl.load(crow_indices_offset_ptr) + nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) + + # Compute nnz for the row with number row_block_pid. + # If it is zero, skip the row. + row_nnz = nnz_offset_next - nnz_offset + if row_nnz == 0: + return + + row_block_arange = tl.arange(0, BLOCKSIZE_ROW) + col_block_arange = tl.arange(0, BLOCKSIZE_COL) + + # Pointers are set to the first block of the current row. + values_block_ptrs = ( + values_ptr + + values_batch_stride * batch_pid + + values_nnz_stride * nnz_offset + + values_row_block_stride * row_block_arange[:, None] + + values_col_block_stride * col_block_arange[None, :] + ) + + # NOTE: dense is advanced into all dimensions but the tiled row one. + # That will be advanced in the loop according to values in col_indices. + dense_block_ptrs = ( + dense_ptr + + dense_batch_stride * batch_pid + + dense_tiled_col_stride * col_block_pid + + dense_row_block_stride * col_block_arange[:, None] + + dense_col_block_stride * row_block_arange[None, :] + ) + + # Pointers are set to exact write-to locations + output_ptrs = ( + output_ptr + + output_batch_stride * batch_pid + + output_tiled_row_stride * row_block_pid + + output_tiled_col_stride * col_block_pid + + output_row_block_stride * row_block_arange[:, None] + + output_col_block_stride * row_block_arange[None, :] + ) + + # Set pointer to the first nonzero element in the current row + col_index_nnz_ptr = ( + col_indices_ptr + + col_indices_batch_stride * batch_pid + + col_indices_stride * nnz_offset + ) + + output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) + for _ in range(row_nnz): + values_block = tl.load(values_block_ptrs) + + # find which row of dense needs to get loaded + # for multiplication with values_block. + dense_row_idx = tl.load(col_index_nnz_ptr) + dense_block = tl.load( + dense_block_ptrs + dense_tiled_row_stride * dense_row_idx + ) + + # do block mm + output_acc_block += tl.dot( + values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype + ) + + # move val/col_index ptrs to the next block in the row + values_block_ptrs += values_nnz_stride + col_index_nnz_ptr += col_indices_stride + + # write back the result + tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) + + def _run_sampled_addmm_kernel( + alpha, + beta, + is_beta_zero, + blocksize, + k, + tile_k, + values, + crow_indices, + col_indices, + mat1, + mat2, + max_grid, + ): + n_batches = values.size(0) + n_block_rows = crow_indices.size(-1) - 1 + + full_grid = (n_batches, n_block_rows) + if max_grid is not None: + grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2])) + else: + grid_blocks = None + tensor_dims_map = { + values: (0, None), + crow_indices: (0, -1), + col_indices: (0, None), + mat1: (0, -4), + mat2: (0, None), + } + if values.dtype in (torch.half, torch.bfloat16): + acc_dtype = tl.float32 + allow_tf32 = True + else: + acc_dtype = tl.float64 + allow_tf32 = False + + def kernel(grid, *sliced_tensors): + _sampled_addmm_kernel[grid]( + alpha, + beta, + is_beta_zero, + *blocksize, + k, + tile_k, + *ptr_stride_extractor(*sliced_tensors), + acc_dtype=acc_dtype, + allow_tf32=allow_tf32, + num_stages=1, + num_warps=4, + ) + + launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) + + def sampled_addmm( + input: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, + *, + beta=1.0, + alpha=1.0, + out: Optional[torch.Tensor] = None, + skip_checks: bool = False, + max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, + ): + f_name = "sampled_addmm" + + check_bsr_layout(f_name, input) + input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2) + + if not skip_checks: + check_device(f_name, mat1, input.device) + check_device(f_name, mat2, input.device) + if beta != 0.0 and input.dtype is torch.bool: + check( + False, + f"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.", + ) + if input.dtype is not torch.bool: + check_dtype(f_name, mat1, input.dtype) + check_dtype(f_name, mat2, input.dtype) + else: + check_dtype(f_name, mat1, mat2.dtype) + check_mm_compatible_shapes(f_name, mat1, mat2) + if out is not None: + check_bsr_layout(f_name, out) + check_device(f_name, out, mat1.device) + check_dtype(f_name, out, input.dtype) + check( + out.shape == input_broadcasted.shape and out._nnz() == input._nnz(), + f"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} " + f"and with nnz equal to {input_broadcasted._nnz()} " + f"but got out.shape = {out.shape} and out.nnz = {out._nnz()}", + ) + + if out is None: + out = input_broadcasted.to(mat1.dtype, copy=True) + else: + out.copy_(input_broadcasted) + + if out.numel() == 0 or out._nnz() == 0: + return out + + blocksize = out.values().shape[-2:] + m = mat1.size(-2) + n = mat2.size(-1) + k = mat1.size(-1) + + # NOTE: (m, 0) @ (0, n) == zeros(m, n) + if alpha == 0.0 or k == 0: + out.values().mul_(beta) + return out + + # prepare inputs by reshaping them to be kernel-compatible + out_backup = out + crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2) + + mat1 = tile_to_blocksize(mat1, (blocksize[0], k)) + mat2 = tile_to_blocksize(mat2, (k, blocksize[1])) + tile_k = max(*blocksize) + + _run_sampled_addmm_kernel( + alpha, + beta, + beta == 0.0, + blocksize, + k, + tile_k, + values, + crow_indices, + col_indices, + mat1, + mat2, + max_grid, + ) + + # If nnz x block strides are not the same in out_backup.values and values, + # it means that out_backup.values and values are not the views of each other, + # so we have to copy. + if out_backup.values().stride()[-3:] != values.stride()[-3:]: + out_backup.values().copy_(values.reshape(out_backup.values().shape)) + return out_backup + + def bsr_dense_mm( + bsr: torch.Tensor, + dense: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + skip_checks: bool = False, + max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, + meta: Optional[dict] = None, + ): + f_name = "bsr_dense_mm" + m, kl = bsr.shape[-2:] + if not skip_checks: + check_bsr_layout(f_name, bsr) + check_device(f_name, bsr, dense.device) + check_dtype(f_name, bsr, dense.dtype, (torch.int8,)) + check_mm_compatible_shapes(f_name, bsr, dense) + + n = dense.size(-1) + row_block, col_block = bsr.values().shape[-2:] + check_blocksize(f_name, (row_block, col_block)) + check( + not n % 16, + f"{f_name}(): dense.size(-1) == {n} should be divisible by 16", + ) + else: + kr, n = dense.shape[-2:] + + original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) + + if out is not None and not skip_checks: + expected_out_shape = original_batch_dims_broadcasted + (m, n) + check( + out.shape == expected_out_shape, + "bsr_dense_mm(): `out` argument has wrong shape, " + f"expected {expected_out_shape}, but got {out.shape}.", + ) + check( + out.is_contiguous() or out.transpose(-2, -1).is_contiguous(), + "bsr_dense_mm(): only row-major/col-major `out` arguments are supported, " + "i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) " + "should be True.", + ) + + # Allocate out + if out is None: + out = dense.new_empty(original_batch_dims_broadcasted + (m, n)) + + # Short circuit if lhs is zero + if bsr._nnz() == 0: + return out.zero_() + + # with beta==0, addmm ignores input content, so we can use out + # as a placeholder for input because their shapes match: + return bsr_dense_addmm(out, bsr, dense, alpha=1, beta=0, out=out) + + @triton.jit + def _bsr_softmax_kernel( + crow_indices_ptr, + crow_indices_batch_stride, + crow_indices_stride, + values_ptr, + values_batch_stride, + values_row_block_stride, + values_nnz_col_block_stride, + row_block, + col_block, + MAX_ROW_NNZ: tl.constexpr, + TILE: tl.constexpr, + ): + batch_pid = tl.program_id(axis=2) + row_block_offset_pid = tl.program_id(axis=1) + row_block_pid = tl.program_id(axis=0) + + crow_indices_offset_ptr = ( + crow_indices_ptr + + crow_indices_batch_stride * batch_pid + + crow_indices_stride * row_block_pid + ) + nnz_offset = tl.load(crow_indices_offset_ptr) + nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) + + # Compute nnz for the row with number row_block_pid. + # If it is zero, skip the row. + row_nnz = nnz_offset_next - nnz_offset + if row_nnz == 0: + return + + row_arange = tl.arange(0, TILE) + mask = row_arange < row_nnz * col_block + + curr_row_values_ptrs = ( + values_ptr + + values_batch_stride * batch_pid + + values_row_block_stride * row_block_offset_pid + + nnz_offset * col_block + ) + + # find max in the row + row_tile = tl.load( + curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") + ).to(tl.float32) + max_row_value = tl.max(row_tile, axis=0) + for _ in range(TILE, MAX_ROW_NNZ, TILE): + row_arange += TILE + mask = row_arange < row_nnz * col_block + row_tile = tl.load( + curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") + ).to(tl.float32) + curr_max_row_value = tl.max(row_tile, axis=0) + max_row_value = tl.where( + max_row_value > curr_max_row_value, max_row_value, curr_max_row_value + ) + + # find denominator for stable softmax + num = tl.exp(row_tile - max_row_value) + denom = tl.sum(num, axis=0) + for _ in range(TILE, MAX_ROW_NNZ, TILE): + row_arange -= TILE + mask = row_arange < row_nnz * col_block + row_tile = tl.load( + curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") + ).to(tl.float32) + num = tl.exp(row_tile - max_row_value) + denom += tl.sum(num, axis=0) + + # populate output + tl.store( + curr_row_values_ptrs + row_arange, + (num / denom).to(values_ptr.dtype.element_ty), + mask=mask, + ) + for _ in range(TILE, MAX_ROW_NNZ, TILE): + row_arange += TILE + mask = row_arange < row_nnz * col_block + row_tile = tl.load( + curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf") + ).to(tl.float32) + num = tl.exp(row_tile - max_row_value) + tl.store( + curr_row_values_ptrs + row_arange, + (num / denom).to(values_ptr.dtype.element_ty), + mask=mask, + ) + + def bsr_softmax(input, max_row_nnz=None): + f_name = "bsr_softmax" + + check_bsr_layout(f_name, input) + check_dtype(f_name, input, input.dtype) + + if input._nnz() == 0 or input.numel() == 0: + return input.clone() + + m, n = input.shape[-2:] + nnz = input._nnz() + row_block, col_block = input.values().shape[-2:] + + if max_row_nnz is None: + max_row_nnz = triton.next_power_of_2(n) + else: + max_row_nnz = triton.next_power_of_2(max_row_nnz) + + crow_indices = input.crow_indices().unsqueeze(0).flatten(0, -2) + # reshape values from + # (b1, ..., bn, nnz, row_block, col_block) to + # (b1 * ... * bn, row_block, nnz * col_block). + # This simplifies batch dim manipulation and unlocks + # the possibility to access all nnzs in any given row. + if input.values().transpose(-3, -2).is_contiguous(): + # Need to clone to avoid `contiguous` returning a view. + values = input.values().clone() + else: + values = input.values() + values = ( + values.transpose(-3, -2) + .contiguous() + .unsqueeze(0) + .flatten(0, -4) + .reshape(-1, row_block, nnz * col_block) + ) + full_grid = (values.shape[0], row_block, m // row_block) + grid_blocks = None + tensor_dims_map = { + # We span nnz number of blocks, not nnz + 1, + # hence crow_indices[..., :-1] + crow_indices[..., :-1]: (0, None, -1), + values: (0, None, None), + } + + def kernel(grid, *sliced_tensors): + _bsr_softmax_kernel[grid]( + *ptr_stride_extractor(*sliced_tensors), + row_block, + col_block, + max_row_nnz, + # Triton's max numel is bounded by 2 ** 17. + min(2**17, max_row_nnz), + ) + + launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) + + values = ( + values.reshape(-1, row_block, nnz, col_block) + .transpose(-3, -2) + .reshape(*input.values().shape) + ) + + return torch.sparse_compressed_tensor( + input.crow_indices().clone(), + input.col_indices().clone(), + values, + size=input.shape, + layout=input.layout, + ) + + def _scaled_dot_product_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + ): + f_name = "_scaled_dot_product_attention" + check(not is_causal, f"{f_name}(): is_causal == True is not supported.") + check(attn_mask is not None, f"{f_name}(): attn_mask == None is not supported.") + assert attn_mask is not None + + check( + attn_mask.layout == torch.sparse_bsr, + f"{f_name}(): " + f"attn_mask.layout must be {torch.sparse_bsr}, but got " + f"attn_mask.layout == {attn_mask.layout}.", + ) + + check_device(f_name, key, query.device) + check_device(f_name, value, query.device) + check_device(f_name, attn_mask, query.device) + + check_dtype(f_name, key, query.dtype) + check_dtype(f_name, value, query.dtype) + if attn_mask.dtype is not torch.bool: + check_dtype(f_name, attn_mask, query.dtype) + + sdpa = sampled_addmm( + attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False + ) + if scale is None and query.size(-1) == 0 or scale == 0.0: + check( + False, + f"{f_name}(): current value of scale == {scale} " + "results in division by zero.", + ) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + sdpa.values().mul_(scale_factor) + sdpa = bsr_softmax(sdpa) + torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True) + sdpa = bsr_dense_mm(sdpa, value) + return sdpa + + @triton.jit + def _scatter_mm2_kernel( + M: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + blocks_ptr, + blocks_stride_P, + blocks_stride_M, + blocks_stride_K, + others_ptr, + others_stride_Q, + others_stride_K, + others_stride_N, + accumulators_ptr, + accumulators_stride_R, + accumulators_stride_M, + accumulators_stride_N, + pq_offsets_ptr, + pq_offsets_stride, + pq_ptr, + pq_stride_T, + pq_stride_1, + dot_out_dtype: tl.constexpr, + TILE_M: tl.constexpr, + TILE_N: tl.constexpr, + allow_tf32: tl.constexpr, + ): + Ms = M // TILE_M + Ns = N // TILE_N + + pid_t = tl.program_id(axis=0) + + pid = tl.program_id(axis=1) + pid_m = pid // Ms + pid_n = pid % Ms + + rm = pid_m * TILE_M + tl.arange(0, TILE_M) + rn = pid_n * TILE_N + tl.arange(0, TILE_N) + rk = tl.arange(0, K) + + A_ptr = blocks_ptr + ( + rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K + ) + B_ptr = others_ptr + ( + rk[:, None] * others_stride_K + rn[None, :] * others_stride_N + ) + + g0 = tl.load(pq_offsets_ptr + pid_t * pq_offsets_stride) + g1 = tl.load(pq_offsets_ptr + (pid_t + 1) * pq_offsets_stride) + + if g0 == g1: + return + + acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype) + + for i in range(g0, g1): + p = tl.load(pq_ptr + i * pq_stride_T) + q = tl.load(pq_ptr + i * pq_stride_T + pq_stride_1) + A = tl.load(A_ptr + p * blocks_stride_P) + B = tl.load(B_ptr + q * others_stride_Q) + acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + + C_ptr = ( + accumulators_ptr + + pid_t * accumulators_stride_R + + ( + rm[:, None] * accumulators_stride_M + + rn[None, :] * accumulators_stride_N + ) + ) + tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty)) + + def _scatter_mm2( + blocks: torch.Tensor, + others: torch.Tensor, + pq_offsets: torch.Tensor, + pq_indices: torch.Tensor, + accumulators: torch.Tensor, + ): + P, M, K = blocks.shape + Q, _, N = others.shape + R, _, _ = accumulators.shape + + meta = dict( + TILE_M=max(16, M // 4), TILE_N=max(16, N // 4), num_stages=1, num_warps=2 + ) + + def grid(META): + return ( + pq_offsets.shape[0] - 1, + triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]), + 1, + ) + + dot_out_dtype = { + torch.float16: tl.float32, + torch.bfloat16: tl.float32, + torch.float32: tl.float64, + torch.float64: tl.float64, + }[accumulators.dtype] + if "allow_tf32" not in meta: + meta.update(allow_tf32=dot_out_dtype == tl.float32) + _scatter_mm2_kernel[grid]( + M, + K, + N, + blocks, + blocks.stride(0), + blocks.stride(1), + blocks.stride(2), + others, + others.stride(0), + others.stride(1), + others.stride(2), + accumulators, + accumulators.stride(0), + accumulators.stride(1), + accumulators.stride(2), + pq_offsets, + pq_offsets.stride(0), + pq_indices, + pq_indices.stride(0), + pq_indices.stride(1), + dot_out_dtype=dot_out_dtype, + **meta, + ) + + @triton.jit + def _scatter_mm6_kernel( + nbatches, + Ms, + Ks: tl.constexpr, + N, + blocks_ptr, + blocks_stride_P, + blocks_stride_M, + blocks_stride_K, + others_ptr, + others_stride_B, + others_stride_K, + others_stride_N, + accumulators_ptr, + accumulators_stride_B, + accumulators_stride_M, + accumulators_stride_N, + c_indices_ptr, + r_offsets_ptr, + p_offsets_ptr, + q_offsets_ptr, + is_compressed: tl.constexpr, + dot_out_dtype: tl.constexpr, + SPLIT_N: tl.constexpr, + TILE_M: tl.constexpr, + TILE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + allow_tf32: tl.constexpr, + ): + Ns = N // SPLIT_N + BLOCKS_M = Ms // TILE_M + BLOCKS_N = Ns // TILE_N + + pid_t_ = tl.program_id(axis=0) + pid = tl.program_id(axis=1) + pid_b = pid_t_ % nbatches + pid_t = pid_t_ // nbatches + + num_pid_in_group = GROUP_SIZE * BLOCKS_N + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE + group_size_m = min(BLOCKS_M - first_pid_m, GROUP_SIZE) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + rm = pid_m * TILE_M + tl.arange(0, TILE_M) + rn = pid_n * TILE_N + tl.arange(0, TILE_N) + rk = tl.arange(0, Ks) + A_ptr = blocks_ptr + ( + rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K + ) + B_ptr = ( + others_ptr + + pid_b * others_stride_B + + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N) + ) + + # When is_compressed is True, r is the only variable that + # depends on pid_t. This property allows sorting r values + # before calling the kernel. The sorting of r is equivalent to + # defining swizzle operator outside of the kernel. + r = tl.load(r_offsets_ptr + pid_t) + + if is_compressed: + m = (r // N) // Ms + n = (r % N) // Ns + r0 = tl.load(c_indices_ptr + m) + r1 = tl.load(c_indices_ptr + m + 1) + g0 = n * r1 + (SPLIT_N - n) * r0 + nnz = r1 - r0 + else: + g0 = tl.load(c_indices_ptr + pid_t) + g1 = tl.load(c_indices_ptr + pid_t + 1) + nnz = g1 - g0 + + q_ptr = q_offsets_ptr + g0 + acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype) + + if is_compressed: + A_ptr += r0 * blocks_stride_P # type: ignore[possibly-undefined] + for _ in range(nnz): + q = tl.load(q_ptr) + B = tl.load(B_ptr + q) + A = tl.load(A_ptr) + acc_block += tl.dot( + A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32 + ) + A_ptr += blocks_stride_P + q_ptr += 1 + else: + p_ptr = p_offsets_ptr + g0 + for _ in range(nnz): + q = tl.load(q_ptr) + B = tl.load(B_ptr + q) + p = tl.load(p_ptr) + A = tl.load(A_ptr + p * blocks_stride_P) + p_ptr += 1 + q_ptr += 1 + acc_block += tl.dot( + A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32 + ) + + C_ptr = ( + accumulators_ptr + + r + + pid_b * accumulators_stride_B + + ( + rm[:, None] * accumulators_stride_M + + rn[None, :] * accumulators_stride_N + ) + ) + tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty)) + + def _scatter_mm6( + blocks: torch.Tensor, + others: torch.Tensor, + c_indices: torch.Tensor, + r_offsets: torch.Tensor, + p_offsets: torch.Tensor, + q_offsets: torch.Tensor, + meta: dict, + accumulators: torch.Tensor, + force_contiguous: bool = True, + ): + SPLIT_N = meta["SPLIT_N"] + P, Ms, Ks = blocks.shape + B, K_, N = others.shape + B_, M, N_ = accumulators.shape + assert N_ == N + Ns = N // SPLIT_N + assert B_ == B + + def grid(META): + return ( + r_offsets.shape[0] * B, + triton.cdiv(Ms, META["TILE_M"]) * triton.cdiv(Ns, META["TILE_N"]), + ) + + dot_out_dtype = { + torch.float16: tl.float32, + torch.bfloat16: tl.float32, + torch.float32: tl.float64, + torch.float64: tl.float64, + }[accumulators.dtype] + if "allow_tf32" not in meta: + meta.update(allow_tf32=dot_out_dtype == tl.float32) + + assert c_indices.stride(0) == 1 + assert r_offsets.stride(0) == 1 + assert p_offsets.stride(0) == 1 + assert q_offsets.stride(0) == 1 + + # Re non-contiguous tensor arguments. Sometimes triton kernel + # launches may fail with + # + # RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered + # + # that appears to be case when the size of a non-contiguous + # tensor argument is larger than a certain threshold. Could + # this be related to shared memory or L1 cache size of a GPU + # card? In anycase, ensuring that tensor arguments are + # contiguous seems to avoid the above exception. So, in the + # following we'll always convert tensor arguments to + # C-contiguous tensors. + + if force_contiguous: + blocks = blocks.contiguous() + others = others.contiguous() + if not accumulators.is_contiguous(): + accumulators_ = accumulators.contiguous() + else: + accumulators_ = accumulators + else: + accumulators_ = accumulators + + _scatter_mm6_kernel[grid]( + B, + Ms, + Ks, + N, + blocks, + blocks.stride(0), + blocks.stride(1), + blocks.stride(2), + others, + others.stride(0), + others.stride(1), + others.stride(2), + accumulators_, + accumulators_.stride(0), + accumulators_.stride(1), + accumulators_.stride(2), + c_indices, + r_offsets, + p_offsets, + q_offsets, + dot_out_dtype=dot_out_dtype, + **meta, + ) + + if force_contiguous and not accumulators.is_contiguous(): + accumulators.copy_(accumulators_) + + @triton.jit + def _bsr_strided_addmm_kernel( + # values prologue + values_ptr, + values_batch_stride, + values_nnz_stride, + values_row_block_stride, + values_col_block_stride, + # values epilogue + # crow_indices prologue + crow_indices_ptr, + crow_indices_batch_stride, + crow_indices_stride, + # crow_indices epilogue + # col_indices prologue + col_indices_ptr, + col_indices_batch_stride, + col_indices_stride, + # col_indices epilogue + # input prologue + input_ptr, + input_batch_stride, + input_tiled_row_stride, + input_tiled_col_stride, + input_row_block_stride, + input_col_block_stride, + # input epilogue + # dense prologue + dense_ptr, + dense_batch_stride, + dense_tiled_row_stride, + dense_tiled_col_stride, + dense_row_block_stride, + dense_col_block_stride, + # dense epilogue + # output prologue + output_ptr, + output_batch_stride, + output_tiled_row_stride, + output_tiled_col_stride, + output_row_block_stride, + output_col_block_stride, + # output epilogue + beta, + alpha, + beta_is_one: tl.constexpr, + beta_is_nonzero: tl.constexpr, + alpha_is_one: tl.constexpr, + BLOCKSIZE_ROW: tl.constexpr, + BLOCKSIZE_COL: tl.constexpr, + BLOCKSIZE_INNER: tl.constexpr, + acc_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + GROUP_SIZE_ROW: tl.constexpr, + SPLIT_N: tl.constexpr, + ): + batch_pid = tl.program_id(axis=2) + row_block_pid = tl.program_id(axis=0) + col_block_pid = tl.program_id(axis=1) + n_block_rows = tl.num_programs(axis=0) + n_block_cols = tl.num_programs(axis=1) + + row_block_pid, col_block_pid = tl.swizzle2d( + row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW + ) + + crow_indices_offset_ptr = ( + crow_indices_ptr + + crow_indices_batch_stride * batch_pid + + crow_indices_stride * row_block_pid + ) + nnz_offset = tl.load(crow_indices_offset_ptr) + nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) + + # Compute nnz for the row with number row_block_pid. + row_nnz = nnz_offset_next - nnz_offset + + row_block_arange = tl.arange(0, BLOCKSIZE_ROW) + inner_block_arange = tl.arange(0, BLOCKSIZE_INNER) + col_block_arange = tl.arange(0, BLOCKSIZE_COL) + + if beta_is_nonzero: + # Pointers are set to exact write-to locations + input_ptrs = ( + input_ptr + + input_batch_stride * batch_pid + + input_tiled_row_stride * row_block_pid + + input_tiled_col_stride * col_block_pid + + input_row_block_stride * row_block_arange[:, None] + + input_col_block_stride * col_block_arange[None, :] + ) + + # Pointers are set to the first block of the current row. + values_block_ptrs = ( + values_ptr + + values_batch_stride * batch_pid + + values_nnz_stride * nnz_offset + + values_row_block_stride * row_block_arange[:, None] + + values_col_block_stride * inner_block_arange[None, :] + ) + + # NOTE: dense is advanced into all dimensions but the tiled row one. + # That will be advanced in the loop according to values in col_indices. + dense_block_ptrs = ( + dense_ptr + + dense_batch_stride * batch_pid + + dense_tiled_col_stride * col_block_pid + + dense_row_block_stride * inner_block_arange[:, None] + + dense_col_block_stride * col_block_arange[None, :] + ) + + # Pointers are set to exact write-to locations + output_ptrs = ( + output_ptr + + output_batch_stride * batch_pid + + output_tiled_row_stride * row_block_pid + + output_tiled_col_stride * col_block_pid + + output_row_block_stride * row_block_arange[:, None] + + output_col_block_stride * col_block_arange[None, :] + ) + + # Set pointer to the first nonzero element in the current row + col_index_nnz_ptr = ( + col_indices_ptr + + col_indices_batch_stride * batch_pid + + col_indices_stride * nnz_offset + ) + + # alpha is never 0 + if beta_is_nonzero: + output_acc_block = tl.load(input_ptrs).to(acc_dtype) # type: ignore[possibly-undefined] + if not (beta_is_one and alpha_is_one): + beta_alpha = beta / alpha + output_acc_block *= beta_alpha + else: + output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) + + for _ in range(row_nnz): + values_block = tl.load(values_block_ptrs) + + # find which row of dense needs to get loaded + # for multiplication with values_block. + dense_row_idx = tl.load(col_index_nnz_ptr) + dense_block = tl.load( + dense_block_ptrs + dense_tiled_row_stride * dense_row_idx + ) + + # do block mm + output_acc_block += tl.dot( + values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype + ) + + # move val/col_index ptrs to the next block in the row + values_block_ptrs += values_nnz_stride + col_index_nnz_ptr += col_indices_stride + + if not alpha_is_one: + output_acc_block *= alpha + + # write back the result + tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) + +else: + bsr_softmax = None # type: ignore[assignment] + bsr_dense_mm = None # type: ignore[assignment] + sampled_addmm = None # type: ignore[assignment] + _scaled_dot_product_attention = None # type: ignore[assignment] + _scatter_mm2 = None # type: ignore[assignment] + _scatter_mm6 = None # type: ignore[assignment] + _bsr_strided_addmm_kernel = None # type: ignore[assignment] diff --git a/lib/python3.10/site-packages/torch/sparse/_triton_ops_meta.py b/lib/python3.10/site-packages/torch/sparse/_triton_ops_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a97d9c502dc349533ed778c2e4bbbcf1891ded1c --- /dev/null +++ b/lib/python3.10/site-packages/torch/sparse/_triton_ops_meta.py @@ -0,0 +1,7387 @@ +# mypy: allow-untyped-defs +"""Provides optimal triton kernel parameters. + +Aim +--- + +The usage of optimal triton kernel parameters may increase the +performance of operations several times. For example, for large tensor +shapes, the usage of a bsr tensor as mat1 argument in addmm-based +operations typically outperforms the corresponding operation with +strided-only inputs when the blocked representation of a tensor +provides a better alignement with memory access than what the strided +representation would provide. + +Pre-computed kernel parameters +------------------------------ + +This script finds and stores the optimal triton kernel parameters for +a specific set of shape configurations. For instance, the set of shape +configurations of the bsr_dense_addmm kernel is defined as + + input, out: M x N strided tensor + mat1: M x K bsr tensor with blocksize (BM, BK) and given sparsity + mat2: M x N strided tensor + dtype = float16, bfloat16, float32 + sparsity = 0.5 + M = 256, 512, ..., 16384 + K = M + N = 256, 512, ..., 131072 + BM = 16, 32, ..., 128 + BK = BM + alpha = 1 + beta = 0, 1 + GPUs: NVIDIA A100-SXM4-80GB + +Approximations +-------------- + +It is practically infeasible to pre-compute optimal kernel parameter +for all possible shape configurations as well as for all existing +GPUs. Therefore, we'll assume that the pre-computed optimal parameters +are good enough approximations when +1) the used GPU is any of NVIDIA A100 Tensor Core GPUs, +2) the actual sparsity of mat1 is different from sparsity value 0.5. + +If a particular shape configuration does not fall in the set of +pre-computed kernel parameters, or it does not match with the listed +approximations above, or the used GPU device is not a NVIDIA A100 GPU, +then a reference set of triton kernel parameters will be used when +executing operations. The reference kernel parameters are defined in +torch/sparse/_triton_ops.py, see bsr_dense_addmm_meta function, for +instance. + +Computing optimal kernel parameters +----------------------------------- + +If the approximations listed above are unacceptable, e.g. when one +seeks a maximal performance possible, the optimal kernel parameters +for a particular GPU can be computed by simply running this script in +the pytorch developement tree:: + + cd /path/to/pytorch + python setup.py develop + python torch/sparse/_triton_ops_meta.py + +This will compute the optimal kernel parameters for the GPU device +available in the host system for all shape configurations listed in +"Pre-computed kernel parameters" above. The results will be stored in +the database of kernel parameters. Currently, this database is defined +as this module (see "BEGIN GENERATED DATA" comment below) that will be +modified when the script is run. Create a pytorch PR with the +corresponding modifications in this file to make the computed optimal +kernel parameters available for other users as pre-computed kernel +parameters. + +Moreover, one can compute the optimal kernel parameters for a specific +set of shape configurations and specific sparsity patterns. For that, +use tuning functions provided by this module: + + tune_bsr_dense_addmm(input, mat1, mat2, beta=1, alpha=1, out=None, verbose=False, store=False) -> meta + +The tuning functions return a dictionary of optimal kernel parameters +that can be passed to the corresponding operation, e.g. + + bsr_dense_addmm(..., meta=meta) + +Or, when store==True, the optimal kernel parameters will be stored in +the database of pre-computed kernel parameters in runtime so that all +addmm-based operations such as torch.addmm, torch.mm, +torch.nn.functional.linear will benefit from using the computed +optimal set of kernel parameters. + +Note that running tune_bsr_dense_addmm can take several minutes. So, +use it wisely, e.g. by implementing persisten storage of optimized +kernel parameters. See the source code of get_meta and +tune_bsr_dense_addmm to learn how to register a custom set of optimal +kernel parameters for addmm-based operations. + +""" +__all__ = ["get_meta", "tune_bsr_dense_addmm", "tune__int_bsr_dense_addmm"] + +import inspect +import itertools +import re +import warnings +from typing import Any, Dict + +import torch +from torch.hub import tqdm +from torch.testing import make_tensor + + +def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=False): + """Return triton kernel meta parameters of the specified op and its inputs key. + + Parameters + ---------- + op (str): The name of an operation that implementation uses meta parameters. + key (tuple): A tuple of op input parameters, e.g. shapes, etc. + device_name (optional, str): The name of a device for which op + parameters are provided. + version (optional, hashable): Specifies the version of parameters. + exact (optional, bool): When True, the returned data (if + available) corresponds exactly to the specified device_name and + version information. Otherwise, if the corresponding data is not + available but there exists a data set that is computed for a + similar GPU device, then this data set will be returned. + + Returns + ------- + result (dict): The requested mapping of parameter names and + values, or None when no data is available. If the input `key` + contains `"*"`, the result will be a dictionary of keys and + mappings that match with the given `key`. + """ + if device_name is None: + device_name = torch.cuda.get_device_name() + + op_data = _operation_device_version_data.get((op, device_name, version)) + if op_data is None and not exact: + # A lack of op data could be due to using a (slightly) + # different GPU model compared to a model for which optimal + # meta parameters have been computed. In the following we'll + # assume that there is a set of GPU models that all have + # a similar set of optimal meta parameters. + if re.match(r"NVIDIA A100[^\d]", device_name) is not None: + device_name = "NVIDIA A100-SXM4-80GB" + else: + return + op_data = _operation_device_version_data.get((op, device_name, version)) + if op_data is None: + return + + matching_data = {} + if "*" in key: + for op_key in op_data: + if [None for k1, k2 in zip(op_key, key) if k2 != "*" and k1 != k2]: + continue + matching_data[op_key] = op_data[op_key] + else: + values = op_data.get(key) + if values is not None: + matching_data[key] = values + matching_meta = {} + for op_key, values in matching_data.items(): + if op == "scatter_mm": + names = ( + "GROUP_SIZE", + "SPLIT_N", + "TILE_M", + "TILE_N", + "num_stages", + "num_warps", + ) + meta = dict(zip(names, values)) + elif op in {"bsr_dense_addmm", "_int_bsr_dense_addmm"}: + meta = dict( + zip(("GROUP_SIZE_ROW", "SPLIT_N", "num_stages", "num_warps"), values) + ) + else: + raise NotImplementedError(f"names for {op=}") + if "*" not in key: + return meta + + matching_meta[op_key] = meta + + if "*" in key: + return matching_meta + + +def update(op, device_name, version, key, value): + """Update the db of op parameters.""" + # avoid storing possible optimization failures: + assert value, (op, device_name, version, key, value) + if (op, device_name, version) in _operation_device_version_data: + if _operation_device_version_data[op, device_name, version].get(key) == value: + return + _operation_device_version_data[op, device_name, version][key] = value + else: + _operation_device_version_data[op, device_name, version] = {key: value} + + +def dump(): + """Store the current runtime db state to the module file.""" + current_file = inspect.getfile(dump) + f = open(current_file) + current_content = f.read() + f.close() + begin_data_str = "# BEGIN GENERATED DATA\n" + begin_data_index = current_content.find(begin_data_str) + end_data_index = current_content.find(" # END GENERATED DATA\n") + if begin_data_index == -1 or end_data_index == -1: + warnings.warn( + f"{current_file} cannot be updated:" + " BEGIN/END GENERATED DATA comment blocks appear to be corrupted" + ) + return + + def sort_key(key): + op, device_name, version = key + version = tuple( + (str(item) if isinstance(item, torch.dtype) else item) for item in version + ) + return (op, device_name, version) + + part1 = current_content[: begin_data_index + len(begin_data_str)] + part2 = current_content[end_data_index:] + data_part = [] + for op_key in sorted(_operation_device_version_data, key=sort_key): + data_part.append(" " + repr(op_key).replace("'", '"') + ": {") + op_data = _operation_device_version_data[op_key] + for key in sorted(op_data): + data_part.append(f" {key}: {op_data[key]},") + data_part.append(" },") + new_content = part1 + "\n".join(data_part) + "\n" + part2 + if current_content != new_content: + f = open(current_file, "w") + f.write(new_content) + f.close() + + +def minimize( + target_func, + initial_parameters, + reference_parameters, + step_func, + max_step=2, + verbose=False, + all_values=None, +): + """Find a dict of parameters that minimizes the target function using + the initial dict of parameters and a step function that progresses + a specified parameter in a dict of parameters. + + Parameters + ---------- + target_func (callable): a functional with the signature + ``target_func(parameters: dict) -> float`` + initial_parameters (dict): a set of parameters used as an initial + value to the minimization process. + reference_parameters (dict): a set of parameters used as an + reference value with respect to which the speed up is computed. + step_func (callable): a functional with the signature + ``step_func(parameter_name:str, parameter_value:int, direction:int, parameters:dict) -> int`` + that increments or decrements (when ``direction`` is positive or + negative, respectively) the parameter with given name and value. + When return value is equal to ``parameter_value``, it means that + no step along the given direction can be made. + + Returns + ------- + parameters (dict): a set of parameters that minimizes the target + function. + speedup_incr (float): a speedup change given in percentage. + timing (float): the value of the target function at the parameters. + sensitivity_message (str): a message containing sensitivity. + information of parameters around the target function minimizer. + """ + + def to_key(parameters): + return tuple(parameters[k] for k in sorted(parameters)) + + def from_key(key, parameters): + return dict(zip(sorted(parameters), key)) + + if all_values is None: + all_values = {} + + directions = list(range(-max_step, max_step + 1)) + names = sorted(initial_parameters) + all_directions = [] + for d_tuple in itertools.product(*((directions,) * len(names))): + dist = sum(map(abs, d_tuple)) + if dist > 0 and dist <= max_step: + all_directions.append((dist, d_tuple)) + all_directions.sort() + + try: + reference_target = target_func(reference_parameters) + except Exception as msg: + if verbose and "out of resource" not in str(msg): + print(f"{reference_parameters=} lead to failure: {msg}.") + reference_target = None + + if reference_target is not None: + all_values[to_key(reference_parameters)] = reference_target + + parameters = initial_parameters + try: + initial_target = target_func(parameters) + except Exception as msg: + if reference_target is None: + if verbose: + print( + f"{initial_parameters=} lead to failure: {msg}. Optimization failed!" + ) + return {}, -1, -1, f"{msg}" + if verbose and "out of resource" not in str(msg): + print( + f"{initial_parameters=} lead to failure: {msg}. Using reference parameters instead of initial parameters." + ) + parameters = reference_parameters + initial_target = reference_target + + if reference_target is None: + if verbose: + print("Using initial parameters instead of reference parameters.") + reference_target = initial_target + + initial_key = to_key(parameters) + minimal_target = all_values[initial_key] = initial_target + pbar = tqdm( + total=len(all_directions), + desc="Tuning...", + disable=not verbose, + ncols=75, + ) + while True: + for i, (_, d_tuple) in enumerate(all_directions): + pbar.update(1) + next_parameters = parameters.copy() + for name, direction in zip(names, d_tuple): + value = next_parameters[name] + if direction == 0: + continue + next_value = step_func(name, value, direction, parameters) + if next_value == value: + break + next_parameters[name] = next_value + else: + next_key = to_key(next_parameters) + if next_key in all_values: + continue + try: + next_target = target_func(next_parameters) + except Exception as msg: + all_values[next_key] = str(msg) + if verbose and "out of resource" not in str(msg): + print(f"{next_parameters=} lead to failure: {msg}. Skipping.") + continue + all_values[next_key] = next_target + + if next_target < minimal_target: + minimal_target = next_target + parameters = next_parameters + pbar.total += i + 1 + break + else: + # ensure stable minimizer: + minimizer_keys = { + k + for k, v in all_values.items() + if isinstance(v, float) and abs(1 - v / minimal_target) < 0.001 + } + minimizer_key = ( + initial_key if initial_key in minimizer_keys else min(minimizer_keys) + ) + minimizer_target = all_values[minimizer_key] + parameters = from_key(minimizer_key, parameters) + speedup_incr = (1 - minimal_target / reference_target) * 100 + if speedup_incr < 0: + if verbose: + print( + f"{speedup_incr=} is negative. Rerunning minimize with reference parameters as initial parameters." + ) + return minimize( + target_func, + reference_parameters, + reference_parameters, + step_func, + max_step=max_step, + verbose=verbose, + all_values=all_values, + ) + sensitivity = [] + for name in parameters: + value = parameters[name] + rel_diffs = [] + for direction in range(-max_step, max_step + 1): + if direction == 0: + continue + next_value = step_func(name, value, direction, parameters) + if next_value == value: + rel_diffs.append(0) + continue + next_parameters = parameters.copy() + next_parameters[name] = next_value + next_key = to_key(next_parameters) + next_target = all_values.get(next_key) + if next_target is None or isinstance(next_target, str): + rel_diffs.append(0) + continue + rel_diff = (next_target / minimal_target - 1) * 100 + rel_diffs.append(rel_diff) + sensitivity.append((max(rel_diffs), rel_diffs, name)) + + sensitivity_message = [f"timing0={initial_target:.3f}"] + for _, rel_diffs, name in sorted(sensitivity, reverse=True): + left_diffs = "|".join( + [f"{rel_diff:.1f}" for rel_diff in rel_diffs[:max_step]] + ) + right_diffs = "|".join( + [f"{rel_diff:.1f}" for rel_diff in rel_diffs[max_step:]] + ) + sensitivity_message.append( + f"{name}={parameters[name]} ({left_diffs}...{right_diffs} %)" + ) + sensitivity_message = ", ".join(sensitivity_message) + return parameters, speedup_incr, minimal_target, sensitivity_message + + +def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device): + assert ( + sparsity <= 1.0 and sparsity >= 0.0 + ), "sparsity should be a value between 0 and 1" + assert M % blocksize[0] == 0 + assert N % blocksize[1] == 0 + shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :] + A = torch.bernoulli( + torch.full(shape, 1 - sparsity, dtype=torch.float32, device=device) + ).to(dtype) + expected_nnz = int((1 - sparsity) * M * N / (blocksize[0] * blocksize[1])) + nonzero_indices = A.flatten().nonzero() + actual_nnz = nonzero_indices.shape[0] + if actual_nnz > expected_nnz: + selected_nonzeros = torch.randperm(actual_nnz)[: actual_nnz - expected_nnz] + A.flatten()[nonzero_indices[selected_nonzeros]] = 0 + elif actual_nnz < expected_nnz: + zero_indices = (A == 0).flatten().nonzero() + selected_zeros = torch.randperm(zero_indices.shape[0])[ + : expected_nnz - actual_nnz + ] + A.flatten()[zero_indices[selected_zeros]] = 1 + A = torch.repeat_interleave(A, blocksize[0], dim=-2) + A = torch.repeat_interleave(A, blocksize[1], dim=-1) + return A + + +def optimize_scatter_mm( + m, k, n, bm, bk, dtype=torch.float16, device="cuda", sparsity=0.5, force=False +): + import triton + + from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data + + key = (m, k, n, bm, bk) + + version = (0, dtype, sparsity) + device_name = torch.cuda.get_device_name() + + reference_meta = dict( + GROUP_SIZE=1, + TILE_M=16, + TILE_N=16, + SPLIT_N=n // 16, + num_stages=1, + num_warps=1, + ) + + initial_meta = get_meta( + "scatter_mm", key, device_name=device_name, version=version, exact=True + ) + if initial_meta is None: + initial_meta = get_meta( + "bsr_dense_addmm", + key, + device_name=device_name, + version=(0, dtype, 0.5), + exact=True, + ) + if initial_meta is None: + initial_meta = reference_meta + elif not force: + return + + torch.manual_seed(0) + bsr = create_blocked_tensor( + 0, m, k, (bm, bk), sparsity, dtype, device + ).to_sparse_bsr((bm, bk)) + dense = make_tensor(k, n, dtype=dtype, device=device) + + def bench(meta, bsr=bsr, dense=dense): + indices_data = bsr_scatter_mm_indices_data( + bsr, dense, indices_format="bsr_strided_mm_compressed", **meta + ) + + def test_func(): + return bsr_scatter_mm(bsr, dense, indices_data=indices_data) + + ms_min = triton.testing.do_bench( + test_func, warmup=500, rep=100, fast_flush=False + ) + + return ms_min + + def step_meta_parameter(name, value, direction, meta, m=m, n=n, k=k, bm=bm, bk=bk): + # return next value in positive or negative direction, or + # input value if the step will result an invalid + # value. The input value is assumed to be valid. + + is_log = name in {"SPLIT_N", "TILE_M", "TILE_N", "num_warps"} + min_value = dict( + SPLIT_N=1, TILE_M=16, TILE_N=16, num_warps=1, num_stages=1, GROUP_SIZE=1 + )[name] + max_value = dict( + SPLIT_N=n // meta["TILE_N"], TILE_M=bm, TILE_N=n // meta["SPLIT_N"] + ).get(name) + value_step = dict( + SPLIT_N=2, TILE_M=2, TILE_N=2, num_warps=2, num_stages=1, GROUP_SIZE=1 + )[name] + if is_log: + next_value = ( + value * value_step**direction + if direction > 0 + else value // (value_step ** abs(direction)) + ) + else: + next_value = value + value_step * direction + if min_value is not None: + next_value = max(next_value, min_value) + if max_value is not None: + next_value = min(next_value, max_value) + if name == "SPLIT_N" and n % next_value != 0: + return value + # Hard-skip parameter combinations that break CUDA state for pytorch: + if (dtype, name, next_value, m, n, k, bm, bk) in { + (torch.float32, "num_warps", 32, 256, 256, 256, 16, 16), + (torch.float32, "num_warps", 16, 256, 256, 256, 32, 32), + (torch.float32, "num_warps", 16, 256, 256, 256, 64, 64), + (torch.float32, "num_warps", 16, 256, 256, 256, 128, 128), + (torch.float32, "num_warps", 16, 512, 512, 256, 128, 128), + } and re.match(r"NVIDIA A100[^\d]", device_name) is not None: + return value + return next_value + + meta, speedup, timing, sensitivity_message = minimize( + bench, initial_meta, reference_meta, step_meta_parameter + ) + if initial_meta is not reference_meta and initial_meta == meta and not force: + return + print(f"{meta=} {speedup=:.1f} % {timing=:.3f} ms") + if speedup < 0: + return + device_name = torch.cuda.get_device_name() + + update( + "scatter_mm", device_name, version, key, tuple(meta[k] for k in sorted(meta)) + ) + + +def tune__int_bsr_dense_addmm( + input, + bsr, + dense, + *, + beta=1, + alpha=1, + out=None, + store=False, + verbose=False, + force=False, +): + return tune_bsr_dense_addmm( + input, + bsr, + dense, + beta=beta, + alpha=alpha, + out=out, + store=store, + verbose=verbose, + force=force, + opname="_int_bsr_dense_addmm", + ) + + +def tune_bsr_dense_addmm( + input, + bsr, + dense, + *, + beta=1, + alpha=1, + out=None, + store=False, + verbose=False, + force=False, + opname=None, +): + """Tune bsr_dense_addmm kernel parameters against the given inputs. + + When store is True, the tuning results will be stored in the + database of kernel parameters. + """ + import triton + + if opname is None: + opname = "bsr_dense_addmm" + + if opname == "_int_bsr_dense_addmm": + from torch.sparse._triton_ops import _int_bsr_dense_addmm as bsr_dense_addmm + else: + from torch.sparse._triton_ops import bsr_dense_addmm + + N = dense.shape[-1] + values = bsr.values() + crow_indices = bsr.crow_indices() + batch_ndim = crow_indices.dim() - 1 + M, K = bsr.shape[batch_ndim : batch_ndim + 2] + BM, BK = values.shape[batch_ndim + 1 : batch_ndim + 3] + + # Reference parameters is a set of parameters that leads to a + # successful kernel call and the corresponding timing is used as a + # reference for computing speedups. Avoid changing the reference + # parameters when possible. + reference_meta = dict( + GROUP_SIZE_ROW=1, num_stages=1, num_warps=4, SPLIT_N=max(N // BM, 1) + ) + + # Compute the key of parameters: + sparsity = round(1 - bsr._nnz() * BM * BK / (M * K), 2) + dtype = bsr.dtype + version = (0, dtype, sparsity) + key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1) + + # For tuning, for an initial state, use parameters from the + # database if available, otherwise, use the reference parameters. + initial_meta = get_meta(opname, key, version=version, exact=True) + if initial_meta is None: + may_skip_update = False + initial_meta = get_meta(opname, key, version=(0, dtype, 0.5), exact=True) + if initial_meta is None: + initial_meta = reference_meta + elif not force: + return initial_meta + else: + may_skip_update = True + + # The target function that is minimized in the tuning process: + def bench(meta, input=input, bsr=bsr, dense=dense, alpha=alpha, out=out): + def test_func(): + return bsr_dense_addmm( + input, bsr, dense, beta=beta, alpha=alpha, meta=meta, out=out + ) + + return triton.testing.do_bench(test_func, warmup=500, rep=100, fast_flush=False) + + # The step function that increments a specified meta parameter: + def step_meta_parameter(name, value, direction, meta, M=M, N=N, K=K, BM=BM, BK=BK): + # return next value in positive or negative direction, or + # input value if the step will result an invalid + # value. The input value is assumed to be valid. + is_log = name in {"SPLIT_N", "num_warps"} + min_value = dict(SPLIT_N=1, num_warps=1, num_stages=1, GROUP_SIZE_ROW=1)[name] + max_value = dict(SPLIT_N=max(N // BM, 1)).get(name) + value_step = dict(SPLIT_N=2, num_warps=2, num_stages=1, GROUP_SIZE_ROW=1)[name] + if is_log: + next_value = ( + value * value_step**direction + if direction > 0 + else value // (value_step ** abs(direction)) + ) + else: + next_value = value + value_step * direction + if min_value is not None: + next_value = max(next_value, min_value) + if max_value is not None: + next_value = min(next_value, max_value) + if name == "SPLIT_N" and N % next_value != 0: + return value + return next_value + + # Tune: + meta, speedup, timing, sensitivity_message = minimize( + bench, + initial_meta, + reference_meta, + step_meta_parameter, + max_step=2, + verbose=verbose, + ) + if verbose: + print(f"-> {sensitivity_message}, {speedup=:.1f} %, {timing=:.3f} ms") + + if store and not ( + may_skip_update and meta == initial_meta and initial_meta is not reference_meta + ): + device_name = torch.cuda.get_device_name() + update( + opname, + device_name, + version, + key, + tuple(meta[k] for k in sorted(meta)), + ) + + return meta + + +def optimize_bsr_dense_addmm( + m, + k, + n, + bm, + bk, + beta=1, + alpha=1, + dtype=torch.float16, + device="cuda", + sparsity=0.5, + force=False, + verbose=False, + opname=None, +): + torch.manual_seed(0) + bsr = create_blocked_tensor( + 0, m, k, (bm, bk), sparsity, dtype, device + ).to_sparse_bsr((bm, bk)) + dense = make_tensor(k, n, dtype=dtype, device=device) + input = make_tensor(m, n, dtype=dtype, device=device) + tune_bsr_dense_addmm( + input, + bsr, + dense, + beta=beta, + alpha=alpha, + store=True, + force=force, + verbose=verbose, + opname=opname, + ) + + +def main(op="scatter_mm", force=False, dtype=torch.float16, verbose=True): + import itertools + + sizes_lst = [ + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 32768, + 65536, + 131072, + 50432, + ] + sizes3_lst = [3 * sz for sz in [64, 128] + sizes_lst if sz <= 2048] + shapes_lst = [(sz, sz) for sz in sizes_lst[:-4] + sizes3_lst] + shapes_lst.extend([(3072, 768), (768, 3072)]) + if dtype is torch.int8: + # triton does not support smaller blocks than 32 + blocksize_lst = [(32, 32), (64, 64), (128, 128), (256, 256)] + else: + blocksize_lst = [(16, 16), (32, 32), (64, 64), (128, 128)] + sparsity_lst = [0.5, 0.7, 0.3][:1] + for sparsity in sparsity_lst: + print(f"{op, dtype, sparsity=}") + try: + for (M, K), N, (BM, BK) in itertools.product( + shapes_lst, sizes_lst, blocksize_lst + ): + if not (BM <= M and BK <= K and M % BM == 0 and K % BK == 0): + continue + if op == "scatter_mm": + optimize_scatter_mm( + M, K, N, BM, BK, force=force, sparsity=sparsity, dtype=dtype + ) + elif op in {"bsr_dense_addmm", "_int_bsr_dense_addmm"}: + if M == K and N == 50432: + continue + print(f"{M, K, N, (BM, BK)=}") + for alpha, beta in [(1, 1), (1, 0)]: + optimize_bsr_dense_addmm( + M, + K, + N, + BM, + BK, + beta=beta, + alpha=alpha, + force=force, + sparsity=sparsity, + dtype=dtype, + verbose=verbose, + opname=op, + ) + else: + raise NotImplementedError(op) + except KeyboardInterrupt: + break + except Exception as msg: + dump() + raise + dump() + + if 0: + # Check performance dependence on sparsity and apply + # adjustments when differences are noticable (more than 10%). + # + # When using NVIDIA A100 GPU, the performance dependence on + # sparsity is insignificant (0 % ... 10 %) for majority of + # shapes/blocksizes combinations. However, for a very few + # specific size combinations, the effect of sparsity on + # performance can be up to 20 %. + for (M, K), N, (BM, BK) in itertools.product( + shapes_lst, sizes_lst, blocksize_lst + ): + meta_lst: list = [] + key = (M, K, N, BM, BK) + for sparsity1 in sparsity_lst: + torch.manual_seed(0) + bsr = create_blocked_tensor( + 0, M, K, (BM, BK), sparsity1, dtype, device="cuda" + ).to_sparse_bsr((BM, BK)) + dense = make_tensor(K, N, dtype=dtype, device="cuda") + meta_lst = [] + for sparsity in sparsity_lst: + meta = get_meta(op, key, version=(0, dtype, sparsity), exact=True) + if meta is None: + continue + + def bench(meta, bsr=bsr, dense=dense): + import triton + + if op == "scatter_mm": + from torch.sparse._triton_ops import ( + bsr_scatter_mm, + bsr_scatter_mm_indices_data, + ) + + indices_data = bsr_scatter_mm_indices_data( + bsr, + dense, + indices_format="bsr_strided_mm_compressed", + **meta, + ) + + def test_func(): + return bsr_scatter_mm( + bsr, dense, indices_data=indices_data + ) + + else: + raise NotImplementedError(op) + + ms_min = triton.testing.do_bench( + test_func, warmup=500, rep=100, fast_flush=False + ) + + return ms_min + + meta_lst.append( + (bench(meta), sparsity, tuple(meta[k] for k in sorted(meta))) + ) + if not meta_lst: + continue + meta_lst = sorted(meta_lst) + index = next( + i for i, item in enumerate(meta_lst) if item[1] == sparsity1 + ) + if meta_lst[0][2] == meta_lst[index][2]: + continue + speeddiff = (1 - meta_lst[index][0] / meta_lst[0][0]) * 100 + if abs(speeddiff) < 10: + continue + + print(sparsity1, index, key, meta_lst, speeddiff) + + if index > 0: + device_name = torch.cuda.get_device_name() + meta = get_meta( + op, key, version=(0, dtype, meta_lst[0][1]), exact=True + ) + update( + op, + device_name, + (0, dtype, sparsity1), + key, + tuple(meta[k] for k in sorted(meta)), + ) + print("update") + dump() + + +_operation_device_version_data: Dict[Any, Dict] = { + # Warning: the data in between the BEGIN/END DATA comment lines + # below is generated. It can be updated either manually or via + # calling dump function defined above. + # + # Legend [op: key -> data]: + # scatter_mm : M, K, N, Ms, Ks -> GROUP_SIZE, SPLIT_N, TILE_M, TILE_N, num_stages, num_warps + # bsr_dense_addmm : M, K, N, Ms, Ks, beta==0, beta==1, alpha==1 -> GROUP_SIZE_ROW, SPLIT_N, num_stages, num_warps + # + # BEGIN GENERATED DATA + ("_int_bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.int8, 0.5)): { + (192, 192, 256, 32, 32, False, True, True): (2, 8, 1, 4), + (192, 192, 256, 32, 32, True, False, True): (2, 8, 5, 4), + (192, 192, 512, 32, 32, False, True, True): (1, 16, 1, 4), + (192, 192, 512, 32, 32, True, False, True): (1, 16, 5, 4), + (192, 192, 1024, 32, 32, False, True, True): (1, 32, 1, 4), + (192, 192, 1024, 32, 32, True, False, True): (4, 32, 4, 4), + (192, 192, 2048, 32, 32, False, True, True): (2, 64, 1, 4), + (192, 192, 2048, 32, 32, True, False, True): (3, 16, 5, 4), + (192, 192, 4096, 32, 32, False, True, True): (1, 128, 1, 4), + (192, 192, 4096, 32, 32, True, False, True): (1, 128, 1, 4), + (192, 192, 8192, 32, 32, False, True, True): (1, 256, 1, 4), + (192, 192, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (192, 192, 16384, 32, 32, False, True, True): (2, 512, 1, 4), + (192, 192, 16384, 32, 32, True, False, True): (5, 128, 1, 4), + (192, 192, 32768, 32, 32, False, True, True): (1, 1024, 1, 4), + (192, 192, 32768, 32, 32, True, False, True): (1, 256, 1, 4), + (192, 192, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (192, 192, 65536, 32, 32, True, False, True): (1, 512, 1, 4), + (192, 192, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (192, 192, 131072, 32, 32, True, False, True): (2, 512, 1, 4), + (256, 256, 256, 32, 32, False, True, True): (4, 8, 1, 4), + (256, 256, 256, 32, 32, True, False, True): (1, 8, 6, 4), + (256, 256, 256, 64, 64, False, True, True): (1, 4, 1, 16), + (256, 256, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (256, 256, 256, 128, 128, False, True, True): (3, 2, 1, 16), + (256, 256, 256, 128, 128, True, False, True): (1, 2, 1, 4), + (256, 256, 512, 32, 32, False, True, True): (2, 16, 1, 4), + (256, 256, 512, 32, 32, True, False, True): (2, 16, 4, 4), + (256, 256, 512, 64, 64, False, True, True): (7, 8, 1, 16), + (256, 256, 512, 64, 64, True, False, True): (3, 8, 3, 4), + (256, 256, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (256, 256, 512, 128, 128, True, False, True): (1, 4, 1, 4), + (256, 256, 1024, 32, 32, False, True, True): (1, 32, 1, 4), + (256, 256, 1024, 32, 32, True, False, True): (1, 8, 6, 4), + (256, 256, 1024, 64, 64, False, True, True): (2, 16, 1, 16), + (256, 256, 1024, 64, 64, True, False, True): (1, 16, 5, 4), + (256, 256, 1024, 128, 128, False, True, True): (4, 8, 1, 32), + (256, 256, 1024, 128, 128, True, False, True): (1, 8, 2, 4), + (256, 256, 2048, 32, 32, False, True, True): (1, 64, 1, 4), + (256, 256, 2048, 32, 32, True, False, True): (2, 32, 3, 2), + (256, 256, 2048, 64, 64, False, True, True): (2, 32, 1, 16), + (256, 256, 2048, 64, 64, True, False, True): (1, 16, 3, 4), + (256, 256, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (256, 256, 2048, 128, 128, True, False, True): (1, 16, 2, 4), + (256, 256, 4096, 32, 32, False, True, True): (2, 128, 1, 4), + (256, 256, 4096, 32, 32, True, False, True): (1, 32, 3, 2), + (256, 256, 4096, 64, 64, False, True, True): (2, 64, 1, 8), + (256, 256, 4096, 64, 64, True, False, True): (1, 64, 3, 2), + (256, 256, 4096, 128, 128, False, True, True): (2, 32, 1, 32), + (256, 256, 4096, 128, 128, True, False, True): (3, 32, 2, 8), + (256, 256, 8192, 32, 32, False, True, True): (1, 256, 1, 4), + (256, 256, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (256, 256, 8192, 64, 64, False, True, True): (1, 128, 1, 8), + (256, 256, 8192, 64, 64, True, False, True): (2, 128, 1, 4), + (256, 256, 8192, 128, 128, False, True, True): (4, 64, 1, 32), + (256, 256, 8192, 128, 128, True, False, True): (3, 64, 1, 4), + (256, 256, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (256, 256, 16384, 32, 32, True, False, True): (3, 128, 1, 4), + (256, 256, 16384, 64, 64, False, True, True): (2, 256, 1, 8), + (256, 256, 16384, 64, 64, True, False, True): (2, 256, 1, 4), + (256, 256, 16384, 128, 128, False, True, True): (2, 128, 1, 32), + (256, 256, 16384, 128, 128, True, False, True): (4, 128, 2, 4), + (256, 256, 32768, 32, 32, False, True, True): (2, 512, 1, 8), + (256, 256, 32768, 32, 32, True, False, True): (1, 256, 1, 4), + (256, 256, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (256, 256, 32768, 64, 64, True, False, True): (1, 512, 1, 4), + (256, 256, 32768, 128, 128, False, True, True): (2, 256, 1, 32), + (256, 256, 32768, 128, 128, True, False, True): (1, 256, 2, 4), + (256, 256, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (256, 256, 65536, 32, 32, True, False, True): (1, 512, 1, 4), + (256, 256, 65536, 64, 64, False, True, True): (1, 1024, 1, 8), + (256, 256, 65536, 64, 64, True, False, True): (1, 512, 1, 4), + (256, 256, 65536, 128, 128, False, True, True): (2, 512, 1, 16), + (256, 256, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (256, 256, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (256, 256, 131072, 32, 32, True, False, True): (2, 1024, 1, 4), + (256, 256, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (256, 256, 131072, 64, 64, True, False, True): (2, 512, 1, 4), + (256, 256, 131072, 128, 128, False, True, True): (2, 1024, 1, 16), + (256, 256, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (384, 384, 256, 32, 32, False, True, True): (1, 8, 1, 4), + (384, 384, 256, 32, 32, True, False, True): (5, 8, 5, 4), + (384, 384, 256, 64, 64, False, True, True): (2, 4, 1, 16), + (384, 384, 256, 64, 64, True, False, True): (1, 4, 5, 4), + (384, 384, 512, 32, 32, False, True, True): (2, 16, 1, 4), + (384, 384, 512, 32, 32, True, False, True): (1, 16, 4, 4), + (384, 384, 512, 64, 64, False, True, True): (3, 8, 1, 16), + (384, 384, 512, 64, 64, True, False, True): (3, 8, 3, 4), + (384, 384, 1024, 32, 32, False, True, True): (2, 32, 1, 4), + (384, 384, 1024, 32, 32, True, False, True): (1, 8, 6, 4), + (384, 384, 1024, 64, 64, False, True, True): (2, 16, 1, 16), + (384, 384, 1024, 64, 64, True, False, True): (1, 16, 5, 4), + (384, 384, 2048, 32, 32, False, True, True): (1, 64, 1, 4), + (384, 384, 2048, 32, 32, True, False, True): (3, 16, 4, 4), + (384, 384, 2048, 64, 64, False, True, True): (2, 32, 1, 16), + (384, 384, 2048, 64, 64, True, False, True): (1, 16, 4, 4), + (384, 384, 4096, 32, 32, False, True, True): (4, 64, 1, 8), + (384, 384, 4096, 32, 32, True, False, True): (4, 32, 1, 4), + (384, 384, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (384, 384, 4096, 64, 64, True, False, True): (1, 64, 1, 4), + (384, 384, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (384, 384, 8192, 32, 32, True, False, True): (3, 64, 1, 1), + (384, 384, 8192, 64, 64, False, True, True): (2, 128, 1, 8), + (384, 384, 8192, 64, 64, True, False, True): (1, 64, 2, 2), + (384, 384, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (384, 384, 16384, 32, 32, True, False, True): (1, 128, 1, 4), + (384, 384, 16384, 64, 64, False, True, True): (2, 256, 1, 8), + (384, 384, 16384, 64, 64, True, False, True): (2, 128, 1, 4), + (384, 384, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (384, 384, 32768, 32, 32, True, False, True): (1, 256, 1, 4), + (384, 384, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (384, 384, 32768, 64, 64, True, False, True): (1, 256, 3, 2), + (384, 384, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (384, 384, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (384, 384, 65536, 64, 64, False, True, True): (2, 1024, 1, 8), + (384, 384, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (384, 384, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (384, 384, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (384, 384, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (384, 384, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (512, 512, 256, 32, 32, False, True, True): (1, 8, 1, 4), + (512, 512, 256, 32, 32, True, False, True): (4, 8, 4, 4), + (512, 512, 256, 64, 64, False, True, True): (3, 4, 1, 16), + (512, 512, 256, 64, 64, True, False, True): (2, 4, 5, 4), + (512, 512, 256, 128, 128, False, True, True): (4, 2, 1, 16), + (512, 512, 256, 128, 128, True, False, True): (1, 2, 3, 4), + (512, 512, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (512, 512, 256, 256, 256, True, False, True): (2, 1, 1, 32), + (512, 512, 512, 32, 32, False, True, True): (3, 16, 1, 4), + (512, 512, 512, 32, 32, True, False, True): (1, 8, 4, 2), + (512, 512, 512, 64, 64, False, True, True): (2, 8, 1, 16), + (512, 512, 512, 64, 64, True, False, True): (2, 8, 5, 4), + (512, 512, 512, 128, 128, False, True, True): (3, 4, 1, 16), + (512, 512, 512, 128, 128, True, False, True): (1, 4, 3, 4), + (512, 512, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (512, 512, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (512, 512, 1024, 32, 32, False, True, True): (2, 32, 1, 4), + (512, 512, 1024, 32, 32, True, False, True): (4, 16, 3, 2), + (512, 512, 1024, 64, 64, False, True, True): (4, 16, 1, 16), + (512, 512, 1024, 64, 64, True, False, True): (1, 8, 4, 4), + (512, 512, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (512, 512, 1024, 128, 128, True, False, True): (1, 8, 3, 4), + (512, 512, 1024, 256, 256, False, True, True): (4, 4, 1, 32), + (512, 512, 1024, 256, 256, True, False, True): (2, 4, 1, 32), + (512, 512, 2048, 32, 32, False, True, True): (3, 32, 1, 8), + (512, 512, 2048, 32, 32, True, False, True): (1, 16, 3, 4), + (512, 512, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (512, 512, 2048, 64, 64, True, False, True): (1, 32, 3, 2), + (512, 512, 2048, 128, 128, False, True, True): (4, 16, 1, 32), + (512, 512, 2048, 128, 128, True, False, True): (1, 16, 3, 4), + (512, 512, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (512, 512, 2048, 256, 256, True, False, True): (3, 8, 1, 32), + (512, 512, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (512, 512, 4096, 32, 32, True, False, True): (5, 32, 1, 4), + (512, 512, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (512, 512, 4096, 64, 64, True, False, True): (1, 64, 1, 4), + (512, 512, 4096, 128, 128, False, True, True): (5, 32, 1, 32), + (512, 512, 4096, 128, 128, True, False, True): (2, 32, 3, 4), + (512, 512, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (512, 512, 4096, 256, 256, True, False, True): (3, 16, 1, 32), + (512, 512, 8192, 32, 32, False, True, True): (3, 128, 1, 8), + (512, 512, 8192, 32, 32, True, False, True): (3, 64, 1, 4), + (512, 512, 8192, 64, 64, False, True, True): (4, 128, 1, 8), + (512, 512, 8192, 64, 64, True, False, True): (1, 64, 3, 2), + (512, 512, 8192, 128, 128, False, True, True): (5, 64, 1, 32), + (512, 512, 8192, 128, 128, True, False, True): (1, 64, 2, 4), + (512, 512, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (512, 512, 8192, 256, 256, True, False, True): (1, 32, 1, 32), + (512, 512, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (512, 512, 16384, 32, 32, True, False, True): (2, 128, 1, 4), + (512, 512, 16384, 64, 64, False, True, True): (2, 256, 1, 8), + (512, 512, 16384, 64, 64, True, False, True): (1, 128, 3, 2), + (512, 512, 16384, 128, 128, False, True, True): (4, 128, 1, 16), + (512, 512, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (512, 512, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (512, 512, 16384, 256, 256, True, False, True): (2, 64, 1, 32), + (512, 512, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (512, 512, 32768, 32, 32, True, False, True): (2, 256, 1, 4), + (512, 512, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (512, 512, 32768, 64, 64, True, False, True): (1, 256, 3, 2), + (512, 512, 32768, 128, 128, False, True, True): (4, 256, 1, 16), + (512, 512, 32768, 128, 128, True, False, True): (2, 256, 1, 4), + (512, 512, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (512, 512, 32768, 256, 256, True, False, True): (2, 128, 1, 32), + (512, 512, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (512, 512, 65536, 32, 32, True, False, True): (2, 512, 1, 2), + (512, 512, 65536, 64, 64, False, True, True): (1, 1024, 1, 8), + (512, 512, 65536, 64, 64, True, False, True): (1, 512, 3, 2), + (512, 512, 65536, 128, 128, False, True, True): (4, 512, 1, 16), + (512, 512, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (512, 512, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (512, 512, 65536, 256, 256, True, False, True): (1, 256, 1, 32), + (512, 512, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (512, 512, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (512, 512, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (512, 512, 131072, 64, 64, True, False, True): (1, 1024, 3, 2), + (512, 512, 131072, 128, 128, False, True, True): (4, 1024, 1, 16), + (512, 512, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (512, 512, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (512, 512, 131072, 256, 256, True, False, True): (2, 512, 1, 32), + (768, 768, 256, 32, 32, False, True, True): (1, 8, 1, 4), + (768, 768, 256, 32, 32, True, False, True): (2, 8, 4, 4), + (768, 768, 256, 64, 64, False, True, True): (3, 4, 1, 16), + (768, 768, 256, 64, 64, True, False, True): (2, 4, 4, 4), + (768, 768, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (768, 768, 256, 128, 128, True, False, True): (1, 2, 3, 4), + (768, 768, 512, 32, 32, False, True, True): (1, 16, 1, 4), + (768, 768, 512, 32, 32, True, False, True): (1, 4, 5, 4), + (768, 768, 512, 64, 64, False, True, True): (1, 8, 3, 32), + (768, 768, 512, 64, 64, True, False, True): (4, 8, 4, 4), + (768, 768, 512, 128, 128, False, True, True): (4, 4, 1, 16), + (768, 768, 512, 128, 128, True, False, True): (4, 4, 3, 4), + (768, 768, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (768, 768, 1024, 32, 32, True, False, True): (1, 8, 3, 4), + (768, 768, 1024, 64, 64, False, True, True): (3, 16, 1, 16), + (768, 768, 1024, 64, 64, True, False, True): (1, 8, 4, 4), + (768, 768, 1024, 128, 128, False, True, True): (3, 8, 1, 32), + (768, 768, 1024, 128, 128, True, False, True): (1, 8, 3, 4), + (768, 768, 2048, 32, 32, False, True, True): (2, 32, 1, 8), + (768, 768, 2048, 32, 32, True, False, True): (3, 16, 1, 4), + (768, 768, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (768, 768, 2048, 64, 64, True, False, True): (4, 8, 3, 4), + (768, 768, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (768, 768, 2048, 128, 128, True, False, True): (1, 16, 3, 4), + (768, 768, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (768, 768, 4096, 32, 32, True, False, True): (1, 32, 1, 1), + (768, 768, 4096, 64, 64, False, True, True): (2, 64, 1, 8), + (768, 768, 4096, 64, 64, True, False, True): (1, 32, 2, 2), + (768, 768, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (768, 768, 4096, 128, 128, True, False, True): (6, 32, 1, 4), + (768, 768, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (768, 768, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (768, 768, 8192, 64, 64, False, True, True): (1, 128, 1, 8), + (768, 768, 8192, 64, 64, True, False, True): (4, 32, 3, 4), + (768, 768, 8192, 128, 128, False, True, True): (2, 64, 1, 16), + (768, 768, 8192, 128, 128, True, False, True): (2, 64, 3, 4), + (768, 768, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (768, 768, 16384, 32, 32, True, False, True): (1, 128, 1, 4), + (768, 768, 16384, 64, 64, False, True, True): (1, 256, 1, 8), + (768, 768, 16384, 64, 64, True, False, True): (1, 128, 3, 2), + (768, 768, 16384, 128, 128, False, True, True): (2, 128, 1, 16), + (768, 768, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (768, 768, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (768, 768, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (768, 768, 32768, 64, 64, False, True, True): (2, 512, 1, 8), + (768, 768, 32768, 64, 64, True, False, True): (1, 256, 3, 2), + (768, 768, 32768, 128, 128, False, True, True): (2, 256, 1, 16), + (768, 768, 32768, 128, 128, True, False, True): (3, 256, 1, 4), + (768, 768, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (768, 768, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (768, 768, 65536, 64, 64, False, True, True): (2, 512, 1, 4), + (768, 768, 65536, 64, 64, True, False, True): (1, 512, 3, 2), + (768, 768, 65536, 128, 128, False, True, True): (2, 512, 1, 16), + (768, 768, 65536, 128, 128, True, False, True): (2, 512, 1, 4), + (768, 768, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (768, 768, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (768, 768, 131072, 64, 64, False, True, True): (2, 1024, 1, 4), + (768, 768, 131072, 64, 64, True, False, True): (2, 1024, 3, 2), + (768, 768, 131072, 128, 128, False, True, True): (2, 1024, 1, 16), + (768, 768, 131072, 128, 128, True, False, True): (2, 1024, 1, 4), + (768, 3072, 256, 32, 32, False, True, True): (3, 8, 4, 8), + (768, 3072, 256, 32, 32, True, False, True): (3, 8, 5, 4), + (768, 3072, 256, 64, 64, False, True, True): (1, 4, 4, 16), + (768, 3072, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (768, 3072, 256, 128, 128, False, True, True): (2, 2, 1, 8), + (768, 3072, 256, 128, 128, True, False, True): (2, 2, 4, 4), + (768, 3072, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (768, 3072, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (768, 3072, 512, 32, 32, False, True, True): (1, 16, 1, 4), + (768, 3072, 512, 32, 32, True, False, True): (2, 4, 4, 4), + (768, 3072, 512, 64, 64, False, True, True): (3, 8, 4, 16), + (768, 3072, 512, 64, 64, True, False, True): (1, 8, 4, 4), + (768, 3072, 512, 128, 128, False, True, True): (2, 4, 1, 8), + (768, 3072, 512, 128, 128, True, False, True): (4, 4, 3, 4), + (768, 3072, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (768, 3072, 512, 256, 256, True, False, True): (1, 2, 1, 32), + (768, 3072, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (768, 3072, 1024, 32, 32, True, False, True): (3, 8, 3, 4), + (768, 3072, 1024, 64, 64, False, True, True): (2, 16, 1, 16), + (768, 3072, 1024, 64, 64, True, False, True): (1, 8, 3, 4), + (768, 3072, 1024, 128, 128, False, True, True): (1, 8, 1, 8), + (768, 3072, 1024, 128, 128, True, False, True): (3, 8, 4, 4), + (768, 3072, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (768, 3072, 1024, 256, 256, True, False, True): (4, 4, 1, 32), + (768, 3072, 2048, 32, 32, False, True, True): (3, 32, 1, 8), + (768, 3072, 2048, 32, 32, True, False, True): (4, 8, 3, 4), + (768, 3072, 2048, 64, 64, False, True, True): (5, 16, 1, 16), + (768, 3072, 2048, 64, 64, True, False, True): (6, 8, 3, 4), + (768, 3072, 2048, 128, 128, False, True, True): (2, 16, 1, 16), + (768, 3072, 2048, 128, 128, True, False, True): (1, 16, 4, 4), + (768, 3072, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (768, 3072, 2048, 256, 256, True, False, True): (1, 8, 1, 32), + (768, 3072, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (768, 3072, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (768, 3072, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (768, 3072, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (768, 3072, 4096, 128, 128, False, True, True): (1, 32, 1, 8), + (768, 3072, 4096, 128, 128, True, False, True): (2, 32, 2, 4), + (768, 3072, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (768, 3072, 4096, 256, 256, True, False, True): (1, 16, 1, 32), + (768, 3072, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (768, 3072, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (768, 3072, 8192, 64, 64, False, True, True): (1, 128, 1, 8), + (768, 3072, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (768, 3072, 8192, 128, 128, False, True, True): (2, 64, 1, 16), + (768, 3072, 8192, 128, 128, True, False, True): (2, 64, 3, 4), + (768, 3072, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (768, 3072, 8192, 256, 256, True, False, True): (1, 32, 1, 32), + (768, 3072, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (768, 3072, 16384, 32, 32, True, False, True): (1, 128, 1, 4), + (768, 3072, 16384, 64, 64, False, True, True): (1, 256, 1, 8), + (768, 3072, 16384, 64, 64, True, False, True): (2, 64, 3, 4), + (768, 3072, 16384, 128, 128, False, True, True): (2, 128, 1, 16), + (768, 3072, 16384, 128, 128, True, False, True): (2, 128, 3, 4), + (768, 3072, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (768, 3072, 16384, 256, 256, True, False, True): (1, 64, 1, 32), + (768, 3072, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (768, 3072, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (768, 3072, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (768, 3072, 32768, 64, 64, True, False, True): (3, 128, 3, 4), + (768, 3072, 32768, 128, 128, False, True, True): (2, 256, 1, 16), + (768, 3072, 32768, 128, 128, True, False, True): (2, 256, 3, 4), + (768, 3072, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (768, 3072, 32768, 256, 256, True, False, True): (1, 128, 1, 32), + (768, 3072, 50432, 32, 32, False, True, True): (1, 788, 1, 8), + (768, 3072, 50432, 32, 32, True, False, True): (1, 394, 3, 2), + (768, 3072, 50432, 64, 64, False, True, True): (1, 788, 1, 8), + (768, 3072, 50432, 64, 64, True, False, True): (2, 197, 3, 4), + (768, 3072, 50432, 128, 128, False, True, True): (2, 394, 1, 16), + (768, 3072, 50432, 128, 128, True, False, True): (2, 394, 3, 4), + (768, 3072, 50432, 256, 256, False, True, True): (1, 197, 1, 32), + (768, 3072, 50432, 256, 256, True, False, True): (1, 197, 1, 32), + (768, 3072, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (768, 3072, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (768, 3072, 65536, 64, 64, False, True, True): (1, 1024, 1, 8), + (768, 3072, 65536, 64, 64, True, False, True): (2, 256, 3, 4), + (768, 3072, 65536, 128, 128, False, True, True): (2, 512, 1, 16), + (768, 3072, 65536, 128, 128, True, False, True): (2, 512, 3, 4), + (768, 3072, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (768, 3072, 65536, 256, 256, True, False, True): (1, 256, 1, 32), + (768, 3072, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (768, 3072, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (768, 3072, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (768, 3072, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (768, 3072, 131072, 128, 128, False, True, True): (2, 1024, 1, 16), + (768, 3072, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (768, 3072, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (768, 3072, 131072, 256, 256, True, False, True): (1, 512, 1, 32), + (1024, 1024, 256, 32, 32, False, True, True): (1, 8, 1, 4), + (1024, 1024, 256, 32, 32, True, False, True): (1, 8, 5, 4), + (1024, 1024, 256, 64, 64, False, True, True): (1, 4, 1, 16), + (1024, 1024, 256, 64, 64, True, False, True): (4, 4, 4, 4), + (1024, 1024, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (1024, 1024, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (1024, 1024, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (1024, 1024, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (1024, 1024, 512, 32, 32, False, True, True): (5, 16, 1, 4), + (1024, 1024, 512, 32, 32, True, False, True): (2, 8, 4, 2), + (1024, 1024, 512, 64, 64, False, True, True): (4, 8, 1, 16), + (1024, 1024, 512, 64, 64, True, False, True): (1, 4, 3, 4), + (1024, 1024, 512, 128, 128, False, True, True): (3, 4, 1, 16), + (1024, 1024, 512, 128, 128, True, False, True): (1, 4, 2, 4), + (1024, 1024, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (1024, 1024, 512, 256, 256, True, False, True): (1, 2, 1, 32), + (1024, 1024, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (1024, 1024, 1024, 32, 32, True, False, True): (1, 8, 3, 4), + (1024, 1024, 1024, 64, 64, False, True, True): (3, 16, 1, 8), + (1024, 1024, 1024, 64, 64, True, False, True): (1, 16, 3, 2), + (1024, 1024, 1024, 128, 128, False, True, True): (1, 8, 1, 16), + (1024, 1024, 1024, 128, 128, True, False, True): (2, 8, 3, 8), + (1024, 1024, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (1024, 1024, 1024, 256, 256, True, False, True): (2, 4, 1, 32), + (1024, 1024, 2048, 32, 32, False, True, True): (2, 32, 1, 8), + (1024, 1024, 2048, 32, 32, True, False, True): (3, 16, 1, 4), + (1024, 1024, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (1024, 1024, 2048, 64, 64, True, False, True): (3, 32, 1, 4), + (1024, 1024, 2048, 128, 128, False, True, True): (4, 16, 1, 16), + (1024, 1024, 2048, 128, 128, True, False, True): (1, 16, 3, 4), + (1024, 1024, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (1024, 1024, 2048, 256, 256, True, False, True): (1, 8, 1, 32), + (1024, 1024, 4096, 32, 32, False, True, True): (4, 64, 1, 8), + (1024, 1024, 4096, 32, 32, True, False, True): (3, 32, 1, 4), + (1024, 1024, 4096, 64, 64, False, True, True): (3, 64, 1, 8), + (1024, 1024, 4096, 64, 64, True, False, True): (1, 32, 3, 2), + (1024, 1024, 4096, 128, 128, False, True, True): (4, 32, 1, 16), + (1024, 1024, 4096, 128, 128, True, False, True): (2, 32, 2, 4), + (1024, 1024, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (1024, 1024, 4096, 256, 256, True, False, True): (7, 16, 1, 32), + (1024, 1024, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (1024, 1024, 8192, 32, 32, True, False, True): (4, 64, 1, 4), + (1024, 1024, 8192, 64, 64, False, True, True): (2, 128, 1, 8), + (1024, 1024, 8192, 64, 64, True, False, True): (3, 32, 3, 4), + (1024, 1024, 8192, 128, 128, False, True, True): (4, 64, 1, 16), + (1024, 1024, 8192, 128, 128, True, False, True): (2, 64, 2, 4), + (1024, 1024, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (1024, 1024, 8192, 256, 256, True, False, True): (1, 32, 1, 32), + (1024, 1024, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (1024, 1024, 16384, 32, 32, True, False, True): (1, 128, 1, 4), + (1024, 1024, 16384, 64, 64, False, True, True): (1, 256, 1, 8), + (1024, 1024, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (1024, 1024, 16384, 128, 128, False, True, True): (4, 128, 1, 16), + (1024, 1024, 16384, 128, 128, True, False, True): (1, 128, 3, 4), + (1024, 1024, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (1024, 1024, 16384, 256, 256, True, False, True): (1, 64, 1, 32), + (1024, 1024, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (1024, 1024, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (1024, 1024, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (1024, 1024, 32768, 64, 64, True, False, True): (4, 128, 3, 4), + (1024, 1024, 32768, 128, 128, False, True, True): (4, 256, 1, 16), + (1024, 1024, 32768, 128, 128, True, False, True): (2, 256, 3, 4), + (1024, 1024, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (1024, 1024, 32768, 256, 256, True, False, True): (2, 128, 1, 32), + (1024, 1024, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (1024, 1024, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (1024, 1024, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (1024, 1024, 65536, 64, 64, True, False, True): (2, 256, 3, 4), + (1024, 1024, 65536, 128, 128, False, True, True): (4, 512, 1, 16), + (1024, 1024, 65536, 128, 128, True, False, True): (4, 512, 3, 4), + (1024, 1024, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (1024, 1024, 65536, 256, 256, True, False, True): (1, 256, 1, 32), + (1024, 1024, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (1024, 1024, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (1024, 1024, 131072, 64, 64, False, True, True): (2, 1024, 1, 4), + (1024, 1024, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (1024, 1024, 131072, 128, 128, False, True, True): (4, 1024, 1, 16), + (1024, 1024, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (1024, 1024, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (1024, 1024, 131072, 256, 256, True, False, True): (1, 512, 1, 32), + (1536, 1536, 256, 32, 32, False, True, True): (1, 8, 1, 4), + (1536, 1536, 256, 32, 32, True, False, True): (2, 8, 1, 8), + (1536, 1536, 256, 64, 64, False, True, True): (4, 4, 1, 16), + (1536, 1536, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (1536, 1536, 256, 128, 128, False, True, True): (2, 2, 1, 16), + (1536, 1536, 256, 128, 128, True, False, True): (2, 2, 3, 4), + (1536, 1536, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (1536, 1536, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (1536, 1536, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (1536, 1536, 512, 32, 32, True, False, True): (3, 4, 4, 4), + (1536, 1536, 512, 64, 64, False, True, True): (3, 8, 1, 16), + (1536, 1536, 512, 64, 64, True, False, True): (1, 4, 3, 4), + (1536, 1536, 512, 128, 128, False, True, True): (1, 4, 1, 16), + (1536, 1536, 512, 128, 128, True, False, True): (2, 4, 4, 4), + (1536, 1536, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (1536, 1536, 512, 256, 256, True, False, True): (1, 2, 1, 32), + (1536, 1536, 1024, 32, 32, False, True, True): (4, 16, 1, 8), + (1536, 1536, 1024, 32, 32, True, False, True): (2, 8, 1, 4), + (1536, 1536, 1024, 64, 64, False, True, True): (2, 16, 1, 16), + (1536, 1536, 1024, 64, 64, True, False, True): (2, 4, 3, 4), + (1536, 1536, 1024, 128, 128, False, True, True): (3, 8, 1, 32), + (1536, 1536, 1024, 128, 128, True, False, True): (4, 8, 3, 4), + (1536, 1536, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (1536, 1536, 1024, 256, 256, True, False, True): (1, 4, 1, 32), + (1536, 1536, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (1536, 1536, 2048, 32, 32, True, False, True): (1, 16, 1, 4), + (1536, 1536, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (1536, 1536, 2048, 64, 64, True, False, True): (1, 16, 2, 2), + (1536, 1536, 2048, 128, 128, False, True, True): (2, 16, 1, 16), + (1536, 1536, 2048, 128, 128, True, False, True): (4, 16, 2, 4), + (1536, 1536, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (1536, 1536, 2048, 256, 256, True, False, True): (1, 8, 1, 32), + (1536, 1536, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (1536, 1536, 4096, 32, 32, True, False, True): (1, 32, 1, 4), + (1536, 1536, 4096, 64, 64, False, True, True): (3, 64, 1, 8), + (1536, 1536, 4096, 64, 64, True, False, True): (1, 32, 3, 2), + (1536, 1536, 4096, 128, 128, False, True, True): (1, 32, 1, 8), + (1536, 1536, 4096, 128, 128, True, False, True): (2, 32, 2, 4), + (1536, 1536, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (1536, 1536, 4096, 256, 256, True, False, True): (2, 16, 1, 32), + (1536, 1536, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (1536, 1536, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (1536, 1536, 8192, 64, 64, False, True, True): (3, 128, 1, 8), + (1536, 1536, 8192, 64, 64, True, False, True): (1, 64, 3, 2), + (1536, 1536, 8192, 128, 128, False, True, True): (1, 64, 1, 8), + (1536, 1536, 8192, 128, 128, True, False, True): (1, 64, 2, 4), + (1536, 1536, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (1536, 1536, 8192, 256, 256, True, False, True): (2, 32, 1, 32), + (1536, 1536, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (1536, 1536, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (1536, 1536, 16384, 64, 64, False, True, True): (2, 128, 1, 4), + (1536, 1536, 16384, 64, 64, True, False, True): (2, 64, 3, 4), + (1536, 1536, 16384, 128, 128, False, True, True): (1, 128, 1, 8), + (1536, 1536, 16384, 128, 128, True, False, True): (2, 128, 2, 4), + (1536, 1536, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (1536, 1536, 16384, 256, 256, True, False, True): (2, 64, 1, 32), + (1536, 1536, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (1536, 1536, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (1536, 1536, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (1536, 1536, 32768, 64, 64, True, False, True): (3, 128, 3, 4), + (1536, 1536, 32768, 128, 128, False, True, True): (1, 256, 1, 8), + (1536, 1536, 32768, 128, 128, True, False, True): (1, 256, 2, 4), + (1536, 1536, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (1536, 1536, 32768, 256, 256, True, False, True): (2, 128, 1, 32), + (1536, 1536, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (1536, 1536, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (1536, 1536, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (1536, 1536, 65536, 64, 64, True, False, True): (1, 512, 3, 2), + (1536, 1536, 65536, 128, 128, False, True, True): (1, 512, 1, 8), + (1536, 1536, 65536, 128, 128, True, False, True): (1, 512, 3, 4), + (1536, 1536, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (1536, 1536, 65536, 256, 256, True, False, True): (2, 256, 1, 32), + (1536, 1536, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (1536, 1536, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (1536, 1536, 131072, 64, 64, False, True, True): (3, 1024, 1, 4), + (1536, 1536, 131072, 64, 64, True, False, True): (3, 512, 3, 4), + (1536, 1536, 131072, 128, 128, False, True, True): (1, 1024, 1, 8), + (1536, 1536, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (1536, 1536, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (1536, 1536, 131072, 256, 256, True, False, True): (2, 512, 1, 32), + (2048, 2048, 256, 32, 32, False, True, True): (3, 8, 1, 4), + (2048, 2048, 256, 32, 32, True, False, True): (1, 4, 4, 2), + (2048, 2048, 256, 64, 64, False, True, True): (2, 4, 1, 16), + (2048, 2048, 256, 64, 64, True, False, True): (1, 2, 3, 4), + (2048, 2048, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (2048, 2048, 256, 128, 128, True, False, True): (1, 2, 4, 4), + (2048, 2048, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (2048, 2048, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (2048, 2048, 512, 32, 32, False, True, True): (3, 8, 1, 8), + (2048, 2048, 512, 32, 32, True, False, True): (4, 4, 3, 2), + (2048, 2048, 512, 64, 64, False, True, True): (1, 8, 1, 8), + (2048, 2048, 512, 64, 64, True, False, True): (1, 8, 3, 4), + (2048, 2048, 512, 128, 128, False, True, True): (1, 4, 1, 8), + (2048, 2048, 512, 128, 128, True, False, True): (1, 4, 4, 4), + (2048, 2048, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (2048, 2048, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (2048, 2048, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (2048, 2048, 1024, 32, 32, True, False, True): (3, 8, 1, 4), + (2048, 2048, 1024, 64, 64, False, True, True): (4, 16, 1, 8), + (2048, 2048, 1024, 64, 64, True, False, True): (1, 8, 3, 2), + (2048, 2048, 1024, 128, 128, False, True, True): (4, 8, 1, 16), + (2048, 2048, 1024, 128, 128, True, False, True): (2, 8, 2, 4), + (2048, 2048, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (2048, 2048, 1024, 256, 256, True, False, True): (3, 4, 1, 32), + (2048, 2048, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (2048, 2048, 2048, 32, 32, True, False, True): (1, 16, 1, 4), + (2048, 2048, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (2048, 2048, 2048, 64, 64, True, False, True): (1, 16, 3, 2), + (2048, 2048, 2048, 128, 128, False, True, True): (4, 16, 1, 16), + (2048, 2048, 2048, 128, 128, True, False, True): (2, 16, 2, 4), + (2048, 2048, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (2048, 2048, 2048, 256, 256, True, False, True): (1, 8, 1, 32), + (2048, 2048, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (2048, 2048, 4096, 32, 32, True, False, True): (1, 32, 1, 4), + (2048, 2048, 4096, 64, 64, False, True, True): (4, 64, 1, 8), + (2048, 2048, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (2048, 2048, 4096, 128, 128, False, True, True): (4, 32, 1, 8), + (2048, 2048, 4096, 128, 128, True, False, True): (1, 32, 2, 4), + (2048, 2048, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (2048, 2048, 4096, 256, 256, True, False, True): (4, 16, 1, 32), + (2048, 2048, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (2048, 2048, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (2048, 2048, 8192, 64, 64, False, True, True): (2, 64, 1, 4), + (2048, 2048, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (2048, 2048, 8192, 128, 128, False, True, True): (4, 64, 1, 8), + (2048, 2048, 8192, 128, 128, True, False, True): (2, 64, 2, 4), + (2048, 2048, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (2048, 2048, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (2048, 2048, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (2048, 2048, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (2048, 2048, 16384, 64, 64, False, True, True): (2, 128, 1, 4), + (2048, 2048, 16384, 64, 64, True, False, True): (2, 64, 3, 4), + (2048, 2048, 16384, 128, 128, False, True, True): (1, 128, 1, 8), + (2048, 2048, 16384, 128, 128, True, False, True): (2, 128, 2, 4), + (2048, 2048, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (2048, 2048, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (2048, 2048, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (2048, 2048, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (2048, 2048, 32768, 64, 64, False, True, True): (2, 256, 1, 4), + (2048, 2048, 32768, 64, 64, True, False, True): (2, 128, 3, 4), + (2048, 2048, 32768, 128, 128, False, True, True): (1, 256, 1, 8), + (2048, 2048, 32768, 128, 128, True, False, True): (2, 256, 2, 4), + (2048, 2048, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (2048, 2048, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (2048, 2048, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (2048, 2048, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (2048, 2048, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (2048, 2048, 65536, 64, 64, True, False, True): (2, 256, 3, 4), + (2048, 2048, 65536, 128, 128, False, True, True): (1, 512, 1, 8), + (2048, 2048, 65536, 128, 128, True, False, True): (1, 512, 2, 4), + (2048, 2048, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (2048, 2048, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (2048, 2048, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (2048, 2048, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (2048, 2048, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (2048, 2048, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (2048, 2048, 131072, 128, 128, False, True, True): (1, 1024, 1, 8), + (2048, 2048, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (2048, 2048, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (2048, 2048, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (3072, 768, 256, 32, 32, False, True, True): (5, 4, 1, 8), + (3072, 768, 256, 32, 32, True, False, True): (2, 2, 4, 4), + (3072, 768, 256, 64, 64, False, True, True): (1, 4, 1, 16), + (3072, 768, 256, 64, 64, True, False, True): (2, 2, 3, 4), + (3072, 768, 256, 128, 128, False, True, True): (5, 2, 1, 16), + (3072, 768, 256, 128, 128, True, False, True): (1, 2, 5, 4), + (3072, 768, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (3072, 768, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (3072, 768, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (3072, 768, 512, 32, 32, True, False, True): (5, 4, 1, 4), + (3072, 768, 512, 64, 64, False, True, True): (1, 8, 1, 8), + (3072, 768, 512, 64, 64, True, False, True): (3, 2, 3, 4), + (3072, 768, 512, 128, 128, False, True, True): (3, 4, 1, 32), + (3072, 768, 512, 128, 128, True, False, True): (2, 4, 3, 4), + (3072, 768, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (3072, 768, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (3072, 768, 1024, 32, 32, False, True, True): (2, 16, 1, 8), + (3072, 768, 1024, 32, 32, True, False, True): (3, 8, 1, 4), + (3072, 768, 1024, 64, 64, False, True, True): (4, 16, 1, 8), + (3072, 768, 1024, 64, 64, True, False, True): (1, 8, 3, 2), + (3072, 768, 1024, 128, 128, False, True, True): (2, 8, 1, 32), + (3072, 768, 1024, 128, 128, True, False, True): (3, 8, 2, 4), + (3072, 768, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (3072, 768, 1024, 256, 256, True, False, True): (4, 4, 1, 32), + (3072, 768, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (3072, 768, 2048, 32, 32, True, False, True): (1, 16, 1, 4), + (3072, 768, 2048, 64, 64, False, True, True): (2, 32, 1, 8), + (3072, 768, 2048, 64, 64, True, False, True): (2, 8, 3, 4), + (3072, 768, 2048, 128, 128, False, True, True): (2, 16, 1, 16), + (3072, 768, 2048, 128, 128, True, False, True): (2, 16, 1, 4), + (3072, 768, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (3072, 768, 2048, 256, 256, True, False, True): (2, 8, 1, 32), + (3072, 768, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (3072, 768, 4096, 32, 32, True, False, True): (1, 32, 1, 2), + (3072, 768, 4096, 64, 64, False, True, True): (2, 64, 1, 8), + (3072, 768, 4096, 64, 64, True, False, True): (2, 32, 2, 2), + (3072, 768, 4096, 128, 128, False, True, True): (1, 32, 1, 8), + (3072, 768, 4096, 128, 128, True, False, True): (2, 32, 2, 4), + (3072, 768, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (3072, 768, 4096, 256, 256, True, False, True): (4, 16, 1, 32), + (3072, 768, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (3072, 768, 8192, 32, 32, True, False, True): (3, 64, 1, 2), + (3072, 768, 8192, 64, 64, False, True, True): (1, 128, 1, 8), + (3072, 768, 8192, 64, 64, True, False, True): (2, 64, 2, 2), + (3072, 768, 8192, 128, 128, False, True, True): (1, 64, 1, 8), + (3072, 768, 8192, 128, 128, True, False, True): (2, 64, 2, 4), + (3072, 768, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (3072, 768, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (3072, 768, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (3072, 768, 16384, 32, 32, True, False, True): (1, 128, 1, 2), + (3072, 768, 16384, 64, 64, False, True, True): (2, 128, 1, 4), + (3072, 768, 16384, 64, 64, True, False, True): (1, 128, 2, 2), + (3072, 768, 16384, 128, 128, False, True, True): (1, 128, 1, 8), + (3072, 768, 16384, 128, 128, True, False, True): (1, 128, 1, 4), + (3072, 768, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (3072, 768, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (3072, 768, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (3072, 768, 32768, 32, 32, True, False, True): (1, 256, 1, 2), + (3072, 768, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (3072, 768, 32768, 64, 64, True, False, True): (2, 256, 2, 2), + (3072, 768, 32768, 128, 128, False, True, True): (1, 256, 1, 8), + (3072, 768, 32768, 128, 128, True, False, True): (2, 256, 1, 4), + (3072, 768, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (3072, 768, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (3072, 768, 50432, 32, 32, False, True, True): (1, 788, 1, 8), + (3072, 768, 50432, 32, 32, True, False, True): (1, 394, 1, 2), + (3072, 768, 50432, 64, 64, False, True, True): (2, 394, 1, 4), + (3072, 768, 50432, 64, 64, True, False, True): (2, 394, 2, 2), + (3072, 768, 50432, 128, 128, False, True, True): (1, 394, 1, 8), + (3072, 768, 50432, 128, 128, True, False, True): (2, 394, 1, 4), + (3072, 768, 50432, 256, 256, False, True, True): (1, 197, 1, 32), + (3072, 768, 50432, 256, 256, True, False, True): (1, 197, 1, 32), + (3072, 768, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (3072, 768, 65536, 32, 32, True, False, True): (1, 512, 1, 2), + (3072, 768, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (3072, 768, 65536, 64, 64, True, False, True): (2, 512, 2, 2), + (3072, 768, 65536, 128, 128, False, True, True): (1, 512, 1, 8), + (3072, 768, 65536, 128, 128, True, False, True): (2, 512, 1, 4), + (3072, 768, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (3072, 768, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (3072, 768, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (3072, 768, 131072, 32, 32, True, False, True): (1, 1024, 1, 2), + (3072, 768, 131072, 64, 64, False, True, True): (2, 1024, 1, 4), + (3072, 768, 131072, 64, 64, True, False, True): (2, 1024, 2, 2), + (3072, 768, 131072, 128, 128, False, True, True): (1, 1024, 1, 8), + (3072, 768, 131072, 128, 128, True, False, True): (2, 1024, 1, 4), + (3072, 768, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (3072, 768, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (3072, 3072, 256, 32, 32, False, True, True): (1, 4, 1, 8), + (3072, 3072, 256, 32, 32, True, False, True): (2, 2, 5, 4), + (3072, 3072, 256, 64, 64, False, True, True): (2, 4, 1, 16), + (3072, 3072, 256, 64, 64, True, False, True): (3, 2, 3, 4), + (3072, 3072, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (3072, 3072, 256, 128, 128, True, False, True): (1, 2, 5, 4), + (3072, 3072, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (3072, 3072, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (3072, 3072, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (3072, 3072, 512, 32, 32, True, False, True): (3, 2, 3, 4), + (3072, 3072, 512, 64, 64, False, True, True): (1, 8, 1, 8), + (3072, 3072, 512, 64, 64, True, False, True): (3, 2, 3, 4), + (3072, 3072, 512, 128, 128, False, True, True): (2, 4, 1, 8), + (3072, 3072, 512, 128, 128, True, False, True): (2, 4, 4, 4), + (3072, 3072, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (3072, 3072, 512, 256, 256, True, False, True): (1, 2, 1, 32), + (3072, 3072, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (3072, 3072, 1024, 32, 32, True, False, True): (3, 8, 3, 4), + (3072, 3072, 1024, 64, 64, False, True, True): (2, 16, 1, 8), + (3072, 3072, 1024, 64, 64, True, False, True): (2, 4, 3, 4), + (3072, 3072, 1024, 128, 128, False, True, True): (1, 8, 1, 8), + (3072, 3072, 1024, 128, 128, True, False, True): (3, 8, 2, 4), + (3072, 3072, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (3072, 3072, 1024, 256, 256, True, False, True): (3, 4, 1, 32), + (3072, 3072, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (3072, 3072, 2048, 32, 32, True, False, True): (1, 16, 1, 4), + (3072, 3072, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (3072, 3072, 2048, 64, 64, True, False, True): (1, 16, 3, 2), + (3072, 3072, 2048, 128, 128, False, True, True): (1, 16, 1, 8), + (3072, 3072, 2048, 128, 128, True, False, True): (2, 16, 2, 4), + (3072, 3072, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (3072, 3072, 2048, 256, 256, True, False, True): (3, 8, 1, 32), + (3072, 3072, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (3072, 3072, 4096, 32, 32, True, False, True): (1, 32, 1, 4), + (3072, 3072, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (3072, 3072, 4096, 64, 64, True, False, True): (3, 16, 3, 4), + (3072, 3072, 4096, 128, 128, False, True, True): (1, 32, 1, 8), + (3072, 3072, 4096, 128, 128, True, False, True): (2, 32, 2, 4), + (3072, 3072, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (3072, 3072, 4096, 256, 256, True, False, True): (2, 16, 1, 32), + (3072, 3072, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (3072, 3072, 8192, 32, 32, True, False, True): (1, 64, 1, 2), + (3072, 3072, 8192, 64, 64, False, True, True): (1, 64, 1, 4), + (3072, 3072, 8192, 64, 64, True, False, True): (1, 64, 3, 2), + (3072, 3072, 8192, 128, 128, False, True, True): (1, 64, 1, 8), + (3072, 3072, 8192, 128, 128, True, False, True): (2, 64, 2, 4), + (3072, 3072, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (3072, 3072, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (3072, 3072, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (3072, 3072, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (3072, 3072, 16384, 64, 64, False, True, True): (1, 128, 1, 4), + (3072, 3072, 16384, 64, 64, True, False, True): (2, 64, 3, 4), + (3072, 3072, 16384, 128, 128, False, True, True): (1, 128, 1, 8), + (3072, 3072, 16384, 128, 128, True, False, True): (1, 128, 2, 4), + (3072, 3072, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (3072, 3072, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (3072, 3072, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (3072, 3072, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (3072, 3072, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (3072, 3072, 32768, 64, 64, True, False, True): (1, 256, 3, 2), + (3072, 3072, 32768, 128, 128, False, True, True): (1, 256, 1, 8), + (3072, 3072, 32768, 128, 128, True, False, True): (1, 256, 2, 4), + (3072, 3072, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (3072, 3072, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (3072, 3072, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (3072, 3072, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (3072, 3072, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (3072, 3072, 65536, 64, 64, True, False, True): (2, 256, 3, 4), + (3072, 3072, 65536, 128, 128, False, True, True): (1, 512, 1, 8), + (3072, 3072, 65536, 128, 128, True, False, True): (1, 512, 3, 4), + (3072, 3072, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (3072, 3072, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (3072, 3072, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (3072, 3072, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (3072, 3072, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (3072, 3072, 131072, 64, 64, True, False, True): (1, 1024, 3, 2), + (3072, 3072, 131072, 128, 128, False, True, True): (1, 1024, 1, 8), + (3072, 3072, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (3072, 3072, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (3072, 3072, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (4096, 4096, 256, 32, 32, False, True, True): (1, 4, 1, 8), + (4096, 4096, 256, 32, 32, True, False, True): (5, 2, 3, 4), + (4096, 4096, 256, 64, 64, False, True, True): (3, 4, 1, 8), + (4096, 4096, 256, 64, 64, True, False, True): (3, 4, 3, 2), + (4096, 4096, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (4096, 4096, 256, 128, 128, True, False, True): (2, 2, 4, 4), + (4096, 4096, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (4096, 4096, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (4096, 4096, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (4096, 4096, 512, 32, 32, True, False, True): (1, 4, 1, 4), + (4096, 4096, 512, 64, 64, False, True, True): (1, 8, 1, 8), + (4096, 4096, 512, 64, 64, True, False, True): (3, 4, 2, 2), + (4096, 4096, 512, 128, 128, False, True, True): (2, 4, 1, 8), + (4096, 4096, 512, 128, 128, True, False, True): (2, 4, 2, 4), + (4096, 4096, 512, 256, 256, False, True, True): (2, 2, 1, 32), + (4096, 4096, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (4096, 4096, 1024, 32, 32, False, True, True): (4, 16, 1, 8), + (4096, 4096, 1024, 32, 32, True, False, True): (1, 8, 1, 4), + (4096, 4096, 1024, 64, 64, False, True, True): (1, 16, 1, 8), + (4096, 4096, 1024, 64, 64, True, False, True): (4, 4, 3, 4), + (4096, 4096, 1024, 128, 128, False, True, True): (2, 8, 1, 8), + (4096, 4096, 1024, 128, 128, True, False, True): (1, 8, 3, 4), + (4096, 4096, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (4096, 4096, 1024, 256, 256, True, False, True): (6, 4, 1, 32), + (4096, 4096, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (4096, 4096, 2048, 32, 32, True, False, True): (1, 16, 1, 4), + (4096, 4096, 2048, 64, 64, False, True, True): (4, 32, 1, 8), + (4096, 4096, 2048, 64, 64, True, False, True): (4, 8, 3, 4), + (4096, 4096, 2048, 128, 128, False, True, True): (2, 16, 1, 8), + (4096, 4096, 2048, 128, 128, True, False, True): (1, 16, 3, 4), + (4096, 4096, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (4096, 4096, 2048, 256, 256, True, False, True): (4, 8, 1, 32), + (4096, 4096, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (4096, 4096, 4096, 32, 32, True, False, True): (1, 32, 1, 4), + (4096, 4096, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (4096, 4096, 4096, 64, 64, True, False, True): (1, 32, 3, 2), + (4096, 4096, 4096, 128, 128, False, True, True): (1, 32, 1, 8), + (4096, 4096, 4096, 128, 128, True, False, True): (2, 32, 3, 4), + (4096, 4096, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (4096, 4096, 4096, 256, 256, True, False, True): (4, 16, 1, 32), + (4096, 4096, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (4096, 4096, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (4096, 4096, 8192, 64, 64, False, True, True): (1, 128, 1, 8), + (4096, 4096, 8192, 64, 64, True, False, True): (1, 64, 3, 2), + (4096, 4096, 8192, 128, 128, False, True, True): (1, 64, 1, 8), + (4096, 4096, 8192, 128, 128, True, False, True): (1, 64, 3, 4), + (4096, 4096, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (4096, 4096, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (4096, 4096, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (4096, 4096, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (4096, 4096, 16384, 64, 64, False, True, True): (1, 128, 1, 4), + (4096, 4096, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (4096, 4096, 16384, 128, 128, False, True, True): (1, 128, 1, 8), + (4096, 4096, 16384, 128, 128, True, False, True): (1, 128, 3, 4), + (4096, 4096, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (4096, 4096, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (4096, 4096, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (4096, 4096, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (4096, 4096, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (4096, 4096, 32768, 64, 64, True, False, True): (1, 256, 3, 2), + (4096, 4096, 32768, 128, 128, False, True, True): (1, 256, 1, 8), + (4096, 4096, 32768, 128, 128, True, False, True): (1, 256, 3, 4), + (4096, 4096, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (4096, 4096, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (4096, 4096, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (4096, 4096, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (4096, 4096, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (4096, 4096, 65536, 64, 64, True, False, True): (4, 256, 3, 4), + (4096, 4096, 65536, 128, 128, False, True, True): (1, 512, 1, 8), + (4096, 4096, 65536, 128, 128, True, False, True): (1, 512, 3, 4), + (4096, 4096, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (4096, 4096, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (4096, 4096, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (4096, 4096, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (4096, 4096, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (4096, 4096, 131072, 64, 64, True, False, True): (1, 1024, 3, 2), + (4096, 4096, 131072, 128, 128, False, True, True): (1, 1024, 1, 8), + (4096, 4096, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (4096, 4096, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (4096, 4096, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (6144, 6144, 256, 32, 32, False, True, True): (2, 4, 1, 8), + (6144, 6144, 256, 32, 32, True, False, True): (2, 1, 4, 4), + (6144, 6144, 256, 64, 64, False, True, True): (1, 4, 1, 8), + (6144, 6144, 256, 64, 64, True, False, True): (5, 1, 3, 4), + (6144, 6144, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (6144, 6144, 256, 128, 128, True, False, True): (1, 2, 3, 4), + (6144, 6144, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (6144, 6144, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (6144, 6144, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (6144, 6144, 512, 32, 32, True, False, True): (1, 4, 4, 2), + (6144, 6144, 512, 64, 64, False, True, True): (2, 8, 1, 8), + (6144, 6144, 512, 64, 64, True, False, True): (2, 2, 3, 4), + (6144, 6144, 512, 128, 128, False, True, True): (3, 4, 1, 8), + (6144, 6144, 512, 128, 128, True, False, True): (2, 4, 3, 4), + (6144, 6144, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (6144, 6144, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (6144, 6144, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (6144, 6144, 1024, 32, 32, True, False, True): (1, 8, 1, 4), + (6144, 6144, 1024, 64, 64, False, True, True): (1, 16, 1, 8), + (6144, 6144, 1024, 64, 64, True, False, True): (4, 4, 3, 4), + (6144, 6144, 1024, 128, 128, False, True, True): (1, 8, 1, 8), + (6144, 6144, 1024, 128, 128, True, False, True): (3, 8, 3, 4), + (6144, 6144, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (6144, 6144, 1024, 256, 256, True, False, True): (1, 4, 1, 32), + (6144, 6144, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (6144, 6144, 2048, 32, 32, True, False, True): (1, 16, 1, 4), + (6144, 6144, 2048, 64, 64, False, True, True): (1, 32, 1, 8), + (6144, 6144, 2048, 64, 64, True, False, True): (4, 8, 3, 4), + (6144, 6144, 2048, 128, 128, False, True, True): (1, 16, 1, 8), + (6144, 6144, 2048, 128, 128, True, False, True): (3, 16, 3, 4), + (6144, 6144, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (6144, 6144, 2048, 256, 256, True, False, True): (4, 8, 1, 32), + (6144, 6144, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (6144, 6144, 4096, 32, 32, True, False, True): (1, 32, 1, 4), + (6144, 6144, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (6144, 6144, 4096, 64, 64, True, False, True): (4, 16, 3, 4), + (6144, 6144, 4096, 128, 128, False, True, True): (1, 32, 1, 8), + (6144, 6144, 4096, 128, 128, True, False, True): (4, 32, 3, 4), + (6144, 6144, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (6144, 6144, 4096, 256, 256, True, False, True): (4, 16, 1, 32), + (6144, 6144, 8192, 32, 32, False, True, True): (1, 128, 1, 8), + (6144, 6144, 8192, 32, 32, True, False, True): (1, 64, 1, 4), + (6144, 6144, 8192, 64, 64, False, True, True): (1, 128, 1, 8), + (6144, 6144, 8192, 64, 64, True, False, True): (4, 32, 3, 4), + (6144, 6144, 8192, 128, 128, False, True, True): (1, 64, 1, 8), + (6144, 6144, 8192, 128, 128, True, False, True): (1, 64, 3, 4), + (6144, 6144, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (6144, 6144, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (6144, 6144, 16384, 32, 32, False, True, True): (1, 256, 1, 8), + (6144, 6144, 16384, 32, 32, True, False, True): (1, 128, 1, 4), + (6144, 6144, 16384, 64, 64, False, True, True): (1, 256, 1, 8), + (6144, 6144, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (6144, 6144, 16384, 128, 128, False, True, True): (1, 128, 1, 8), + (6144, 6144, 16384, 128, 128, True, False, True): (4, 128, 3, 4), + (6144, 6144, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (6144, 6144, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (6144, 6144, 32768, 32, 32, False, True, True): (1, 512, 1, 8), + (6144, 6144, 32768, 32, 32, True, False, True): (1, 256, 1, 4), + (6144, 6144, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (6144, 6144, 32768, 64, 64, True, False, True): (4, 128, 3, 4), + (6144, 6144, 32768, 128, 128, False, True, True): (1, 256, 1, 8), + (6144, 6144, 32768, 128, 128, True, False, True): (1, 256, 3, 4), + (6144, 6144, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (6144, 6144, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (6144, 6144, 65536, 32, 32, False, True, True): (1, 1024, 1, 8), + (6144, 6144, 65536, 32, 32, True, False, True): (1, 512, 1, 4), + (6144, 6144, 65536, 64, 64, False, True, True): (1, 1024, 1, 8), + (6144, 6144, 65536, 64, 64, True, False, True): (4, 256, 3, 4), + (6144, 6144, 65536, 128, 128, False, True, True): (1, 512, 1, 8), + (6144, 6144, 65536, 128, 128, True, False, True): (1, 512, 3, 4), + (6144, 6144, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (6144, 6144, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (6144, 6144, 131072, 32, 32, False, True, True): (1, 2048, 1, 8), + (6144, 6144, 131072, 32, 32, True, False, True): (1, 1024, 1, 4), + (6144, 6144, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (6144, 6144, 131072, 64, 64, True, False, True): (4, 512, 3, 4), + (6144, 6144, 131072, 128, 128, False, True, True): (1, 1024, 1, 8), + (6144, 6144, 131072, 128, 128, True, False, True): (1, 1024, 3, 4), + (6144, 6144, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (6144, 6144, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (8192, 8192, 256, 32, 32, False, True, True): (1, 4, 1, 8), + (8192, 8192, 256, 32, 32, True, False, True): (3, 2, 3, 4), + (8192, 8192, 256, 64, 64, False, True, True): (1, 4, 1, 4), + (8192, 8192, 256, 64, 64, True, False, True): (1, 4, 1, 4), + (8192, 8192, 256, 128, 128, False, True, True): (1, 2, 1, 8), + (8192, 8192, 256, 128, 128, True, False, True): (2, 2, 3, 4), + (8192, 8192, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (8192, 8192, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (8192, 8192, 512, 32, 32, False, True, True): (4, 8, 1, 8), + (8192, 8192, 512, 32, 32, True, False, True): (2, 4, 4, 2), + (8192, 8192, 512, 64, 64, False, True, True): (4, 4, 1, 4), + (8192, 8192, 512, 64, 64, True, False, True): (3, 2, 3, 4), + (8192, 8192, 512, 128, 128, False, True, True): (1, 4, 1, 8), + (8192, 8192, 512, 128, 128, True, False, True): (1, 4, 3, 4), + (8192, 8192, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (8192, 8192, 512, 256, 256, True, False, True): (1, 2, 1, 32), + (8192, 8192, 1024, 32, 32, False, True, True): (4, 16, 1, 8), + (8192, 8192, 1024, 32, 32, True, False, True): (1, 8, 3, 2), + (8192, 8192, 1024, 64, 64, False, True, True): (4, 8, 1, 4), + (8192, 8192, 1024, 64, 64, True, False, True): (4, 4, 3, 4), + (8192, 8192, 1024, 128, 128, False, True, True): (1, 8, 1, 8), + (8192, 8192, 1024, 128, 128, True, False, True): (1, 8, 3, 4), + (8192, 8192, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (8192, 8192, 1024, 256, 256, True, False, True): (4, 4, 1, 32), + (8192, 8192, 2048, 32, 32, False, True, True): (4, 32, 1, 8), + (8192, 8192, 2048, 32, 32, True, False, True): (1, 16, 3, 2), + (8192, 8192, 2048, 64, 64, False, True, True): (4, 32, 1, 8), + (8192, 8192, 2048, 64, 64, True, False, True): (4, 8, 3, 4), + (8192, 8192, 2048, 128, 128, False, True, True): (4, 16, 1, 8), + (8192, 8192, 2048, 128, 128, True, False, True): (4, 16, 3, 4), + (8192, 8192, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (8192, 8192, 2048, 256, 256, True, False, True): (4, 8, 1, 32), + (8192, 8192, 4096, 32, 32, False, True, True): (4, 64, 1, 8), + (8192, 8192, 4096, 32, 32, True, False, True): (2, 32, 3, 2), + (8192, 8192, 4096, 64, 64, False, True, True): (4, 64, 1, 8), + (8192, 8192, 4096, 64, 64, True, False, True): (4, 16, 3, 4), + (8192, 8192, 4096, 128, 128, False, True, True): (4, 32, 1, 8), + (8192, 8192, 4096, 128, 128, True, False, True): (4, 32, 3, 4), + (8192, 8192, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (8192, 8192, 4096, 256, 256, True, False, True): (2, 16, 1, 32), + (8192, 8192, 8192, 32, 32, False, True, True): (4, 128, 1, 8), + (8192, 8192, 8192, 32, 32, True, False, True): (1, 64, 3, 2), + (8192, 8192, 8192, 64, 64, False, True, True): (4, 64, 1, 4), + (8192, 8192, 8192, 64, 64, True, False, True): (4, 32, 3, 4), + (8192, 8192, 8192, 128, 128, False, True, True): (4, 64, 1, 16), + (8192, 8192, 8192, 128, 128, True, False, True): (4, 64, 3, 4), + (8192, 8192, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (8192, 8192, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (8192, 8192, 16384, 32, 32, False, True, True): (4, 256, 1, 8), + (8192, 8192, 16384, 32, 32, True, False, True): (4, 128, 3, 2), + (8192, 8192, 16384, 64, 64, False, True, True): (4, 128, 1, 4), + (8192, 8192, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (8192, 8192, 16384, 128, 128, False, True, True): (4, 128, 1, 16), + (8192, 8192, 16384, 128, 128, True, False, True): (4, 128, 3, 4), + (8192, 8192, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (8192, 8192, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (8192, 8192, 32768, 32, 32, False, True, True): (4, 512, 1, 8), + (8192, 8192, 32768, 32, 32, True, False, True): (2, 256, 3, 2), + (8192, 8192, 32768, 64, 64, False, True, True): (4, 256, 1, 4), + (8192, 8192, 32768, 64, 64, True, False, True): (4, 128, 3, 4), + (8192, 8192, 32768, 128, 128, False, True, True): (4, 256, 1, 16), + (8192, 8192, 32768, 128, 128, True, False, True): (4, 256, 3, 4), + (8192, 8192, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (8192, 8192, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (8192, 8192, 65536, 32, 32, False, True, True): (4, 1024, 1, 8), + (8192, 8192, 65536, 32, 32, True, False, True): (4, 512, 3, 2), + (8192, 8192, 65536, 64, 64, False, True, True): (4, 512, 1, 4), + (8192, 8192, 65536, 64, 64, True, False, True): (4, 256, 3, 4), + (8192, 8192, 65536, 128, 128, False, True, True): (4, 512, 1, 16), + (8192, 8192, 65536, 128, 128, True, False, True): (4, 512, 3, 4), + (8192, 8192, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (8192, 8192, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (8192, 8192, 131072, 32, 32, False, True, True): (4, 2048, 1, 8), + (8192, 8192, 131072, 32, 32, True, False, True): (4, 1024, 3, 2), + (8192, 8192, 131072, 64, 64, False, True, True): (4, 1024, 1, 4), + (8192, 8192, 131072, 64, 64, True, False, True): (4, 512, 3, 4), + (8192, 8192, 131072, 128, 128, False, True, True): (4, 1024, 1, 16), + (8192, 8192, 131072, 128, 128, True, False, True): (4, 1024, 3, 4), + (8192, 8192, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (8192, 8192, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + (16384, 16384, 256, 32, 32, False, True, True): (4, 4, 1, 8), + (16384, 16384, 256, 32, 32, True, False, True): (2, 2, 4, 2), + (16384, 16384, 256, 64, 64, False, True, True): (2, 2, 1, 4), + (16384, 16384, 256, 64, 64, True, False, True): (5, 1, 3, 4), + (16384, 16384, 256, 128, 128, False, True, True): (6, 2, 1, 8), + (16384, 16384, 256, 128, 128, True, False, True): (6, 2, 3, 4), + (16384, 16384, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (16384, 16384, 256, 256, 256, True, False, True): (1, 1, 1, 32), + (16384, 16384, 512, 32, 32, False, True, True): (4, 8, 1, 8), + (16384, 16384, 512, 32, 32, True, False, True): (1, 4, 4, 2), + (16384, 16384, 512, 64, 64, False, True, True): (4, 4, 1, 4), + (16384, 16384, 512, 64, 64, True, False, True): (2, 2, 3, 4), + (16384, 16384, 512, 128, 128, False, True, True): (4, 4, 1, 8), + (16384, 16384, 512, 128, 128, True, False, True): (4, 4, 3, 4), + (16384, 16384, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (16384, 16384, 512, 256, 256, True, False, True): (2, 2, 1, 32), + (16384, 16384, 1024, 32, 32, False, True, True): (4, 16, 1, 8), + (16384, 16384, 1024, 32, 32, True, False, True): (1, 8, 3, 2), + (16384, 16384, 1024, 64, 64, False, True, True): (4, 8, 1, 4), + (16384, 16384, 1024, 64, 64, True, False, True): (4, 4, 3, 4), + (16384, 16384, 1024, 128, 128, False, True, True): (4, 4, 1, 8), + (16384, 16384, 1024, 128, 128, True, False, True): (4, 8, 3, 4), + (16384, 16384, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (16384, 16384, 1024, 256, 256, True, False, True): (4, 4, 1, 32), + (16384, 16384, 2048, 32, 32, False, True, True): (4, 32, 1, 8), + (16384, 16384, 2048, 32, 32, True, False, True): (2, 16, 3, 2), + (16384, 16384, 2048, 64, 64, False, True, True): (4, 16, 1, 4), + (16384, 16384, 2048, 64, 64, True, False, True): (4, 8, 3, 4), + (16384, 16384, 2048, 128, 128, False, True, True): (4, 16, 1, 8), + (16384, 16384, 2048, 128, 128, True, False, True): (4, 16, 3, 4), + (16384, 16384, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (16384, 16384, 2048, 256, 256, True, False, True): (4, 8, 1, 32), + (16384, 16384, 4096, 32, 32, False, True, True): (4, 64, 1, 8), + (16384, 16384, 4096, 32, 32, True, False, True): (2, 32, 3, 2), + (16384, 16384, 4096, 64, 64, False, True, True): (2, 32, 1, 4), + (16384, 16384, 4096, 64, 64, True, False, True): (4, 16, 3, 4), + (16384, 16384, 4096, 128, 128, False, True, True): (4, 32, 1, 8), + (16384, 16384, 4096, 128, 128, True, False, True): (4, 32, 3, 4), + (16384, 16384, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (16384, 16384, 4096, 256, 256, True, False, True): (4, 16, 1, 32), + (16384, 16384, 8192, 32, 32, False, True, True): (4, 128, 1, 8), + (16384, 16384, 8192, 32, 32, True, False, True): (2, 64, 3, 2), + (16384, 16384, 8192, 64, 64, False, True, True): (4, 64, 1, 4), + (16384, 16384, 8192, 64, 64, True, False, True): (4, 32, 3, 4), + (16384, 16384, 8192, 128, 128, False, True, True): (4, 64, 1, 16), + (16384, 16384, 8192, 128, 128, True, False, True): (4, 64, 3, 4), + (16384, 16384, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (16384, 16384, 8192, 256, 256, True, False, True): (4, 32, 1, 32), + (16384, 16384, 16384, 32, 32, False, True, True): (4, 256, 1, 8), + (16384, 16384, 16384, 32, 32, True, False, True): (2, 128, 3, 2), + (16384, 16384, 16384, 64, 64, False, True, True): (4, 128, 1, 4), + (16384, 16384, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (16384, 16384, 16384, 128, 128, False, True, True): (1, 64, 1, 8), + (16384, 16384, 16384, 128, 128, True, False, True): (4, 128, 3, 4), + (16384, 16384, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (16384, 16384, 16384, 256, 256, True, False, True): (4, 64, 1, 32), + (16384, 16384, 32768, 32, 32, False, True, True): (4, 512, 1, 8), + (16384, 16384, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (16384, 16384, 32768, 64, 64, False, True, True): (4, 256, 1, 4), + (16384, 16384, 32768, 64, 64, True, False, True): (4, 128, 3, 4), + (16384, 16384, 32768, 128, 128, False, True, True): (4, 256, 1, 16), + (16384, 16384, 32768, 128, 128, True, False, True): (4, 256, 3, 4), + (16384, 16384, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (16384, 16384, 32768, 256, 256, True, False, True): (4, 128, 1, 32), + (16384, 16384, 65536, 32, 32, False, True, True): (4, 1024, 1, 8), + (16384, 16384, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (16384, 16384, 65536, 64, 64, False, True, True): (2, 512, 1, 4), + (16384, 16384, 65536, 64, 64, True, False, True): (4, 256, 3, 4), + (16384, 16384, 65536, 128, 128, False, True, True): (4, 512, 1, 16), + (16384, 16384, 65536, 128, 128, True, False, True): (4, 512, 3, 4), + (16384, 16384, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (16384, 16384, 65536, 256, 256, True, False, True): (4, 256, 1, 32), + (16384, 16384, 131072, 32, 32, False, True, True): (4, 1024, 1, 8), + (16384, 16384, 131072, 32, 32, True, False, True): (4, 512, 3, 4), + (16384, 16384, 131072, 64, 64, False, True, True): (4, 1024, 1, 4), + (16384, 16384, 131072, 64, 64, True, False, True): (4, 1024, 3, 2), + (16384, 16384, 131072, 128, 128, False, True, True): (2, 1024, 3, 8), + (16384, 16384, 131072, 128, 128, True, False, True): (4, 1024, 3, 4), + (16384, 16384, 131072, 256, 256, False, True, True): (4, 512, 1, 32), + (16384, 16384, 131072, 256, 256, True, False, True): (4, 512, 1, 32), + }, + ("_int_bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.int8, 0.56)): { + (192, 192, 256, 64, 64, False, True, True): (3, 4, 3, 32), + (192, 192, 256, 64, 64, True, False, True): (1, 4, 3, 4), + (192, 192, 512, 64, 64, False, True, True): (1, 8, 1, 16), + (192, 192, 512, 64, 64, True, False, True): (1, 8, 5, 4), + (192, 192, 1024, 64, 64, False, True, True): (4, 16, 1, 16), + (192, 192, 1024, 64, 64, True, False, True): (3, 16, 3, 4), + (192, 192, 2048, 64, 64, False, True, True): (5, 32, 1, 8), + (192, 192, 2048, 64, 64, True, False, True): (2, 32, 4, 4), + (192, 192, 4096, 64, 64, False, True, True): (4, 64, 1, 16), + (192, 192, 4096, 64, 64, True, False, True): (1, 32, 4, 4), + (192, 192, 8192, 64, 64, False, True, True): (2, 128, 1, 8), + (192, 192, 8192, 64, 64, True, False, True): (3, 64, 1, 4), + (192, 192, 16384, 64, 64, False, True, True): (2, 256, 1, 8), + (192, 192, 16384, 64, 64, True, False, True): (1, 128, 3, 2), + (192, 192, 32768, 64, 64, False, True, True): (2, 512, 1, 8), + (192, 192, 32768, 64, 64, True, False, True): (3, 128, 1, 4), + (192, 192, 65536, 64, 64, False, True, True): (3, 1024, 1, 8), + (192, 192, 65536, 64, 64, True, False, True): (1, 512, 3, 4), + (192, 192, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (192, 192, 131072, 64, 64, True, False, True): (1, 512, 1, 4), + (384, 384, 256, 128, 128, False, True, True): (4, 2, 1, 16), + (384, 384, 256, 128, 128, True, False, True): (1, 2, 3, 4), + (384, 384, 512, 128, 128, False, True, True): (2, 4, 1, 16), + (384, 384, 512, 128, 128, True, False, True): (2, 4, 3, 4), + (384, 384, 1024, 128, 128, False, True, True): (3, 8, 1, 32), + (384, 384, 1024, 128, 128, True, False, True): (3, 8, 3, 4), + (384, 384, 2048, 128, 128, False, True, True): (3, 16, 1, 32), + (384, 384, 2048, 128, 128, True, False, True): (2, 16, 3, 4), + (384, 384, 4096, 128, 128, False, True, True): (3, 32, 1, 32), + (384, 384, 4096, 128, 128, True, False, True): (3, 32, 3, 4), + (384, 384, 8192, 128, 128, False, True, True): (2, 64, 1, 32), + (384, 384, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (384, 384, 16384, 128, 128, False, True, True): (2, 128, 1, 32), + (384, 384, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (384, 384, 32768, 128, 128, False, True, True): (3, 256, 1, 16), + (384, 384, 32768, 128, 128, True, False, True): (1, 256, 1, 4), + (384, 384, 65536, 128, 128, False, True, True): (4, 512, 1, 16), + (384, 384, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (384, 384, 131072, 128, 128, False, True, True): (4, 1024, 1, 16), + (384, 384, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (768, 768, 256, 256, 256, False, True, True): (1, 1, 1, 32), + (768, 768, 256, 256, 256, True, False, True): (3, 1, 1, 32), + (768, 768, 512, 256, 256, False, True, True): (1, 2, 1, 32), + (768, 768, 512, 256, 256, True, False, True): (1, 2, 1, 32), + (768, 768, 1024, 256, 256, False, True, True): (1, 4, 1, 32), + (768, 768, 1024, 256, 256, True, False, True): (2, 4, 1, 32), + (768, 768, 2048, 256, 256, False, True, True): (1, 8, 1, 32), + (768, 768, 2048, 256, 256, True, False, True): (2, 8, 1, 32), + (768, 768, 4096, 256, 256, False, True, True): (1, 16, 1, 32), + (768, 768, 4096, 256, 256, True, False, True): (1, 16, 1, 32), + (768, 768, 8192, 256, 256, False, True, True): (1, 32, 1, 32), + (768, 768, 8192, 256, 256, True, False, True): (2, 32, 1, 32), + (768, 768, 16384, 256, 256, False, True, True): (1, 64, 1, 32), + (768, 768, 16384, 256, 256, True, False, True): (7, 64, 1, 32), + (768, 768, 32768, 256, 256, False, True, True): (1, 128, 1, 32), + (768, 768, 32768, 256, 256, True, False, True): (1, 128, 1, 32), + (768, 768, 65536, 256, 256, False, True, True): (1, 256, 1, 32), + (768, 768, 65536, 256, 256, True, False, True): (1, 256, 1, 32), + (768, 768, 131072, 256, 256, False, True, True): (1, 512, 1, 32), + (768, 768, 131072, 256, 256, True, False, True): (1, 512, 1, 32), + }, + ("_int_bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.int8, 1.0)): { + (256, 256, 256, 256, 256, False, True, True): (2, 1, 1, 4), + (256, 256, 256, 256, 256, True, False, True): (2, 1, 2, 1), + (256, 256, 512, 256, 256, False, True, True): (2, 1, 1, 2), + (256, 256, 512, 256, 256, True, False, True): (2, 2, 2, 8), + (256, 256, 1024, 256, 256, False, True, True): (1, 4, 1, 4), + (256, 256, 1024, 256, 256, True, False, True): (1, 2, 2, 4), + (256, 256, 2048, 256, 256, False, True, True): (1, 4, 1, 2), + (256, 256, 2048, 256, 256, True, False, True): (1, 8, 1, 2), + (256, 256, 4096, 256, 256, False, True, True): (1, 16, 1, 4), + (256, 256, 4096, 256, 256, True, False, True): (1, 16, 1, 2), + (256, 256, 8192, 256, 256, False, True, True): (1, 16, 3, 4), + (256, 256, 8192, 256, 256, True, False, True): (1, 8, 1, 4), + (256, 256, 16384, 256, 256, False, True, True): (2, 16, 1, 8), + (256, 256, 16384, 256, 256, True, False, True): (1, 32, 1, 2), + (256, 256, 32768, 256, 256, False, True, True): (1, 128, 1, 8), + (256, 256, 32768, 256, 256, True, False, True): (1, 128, 1, 4), + (256, 256, 65536, 256, 256, False, True, True): (1, 4, 1, 1), + (256, 256, 65536, 256, 256, True, False, True): (1, 128, 1, 4), + (256, 256, 131072, 256, 256, False, True, True): (1, 512, 1, 4), + (256, 256, 131072, 256, 256, True, False, True): (1, 512, 1, 2), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): { + (16, 16, 16, 16, 16, False, False, False): (2, 1, 1, 2), + (16, 16, 16, 16, 16, False, False, True): (1, 1, 1, 4), + (16, 16, 16, 16, 16, False, True, False): (1, 1, 3, 16), + (16, 16, 16, 16, 16, False, True, True): (1, 1, 1, 8), + (16, 16, 16, 16, 16, True, False, False): (2, 1, 1, 8), + (16, 16, 16, 16, 16, True, False, True): (1, 1, 1, 8), + (16, 16, 32, 16, 16, False, False, False): (1, 2, 1, 8), + (16, 16, 32, 16, 16, False, False, True): (1, 2, 2, 4), + (16, 16, 32, 16, 16, False, True, False): (1, 1, 2, 4), + (16, 16, 32, 16, 16, False, True, True): (1, 1, 2, 4), + (16, 16, 32, 16, 16, True, False, False): (1, 1, 2, 4), + (16, 16, 32, 16, 16, True, False, True): (2, 2, 1, 2), + (16, 16, 64, 16, 16, False, False, False): (1, 4, 2, 4), + (16, 16, 64, 16, 16, False, False, True): (1, 2, 1, 2), + (16, 16, 64, 16, 16, False, True, False): (2, 1, 1, 2), + (16, 16, 64, 16, 16, False, True, True): (1, 4, 1, 8), + (16, 16, 64, 16, 16, True, False, False): (1, 4, 1, 1), + (16, 16, 64, 16, 16, True, False, True): (1, 4, 2, 4), + (16, 32, 16, 16, 16, False, False, False): (1, 1, 2, 2), + (16, 32, 16, 16, 16, False, False, True): (1, 1, 1, 4), + (16, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2), + (16, 32, 16, 16, 16, False, True, True): (1, 1, 1, 1), + (16, 32, 16, 16, 16, True, False, False): (1, 1, 1, 2), + (16, 32, 16, 16, 16, True, False, True): (2, 1, 1, 2), + (16, 32, 16, 16, 32, False, False, False): (1, 1, 1, 4), + (16, 32, 16, 16, 32, False, False, True): (1, 1, 1, 8), + (16, 32, 16, 16, 32, False, True, False): (1, 1, 1, 8), + (16, 32, 16, 16, 32, False, True, True): (1, 1, 2, 4), + (16, 32, 16, 16, 32, True, False, False): (1, 1, 1, 2), + (16, 32, 16, 16, 32, True, False, True): (1, 1, 1, 1), + (16, 32, 32, 16, 16, False, False, False): (2, 2, 1, 4), + (16, 32, 32, 16, 16, False, False, True): (2, 2, 1, 2), + (16, 32, 32, 16, 16, False, True, False): (1, 1, 2, 8), + (16, 32, 32, 16, 16, False, True, True): (1, 2, 1, 1), + (16, 32, 32, 16, 16, True, False, False): (1, 1, 1, 8), + (16, 32, 32, 16, 16, True, False, True): (1, 2, 1, 4), + (16, 32, 32, 16, 32, False, False, False): (1, 1, 2, 8), + (16, 32, 32, 16, 32, False, False, True): (2, 1, 1, 8), + (16, 32, 32, 16, 32, False, True, False): (1, 1, 1, 4), + (16, 32, 32, 16, 32, False, True, True): (1, 1, 1, 4), + (16, 32, 32, 16, 32, True, False, False): (1, 2, 1, 8), + (16, 32, 32, 16, 32, True, False, True): (1, 1, 1, 4), + (16, 32, 64, 16, 16, False, False, False): (1, 4, 3, 8), + (16, 32, 64, 16, 16, False, False, True): (1, 4, 1, 4), + (16, 32, 64, 16, 16, False, True, False): (1, 4, 1, 4), + (16, 32, 64, 16, 16, False, True, True): (2, 4, 1, 4), + (16, 32, 64, 16, 16, True, False, False): (1, 2, 1, 4), + (16, 32, 64, 16, 16, True, False, True): (1, 2, 1, 4), + (16, 32, 64, 16, 32, False, False, False): (1, 4, 1, 8), + (16, 32, 64, 16, 32, False, False, True): (1, 4, 1, 4), + (16, 32, 64, 16, 32, False, True, False): (1, 4, 1, 2), + (16, 32, 64, 16, 32, False, True, True): (1, 2, 1, 4), + (16, 32, 64, 16, 32, True, False, False): (1, 2, 1, 4), + (16, 32, 64, 16, 32, True, False, True): (1, 2, 1, 2), + (16, 64, 16, 16, 32, False, False, False): (1, 1, 1, 2), + (16, 64, 16, 16, 32, False, False, True): (1, 1, 2, 2), + (16, 64, 16, 16, 32, False, True, False): (1, 1, 2, 8), + (16, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (16, 64, 16, 16, 32, True, False, False): (1, 1, 1, 8), + (16, 64, 16, 16, 32, True, False, True): (1, 1, 1, 4), + (16, 64, 32, 16, 32, False, False, False): (1, 2, 1, 2), + (16, 64, 32, 16, 32, False, False, True): (1, 2, 1, 4), + (16, 64, 32, 16, 32, False, True, False): (1, 2, 1, 4), + (16, 64, 32, 16, 32, False, True, True): (2, 2, 1, 4), + (16, 64, 32, 16, 32, True, False, False): (1, 2, 1, 4), + (16, 64, 32, 16, 32, True, False, True): (1, 2, 1, 8), + (16, 64, 64, 16, 32, False, False, False): (1, 2, 1, 4), + (16, 64, 64, 16, 32, False, False, True): (1, 4, 2, 2), + (16, 64, 64, 16, 32, False, True, False): (1, 1, 1, 4), + (16, 64, 64, 16, 32, False, True, True): (1, 4, 1, 2), + (16, 64, 64, 16, 32, True, False, False): (1, 2, 1, 4), + (16, 64, 64, 16, 32, True, False, True): (1, 4, 1, 4), + (32, 16, 16, 16, 16, False, False, False): (1, 1, 1, 8), + (32, 16, 16, 16, 16, False, False, True): (1, 1, 2, 4), + (32, 16, 16, 16, 16, False, True, False): (1, 1, 1, 4), + (32, 16, 16, 16, 16, False, True, True): (1, 1, 2, 4), + (32, 16, 16, 16, 16, True, False, False): (1, 1, 1, 2), + (32, 16, 16, 16, 16, True, False, True): (1, 1, 1, 4), + (32, 16, 32, 16, 16, False, False, False): (1, 1, 1, 4), + (32, 16, 32, 16, 16, False, False, True): (2, 2, 1, 4), + (32, 16, 32, 16, 16, False, True, False): (1, 2, 2, 2), + (32, 16, 32, 16, 16, False, True, True): (2, 2, 1, 4), + (32, 16, 32, 16, 16, True, False, False): (1, 2, 2, 8), + (32, 16, 32, 16, 16, True, False, True): (1, 2, 1, 2), + (32, 16, 64, 16, 16, False, False, False): (1, 4, 1, 4), + (32, 16, 64, 16, 16, False, False, True): (1, 4, 2, 4), + (32, 16, 64, 16, 16, False, True, False): (1, 2, 2, 2), + (32, 16, 64, 16, 16, False, True, True): (3, 4, 1, 4), + (32, 16, 64, 16, 16, True, False, False): (1, 2, 1, 2), + (32, 16, 64, 16, 16, True, False, True): (1, 2, 1, 4), + (32, 32, 16, 16, 16, False, False, False): (1, 1, 3, 4), + (32, 32, 16, 16, 16, False, False, True): (1, 1, 1, 4), + (32, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2), + (32, 32, 16, 16, 16, False, True, True): (1, 1, 1, 4), + (32, 32, 16, 16, 16, True, False, False): (1, 1, 1, 4), + (32, 32, 16, 16, 16, True, False, True): (1, 1, 2, 2), + (32, 32, 16, 16, 32, False, False, False): (2, 1, 1, 4), + (32, 32, 16, 16, 32, False, False, True): (1, 1, 1, 4), + (32, 32, 16, 16, 32, False, True, False): (1, 1, 1, 4), + (32, 32, 16, 16, 32, False, True, True): (3, 1, 2, 4), + (32, 32, 16, 16, 32, True, False, False): (1, 1, 1, 4), + (32, 32, 16, 16, 32, True, False, True): (1, 1, 1, 4), + (32, 32, 16, 32, 32, False, False, False): (1, 1, 1, 8), + (32, 32, 16, 32, 32, False, False, True): (1, 1, 1, 4), + (32, 32, 16, 32, 32, False, True, False): (1, 1, 2, 1), + (32, 32, 16, 32, 32, False, True, True): (2, 1, 2, 2), + (32, 32, 16, 32, 32, True, False, False): (1, 1, 1, 8), + (32, 32, 16, 32, 32, True, False, True): (2, 1, 3, 4), + (32, 32, 32, 16, 16, False, False, False): (1, 2, 1, 4), + (32, 32, 32, 16, 16, False, False, True): (2, 2, 1, 4), + (32, 32, 32, 16, 16, False, True, False): (1, 1, 1, 8), + (32, 32, 32, 16, 16, False, True, True): (2, 2, 1, 4), + (32, 32, 32, 16, 16, True, False, False): (1, 1, 1, 4), + (32, 32, 32, 16, 16, True, False, True): (2, 2, 2, 4), + (32, 32, 32, 16, 32, False, False, False): (2, 2, 1, 8), + (32, 32, 32, 16, 32, False, False, True): (1, 2, 1, 2), + (32, 32, 32, 16, 32, False, True, False): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, True, True): (1, 2, 1, 4), + (32, 32, 32, 16, 32, True, False, False): (1, 2, 1, 4), + (32, 32, 32, 16, 32, True, False, True): (1, 2, 1, 2), + (32, 32, 32, 32, 32, False, False, False): (1, 1, 3, 8), + (32, 32, 32, 32, 32, False, False, True): (1, 1, 1, 8), + (32, 32, 32, 32, 32, False, True, False): (2, 1, 3, 4), + (32, 32, 32, 32, 32, False, True, True): (2, 1, 1, 2), + (32, 32, 32, 32, 32, True, False, False): (1, 1, 1, 2), + (32, 32, 32, 32, 32, True, False, True): (4, 1, 1, 1), + (32, 32, 64, 16, 16, False, False, False): (1, 4, 1, 4), + (32, 32, 64, 16, 16, False, False, True): (1, 4, 1, 4), + (32, 32, 64, 16, 16, False, True, False): (1, 2, 1, 8), + (32, 32, 64, 16, 16, False, True, True): (1, 4, 1, 2), + (32, 32, 64, 16, 16, True, False, False): (2, 4, 1, 2), + (32, 32, 64, 16, 16, True, False, True): (1, 4, 1, 2), + (32, 32, 64, 16, 32, False, False, False): (1, 2, 1, 8), + (32, 32, 64, 16, 32, False, False, True): (1, 4, 2, 2), + (32, 32, 64, 16, 32, False, True, False): (1, 2, 1, 4), + (32, 32, 64, 16, 32, False, True, True): (1, 4, 1, 4), + (32, 32, 64, 16, 32, True, False, False): (1, 4, 2, 2), + (32, 32, 64, 16, 32, True, False, True): (3, 4, 2, 2), + (32, 32, 64, 32, 32, False, False, False): (2, 2, 1, 4), + (32, 32, 64, 32, 32, False, False, True): (1, 2, 1, 4), + (32, 32, 64, 32, 32, False, True, False): (1, 1, 1, 8), + (32, 32, 64, 32, 32, False, True, True): (1, 1, 1, 4), + (32, 32, 64, 32, 32, True, False, False): (1, 2, 1, 2), + (32, 32, 64, 32, 32, True, False, True): (3, 2, 1, 8), + (32, 64, 16, 16, 32, False, False, False): (1, 1, 2, 2), + (32, 64, 16, 16, 32, False, False, True): (1, 1, 1, 4), + (32, 64, 16, 16, 32, False, True, False): (1, 1, 2, 4), + (32, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (32, 64, 16, 16, 32, True, False, False): (1, 1, 1, 2), + (32, 64, 16, 16, 32, True, False, True): (2, 1, 2, 2), + (32, 64, 16, 32, 32, False, False, False): (1, 1, 1, 1), + (32, 64, 16, 32, 32, False, False, True): (2, 1, 1, 4), + (32, 64, 16, 32, 32, False, True, False): (1, 1, 1, 1), + (32, 64, 16, 32, 32, False, True, True): (1, 1, 2, 2), + (32, 64, 16, 32, 32, True, False, False): (1, 1, 2, 4), + (32, 64, 16, 32, 32, True, False, True): (1, 1, 1, 4), + (32, 64, 32, 16, 32, False, False, False): (2, 2, 1, 4), + (32, 64, 32, 16, 32, False, False, True): (1, 2, 1, 4), + (32, 64, 32, 16, 32, False, True, False): (1, 1, 1, 4), + (32, 64, 32, 16, 32, False, True, True): (2, 2, 3, 4), + (32, 64, 32, 16, 32, True, False, False): (1, 1, 1, 2), + (32, 64, 32, 16, 32, True, False, True): (1, 2, 1, 2), + (32, 64, 32, 32, 32, False, False, False): (1, 1, 1, 2), + (32, 64, 32, 32, 32, False, False, True): (2, 1, 1, 4), + (32, 64, 32, 32, 32, False, True, False): (1, 1, 1, 8), + (32, 64, 32, 32, 32, False, True, True): (1, 1, 2, 4), + (32, 64, 32, 32, 32, True, False, False): (2, 1, 1, 4), + (32, 64, 32, 32, 32, True, False, True): (1, 1, 2, 4), + (32, 64, 64, 16, 32, False, False, False): (1, 4, 1, 4), + (32, 64, 64, 16, 32, False, False, True): (1, 4, 2, 4), + (32, 64, 64, 16, 32, False, True, False): (1, 4, 2, 2), + (32, 64, 64, 16, 32, False, True, True): (1, 4, 1, 4), + (32, 64, 64, 16, 32, True, False, False): (1, 4, 1, 8), + (32, 64, 64, 16, 32, True, False, True): (1, 4, 2, 1), + (32, 64, 64, 32, 32, False, False, False): (1, 1, 1, 4), + (32, 64, 64, 32, 32, False, False, True): (2, 2, 1, 4), + (32, 64, 64, 32, 32, False, True, False): (1, 1, 1, 4), + (32, 64, 64, 32, 32, False, True, True): (2, 2, 1, 4), + (32, 64, 64, 32, 32, True, False, False): (1, 2, 2, 4), + (32, 64, 64, 32, 32, True, False, True): (2, 2, 3, 4), + (64, 32, 16, 32, 32, False, False, False): (1, 1, 1, 4), + (64, 32, 16, 32, 32, False, False, True): (1, 1, 1, 4), + (64, 32, 16, 32, 32, False, True, False): (1, 1, 1, 8), + (64, 32, 16, 32, 32, False, True, True): (1, 1, 1, 4), + (64, 32, 16, 32, 32, True, False, False): (1, 1, 1, 16), + (64, 32, 16, 32, 32, True, False, True): (2, 1, 1, 4), + (64, 32, 32, 32, 32, False, False, False): (1, 1, 3, 4), + (64, 32, 32, 32, 32, False, False, True): (2, 1, 1, 4), + (64, 32, 32, 32, 32, False, True, False): (1, 1, 2, 4), + (64, 32, 32, 32, 32, False, True, True): (2, 1, 1, 4), + (64, 32, 32, 32, 32, True, False, False): (2, 1, 1, 16), + (64, 32, 32, 32, 32, True, False, True): (2, 1, 1, 4), + (64, 32, 64, 32, 32, False, False, False): (1, 2, 1, 4), + (64, 32, 64, 32, 32, False, False, True): (2, 2, 1, 4), + (64, 32, 64, 32, 32, False, True, False): (1, 1, 1, 4), + (64, 32, 64, 32, 32, False, True, True): (2, 2, 1, 4), + (64, 32, 64, 32, 32, True, False, False): (1, 2, 1, 8), + (64, 32, 64, 32, 32, True, False, True): (2, 2, 3, 4), + (64, 64, 16, 32, 32, False, False, False): (1, 1, 2, 16), + (64, 64, 16, 32, 32, False, False, True): (1, 1, 3, 4), + (64, 64, 16, 32, 32, False, True, False): (1, 1, 1, 2), + (64, 64, 16, 32, 32, False, True, True): (2, 1, 1, 4), + (64, 64, 16, 32, 32, True, False, False): (2, 1, 3, 2), + (64, 64, 16, 32, 32, True, False, True): (1, 1, 2, 4), + (64, 64, 32, 32, 32, False, False, False): (1, 1, 1, 8), + (64, 64, 32, 32, 32, False, False, True): (2, 1, 2, 4), + (64, 64, 32, 32, 32, False, True, False): (2, 1, 1, 4), + (64, 64, 32, 32, 32, False, True, True): (1, 1, 2, 4), + (64, 64, 32, 32, 32, True, False, False): (2, 1, 1, 4), + (64, 64, 32, 32, 32, True, False, True): (1, 1, 2, 4), + (64, 64, 64, 32, 32, False, False, False): (1, 2, 2, 4), + (64, 64, 64, 32, 32, False, False, True): (1, 2, 2, 2), + (64, 64, 64, 32, 32, False, True, False): (1, 2, 1, 2), + (64, 64, 64, 32, 32, False, True, True): (1, 2, 1, 4), + (64, 64, 64, 32, 32, True, False, False): (1, 2, 1, 4), + (64, 64, 64, 32, 32, True, False, True): (1, 2, 1, 4), + (192, 192, 256, 16, 16, False, True, True): (1, 8, 5, 4), + (192, 192, 256, 16, 16, True, False, True): (2, 8, 5, 2), + (192, 192, 256, 32, 32, False, True, True): (1, 8, 6, 4), + (192, 192, 256, 32, 32, True, False, True): (3, 8, 5, 2), + (192, 192, 512, 16, 16, False, True, True): (1, 16, 5, 2), + (192, 192, 512, 16, 16, True, False, True): (1, 8, 4, 2), + (192, 192, 512, 32, 32, False, True, True): (2, 16, 5, 4), + (192, 192, 512, 32, 32, True, False, True): (2, 8, 5, 2), + (192, 192, 1024, 16, 16, False, True, True): (1, 16, 3, 4), + (192, 192, 1024, 16, 16, True, False, True): (1, 16, 6, 2), + (192, 192, 1024, 32, 32, False, True, True): (1, 32, 3, 4), + (192, 192, 1024, 32, 32, True, False, True): (1, 16, 4, 2), + (192, 192, 2048, 16, 16, False, True, True): (1, 32, 1, 4), + (192, 192, 2048, 16, 16, True, False, True): (4, 32, 4, 2), + (192, 192, 2048, 32, 32, False, True, True): (1, 16, 3, 8), + (192, 192, 2048, 32, 32, True, False, True): (2, 32, 4, 2), + (192, 192, 4096, 16, 16, False, True, True): (2, 64, 1, 4), + (192, 192, 4096, 16, 16, True, False, True): (1, 32, 3, 2), + (192, 192, 4096, 32, 32, False, True, True): (1, 64, 1, 8), + (192, 192, 4096, 32, 32, True, False, True): (2, 32, 4, 4), + (192, 192, 8192, 16, 16, False, True, True): (1, 64, 1, 4), + (192, 192, 8192, 16, 16, True, False, True): (2, 32, 3, 1), + (192, 192, 8192, 32, 32, False, True, True): (3, 128, 1, 4), + (192, 192, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (192, 192, 16384, 16, 16, False, True, True): (1, 128, 1, 4), + (192, 192, 16384, 16, 16, True, False, True): (4, 64, 3, 1), + (192, 192, 16384, 32, 32, False, True, True): (1, 128, 1, 4), + (192, 192, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (192, 192, 32768, 16, 16, False, True, True): (2, 256, 1, 2), + (192, 192, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (192, 192, 32768, 32, 32, False, True, True): (2, 256, 1, 4), + (192, 192, 32768, 32, 32, True, False, True): (4, 128, 3, 4), + (192, 192, 65536, 16, 16, False, True, True): (2, 512, 1, 2), + (192, 192, 65536, 16, 16, True, False, True): (2, 256, 3, 2), + (192, 192, 65536, 32, 32, False, True, True): (2, 512, 1, 4), + (192, 192, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (192, 192, 131072, 16, 16, False, True, True): (4, 1024, 1, 2), + (192, 192, 131072, 16, 16, True, False, True): (3, 512, 3, 2), + (192, 192, 131072, 32, 32, False, True, True): (1, 1024, 1, 2), + (192, 192, 131072, 32, 32, True, False, True): (1, 512, 3, 4), + (256, 256, 256, 16, 16, False, True, True): (4, 8, 5, 1), + (256, 256, 256, 16, 16, True, False, True): (2, 8, 4, 2), + (256, 256, 256, 32, 32, False, True, True): (2, 8, 5, 2), + (256, 256, 256, 32, 32, True, False, True): (1, 8, 5, 4), + (256, 256, 256, 64, 64, False, True, True): (2, 4, 4, 4), + (256, 256, 256, 64, 64, True, False, True): (1, 4, 3, 4), + (256, 256, 256, 128, 128, False, True, True): (4, 2, 2, 8), + (256, 256, 256, 128, 128, True, False, True): (1, 2, 2, 8), + (256, 256, 512, 16, 16, False, True, True): (1, 16, 5, 1), + (256, 256, 512, 16, 16, True, False, True): (3, 16, 3, 2), + (256, 256, 512, 32, 32, False, True, True): (2, 8, 5, 2), + (256, 256, 512, 32, 32, True, False, True): (1, 16, 4, 4), + (256, 256, 512, 64, 64, False, True, True): (1, 8, 4, 4), + (256, 256, 512, 64, 64, True, False, True): (3, 8, 3, 4), + (256, 256, 512, 128, 128, False, True, True): (1, 4, 2, 8), + (256, 256, 512, 128, 128, True, False, True): (1, 4, 2, 8), + (256, 256, 1024, 16, 16, False, True, True): (1, 16, 5, 4), + (256, 256, 1024, 16, 16, True, False, True): (5, 16, 4, 2), + (256, 256, 1024, 32, 32, False, True, True): (1, 32, 5, 2), + (256, 256, 1024, 32, 32, True, False, True): (2, 16, 5, 2), + (256, 256, 1024, 64, 64, False, True, True): (1, 16, 4, 4), + (256, 256, 1024, 64, 64, True, False, True): (1, 16, 4, 4), + (256, 256, 1024, 128, 128, False, True, True): (1, 8, 2, 8), + (256, 256, 1024, 128, 128, True, False, True): (1, 8, 2, 8), + (256, 256, 2048, 16, 16, False, True, True): (1, 16, 4, 4), + (256, 256, 2048, 16, 16, True, False, True): (2, 32, 5, 1), + (256, 256, 2048, 32, 32, False, True, True): (1, 64, 4, 1), + (256, 256, 2048, 32, 32, True, False, True): (2, 32, 4, 2), + (256, 256, 2048, 64, 64, False, True, True): (8, 16, 5, 4), + (256, 256, 2048, 64, 64, True, False, True): (1, 16, 4, 4), + (256, 256, 2048, 128, 128, False, True, True): (2, 16, 2, 8), + (256, 256, 2048, 128, 128, True, False, True): (1, 16, 2, 8), + (256, 256, 4096, 16, 16, False, True, True): (1, 64, 1, 4), + (256, 256, 4096, 16, 16, True, False, True): (1, 16, 3, 2), + (256, 256, 4096, 32, 32, False, True, True): (6, 32, 3, 2), + (256, 256, 4096, 32, 32, True, False, True): (4, 32, 4, 2), + (256, 256, 4096, 64, 64, False, True, True): (6, 64, 3, 4), + (256, 256, 4096, 64, 64, True, False, True): (2, 64, 3, 4), + (256, 256, 4096, 128, 128, False, True, True): (1, 32, 2, 8), + (256, 256, 4096, 128, 128, True, False, True): (1, 32, 2, 8), + (256, 256, 8192, 16, 16, False, True, True): (2, 32, 3, 4), + (256, 256, 8192, 16, 16, True, False, True): (4, 64, 3, 2), + (256, 256, 8192, 32, 32, False, True, True): (1, 64, 3, 4), + (256, 256, 8192, 32, 32, True, False, True): (3, 128, 1, 2), + (256, 256, 8192, 64, 64, False, True, True): (9, 128, 1, 4), + (256, 256, 8192, 64, 64, True, False, True): (8, 128, 1, 4), + (256, 256, 8192, 128, 128, False, True, True): (7, 64, 1, 4), + (256, 256, 8192, 128, 128, True, False, True): (1, 32, 1, 16), + (256, 256, 16384, 16, 16, False, True, True): (3, 128, 3, 2), + (256, 256, 16384, 16, 16, True, False, True): (5, 64, 3, 2), + (256, 256, 16384, 32, 32, False, True, True): (3, 128, 3, 2), + (256, 256, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (256, 256, 16384, 64, 64, False, True, True): (3, 128, 1, 4), + (256, 256, 16384, 64, 64, True, False, True): (2, 128, 1, 4), + (256, 256, 16384, 128, 128, False, True, True): (7, 128, 1, 4), + (256, 256, 16384, 128, 128, True, False, True): (1, 128, 2, 8), + (256, 256, 32768, 16, 16, False, True, True): (2, 128, 3, 2), + (256, 256, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (256, 256, 32768, 32, 32, False, True, True): (1, 256, 3, 4), + (256, 256, 32768, 32, 32, True, False, True): (3, 256, 3, 2), + (256, 256, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (256, 256, 32768, 64, 64, True, False, True): (3, 256, 1, 4), + (256, 256, 32768, 128, 128, False, True, True): (9, 256, 1, 4), + (256, 256, 32768, 128, 128, True, False, True): (2, 256, 1, 4), + (256, 256, 65536, 16, 16, False, True, True): (1, 256, 3, 2), + (256, 256, 65536, 16, 16, True, False, True): (1, 256, 3, 2), + (256, 256, 65536, 32, 32, False, True, True): (2, 512, 3, 2), + (256, 256, 65536, 32, 32, True, False, True): (2, 512, 3, 2), + (256, 256, 65536, 64, 64, False, True, True): (2, 512, 1, 4), + (256, 256, 65536, 64, 64, True, False, True): (1, 512, 1, 4), + (256, 256, 65536, 128, 128, False, True, True): (7, 512, 1, 4), + (256, 256, 65536, 128, 128, True, False, True): (2, 512, 1, 4), + (256, 256, 131072, 16, 16, False, True, True): (1, 512, 3, 2), + (256, 256, 131072, 16, 16, True, False, True): (1, 512, 3, 2), + (256, 256, 131072, 32, 32, False, True, True): (1, 1024, 3, 2), + (256, 256, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (256, 256, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (256, 256, 131072, 64, 64, True, False, True): (1, 1024, 1, 4), + (256, 256, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (256, 256, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (384, 384, 256, 16, 16, False, True, True): (1, 8, 5, 2), + (384, 384, 256, 16, 16, True, False, True): (3, 4, 5, 2), + (384, 384, 256, 32, 32, False, True, True): (2, 8, 4, 4), + (384, 384, 256, 32, 32, True, False, True): (1, 4, 6, 2), + (384, 384, 256, 64, 64, False, True, True): (2, 4, 4, 4), + (384, 384, 256, 64, 64, True, False, True): (2, 4, 4, 4), + (384, 384, 512, 16, 16, False, True, True): (1, 8, 4, 2), + (384, 384, 512, 16, 16, True, False, True): (1, 4, 5, 4), + (384, 384, 512, 32, 32, False, True, True): (1, 8, 4, 4), + (384, 384, 512, 32, 32, True, False, True): (3, 8, 5, 2), + (384, 384, 512, 64, 64, False, True, True): (3, 8, 3, 4), + (384, 384, 512, 64, 64, True, False, True): (5, 8, 5, 4), + (384, 384, 1024, 16, 16, False, True, True): (3, 16, 4, 2), + (384, 384, 1024, 16, 16, True, False, True): (1, 8, 4, 4), + (384, 384, 1024, 32, 32, False, True, True): (6, 32, 3, 2), + (384, 384, 1024, 32, 32, True, False, True): (3, 8, 4, 4), + (384, 384, 1024, 64, 64, False, True, True): (3, 16, 3, 4), + (384, 384, 1024, 64, 64, True, False, True): (2, 16, 4, 4), + (384, 384, 2048, 16, 16, False, True, True): (1, 32, 1, 4), + (384, 384, 2048, 16, 16, True, False, True): (1, 16, 5, 2), + (384, 384, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (384, 384, 2048, 32, 32, True, False, True): (1, 8, 4, 4), + (384, 384, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (384, 384, 2048, 64, 64, True, False, True): (1, 16, 3, 8), + (384, 384, 4096, 16, 16, False, True, True): (5, 32, 1, 4), + (384, 384, 4096, 16, 16, True, False, True): (6, 32, 3, 2), + (384, 384, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (384, 384, 4096, 32, 32, True, False, True): (1, 16, 3, 4), + (384, 384, 4096, 64, 64, False, True, True): (1, 64, 1, 4), + (384, 384, 4096, 64, 64, True, False, True): (2, 32, 3, 4), + (384, 384, 8192, 16, 16, False, True, True): (2, 64, 1, 4), + (384, 384, 8192, 16, 16, True, False, True): (3, 32, 3, 2), + (384, 384, 8192, 32, 32, False, True, True): (5, 64, 1, 8), + (384, 384, 8192, 32, 32, True, False, True): (1, 32, 3, 2), + (384, 384, 8192, 64, 64, False, True, True): (1, 128, 1, 4), + (384, 384, 8192, 64, 64, True, False, True): (3, 64, 3, 4), + (384, 384, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (384, 384, 16384, 16, 16, True, False, True): (4, 128, 3, 2), + (384, 384, 16384, 32, 32, False, True, True): (3, 128, 1, 4), + (384, 384, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (384, 384, 16384, 64, 64, False, True, True): (3, 256, 1, 4), + (384, 384, 16384, 64, 64, True, False, True): (2, 128, 3, 4), + (384, 384, 32768, 16, 16, False, True, True): (1, 256, 1, 2), + (384, 384, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (384, 384, 32768, 32, 32, False, True, True): (1, 256, 1, 2), + (384, 384, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (384, 384, 32768, 64, 64, False, True, True): (2, 256, 1, 4), + (384, 384, 32768, 64, 64, True, False, True): (1, 256, 3, 4), + (384, 384, 65536, 16, 16, False, True, True): (4, 512, 1, 2), + (384, 384, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (384, 384, 65536, 32, 32, False, True, True): (1, 512, 1, 2), + (384, 384, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (384, 384, 65536, 64, 64, False, True, True): (3, 512, 1, 4), + (384, 384, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (384, 384, 131072, 16, 16, False, True, True): (1, 512, 1, 1), + (384, 384, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (384, 384, 131072, 32, 32, False, True, True): (1, 512, 1, 4), + (384, 384, 131072, 32, 32, True, False, True): (1, 512, 3, 4), + (384, 384, 131072, 64, 64, False, True, True): (3, 1024, 1, 4), + (384, 384, 131072, 64, 64, True, False, True): (3, 512, 3, 4), + (512, 512, 256, 16, 16, False, True, True): (2, 4, 5, 4), + (512, 512, 256, 16, 16, True, False, True): (3, 4, 5, 4), + (512, 512, 256, 32, 32, False, True, True): (1, 4, 5, 2), + (512, 512, 256, 32, 32, True, False, True): (4, 8, 5, 1), + (512, 512, 256, 64, 64, False, True, True): (4, 4, 5, 4), + (512, 512, 256, 64, 64, True, False, True): (5, 4, 5, 4), + (512, 512, 256, 128, 128, False, True, True): (3, 2, 2, 8), + (512, 512, 256, 128, 128, True, False, True): (2, 2, 2, 8), + (512, 512, 512, 16, 16, False, True, True): (1, 8, 5, 4), + (512, 512, 512, 16, 16, True, False, True): (4, 8, 5, 2), + (512, 512, 512, 32, 32, False, True, True): (1, 16, 4, 1), + (512, 512, 512, 32, 32, True, False, True): (1, 8, 5, 2), + (512, 512, 512, 64, 64, False, True, True): (4, 8, 5, 4), + (512, 512, 512, 64, 64, True, False, True): (2, 8, 5, 4), + (512, 512, 512, 128, 128, False, True, True): (2, 4, 2, 8), + (512, 512, 512, 128, 128, True, False, True): (1, 4, 2, 8), + (512, 512, 1024, 16, 16, False, True, True): (2, 8, 4, 4), + (512, 512, 1024, 16, 16, True, False, True): (1, 8, 4, 4), + (512, 512, 1024, 32, 32, False, True, True): (3, 16, 4, 2), + (512, 512, 1024, 32, 32, True, False, True): (1, 16, 5, 2), + (512, 512, 1024, 64, 64, False, True, True): (2, 8, 3, 4), + (512, 512, 1024, 64, 64, True, False, True): (2, 16, 3, 4), + (512, 512, 1024, 128, 128, False, True, True): (2, 8, 2, 8), + (512, 512, 1024, 128, 128, True, False, True): (3, 8, 2, 8), + (512, 512, 2048, 16, 16, False, True, True): (4, 16, 3, 2), + (512, 512, 2048, 16, 16, True, False, True): (1, 16, 4, 2), + (512, 512, 2048, 32, 32, False, True, True): (3, 32, 3, 2), + (512, 512, 2048, 32, 32, True, False, True): (2, 32, 3, 2), + (512, 512, 2048, 64, 64, False, True, True): (6, 32, 3, 2), + (512, 512, 2048, 64, 64, True, False, True): (1, 32, 3, 2), + (512, 512, 2048, 128, 128, False, True, True): (4, 16, 2, 8), + (512, 512, 2048, 128, 128, True, False, True): (1, 16, 2, 8), + (512, 512, 4096, 16, 16, False, True, True): (1, 16, 3, 2), + (512, 512, 4096, 16, 16, True, False, True): (4, 32, 3, 2), + (512, 512, 4096, 32, 32, False, True, True): (3, 32, 3, 2), + (512, 512, 4096, 32, 32, True, False, True): (2, 32, 3, 2), + (512, 512, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (512, 512, 4096, 64, 64, True, False, True): (1, 64, 3, 4), + (512, 512, 4096, 128, 128, False, True, True): (4, 32, 1, 4), + (512, 512, 4096, 128, 128, True, False, True): (4, 32, 2, 8), + (512, 512, 8192, 16, 16, False, True, True): (8, 64, 3, 2), + (512, 512, 8192, 16, 16, True, False, True): (4, 64, 3, 2), + (512, 512, 8192, 32, 32, False, True, True): (3, 64, 3, 2), + (512, 512, 8192, 32, 32, True, False, True): (3, 64, 3, 2), + (512, 512, 8192, 64, 64, False, True, True): (1, 64, 3, 4), + (512, 512, 8192, 64, 64, True, False, True): (7, 64, 3, 4), + (512, 512, 8192, 128, 128, False, True, True): (1, 64, 1, 4), + (512, 512, 8192, 128, 128, True, False, True): (4, 64, 2, 8), + (512, 512, 16384, 16, 16, False, True, True): (1, 64, 3, 2), + (512, 512, 16384, 16, 16, True, False, True): (1, 128, 3, 2), + (512, 512, 16384, 32, 32, False, True, True): (3, 128, 3, 2), + (512, 512, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (512, 512, 16384, 64, 64, False, True, True): (4, 64, 2, 4), + (512, 512, 16384, 64, 64, True, False, True): (2, 64, 2, 4), + (512, 512, 16384, 128, 128, False, True, True): (4, 128, 1, 4), + (512, 512, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (512, 512, 32768, 16, 16, False, True, True): (1, 128, 3, 2), + (512, 512, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (512, 512, 32768, 32, 32, False, True, True): (1, 256, 3, 2), + (512, 512, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (512, 512, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (512, 512, 32768, 64, 64, True, False, True): (2, 256, 3, 4), + (512, 512, 32768, 128, 128, False, True, True): (5, 256, 1, 4), + (512, 512, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (512, 512, 65536, 16, 16, False, True, True): (1, 256, 3, 2), + (512, 512, 65536, 16, 16, True, False, True): (1, 256, 3, 1), + (512, 512, 65536, 32, 32, False, True, True): (1, 512, 3, 2), + (512, 512, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (512, 512, 65536, 64, 64, False, True, True): (4, 256, 2, 4), + (512, 512, 65536, 64, 64, True, False, True): (2, 512, 3, 4), + (512, 512, 65536, 128, 128, False, True, True): (6, 512, 1, 4), + (512, 512, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (512, 512, 131072, 16, 16, False, True, True): (1, 512, 3, 2), + (512, 512, 131072, 16, 16, True, False, True): (1, 512, 3, 1), + (512, 512, 131072, 32, 32, False, True, True): (1, 1024, 3, 2), + (512, 512, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (512, 512, 131072, 64, 64, False, True, True): (4, 512, 2, 4), + (512, 512, 131072, 64, 64, True, False, True): (4, 1024, 3, 4), + (512, 512, 131072, 128, 128, False, True, True): (6, 1024, 1, 4), + (512, 512, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (768, 768, 256, 16, 16, False, True, True): (1, 8, 4, 1), + (768, 768, 256, 16, 16, True, False, True): (3, 2, 6, 4), + (768, 768, 256, 32, 32, False, True, True): (3, 8, 3, 4), + (768, 768, 256, 32, 32, True, False, True): (1, 4, 4, 2), + (768, 768, 256, 64, 64, False, True, True): (2, 4, 3, 4), + (768, 768, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (768, 768, 256, 128, 128, False, True, True): (2, 2, 3, 8), + (768, 768, 256, 128, 128, True, False, True): (4, 2, 3, 8), + (768, 768, 512, 16, 16, False, True, True): (4, 8, 4, 2), + (768, 768, 512, 16, 16, True, False, True): (4, 8, 6, 2), + (768, 768, 512, 32, 32, False, True, True): (1, 8, 4, 4), + (768, 768, 512, 32, 32, True, False, True): (3, 8, 4, 2), + (768, 768, 512, 64, 64, False, True, True): (1, 8, 3, 4), + (768, 768, 512, 64, 64, True, False, True): (1, 8, 4, 4), + (768, 768, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (768, 768, 512, 128, 128, True, False, True): (4, 4, 3, 8), + (768, 768, 1024, 16, 16, False, True, True): (3, 16, 1, 4), + (768, 768, 1024, 16, 16, True, False, True): (1, 8, 5, 2), + (768, 768, 1024, 32, 32, False, True, True): (3, 16, 1, 8), + (768, 768, 1024, 32, 32, True, False, True): (1, 16, 3, 2), + (768, 768, 1024, 64, 64, False, True, True): (1, 8, 3, 4), + (768, 768, 1024, 64, 64, True, False, True): (2, 8, 3, 8), + (768, 768, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (768, 768, 1024, 128, 128, True, False, True): (1, 8, 3, 8), + (768, 768, 2048, 16, 16, False, True, True): (2, 16, 1, 2), + (768, 768, 2048, 16, 16, True, False, True): (1, 16, 3, 2), + (768, 768, 2048, 32, 32, False, True, True): (5, 32, 1, 4), + (768, 768, 2048, 32, 32, True, False, True): (3, 8, 3, 4), + (768, 768, 2048, 64, 64, False, True, True): (1, 16, 1, 8), + (768, 768, 2048, 64, 64, True, False, True): (3, 16, 3, 4), + (768, 768, 2048, 128, 128, False, True, True): (2, 16, 3, 8), + (768, 768, 2048, 128, 128, True, False, True): (1, 16, 3, 8), + (768, 768, 4096, 16, 16, False, True, True): (3, 32, 1, 4), + (768, 768, 4096, 16, 16, True, False, True): (2, 32, 3, 1), + (768, 768, 4096, 32, 32, False, True, True): (2, 64, 1, 4), + (768, 768, 4096, 32, 32, True, False, True): (1, 16, 4, 4), + (768, 768, 4096, 64, 64, False, True, True): (3, 64, 3, 4), + (768, 768, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (768, 768, 4096, 128, 128, False, True, True): (1, 32, 3, 8), + (768, 768, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (768, 768, 8192, 16, 16, False, True, True): (1, 64, 1, 2), + (768, 768, 8192, 16, 16, True, False, True): (4, 64, 3, 2), + (768, 768, 8192, 32, 32, False, True, True): (1, 64, 1, 8), + (768, 768, 8192, 32, 32, True, False, True): (2, 32, 3, 4), + (768, 768, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (768, 768, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (768, 768, 8192, 128, 128, False, True, True): (2, 64, 3, 8), + (768, 768, 8192, 128, 128, True, False, True): (1, 64, 3, 8), + (768, 768, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (768, 768, 16384, 16, 16, True, False, True): (1, 64, 4, 4), + (768, 768, 16384, 32, 32, False, True, True): (1, 128, 1, 8), + (768, 768, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (768, 768, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (768, 768, 16384, 64, 64, True, False, True): (1, 64, 3, 4), + (768, 768, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (768, 768, 16384, 128, 128, True, False, True): (3, 128, 2, 4), + (768, 768, 32768, 16, 16, False, True, True): (1, 256, 1, 2), + (768, 768, 32768, 16, 16, True, False, True): (1, 128, 4, 4), + (768, 768, 32768, 32, 32, False, True, True): (1, 128, 1, 2), + (768, 768, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (768, 768, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (768, 768, 32768, 64, 64, True, False, True): (2, 128, 3, 4), + (768, 768, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (768, 768, 32768, 128, 128, True, False, True): (2, 256, 2, 4), + (768, 768, 65536, 16, 16, False, True, True): (4, 512, 1, 2), + (768, 768, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (768, 768, 65536, 32, 32, False, True, True): (1, 256, 1, 2), + (768, 768, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (768, 768, 65536, 64, 64, False, True, True): (3, 512, 1, 4), + (768, 768, 65536, 64, 64, True, False, True): (2, 256, 3, 4), + (768, 768, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (768, 768, 65536, 128, 128, True, False, True): (2, 512, 2, 4), + (768, 768, 131072, 16, 16, False, True, True): (4, 1024, 1, 2), + (768, 768, 131072, 16, 16, True, False, True): (1, 512, 4, 1), + (768, 768, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (768, 768, 131072, 32, 32, True, False, True): (1, 512, 3, 4), + (768, 768, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (768, 768, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (768, 768, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (768, 768, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + (768, 3072, 256, 16, 16, False, True, True): (3, 8, 6, 1), + (768, 3072, 256, 16, 16, True, False, True): (1, 4, 6, 2), + (768, 3072, 256, 32, 32, False, True, True): (1, 8, 4, 4), + (768, 3072, 256, 32, 32, True, False, True): (3, 4, 6, 4), + (768, 3072, 256, 64, 64, False, True, True): (2, 4, 3, 4), + (768, 3072, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (768, 3072, 256, 128, 128, False, True, True): (2, 2, 3, 8), + (768, 3072, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (768, 3072, 512, 16, 16, False, True, True): (1, 8, 4, 2), + (768, 3072, 512, 16, 16, True, False, True): (1, 8, 5, 2), + (768, 3072, 512, 32, 32, False, True, True): (1, 16, 3, 2), + (768, 3072, 512, 32, 32, True, False, True): (1, 8, 5, 2), + (768, 3072, 512, 64, 64, False, True, True): (1, 8, 3, 4), + (768, 3072, 512, 64, 64, True, False, True): (3, 8, 4, 4), + (768, 3072, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (768, 3072, 512, 128, 128, True, False, True): (2, 4, 3, 8), + (768, 3072, 1024, 16, 16, False, True, True): (1, 16, 1, 4), + (768, 3072, 1024, 16, 16, True, False, True): (5, 4, 4, 4), + (768, 3072, 1024, 32, 32, False, True, True): (3, 8, 3, 4), + (768, 3072, 1024, 32, 32, True, False, True): (1, 8, 4, 4), + (768, 3072, 1024, 64, 64, False, True, True): (2, 16, 3, 4), + (768, 3072, 1024, 64, 64, True, False, True): (2, 16, 4, 4), + (768, 3072, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (768, 3072, 1024, 128, 128, True, False, True): (5, 8, 3, 8), + (768, 3072, 2048, 16, 16, False, True, True): (3, 16, 1, 2), + (768, 3072, 2048, 16, 16, True, False, True): (1, 8, 3, 4), + (768, 3072, 2048, 32, 32, False, True, True): (4, 16, 1, 8), + (768, 3072, 2048, 32, 32, True, False, True): (3, 8, 3, 4), + (768, 3072, 2048, 64, 64, False, True, True): (2, 16, 3, 4), + (768, 3072, 2048, 64, 64, True, False, True): (2, 16, 3, 4), + (768, 3072, 2048, 128, 128, False, True, True): (3, 16, 3, 8), + (768, 3072, 2048, 128, 128, True, False, True): (4, 16, 3, 8), + (768, 3072, 4096, 16, 16, False, True, True): (1, 32, 1, 4), + (768, 3072, 4096, 16, 16, True, False, True): (1, 16, 3, 1), + (768, 3072, 4096, 32, 32, False, True, True): (3, 32, 1, 8), + (768, 3072, 4096, 32, 32, True, False, True): (3, 16, 4, 4), + (768, 3072, 4096, 64, 64, False, True, True): (2, 32, 3, 4), + (768, 3072, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (768, 3072, 4096, 128, 128, False, True, True): (5, 32, 1, 4), + (768, 3072, 4096, 128, 128, True, False, True): (9, 32, 3, 8), + (768, 3072, 8192, 16, 16, False, True, True): (1, 32, 1, 4), + (768, 3072, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (768, 3072, 8192, 32, 32, False, True, True): (1, 64, 1, 8), + (768, 3072, 8192, 32, 32, True, False, True): (2, 64, 4, 2), + (768, 3072, 8192, 64, 64, False, True, True): (1, 64, 3, 4), + (768, 3072, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (768, 3072, 8192, 128, 128, False, True, True): (2, 64, 3, 8), + (768, 3072, 8192, 128, 128, True, False, True): (2, 64, 3, 8), + (768, 3072, 16384, 16, 16, False, True, True): (1, 64, 1, 4), + (768, 3072, 16384, 16, 16, True, False, True): (1, 64, 4, 1), + (768, 3072, 16384, 32, 32, False, True, True): (1, 128, 1, 8), + (768, 3072, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (768, 3072, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (768, 3072, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (768, 3072, 16384, 128, 128, False, True, True): (2, 128, 3, 8), + (768, 3072, 16384, 128, 128, True, False, True): (2, 128, 3, 8), + (768, 3072, 32768, 16, 16, False, True, True): (1, 128, 1, 4), + (768, 3072, 32768, 16, 16, True, False, True): (1, 128, 4, 1), + (768, 3072, 32768, 32, 32, False, True, True): (1, 256, 1, 8), + (768, 3072, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (768, 3072, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (768, 3072, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (768, 3072, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (768, 3072, 32768, 128, 128, True, False, True): (2, 256, 3, 8), + (768, 3072, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (768, 3072, 50432, 16, 16, True, False, True): (4, 197, 4, 4), + (768, 3072, 50432, 32, 32, False, True, True): (1, 197, 1, 4), + (768, 3072, 50432, 32, 32, True, False, True): (4, 197, 3, 4), + (768, 3072, 50432, 64, 64, False, True, True): (1, 394, 3, 4), + (768, 3072, 50432, 64, 64, True, False, True): (3, 197, 3, 4), + (768, 3072, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (768, 3072, 50432, 128, 128, True, False, True): (1, 394, 3, 8), + (768, 3072, 65536, 16, 16, False, True, True): (1, 256, 1, 4), + (768, 3072, 65536, 16, 16, True, False, True): (5, 256, 4, 1), + (768, 3072, 65536, 32, 32, False, True, True): (1, 256, 1, 4), + (768, 3072, 65536, 32, 32, True, False, True): (3, 256, 3, 4), + (768, 3072, 65536, 64, 64, False, True, True): (2, 512, 3, 4), + (768, 3072, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (768, 3072, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (768, 3072, 65536, 128, 128, True, False, True): (2, 512, 3, 8), + (768, 3072, 131072, 16, 16, False, True, True): (1, 512, 1, 4), + (768, 3072, 131072, 16, 16, True, False, True): (5, 512, 4, 1), + (768, 3072, 131072, 32, 32, False, True, True): (1, 512, 1, 4), + (768, 3072, 131072, 32, 32, True, False, True): (4, 512, 3, 4), + (768, 3072, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (768, 3072, 131072, 64, 64, True, False, True): (1, 512, 3, 4), + (768, 3072, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (768, 3072, 131072, 128, 128, True, False, True): (1, 1024, 3, 8), + (1024, 1024, 256, 16, 16, False, True, True): (1, 4, 5, 4), + (1024, 1024, 256, 16, 16, True, False, True): (3, 4, 4, 4), + (1024, 1024, 256, 32, 32, False, True, True): (4, 4, 5, 2), + (1024, 1024, 256, 32, 32, True, False, True): (3, 4, 5, 2), + (1024, 1024, 256, 64, 64, False, True, True): (1, 4, 5, 4), + (1024, 1024, 256, 64, 64, True, False, True): (1, 4, 5, 4), + (1024, 1024, 256, 128, 128, False, True, True): (1, 2, 2, 8), + (1024, 1024, 256, 128, 128, True, False, True): (2, 2, 2, 8), + (1024, 1024, 512, 16, 16, False, True, True): (3, 4, 4, 4), + (1024, 1024, 512, 16, 16, True, False, True): (4, 8, 5, 2), + (1024, 1024, 512, 32, 32, False, True, True): (1, 8, 4, 2), + (1024, 1024, 512, 32, 32, True, False, True): (1, 8, 4, 2), + (1024, 1024, 512, 64, 64, False, True, True): (4, 8, 4, 4), + (1024, 1024, 512, 64, 64, True, False, True): (2, 8, 3, 4), + (1024, 1024, 512, 128, 128, False, True, True): (2, 4, 2, 8), + (1024, 1024, 512, 128, 128, True, False, True): (1, 4, 2, 8), + (1024, 1024, 1024, 16, 16, False, True, True): (3, 8, 4, 4), + (1024, 1024, 1024, 16, 16, True, False, True): (4, 8, 4, 2), + (1024, 1024, 1024, 32, 32, False, True, True): (1, 16, 3, 2), + (1024, 1024, 1024, 32, 32, True, False, True): (1, 16, 3, 2), + (1024, 1024, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (1024, 1024, 1024, 64, 64, True, False, True): (3, 16, 3, 2), + (1024, 1024, 1024, 128, 128, False, True, True): (1, 8, 2, 8), + (1024, 1024, 1024, 128, 128, True, False, True): (2, 8, 2, 8), + (1024, 1024, 2048, 16, 16, False, True, True): (3, 8, 3, 4), + (1024, 1024, 2048, 16, 16, True, False, True): (3, 8, 3, 2), + (1024, 1024, 2048, 32, 32, False, True, True): (5, 16, 3, 4), + (1024, 1024, 2048, 32, 32, True, False, True): (1, 16, 3, 2), + (1024, 1024, 2048, 64, 64, False, True, True): (6, 16, 4, 4), + (1024, 1024, 2048, 64, 64, True, False, True): (5, 16, 3, 4), + (1024, 1024, 2048, 128, 128, False, True, True): (4, 16, 2, 8), + (1024, 1024, 2048, 128, 128, True, False, True): (4, 16, 2, 8), + (1024, 1024, 4096, 16, 16, False, True, True): (8, 32, 3, 2), + (1024, 1024, 4096, 16, 16, True, False, True): (4, 32, 3, 2), + (1024, 1024, 4096, 32, 32, False, True, True): (2, 32, 3, 4), + (1024, 1024, 4096, 32, 32, True, False, True): (3, 32, 3, 2), + (1024, 1024, 4096, 64, 64, False, True, True): (3, 32, 3, 4), + (1024, 1024, 4096, 64, 64, True, False, True): (1, 32, 3, 4), + (1024, 1024, 4096, 128, 128, False, True, True): (4, 32, 2, 8), + (1024, 1024, 4096, 128, 128, True, False, True): (1, 32, 2, 8), + (1024, 1024, 8192, 16, 16, False, True, True): (4, 64, 3, 2), + (1024, 1024, 8192, 16, 16, True, False, True): (4, 64, 3, 2), + (1024, 1024, 8192, 32, 32, False, True, True): (8, 64, 3, 4), + (1024, 1024, 8192, 32, 32, True, False, True): (4, 32, 3, 4), + (1024, 1024, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (1024, 1024, 8192, 64, 64, True, False, True): (2, 64, 3, 4), + (1024, 1024, 8192, 128, 128, False, True, True): (4, 64, 2, 8), + (1024, 1024, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (1024, 1024, 16384, 16, 16, False, True, True): (1, 64, 3, 2), + (1024, 1024, 16384, 16, 16, True, False, True): (1, 64, 3, 2), + (1024, 1024, 16384, 32, 32, False, True, True): (1, 128, 3, 2), + (1024, 1024, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (1024, 1024, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (1024, 1024, 16384, 64, 64, True, False, True): (1, 128, 3, 4), + (1024, 1024, 16384, 128, 128, False, True, True): (2, 128, 1, 4), + (1024, 1024, 16384, 128, 128, True, False, True): (4, 128, 1, 4), + (1024, 1024, 32768, 16, 16, False, True, True): (1, 128, 3, 2), + (1024, 1024, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (1024, 1024, 32768, 32, 32, False, True, True): (1, 256, 3, 2), + (1024, 1024, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (1024, 1024, 32768, 64, 64, False, True, True): (2, 128, 2, 4), + (1024, 1024, 32768, 64, 64, True, False, True): (1, 256, 3, 4), + (1024, 1024, 32768, 128, 128, False, True, True): (2, 256, 1, 4), + (1024, 1024, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (1024, 1024, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (1024, 1024, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (1024, 1024, 65536, 32, 32, False, True, True): (9, 256, 3, 4), + (1024, 1024, 65536, 32, 32, True, False, True): (7, 256, 3, 4), + (1024, 1024, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (1024, 1024, 65536, 64, 64, True, False, True): (2, 512, 3, 4), + (1024, 1024, 65536, 128, 128, False, True, True): (2, 512, 1, 4), + (1024, 1024, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (1024, 1024, 131072, 16, 16, False, True, True): (11, 512, 3, 2), + (1024, 1024, 131072, 16, 16, True, False, True): (11, 512, 3, 2), + (1024, 1024, 131072, 32, 32, False, True, True): (4, 512, 3, 4), + (1024, 1024, 131072, 32, 32, True, False, True): (6, 512, 3, 4), + (1024, 1024, 131072, 64, 64, False, True, True): (2, 512, 2, 4), + (1024, 1024, 131072, 64, 64, True, False, True): (2, 1024, 3, 4), + (1024, 1024, 131072, 128, 128, False, True, True): (4, 1024, 1, 4), + (1024, 1024, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (1536, 1536, 256, 16, 16, False, True, True): (1, 4, 6, 2), + (1536, 1536, 256, 16, 16, True, False, True): (3, 4, 5, 2), + (1536, 1536, 256, 32, 32, False, True, True): (2, 4, 3, 4), + (1536, 1536, 256, 32, 32, True, False, True): (1, 4, 5, 2), + (1536, 1536, 256, 64, 64, False, True, True): (2, 4, 3, 4), + (1536, 1536, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (1536, 1536, 256, 128, 128, False, True, True): (3, 2, 3, 8), + (1536, 1536, 256, 128, 128, True, False, True): (6, 2, 3, 8), + (1536, 1536, 512, 16, 16, False, True, True): (1, 8, 1, 4), + (1536, 1536, 512, 16, 16, True, False, True): (3, 4, 5, 2), + (1536, 1536, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (1536, 1536, 512, 32, 32, True, False, True): (1, 4, 4, 4), + (1536, 1536, 512, 64, 64, False, True, True): (3, 8, 5, 4), + (1536, 1536, 512, 64, 64, True, False, True): (3, 8, 3, 4), + (1536, 1536, 512, 128, 128, False, True, True): (2, 4, 3, 8), + (1536, 1536, 512, 128, 128, True, False, True): (3, 4, 3, 8), + (1536, 1536, 1024, 16, 16, False, True, True): (1, 8, 1, 2), + (1536, 1536, 1024, 16, 16, True, False, True): (2, 8, 4, 2), + (1536, 1536, 1024, 32, 32, False, True, True): (8, 16, 1, 4), + (1536, 1536, 1024, 32, 32, True, False, True): (3, 8, 4, 2), + (1536, 1536, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (1536, 1536, 1024, 64, 64, True, False, True): (3, 8, 3, 4), + (1536, 1536, 1024, 128, 128, False, True, True): (3, 8, 3, 8), + (1536, 1536, 1024, 128, 128, True, False, True): (3, 8, 3, 8), + (1536, 1536, 2048, 16, 16, False, True, True): (1, 16, 1, 4), + (1536, 1536, 2048, 16, 16, True, False, True): (1, 8, 3, 1), + (1536, 1536, 2048, 32, 32, False, True, True): (3, 16, 1, 8), + (1536, 1536, 2048, 32, 32, True, False, True): (3, 8, 4, 4), + (1536, 1536, 2048, 64, 64, False, True, True): (1, 16, 3, 4), + (1536, 1536, 2048, 64, 64, True, False, True): (3, 8, 3, 4), + (1536, 1536, 2048, 128, 128, False, True, True): (4, 16, 1, 4), + (1536, 1536, 2048, 128, 128, True, False, True): (6, 16, 3, 8), + (1536, 1536, 4096, 16, 16, False, True, True): (1, 32, 1, 2), + (1536, 1536, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (1536, 1536, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (1536, 1536, 4096, 32, 32, True, False, True): (5, 32, 4, 2), + (1536, 1536, 4096, 64, 64, False, True, True): (2, 32, 3, 4), + (1536, 1536, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (1536, 1536, 4096, 128, 128, False, True, True): (4, 32, 3, 8), + (1536, 1536, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (1536, 1536, 8192, 16, 16, False, True, True): (1, 64, 1, 2), + (1536, 1536, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (1536, 1536, 8192, 32, 32, False, True, True): (2, 64, 1, 8), + (1536, 1536, 8192, 32, 32, True, False, True): (2, 32, 3, 4), + (1536, 1536, 8192, 64, 64, False, True, True): (1, 64, 3, 4), + (1536, 1536, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (1536, 1536, 8192, 128, 128, False, True, True): (4, 64, 3, 8), + (1536, 1536, 8192, 128, 128, True, False, True): (1, 64, 3, 8), + (1536, 1536, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (1536, 1536, 16384, 16, 16, True, False, True): (1, 64, 4, 4), + (1536, 1536, 16384, 32, 32, False, True, True): (1, 64, 1, 2), + (1536, 1536, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (1536, 1536, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (1536, 1536, 16384, 64, 64, True, False, True): (1, 64, 3, 4), + (1536, 1536, 16384, 128, 128, False, True, True): (1, 128, 1, 4), + (1536, 1536, 16384, 128, 128, True, False, True): (1, 128, 2, 4), + (1536, 1536, 32768, 16, 16, False, True, True): (1, 256, 1, 2), + (1536, 1536, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (1536, 1536, 32768, 32, 32, False, True, True): (1, 128, 1, 2), + (1536, 1536, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (1536, 1536, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (1536, 1536, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (1536, 1536, 32768, 128, 128, False, True, True): (1, 256, 1, 4), + (1536, 1536, 32768, 128, 128, True, False, True): (2, 256, 2, 4), + (1536, 1536, 65536, 16, 16, False, True, True): (2, 512, 1, 2), + (1536, 1536, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (1536, 1536, 65536, 32, 32, False, True, True): (1, 256, 1, 2), + (1536, 1536, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (1536, 1536, 65536, 64, 64, False, True, True): (1, 512, 3, 4), + (1536, 1536, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (1536, 1536, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (1536, 1536, 65536, 128, 128, True, False, True): (4, 512, 2, 4), + (1536, 1536, 131072, 16, 16, False, True, True): (2, 1024, 1, 2), + (1536, 1536, 131072, 16, 16, True, False, True): (9, 512, 4, 4), + (1536, 1536, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (1536, 1536, 131072, 32, 32, True, False, True): (5, 512, 3, 4), + (1536, 1536, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (1536, 1536, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (1536, 1536, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (1536, 1536, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + (2048, 2048, 256, 16, 16, False, True, True): (1, 4, 5, 2), + (2048, 2048, 256, 16, 16, True, False, True): (4, 4, 5, 2), + (2048, 2048, 256, 32, 32, False, True, True): (3, 4, 6, 2), + (2048, 2048, 256, 32, 32, True, False, True): (2, 4, 5, 2), + (2048, 2048, 256, 64, 64, False, True, True): (2, 4, 4, 4), + (2048, 2048, 256, 64, 64, True, False, True): (2, 4, 3, 4), + (2048, 2048, 256, 128, 128, False, True, True): (3, 2, 2, 8), + (2048, 2048, 256, 128, 128, True, False, True): (3, 2, 2, 8), + (2048, 2048, 512, 16, 16, False, True, True): (3, 4, 4, 4), + (2048, 2048, 512, 16, 16, True, False, True): (1, 4, 4, 4), + (2048, 2048, 512, 32, 32, False, True, True): (1, 4, 3, 4), + (2048, 2048, 512, 32, 32, True, False, True): (1, 4, 4, 2), + (2048, 2048, 512, 64, 64, False, True, True): (1, 8, 3, 4), + (2048, 2048, 512, 64, 64, True, False, True): (1, 8, 3, 4), + (2048, 2048, 512, 128, 128, False, True, True): (3, 4, 2, 8), + (2048, 2048, 512, 128, 128, True, False, True): (2, 4, 2, 8), + (2048, 2048, 1024, 16, 16, False, True, True): (3, 4, 3, 4), + (2048, 2048, 1024, 16, 16, True, False, True): (4, 8, 3, 2), + (2048, 2048, 1024, 32, 32, False, True, True): (3, 8, 3, 4), + (2048, 2048, 1024, 32, 32, True, False, True): (1, 8, 3, 2), + (2048, 2048, 1024, 64, 64, False, True, True): (1, 8, 3, 4), + (2048, 2048, 1024, 64, 64, True, False, True): (1, 8, 3, 4), + (2048, 2048, 1024, 128, 128, False, True, True): (4, 8, 1, 4), + (2048, 2048, 1024, 128, 128, True, False, True): (2, 8, 1, 4), + (2048, 2048, 2048, 16, 16, False, True, True): (4, 16, 3, 2), + (2048, 2048, 2048, 16, 16, True, False, True): (4, 16, 3, 2), + (2048, 2048, 2048, 32, 32, False, True, True): (1, 16, 3, 2), + (2048, 2048, 2048, 32, 32, True, False, True): (1, 16, 3, 2), + (2048, 2048, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (2048, 2048, 2048, 64, 64, True, False, True): (4, 16, 3, 4), + (2048, 2048, 2048, 128, 128, False, True, True): (6, 16, 2, 8), + (2048, 2048, 2048, 128, 128, True, False, True): (3, 16, 1, 4), + (2048, 2048, 4096, 16, 16, False, True, True): (4, 32, 4, 2), + (2048, 2048, 4096, 16, 16, True, False, True): (4, 32, 3, 2), + (2048, 2048, 4096, 32, 32, False, True, True): (4, 16, 3, 8), + (2048, 2048, 4096, 32, 32, True, False, True): (4, 16, 3, 8), + (2048, 2048, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (2048, 2048, 4096, 64, 64, True, False, True): (3, 32, 3, 4), + (2048, 2048, 4096, 128, 128, False, True, True): (2, 32, 1, 4), + (2048, 2048, 4096, 128, 128, True, False, True): (2, 32, 1, 4), + (2048, 2048, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (2048, 2048, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (2048, 2048, 8192, 32, 32, False, True, True): (4, 32, 4, 8), + (2048, 2048, 8192, 32, 32, True, False, True): (4, 32, 3, 8), + (2048, 2048, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (2048, 2048, 8192, 64, 64, True, False, True): (4, 64, 3, 4), + (2048, 2048, 8192, 128, 128, False, True, True): (2, 64, 1, 4), + (2048, 2048, 8192, 128, 128, True, False, True): (2, 64, 1, 4), + (2048, 2048, 16384, 16, 16, False, True, True): (4, 64, 3, 2), + (2048, 2048, 16384, 16, 16, True, False, True): (1, 64, 3, 2), + (2048, 2048, 16384, 32, 32, False, True, True): (4, 64, 3, 4), + (2048, 2048, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (2048, 2048, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (2048, 2048, 16384, 64, 64, True, False, True): (4, 128, 3, 4), + (2048, 2048, 16384, 128, 128, False, True, True): (2, 128, 1, 4), + (2048, 2048, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (2048, 2048, 32768, 16, 16, False, True, True): (8, 128, 3, 2), + (2048, 2048, 32768, 16, 16, True, False, True): (8, 128, 3, 4), + (2048, 2048, 32768, 32, 32, False, True, True): (8, 128, 3, 4), + (2048, 2048, 32768, 32, 32, True, False, True): (8, 128, 3, 4), + (2048, 2048, 32768, 64, 64, False, True, True): (1, 128, 2, 4), + (2048, 2048, 32768, 64, 64, True, False, True): (8, 256, 3, 4), + (2048, 2048, 32768, 128, 128, False, True, True): (2, 256, 1, 4), + (2048, 2048, 32768, 128, 128, True, False, True): (2, 256, 1, 4), + (2048, 2048, 65536, 16, 16, False, True, True): (9, 256, 4, 4), + (2048, 2048, 65536, 16, 16, True, False, True): (7, 256, 4, 4), + (2048, 2048, 65536, 32, 32, False, True, True): (7, 256, 3, 4), + (2048, 2048, 65536, 32, 32, True, False, True): (3, 256, 3, 4), + (2048, 2048, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (2048, 2048, 65536, 64, 64, True, False, True): (6, 512, 3, 4), + (2048, 2048, 65536, 128, 128, False, True, True): (2, 512, 1, 4), + (2048, 2048, 65536, 128, 128, True, False, True): (2, 512, 1, 4), + (2048, 2048, 131072, 16, 16, False, True, True): (9, 512, 4, 4), + (2048, 2048, 131072, 16, 16, True, False, True): (9, 512, 4, 4), + (2048, 2048, 131072, 32, 32, False, True, True): (7, 512, 4, 4), + (2048, 2048, 131072, 32, 32, True, False, True): (3, 512, 3, 4), + (2048, 2048, 131072, 64, 64, False, True, True): (2, 512, 2, 4), + (2048, 2048, 131072, 64, 64, True, False, True): (4, 1024, 3, 4), + (2048, 2048, 131072, 128, 128, False, True, True): (1, 1024, 1, 4), + (2048, 2048, 131072, 128, 128, True, False, True): (2, 1024, 1, 4), + (3072, 768, 256, 16, 16, False, True, True): (6, 4, 1, 4), + (3072, 768, 256, 16, 16, True, False, True): (3, 1, 4, 4), + (3072, 768, 256, 32, 32, False, True, True): (6, 8, 1, 2), + (3072, 768, 256, 32, 32, True, False, True): (1, 2, 4, 4), + (3072, 768, 256, 64, 64, False, True, True): (1, 4, 4, 4), + (3072, 768, 256, 64, 64, True, False, True): (4, 2, 4, 4), + (3072, 768, 256, 128, 128, False, True, True): (1, 2, 3, 8), + (3072, 768, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (3072, 768, 512, 16, 16, False, True, True): (2, 4, 1, 4), + (3072, 768, 512, 16, 16, True, False, True): (1, 4, 4, 1), + (3072, 768, 512, 32, 32, False, True, True): (3, 8, 1, 4), + (3072, 768, 512, 32, 32, True, False, True): (1, 2, 3, 4), + (3072, 768, 512, 64, 64, False, True, True): (1, 8, 1, 4), + (3072, 768, 512, 64, 64, True, False, True): (4, 4, 3, 4), + (3072, 768, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (3072, 768, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (3072, 768, 1024, 16, 16, False, True, True): (1, 8, 1, 4), + (3072, 768, 1024, 16, 16, True, False, True): (3, 4, 3, 1), + (3072, 768, 1024, 32, 32, False, True, True): (1, 8, 1, 8), + (3072, 768, 1024, 32, 32, True, False, True): (1, 4, 4, 4), + (3072, 768, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (3072, 768, 1024, 64, 64, True, False, True): (1, 4, 3, 4), + (3072, 768, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (3072, 768, 1024, 128, 128, True, False, True): (2, 8, 3, 8), + (3072, 768, 2048, 16, 16, False, True, True): (3, 8, 1, 4), + (3072, 768, 2048, 16, 16, True, False, True): (2, 8, 3, 4), + (3072, 768, 2048, 32, 32, False, True, True): (3, 16, 1, 8), + (3072, 768, 2048, 32, 32, True, False, True): (3, 8, 3, 4), + (3072, 768, 2048, 64, 64, False, True, True): (1, 16, 1, 4), + (3072, 768, 2048, 64, 64, True, False, True): (1, 16, 3, 4), + (3072, 768, 2048, 128, 128, False, True, True): (1, 16, 3, 8), + (3072, 768, 2048, 128, 128, True, False, True): (2, 16, 2, 4), + (3072, 768, 4096, 16, 16, False, True, True): (1, 16, 1, 4), + (3072, 768, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (3072, 768, 4096, 32, 32, False, True, True): (2, 32, 1, 8), + (3072, 768, 4096, 32, 32, True, False, True): (7, 16, 3, 4), + (3072, 768, 4096, 64, 64, False, True, True): (2, 32, 1, 4), + (3072, 768, 4096, 64, 64, True, False, True): (2, 16, 2, 4), + (3072, 768, 4096, 128, 128, False, True, True): (1, 32, 3, 8), + (3072, 768, 4096, 128, 128, True, False, True): (3, 32, 2, 4), + (3072, 768, 8192, 16, 16, False, True, True): (2, 32, 1, 4), + (3072, 768, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (3072, 768, 8192, 32, 32, False, True, True): (4, 32, 1, 4), + (3072, 768, 8192, 32, 32, True, False, True): (4, 32, 3, 4), + (3072, 768, 8192, 64, 64, False, True, True): (2, 64, 1, 4), + (3072, 768, 8192, 64, 64, True, False, True): (4, 32, 2, 4), + (3072, 768, 8192, 128, 128, False, True, True): (3, 64, 1, 4), + (3072, 768, 8192, 128, 128, True, False, True): (6, 64, 2, 4), + (3072, 768, 16384, 16, 16, False, True, True): (1, 64, 1, 4), + (3072, 768, 16384, 16, 16, True, False, True): (1, 64, 1, 1), + (3072, 768, 16384, 32, 32, False, True, True): (1, 64, 1, 4), + (3072, 768, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (3072, 768, 16384, 64, 64, False, True, True): (4, 128, 1, 4), + (3072, 768, 16384, 64, 64, True, False, True): (4, 64, 2, 4), + (3072, 768, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (3072, 768, 16384, 128, 128, True, False, True): (4, 128, 2, 4), + (3072, 768, 32768, 16, 16, False, True, True): (1, 128, 1, 4), + (3072, 768, 32768, 16, 16, True, False, True): (8, 128, 4, 1), + (3072, 768, 32768, 32, 32, False, True, True): (1, 128, 1, 4), + (3072, 768, 32768, 32, 32, True, False, True): (8, 128, 3, 4), + (3072, 768, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (3072, 768, 32768, 64, 64, True, False, True): (1, 128, 2, 4), + (3072, 768, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (3072, 768, 32768, 128, 128, True, False, True): (8, 256, 2, 4), + (3072, 768, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (3072, 768, 50432, 16, 16, True, False, True): (7, 197, 4, 1), + (3072, 768, 50432, 32, 32, False, True, True): (1, 197, 1, 4), + (3072, 768, 50432, 32, 32, True, False, True): (4, 197, 3, 4), + (3072, 768, 50432, 64, 64, False, True, True): (1, 394, 1, 4), + (3072, 768, 50432, 64, 64, True, False, True): (3, 197, 2, 4), + (3072, 768, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (3072, 768, 50432, 128, 128, True, False, True): (8, 394, 2, 4), + (3072, 768, 65536, 16, 16, False, True, True): (1, 256, 1, 4), + (3072, 768, 65536, 16, 16, True, False, True): (15, 256, 4, 1), + (3072, 768, 65536, 32, 32, False, True, True): (1, 256, 1, 4), + (3072, 768, 65536, 32, 32, True, False, True): (15, 256, 3, 4), + (3072, 768, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (3072, 768, 65536, 64, 64, True, False, True): (2, 256, 2, 4), + (3072, 768, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (3072, 768, 65536, 128, 128, True, False, True): (3, 512, 2, 4), + (3072, 768, 131072, 16, 16, False, True, True): (1, 512, 1, 4), + (3072, 768, 131072, 16, 16, True, False, True): (15, 512, 4, 1), + (3072, 768, 131072, 32, 32, False, True, True): (1, 512, 1, 4), + (3072, 768, 131072, 32, 32, True, False, True): (9, 512, 3, 4), + (3072, 768, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (3072, 768, 131072, 64, 64, True, False, True): (3, 512, 2, 4), + (3072, 768, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (3072, 768, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + (3072, 3072, 256, 16, 16, False, True, True): (5, 4, 1, 4), + (3072, 3072, 256, 16, 16, True, False, True): (1, 2, 5, 2), + (3072, 3072, 256, 32, 32, False, True, True): (5, 4, 1, 8), + (3072, 3072, 256, 32, 32, True, False, True): (1, 4, 4, 2), + (3072, 3072, 256, 64, 64, False, True, True): (2, 4, 4, 4), + (3072, 3072, 256, 64, 64, True, False, True): (2, 4, 4, 4), + (3072, 3072, 256, 128, 128, False, True, True): (1, 2, 3, 8), + (3072, 3072, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (3072, 3072, 512, 16, 16, False, True, True): (5, 4, 1, 2), + (3072, 3072, 512, 16, 16, True, False, True): (1, 2, 3, 4), + (3072, 3072, 512, 32, 32, False, True, True): (3, 8, 1, 4), + (3072, 3072, 512, 32, 32, True, False, True): (1, 4, 4, 2), + (3072, 3072, 512, 64, 64, False, True, True): (1, 8, 2, 2), + (3072, 3072, 512, 64, 64, True, False, True): (2, 4, 3, 4), + (3072, 3072, 512, 128, 128, False, True, True): (2, 4, 3, 8), + (3072, 3072, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (3072, 3072, 1024, 16, 16, False, True, True): (1, 8, 1, 4), + (3072, 3072, 1024, 16, 16, True, False, True): (2, 8, 3, 1), + (3072, 3072, 1024, 32, 32, False, True, True): (1, 16, 1, 4), + (3072, 3072, 1024, 32, 32, True, False, True): (1, 4, 4, 4), + (3072, 3072, 1024, 64, 64, False, True, True): (1, 8, 3, 4), + (3072, 3072, 1024, 64, 64, True, False, True): (2, 4, 3, 4), + (3072, 3072, 1024, 128, 128, False, True, True): (1, 8, 1, 4), + (3072, 3072, 1024, 128, 128, True, False, True): (2, 8, 3, 8), + (3072, 3072, 2048, 16, 16, False, True, True): (1, 16, 1, 2), + (3072, 3072, 2048, 16, 16, True, False, True): (2, 16, 4, 2), + (3072, 3072, 2048, 32, 32, False, True, True): (1, 16, 1, 8), + (3072, 3072, 2048, 32, 32, True, False, True): (3, 8, 4, 4), + (3072, 3072, 2048, 64, 64, False, True, True): (3, 16, 3, 4), + (3072, 3072, 2048, 64, 64, True, False, True): (3, 8, 3, 4), + (3072, 3072, 2048, 128, 128, False, True, True): (1, 16, 3, 8), + (3072, 3072, 2048, 128, 128, True, False, True): (5, 16, 3, 8), + (3072, 3072, 4096, 16, 16, False, True, True): (1, 32, 1, 2), + (3072, 3072, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (3072, 3072, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (3072, 3072, 4096, 32, 32, True, False, True): (3, 16, 3, 4), + (3072, 3072, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (3072, 3072, 4096, 64, 64, True, False, True): (3, 16, 3, 4), + (3072, 3072, 4096, 128, 128, False, True, True): (3, 32, 3, 8), + (3072, 3072, 4096, 128, 128, True, False, True): (3, 32, 3, 8), + (3072, 3072, 8192, 16, 16, False, True, True): (1, 64, 1, 2), + (3072, 3072, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (3072, 3072, 8192, 32, 32, False, True, True): (1, 64, 1, 8), + (3072, 3072, 8192, 32, 32, True, False, True): (6, 32, 3, 4), + (3072, 3072, 8192, 64, 64, False, True, True): (1, 64, 3, 4), + (3072, 3072, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (3072, 3072, 8192, 128, 128, False, True, True): (2, 64, 3, 8), + (3072, 3072, 8192, 128, 128, True, False, True): (1, 64, 3, 8), + (3072, 3072, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (3072, 3072, 16384, 16, 16, True, False, True): (4, 128, 4, 2), + (3072, 3072, 16384, 32, 32, False, True, True): (1, 64, 1, 2), + (3072, 3072, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (3072, 3072, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (3072, 3072, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (3072, 3072, 16384, 128, 128, False, True, True): (1, 128, 1, 4), + (3072, 3072, 16384, 128, 128, True, False, True): (1, 128, 3, 8), + (3072, 3072, 32768, 16, 16, False, True, True): (1, 256, 1, 2), + (3072, 3072, 32768, 16, 16, True, False, True): (8, 128, 4, 4), + (3072, 3072, 32768, 32, 32, False, True, True): (1, 256, 1, 8), + (3072, 3072, 32768, 32, 32, True, False, True): (5, 128, 3, 4), + (3072, 3072, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (3072, 3072, 32768, 64, 64, True, False, True): (3, 128, 3, 4), + (3072, 3072, 32768, 128, 128, False, True, True): (1, 256, 1, 4), + (3072, 3072, 32768, 128, 128, True, False, True): (3, 256, 2, 4), + (3072, 3072, 65536, 16, 16, False, True, True): (1, 512, 1, 2), + (3072, 3072, 65536, 16, 16, True, False, True): (7, 256, 4, 4), + (3072, 3072, 65536, 32, 32, False, True, True): (1, 256, 1, 2), + (3072, 3072, 65536, 32, 32, True, False, True): (5, 256, 3, 4), + (3072, 3072, 65536, 64, 64, False, True, True): (1, 512, 3, 4), + (3072, 3072, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (3072, 3072, 65536, 128, 128, False, True, True): (1, 512, 1, 4), + (3072, 3072, 65536, 128, 128, True, False, True): (3, 512, 2, 4), + (3072, 3072, 131072, 16, 16, False, True, True): (1, 1024, 1, 2), + (3072, 3072, 131072, 16, 16, True, False, True): (5, 512, 4, 4), + (3072, 3072, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (3072, 3072, 131072, 32, 32, True, False, True): (5, 512, 3, 4), + (3072, 3072, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (3072, 3072, 131072, 64, 64, True, False, True): (3, 512, 3, 4), + (3072, 3072, 131072, 128, 128, False, True, True): (1, 1024, 1, 4), + (3072, 3072, 131072, 128, 128, True, False, True): (6, 1024, 2, 4), + (4096, 4096, 256, 16, 16, False, True, True): (2, 2, 5, 4), + (4096, 4096, 256, 16, 16, True, False, True): (2, 2, 4, 2), + (4096, 4096, 256, 32, 32, False, True, True): (1, 2, 4, 4), + (4096, 4096, 256, 32, 32, True, False, True): (3, 2, 4, 2), + (4096, 4096, 256, 64, 64, False, True, True): (3, 4, 3, 4), + (4096, 4096, 256, 64, 64, True, False, True): (1, 4, 3, 2), + (4096, 4096, 256, 128, 128, False, True, True): (1, 2, 2, 8), + (4096, 4096, 256, 128, 128, True, False, True): (1, 2, 2, 8), + (4096, 4096, 512, 16, 16, False, True, True): (4, 2, 3, 4), + (4096, 4096, 512, 16, 16, True, False, True): (1, 2, 3, 4), + (4096, 4096, 512, 32, 32, False, True, True): (1, 4, 3, 4), + (4096, 4096, 512, 32, 32, True, False, True): (3, 4, 3, 2), + (4096, 4096, 512, 64, 64, False, True, True): (4, 4, 4, 4), + (4096, 4096, 512, 64, 64, True, False, True): (3, 4, 3, 4), + (4096, 4096, 512, 128, 128, False, True, True): (2, 4, 2, 8), + (4096, 4096, 512, 128, 128, True, False, True): (2, 4, 1, 4), + (4096, 4096, 1024, 16, 16, False, True, True): (2, 8, 3, 2), + (4096, 4096, 1024, 16, 16, True, False, True): (2, 8, 3, 2), + (4096, 4096, 1024, 32, 32, False, True, True): (1, 8, 3, 4), + (4096, 4096, 1024, 32, 32, True, False, True): (1, 8, 3, 2), + (4096, 4096, 1024, 64, 64, False, True, True): (1, 8, 3, 4), + (4096, 4096, 1024, 64, 64, True, False, True): (1, 8, 3, 4), + (4096, 4096, 1024, 128, 128, False, True, True): (4, 8, 1, 4), + (4096, 4096, 1024, 128, 128, True, False, True): (2, 8, 2, 8), + (4096, 4096, 2048, 16, 16, False, True, True): (2, 8, 4, 4), + (4096, 4096, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (4096, 4096, 2048, 32, 32, False, True, True): (4, 8, 3, 8), + (4096, 4096, 2048, 32, 32, True, False, True): (4, 8, 4, 8), + (4096, 4096, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (4096, 4096, 2048, 64, 64, True, False, True): (4, 16, 3, 4), + (4096, 4096, 2048, 128, 128, False, True, True): (1, 16, 1, 4), + (4096, 4096, 2048, 128, 128, True, False, True): (4, 16, 1, 4), + (4096, 4096, 4096, 16, 16, False, True, True): (4, 32, 4, 4), + (4096, 4096, 4096, 16, 16, True, False, True): (2, 32, 4, 4), + (4096, 4096, 4096, 32, 32, False, True, True): (4, 16, 4, 8), + (4096, 4096, 4096, 32, 32, True, False, True): (4, 16, 4, 8), + (4096, 4096, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (4096, 4096, 4096, 64, 64, True, False, True): (2, 32, 3, 4), + (4096, 4096, 4096, 128, 128, False, True, True): (2, 32, 1, 4), + (4096, 4096, 4096, 128, 128, True, False, True): (2, 32, 1, 4), + (4096, 4096, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (4096, 4096, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (4096, 4096, 8192, 32, 32, False, True, True): (4, 32, 4, 8), + (4096, 4096, 8192, 32, 32, True, False, True): (4, 32, 4, 8), + (4096, 4096, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (4096, 4096, 8192, 64, 64, True, False, True): (4, 64, 3, 4), + (4096, 4096, 8192, 128, 128, False, True, True): (1, 64, 1, 4), + (4096, 4096, 8192, 128, 128, True, False, True): (1, 64, 1, 4), + (4096, 4096, 16384, 16, 16, False, True, True): (4, 64, 4, 4), + (4096, 4096, 16384, 16, 16, True, False, True): (4, 64, 4, 4), + (4096, 4096, 16384, 32, 32, False, True, True): (4, 64, 4, 8), + (4096, 4096, 16384, 32, 32, True, False, True): (4, 64, 4, 8), + (4096, 4096, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (4096, 4096, 16384, 64, 64, True, False, True): (4, 128, 3, 4), + (4096, 4096, 16384, 128, 128, False, True, True): (1, 128, 1, 4), + (4096, 4096, 16384, 128, 128, True, False, True): (1, 128, 1, 4), + (4096, 4096, 32768, 16, 16, False, True, True): (8, 128, 4, 4), + (4096, 4096, 32768, 16, 16, True, False, True): (5, 128, 4, 4), + (4096, 4096, 32768, 32, 32, False, True, True): (5, 128, 4, 4), + (4096, 4096, 32768, 32, 32, True, False, True): (3, 128, 4, 8), + (4096, 4096, 32768, 64, 64, False, True, True): (3, 256, 3, 4), + (4096, 4096, 32768, 64, 64, True, False, True): (2, 256, 3, 4), + (4096, 4096, 32768, 128, 128, False, True, True): (1, 256, 1, 4), + (4096, 4096, 32768, 128, 128, True, False, True): (1, 256, 1, 4), + (4096, 4096, 65536, 16, 16, False, True, True): (5, 256, 4, 4), + (4096, 4096, 65536, 16, 16, True, False, True): (5, 256, 4, 4), + (4096, 4096, 65536, 32, 32, False, True, True): (4, 256, 4, 8), + (4096, 4096, 65536, 32, 32, True, False, True): (4, 256, 4, 8), + (4096, 4096, 65536, 64, 64, False, True, True): (1, 512, 3, 4), + (4096, 4096, 65536, 64, 64, True, False, True): (3, 512, 3, 4), + (4096, 4096, 65536, 128, 128, False, True, True): (1, 512, 1, 4), + (4096, 4096, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (4096, 4096, 131072, 16, 16, False, True, True): (5, 512, 4, 4), + (4096, 4096, 131072, 16, 16, True, False, True): (5, 512, 4, 4), + (4096, 4096, 131072, 32, 32, False, True, True): (4, 512, 4, 4), + (4096, 4096, 131072, 32, 32, True, False, True): (2, 512, 3, 4), + (4096, 4096, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (4096, 4096, 131072, 64, 64, True, False, True): (3, 1024, 3, 4), + (4096, 4096, 131072, 128, 128, False, True, True): (1, 1024, 1, 4), + (4096, 4096, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (6144, 6144, 256, 16, 16, False, True, True): (1, 2, 1, 4), + (6144, 6144, 256, 16, 16, True, False, True): (3, 1, 4, 4), + (6144, 6144, 256, 32, 32, False, True, True): (3, 2, 1, 8), + (6144, 6144, 256, 32, 32, True, False, True): (1, 1, 4, 4), + (6144, 6144, 256, 64, 64, False, True, True): (4, 2, 3, 4), + (6144, 6144, 256, 64, 64, True, False, True): (3, 2, 4, 4), + (6144, 6144, 256, 128, 128, False, True, True): (2, 2, 3, 8), + (6144, 6144, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (6144, 6144, 512, 16, 16, False, True, True): (4, 4, 1, 4), + (6144, 6144, 512, 16, 16, True, False, True): (3, 2, 3, 1), + (6144, 6144, 512, 32, 32, False, True, True): (1, 8, 1, 4), + (6144, 6144, 512, 32, 32, True, False, True): (1, 2, 3, 2), + (6144, 6144, 512, 64, 64, False, True, True): (2, 4, 3, 4), + (6144, 6144, 512, 64, 64, True, False, True): (2, 2, 3, 4), + (6144, 6144, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (6144, 6144, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (6144, 6144, 1024, 16, 16, False, True, True): (1, 8, 1, 2), + (6144, 6144, 1024, 16, 16, True, False, True): (4, 8, 4, 4), + (6144, 6144, 1024, 32, 32, False, True, True): (1, 8, 4, 2), + (6144, 6144, 1024, 32, 32, True, False, True): (1, 8, 4, 2), + (6144, 6144, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (6144, 6144, 1024, 64, 64, True, False, True): (1, 4, 3, 4), + (6144, 6144, 1024, 128, 128, False, True, True): (2, 8, 3, 8), + (6144, 6144, 1024, 128, 128, True, False, True): (1, 8, 3, 8), + (6144, 6144, 2048, 16, 16, False, True, True): (4, 4, 1, 4), + (6144, 6144, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (6144, 6144, 2048, 32, 32, False, True, True): (1, 16, 4, 2), + (6144, 6144, 2048, 32, 32, True, False, True): (4, 8, 4, 8), + (6144, 6144, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (6144, 6144, 2048, 64, 64, True, False, True): (2, 8, 3, 4), + (6144, 6144, 2048, 128, 128, False, True, True): (1, 16, 3, 8), + (6144, 6144, 2048, 128, 128, True, False, True): (4, 16, 3, 8), + (6144, 6144, 4096, 16, 16, False, True, True): (4, 8, 1, 4), + (6144, 6144, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (6144, 6144, 4096, 32, 32, False, True, True): (4, 16, 1, 2), + (6144, 6144, 4096, 32, 32, True, False, True): (2, 8, 3, 8), + (6144, 6144, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (6144, 6144, 4096, 64, 64, True, False, True): (4, 16, 3, 4), + (6144, 6144, 4096, 128, 128, False, True, True): (4, 32, 3, 8), + (6144, 6144, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (6144, 6144, 8192, 16, 16, False, True, True): (2, 16, 1, 2), + (6144, 6144, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (6144, 6144, 8192, 32, 32, False, True, True): (4, 32, 1, 2), + (6144, 6144, 8192, 32, 32, True, False, True): (4, 32, 4, 8), + (6144, 6144, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (6144, 6144, 8192, 64, 64, True, False, True): (4, 32, 3, 4), + (6144, 6144, 8192, 128, 128, False, True, True): (4, 64, 3, 8), + (6144, 6144, 8192, 128, 128, True, False, True): (4, 64, 3, 8), + (6144, 6144, 16384, 16, 16, False, True, True): (2, 32, 1, 2), + (6144, 6144, 16384, 16, 16, True, False, True): (4, 64, 4, 4), + (6144, 6144, 16384, 32, 32, False, True, True): (4, 64, 1, 2), + (6144, 6144, 16384, 32, 32, True, False, True): (4, 64, 3, 2), + (6144, 6144, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (6144, 6144, 16384, 64, 64, True, False, True): (2, 32, 3, 8), + (6144, 6144, 16384, 128, 128, False, True, True): (4, 128, 3, 8), + (6144, 6144, 16384, 128, 128, True, False, True): (4, 128, 3, 8), + (6144, 6144, 32768, 16, 16, False, True, True): (2, 64, 1, 2), + (6144, 6144, 32768, 16, 16, True, False, True): (3, 128, 4, 4), + (6144, 6144, 32768, 32, 32, False, True, True): (4, 128, 1, 2), + (6144, 6144, 32768, 32, 32, True, False, True): (3, 128, 3, 4), + (6144, 6144, 32768, 64, 64, False, True, True): (4, 256, 3, 4), + (6144, 6144, 32768, 64, 64, True, False, True): (2, 64, 3, 8), + (6144, 6144, 32768, 128, 128, False, True, True): (4, 256, 3, 8), + (6144, 6144, 32768, 128, 128, True, False, True): (4, 256, 3, 8), + (6144, 6144, 65536, 16, 16, False, True, True): (2, 128, 1, 2), + (6144, 6144, 65536, 16, 16, True, False, True): (4, 256, 4, 4), + (6144, 6144, 65536, 32, 32, False, True, True): (4, 256, 1, 2), + (6144, 6144, 65536, 32, 32, True, False, True): (4, 256, 3, 4), + (6144, 6144, 65536, 64, 64, False, True, True): (4, 512, 3, 4), + (6144, 6144, 65536, 64, 64, True, False, True): (2, 128, 3, 8), + (6144, 6144, 65536, 128, 128, False, True, True): (4, 512, 3, 8), + (6144, 6144, 65536, 128, 128, True, False, True): (4, 512, 3, 8), + (6144, 6144, 131072, 16, 16, False, True, True): (2, 256, 1, 2), + (6144, 6144, 131072, 16, 16, True, False, True): (5, 512, 4, 1), + (6144, 6144, 131072, 32, 32, False, True, True): (4, 512, 1, 2), + (6144, 6144, 131072, 32, 32, True, False, True): (4, 512, 3, 2), + (6144, 6144, 131072, 64, 64, False, True, True): (4, 1024, 3, 4), + (6144, 6144, 131072, 64, 64, True, False, True): (2, 256, 3, 8), + (6144, 6144, 131072, 128, 128, False, True, True): (4, 1024, 3, 8), + (6144, 6144, 131072, 128, 128, True, False, True): (4, 1024, 3, 8), + (8192, 8192, 256, 16, 16, False, True, True): (1, 1, 3, 4), + (8192, 8192, 256, 16, 16, True, False, True): (4, 1, 3, 4), + (8192, 8192, 256, 32, 32, False, True, True): (1, 2, 3, 4), + (8192, 8192, 256, 32, 32, True, False, True): (1, 2, 3, 4), + (8192, 8192, 256, 64, 64, False, True, True): (6, 2, 3, 8), + (8192, 8192, 256, 64, 64, True, False, True): (4, 2, 3, 8), + (8192, 8192, 256, 128, 128, False, True, True): (1, 2, 1, 4), + (8192, 8192, 256, 128, 128, True, False, True): (1, 2, 1, 4), + (8192, 8192, 512, 16, 16, False, True, True): (4, 4, 3, 2), + (8192, 8192, 512, 16, 16, True, False, True): (4, 4, 3, 4), + (8192, 8192, 512, 32, 32, False, True, True): (1, 4, 3, 4), + (8192, 8192, 512, 32, 32, True, False, True): (3, 4, 3, 2), + (8192, 8192, 512, 64, 64, False, True, True): (1, 4, 3, 4), + (8192, 8192, 512, 64, 64, True, False, True): (1, 4, 3, 4), + (8192, 8192, 512, 128, 128, False, True, True): (4, 4, 2, 8), + (8192, 8192, 512, 128, 128, True, False, True): (4, 4, 2, 8), + (8192, 8192, 1024, 16, 16, False, True, True): (4, 8, 4, 4), + (8192, 8192, 1024, 16, 16, True, False, True): (2, 8, 4, 4), + (8192, 8192, 1024, 32, 32, False, True, True): (2, 4, 4, 8), + (8192, 8192, 1024, 32, 32, True, False, True): (1, 4, 3, 4), + (8192, 8192, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (8192, 8192, 1024, 64, 64, True, False, True): (2, 8, 3, 4), + (8192, 8192, 1024, 128, 128, False, True, True): (4, 8, 1, 4), + (8192, 8192, 1024, 128, 128, True, False, True): (4, 8, 1, 4), + (8192, 8192, 2048, 16, 16, False, True, True): (2, 8, 4, 4), + (8192, 8192, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (8192, 8192, 2048, 32, 32, False, True, True): (2, 8, 4, 8), + (8192, 8192, 2048, 32, 32, True, False, True): (2, 8, 4, 8), + (8192, 8192, 2048, 64, 64, False, True, True): (4, 8, 2, 4), + (8192, 8192, 2048, 64, 64, True, False, True): (4, 16, 3, 4), + (8192, 8192, 2048, 128, 128, False, True, True): (4, 16, 1, 4), + (8192, 8192, 2048, 128, 128, True, False, True): (4, 16, 1, 4), + (8192, 8192, 4096, 16, 16, False, True, True): (4, 16, 4, 4), + (8192, 8192, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (8192, 8192, 4096, 32, 32, False, True, True): (2, 16, 4, 8), + (8192, 8192, 4096, 32, 32, True, False, True): (2, 16, 4, 8), + (8192, 8192, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (8192, 8192, 4096, 64, 64, True, False, True): (4, 16, 2, 4), + (8192, 8192, 4096, 128, 128, False, True, True): (4, 32, 1, 4), + (8192, 8192, 4096, 128, 128, True, False, True): (4, 32, 1, 4), + (8192, 8192, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (8192, 8192, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (8192, 8192, 8192, 32, 32, False, True, True): (2, 32, 4, 8), + (8192, 8192, 8192, 32, 32, True, False, True): (2, 32, 4, 8), + (8192, 8192, 8192, 64, 64, False, True, True): (4, 32, 3, 8), + (8192, 8192, 8192, 64, 64, True, False, True): (4, 32, 2, 4), + (8192, 8192, 8192, 128, 128, False, True, True): (4, 64, 1, 4), + (8192, 8192, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (8192, 8192, 16384, 16, 16, False, True, True): (4, 64, 4, 4), + (8192, 8192, 16384, 16, 16, True, False, True): (4, 64, 4, 4), + (8192, 8192, 16384, 32, 32, False, True, True): (4, 64, 3, 4), + (8192, 8192, 16384, 32, 32, True, False, True): (4, 64, 4, 8), + (8192, 8192, 16384, 64, 64, False, True, True): (4, 64, 2, 4), + (8192, 8192, 16384, 64, 64, True, False, True): (4, 64, 2, 4), + (8192, 8192, 16384, 128, 128, False, True, True): (4, 128, 1, 4), + (8192, 8192, 16384, 128, 128, True, False, True): (4, 128, 1, 4), + (8192, 8192, 32768, 16, 16, False, True, True): (3, 128, 4, 4), + (8192, 8192, 32768, 16, 16, True, False, True): (3, 128, 4, 4), + (8192, 8192, 32768, 32, 32, False, True, True): (2, 128, 4, 8), + (8192, 8192, 32768, 32, 32, True, False, True): (2, 128, 4, 8), + (8192, 8192, 32768, 64, 64, False, True, True): (2, 128, 2, 4), + (8192, 8192, 32768, 64, 64, True, False, True): (2, 128, 2, 4), + (8192, 8192, 32768, 128, 128, False, True, True): (4, 256, 1, 4), + (8192, 8192, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (8192, 8192, 65536, 16, 16, False, True, True): (3, 256, 4, 4), + (8192, 8192, 65536, 16, 16, True, False, True): (3, 256, 4, 4), + (8192, 8192, 65536, 32, 32, False, True, True): (2, 256, 3, 4), + (8192, 8192, 65536, 32, 32, True, False, True): (2, 256, 3, 4), + (8192, 8192, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (8192, 8192, 65536, 64, 64, True, False, True): (2, 256, 3, 8), + (8192, 8192, 65536, 128, 128, False, True, True): (4, 512, 1, 4), + (8192, 8192, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (8192, 8192, 131072, 16, 16, False, True, True): (3, 512, 4, 4), + (8192, 8192, 131072, 16, 16, True, False, True): (3, 512, 4, 4), + (8192, 8192, 131072, 32, 32, False, True, True): (2, 512, 4, 4), + (8192, 8192, 131072, 32, 32, True, False, True): (2, 512, 3, 4), + (8192, 8192, 131072, 64, 64, False, True, True): (4, 512, 2, 4), + (8192, 8192, 131072, 64, 64, True, False, True): (2, 512, 2, 4), + (8192, 8192, 131072, 128, 128, False, True, True): (4, 1024, 1, 4), + (8192, 8192, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (16384, 16384, 256, 16, 16, False, True, True): (2, 2, 6, 4), + (16384, 16384, 256, 16, 16, True, False, True): (2, 2, 6, 4), + (16384, 16384, 256, 32, 32, False, True, True): (4, 2, 3, 2), + (16384, 16384, 256, 32, 32, True, False, True): (4, 2, 3, 2), + (16384, 16384, 256, 64, 64, False, True, True): (2, 2, 4, 4), + (16384, 16384, 256, 64, 64, True, False, True): (4, 2, 3, 8), + (16384, 16384, 256, 128, 128, False, True, True): (4, 2, 2, 8), + (16384, 16384, 256, 128, 128, True, False, True): (4, 2, 2, 8), + (16384, 16384, 512, 16, 16, False, True, True): (1, 2, 4, 4), + (16384, 16384, 512, 16, 16, True, False, True): (1, 2, 4, 4), + (16384, 16384, 512, 32, 32, False, True, True): (2, 2, 4, 8), + (16384, 16384, 512, 32, 32, True, False, True): (2, 2, 4, 8), + (16384, 16384, 512, 64, 64, False, True, True): (4, 4, 3, 4), + (16384, 16384, 512, 64, 64, True, False, True): (4, 4, 3, 4), + (16384, 16384, 512, 128, 128, False, True, True): (4, 4, 2, 8), + (16384, 16384, 512, 128, 128, True, False, True): (4, 4, 2, 8), + (16384, 16384, 1024, 16, 16, False, True, True): (3, 4, 4, 4), + (16384, 16384, 1024, 16, 16, True, False, True): (2, 8, 4, 4), + (16384, 16384, 1024, 32, 32, False, True, True): (2, 4, 4, 8), + (16384, 16384, 1024, 32, 32, True, False, True): (1, 4, 4, 8), + (16384, 16384, 1024, 64, 64, False, True, True): (2, 8, 3, 4), + (16384, 16384, 1024, 64, 64, True, False, True): (2, 8, 3, 4), + (16384, 16384, 1024, 128, 128, False, True, True): (4, 8, 1, 4), + (16384, 16384, 1024, 128, 128, True, False, True): (4, 8, 1, 4), + (16384, 16384, 2048, 16, 16, False, True, True): (2, 8, 4, 4), + (16384, 16384, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (16384, 16384, 2048, 32, 32, False, True, True): (1, 8, 4, 8), + (16384, 16384, 2048, 32, 32, True, False, True): (2, 8, 4, 8), + (16384, 16384, 2048, 64, 64, False, True, True): (2, 8, 2, 4), + (16384, 16384, 2048, 64, 64, True, False, True): (2, 8, 2, 4), + (16384, 16384, 2048, 128, 128, False, True, True): (4, 16, 1, 4), + (16384, 16384, 2048, 128, 128, True, False, True): (4, 16, 1, 4), + (16384, 16384, 4096, 16, 16, False, True, True): (2, 16, 4, 4), + (16384, 16384, 4096, 16, 16, True, False, True): (2, 16, 4, 4), + (16384, 16384, 4096, 32, 32, False, True, True): (1, 8, 3, 8), + (16384, 16384, 4096, 32, 32, True, False, True): (2, 16, 3, 4), + (16384, 16384, 4096, 64, 64, False, True, True): (2, 16, 2, 4), + (16384, 16384, 4096, 64, 64, True, False, True): (2, 16, 2, 4), + (16384, 16384, 4096, 128, 128, False, True, True): (4, 32, 1, 4), + (16384, 16384, 4096, 128, 128, True, False, True): (4, 32, 1, 4), + (16384, 16384, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (16384, 16384, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (16384, 16384, 8192, 32, 32, False, True, True): (2, 32, 4, 8), + (16384, 16384, 8192, 32, 32, True, False, True): (2, 32, 3, 4), + (16384, 16384, 8192, 64, 64, False, True, True): (2, 32, 4, 8), + (16384, 16384, 8192, 64, 64, True, False, True): (2, 32, 3, 8), + (16384, 16384, 8192, 128, 128, False, True, True): (4, 64, 1, 4), + (16384, 16384, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (16384, 16384, 16384, 16, 16, False, True, True): (1, 64, 4, 4), + (16384, 16384, 16384, 16, 16, True, False, True): (1, 64, 4, 4), + (16384, 16384, 16384, 32, 32, False, True, True): (1, 64, 3, 8), + (16384, 16384, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (16384, 16384, 16384, 64, 64, False, True, True): (1, 64, 2, 4), + (16384, 16384, 16384, 64, 64, True, False, True): (1, 64, 4, 8), + (16384, 16384, 16384, 128, 128, False, True, True): (4, 128, 1, 4), + (16384, 16384, 16384, 128, 128, True, False, True): (4, 128, 1, 4), + (16384, 16384, 32768, 16, 16, False, True, True): (1, 128, 4, 4), + (16384, 16384, 32768, 16, 16, True, False, True): (1, 128, 4, 4), + (16384, 16384, 32768, 32, 32, False, True, True): (1, 128, 4, 2), + (16384, 16384, 32768, 32, 32, True, False, True): (1, 128, 3, 8), + (16384, 16384, 32768, 64, 64, False, True, True): (2, 128, 2, 4), + (16384, 16384, 32768, 64, 64, True, False, True): (1, 128, 3, 8), + (16384, 16384, 32768, 128, 128, False, True, True): (4, 256, 1, 4), + (16384, 16384, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (16384, 16384, 65536, 16, 16, False, True, True): (1, 256, 4, 4), + (16384, 16384, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (16384, 16384, 65536, 32, 32, False, True, True): (1, 256, 3, 4), + (16384, 16384, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (16384, 16384, 65536, 64, 64, False, True, True): (1, 256, 2, 4), + (16384, 16384, 65536, 64, 64, True, False, True): (2, 256, 2, 4), + (16384, 16384, 65536, 128, 128, False, True, True): (4, 512, 1, 4), + (16384, 16384, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (16384, 16384, 131072, 16, 16, False, True, True): (2, 512, 4, 4), + (16384, 16384, 131072, 16, 16, True, False, True): (1, 512, 4, 4), + (16384, 16384, 131072, 32, 32, False, True, True): (1, 512, 4, 8), + (16384, 16384, 131072, 32, 32, True, False, True): (1, 512, 3, 4), + (16384, 16384, 131072, 64, 64, False, True, True): (2, 512, 2, 4), + (16384, 16384, 131072, 64, 64, True, False, True): (1, 512, 2, 4), + (16384, 16384, 131072, 128, 128, False, True, True): (4, 1024, 1, 4), + (16384, 16384, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.56)): { + (192, 192, 256, 64, 64, False, True, True): (3, 4, 3, 4), + (192, 192, 256, 64, 64, True, False, True): (1, 4, 4, 4), + (192, 192, 512, 64, 64, False, True, True): (2, 8, 3, 4), + (192, 192, 512, 64, 64, True, False, True): (2, 8, 3, 4), + (192, 192, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (192, 192, 1024, 64, 64, True, False, True): (1, 16, 5, 4), + (192, 192, 2048, 64, 64, False, True, True): (3, 32, 3, 4), + (192, 192, 2048, 64, 64, True, False, True): (5, 32, 3, 4), + (192, 192, 4096, 64, 64, False, True, True): (1, 64, 4, 4), + (192, 192, 4096, 64, 64, True, False, True): (2, 32, 3, 4), + (192, 192, 8192, 64, 64, False, True, True): (1, 128, 2, 4), + (192, 192, 8192, 64, 64, True, False, True): (1, 64, 3, 4), + (192, 192, 16384, 64, 64, False, True, True): (1, 256, 1, 4), + (192, 192, 16384, 64, 64, True, False, True): (1, 64, 3, 4), + (192, 192, 32768, 64, 64, False, True, True): (2, 512, 1, 2), + (192, 192, 32768, 64, 64, True, False, True): (2, 256, 2, 4), + (192, 192, 65536, 64, 64, False, True, True): (3, 512, 1, 4), + (192, 192, 65536, 64, 64, True, False, True): (1, 512, 2, 4), + (192, 192, 131072, 64, 64, False, True, True): (5, 1024, 1, 4), + (192, 192, 131072, 64, 64, True, False, True): (4, 512, 2, 4), + (384, 384, 256, 128, 128, False, True, True): (3, 2, 3, 8), + (384, 384, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (384, 384, 512, 128, 128, False, True, True): (4, 4, 3, 8), + (384, 384, 512, 128, 128, True, False, True): (3, 4, 3, 8), + (384, 384, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (384, 384, 1024, 128, 128, True, False, True): (2, 8, 3, 8), + (384, 384, 2048, 128, 128, False, True, True): (5, 16, 3, 8), + (384, 384, 2048, 128, 128, True, False, True): (5, 16, 3, 8), + (384, 384, 4096, 128, 128, False, True, True): (3, 32, 3, 8), + (384, 384, 4096, 128, 128, True, False, True): (6, 32, 3, 8), + (384, 384, 8192, 128, 128, False, True, True): (2, 64, 3, 8), + (384, 384, 8192, 128, 128, True, False, True): (4, 32, 2, 8), + (384, 384, 16384, 128, 128, False, True, True): (2, 128, 3, 8), + (384, 384, 16384, 128, 128, True, False, True): (5, 128, 2, 4), + (384, 384, 32768, 128, 128, False, True, True): (2, 256, 3, 8), + (384, 384, 32768, 128, 128, True, False, True): (3, 256, 2, 4), + (384, 384, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (384, 384, 65536, 128, 128, True, False, True): (1, 512, 2, 4), + (384, 384, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (384, 384, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.5)): { + (16, 16, 16, 16, 16, False, False, False): (1, 1, 1, 1), + (16, 16, 16, 16, 16, False, False, True): (1, 1, 2, 2), + (16, 16, 16, 16, 16, False, True, False): (1, 1, 1, 1), + (16, 16, 16, 16, 16, False, True, True): (1, 1, 1, 8), + (16, 16, 16, 16, 16, True, False, False): (3, 1, 3, 4), + (16, 16, 16, 16, 16, True, False, True): (1, 1, 2, 1), + (16, 16, 32, 16, 16, False, False, False): (1, 2, 1, 8), + (16, 16, 32, 16, 16, False, False, True): (1, 2, 1, 2), + (16, 16, 32, 16, 16, False, True, False): (2, 1, 1, 4), + (16, 16, 32, 16, 16, False, True, True): (1, 2, 1, 4), + (16, 16, 32, 16, 16, True, False, False): (1, 1, 1, 4), + (16, 16, 32, 16, 16, True, False, True): (1, 2, 1, 2), + (16, 16, 64, 16, 16, False, False, False): (1, 4, 1, 1), + (16, 16, 64, 16, 16, False, False, True): (1, 2, 2, 4), + (16, 16, 64, 16, 16, False, True, False): (1, 4, 1, 4), + (16, 16, 64, 16, 16, False, True, True): (1, 2, 1, 4), + (16, 16, 64, 16, 16, True, False, False): (1, 4, 1, 2), + (16, 16, 64, 16, 16, True, False, True): (1, 1, 1, 2), + (16, 32, 16, 16, 16, False, False, False): (1, 1, 2, 4), + (16, 32, 16, 16, 16, False, False, True): (1, 1, 1, 4), + (16, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2), + (16, 32, 16, 16, 16, False, True, True): (1, 1, 1, 2), + (16, 32, 16, 16, 16, True, False, False): (1, 1, 2, 16), + (16, 32, 16, 16, 16, True, False, True): (1, 1, 1, 4), + (16, 32, 16, 16, 32, False, False, False): (2, 1, 1, 8), + (16, 32, 16, 16, 32, False, False, True): (2, 1, 1, 8), + (16, 32, 16, 16, 32, False, True, False): (1, 1, 2, 1), + (16, 32, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (16, 32, 16, 16, 32, True, False, False): (2, 1, 1, 8), + (16, 32, 16, 16, 32, True, False, True): (1, 1, 2, 4), + (16, 32, 32, 16, 16, False, False, False): (1, 1, 1, 16), + (16, 32, 32, 16, 16, False, False, True): (1, 2, 1, 2), + (16, 32, 32, 16, 16, False, True, False): (1, 2, 1, 8), + (16, 32, 32, 16, 16, False, True, True): (3, 2, 1, 4), + (16, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4), + (16, 32, 32, 16, 16, True, False, True): (1, 2, 1, 2), + (16, 32, 32, 16, 32, False, False, False): (1, 2, 1, 2), + (16, 32, 32, 16, 32, False, False, True): (1, 1, 1, 4), + (16, 32, 32, 16, 32, False, True, False): (1, 1, 2, 4), + (16, 32, 32, 16, 32, False, True, True): (1, 2, 1, 2), + (16, 32, 32, 16, 32, True, False, False): (1, 2, 1, 2), + (16, 32, 32, 16, 32, True, False, True): (1, 2, 1, 16), + (16, 32, 64, 16, 16, False, False, False): (1, 4, 1, 4), + (16, 32, 64, 16, 16, False, False, True): (2, 4, 1, 4), + (16, 32, 64, 16, 16, False, True, False): (1, 4, 1, 4), + (16, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4), + (16, 32, 64, 16, 16, True, False, False): (3, 4, 1, 2), + (16, 32, 64, 16, 16, True, False, True): (1, 4, 1, 1), + (16, 32, 64, 16, 32, False, False, False): (1, 4, 1, 16), + (16, 32, 64, 16, 32, False, False, True): (1, 2, 1, 2), + (16, 32, 64, 16, 32, False, True, False): (1, 4, 2, 2), + (16, 32, 64, 16, 32, False, True, True): (1, 4, 1, 8), + (16, 32, 64, 16, 32, True, False, False): (1, 4, 1, 8), + (16, 32, 64, 16, 32, True, False, True): (1, 2, 1, 4), + (16, 64, 16, 16, 32, False, False, False): (1, 1, 1, 2), + (16, 64, 16, 16, 32, False, False, True): (1, 1, 1, 4), + (16, 64, 16, 16, 32, False, True, False): (2, 1, 2, 4), + (16, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (16, 64, 16, 16, 32, True, False, False): (1, 1, 1, 4), + (16, 64, 16, 16, 32, True, False, True): (1, 1, 1, 4), + (16, 64, 32, 16, 32, False, False, False): (1, 2, 1, 2), + (16, 64, 32, 16, 32, False, False, True): (1, 1, 1, 4), + (16, 64, 32, 16, 32, False, True, False): (1, 1, 1, 4), + (16, 64, 32, 16, 32, False, True, True): (1, 2, 3, 2), + (16, 64, 32, 16, 32, True, False, False): (1, 1, 1, 4), + (16, 64, 32, 16, 32, True, False, True): (1, 1, 2, 4), + (16, 64, 64, 16, 32, False, False, False): (1, 4, 1, 8), + (16, 64, 64, 16, 32, False, False, True): (1, 4, 1, 4), + (16, 64, 64, 16, 32, False, True, False): (1, 4, 1, 1), + (16, 64, 64, 16, 32, False, True, True): (2, 4, 1, 4), + (16, 64, 64, 16, 32, True, False, False): (1, 4, 1, 4), + (16, 64, 64, 16, 32, True, False, True): (1, 4, 1, 4), + (32, 16, 16, 16, 16, False, False, False): (2, 1, 2, 4), + (32, 16, 16, 16, 16, False, False, True): (2, 1, 1, 2), + (32, 16, 16, 16, 16, False, True, False): (1, 1, 2, 4), + (32, 16, 16, 16, 16, False, True, True): (1, 1, 1, 2), + (32, 16, 16, 16, 16, True, False, False): (1, 1, 1, 4), + (32, 16, 16, 16, 16, True, False, True): (2, 1, 1, 2), + (32, 16, 32, 16, 16, False, False, False): (1, 1, 1, 4), + (32, 16, 32, 16, 16, False, False, True): (1, 1, 1, 4), + (32, 16, 32, 16, 16, False, True, False): (1, 2, 1, 4), + (32, 16, 32, 16, 16, False, True, True): (2, 2, 1, 4), + (32, 16, 32, 16, 16, True, False, False): (2, 1, 1, 4), + (32, 16, 32, 16, 16, True, False, True): (2, 2, 1, 2), + (32, 16, 64, 16, 16, False, False, False): (1, 4, 1, 2), + (32, 16, 64, 16, 16, False, False, True): (1, 4, 1, 4), + (32, 16, 64, 16, 16, False, True, False): (1, 2, 1, 4), + (32, 16, 64, 16, 16, False, True, True): (1, 4, 1, 2), + (32, 16, 64, 16, 16, True, False, False): (1, 4, 2, 8), + (32, 16, 64, 16, 16, True, False, True): (1, 4, 1, 1), + (32, 32, 16, 16, 16, False, False, False): (1, 1, 1, 4), + (32, 32, 16, 16, 16, False, False, True): (2, 1, 1, 4), + (32, 32, 16, 16, 16, False, True, False): (1, 1, 2, 4), + (32, 32, 16, 16, 16, False, True, True): (1, 1, 2, 2), + (32, 32, 16, 16, 16, True, False, False): (1, 1, 1, 8), + (32, 32, 16, 16, 16, True, False, True): (1, 1, 1, 4), + (32, 32, 16, 16, 32, False, False, False): (1, 1, 3, 2), + (32, 32, 16, 16, 32, False, False, True): (2, 1, 1, 4), + (32, 32, 16, 16, 32, False, True, False): (3, 1, 1, 4), + (32, 32, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (32, 32, 16, 16, 32, True, False, False): (2, 1, 1, 8), + (32, 32, 16, 16, 32, True, False, True): (1, 1, 3, 2), + (32, 32, 16, 32, 32, False, False, False): (1, 1, 1, 2), + (32, 32, 16, 32, 32, False, False, True): (2, 1, 1, 8), + (32, 32, 16, 32, 32, False, True, False): (1, 1, 1, 2), + (32, 32, 16, 32, 32, False, True, True): (1, 1, 1, 8), + (32, 32, 16, 32, 32, True, False, False): (1, 1, 2, 4), + (32, 32, 16, 32, 32, True, False, True): (1, 1, 1, 2), + (32, 32, 32, 16, 16, False, False, False): (1, 1, 1, 4), + (32, 32, 32, 16, 16, False, False, True): (1, 2, 1, 4), + (32, 32, 32, 16, 16, False, True, False): (1, 2, 1, 4), + (32, 32, 32, 16, 16, False, True, True): (1, 2, 1, 2), + (32, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4), + (32, 32, 32, 16, 16, True, False, True): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, False, False): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, False, True): (1, 2, 1, 2), + (32, 32, 32, 16, 32, False, True, False): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, True, True): (1, 2, 1, 2), + (32, 32, 32, 16, 32, True, False, False): (1, 2, 1, 1), + (32, 32, 32, 16, 32, True, False, True): (1, 2, 1, 2), + (32, 32, 32, 32, 32, False, False, False): (1, 1, 1, 4), + (32, 32, 32, 32, 32, False, False, True): (2, 1, 1, 4), + (32, 32, 32, 32, 32, False, True, False): (1, 1, 1, 8), + (32, 32, 32, 32, 32, False, True, True): (1, 1, 1, 8), + (32, 32, 32, 32, 32, True, False, False): (1, 1, 3, 4), + (32, 32, 32, 32, 32, True, False, True): (1, 1, 1, 8), + (32, 32, 64, 16, 16, False, False, False): (1, 4, 1, 4), + (32, 32, 64, 16, 16, False, False, True): (1, 4, 1, 2), + (32, 32, 64, 16, 16, False, True, False): (1, 1, 1, 4), + (32, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4), + (32, 32, 64, 16, 16, True, False, False): (1, 4, 1, 8), + (32, 32, 64, 16, 16, True, False, True): (1, 4, 1, 2), + (32, 32, 64, 16, 32, False, False, False): (1, 1, 1, 4), + (32, 32, 64, 16, 32, False, False, True): (1, 4, 1, 4), + (32, 32, 64, 16, 32, False, True, False): (1, 1, 1, 4), + (32, 32, 64, 16, 32, False, True, True): (1, 4, 1, 4), + (32, 32, 64, 16, 32, True, False, False): (2, 2, 1, 8), + (32, 32, 64, 16, 32, True, False, True): (1, 2, 1, 2), + (32, 32, 64, 32, 32, False, False, False): (1, 2, 1, 4), + (32, 32, 64, 32, 32, False, False, True): (1, 2, 1, 1), + (32, 32, 64, 32, 32, False, True, False): (1, 2, 2, 8), + (32, 32, 64, 32, 32, False, True, True): (1, 1, 1, 4), + (32, 32, 64, 32, 32, True, False, False): (1, 2, 1, 4), + (32, 32, 64, 32, 32, True, False, True): (2, 2, 1, 4), + (32, 64, 16, 16, 32, False, False, False): (1, 1, 1, 8), + (32, 64, 16, 16, 32, False, False, True): (1, 1, 1, 4), + (32, 64, 16, 16, 32, False, True, False): (2, 1, 1, 4), + (32, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (32, 64, 16, 16, 32, True, False, False): (1, 1, 2, 4), + (32, 64, 16, 16, 32, True, False, True): (1, 1, 2, 2), + (32, 64, 16, 32, 32, False, False, False): (1, 1, 1, 8), + (32, 64, 16, 32, 32, False, False, True): (2, 1, 1, 4), + (32, 64, 16, 32, 32, False, True, False): (1, 1, 1, 4), + (32, 64, 16, 32, 32, False, True, True): (1, 1, 2, 2), + (32, 64, 16, 32, 32, True, False, False): (1, 1, 1, 2), + (32, 64, 16, 32, 32, True, False, True): (2, 1, 2, 4), + (32, 64, 32, 16, 32, False, False, False): (1, 1, 1, 4), + (32, 64, 32, 16, 32, False, False, True): (1, 2, 1, 2), + (32, 64, 32, 16, 32, False, True, False): (1, 2, 3, 4), + (32, 64, 32, 16, 32, False, True, True): (2, 2, 1, 4), + (32, 64, 32, 16, 32, True, False, False): (1, 1, 1, 4), + (32, 64, 32, 16, 32, True, False, True): (1, 2, 2, 1), + (32, 64, 32, 32, 32, False, False, False): (1, 1, 1, 8), + (32, 64, 32, 32, 32, False, False, True): (1, 1, 1, 4), + (32, 64, 32, 32, 32, False, True, False): (1, 1, 2, 4), + (32, 64, 32, 32, 32, False, True, True): (1, 1, 1, 4), + (32, 64, 32, 32, 32, True, False, False): (2, 1, 1, 2), + (32, 64, 32, 32, 32, True, False, True): (1, 1, 1, 4), + (32, 64, 64, 16, 32, False, False, False): (1, 4, 2, 1), + (32, 64, 64, 16, 32, False, False, True): (3, 4, 1, 4), + (32, 64, 64, 16, 32, False, True, False): (1, 1, 1, 8), + (32, 64, 64, 16, 32, False, True, True): (1, 4, 1, 4), + (32, 64, 64, 16, 32, True, False, False): (1, 4, 1, 4), + (32, 64, 64, 16, 32, True, False, True): (2, 2, 3, 4), + (32, 64, 64, 32, 32, False, False, False): (1, 2, 1, 4), + (32, 64, 64, 32, 32, False, False, True): (1, 2, 1, 4), + (32, 64, 64, 32, 32, False, True, False): (1, 2, 2, 8), + (32, 64, 64, 32, 32, False, True, True): (1, 2, 1, 4), + (32, 64, 64, 32, 32, True, False, False): (1, 2, 2, 4), + (32, 64, 64, 32, 32, True, False, True): (1, 2, 1, 4), + (64, 32, 16, 32, 32, False, False, False): (1, 1, 1, 1), + (64, 32, 16, 32, 32, False, False, True): (1, 1, 2, 4), + (64, 32, 16, 32, 32, False, True, False): (2, 1, 1, 8), + (64, 32, 16, 32, 32, False, True, True): (1, 1, 1, 4), + (64, 32, 16, 32, 32, True, False, False): (2, 1, 1, 2), + (64, 32, 16, 32, 32, True, False, True): (1, 1, 1, 4), + (64, 32, 32, 32, 32, False, False, False): (3, 1, 1, 4), + (64, 32, 32, 32, 32, False, False, True): (1, 1, 1, 4), + (64, 32, 32, 32, 32, False, True, False): (1, 1, 1, 8), + (64, 32, 32, 32, 32, False, True, True): (1, 1, 1, 2), + (64, 32, 32, 32, 32, True, False, False): (1, 1, 1, 2), + (64, 32, 32, 32, 32, True, False, True): (1, 1, 1, 4), + (64, 32, 64, 32, 32, False, False, False): (1, 2, 1, 2), + (64, 32, 64, 32, 32, False, False, True): (3, 2, 1, 4), + (64, 32, 64, 32, 32, False, True, False): (1, 1, 1, 1), + (64, 32, 64, 32, 32, False, True, True): (1, 2, 1, 4), + (64, 32, 64, 32, 32, True, False, False): (1, 1, 3, 4), + (64, 32, 64, 32, 32, True, False, True): (1, 2, 2, 4), + (64, 64, 16, 32, 32, False, False, False): (1, 1, 2, 2), + (64, 64, 16, 32, 32, False, False, True): (1, 1, 3, 2), + (64, 64, 16, 32, 32, False, True, False): (1, 1, 1, 8), + (64, 64, 16, 32, 32, False, True, True): (1, 1, 2, 4), + (64, 64, 16, 32, 32, True, False, False): (1, 1, 2, 4), + (64, 64, 16, 32, 32, True, False, True): (2, 1, 2, 4), + (64, 64, 32, 32, 32, False, False, False): (1, 1, 2, 8), + (64, 64, 32, 32, 32, False, False, True): (1, 1, 2, 4), + (64, 64, 32, 32, 32, False, True, False): (1, 1, 1, 4), + (64, 64, 32, 32, 32, False, True, True): (1, 1, 1, 4), + (64, 64, 32, 32, 32, True, False, False): (1, 1, 1, 4), + (64, 64, 32, 32, 32, True, False, True): (2, 1, 2, 4), + (64, 64, 64, 32, 32, False, False, False): (1, 2, 1, 4), + (64, 64, 64, 32, 32, False, False, True): (1, 2, 1, 4), + (64, 64, 64, 32, 32, False, True, False): (1, 2, 1, 4), + (64, 64, 64, 32, 32, False, True, True): (3, 2, 1, 4), + (64, 64, 64, 32, 32, True, False, False): (1, 2, 1, 8), + (64, 64, 64, 32, 32, True, False, True): (1, 2, 3, 4), + (192, 192, 256, 16, 16, False, True, True): (1, 8, 4, 2), + (192, 192, 256, 16, 16, True, False, True): (1, 4, 4, 4), + (192, 192, 256, 32, 32, False, True, True): (2, 8, 5, 4), + (192, 192, 256, 32, 32, True, False, True): (2, 8, 5, 1), + (192, 192, 512, 16, 16, False, True, True): (3, 8, 4, 4), + (192, 192, 512, 16, 16, True, False, True): (5, 8, 5, 4), + (192, 192, 512, 32, 32, False, True, True): (1, 16, 5, 4), + (192, 192, 512, 32, 32, True, False, True): (1, 8, 6, 2), + (192, 192, 1024, 16, 16, False, True, True): (1, 16, 4, 4), + (192, 192, 1024, 16, 16, True, False, True): (3, 16, 5, 2), + (192, 192, 1024, 32, 32, False, True, True): (3, 16, 4, 4), + (192, 192, 1024, 32, 32, True, False, True): (1, 16, 5, 4), + (192, 192, 2048, 16, 16, False, True, True): (2, 16, 3, 4), + (192, 192, 2048, 16, 16, True, False, True): (1, 16, 4, 4), + (192, 192, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (192, 192, 2048, 32, 32, True, False, True): (3, 16, 4, 4), + (192, 192, 4096, 16, 16, False, True, True): (1, 64, 1, 4), + (192, 192, 4096, 16, 16, True, False, True): (1, 16, 3, 4), + (192, 192, 4096, 32, 32, False, True, True): (1, 128, 1, 4), + (192, 192, 4096, 32, 32, True, False, True): (2, 32, 4, 2), + (192, 192, 8192, 16, 16, False, True, True): (1, 64, 1, 4), + (192, 192, 8192, 16, 16, True, False, True): (2, 64, 3, 2), + (192, 192, 8192, 32, 32, False, True, True): (1, 128, 1, 4), + (192, 192, 8192, 32, 32, True, False, True): (4, 32, 3, 4), + (192, 192, 16384, 16, 16, False, True, True): (1, 128, 1, 4), + (192, 192, 16384, 16, 16, True, False, True): (1, 64, 3, 2), + (192, 192, 16384, 32, 32, False, True, True): (1, 128, 1, 4), + (192, 192, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (192, 192, 32768, 16, 16, False, True, True): (2, 256, 1, 2), + (192, 192, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (192, 192, 32768, 32, 32, False, True, True): (2, 256, 1, 4), + (192, 192, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (192, 192, 65536, 16, 16, False, True, True): (2, 512, 1, 2), + (192, 192, 65536, 16, 16, True, False, True): (1, 256, 3, 2), + (192, 192, 65536, 32, 32, False, True, True): (2, 512, 1, 4), + (192, 192, 65536, 32, 32, True, False, True): (2, 256, 3, 4), + (192, 192, 131072, 16, 16, False, True, True): (4, 1024, 1, 2), + (192, 192, 131072, 16, 16, True, False, True): (3, 512, 3, 2), + (192, 192, 131072, 32, 32, False, True, True): (1, 1024, 1, 4), + (192, 192, 131072, 32, 32, True, False, True): (3, 512, 3, 4), + (256, 256, 256, 16, 16, False, True, True): (4, 8, 6, 2), + (256, 256, 256, 16, 16, True, False, True): (5, 16, 5, 1), + (256, 256, 256, 32, 32, False, True, True): (1, 8, 7, 4), + (256, 256, 256, 32, 32, True, False, True): (1, 8, 5, 4), + (256, 256, 256, 64, 64, False, True, True): (1, 4, 5, 4), + (256, 256, 256, 64, 64, True, False, True): (2, 4, 3, 4), + (256, 256, 256, 128, 128, False, True, True): (1, 2, 2, 8), + (256, 256, 256, 128, 128, True, False, True): (1, 2, 2, 8), + (256, 256, 512, 16, 16, False, True, True): (4, 8, 4, 4), + (256, 256, 512, 16, 16, True, False, True): (4, 8, 6, 2), + (256, 256, 512, 32, 32, False, True, True): (3, 8, 5, 4), + (256, 256, 512, 32, 32, True, False, True): (2, 8, 5, 4), + (256, 256, 512, 64, 64, False, True, True): (2, 8, 4, 4), + (256, 256, 512, 64, 64, True, False, True): (1, 8, 7, 4), + (256, 256, 512, 128, 128, False, True, True): (2, 4, 2, 8), + (256, 256, 512, 128, 128, True, False, True): (5, 4, 2, 8), + (256, 256, 1024, 16, 16, False, True, True): (1, 8, 4, 4), + (256, 256, 1024, 16, 16, True, False, True): (1, 16, 4, 2), + (256, 256, 1024, 32, 32, False, True, True): (5, 32, 5, 1), + (256, 256, 1024, 32, 32, True, False, True): (1, 16, 4, 2), + (256, 256, 1024, 64, 64, False, True, True): (1, 16, 4, 4), + (256, 256, 1024, 64, 64, True, False, True): (2, 16, 3, 4), + (256, 256, 1024, 128, 128, False, True, True): (9, 8, 2, 8), + (256, 256, 1024, 128, 128, True, False, True): (1, 8, 2, 8), + (256, 256, 2048, 16, 16, False, True, True): (6, 32, 5, 2), + (256, 256, 2048, 16, 16, True, False, True): (2, 32, 4, 2), + (256, 256, 2048, 32, 32, False, True, True): (1, 32, 3, 2), + (256, 256, 2048, 32, 32, True, False, True): (1, 32, 3, 2), + (256, 256, 2048, 64, 64, False, True, True): (2, 32, 4, 4), + (256, 256, 2048, 64, 64, True, False, True): (2, 16, 4, 4), + (256, 256, 2048, 128, 128, False, True, True): (3, 16, 2, 8), + (256, 256, 2048, 128, 128, True, False, True): (4, 16, 2, 8), + (256, 256, 4096, 16, 16, False, True, True): (1, 32, 3, 4), + (256, 256, 4096, 16, 16, True, False, True): (3, 16, 3, 2), + (256, 256, 4096, 32, 32, False, True, True): (3, 32, 3, 2), + (256, 256, 4096, 32, 32, True, False, True): (1, 32, 3, 2), + (256, 256, 4096, 64, 64, False, True, True): (2, 32, 3, 4), + (256, 256, 4096, 64, 64, True, False, True): (2, 32, 3, 4), + (256, 256, 4096, 128, 128, False, True, True): (5, 32, 2, 8), + (256, 256, 4096, 128, 128, True, False, True): (1, 32, 2, 8), + (256, 256, 8192, 16, 16, False, True, True): (8, 32, 3, 4), + (256, 256, 8192, 16, 16, True, False, True): (1, 32, 3, 2), + (256, 256, 8192, 32, 32, False, True, True): (3, 64, 3, 4), + (256, 256, 8192, 32, 32, True, False, True): (2, 128, 1, 2), + (256, 256, 8192, 64, 64, False, True, True): (7, 128, 1, 4), + (256, 256, 8192, 64, 64, True, False, True): (4, 128, 1, 4), + (256, 256, 8192, 128, 128, False, True, True): (2, 64, 1, 4), + (256, 256, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (256, 256, 16384, 16, 16, False, True, True): (4, 128, 3, 2), + (256, 256, 16384, 16, 16, True, False, True): (5, 64, 3, 2), + (256, 256, 16384, 32, 32, False, True, True): (5, 128, 3, 2), + (256, 256, 16384, 32, 32, True, False, True): (5, 128, 3, 2), + (256, 256, 16384, 64, 64, False, True, True): (1, 256, 1, 4), + (256, 256, 16384, 64, 64, True, False, True): (5, 128, 3, 4), + (256, 256, 16384, 128, 128, False, True, True): (11, 128, 2, 8), + (256, 256, 16384, 128, 128, True, False, True): (3, 128, 1, 4), + (256, 256, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (256, 256, 32768, 16, 16, True, False, True): (2, 128, 3, 2), + (256, 256, 32768, 32, 32, False, True, True): (4, 256, 3, 2), + (256, 256, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (256, 256, 32768, 64, 64, False, True, True): (2, 256, 1, 4), + (256, 256, 32768, 64, 64, True, False, True): (2, 256, 1, 4), + (256, 256, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (256, 256, 32768, 128, 128, True, False, True): (2, 256, 1, 4), + (256, 256, 50432, 16, 16, False, True, True): (4, 197, 1, 4), + (256, 256, 50432, 16, 16, True, False, True): (4, 197, 3, 2), + (256, 256, 50432, 32, 32, False, True, True): (1, 394, 1, 2), + (256, 256, 50432, 32, 32, True, False, True): (4, 197, 3, 4), + (256, 256, 50432, 64, 64, False, True, True): (6, 394, 1, 4), + (256, 256, 50432, 64, 64, True, False, True): (4, 394, 2, 4), + (256, 256, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (256, 256, 50432, 128, 128, True, False, True): (1, 394, 2, 4), + (256, 256, 65536, 16, 16, False, True, True): (1, 256, 3, 2), + (256, 256, 65536, 16, 16, True, False, True): (1, 256, 3, 2), + (256, 256, 65536, 32, 32, False, True, True): (1, 512, 3, 2), + (256, 256, 65536, 32, 32, True, False, True): (4, 512, 3, 2), + (256, 256, 65536, 64, 64, False, True, True): (2, 512, 1, 4), + (256, 256, 65536, 64, 64, True, False, True): (5, 512, 1, 4), + (256, 256, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (256, 256, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (256, 256, 131072, 16, 16, False, True, True): (1, 512, 3, 1), + (256, 256, 131072, 16, 16, True, False, True): (1, 512, 3, 2), + (256, 256, 131072, 32, 32, False, True, True): (2, 1024, 3, 2), + (256, 256, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (256, 256, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (256, 256, 131072, 64, 64, True, False, True): (1, 1024, 1, 4), + (256, 256, 131072, 128, 128, False, True, True): (7, 1024, 1, 4), + (256, 256, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (384, 384, 256, 16, 16, False, True, True): (3, 16, 4, 1), + (384, 384, 256, 16, 16, True, False, True): (2, 4, 6, 2), + (384, 384, 256, 32, 32, False, True, True): (1, 8, 4, 4), + (384, 384, 256, 32, 32, True, False, True): (1, 4, 5, 2), + (384, 384, 256, 64, 64, False, True, True): (3, 4, 3, 4), + (384, 384, 256, 64, 64, True, False, True): (4, 4, 5, 4), + (384, 384, 512, 16, 16, False, True, True): (1, 16, 4, 1), + (384, 384, 512, 16, 16, True, False, True): (1, 8, 5, 2), + (384, 384, 512, 32, 32, False, True, True): (4, 16, 4, 2), + (384, 384, 512, 32, 32, True, False, True): (1, 8, 5, 2), + (384, 384, 512, 64, 64, False, True, True): (2, 8, 3, 4), + (384, 384, 512, 64, 64, True, False, True): (1, 8, 4, 4), + (384, 384, 1024, 16, 16, False, True, True): (1, 16, 4, 2), + (384, 384, 1024, 16, 16, True, False, True): (7, 8, 5, 2), + (384, 384, 1024, 32, 32, False, True, True): (2, 16, 3, 4), + (384, 384, 1024, 32, 32, True, False, True): (1, 16, 4, 2), + (384, 384, 1024, 64, 64, False, True, True): (6, 16, 3, 4), + (384, 384, 1024, 64, 64, True, False, True): (4, 16, 4, 4), + (384, 384, 2048, 16, 16, False, True, True): (1, 32, 1, 4), + (384, 384, 2048, 16, 16, True, False, True): (1, 16, 3, 2), + (384, 384, 2048, 32, 32, False, True, True): (1, 32, 1, 8), + (384, 384, 2048, 32, 32, True, False, True): (1, 8, 4, 4), + (384, 384, 2048, 64, 64, False, True, True): (2, 32, 1, 8), + (384, 384, 2048, 64, 64, True, False, True): (3, 16, 3, 4), + (384, 384, 4096, 16, 16, False, True, True): (5, 32, 1, 4), + (384, 384, 4096, 16, 16, True, False, True): (1, 32, 3, 2), + (384, 384, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (384, 384, 4096, 32, 32, True, False, True): (2, 16, 4, 4), + (384, 384, 4096, 64, 64, False, True, True): (1, 64, 1, 4), + (384, 384, 4096, 64, 64, True, False, True): (2, 32, 3, 4), + (384, 384, 8192, 16, 16, False, True, True): (2, 64, 1, 4), + (384, 384, 8192, 16, 16, True, False, True): (3, 32, 3, 2), + (384, 384, 8192, 32, 32, False, True, True): (4, 128, 1, 4), + (384, 384, 8192, 32, 32, True, False, True): (1, 32, 3, 2), + (384, 384, 8192, 64, 64, False, True, True): (1, 128, 1, 4), + (384, 384, 8192, 64, 64, True, False, True): (1, 64, 3, 4), + (384, 384, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (384, 384, 16384, 16, 16, True, False, True): (1, 64, 3, 2), + (384, 384, 16384, 32, 32, False, True, True): (1, 128, 1, 4), + (384, 384, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (384, 384, 16384, 64, 64, False, True, True): (5, 128, 3, 4), + (384, 384, 16384, 64, 64, True, False, True): (1, 128, 3, 4), + (384, 384, 32768, 16, 16, False, True, True): (2, 256, 1, 2), + (384, 384, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (384, 384, 32768, 32, 32, False, True, True): (1, 256, 1, 2), + (384, 384, 32768, 32, 32, True, False, True): (2, 128, 3, 4), + (384, 384, 32768, 64, 64, False, True, True): (3, 256, 1, 4), + (384, 384, 32768, 64, 64, True, False, True): (2, 256, 3, 4), + (384, 384, 65536, 16, 16, False, True, True): (2, 128, 1, 4), + (384, 384, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (384, 384, 65536, 32, 32, False, True, True): (1, 512, 1, 2), + (384, 384, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (384, 384, 65536, 64, 64, False, True, True): (3, 512, 1, 4), + (384, 384, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (384, 384, 131072, 16, 16, False, True, True): (2, 256, 1, 2), + (384, 384, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (384, 384, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (384, 384, 131072, 32, 32, True, False, True): (1, 512, 3, 4), + (384, 384, 131072, 64, 64, False, True, True): (3, 1024, 1, 4), + (384, 384, 131072, 64, 64, True, False, True): (3, 512, 3, 4), + (512, 512, 256, 16, 16, False, True, True): (1, 8, 5, 1), + (512, 512, 256, 16, 16, True, False, True): (2, 16, 5, 1), + (512, 512, 256, 32, 32, False, True, True): (2, 8, 5, 2), + (512, 512, 256, 32, 32, True, False, True): (4, 4, 5, 2), + (512, 512, 256, 64, 64, False, True, True): (1, 4, 5, 4), + (512, 512, 256, 64, 64, True, False, True): (3, 4, 5, 4), + (512, 512, 256, 128, 128, False, True, True): (1, 2, 2, 8), + (512, 512, 256, 128, 128, True, False, True): (1, 2, 2, 8), + (512, 512, 512, 16, 16, False, True, True): (1, 8, 4, 4), + (512, 512, 512, 16, 16, True, False, True): (4, 16, 5, 1), + (512, 512, 512, 32, 32, False, True, True): (4, 8, 5, 2), + (512, 512, 512, 32, 32, True, False, True): (7, 16, 4, 1), + (512, 512, 512, 64, 64, False, True, True): (3, 8, 5, 4), + (512, 512, 512, 64, 64, True, False, True): (1, 8, 4, 4), + (512, 512, 512, 128, 128, False, True, True): (4, 4, 2, 8), + (512, 512, 512, 128, 128, True, False, True): (4, 4, 2, 8), + (512, 512, 1024, 16, 16, False, True, True): (2, 8, 4, 4), + (512, 512, 1024, 16, 16, True, False, True): (2, 16, 4, 2), + (512, 512, 1024, 32, 32, False, True, True): (3, 16, 4, 2), + (512, 512, 1024, 32, 32, True, False, True): (3, 16, 3, 2), + (512, 512, 1024, 64, 64, False, True, True): (5, 8, 5, 4), + (512, 512, 1024, 64, 64, True, False, True): (4, 16, 3, 4), + (512, 512, 1024, 128, 128, False, True, True): (6, 8, 2, 8), + (512, 512, 1024, 128, 128, True, False, True): (4, 8, 2, 8), + (512, 512, 2048, 16, 16, False, True, True): (2, 16, 3, 4), + (512, 512, 2048, 16, 16, True, False, True): (1, 16, 4, 2), + (512, 512, 2048, 32, 32, False, True, True): (2, 32, 3, 2), + (512, 512, 2048, 32, 32, True, False, True): (2, 32, 3, 2), + (512, 512, 2048, 64, 64, False, True, True): (1, 32, 3, 4), + (512, 512, 2048, 64, 64, True, False, True): (1, 32, 3, 2), + (512, 512, 2048, 128, 128, False, True, True): (3, 16, 2, 8), + (512, 512, 2048, 128, 128, True, False, True): (1, 16, 2, 8), + (512, 512, 4096, 16, 16, False, True, True): (4, 32, 3, 2), + (512, 512, 4096, 16, 16, True, False, True): (1, 32, 3, 2), + (512, 512, 4096, 32, 32, False, True, True): (3, 32, 3, 2), + (512, 512, 4096, 32, 32, True, False, True): (3, 32, 3, 2), + (512, 512, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (512, 512, 4096, 64, 64, True, False, True): (1, 64, 1, 4), + (512, 512, 4096, 128, 128, False, True, True): (7, 32, 2, 8), + (512, 512, 4096, 128, 128, True, False, True): (1, 32, 2, 8), + (512, 512, 8192, 16, 16, False, True, True): (4, 64, 3, 2), + (512, 512, 8192, 16, 16, True, False, True): (1, 64, 3, 2), + (512, 512, 8192, 32, 32, False, True, True): (3, 64, 3, 2), + (512, 512, 8192, 32, 32, True, False, True): (1, 64, 3, 2), + (512, 512, 8192, 64, 64, False, True, True): (1, 64, 3, 4), + (512, 512, 8192, 64, 64, True, False, True): (1, 64, 3, 4), + (512, 512, 8192, 128, 128, False, True, True): (7, 64, 2, 8), + (512, 512, 8192, 128, 128, True, False, True): (1, 64, 1, 4), + (512, 512, 16384, 16, 16, False, True, True): (1, 128, 3, 2), + (512, 512, 16384, 16, 16, True, False, True): (1, 64, 3, 2), + (512, 512, 16384, 32, 32, False, True, True): (1, 128, 3, 2), + (512, 512, 16384, 32, 32, True, False, True): (1, 128, 3, 2), + (512, 512, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (512, 512, 16384, 64, 64, True, False, True): (4, 128, 3, 4), + (512, 512, 16384, 128, 128, False, True, True): (5, 128, 2, 8), + (512, 512, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (512, 512, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (512, 512, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (512, 512, 32768, 32, 32, False, True, True): (1, 256, 3, 2), + (512, 512, 32768, 32, 32, True, False, True): (1, 256, 3, 2), + (512, 512, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (512, 512, 32768, 64, 64, True, False, True): (1, 256, 3, 4), + (512, 512, 32768, 128, 128, False, True, True): (5, 256, 1, 4), + (512, 512, 32768, 128, 128, True, False, True): (1, 256, 1, 4), + (512, 512, 50432, 16, 16, False, True, True): (4, 197, 1, 4), + (512, 512, 50432, 16, 16, True, False, True): (4, 197, 3, 2), + (512, 512, 50432, 32, 32, False, True, True): (2, 197, 1, 4), + (512, 512, 50432, 32, 32, True, False, True): (4, 197, 3, 4), + (512, 512, 50432, 64, 64, False, True, True): (2, 394, 1, 4), + (512, 512, 50432, 64, 64, True, False, True): (4, 197, 2, 4), + (512, 512, 50432, 128, 128, False, True, True): (5, 394, 1, 4), + (512, 512, 50432, 128, 128, True, False, True): (6, 394, 2, 4), + (512, 512, 65536, 16, 16, False, True, True): (1, 256, 3, 2), + (512, 512, 65536, 16, 16, True, False, True): (1, 256, 3, 1), + (512, 512, 65536, 32, 32, False, True, True): (1, 512, 3, 2), + (512, 512, 65536, 32, 32, True, False, True): (1, 512, 3, 2), + (512, 512, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (512, 512, 65536, 64, 64, True, False, True): (1, 512, 3, 4), + (512, 512, 65536, 128, 128, False, True, True): (7, 512, 1, 4), + (512, 512, 65536, 128, 128, True, False, True): (5, 512, 1, 4), + (512, 512, 131072, 16, 16, False, True, True): (1, 512, 3, 1), + (512, 512, 131072, 16, 16, True, False, True): (1, 512, 3, 1), + (512, 512, 131072, 32, 32, False, True, True): (1, 1024, 3, 2), + (512, 512, 131072, 32, 32, True, False, True): (1, 1024, 3, 2), + (512, 512, 131072, 64, 64, False, True, True): (4, 512, 2, 4), + (512, 512, 131072, 64, 64, True, False, True): (2, 512, 2, 4), + (512, 512, 131072, 128, 128, False, True, True): (5, 1024, 1, 4), + (512, 512, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (768, 768, 256, 16, 16, False, True, True): (1, 8, 4, 1), + (768, 768, 256, 16, 16, True, False, True): (3, 2, 5, 2), + (768, 768, 256, 32, 32, False, True, True): (1, 8, 4, 2), + (768, 768, 256, 32, 32, True, False, True): (2, 4, 6, 2), + (768, 768, 256, 64, 64, False, True, True): (3, 4, 3, 4), + (768, 768, 256, 64, 64, True, False, True): (2, 4, 4, 4), + (768, 768, 256, 128, 128, False, True, True): (1, 2, 3, 8), + (768, 768, 256, 128, 128, True, False, True): (2, 2, 3, 8), + (768, 768, 512, 16, 16, False, True, True): (1, 8, 4, 2), + (768, 768, 512, 16, 16, True, False, True): (2, 8, 5, 2), + (768, 768, 512, 32, 32, False, True, True): (1, 16, 1, 4), + (768, 768, 512, 32, 32, True, False, True): (3, 8, 5, 2), + (768, 768, 512, 64, 64, False, True, True): (4, 8, 3, 4), + (768, 768, 512, 64, 64, True, False, True): (2, 8, 4, 4), + (768, 768, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (768, 768, 512, 128, 128, True, False, True): (3, 4, 3, 8), + (768, 768, 1024, 16, 16, False, True, True): (1, 16, 1, 4), + (768, 768, 1024, 16, 16, True, False, True): (1, 8, 5, 2), + (768, 768, 1024, 32, 32, False, True, True): (1, 16, 1, 8), + (768, 768, 1024, 32, 32, True, False, True): (1, 4, 4, 4), + (768, 768, 1024, 64, 64, False, True, True): (2, 16, 1, 8), + (768, 768, 1024, 64, 64, True, False, True): (1, 8, 3, 8), + (768, 768, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (768, 768, 1024, 128, 128, True, False, True): (3, 8, 3, 8), + (768, 768, 2048, 16, 16, False, True, True): (6, 16, 1, 2), + (768, 768, 2048, 16, 16, True, False, True): (2, 16, 4, 2), + (768, 768, 2048, 32, 32, False, True, True): (3, 32, 1, 4), + (768, 768, 2048, 32, 32, True, False, True): (6, 8, 3, 4), + (768, 768, 2048, 64, 64, False, True, True): (2, 32, 2, 2), + (768, 768, 2048, 64, 64, True, False, True): (1, 16, 4, 4), + (768, 768, 2048, 128, 128, False, True, True): (2, 16, 3, 8), + (768, 768, 2048, 128, 128, True, False, True): (4, 16, 3, 8), + (768, 768, 4096, 16, 16, False, True, True): (1, 32, 1, 4), + (768, 768, 4096, 16, 16, True, False, True): (2, 16, 3, 2), + (768, 768, 4096, 32, 32, False, True, True): (3, 32, 1, 8), + (768, 768, 4096, 32, 32, True, False, True): (1, 16, 4, 4), + (768, 768, 4096, 64, 64, False, True, True): (1, 64, 2, 4), + (768, 768, 4096, 64, 64, True, False, True): (1, 8, 3, 8), + (768, 768, 4096, 128, 128, False, True, True): (1, 32, 3, 8), + (768, 768, 4096, 128, 128, True, False, True): (2, 32, 3, 8), + (768, 768, 8192, 16, 16, False, True, True): (1, 64, 1, 2), + (768, 768, 8192, 16, 16, True, False, True): (2, 64, 3, 2), + (768, 768, 8192, 32, 32, False, True, True): (2, 64, 1, 8), + (768, 768, 8192, 32, 32, True, False, True): (2, 32, 3, 4), + (768, 768, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (768, 768, 8192, 64, 64, True, False, True): (1, 64, 3, 4), + (768, 768, 8192, 128, 128, False, True, True): (4, 64, 3, 8), + (768, 768, 8192, 128, 128, True, False, True): (2, 64, 3, 8), + (768, 768, 16384, 16, 16, False, True, True): (4, 128, 1, 2), + (768, 768, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (768, 768, 16384, 32, 32, False, True, True): (1, 128, 1, 8), + (768, 768, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (768, 768, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (768, 768, 16384, 64, 64, True, False, True): (1, 128, 3, 4), + (768, 768, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (768, 768, 16384, 128, 128, True, False, True): (1, 128, 2, 4), + (768, 768, 32768, 16, 16, False, True, True): (2, 256, 1, 2), + (768, 768, 32768, 16, 16, True, False, True): (1, 128, 4, 4), + (768, 768, 32768, 32, 32, False, True, True): (1, 128, 1, 2), + (768, 768, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (768, 768, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (768, 768, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (768, 768, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (768, 768, 32768, 128, 128, True, False, True): (3, 256, 2, 4), + (768, 768, 65536, 16, 16, False, True, True): (4, 512, 1, 2), + (768, 768, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (768, 768, 65536, 32, 32, False, True, True): (1, 256, 1, 2), + (768, 768, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (768, 768, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (768, 768, 65536, 64, 64, True, False, True): (1, 256, 3, 4), + (768, 768, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (768, 768, 65536, 128, 128, True, False, True): (2, 512, 2, 4), + (768, 768, 131072, 16, 16, False, True, True): (1, 512, 1, 1), + (768, 768, 131072, 16, 16, True, False, True): (1, 512, 4, 4), + (768, 768, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (768, 768, 131072, 32, 32, True, False, True): (1, 512, 3, 4), + (768, 768, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (768, 768, 131072, 64, 64, True, False, True): (3, 512, 3, 4), + (768, 768, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (768, 768, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + (768, 3072, 256, 16, 16, False, True, True): (1, 8, 5, 2), + (768, 3072, 256, 16, 16, True, False, True): (3, 4, 7, 2), + (768, 3072, 256, 32, 32, False, True, True): (1, 8, 4, 2), + (768, 3072, 256, 32, 32, True, False, True): (1, 4, 5, 4), + (768, 3072, 256, 64, 64, False, True, True): (1, 4, 3, 4), + (768, 3072, 256, 64, 64, True, False, True): (1, 4, 5, 4), + (768, 3072, 256, 128, 128, False, True, True): (2, 2, 3, 8), + (768, 3072, 256, 128, 128, True, False, True): (2, 2, 3, 8), + (768, 3072, 512, 16, 16, False, True, True): (1, 8, 5, 2), + (768, 3072, 512, 16, 16, True, False, True): (1, 8, 5, 2), + (768, 3072, 512, 32, 32, False, True, True): (3, 8, 3, 4), + (768, 3072, 512, 32, 32, True, False, True): (1, 8, 7, 4), + (768, 3072, 512, 64, 64, False, True, True): (3, 8, 3, 4), + (768, 3072, 512, 64, 64, True, False, True): (3, 8, 5, 4), + (768, 3072, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (768, 3072, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (768, 3072, 1024, 16, 16, False, True, True): (4, 16, 1, 4), + (768, 3072, 1024, 16, 16, True, False, True): (2, 8, 5, 2), + (768, 3072, 1024, 32, 32, False, True, True): (1, 16, 6, 2), + (768, 3072, 1024, 32, 32, True, False, True): (1, 8, 4, 4), + (768, 3072, 1024, 64, 64, False, True, True): (2, 16, 4, 4), + (768, 3072, 1024, 64, 64, True, False, True): (2, 16, 4, 4), + (768, 3072, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (768, 3072, 1024, 128, 128, True, False, True): (3, 8, 3, 8), + (768, 3072, 2048, 16, 16, False, True, True): (1, 16, 1, 2), + (768, 3072, 2048, 16, 16, True, False, True): (1, 16, 5, 2), + (768, 3072, 2048, 32, 32, False, True, True): (4, 16, 1, 8), + (768, 3072, 2048, 32, 32, True, False, True): (2, 8, 3, 4), + (768, 3072, 2048, 64, 64, False, True, True): (2, 16, 3, 4), + (768, 3072, 2048, 64, 64, True, False, True): (2, 16, 3, 4), + (768, 3072, 2048, 128, 128, False, True, True): (3, 16, 3, 8), + (768, 3072, 2048, 128, 128, True, False, True): (1, 16, 3, 8), + (768, 3072, 4096, 16, 16, False, True, True): (1, 32, 1, 4), + (768, 3072, 4096, 16, 16, True, False, True): (1, 16, 3, 1), + (768, 3072, 4096, 32, 32, False, True, True): (3, 32, 1, 8), + (768, 3072, 4096, 32, 32, True, False, True): (2, 16, 3, 8), + (768, 3072, 4096, 64, 64, False, True, True): (2, 32, 3, 4), + (768, 3072, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (768, 3072, 4096, 128, 128, False, True, True): (5, 32, 1, 4), + (768, 3072, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (768, 3072, 8192, 16, 16, False, True, True): (1, 32, 1, 4), + (768, 3072, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (768, 3072, 8192, 32, 32, False, True, True): (1, 64, 1, 8), + (768, 3072, 8192, 32, 32, True, False, True): (2, 32, 3, 8), + (768, 3072, 8192, 64, 64, False, True, True): (2, 64, 3, 4), + (768, 3072, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (768, 3072, 8192, 128, 128, False, True, True): (1, 64, 3, 8), + (768, 3072, 8192, 128, 128, True, False, True): (2, 64, 3, 8), + (768, 3072, 16384, 16, 16, False, True, True): (1, 64, 1, 4), + (768, 3072, 16384, 16, 16, True, False, True): (1, 64, 4, 1), + (768, 3072, 16384, 32, 32, False, True, True): (1, 128, 1, 8), + (768, 3072, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (768, 3072, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (768, 3072, 16384, 64, 64, True, False, True): (1, 64, 3, 4), + (768, 3072, 16384, 128, 128, False, True, True): (2, 128, 3, 8), + (768, 3072, 16384, 128, 128, True, False, True): (1, 128, 3, 8), + (768, 3072, 32768, 16, 16, False, True, True): (1, 128, 1, 4), + (768, 3072, 32768, 16, 16, True, False, True): (1, 128, 4, 1), + (768, 3072, 32768, 32, 32, False, True, True): (1, 256, 1, 8), + (768, 3072, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (768, 3072, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (768, 3072, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (768, 3072, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (768, 3072, 32768, 128, 128, True, False, True): (5, 256, 3, 8), + (768, 3072, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (768, 3072, 50432, 16, 16, True, False, True): (4, 197, 4, 1), + (768, 3072, 50432, 32, 32, False, True, True): (2, 197, 1, 4), + (768, 3072, 50432, 32, 32, True, False, True): (4, 197, 3, 4), + (768, 3072, 50432, 64, 64, False, True, True): (1, 394, 3, 4), + (768, 3072, 50432, 64, 64, True, False, True): (1, 197, 3, 4), + (768, 3072, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (768, 3072, 50432, 128, 128, True, False, True): (3, 394, 2, 4), + (768, 3072, 65536, 16, 16, False, True, True): (1, 256, 1, 4), + (768, 3072, 65536, 16, 16, True, False, True): (5, 256, 4, 1), + (768, 3072, 65536, 32, 32, False, True, True): (2, 256, 1, 4), + (768, 3072, 65536, 32, 32, True, False, True): (3, 256, 3, 4), + (768, 3072, 65536, 64, 64, False, True, True): (1, 512, 3, 4), + (768, 3072, 65536, 64, 64, True, False, True): (1, 256, 3, 4), + (768, 3072, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (768, 3072, 65536, 128, 128, True, False, True): (2, 512, 3, 8), + (768, 3072, 131072, 16, 16, False, True, True): (1, 512, 1, 4), + (768, 3072, 131072, 16, 16, True, False, True): (5, 512, 4, 1), + (768, 3072, 131072, 32, 32, False, True, True): (2, 512, 1, 4), + (768, 3072, 131072, 32, 32, True, False, True): (2, 512, 3, 4), + (768, 3072, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (768, 3072, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (768, 3072, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (768, 3072, 131072, 128, 128, True, False, True): (2, 1024, 3, 8), + (1024, 1024, 256, 16, 16, False, True, True): (3, 4, 5, 4), + (1024, 1024, 256, 16, 16, True, False, True): (3, 4, 5, 4), + (1024, 1024, 256, 32, 32, False, True, True): (2, 4, 6, 2), + (1024, 1024, 256, 32, 32, True, False, True): (2, 4, 6, 2), + (1024, 1024, 256, 64, 64, False, True, True): (1, 4, 4, 4), + (1024, 1024, 256, 64, 64, True, False, True): (2, 4, 6, 4), + (1024, 1024, 256, 128, 128, False, True, True): (1, 2, 2, 8), + (1024, 1024, 256, 128, 128, True, False, True): (1, 2, 2, 8), + (1024, 1024, 512, 16, 16, False, True, True): (3, 4, 5, 4), + (1024, 1024, 512, 16, 16, True, False, True): (3, 8, 4, 2), + (1024, 1024, 512, 32, 32, False, True, True): (1, 8, 4, 2), + (1024, 1024, 512, 32, 32, True, False, True): (1, 8, 4, 2), + (1024, 1024, 512, 64, 64, False, True, True): (2, 8, 3, 4), + (1024, 1024, 512, 64, 64, True, False, True): (1, 4, 4, 4), + (1024, 1024, 512, 128, 128, False, True, True): (7, 4, 2, 8), + (1024, 1024, 512, 128, 128, True, False, True): (1, 4, 2, 8), + (1024, 1024, 1024, 16, 16, False, True, True): (4, 8, 4, 2), + (1024, 1024, 1024, 16, 16, True, False, True): (3, 8, 5, 2), + (1024, 1024, 1024, 32, 32, False, True, True): (1, 8, 4, 4), + (1024, 1024, 1024, 32, 32, True, False, True): (1, 8, 4, 2), + (1024, 1024, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (1024, 1024, 1024, 64, 64, True, False, True): (3, 16, 3, 4), + (1024, 1024, 1024, 128, 128, False, True, True): (6, 8, 2, 8), + (1024, 1024, 1024, 128, 128, True, False, True): (4, 8, 2, 8), + (1024, 1024, 2048, 16, 16, False, True, True): (3, 8, 3, 4), + (1024, 1024, 2048, 16, 16, True, False, True): (3, 8, 3, 4), + (1024, 1024, 2048, 32, 32, False, True, True): (1, 16, 3, 4), + (1024, 1024, 2048, 32, 32, True, False, True): (1, 16, 3, 2), + (1024, 1024, 2048, 64, 64, False, True, True): (5, 16, 3, 4), + (1024, 1024, 2048, 64, 64, True, False, True): (5, 16, 3, 4), + (1024, 1024, 2048, 128, 128, False, True, True): (3, 16, 2, 8), + (1024, 1024, 2048, 128, 128, True, False, True): (4, 16, 2, 16), + (1024, 1024, 4096, 16, 16, False, True, True): (4, 32, 3, 2), + (1024, 1024, 4096, 16, 16, True, False, True): (8, 32, 3, 2), + (1024, 1024, 4096, 32, 32, False, True, True): (9, 32, 3, 2), + (1024, 1024, 4096, 32, 32, True, False, True): (1, 32, 3, 2), + (1024, 1024, 4096, 64, 64, False, True, True): (6, 32, 3, 4), + (1024, 1024, 4096, 64, 64, True, False, True): (1, 32, 3, 4), + (1024, 1024, 4096, 128, 128, False, True, True): (4, 32, 2, 8), + (1024, 1024, 4096, 128, 128, True, False, True): (4, 32, 1, 4), + (1024, 1024, 8192, 16, 16, False, True, True): (4, 64, 3, 2), + (1024, 1024, 8192, 16, 16, True, False, True): (4, 64, 3, 2), + (1024, 1024, 8192, 32, 32, False, True, True): (8, 64, 3, 2), + (1024, 1024, 8192, 32, 32, True, False, True): (6, 64, 3, 2), + (1024, 1024, 8192, 64, 64, False, True, True): (2, 64, 3, 4), + (1024, 1024, 8192, 64, 64, True, False, True): (2, 64, 3, 4), + (1024, 1024, 8192, 128, 128, False, True, True): (3, 64, 1, 4), + (1024, 1024, 8192, 128, 128, True, False, True): (2, 64, 1, 4), + (1024, 1024, 16384, 16, 16, False, True, True): (1, 64, 3, 4), + (1024, 1024, 16384, 16, 16, True, False, True): (1, 64, 3, 2), + (1024, 1024, 16384, 32, 32, False, True, True): (1, 128, 3, 4), + (1024, 1024, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (1024, 1024, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (1024, 1024, 16384, 64, 64, True, False, True): (1, 128, 3, 4), + (1024, 1024, 16384, 128, 128, False, True, True): (11, 128, 1, 4), + (1024, 1024, 16384, 128, 128, True, False, True): (4, 128, 1, 4), + (1024, 1024, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (1024, 1024, 32768, 16, 16, True, False, True): (1, 128, 3, 1), + (1024, 1024, 32768, 32, 32, False, True, True): (1, 256, 3, 2), + (1024, 1024, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (1024, 1024, 32768, 64, 64, False, True, True): (2, 128, 2, 4), + (1024, 1024, 32768, 64, 64, True, False, True): (1, 256, 3, 4), + (1024, 1024, 32768, 128, 128, False, True, True): (7, 256, 1, 4), + (1024, 1024, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (1024, 1024, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (1024, 1024, 50432, 16, 16, True, False, True): (4, 197, 3, 4), + (1024, 1024, 50432, 32, 32, False, True, True): (2, 197, 1, 4), + (1024, 1024, 50432, 32, 32, True, False, True): (1, 197, 3, 4), + (1024, 1024, 50432, 64, 64, False, True, True): (2, 394, 1, 4), + (1024, 1024, 50432, 64, 64, True, False, True): (1, 197, 2, 4), + (1024, 1024, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (1024, 1024, 50432, 128, 128, True, False, True): (2, 394, 2, 4), + (1024, 1024, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (1024, 1024, 65536, 16, 16, True, False, True): (1, 256, 3, 1), + (1024, 1024, 65536, 32, 32, False, True, True): (1, 512, 3, 2), + (1024, 1024, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (1024, 1024, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (1024, 1024, 65536, 64, 64, True, False, True): (1, 512, 3, 4), + (1024, 1024, 65536, 128, 128, False, True, True): (10, 512, 1, 4), + (1024, 1024, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (1024, 1024, 131072, 16, 16, False, True, True): (11, 512, 3, 2), + (1024, 1024, 131072, 16, 16, True, False, True): (11, 512, 3, 2), + (1024, 1024, 131072, 32, 32, False, True, True): (7, 1024, 3, 2), + (1024, 1024, 131072, 32, 32, True, False, True): (6, 512, 3, 4), + (1024, 1024, 131072, 64, 64, False, True, True): (1, 512, 2, 4), + (1024, 1024, 131072, 64, 64, True, False, True): (4, 1024, 3, 4), + (1024, 1024, 131072, 128, 128, False, True, True): (12, 1024, 1, 4), + (1024, 1024, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (1536, 1536, 256, 16, 16, False, True, True): (5, 4, 4, 2), + (1536, 1536, 256, 16, 16, True, False, True): (3, 4, 5, 2), + (1536, 1536, 256, 32, 32, False, True, True): (2, 4, 4, 4), + (1536, 1536, 256, 32, 32, True, False, True): (1, 4, 6, 2), + (1536, 1536, 256, 64, 64, False, True, True): (5, 4, 4, 4), + (1536, 1536, 256, 64, 64, True, False, True): (2, 4, 4, 4), + (1536, 1536, 256, 128, 128, False, True, True): (1, 2, 3, 8), + (1536, 1536, 256, 128, 128, True, False, True): (2, 2, 3, 8), + (1536, 1536, 512, 16, 16, False, True, True): (1, 8, 1, 4), + (1536, 1536, 512, 16, 16, True, False, True): (3, 4, 4, 2), + (1536, 1536, 512, 32, 32, False, True, True): (1, 8, 1, 8), + (1536, 1536, 512, 32, 32, True, False, True): (1, 4, 4, 4), + (1536, 1536, 512, 64, 64, False, True, True): (3, 8, 3, 4), + (1536, 1536, 512, 64, 64, True, False, True): (5, 8, 3, 4), + (1536, 1536, 512, 128, 128, False, True, True): (3, 4, 3, 8), + (1536, 1536, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (1536, 1536, 1024, 16, 16, False, True, True): (6, 8, 1, 2), + (1536, 1536, 1024, 16, 16, True, False, True): (2, 8, 5, 2), + (1536, 1536, 1024, 32, 32, False, True, True): (6, 8, 1, 8), + (1536, 1536, 1024, 32, 32, True, False, True): (2, 4, 3, 4), + (1536, 1536, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (1536, 1536, 1024, 64, 64, True, False, True): (3, 8, 3, 4), + (1536, 1536, 1024, 128, 128, False, True, True): (3, 8, 3, 8), + (1536, 1536, 1024, 128, 128, True, False, True): (3, 8, 3, 8), + (1536, 1536, 2048, 16, 16, False, True, True): (1, 16, 1, 4), + (1536, 1536, 2048, 16, 16, True, False, True): (1, 8, 3, 1), + (1536, 1536, 2048, 32, 32, False, True, True): (1, 16, 1, 8), + (1536, 1536, 2048, 32, 32, True, False, True): (4, 8, 3, 2), + (1536, 1536, 2048, 64, 64, False, True, True): (1, 16, 3, 4), + (1536, 1536, 2048, 64, 64, True, False, True): (3, 8, 3, 4), + (1536, 1536, 2048, 128, 128, False, True, True): (6, 16, 1, 4), + (1536, 1536, 2048, 128, 128, True, False, True): (4, 16, 3, 8), + (1536, 1536, 4096, 16, 16, False, True, True): (1, 32, 1, 2), + (1536, 1536, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (1536, 1536, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (1536, 1536, 4096, 32, 32, True, False, True): (3, 16, 3, 4), + (1536, 1536, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (1536, 1536, 4096, 64, 64, True, False, True): (1, 16, 3, 4), + (1536, 1536, 4096, 128, 128, False, True, True): (4, 32, 3, 8), + (1536, 1536, 4096, 128, 128, True, False, True): (2, 32, 3, 8), + (1536, 1536, 8192, 16, 16, False, True, True): (2, 64, 1, 2), + (1536, 1536, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (1536, 1536, 8192, 32, 32, False, True, True): (1, 64, 1, 8), + (1536, 1536, 8192, 32, 32, True, False, True): (12, 32, 3, 4), + (1536, 1536, 8192, 64, 64, False, True, True): (2, 64, 3, 4), + (1536, 1536, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (1536, 1536, 8192, 128, 128, False, True, True): (3, 64, 1, 4), + (1536, 1536, 8192, 128, 128, True, False, True): (4, 64, 3, 8), + (1536, 1536, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (1536, 1536, 16384, 16, 16, True, False, True): (1, 64, 4, 4), + (1536, 1536, 16384, 32, 32, False, True, True): (1, 64, 1, 2), + (1536, 1536, 16384, 32, 32, True, False, True): (1, 64, 3, 4), + (1536, 1536, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (1536, 1536, 16384, 64, 64, True, False, True): (1, 64, 3, 4), + (1536, 1536, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (1536, 1536, 16384, 128, 128, True, False, True): (1, 128, 2, 4), + (1536, 1536, 32768, 16, 16, False, True, True): (1, 256, 1, 2), + (1536, 1536, 32768, 16, 16, True, False, True): (1, 128, 3, 2), + (1536, 1536, 32768, 32, 32, False, True, True): (1, 128, 1, 2), + (1536, 1536, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (1536, 1536, 32768, 64, 64, False, True, True): (3, 256, 3, 4), + (1536, 1536, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (1536, 1536, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (1536, 1536, 32768, 128, 128, True, False, True): (1, 256, 2, 4), + (1536, 1536, 65536, 16, 16, False, True, True): (4, 512, 1, 2), + (1536, 1536, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (1536, 1536, 65536, 32, 32, False, True, True): (1, 256, 1, 2), + (1536, 1536, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (1536, 1536, 65536, 64, 64, False, True, True): (2, 512, 3, 4), + (1536, 1536, 65536, 64, 64, True, False, True): (1, 256, 3, 4), + (1536, 1536, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (1536, 1536, 65536, 128, 128, True, False, True): (2, 512, 2, 4), + (1536, 1536, 131072, 16, 16, False, True, True): (2, 1024, 1, 2), + (1536, 1536, 131072, 16, 16, True, False, True): (9, 512, 4, 4), + (1536, 1536, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (1536, 1536, 131072, 32, 32, True, False, True): (9, 512, 3, 4), + (1536, 1536, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (1536, 1536, 131072, 64, 64, True, False, True): (1, 512, 3, 4), + (1536, 1536, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (1536, 1536, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + (2048, 2048, 256, 16, 16, False, True, True): (4, 4, 6, 2), + (2048, 2048, 256, 16, 16, True, False, True): (2, 8, 4, 1), + (2048, 2048, 256, 32, 32, False, True, True): (3, 4, 4, 2), + (2048, 2048, 256, 32, 32, True, False, True): (1, 4, 5, 2), + (2048, 2048, 256, 64, 64, False, True, True): (2, 4, 4, 4), + (2048, 2048, 256, 64, 64, True, False, True): (2, 4, 4, 4), + (2048, 2048, 256, 128, 128, False, True, True): (3, 2, 2, 8), + (2048, 2048, 256, 128, 128, True, False, True): (5, 2, 2, 8), + (2048, 2048, 512, 16, 16, False, True, True): (5, 4, 4, 4), + (2048, 2048, 512, 16, 16, True, False, True): (2, 4, 4, 2), + (2048, 2048, 512, 32, 32, False, True, True): (1, 4, 3, 4), + (2048, 2048, 512, 32, 32, True, False, True): (3, 4, 4, 2), + (2048, 2048, 512, 64, 64, False, True, True): (1, 8, 3, 4), + (2048, 2048, 512, 64, 64, True, False, True): (1, 8, 3, 2), + (2048, 2048, 512, 128, 128, False, True, True): (3, 4, 2, 8), + (2048, 2048, 512, 128, 128, True, False, True): (2, 4, 2, 8), + (2048, 2048, 1024, 16, 16, False, True, True): (3, 4, 3, 4), + (2048, 2048, 1024, 16, 16, True, False, True): (2, 8, 3, 2), + (2048, 2048, 1024, 32, 32, False, True, True): (3, 8, 3, 4), + (2048, 2048, 1024, 32, 32, True, False, True): (1, 8, 3, 2), + (2048, 2048, 1024, 64, 64, False, True, True): (1, 8, 3, 4), + (2048, 2048, 1024, 64, 64, True, False, True): (1, 8, 3, 4), + (2048, 2048, 1024, 128, 128, False, True, True): (4, 8, 2, 8), + (2048, 2048, 1024, 128, 128, True, False, True): (4, 8, 1, 4), + (2048, 2048, 2048, 16, 16, False, True, True): (4, 16, 3, 2), + (2048, 2048, 2048, 16, 16, True, False, True): (2, 16, 3, 2), + (2048, 2048, 2048, 32, 32, False, True, True): (1, 16, 3, 4), + (2048, 2048, 2048, 32, 32, True, False, True): (1, 16, 3, 2), + (2048, 2048, 2048, 64, 64, False, True, True): (1, 16, 3, 4), + (2048, 2048, 2048, 64, 64, True, False, True): (1, 16, 3, 4), + (2048, 2048, 2048, 128, 128, False, True, True): (6, 16, 2, 8), + (2048, 2048, 2048, 128, 128, True, False, True): (5, 16, 1, 4), + (2048, 2048, 4096, 16, 16, False, True, True): (4, 32, 4, 2), + (2048, 2048, 4096, 16, 16, True, False, True): (4, 32, 3, 2), + (2048, 2048, 4096, 32, 32, False, True, True): (4, 16, 3, 8), + (2048, 2048, 4096, 32, 32, True, False, True): (4, 16, 3, 4), + (2048, 2048, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (2048, 2048, 4096, 64, 64, True, False, True): (4, 32, 3, 4), + (2048, 2048, 4096, 128, 128, False, True, True): (4, 32, 2, 8), + (2048, 2048, 4096, 128, 128, True, False, True): (2, 32, 1, 4), + (2048, 2048, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (2048, 2048, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (2048, 2048, 8192, 32, 32, False, True, True): (4, 32, 3, 8), + (2048, 2048, 8192, 32, 32, True, False, True): (4, 32, 4, 8), + (2048, 2048, 8192, 64, 64, False, True, True): (2, 64, 3, 4), + (2048, 2048, 8192, 64, 64, True, False, True): (4, 64, 3, 4), + (2048, 2048, 8192, 128, 128, False, True, True): (3, 64, 1, 4), + (2048, 2048, 8192, 128, 128, True, False, True): (2, 64, 1, 4), + (2048, 2048, 16384, 16, 16, False, True, True): (4, 64, 3, 4), + (2048, 2048, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (2048, 2048, 16384, 32, 32, False, True, True): (4, 64, 3, 4), + (2048, 2048, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (2048, 2048, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (2048, 2048, 16384, 64, 64, True, False, True): (4, 128, 3, 4), + (2048, 2048, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (2048, 2048, 16384, 128, 128, True, False, True): (2, 128, 1, 4), + (2048, 2048, 32768, 16, 16, False, True, True): (8, 128, 3, 2), + (2048, 2048, 32768, 16, 16, True, False, True): (8, 128, 3, 4), + (2048, 2048, 32768, 32, 32, False, True, True): (8, 128, 3, 4), + (2048, 2048, 32768, 32, 32, True, False, True): (8, 128, 3, 4), + (2048, 2048, 32768, 64, 64, False, True, True): (8, 256, 3, 4), + (2048, 2048, 32768, 64, 64, True, False, True): (8, 256, 3, 4), + (2048, 2048, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (2048, 2048, 32768, 128, 128, True, False, True): (1, 256, 1, 4), + (2048, 2048, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (2048, 2048, 50432, 16, 16, True, False, True): (4, 197, 4, 1), + (2048, 2048, 50432, 32, 32, False, True, True): (2, 197, 1, 4), + (2048, 2048, 50432, 32, 32, True, False, True): (4, 197, 3, 4), + (2048, 2048, 50432, 64, 64, False, True, True): (2, 394, 3, 4), + (2048, 2048, 50432, 64, 64, True, False, True): (4, 197, 2, 4), + (2048, 2048, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (2048, 2048, 50432, 128, 128, True, False, True): (4, 394, 2, 4), + (2048, 2048, 65536, 16, 16, False, True, True): (9, 256, 3, 2), + (2048, 2048, 65536, 16, 16, True, False, True): (9, 256, 4, 4), + (2048, 2048, 65536, 32, 32, False, True, True): (7, 256, 3, 4), + (2048, 2048, 65536, 32, 32, True, False, True): (7, 256, 3, 4), + (2048, 2048, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (2048, 2048, 65536, 64, 64, True, False, True): (9, 512, 3, 4), + (2048, 2048, 65536, 128, 128, False, True, True): (5, 512, 1, 4), + (2048, 2048, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (2048, 2048, 131072, 16, 16, False, True, True): (9, 512, 3, 2), + (2048, 2048, 131072, 16, 16, True, False, True): (9, 512, 4, 4), + (2048, 2048, 131072, 32, 32, False, True, True): (7, 512, 3, 4), + (2048, 2048, 131072, 32, 32, True, False, True): (3, 512, 3, 4), + (2048, 2048, 131072, 64, 64, False, True, True): (1, 512, 2, 4), + (2048, 2048, 131072, 64, 64, True, False, True): (2, 1024, 3, 4), + (2048, 2048, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (2048, 2048, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (3072, 768, 256, 16, 16, False, True, True): (6, 4, 1, 4), + (3072, 768, 256, 16, 16, True, False, True): (2, 1, 5, 2), + (3072, 768, 256, 32, 32, False, True, True): (1, 4, 1, 8), + (3072, 768, 256, 32, 32, True, False, True): (4, 2, 4, 4), + (3072, 768, 256, 64, 64, False, True, True): (1, 2, 3, 4), + (3072, 768, 256, 64, 64, True, False, True): (3, 4, 3, 4), + (3072, 768, 256, 128, 128, False, True, True): (1, 2, 3, 8), + (3072, 768, 256, 128, 128, True, False, True): (3, 2, 3, 8), + (3072, 768, 512, 16, 16, False, True, True): (1, 4, 1, 4), + (3072, 768, 512, 16, 16, True, False, True): (3, 4, 4, 1), + (3072, 768, 512, 32, 32, False, True, True): (5, 8, 1, 4), + (3072, 768, 512, 32, 32, True, False, True): (3, 4, 4, 2), + (3072, 768, 512, 64, 64, False, True, True): (1, 8, 1, 4), + (3072, 768, 512, 64, 64, True, False, True): (1, 4, 3, 4), + (3072, 768, 512, 128, 128, False, True, True): (3, 4, 3, 8), + (3072, 768, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (3072, 768, 1024, 16, 16, False, True, True): (1, 8, 1, 4), + (3072, 768, 1024, 16, 16, True, False, True): (3, 4, 3, 1), + (3072, 768, 1024, 32, 32, False, True, True): (1, 16, 1, 4), + (3072, 768, 1024, 32, 32, True, False, True): (1, 4, 3, 8), + (3072, 768, 1024, 64, 64, False, True, True): (8, 16, 3, 2), + (3072, 768, 1024, 64, 64, True, False, True): (1, 4, 3, 4), + (3072, 768, 1024, 128, 128, False, True, True): (2, 8, 3, 8), + (3072, 768, 1024, 128, 128, True, False, True): (3, 8, 2, 4), + (3072, 768, 2048, 16, 16, False, True, True): (1, 8, 1, 4), + (3072, 768, 2048, 16, 16, True, False, True): (6, 8, 4, 4), + (3072, 768, 2048, 32, 32, False, True, True): (1, 16, 1, 8), + (3072, 768, 2048, 32, 32, True, False, True): (6, 8, 3, 4), + (3072, 768, 2048, 64, 64, False, True, True): (8, 16, 3, 4), + (3072, 768, 2048, 64, 64, True, False, True): (3, 16, 3, 4), + (3072, 768, 2048, 128, 128, False, True, True): (1, 16, 3, 8), + (3072, 768, 2048, 128, 128, True, False, True): (2, 16, 2, 4), + (3072, 768, 4096, 16, 16, False, True, True): (1, 16, 1, 4), + (3072, 768, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (3072, 768, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (3072, 768, 4096, 32, 32, True, False, True): (4, 16, 3, 4), + (3072, 768, 4096, 64, 64, False, True, True): (2, 32, 1, 4), + (3072, 768, 4096, 64, 64, True, False, True): (2, 16, 2, 4), + (3072, 768, 4096, 128, 128, False, True, True): (2, 32, 1, 16), + (3072, 768, 4096, 128, 128, True, False, True): (3, 32, 2, 4), + (3072, 768, 8192, 16, 16, False, True, True): (2, 32, 1, 4), + (3072, 768, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (3072, 768, 8192, 32, 32, False, True, True): (2, 32, 1, 4), + (3072, 768, 8192, 32, 32, True, False, True): (6, 32, 3, 4), + (3072, 768, 8192, 64, 64, False, True, True): (2, 64, 1, 4), + (3072, 768, 8192, 64, 64, True, False, True): (2, 32, 2, 4), + (3072, 768, 8192, 128, 128, False, True, True): (3, 64, 1, 4), + (3072, 768, 8192, 128, 128, True, False, True): (2, 64, 2, 4), + (3072, 768, 16384, 16, 16, False, True, True): (1, 64, 1, 4), + (3072, 768, 16384, 16, 16, True, False, True): (1, 64, 1, 1), + (3072, 768, 16384, 32, 32, False, True, True): (2, 64, 1, 4), + (3072, 768, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (3072, 768, 16384, 64, 64, False, True, True): (2, 128, 1, 4), + (3072, 768, 16384, 64, 64, True, False, True): (4, 64, 2, 4), + (3072, 768, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (3072, 768, 16384, 128, 128, True, False, True): (1, 128, 2, 4), + (3072, 768, 32768, 16, 16, False, True, True): (1, 128, 1, 4), + (3072, 768, 32768, 16, 16, True, False, True): (8, 256, 3, 2), + (3072, 768, 32768, 32, 32, False, True, True): (2, 128, 1, 4), + (3072, 768, 32768, 32, 32, True, False, True): (8, 128, 3, 4), + (3072, 768, 32768, 64, 64, False, True, True): (1, 256, 1, 4), + (3072, 768, 32768, 64, 64, True, False, True): (8, 128, 2, 4), + (3072, 768, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (3072, 768, 32768, 128, 128, True, False, True): (3, 256, 2, 4), + (3072, 768, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (3072, 768, 50432, 16, 16, True, False, True): (7, 197, 4, 1), + (3072, 768, 50432, 32, 32, False, True, True): (2, 197, 1, 4), + (3072, 768, 50432, 32, 32, True, False, True): (10, 197, 3, 4), + (3072, 768, 50432, 64, 64, False, True, True): (1, 394, 1, 4), + (3072, 768, 50432, 64, 64, True, False, True): (3, 197, 2, 4), + (3072, 768, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (3072, 768, 50432, 128, 128, True, False, True): (2, 394, 2, 4), + (3072, 768, 65536, 16, 16, False, True, True): (1, 256, 1, 4), + (3072, 768, 65536, 16, 16, True, False, True): (15, 256, 4, 1), + (3072, 768, 65536, 32, 32, False, True, True): (2, 256, 1, 4), + (3072, 768, 65536, 32, 32, True, False, True): (10, 256, 3, 4), + (3072, 768, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (3072, 768, 65536, 64, 64, True, False, True): (3, 256, 2, 4), + (3072, 768, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (3072, 768, 65536, 128, 128, True, False, True): (3, 512, 2, 4), + (3072, 768, 131072, 16, 16, False, True, True): (1, 512, 1, 4), + (3072, 768, 131072, 16, 16, True, False, True): (15, 512, 4, 1), + (3072, 768, 131072, 32, 32, False, True, True): (2, 512, 1, 4), + (3072, 768, 131072, 32, 32, True, False, True): (9, 512, 3, 4), + (3072, 768, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (3072, 768, 131072, 64, 64, True, False, True): (3, 512, 2, 4), + (3072, 768, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (3072, 768, 131072, 128, 128, True, False, True): (3, 1024, 2, 4), + (3072, 3072, 256, 16, 16, False, True, True): (5, 4, 1, 4), + (3072, 3072, 256, 16, 16, True, False, True): (1, 2, 5, 2), + (3072, 3072, 256, 32, 32, False, True, True): (1, 4, 1, 8), + (3072, 3072, 256, 32, 32, True, False, True): (3, 4, 4, 2), + (3072, 3072, 256, 64, 64, False, True, True): (2, 4, 3, 4), + (3072, 3072, 256, 64, 64, True, False, True): (3, 4, 4, 4), + (3072, 3072, 256, 128, 128, False, True, True): (1, 2, 3, 8), + (3072, 3072, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (3072, 3072, 512, 16, 16, False, True, True): (5, 4, 1, 2), + (3072, 3072, 512, 16, 16, True, False, True): (1, 2, 4, 4), + (3072, 3072, 512, 32, 32, False, True, True): (3, 8, 1, 4), + (3072, 3072, 512, 32, 32, True, False, True): (4, 2, 3, 4), + (3072, 3072, 512, 64, 64, False, True, True): (1, 8, 2, 2), + (3072, 3072, 512, 64, 64, True, False, True): (2, 4, 3, 4), + (3072, 3072, 512, 128, 128, False, True, True): (1, 4, 3, 8), + (3072, 3072, 512, 128, 128, True, False, True): (4, 4, 3, 8), + (3072, 3072, 1024, 16, 16, False, True, True): (1, 8, 1, 4), + (3072, 3072, 1024, 16, 16, True, False, True): (4, 8, 5, 2), + (3072, 3072, 1024, 32, 32, False, True, True): (1, 8, 1, 8), + (3072, 3072, 1024, 32, 32, True, False, True): (1, 4, 4, 4), + (3072, 3072, 1024, 64, 64, False, True, True): (3, 8, 3, 4), + (3072, 3072, 1024, 64, 64, True, False, True): (2, 4, 3, 4), + (3072, 3072, 1024, 128, 128, False, True, True): (3, 8, 1, 4), + (3072, 3072, 1024, 128, 128, True, False, True): (1, 8, 3, 8), + (3072, 3072, 2048, 16, 16, False, True, True): (1, 16, 1, 2), + (3072, 3072, 2048, 16, 16, True, False, True): (4, 16, 4, 2), + (3072, 3072, 2048, 32, 32, False, True, True): (1, 16, 1, 8), + (3072, 3072, 2048, 32, 32, True, False, True): (3, 8, 4, 4), + (3072, 3072, 2048, 64, 64, False, True, True): (3, 16, 3, 4), + (3072, 3072, 2048, 64, 64, True, False, True): (3, 8, 3, 4), + (3072, 3072, 2048, 128, 128, False, True, True): (4, 16, 3, 8), + (3072, 3072, 2048, 128, 128, True, False, True): (3, 16, 3, 8), + (3072, 3072, 4096, 16, 16, False, True, True): (1, 32, 1, 2), + (3072, 3072, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (3072, 3072, 4096, 32, 32, False, True, True): (1, 32, 1, 8), + (3072, 3072, 4096, 32, 32, True, False, True): (3, 16, 3, 4), + (3072, 3072, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (3072, 3072, 4096, 64, 64, True, False, True): (3, 16, 3, 4), + (3072, 3072, 4096, 128, 128, False, True, True): (1, 32, 3, 8), + (3072, 3072, 4096, 128, 128, True, False, True): (3, 32, 3, 8), + (3072, 3072, 8192, 16, 16, False, True, True): (1, 64, 1, 2), + (3072, 3072, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (3072, 3072, 8192, 32, 32, False, True, True): (1, 64, 1, 8), + (3072, 3072, 8192, 32, 32, True, False, True): (8, 32, 3, 4), + (3072, 3072, 8192, 64, 64, False, True, True): (3, 64, 3, 4), + (3072, 3072, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (3072, 3072, 8192, 128, 128, False, True, True): (2, 64, 3, 8), + (3072, 3072, 8192, 128, 128, True, False, True): (1, 64, 3, 8), + (3072, 3072, 16384, 16, 16, False, True, True): (1, 128, 1, 2), + (3072, 3072, 16384, 16, 16, True, False, True): (4, 128, 4, 2), + (3072, 3072, 16384, 32, 32, False, True, True): (1, 64, 1, 2), + (3072, 3072, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (3072, 3072, 16384, 64, 64, False, True, True): (1, 128, 3, 4), + (3072, 3072, 16384, 64, 64, True, False, True): (4, 64, 3, 4), + (3072, 3072, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (3072, 3072, 16384, 128, 128, True, False, True): (1, 128, 3, 8), + (3072, 3072, 32768, 16, 16, False, True, True): (1, 256, 1, 2), + (3072, 3072, 32768, 16, 16, True, False, True): (8, 128, 4, 4), + (3072, 3072, 32768, 32, 32, False, True, True): (1, 256, 1, 8), + (3072, 3072, 32768, 32, 32, True, False, True): (5, 128, 3, 4), + (3072, 3072, 32768, 64, 64, False, True, True): (1, 256, 3, 4), + (3072, 3072, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (3072, 3072, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (3072, 3072, 32768, 128, 128, True, False, True): (3, 256, 2, 4), + (3072, 3072, 65536, 16, 16, False, True, True): (1, 512, 1, 2), + (3072, 3072, 65536, 16, 16, True, False, True): (7, 256, 4, 4), + (3072, 3072, 65536, 32, 32, False, True, True): (1, 256, 1, 2), + (3072, 3072, 65536, 32, 32, True, False, True): (5, 256, 3, 4), + (3072, 3072, 65536, 64, 64, False, True, True): (1, 512, 3, 4), + (3072, 3072, 65536, 64, 64, True, False, True): (3, 256, 3, 4), + (3072, 3072, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (3072, 3072, 65536, 128, 128, True, False, True): (3, 512, 2, 4), + (3072, 3072, 131072, 16, 16, False, True, True): (1, 1024, 1, 2), + (3072, 3072, 131072, 16, 16, True, False, True): (5, 512, 4, 4), + (3072, 3072, 131072, 32, 32, False, True, True): (1, 512, 1, 2), + (3072, 3072, 131072, 32, 32, True, False, True): (3, 512, 3, 4), + (3072, 3072, 131072, 64, 64, False, True, True): (1, 1024, 3, 4), + (3072, 3072, 131072, 64, 64, True, False, True): (3, 512, 3, 4), + (3072, 3072, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (3072, 3072, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + (4096, 4096, 256, 16, 16, False, True, True): (2, 2, 6, 4), + (4096, 4096, 256, 16, 16, True, False, True): (2, 2, 5, 4), + (4096, 4096, 256, 32, 32, False, True, True): (7, 2, 4, 4), + (4096, 4096, 256, 32, 32, True, False, True): (1, 2, 4, 4), + (4096, 4096, 256, 64, 64, False, True, True): (3, 4, 3, 4), + (4096, 4096, 256, 64, 64, True, False, True): (3, 4, 3, 4), + (4096, 4096, 256, 128, 128, False, True, True): (1, 2, 2, 8), + (4096, 4096, 256, 128, 128, True, False, True): (1, 2, 2, 8), + (4096, 4096, 512, 16, 16, False, True, True): (4, 2, 3, 4), + (4096, 4096, 512, 16, 16, True, False, True): (2, 4, 3, 2), + (4096, 4096, 512, 32, 32, False, True, True): (3, 4, 3, 4), + (4096, 4096, 512, 32, 32, True, False, True): (3, 4, 3, 2), + (4096, 4096, 512, 64, 64, False, True, True): (3, 4, 3, 4), + (4096, 4096, 512, 64, 64, True, False, True): (3, 4, 3, 4), + (4096, 4096, 512, 128, 128, False, True, True): (2, 4, 2, 8), + (4096, 4096, 512, 128, 128, True, False, True): (2, 4, 1, 4), + (4096, 4096, 1024, 16, 16, False, True, True): (2, 8, 3, 2), + (4096, 4096, 1024, 16, 16, True, False, True): (2, 8, 3, 2), + (4096, 4096, 1024, 32, 32, False, True, True): (3, 8, 3, 4), + (4096, 4096, 1024, 32, 32, True, False, True): (1, 8, 3, 2), + (4096, 4096, 1024, 64, 64, False, True, True): (1, 8, 3, 4), + (4096, 4096, 1024, 64, 64, True, False, True): (1, 8, 3, 4), + (4096, 4096, 1024, 128, 128, False, True, True): (2, 8, 2, 8), + (4096, 4096, 1024, 128, 128, True, False, True): (2, 8, 2, 8), + (4096, 4096, 2048, 16, 16, False, True, True): (2, 8, 4, 4), + (4096, 4096, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (4096, 4096, 2048, 32, 32, False, True, True): (4, 8, 4, 8), + (4096, 4096, 2048, 32, 32, True, False, True): (4, 8, 4, 8), + (4096, 4096, 2048, 64, 64, False, True, True): (1, 16, 3, 4), + (4096, 4096, 2048, 64, 64, True, False, True): (4, 16, 3, 4), + (4096, 4096, 2048, 128, 128, False, True, True): (2, 16, 2, 8), + (4096, 4096, 2048, 128, 128, True, False, True): (4, 16, 1, 4), + (4096, 4096, 4096, 16, 16, False, True, True): (4, 32, 4, 4), + (4096, 4096, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (4096, 4096, 4096, 32, 32, False, True, True): (4, 16, 4, 8), + (4096, 4096, 4096, 32, 32, True, False, True): (4, 16, 3, 8), + (4096, 4096, 4096, 64, 64, False, True, True): (1, 32, 3, 4), + (4096, 4096, 4096, 64, 64, True, False, True): (1, 32, 3, 4), + (4096, 4096, 4096, 128, 128, False, True, True): (3, 32, 1, 4), + (4096, 4096, 4096, 128, 128, True, False, True): (2, 32, 1, 4), + (4096, 4096, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (4096, 4096, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (4096, 4096, 8192, 32, 32, False, True, True): (4, 32, 4, 8), + (4096, 4096, 8192, 32, 32, True, False, True): (4, 32, 4, 8), + (4096, 4096, 8192, 64, 64, False, True, True): (2, 64, 3, 4), + (4096, 4096, 8192, 64, 64, True, False, True): (2, 64, 3, 4), + (4096, 4096, 8192, 128, 128, False, True, True): (3, 64, 1, 4), + (4096, 4096, 8192, 128, 128, True, False, True): (1, 64, 1, 4), + (4096, 4096, 16384, 16, 16, False, True, True): (4, 64, 3, 4), + (4096, 4096, 16384, 16, 16, True, False, True): (4, 64, 4, 4), + (4096, 4096, 16384, 32, 32, False, True, True): (4, 64, 4, 8), + (4096, 4096, 16384, 32, 32, True, False, True): (4, 64, 4, 8), + (4096, 4096, 16384, 64, 64, False, True, True): (1, 64, 2, 4), + (4096, 4096, 16384, 64, 64, True, False, True): (1, 64, 3, 8), + (4096, 4096, 16384, 128, 128, False, True, True): (3, 128, 1, 4), + (4096, 4096, 16384, 128, 128, True, False, True): (1, 128, 1, 4), + (4096, 4096, 32768, 16, 16, False, True, True): (8, 128, 3, 2), + (4096, 4096, 32768, 16, 16, True, False, True): (5, 128, 4, 4), + (4096, 4096, 32768, 32, 32, False, True, True): (3, 128, 4, 4), + (4096, 4096, 32768, 32, 32, True, False, True): (3, 128, 4, 8), + (4096, 4096, 32768, 64, 64, False, True, True): (1, 128, 2, 4), + (4096, 4096, 32768, 64, 64, True, False, True): (3, 256, 3, 4), + (4096, 4096, 32768, 128, 128, False, True, True): (3, 256, 1, 4), + (4096, 4096, 32768, 128, 128, True, False, True): (1, 256, 1, 4), + (4096, 4096, 50432, 16, 16, False, True, True): (1, 197, 1, 4), + (4096, 4096, 50432, 16, 16, True, False, True): (4, 197, 4, 1), + (4096, 4096, 50432, 32, 32, False, True, True): (1, 197, 1, 4), + (4096, 4096, 50432, 32, 32, True, False, True): (2, 197, 3, 4), + (4096, 4096, 50432, 64, 64, False, True, True): (1, 394, 3, 4), + (4096, 4096, 50432, 64, 64, True, False, True): (1, 197, 2, 4), + (4096, 4096, 50432, 128, 128, False, True, True): (3, 394, 1, 4), + (4096, 4096, 50432, 128, 128, True, False, True): (1, 394, 2, 4), + (4096, 4096, 65536, 16, 16, False, True, True): (5, 256, 4, 4), + (4096, 4096, 65536, 16, 16, True, False, True): (5, 256, 4, 4), + (4096, 4096, 65536, 32, 32, False, True, True): (4, 256, 4, 8), + (4096, 4096, 65536, 32, 32, True, False, True): (4, 256, 3, 8), + (4096, 4096, 65536, 64, 64, False, True, True): (1, 256, 2, 4), + (4096, 4096, 65536, 64, 64, True, False, True): (1, 512, 3, 4), + (4096, 4096, 65536, 128, 128, False, True, True): (3, 512, 1, 4), + (4096, 4096, 65536, 128, 128, True, False, True): (1, 512, 1, 4), + (4096, 4096, 131072, 16, 16, False, True, True): (4, 512, 3, 4), + (4096, 4096, 131072, 16, 16, True, False, True): (5, 512, 4, 4), + (4096, 4096, 131072, 32, 32, False, True, True): (1, 512, 4, 8), + (4096, 4096, 131072, 32, 32, True, False, True): (4, 512, 4, 8), + (4096, 4096, 131072, 64, 64, False, True, True): (1, 512, 2, 4), + (4096, 4096, 131072, 64, 64, True, False, True): (1, 512, 2, 4), + (4096, 4096, 131072, 128, 128, False, True, True): (3, 1024, 1, 4), + (4096, 4096, 131072, 128, 128, True, False, True): (1, 1024, 1, 4), + (6144, 6144, 256, 16, 16, False, True, True): (1, 2, 1, 4), + (6144, 6144, 256, 16, 16, True, False, True): (1, 1, 4, 4), + (6144, 6144, 256, 32, 32, False, True, True): (3, 2, 1, 8), + (6144, 6144, 256, 32, 32, True, False, True): (2, 1, 3, 4), + (6144, 6144, 256, 64, 64, False, True, True): (2, 2, 3, 4), + (6144, 6144, 256, 64, 64, True, False, True): (6, 2, 4, 4), + (6144, 6144, 256, 128, 128, False, True, True): (2, 2, 3, 8), + (6144, 6144, 256, 128, 128, True, False, True): (1, 2, 3, 8), + (6144, 6144, 512, 16, 16, False, True, True): (4, 4, 1, 4), + (6144, 6144, 512, 16, 16, True, False, True): (3, 2, 3, 1), + (6144, 6144, 512, 32, 32, False, True, True): (1, 8, 1, 4), + (6144, 6144, 512, 32, 32, True, False, True): (2, 2, 3, 8), + (6144, 6144, 512, 64, 64, False, True, True): (4, 4, 3, 4), + (6144, 6144, 512, 64, 64, True, False, True): (6, 2, 3, 4), + (6144, 6144, 512, 128, 128, False, True, True): (3, 4, 1, 4), + (6144, 6144, 512, 128, 128, True, False, True): (4, 4, 3, 8), + (6144, 6144, 1024, 16, 16, False, True, True): (1, 8, 1, 2), + (6144, 6144, 1024, 16, 16, True, False, True): (4, 8, 4, 2), + (6144, 6144, 1024, 32, 32, False, True, True): (1, 8, 4, 2), + (6144, 6144, 1024, 32, 32, True, False, True): (1, 8, 4, 2), + (6144, 6144, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (6144, 6144, 1024, 64, 64, True, False, True): (1, 4, 3, 4), + (6144, 6144, 1024, 128, 128, False, True, True): (3, 8, 1, 4), + (6144, 6144, 1024, 128, 128, True, False, True): (1, 8, 3, 8), + (6144, 6144, 2048, 16, 16, False, True, True): (4, 4, 1, 4), + (6144, 6144, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (6144, 6144, 2048, 32, 32, False, True, True): (4, 8, 3, 4), + (6144, 6144, 2048, 32, 32, True, False, True): (2, 8, 3, 4), + (6144, 6144, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (6144, 6144, 2048, 64, 64, True, False, True): (2, 8, 3, 4), + (6144, 6144, 2048, 128, 128, False, True, True): (3, 16, 1, 4), + (6144, 6144, 2048, 128, 128, True, False, True): (4, 16, 3, 8), + (6144, 6144, 4096, 16, 16, False, True, True): (4, 8, 1, 4), + (6144, 6144, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (6144, 6144, 4096, 32, 32, False, True, True): (4, 16, 1, 2), + (6144, 6144, 4096, 32, 32, True, False, True): (2, 8, 3, 8), + (6144, 6144, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (6144, 6144, 4096, 64, 64, True, False, True): (4, 16, 3, 4), + (6144, 6144, 4096, 128, 128, False, True, True): (6, 32, 1, 4), + (6144, 6144, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (6144, 6144, 8192, 16, 16, False, True, True): (2, 16, 1, 2), + (6144, 6144, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (6144, 6144, 8192, 32, 32, False, True, True): (4, 32, 1, 2), + (6144, 6144, 8192, 32, 32, True, False, True): (4, 32, 3, 4), + (6144, 6144, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (6144, 6144, 8192, 64, 64, True, False, True): (4, 32, 3, 4), + (6144, 6144, 8192, 128, 128, False, True, True): (6, 64, 1, 4), + (6144, 6144, 8192, 128, 128, True, False, True): (4, 64, 3, 8), + (6144, 6144, 16384, 16, 16, False, True, True): (2, 32, 1, 2), + (6144, 6144, 16384, 16, 16, True, False, True): (4, 64, 4, 4), + (6144, 6144, 16384, 32, 32, False, True, True): (4, 64, 1, 2), + (6144, 6144, 16384, 32, 32, True, False, True): (4, 64, 3, 4), + (6144, 6144, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (6144, 6144, 16384, 64, 64, True, False, True): (1, 32, 3, 8), + (6144, 6144, 16384, 128, 128, False, True, True): (4, 128, 1, 4), + (6144, 6144, 16384, 128, 128, True, False, True): (4, 128, 3, 8), + (6144, 6144, 32768, 16, 16, False, True, True): (2, 64, 1, 2), + (6144, 6144, 32768, 16, 16, True, False, True): (5, 128, 4, 1), + (6144, 6144, 32768, 32, 32, False, True, True): (4, 128, 1, 2), + (6144, 6144, 32768, 32, 32, True, False, True): (3, 128, 3, 4), + (6144, 6144, 32768, 64, 64, False, True, True): (4, 256, 3, 4), + (6144, 6144, 32768, 64, 64, True, False, True): (2, 64, 3, 8), + (6144, 6144, 32768, 128, 128, False, True, True): (8, 256, 1, 4), + (6144, 6144, 32768, 128, 128, True, False, True): (4, 256, 3, 8), + (6144, 6144, 65536, 16, 16, False, True, True): (2, 128, 1, 2), + (6144, 6144, 65536, 16, 16, True, False, True): (5, 256, 4, 1), + (6144, 6144, 65536, 32, 32, False, True, True): (4, 256, 1, 2), + (6144, 6144, 65536, 32, 32, True, False, True): (2, 256, 3, 4), + (6144, 6144, 65536, 64, 64, False, True, True): (4, 512, 3, 4), + (6144, 6144, 65536, 64, 64, True, False, True): (1, 128, 3, 8), + (6144, 6144, 65536, 128, 128, False, True, True): (4, 512, 1, 4), + (6144, 6144, 65536, 128, 128, True, False, True): (4, 512, 3, 8), + (6144, 6144, 131072, 16, 16, False, True, True): (2, 256, 1, 2), + (6144, 6144, 131072, 16, 16, True, False, True): (3, 512, 4, 4), + (6144, 6144, 131072, 32, 32, False, True, True): (4, 512, 1, 2), + (6144, 6144, 131072, 32, 32, True, False, True): (4, 512, 3, 4), + (6144, 6144, 131072, 64, 64, False, True, True): (4, 1024, 3, 4), + (6144, 6144, 131072, 64, 64, True, False, True): (2, 256, 3, 8), + (6144, 6144, 131072, 128, 128, False, True, True): (4, 1024, 1, 4), + (6144, 6144, 131072, 128, 128, True, False, True): (4, 1024, 3, 8), + (8192, 8192, 256, 16, 16, False, True, True): (2, 2, 6, 4), + (8192, 8192, 256, 16, 16, True, False, True): (2, 4, 2, 2), + (8192, 8192, 256, 32, 32, False, True, True): (4, 2, 3, 4), + (8192, 8192, 256, 32, 32, True, False, True): (4, 2, 3, 4), + (8192, 8192, 256, 64, 64, False, True, True): (2, 2, 3, 8), + (8192, 8192, 256, 64, 64, True, False, True): (6, 2, 3, 8), + (8192, 8192, 256, 128, 128, False, True, True): (3, 2, 1, 4), + (8192, 8192, 256, 128, 128, True, False, True): (1, 2, 1, 4), + (8192, 8192, 512, 16, 16, False, True, True): (4, 4, 3, 2), + (8192, 8192, 512, 16, 16, True, False, True): (4, 4, 3, 4), + (8192, 8192, 512, 32, 32, False, True, True): (1, 4, 3, 4), + (8192, 8192, 512, 32, 32, True, False, True): (5, 4, 3, 2), + (8192, 8192, 512, 64, 64, False, True, True): (1, 4, 3, 4), + (8192, 8192, 512, 64, 64, True, False, True): (2, 2, 3, 8), + (8192, 8192, 512, 128, 128, False, True, True): (4, 4, 2, 8), + (8192, 8192, 512, 128, 128, True, False, True): (4, 4, 2, 8), + (8192, 8192, 1024, 16, 16, False, True, True): (4, 8, 4, 4), + (8192, 8192, 1024, 16, 16, True, False, True): (4, 8, 4, 4), + (8192, 8192, 1024, 32, 32, False, True, True): (2, 4, 4, 8), + (8192, 8192, 1024, 32, 32, True, False, True): (1, 4, 3, 4), + (8192, 8192, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (8192, 8192, 1024, 64, 64, True, False, True): (2, 8, 3, 4), + (8192, 8192, 1024, 128, 128, False, True, True): (4, 8, 2, 8), + (8192, 8192, 1024, 128, 128, True, False, True): (4, 8, 1, 4), + (8192, 8192, 2048, 16, 16, False, True, True): (2, 8, 4, 4), + (8192, 8192, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (8192, 8192, 2048, 32, 32, False, True, True): (2, 8, 4, 8), + (8192, 8192, 2048, 32, 32, True, False, True): (2, 8, 4, 8), + (8192, 8192, 2048, 64, 64, False, True, True): (4, 8, 2, 4), + (8192, 8192, 2048, 64, 64, True, False, True): (4, 16, 3, 4), + (8192, 8192, 2048, 128, 128, False, True, True): (6, 16, 1, 4), + (8192, 8192, 2048, 128, 128, True, False, True): (4, 16, 1, 4), + (8192, 8192, 4096, 16, 16, False, True, True): (4, 32, 4, 2), + (8192, 8192, 4096, 16, 16, True, False, True): (4, 32, 4, 2), + (8192, 8192, 4096, 32, 32, False, True, True): (2, 16, 4, 8), + (8192, 8192, 4096, 32, 32, True, False, True): (4, 16, 4, 8), + (8192, 8192, 4096, 64, 64, False, True, True): (4, 16, 2, 4), + (8192, 8192, 4096, 64, 64, True, False, True): (4, 16, 2, 4), + (8192, 8192, 4096, 128, 128, False, True, True): (6, 32, 1, 4), + (8192, 8192, 4096, 128, 128, True, False, True): (4, 32, 1, 4), + (8192, 8192, 8192, 16, 16, False, True, True): (4, 64, 4, 2), + (8192, 8192, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (8192, 8192, 8192, 32, 32, False, True, True): (2, 32, 4, 8), + (8192, 8192, 8192, 32, 32, True, False, True): (2, 32, 4, 8), + (8192, 8192, 8192, 64, 64, False, True, True): (2, 32, 2, 4), + (8192, 8192, 8192, 64, 64, True, False, True): (4, 32, 2, 4), + (8192, 8192, 8192, 128, 128, False, True, True): (6, 64, 1, 4), + (8192, 8192, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (8192, 8192, 16384, 16, 16, False, True, True): (4, 64, 3, 4), + (8192, 8192, 16384, 16, 16, True, False, True): (4, 64, 4, 4), + (8192, 8192, 16384, 32, 32, False, True, True): (4, 64, 4, 8), + (8192, 8192, 16384, 32, 32, True, False, True): (4, 64, 4, 8), + (8192, 8192, 16384, 64, 64, False, True, True): (4, 64, 2, 4), + (8192, 8192, 16384, 64, 64, True, False, True): (4, 64, 3, 8), + (8192, 8192, 16384, 128, 128, False, True, True): (6, 128, 1, 4), + (8192, 8192, 16384, 128, 128, True, False, True): (4, 128, 1, 4), + (8192, 8192, 32768, 16, 16, False, True, True): (3, 128, 4, 4), + (8192, 8192, 32768, 16, 16, True, False, True): (3, 128, 4, 4), + (8192, 8192, 32768, 32, 32, False, True, True): (2, 128, 4, 8), + (8192, 8192, 32768, 32, 32, True, False, True): (2, 128, 4, 8), + (8192, 8192, 32768, 64, 64, False, True, True): (2, 128, 2, 4), + (8192, 8192, 32768, 64, 64, True, False, True): (2, 128, 3, 8), + (8192, 8192, 32768, 128, 128, False, True, True): (6, 256, 1, 4), + (8192, 8192, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (8192, 8192, 50432, 16, 16, False, True, True): (1, 197, 1, 1), + (8192, 8192, 50432, 16, 16, True, False, True): (3, 197, 4, 1), + (8192, 8192, 50432, 32, 32, False, True, True): (2, 197, 1, 4), + (8192, 8192, 50432, 32, 32, True, False, True): (2, 197, 3, 4), + (8192, 8192, 50432, 64, 64, False, True, True): (2, 394, 3, 4), + (8192, 8192, 65536, 16, 16, False, True, True): (3, 256, 4, 4), + (8192, 8192, 65536, 16, 16, True, False, True): (4, 256, 4, 4), + (8192, 8192, 65536, 32, 32, False, True, True): (2, 256, 4, 8), + (8192, 8192, 65536, 32, 32, True, False, True): (2, 256, 3, 8), + (8192, 8192, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (8192, 8192, 65536, 64, 64, True, False, True): (4, 256, 3, 8), + (8192, 8192, 65536, 128, 128, False, True, True): (6, 512, 1, 4), + (8192, 8192, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (8192, 8192, 131072, 16, 16, False, True, True): (4, 512, 4, 4), + (8192, 8192, 131072, 16, 16, True, False, True): (3, 512, 4, 4), + (8192, 8192, 131072, 32, 32, False, True, True): (2, 512, 4, 8), + (8192, 8192, 131072, 32, 32, True, False, True): (2, 512, 4, 8), + (8192, 8192, 131072, 64, 64, False, True, True): (2, 512, 2, 4), + (8192, 8192, 131072, 64, 64, True, False, True): (2, 512, 2, 4), + (8192, 8192, 131072, 128, 128, False, True, True): (4, 1024, 1, 4), + (8192, 8192, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (12288, 12288, 256, 16, 16, False, True, True): (4, 2, 1, 4), + (12288, 12288, 256, 16, 16, True, False, True): (1, 1, 3, 1), + (12288, 12288, 256, 32, 32, False, True, True): (4, 4, 1, 4), + (12288, 12288, 256, 32, 32, True, False, True): (2, 1, 3, 2), + (12288, 12288, 256, 64, 64, False, True, True): (4, 2, 3, 4), + (12288, 12288, 256, 64, 64, True, False, True): (3, 1, 3, 4), + (12288, 12288, 256, 128, 128, False, True, True): (6, 2, 1, 4), + (12288, 12288, 256, 128, 128, True, False, True): (4, 2, 3, 8), + (12288, 12288, 512, 16, 16, False, True, True): (4, 4, 1, 2), + (12288, 12288, 512, 16, 16, True, False, True): (4, 4, 4, 2), + (12288, 12288, 512, 32, 32, False, True, True): (4, 4, 4, 2), + (12288, 12288, 512, 32, 32, True, False, True): (2, 2, 3, 8), + (12288, 12288, 512, 64, 64, False, True, True): (4, 4, 3, 4), + (12288, 12288, 512, 64, 64, True, False, True): (8, 2, 3, 4), + (12288, 12288, 512, 128, 128, False, True, True): (4, 4, 3, 8), + (12288, 12288, 512, 128, 128, True, False, True): (4, 4, 3, 8), + (12288, 12288, 1024, 16, 16, False, True, True): (4, 8, 1, 2), + (12288, 12288, 1024, 16, 16, True, False, True): (2, 4, 4, 4), + (12288, 12288, 1024, 32, 32, False, True, True): (4, 4, 3, 4), + (12288, 12288, 1024, 32, 32, True, False, True): (1, 4, 3, 4), + (12288, 12288, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (12288, 12288, 1024, 64, 64, True, False, True): (2, 4, 3, 4), + (12288, 12288, 1024, 128, 128, False, True, True): (4, 8, 3, 8), + (12288, 12288, 1024, 128, 128, True, False, True): (4, 8, 3, 8), + (12288, 12288, 2048, 16, 16, False, True, True): (2, 4, 1, 4), + (12288, 12288, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (12288, 12288, 2048, 32, 32, False, True, True): (4, 8, 1, 2), + (12288, 12288, 2048, 32, 32, True, False, True): (2, 8, 4, 8), + (12288, 12288, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (12288, 12288, 2048, 64, 64, True, False, True): (2, 8, 3, 4), + (12288, 12288, 2048, 128, 128, False, True, True): (4, 16, 3, 8), + (12288, 12288, 2048, 128, 128, True, False, True): (4, 16, 3, 8), + (12288, 12288, 4096, 16, 16, False, True, True): (2, 8, 1, 4), + (12288, 12288, 4096, 16, 16, True, False, True): (2, 16, 4, 4), + (12288, 12288, 4096, 32, 32, False, True, True): (2, 16, 1, 2), + (12288, 12288, 4096, 32, 32, True, False, True): (2, 16, 3, 4), + (12288, 12288, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (12288, 12288, 4096, 64, 64, True, False, True): (2, 16, 3, 4), + (12288, 12288, 4096, 128, 128, False, True, True): (4, 32, 1, 4), + (12288, 12288, 4096, 128, 128, True, False, True): (4, 32, 3, 8), + (12288, 12288, 8192, 16, 16, False, True, True): (2, 32, 1, 1), + (12288, 12288, 8192, 16, 16, True, False, True): (4, 64, 4, 2), + (12288, 12288, 8192, 32, 32, False, True, True): (2, 32, 1, 2), + (12288, 12288, 8192, 32, 32, True, False, True): (2, 32, 3, 2), + (12288, 12288, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (12288, 12288, 8192, 64, 64, True, False, True): (2, 32, 3, 4), + (12288, 12288, 8192, 128, 128, False, True, True): (4, 64, 3, 8), + (12288, 12288, 8192, 128, 128, True, False, True): (2, 64, 3, 8), + (12288, 12288, 16384, 16, 16, False, True, True): (4, 128, 1, 2), + (12288, 12288, 16384, 16, 16, True, False, True): (4, 128, 4, 2), + (12288, 12288, 16384, 32, 32, False, True, True): (2, 64, 1, 2), + (12288, 12288, 16384, 32, 32, True, False, True): (2, 64, 3, 4), + (12288, 12288, 16384, 64, 64, False, True, True): (4, 128, 3, 4), + (12288, 12288, 16384, 64, 64, True, False, True): (2, 64, 3, 4), + (12288, 12288, 16384, 128, 128, False, True, True): (4, 128, 1, 4), + (12288, 12288, 16384, 128, 128, True, False, True): (4, 128, 3, 8), + (12288, 12288, 32768, 16, 16, False, True, True): (2, 128, 1, 1), + (12288, 12288, 32768, 16, 16, True, False, True): (3, 128, 4, 1), + (12288, 12288, 32768, 32, 32, False, True, True): (2, 128, 1, 2), + (12288, 12288, 32768, 32, 32, True, False, True): (2, 128, 3, 2), + (12288, 12288, 32768, 64, 64, False, True, True): (4, 256, 3, 4), + (12288, 12288, 32768, 64, 64, True, False, True): (1, 64, 3, 8), + (12288, 12288, 32768, 128, 128, False, True, True): (4, 256, 3, 8), + (12288, 12288, 32768, 128, 128, True, False, True): (4, 256, 3, 8), + (12288, 12288, 65536, 16, 16, False, True, True): (4, 512, 1, 2), + (12288, 12288, 65536, 16, 16, True, False, True): (3, 256, 4, 1), + (12288, 12288, 65536, 32, 32, False, True, True): (2, 256, 1, 2), + (12288, 12288, 65536, 32, 32, True, False, True): (2, 256, 3, 2), + (12288, 12288, 65536, 64, 64, False, True, True): (4, 512, 3, 4), + (12288, 12288, 65536, 64, 64, True, False, True): (2, 256, 3, 4), + (12288, 12288, 65536, 128, 128, False, True, True): (4, 512, 1, 4), + (12288, 12288, 65536, 128, 128, True, False, True): (4, 512, 3, 8), + (12288, 12288, 131072, 16, 16, False, True, True): (2, 512, 1, 1), + (12288, 12288, 131072, 16, 16, True, False, True): (2, 512, 4, 4), + (12288, 12288, 131072, 32, 32, False, True, True): (2, 512, 1, 2), + (12288, 12288, 131072, 32, 32, True, False, True): (2, 512, 3, 4), + (12288, 12288, 131072, 64, 64, False, True, True): (4, 1024, 3, 4), + (12288, 12288, 131072, 64, 64, True, False, True): (2, 512, 3, 4), + (12288, 12288, 131072, 128, 128, False, True, True): (4, 1024, 3, 8), + (12288, 12288, 131072, 128, 128, True, False, True): (4, 1024, 3, 8), + (16384, 16384, 256, 16, 16, False, True, True): (2, 2, 3, 2), + (16384, 16384, 256, 16, 16, True, False, True): (2, 2, 6, 4), + (16384, 16384, 256, 32, 32, False, True, True): (4, 2, 3, 4), + (16384, 16384, 256, 32, 32, True, False, True): (4, 2, 3, 2), + (16384, 16384, 256, 64, 64, False, True, True): (2, 2, 5, 4), + (16384, 16384, 256, 64, 64, True, False, True): (2, 2, 3, 8), + (16384, 16384, 256, 128, 128, False, True, True): (4, 2, 2, 8), + (16384, 16384, 256, 128, 128, True, False, True): (2, 2, 1, 4), + (16384, 16384, 512, 16, 16, False, True, True): (1, 2, 4, 4), + (16384, 16384, 512, 16, 16, True, False, True): (1, 2, 4, 4), + (16384, 16384, 512, 32, 32, False, True, True): (2, 2, 3, 8), + (16384, 16384, 512, 32, 32, True, False, True): (2, 2, 4, 8), + (16384, 16384, 512, 64, 64, False, True, True): (4, 4, 3, 4), + (16384, 16384, 512, 64, 64, True, False, True): (2, 4, 3, 4), + (16384, 16384, 512, 128, 128, False, True, True): (4, 4, 2, 8), + (16384, 16384, 512, 128, 128, True, False, True): (4, 4, 2, 8), + (16384, 16384, 1024, 16, 16, False, True, True): (4, 8, 4, 4), + (16384, 16384, 1024, 16, 16, True, False, True): (2, 4, 4, 4), + (16384, 16384, 1024, 32, 32, False, True, True): (2, 4, 4, 8), + (16384, 16384, 1024, 32, 32, True, False, True): (2, 4, 4, 8), + (16384, 16384, 1024, 64, 64, False, True, True): (4, 4, 2, 4), + (16384, 16384, 1024, 64, 64, True, False, True): (2, 4, 2, 4), + (16384, 16384, 1024, 128, 128, False, True, True): (6, 8, 1, 4), + (16384, 16384, 1024, 128, 128, True, False, True): (4, 8, 1, 4), + (16384, 16384, 2048, 16, 16, False, True, True): (2, 8, 4, 4), + (16384, 16384, 2048, 16, 16, True, False, True): (2, 8, 4, 4), + (16384, 16384, 2048, 32, 32, False, True, True): (2, 8, 4, 8), + (16384, 16384, 2048, 32, 32, True, False, True): (2, 8, 4, 8), + (16384, 16384, 2048, 64, 64, False, True, True): (2, 8, 2, 4), + (16384, 16384, 2048, 64, 64, True, False, True): (2, 8, 2, 4), + (16384, 16384, 2048, 128, 128, False, True, True): (4, 16, 2, 8), + (16384, 16384, 2048, 128, 128, True, False, True): (4, 16, 1, 4), + (16384, 16384, 4096, 16, 16, False, True, True): (2, 16, 4, 4), + (16384, 16384, 4096, 16, 16, True, False, True): (2, 16, 4, 4), + (16384, 16384, 4096, 32, 32, False, True, True): (1, 16, 4, 8), + (16384, 16384, 4096, 32, 32, True, False, True): (2, 16, 3, 4), + (16384, 16384, 4096, 64, 64, False, True, True): (1, 16, 2, 4), + (16384, 16384, 4096, 64, 64, True, False, True): (2, 16, 2, 4), + (16384, 16384, 4096, 128, 128, False, True, True): (4, 32, 2, 8), + (16384, 16384, 4096, 128, 128, True, False, True): (4, 32, 1, 4), + (16384, 16384, 8192, 16, 16, False, True, True): (2, 64, 4, 2), + (16384, 16384, 8192, 16, 16, True, False, True): (2, 64, 4, 2), + (16384, 16384, 8192, 32, 32, False, True, True): (2, 32, 4, 8), + (16384, 16384, 8192, 32, 32, True, False, True): (2, 32, 4, 8), + (16384, 16384, 8192, 64, 64, False, True, True): (2, 32, 2, 4), + (16384, 16384, 8192, 64, 64, True, False, True): (2, 32, 4, 8), + (16384, 16384, 8192, 128, 128, False, True, True): (4, 64, 2, 8), + (16384, 16384, 8192, 128, 128, True, False, True): (4, 64, 1, 4), + (16384, 16384, 16384, 16, 16, False, True, True): (1, 64, 4, 4), + (16384, 16384, 16384, 16, 16, True, False, True): (1, 64, 4, 4), + (16384, 16384, 16384, 32, 32, False, True, True): (1, 64, 4, 8), + (16384, 16384, 16384, 32, 32, True, False, True): (1, 64, 4, 8), + (16384, 16384, 16384, 64, 64, False, True, True): (1, 64, 2, 4), + (16384, 16384, 16384, 64, 64, True, False, True): (1, 64, 3, 8), + (16384, 16384, 16384, 128, 128, False, True, True): (4, 128, 1, 4), + (16384, 16384, 16384, 128, 128, True, False, True): (4, 128, 1, 4), + (16384, 16384, 32768, 16, 16, False, True, True): (1, 128, 4, 4), + (16384, 16384, 32768, 16, 16, True, False, True): (1, 128, 4, 4), + (16384, 16384, 32768, 32, 32, False, True, True): (1, 128, 3, 4), + (16384, 16384, 32768, 32, 32, True, False, True): (1, 128, 3, 8), + (16384, 16384, 32768, 64, 64, False, True, True): (2, 128, 2, 4), + (16384, 16384, 32768, 64, 64, True, False, True): (1, 128, 4, 8), + (16384, 16384, 32768, 128, 128, False, True, True): (4, 256, 2, 8), + (16384, 16384, 32768, 128, 128, True, False, True): (4, 256, 1, 4), + (16384, 16384, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (16384, 16384, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (16384, 16384, 65536, 32, 32, False, True, True): (1, 256, 4, 8), + (16384, 16384, 65536, 32, 32, True, False, True): (1, 256, 3, 4), + (16384, 16384, 65536, 64, 64, False, True, True): (2, 256, 2, 4), + (16384, 16384, 65536, 64, 64, True, False, True): (1, 256, 3, 8), + (16384, 16384, 65536, 128, 128, False, True, True): (4, 512, 2, 8), + (16384, 16384, 65536, 128, 128, True, False, True): (4, 512, 1, 4), + (16384, 16384, 131072, 16, 16, False, True, True): (1, 512, 4, 4), + (16384, 16384, 131072, 16, 16, True, False, True): (1, 512, 3, 2), + (16384, 16384, 131072, 32, 32, False, True, True): (1, 512, 4, 8), + (16384, 16384, 131072, 32, 32, True, False, True): (1, 512, 3, 2), + (16384, 16384, 131072, 64, 64, False, True, True): (1, 512, 2, 4), + (16384, 16384, 131072, 64, 64, True, False, True): (1, 512, 2, 4), + (16384, 16384, 131072, 128, 128, False, True, True): (4, 1024, 1, 4), + (16384, 16384, 131072, 128, 128, True, False, True): (4, 1024, 1, 4), + (24576, 24576, 256, 16, 16, False, True, True): (6, 2, 1, 2), + (24576, 24576, 256, 16, 16, True, False, True): (2, 2, 5, 4), + (24576, 24576, 256, 32, 32, False, True, True): (4, 4, 1, 4), + (24576, 24576, 256, 32, 32, True, False, True): (2, 2, 4, 2), + (24576, 24576, 256, 64, 64, False, True, True): (2, 2, 3, 4), + (24576, 24576, 256, 64, 64, True, False, True): (1, 1, 3, 4), + (24576, 24576, 256, 128, 128, False, True, True): (6, 2, 1, 4), + (24576, 24576, 256, 128, 128, True, False, True): (2, 2, 3, 8), + (24576, 24576, 512, 16, 16, False, True, True): (4, 4, 1, 2), + (24576, 24576, 512, 16, 16, True, False, True): (2, 2, 4, 4), + (24576, 24576, 512, 32, 32, False, True, True): (1, 2, 3, 4), + (24576, 24576, 512, 32, 32, True, False, True): (1, 2, 3, 4), + (24576, 24576, 512, 64, 64, False, True, True): (4, 4, 3, 4), + (24576, 24576, 512, 64, 64, True, False, True): (1, 2, 3, 4), + (24576, 24576, 512, 128, 128, False, True, True): (4, 4, 3, 8), + (24576, 24576, 512, 128, 128, True, False, True): (4, 4, 3, 8), + (24576, 24576, 1024, 16, 16, False, True, True): (2, 8, 1, 2), + (24576, 24576, 1024, 16, 16, True, False, True): (2, 4, 4, 4), + (24576, 24576, 1024, 32, 32, False, True, True): (2, 4, 1, 2), + (24576, 24576, 1024, 32, 32, True, False, True): (1, 4, 3, 4), + (24576, 24576, 1024, 64, 64, False, True, True): (4, 8, 3, 4), + (24576, 24576, 1024, 64, 64, True, False, True): (1, 4, 3, 4), + (24576, 24576, 1024, 128, 128, False, True, True): (4, 8, 3, 8), + (24576, 24576, 1024, 128, 128, True, False, True): (4, 8, 3, 8), + (24576, 24576, 2048, 16, 16, False, True, True): (1, 4, 1, 4), + (24576, 24576, 2048, 16, 16, True, False, True): (1, 8, 4, 4), + (24576, 24576, 2048, 32, 32, False, True, True): (2, 8, 1, 2), + (24576, 24576, 2048, 32, 32, True, False, True): (1, 8, 3, 4), + (24576, 24576, 2048, 64, 64, False, True, True): (4, 16, 3, 4), + (24576, 24576, 2048, 64, 64, True, False, True): (1, 4, 3, 8), + (24576, 24576, 2048, 128, 128, False, True, True): (4, 16, 3, 8), + (24576, 24576, 2048, 128, 128, True, False, True): (2, 16, 3, 8), + (24576, 24576, 4096, 16, 16, False, True, True): (2, 32, 1, 2), + (24576, 24576, 4096, 16, 16, True, False, True): (1, 16, 4, 4), + (24576, 24576, 4096, 32, 32, False, True, True): (1, 16, 1, 2), + (24576, 24576, 4096, 32, 32, True, False, True): (1, 16, 3, 4), + (24576, 24576, 4096, 64, 64, False, True, True): (4, 32, 3, 4), + (24576, 24576, 4096, 64, 64, True, False, True): (1, 8, 3, 8), + (24576, 24576, 4096, 128, 128, False, True, True): (4, 32, 3, 8), + (24576, 24576, 4096, 128, 128, True, False, True): (2, 32, 3, 8), + (24576, 24576, 8192, 16, 16, False, True, True): (1, 32, 1, 1), + (24576, 24576, 8192, 16, 16, True, False, True): (2, 64, 4, 2), + (24576, 24576, 8192, 32, 32, False, True, True): (1, 32, 1, 2), + (24576, 24576, 8192, 32, 32, True, False, True): (1, 32, 3, 4), + (24576, 24576, 8192, 64, 64, False, True, True): (4, 64, 3, 4), + (24576, 24576, 8192, 64, 64, True, False, True): (1, 32, 3, 4), + (24576, 24576, 8192, 128, 128, False, True, True): (4, 64, 3, 8), + (24576, 24576, 8192, 128, 128, True, False, True): (4, 64, 3, 8), + (24576, 24576, 16384, 16, 16, False, True, True): (2, 128, 1, 2), + (24576, 24576, 16384, 16, 16, True, False, True): (1, 64, 4, 4), + (24576, 24576, 16384, 32, 32, False, True, True): (1, 64, 1, 2), + (24576, 24576, 16384, 32, 32, True, False, True): (1, 64, 3, 2), + (24576, 24576, 16384, 64, 64, False, True, True): (2, 128, 3, 4), + (24576, 24576, 16384, 64, 64, True, False, True): (1, 32, 3, 8), + (24576, 24576, 16384, 128, 128, False, True, True): (4, 128, 3, 8), + (24576, 24576, 16384, 128, 128, True, False, True): (4, 128, 3, 8), + (24576, 24576, 32768, 16, 16, False, True, True): (1, 128, 1, 1), + (24576, 24576, 32768, 16, 16, True, False, True): (1, 128, 4, 4), + (24576, 24576, 32768, 32, 32, False, True, True): (1, 128, 1, 2), + (24576, 24576, 32768, 32, 32, True, False, True): (1, 128, 3, 4), + (24576, 24576, 32768, 64, 64, False, True, True): (2, 256, 3, 4), + (24576, 24576, 32768, 64, 64, True, False, True): (1, 128, 3, 4), + (24576, 24576, 32768, 128, 128, False, True, True): (4, 256, 3, 8), + (24576, 24576, 32768, 128, 128, True, False, True): (2, 256, 3, 8), + (24576, 24576, 65536, 16, 16, False, True, True): (2, 512, 1, 2), + (24576, 24576, 65536, 16, 16, True, False, True): (1, 256, 4, 4), + (32768, 32768, 256, 16, 16, False, True, True): (4, 2, 1, 2), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.56)): { + (192, 192, 256, 64, 64, False, True, True): (1, 4, 3, 4), + (192, 192, 256, 64, 64, True, False, True): (1, 4, 3, 4), + (192, 192, 512, 64, 64, False, True, True): (1, 8, 5, 4), + (192, 192, 512, 64, 64, True, False, True): (1, 8, 3, 4), + (192, 192, 1024, 64, 64, False, True, True): (1, 16, 3, 2), + (192, 192, 1024, 64, 64, True, False, True): (1, 16, 3, 4), + (192, 192, 2048, 64, 64, False, True, True): (1, 32, 5, 4), + (192, 192, 2048, 64, 64, True, False, True): (4, 32, 5, 4), + (192, 192, 4096, 64, 64, False, True, True): (1, 64, 1, 8), + (192, 192, 4096, 64, 64, True, False, True): (1, 32, 3, 4), + (192, 192, 8192, 64, 64, False, True, True): (4, 128, 1, 4), + (192, 192, 8192, 64, 64, True, False, True): (3, 64, 3, 4), + (192, 192, 16384, 64, 64, False, True, True): (1, 256, 1, 4), + (192, 192, 16384, 64, 64, True, False, True): (3, 64, 2, 4), + (192, 192, 32768, 64, 64, False, True, True): (1, 512, 1, 2), + (192, 192, 32768, 64, 64, True, False, True): (2, 256, 2, 4), + (192, 192, 65536, 64, 64, False, True, True): (1, 512, 1, 4), + (192, 192, 65536, 64, 64, True, False, True): (2, 512, 2, 4), + (192, 192, 131072, 64, 64, False, True, True): (1, 1024, 1, 4), + (192, 192, 131072, 64, 64, True, False, True): (1, 512, 3, 4), + (384, 384, 256, 128, 128, False, True, True): (3, 2, 3, 8), + (384, 384, 256, 128, 128, True, False, True): (5, 2, 3, 8), + (384, 384, 512, 128, 128, False, True, True): (4, 4, 3, 8), + (384, 384, 512, 128, 128, True, False, True): (1, 4, 3, 8), + (384, 384, 1024, 128, 128, False, True, True): (1, 8, 3, 8), + (384, 384, 1024, 128, 128, True, False, True): (1, 8, 2, 8), + (384, 384, 2048, 128, 128, False, True, True): (3, 16, 3, 8), + (384, 384, 2048, 128, 128, True, False, True): (1, 16, 3, 8), + (384, 384, 4096, 128, 128, False, True, True): (3, 32, 3, 8), + (384, 384, 4096, 128, 128, True, False, True): (3, 32, 3, 8), + (384, 384, 8192, 128, 128, False, True, True): (2, 64, 3, 8), + (384, 384, 8192, 128, 128, True, False, True): (2, 64, 2, 4), + (384, 384, 16384, 128, 128, False, True, True): (1, 128, 2, 8), + (384, 384, 16384, 128, 128, True, False, True): (3, 128, 2, 4), + (384, 384, 32768, 128, 128, False, True, True): (2, 256, 3, 8), + (384, 384, 32768, 128, 128, True, False, True): (1, 256, 2, 4), + (384, 384, 65536, 128, 128, False, True, True): (7, 512, 1, 4), + (384, 384, 65536, 128, 128, True, False, True): (3, 512, 2, 4), + (384, 384, 131072, 128, 128, False, True, True): (5, 1024, 1, 4), + (384, 384, 131072, 128, 128, True, False, True): (1, 1024, 2, 4), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.5)): { + (16, 16, 16, 16, 16, False, False, False): (2, 1, 1, 16), + (16, 16, 16, 16, 16, False, False, True): (1, 1, 2, 4), + (16, 16, 16, 16, 16, False, True, False): (1, 1, 2, 16), + (16, 16, 16, 16, 16, False, True, True): (2, 1, 2, 8), + (16, 16, 16, 16, 16, True, False, False): (1, 1, 1, 2), + (16, 16, 16, 16, 16, True, False, True): (2, 1, 1, 4), + (16, 16, 32, 16, 16, False, False, False): (1, 1, 1, 2), + (16, 16, 32, 16, 16, False, False, True): (1, 1, 2, 8), + (16, 16, 32, 16, 16, False, True, False): (1, 2, 1, 4), + (16, 16, 32, 16, 16, False, True, True): (1, 2, 2, 4), + (16, 16, 32, 16, 16, True, False, False): (1, 1, 2, 4), + (16, 16, 32, 16, 16, True, False, True): (1, 2, 2, 4), + (16, 16, 64, 16, 16, False, False, False): (1, 4, 1, 4), + (16, 16, 64, 16, 16, False, False, True): (2, 2, 1, 4), + (16, 16, 64, 16, 16, False, True, False): (1, 4, 1, 4), + (16, 16, 64, 16, 16, False, True, True): (1, 4, 1, 8), + (16, 16, 64, 16, 16, True, False, False): (1, 2, 1, 4), + (16, 16, 64, 16, 16, True, False, True): (1, 4, 2, 8), + (16, 32, 16, 16, 16, False, False, False): (1, 1, 2, 8), + (16, 32, 16, 16, 16, False, False, True): (2, 1, 1, 4), + (16, 32, 16, 16, 16, False, True, False): (1, 1, 1, 4), + (16, 32, 16, 16, 16, False, True, True): (1, 1, 1, 4), + (16, 32, 16, 16, 16, True, False, False): (1, 1, 1, 4), + (16, 32, 16, 16, 16, True, False, True): (1, 1, 2, 8), + (16, 32, 16, 16, 32, False, False, False): (1, 1, 2, 4), + (16, 32, 16, 16, 32, False, False, True): (2, 1, 2, 2), + (16, 32, 16, 16, 32, False, True, False): (1, 1, 1, 8), + (16, 32, 16, 16, 32, False, True, True): (1, 1, 1, 2), + (16, 32, 16, 16, 32, True, False, False): (3, 1, 1, 4), + (16, 32, 16, 16, 32, True, False, True): (1, 1, 1, 4), + (16, 32, 32, 16, 16, False, False, False): (1, 2, 1, 4), + (16, 32, 32, 16, 16, False, False, True): (2, 2, 1, 4), + (16, 32, 32, 16, 16, False, True, False): (1, 2, 1, 2), + (16, 32, 32, 16, 16, False, True, True): (1, 2, 1, 4), + (16, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4), + (16, 32, 32, 16, 16, True, False, True): (1, 2, 1, 4), + (16, 32, 32, 16, 32, False, False, False): (1, 1, 2, 4), + (16, 32, 32, 16, 32, False, False, True): (1, 2, 1, 4), + (16, 32, 32, 16, 32, False, True, False): (1, 2, 2, 8), + (16, 32, 32, 16, 32, False, True, True): (1, 2, 1, 1), + (16, 32, 32, 16, 32, True, False, False): (1, 2, 1, 2), + (16, 32, 32, 16, 32, True, False, True): (1, 2, 1, 4), + (16, 32, 64, 16, 16, False, False, False): (1, 2, 1, 4), + (16, 32, 64, 16, 16, False, False, True): (2, 4, 1, 4), + (16, 32, 64, 16, 16, False, True, False): (1, 4, 2, 4), + (16, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4), + (16, 32, 64, 16, 16, True, False, False): (1, 2, 2, 8), + (16, 32, 64, 16, 16, True, False, True): (1, 4, 1, 2), + (16, 32, 64, 16, 32, False, False, False): (1, 4, 1, 4), + (16, 32, 64, 16, 32, False, False, True): (1, 4, 3, 4), + (16, 32, 64, 16, 32, False, True, False): (1, 2, 1, 4), + (16, 32, 64, 16, 32, False, True, True): (1, 4, 1, 4), + (16, 32, 64, 16, 32, True, False, False): (1, 2, 1, 8), + (16, 32, 64, 16, 32, True, False, True): (1, 2, 1, 4), + (16, 64, 16, 16, 32, False, False, False): (1, 1, 1, 2), + (16, 64, 16, 16, 32, False, False, True): (1, 1, 1, 8), + (16, 64, 16, 16, 32, False, True, False): (1, 1, 1, 8), + (16, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4), + (16, 64, 16, 16, 32, True, False, False): (1, 1, 1, 8), + (16, 64, 16, 16, 32, True, False, True): (1, 1, 1, 4), + (16, 64, 32, 16, 32, False, False, False): (1, 2, 1, 4), + (16, 64, 32, 16, 32, False, False, True): (1, 1, 1, 4), + (16, 64, 32, 16, 32, False, True, False): (1, 2, 1, 1), + (16, 64, 32, 16, 32, False, True, True): (1, 2, 1, 8), + (16, 64, 32, 16, 32, True, False, False): (2, 2, 1, 4), + (16, 64, 32, 16, 32, True, False, True): (2, 2, 1, 4), + (16, 64, 64, 16, 32, False, False, False): (1, 2, 1, 4), + (16, 64, 64, 16, 32, False, False, True): (1, 4, 1, 4), + (16, 64, 64, 16, 32, False, True, False): (1, 4, 1, 4), + (16, 64, 64, 16, 32, False, True, True): (1, 4, 1, 4), + (16, 64, 64, 16, 32, True, False, False): (1, 4, 1, 2), + (16, 64, 64, 16, 32, True, False, True): (3, 4, 1, 4), + (32, 16, 16, 16, 16, False, False, False): (1, 1, 2, 4), + (32, 16, 16, 16, 16, False, False, True): (1, 1, 1, 2), + (32, 16, 16, 16, 16, False, True, False): (1, 1, 2, 4), + (32, 16, 16, 16, 16, False, True, True): (1, 1, 2, 4), + (32, 16, 16, 16, 16, True, False, False): (1, 1, 3, 8), + (32, 16, 16, 16, 16, True, False, True): (1, 1, 2, 4), + (32, 16, 32, 16, 16, False, False, False): (1, 2, 1, 4), + (32, 16, 32, 16, 16, False, False, True): (1, 2, 3, 4), + (32, 16, 32, 16, 16, False, True, False): (1, 1, 1, 8), + (32, 16, 32, 16, 16, False, True, True): (1, 2, 1, 4), + (32, 16, 32, 16, 16, True, False, False): (1, 1, 1, 2), + (32, 16, 32, 16, 16, True, False, True): (1, 1, 1, 4), + (32, 16, 64, 16, 16, False, False, False): (1, 4, 1, 4), + (32, 16, 64, 16, 16, False, False, True): (3, 4, 1, 4), + (32, 16, 64, 16, 16, False, True, False): (1, 4, 1, 1), + (32, 16, 64, 16, 16, False, True, True): (1, 4, 1, 4), + (32, 16, 64, 16, 16, True, False, False): (1, 4, 1, 4), + (32, 16, 64, 16, 16, True, False, True): (1, 4, 1, 4), + (32, 32, 16, 16, 16, False, False, False): (1, 1, 1, 2), + (32, 32, 16, 16, 16, False, False, True): (2, 1, 1, 4), + (32, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2), + (32, 32, 16, 16, 16, False, True, True): (2, 1, 1, 4), + (32, 32, 16, 16, 16, True, False, False): (3, 1, 2, 4), + (32, 32, 16, 16, 16, True, False, True): (1, 1, 2, 4), + (32, 32, 16, 16, 32, False, False, False): (2, 1, 1, 2), + (32, 32, 16, 16, 32, False, False, True): (1, 1, 1, 4), + (32, 32, 16, 16, 32, False, True, False): (1, 1, 1, 4), + (32, 32, 16, 16, 32, False, True, True): (1, 1, 1, 8), + (32, 32, 16, 16, 32, True, False, False): (1, 1, 1, 8), + (32, 32, 16, 16, 32, True, False, True): (1, 1, 1, 4), + (32, 32, 16, 32, 32, False, False, False): (2, 1, 1, 4), + (32, 32, 16, 32, 32, False, False, True): (1, 1, 2, 4), + (32, 32, 16, 32, 32, False, True, False): (2, 1, 1, 1), + (32, 32, 16, 32, 32, False, True, True): (2, 1, 2, 4), + (32, 32, 16, 32, 32, True, False, False): (1, 1, 1, 8), + (32, 32, 16, 32, 32, True, False, True): (1, 1, 1, 4), + (32, 32, 32, 16, 16, False, False, False): (1, 1, 1, 4), + (32, 32, 32, 16, 16, False, False, True): (1, 2, 1, 2), + (32, 32, 32, 16, 16, False, True, False): (2, 2, 1, 4), + (32, 32, 32, 16, 16, False, True, True): (1, 2, 2, 4), + (32, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4), + (32, 32, 32, 16, 16, True, False, True): (2, 2, 1, 4), + (32, 32, 32, 16, 32, False, False, False): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, False, True): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, True, False): (1, 2, 1, 4), + (32, 32, 32, 16, 32, False, True, True): (1, 2, 1, 4), + (32, 32, 32, 16, 32, True, False, False): (2, 1, 1, 2), + (32, 32, 32, 16, 32, True, False, True): (2, 2, 2, 4), + (32, 32, 32, 32, 32, False, False, False): (1, 1, 1, 4), + (32, 32, 32, 32, 32, False, False, True): (1, 1, 1, 2), + (32, 32, 32, 32, 32, False, True, False): (1, 1, 1, 4), + (32, 32, 32, 32, 32, False, True, True): (1, 1, 2, 2), + (32, 32, 32, 32, 32, True, False, False): (1, 1, 1, 2), + (32, 32, 32, 32, 32, True, False, True): (1, 1, 2, 1), + (32, 32, 64, 16, 16, False, False, False): (2, 4, 1, 4), + (32, 32, 64, 16, 16, False, False, True): (1, 4, 2, 4), + (32, 32, 64, 16, 16, False, True, False): (1, 4, 1, 4), + (32, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4), + (32, 32, 64, 16, 16, True, False, False): (1, 2, 1, 4), + (32, 32, 64, 16, 16, True, False, True): (2, 4, 1, 4), + (32, 32, 64, 16, 32, False, False, False): (1, 4, 1, 8), + (32, 32, 64, 16, 32, False, False, True): (1, 4, 1, 4), + (32, 32, 64, 16, 32, False, True, False): (1, 4, 1, 4), + (32, 32, 64, 16, 32, False, True, True): (2, 4, 1, 4), + (32, 32, 64, 16, 32, True, False, False): (1, 2, 2, 4), + (32, 32, 64, 16, 32, True, False, True): (2, 4, 1, 4), + (32, 32, 64, 32, 32, False, False, False): (2, 2, 1, 4), + (32, 32, 64, 32, 32, False, False, True): (1, 1, 1, 4), + (32, 32, 64, 32, 32, False, True, False): (1, 1, 1, 8), + (32, 32, 64, 32, 32, False, True, True): (2, 1, 1, 4), + (32, 32, 64, 32, 32, True, False, False): (1, 1, 1, 4), + (32, 32, 64, 32, 32, True, False, True): (1, 2, 1, 1), + (32, 64, 16, 16, 32, False, False, False): (1, 1, 2, 2), + (32, 64, 16, 16, 32, False, False, True): (2, 1, 1, 4), + (32, 64, 16, 16, 32, False, True, False): (1, 1, 1, 8), + (32, 64, 16, 16, 32, False, True, True): (1, 1, 3, 4), + (32, 64, 16, 16, 32, True, False, False): (1, 1, 1, 2), + (32, 64, 16, 16, 32, True, False, True): (1, 1, 2, 4), + (32, 64, 16, 32, 32, False, False, False): (1, 1, 1, 2), + (32, 64, 16, 32, 32, False, False, True): (1, 1, 3, 4), + (32, 64, 16, 32, 32, False, True, False): (1, 1, 2, 4), + (32, 64, 16, 32, 32, False, True, True): (1, 1, 1, 8), + (32, 64, 16, 32, 32, True, False, False): (1, 1, 2, 4), + (32, 64, 16, 32, 32, True, False, True): (1, 1, 1, 8), + (32, 64, 32, 16, 32, False, False, False): (1, 2, 1, 4), + (32, 64, 32, 16, 32, False, False, True): (1, 2, 3, 4), + (32, 64, 32, 16, 32, False, True, False): (1, 2, 1, 8), + (32, 64, 32, 16, 32, False, True, True): (3, 2, 1, 4), + (32, 64, 32, 16, 32, True, False, False): (1, 1, 1, 8), + (32, 64, 32, 16, 32, True, False, True): (1, 2, 1, 4), + (32, 64, 32, 32, 32, False, False, False): (1, 1, 1, 1), + (32, 64, 32, 32, 32, False, False, True): (1, 1, 1, 4), + (32, 64, 32, 32, 32, False, True, False): (1, 1, 1, 4), + (32, 64, 32, 32, 32, False, True, True): (1, 1, 1, 4), + (32, 64, 32, 32, 32, True, False, False): (1, 1, 1, 4), + (32, 64, 32, 32, 32, True, False, True): (1, 1, 2, 8), + (32, 64, 64, 16, 32, False, False, False): (2, 4, 1, 4), + (32, 64, 64, 16, 32, False, False, True): (1, 4, 1, 4), + (32, 64, 64, 16, 32, False, True, False): (1, 4, 1, 4), + (32, 64, 64, 16, 32, False, True, True): (2, 4, 1, 4), + (32, 64, 64, 16, 32, True, False, False): (1, 4, 1, 4), + (32, 64, 64, 16, 32, True, False, True): (1, 4, 1, 4), + (32, 64, 64, 32, 32, False, False, False): (2, 2, 1, 4), + (32, 64, 64, 32, 32, False, False, True): (1, 2, 1, 8), + (32, 64, 64, 32, 32, False, True, False): (1, 2, 1, 4), + (32, 64, 64, 32, 32, False, True, True): (1, 2, 1, 4), + (32, 64, 64, 32, 32, True, False, False): (2, 2, 1, 4), + (32, 64, 64, 32, 32, True, False, True): (1, 2, 3, 8), + (64, 32, 16, 32, 32, False, False, False): (1, 1, 1, 4), + (64, 32, 16, 32, 32, False, False, True): (3, 1, 2, 4), + (64, 32, 16, 32, 32, False, True, False): (2, 1, 1, 2), + (64, 32, 16, 32, 32, False, True, True): (1, 1, 1, 8), + (64, 32, 16, 32, 32, True, False, False): (1, 1, 1, 2), + (64, 32, 16, 32, 32, True, False, True): (1, 1, 1, 4), + (64, 32, 32, 32, 32, False, False, False): (1, 1, 1, 4), + (64, 32, 32, 32, 32, False, False, True): (1, 1, 2, 8), + (64, 32, 32, 32, 32, False, True, False): (1, 1, 1, 8), + (64, 32, 32, 32, 32, False, True, True): (1, 1, 1, 4), + (64, 32, 32, 32, 32, True, False, False): (1, 1, 2, 4), + (64, 32, 32, 32, 32, True, False, True): (1, 1, 3, 8), + (64, 32, 64, 32, 32, False, False, False): (1, 2, 1, 4), + (64, 32, 64, 32, 32, False, False, True): (2, 2, 1, 4), + (64, 32, 64, 32, 32, False, True, False): (1, 1, 1, 4), + (64, 32, 64, 32, 32, False, True, True): (1, 2, 1, 8), + (64, 32, 64, 32, 32, True, False, False): (2, 2, 1, 4), + (64, 32, 64, 32, 32, True, False, True): (1, 2, 1, 8), + (64, 64, 16, 32, 32, False, False, False): (1, 1, 2, 8), + (64, 64, 16, 32, 32, False, False, True): (2, 1, 2, 4), + (64, 64, 16, 32, 32, False, True, False): (1, 1, 1, 2), + (64, 64, 16, 32, 32, False, True, True): (1, 1, 2, 4), + (64, 64, 16, 32, 32, True, False, False): (1, 1, 1, 2), + (64, 64, 16, 32, 32, True, False, True): (1, 1, 2, 4), + (64, 64, 32, 32, 32, False, False, False): (1, 1, 1, 4), + (64, 64, 32, 32, 32, False, False, True): (2, 1, 1, 4), + (64, 64, 32, 32, 32, False, True, False): (1, 1, 1, 8), + (64, 64, 32, 32, 32, False, True, True): (2, 1, 1, 4), + (64, 64, 32, 32, 32, True, False, False): (1, 1, 1, 4), + (64, 64, 32, 32, 32, True, False, True): (1, 1, 1, 8), + (64, 64, 64, 32, 32, False, False, False): (2, 2, 1, 4), + (64, 64, 64, 32, 32, False, False, True): (1, 2, 1, 4), + (64, 64, 64, 32, 32, False, True, False): (1, 2, 1, 4), + (64, 64, 64, 32, 32, False, True, True): (2, 2, 1, 4), + (64, 64, 64, 32, 32, True, False, False): (1, 1, 1, 8), + (64, 64, 64, 32, 32, True, False, True): (1, 2, 2, 4), + (192, 192, 256, 16, 16, False, True, True): (1, 16, 3, 2), + (192, 192, 256, 16, 16, True, False, True): (1, 8, 5, 4), + (192, 192, 256, 32, 32, False, True, True): (2, 8, 4, 4), + (192, 192, 256, 32, 32, True, False, True): (1, 8, 5, 4), + (192, 192, 512, 16, 16, False, True, True): (2, 16, 3, 4), + (192, 192, 512, 16, 16, True, False, True): (1, 16, 5, 4), + (192, 192, 512, 32, 32, False, True, True): (1, 16, 3, 4), + (192, 192, 512, 32, 32, True, False, True): (2, 16, 3, 4), + (192, 192, 1024, 16, 16, False, True, True): (3, 16, 3, 4), + (192, 192, 1024, 16, 16, True, False, True): (2, 8, 3, 4), + (192, 192, 1024, 32, 32, False, True, True): (3, 32, 1, 4), + (192, 192, 1024, 32, 32, True, False, True): (3, 16, 3, 4), + (192, 192, 2048, 16, 16, False, True, True): (1, 32, 3, 4), + (192, 192, 2048, 16, 16, True, False, True): (2, 16, 3, 4), + (192, 192, 2048, 32, 32, False, True, True): (1, 64, 1, 4), + (192, 192, 2048, 32, 32, True, False, True): (1, 64, 2, 4), + (192, 192, 4096, 16, 16, False, True, True): (1, 64, 2, 4), + (192, 192, 4096, 16, 16, True, False, True): (1, 32, 3, 4), + (192, 192, 4096, 32, 32, False, True, True): (3, 128, 2, 4), + (192, 192, 4096, 32, 32, True, False, True): (1, 128, 2, 4), + (192, 192, 8192, 16, 16, False, True, True): (2, 64, 3, 4), + (192, 192, 8192, 16, 16, True, False, True): (1, 64, 3, 4), + (192, 192, 8192, 32, 32, False, True, True): (3, 128, 3, 4), + (192, 192, 8192, 32, 32, True, False, True): (1, 128, 2, 4), + (192, 192, 16384, 16, 16, False, True, True): (1, 256, 3, 2), + (192, 192, 16384, 16, 16, True, False, True): (1, 256, 3, 2), + (192, 192, 16384, 32, 32, False, True, True): (2, 256, 3, 4), + (192, 192, 16384, 32, 32, True, False, True): (2, 256, 3, 4), + (192, 192, 32768, 16, 16, False, True, True): (2, 512, 3, 2), + (192, 192, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (192, 192, 32768, 32, 32, False, True, True): (2, 512, 3, 4), + (192, 192, 32768, 32, 32, True, False, True): (2, 512, 3, 4), + (192, 192, 65536, 16, 16, False, True, True): (2, 1024, 3, 2), + (192, 192, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (192, 192, 65536, 32, 32, False, True, True): (2, 1024, 3, 4), + (192, 192, 65536, 32, 32, True, False, True): (2, 1024, 3, 4), + (192, 192, 131072, 16, 16, False, True, True): (2, 512, 3, 4), + (192, 192, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (192, 192, 131072, 32, 32, False, True, True): (2, 1024, 3, 4), + (192, 192, 131072, 32, 32, True, False, True): (2, 1024, 3, 4), + (256, 256, 256, 16, 16, False, True, True): (1, 16, 3, 4), + (256, 256, 256, 16, 16, True, False, True): (2, 16, 1, 4), + (256, 256, 256, 32, 32, False, True, True): (1, 8, 4, 8), + (256, 256, 256, 32, 32, True, False, True): (4, 8, 4, 4), + (256, 256, 256, 64, 64, False, True, True): (1, 4, 4, 8), + (256, 256, 256, 64, 64, True, False, True): (1, 4, 3, 8), + (256, 256, 256, 128, 128, False, True, True): (7, 2, 1, 32), + (256, 256, 256, 128, 128, True, False, True): (3, 2, 1, 32), + (256, 256, 512, 16, 16, False, True, True): (1, 16, 5, 4), + (256, 256, 512, 16, 16, True, False, True): (1, 16, 3, 2), + (256, 256, 512, 32, 32, False, True, True): (4, 16, 4, 4), + (256, 256, 512, 32, 32, True, False, True): (4, 16, 3, 4), + (256, 256, 512, 64, 64, False, True, True): (1, 8, 3, 8), + (256, 256, 512, 64, 64, True, False, True): (1, 8, 3, 8), + (256, 256, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (256, 256, 512, 128, 128, True, False, True): (3, 4, 1, 32), + (256, 256, 1024, 16, 16, False, True, True): (3, 32, 5, 2), + (256, 256, 1024, 16, 16, True, False, True): (2, 32, 5, 2), + (256, 256, 1024, 32, 32, False, True, True): (1, 32, 4, 4), + (256, 256, 1024, 32, 32, True, False, True): (1, 32, 5, 4), + (256, 256, 1024, 64, 64, False, True, True): (4, 16, 3, 8), + (256, 256, 1024, 64, 64, True, False, True): (1, 16, 3, 8), + (256, 256, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (256, 256, 1024, 128, 128, True, False, True): (3, 8, 1, 32), + (256, 256, 2048, 16, 16, False, True, True): (3, 32, 3, 4), + (256, 256, 2048, 16, 16, True, False, True): (1, 64, 3, 2), + (256, 256, 2048, 32, 32, False, True, True): (1, 64, 3, 4), + (256, 256, 2048, 32, 32, True, False, True): (1, 64, 3, 4), + (256, 256, 2048, 64, 64, False, True, True): (2, 32, 1, 8), + (256, 256, 2048, 64, 64, True, False, True): (2, 32, 1, 8), + (256, 256, 2048, 128, 128, False, True, True): (4, 16, 1, 32), + (256, 256, 2048, 128, 128, True, False, True): (4, 16, 1, 32), + (256, 256, 4096, 16, 16, False, True, True): (1, 32, 2, 4), + (256, 256, 4096, 16, 16, True, False, True): (1, 32, 3, 4), + (256, 256, 4096, 32, 32, False, True, True): (1, 128, 2, 4), + (256, 256, 4096, 32, 32, True, False, True): (1, 128, 2, 4), + (256, 256, 4096, 64, 64, False, True, True): (2, 64, 4, 8), + (256, 256, 4096, 64, 64, True, False, True): (3, 64, 2, 8), + (256, 256, 4096, 128, 128, False, True, True): (3, 32, 1, 32), + (256, 256, 4096, 128, 128, True, False, True): (2, 32, 1, 32), + (256, 256, 8192, 16, 16, False, True, True): (1, 64, 3, 4), + (256, 256, 8192, 16, 16, True, False, True): (2, 128, 3, 2), + (256, 256, 8192, 32, 32, False, True, True): (3, 128, 3, 4), + (256, 256, 8192, 32, 32, True, False, True): (1, 128, 3, 4), + (256, 256, 8192, 64, 64, False, True, True): (3, 128, 1, 4), + (256, 256, 8192, 64, 64, True, False, True): (4, 128, 2, 8), + (256, 256, 8192, 128, 128, False, True, True): (6, 64, 1, 32), + (256, 256, 8192, 128, 128, True, False, True): (2, 64, 1, 32), + (256, 256, 16384, 16, 16, False, True, True): (4, 128, 3, 4), + (256, 256, 16384, 16, 16, True, False, True): (3, 128, 3, 4), + (256, 256, 16384, 32, 32, False, True, True): (4, 256, 3, 4), + (256, 256, 16384, 32, 32, True, False, True): (2, 256, 3, 4), + (256, 256, 16384, 64, 64, False, True, True): (3, 256, 1, 4), + (256, 256, 16384, 64, 64, True, False, True): (2, 256, 2, 4), + (256, 256, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (256, 256, 16384, 128, 128, True, False, True): (3, 128, 1, 32), + (256, 256, 32768, 16, 16, False, True, True): (1, 256, 3, 4), + (256, 256, 32768, 16, 16, True, False, True): (2, 128, 3, 4), + (256, 256, 32768, 32, 32, False, True, True): (2, 512, 3, 4), + (256, 256, 32768, 32, 32, True, False, True): (4, 512, 3, 4), + (256, 256, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (256, 256, 32768, 64, 64, True, False, True): (1, 512, 2, 4), + (256, 256, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (256, 256, 32768, 128, 128, True, False, True): (1, 256, 1, 32), + (256, 256, 65536, 16, 16, False, True, True): (2, 512, 3, 4), + (256, 256, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (256, 256, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (256, 256, 65536, 32, 32, True, False, True): (2, 1024, 3, 4), + (256, 256, 65536, 64, 64, False, True, True): (1, 1024, 2, 4), + (256, 256, 65536, 64, 64, True, False, True): (1, 1024, 2, 4), + (256, 256, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (256, 256, 65536, 128, 128, True, False, True): (2, 512, 1, 32), + (256, 256, 131072, 16, 16, False, True, True): (1, 1024, 3, 4), + (256, 256, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (256, 256, 131072, 32, 32, False, True, True): (1, 2048, 3, 4), + (256, 256, 131072, 32, 32, True, False, True): (1, 2048, 3, 4), + (256, 256, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (256, 256, 131072, 64, 64, True, False, True): (1, 2048, 2, 4), + (256, 256, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (256, 256, 131072, 128, 128, True, False, True): (4, 1024, 1, 32), + (384, 384, 256, 16, 16, False, True, True): (1, 8, 3, 4), + (384, 384, 256, 16, 16, True, False, True): (1, 8, 3, 4), + (384, 384, 256, 32, 32, False, True, True): (2, 8, 3, 8), + (384, 384, 256, 32, 32, True, False, True): (1, 8, 3, 4), + (384, 384, 256, 64, 64, False, True, True): (1, 4, 4, 8), + (384, 384, 256, 64, 64, True, False, True): (2, 4, 3, 8), + (384, 384, 512, 16, 16, False, True, True): (3, 16, 3, 2), + (384, 384, 512, 16, 16, True, False, True): (3, 16, 3, 2), + (384, 384, 512, 32, 32, False, True, True): (2, 8, 3, 4), + (384, 384, 512, 32, 32, True, False, True): (1, 8, 3, 4), + (384, 384, 512, 64, 64, False, True, True): (2, 8, 3, 8), + (384, 384, 512, 64, 64, True, False, True): (2, 8, 4, 8), + (384, 384, 1024, 16, 16, False, True, True): (3, 16, 3, 2), + (384, 384, 1024, 16, 16, True, False, True): (4, 32, 3, 2), + (384, 384, 1024, 32, 32, False, True, True): (1, 32, 3, 4), + (384, 384, 1024, 32, 32, True, False, True): (2, 16, 3, 4), + (384, 384, 1024, 64, 64, False, True, True): (2, 16, 3, 8), + (384, 384, 1024, 64, 64, True, False, True): (4, 16, 4, 8), + (384, 384, 2048, 16, 16, False, True, True): (3, 16, 3, 4), + (384, 384, 2048, 16, 16, True, False, True): (1, 32, 3, 4), + (384, 384, 2048, 32, 32, False, True, True): (3, 64, 2, 4), + (384, 384, 2048, 32, 32, True, False, True): (1, 64, 3, 4), + (384, 384, 2048, 64, 64, False, True, True): (4, 32, 4, 8), + (384, 384, 2048, 64, 64, True, False, True): (5, 32, 4, 8), + (384, 384, 4096, 16, 16, False, True, True): (1, 32, 3, 4), + (384, 384, 4096, 16, 16, True, False, True): (3, 32, 3, 4), + (384, 384, 4096, 32, 32, False, True, True): (2, 64, 3, 4), + (384, 384, 4096, 32, 32, True, False, True): (2, 64, 3, 4), + (384, 384, 4096, 64, 64, False, True, True): (2, 64, 3, 8), + (384, 384, 4096, 64, 64, True, False, True): (2, 64, 3, 8), + (384, 384, 8192, 16, 16, False, True, True): (1, 128, 3, 2), + (384, 384, 8192, 16, 16, True, False, True): (1, 128, 3, 2), + (384, 384, 8192, 32, 32, False, True, True): (1, 128, 3, 4), + (384, 384, 8192, 32, 32, True, False, True): (1, 128, 3, 4), + (384, 384, 8192, 64, 64, False, True, True): (3, 128, 3, 4), + (384, 384, 8192, 64, 64, True, False, True): (2, 128, 3, 4), + (384, 384, 16384, 16, 16, False, True, True): (1, 256, 3, 2), + (384, 384, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (384, 384, 16384, 32, 32, False, True, True): (2, 256, 3, 4), + (384, 384, 16384, 32, 32, True, False, True): (4, 256, 3, 4), + (384, 384, 16384, 64, 64, False, True, True): (2, 256, 3, 4), + (384, 384, 16384, 64, 64, True, False, True): (1, 256, 3, 4), + (384, 384, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (384, 384, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (384, 384, 32768, 32, 32, False, True, True): (1, 512, 3, 4), + (384, 384, 32768, 32, 32, True, False, True): (1, 512, 2, 4), + (384, 384, 32768, 64, 64, False, True, True): (1, 512, 3, 4), + (384, 384, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (384, 384, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (384, 384, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (384, 384, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (384, 384, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (384, 384, 65536, 64, 64, False, True, True): (1, 1024, 3, 4), + (384, 384, 65536, 64, 64, True, False, True): (1, 1024, 3, 4), + (384, 384, 131072, 16, 16, False, True, True): (1, 512, 3, 4), + (384, 384, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (384, 384, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (384, 384, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (384, 384, 131072, 64, 64, False, True, True): (1, 2048, 3, 4), + (384, 384, 131072, 64, 64, True, False, True): (1, 2048, 3, 4), + (512, 512, 256, 16, 16, False, True, True): (1, 8, 4, 4), + (512, 512, 256, 16, 16, True, False, True): (1, 8, 3, 2), + (512, 512, 256, 32, 32, False, True, True): (4, 8, 3, 4), + (512, 512, 256, 32, 32, True, False, True): (4, 8, 3, 4), + (512, 512, 256, 64, 64, False, True, True): (3, 4, 3, 8), + (512, 512, 256, 64, 64, True, False, True): (5, 4, 3, 8), + (512, 512, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (512, 512, 256, 128, 128, True, False, True): (3, 2, 1, 32), + (512, 512, 512, 16, 16, False, True, True): (2, 16, 3, 2), + (512, 512, 512, 16, 16, True, False, True): (1, 8, 4, 4), + (512, 512, 512, 32, 32, False, True, True): (3, 16, 3, 4), + (512, 512, 512, 32, 32, True, False, True): (5, 16, 2, 4), + (512, 512, 512, 64, 64, False, True, True): (1, 8, 3, 8), + (512, 512, 512, 64, 64, True, False, True): (3, 8, 3, 8), + (512, 512, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (512, 512, 512, 128, 128, True, False, True): (3, 4, 1, 16), + (512, 512, 1024, 16, 16, False, True, True): (1, 16, 3, 4), + (512, 512, 1024, 16, 16, True, False, True): (3, 16, 3, 4), + (512, 512, 1024, 32, 32, False, True, True): (3, 32, 3, 4), + (512, 512, 1024, 32, 32, True, False, True): (3, 32, 2, 4), + (512, 512, 1024, 64, 64, False, True, True): (1, 16, 3, 8), + (512, 512, 1024, 64, 64, True, False, True): (4, 16, 3, 8), + (512, 512, 1024, 128, 128, False, True, True): (4, 8, 1, 32), + (512, 512, 1024, 128, 128, True, False, True): (4, 8, 1, 32), + (512, 512, 2048, 16, 16, False, True, True): (5, 16, 3, 4), + (512, 512, 2048, 16, 16, True, False, True): (5, 16, 3, 4), + (512, 512, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (512, 512, 2048, 32, 32, True, False, True): (1, 32, 4, 4), + (512, 512, 2048, 64, 64, False, True, True): (4, 32, 3, 8), + (512, 512, 2048, 64, 64, True, False, True): (4, 32, 3, 8), + (512, 512, 2048, 128, 128, False, True, True): (3, 16, 1, 32), + (512, 512, 2048, 128, 128, True, False, True): (3, 16, 1, 32), + (512, 512, 4096, 16, 16, False, True, True): (4, 32, 3, 4), + (512, 512, 4096, 16, 16, True, False, True): (4, 64, 3, 2), + (512, 512, 4096, 32, 32, False, True, True): (3, 64, 3, 4), + (512, 512, 4096, 32, 32, True, False, True): (3, 64, 3, 4), + (512, 512, 4096, 64, 64, False, True, True): (4, 64, 2, 4), + (512, 512, 4096, 64, 64, True, False, True): (1, 64, 2, 4), + (512, 512, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (512, 512, 4096, 128, 128, True, False, True): (1, 32, 1, 32), + (512, 512, 8192, 16, 16, False, True, True): (1, 64, 3, 4), + (512, 512, 8192, 16, 16, True, False, True): (4, 64, 3, 4), + (512, 512, 8192, 32, 32, False, True, True): (2, 128, 3, 4), + (512, 512, 8192, 32, 32, True, False, True): (3, 128, 3, 4), + (512, 512, 8192, 64, 64, False, True, True): (1, 128, 2, 4), + (512, 512, 8192, 64, 64, True, False, True): (1, 128, 2, 4), + (512, 512, 8192, 128, 128, False, True, True): (6, 64, 1, 32), + (512, 512, 8192, 128, 128, True, False, True): (4, 64, 1, 32), + (512, 512, 16384, 16, 16, False, True, True): (1, 128, 3, 4), + (512, 512, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (512, 512, 16384, 32, 32, False, True, True): (1, 256, 3, 4), + (512, 512, 16384, 32, 32, True, False, True): (4, 256, 3, 4), + (512, 512, 16384, 64, 64, False, True, True): (1, 256, 2, 4), + (512, 512, 16384, 64, 64, True, False, True): (1, 256, 2, 4), + (512, 512, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (512, 512, 16384, 128, 128, True, False, True): (2, 128, 1, 32), + (512, 512, 32768, 16, 16, False, True, True): (1, 256, 3, 4), + (512, 512, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (512, 512, 32768, 32, 32, False, True, True): (1, 512, 3, 4), + (512, 512, 32768, 32, 32, True, False, True): (1, 512, 3, 4), + (512, 512, 32768, 64, 64, False, True, True): (1, 512, 2, 4), + (512, 512, 32768, 64, 64, True, False, True): (2, 512, 2, 4), + (512, 512, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (512, 512, 32768, 128, 128, True, False, True): (2, 256, 1, 32), + (512, 512, 65536, 16, 16, False, True, True): (1, 512, 3, 4), + (512, 512, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (512, 512, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (512, 512, 65536, 32, 32, True, False, True): (1, 1024, 3, 4), + (512, 512, 65536, 64, 64, False, True, True): (1, 1024, 2, 4), + (512, 512, 65536, 64, 64, True, False, True): (1, 1024, 2, 4), + (512, 512, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (512, 512, 65536, 128, 128, True, False, True): (4, 512, 1, 32), + (512, 512, 131072, 16, 16, False, True, True): (1, 512, 3, 4), + (512, 512, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (512, 512, 131072, 32, 32, False, True, True): (1, 2048, 3, 4), + (512, 512, 131072, 32, 32, True, False, True): (1, 2048, 3, 4), + (512, 512, 131072, 64, 64, False, True, True): (1, 2048, 2, 4), + (512, 512, 131072, 64, 64, True, False, True): (1, 2048, 2, 4), + (512, 512, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (512, 512, 131072, 128, 128, True, False, True): (2, 1024, 1, 32), + (768, 768, 256, 16, 16, False, True, True): (1, 4, 5, 4), + (768, 768, 256, 16, 16, True, False, True): (3, 8, 3, 2), + (768, 768, 256, 32, 32, False, True, True): (2, 4, 3, 4), + (768, 768, 256, 32, 32, True, False, True): (3, 8, 4, 4), + (768, 768, 256, 64, 64, False, True, True): (1, 4, 4, 8), + (768, 768, 256, 64, 64, True, False, True): (3, 4, 3, 8), + (768, 768, 256, 128, 128, False, True, True): (3, 2, 1, 32), + (768, 768, 256, 128, 128, True, False, True): (2, 2, 2, 32), + (768, 768, 512, 16, 16, False, True, True): (2, 4, 5, 4), + (768, 768, 512, 16, 16, True, False, True): (2, 4, 4, 4), + (768, 768, 512, 32, 32, False, True, True): (1, 8, 3, 4), + (768, 768, 512, 32, 32, True, False, True): (3, 8, 4, 4), + (768, 768, 512, 64, 64, False, True, True): (2, 8, 3, 8), + (768, 768, 512, 64, 64, True, False, True): (5, 8, 3, 8), + (768, 768, 512, 128, 128, False, True, True): (2, 4, 1, 32), + (768, 768, 512, 128, 128, True, False, True): (2, 4, 2, 32), + (768, 768, 1024, 16, 16, False, True, True): (2, 16, 4, 2), + (768, 768, 1024, 16, 16, True, False, True): (4, 32, 3, 1), + (768, 768, 1024, 32, 32, False, True, True): (1, 32, 2, 4), + (768, 768, 1024, 32, 32, True, False, True): (1, 16, 5, 4), + (768, 768, 1024, 64, 64, False, True, True): (2, 16, 3, 8), + (768, 768, 1024, 64, 64, True, False, True): (2, 16, 3, 8), + (768, 768, 1024, 128, 128, False, True, True): (1, 8, 2, 32), + (768, 768, 1024, 128, 128, True, False, True): (1, 8, 1, 32), + (768, 768, 2048, 16, 16, False, True, True): (1, 16, 3, 4), + (768, 768, 2048, 16, 16, True, False, True): (1, 16, 3, 4), + (768, 768, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (768, 768, 2048, 32, 32, True, False, True): (5, 32, 3, 4), + (768, 768, 2048, 64, 64, False, True, True): (1, 32, 3, 8), + (768, 768, 2048, 64, 64, True, False, True): (1, 32, 3, 4), + (768, 768, 2048, 128, 128, False, True, True): (3, 16, 1, 32), + (768, 768, 2048, 128, 128, True, False, True): (4, 16, 1, 32), + (768, 768, 4096, 16, 16, False, True, True): (1, 64, 3, 2), + (768, 768, 4096, 16, 16, True, False, True): (3, 64, 3, 2), + (768, 768, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (768, 768, 4096, 32, 32, True, False, True): (1, 64, 3, 4), + (768, 768, 4096, 64, 64, False, True, True): (4, 64, 3, 4), + (768, 768, 4096, 64, 64, True, False, True): (4, 64, 3, 4), + (768, 768, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (768, 768, 4096, 128, 128, True, False, True): (1, 32, 2, 32), + (768, 768, 8192, 16, 16, False, True, True): (1, 128, 3, 2), + (768, 768, 8192, 16, 16, True, False, True): (2, 32, 3, 4), + (768, 768, 8192, 32, 32, False, True, True): (2, 128, 3, 4), + (768, 768, 8192, 32, 32, True, False, True): (1, 128, 2, 4), + (768, 768, 8192, 64, 64, False, True, True): (1, 128, 3, 4), + (768, 768, 8192, 64, 64, True, False, True): (2, 128, 3, 4), + (768, 768, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (768, 768, 8192, 128, 128, True, False, True): (2, 64, 1, 32), + (768, 768, 16384, 16, 16, False, True, True): (3, 64, 3, 4), + (768, 768, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (768, 768, 16384, 32, 32, False, True, True): (2, 256, 3, 4), + (768, 768, 16384, 32, 32, True, False, True): (4, 256, 2, 4), + (768, 768, 16384, 64, 64, False, True, True): (1, 256, 3, 4), + (768, 768, 16384, 64, 64, True, False, True): (1, 256, 3, 4), + (768, 768, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (768, 768, 16384, 128, 128, True, False, True): (2, 128, 1, 32), + (768, 768, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (768, 768, 32768, 16, 16, True, False, True): (2, 128, 3, 4), + (768, 768, 32768, 32, 32, False, True, True): (2, 256, 3, 4), + (768, 768, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (768, 768, 32768, 64, 64, False, True, True): (1, 512, 3, 4), + (768, 768, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (768, 768, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (768, 768, 32768, 128, 128, True, False, True): (1, 256, 1, 32), + (768, 768, 50432, 16, 16, False, True, True): (1, 197, 3, 4), + (768, 768, 50432, 32, 32, False, True, True): (1, 394, 3, 4), + (768, 768, 50432, 64, 64, False, True, True): (1, 788, 3, 4), + (768, 768, 50432, 128, 128, False, True, True): (3, 394, 1, 32), + (768, 768, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (768, 768, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (768, 768, 65536, 32, 32, False, True, True): (1, 512, 3, 4), + (768, 768, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (768, 768, 65536, 64, 64, False, True, True): (1, 1024, 3, 4), + (768, 768, 65536, 64, 64, True, False, True): (1, 1024, 3, 4), + (768, 768, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (768, 768, 65536, 128, 128, True, False, True): (1, 512, 1, 32), + (768, 768, 131072, 16, 16, False, True, True): (1, 512, 3, 4), + (768, 768, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (768, 768, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (768, 768, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (768, 768, 131072, 64, 64, False, True, True): (1, 2048, 3, 4), + (768, 768, 131072, 64, 64, True, False, True): (1, 2048, 3, 4), + (768, 768, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (768, 768, 131072, 128, 128, True, False, True): (1, 1024, 1, 32), + (768, 3072, 256, 16, 16, False, True, True): (1, 2, 4, 4), + (768, 3072, 256, 16, 16, True, False, True): (1, 4, 3, 4), + (768, 3072, 256, 32, 32, False, True, True): (1, 4, 3, 4), + (768, 3072, 256, 32, 32, True, False, True): (3, 4, 3, 4), + (768, 3072, 256, 64, 64, False, True, True): (1, 4, 3, 8), + (768, 3072, 256, 64, 64, True, False, True): (1, 4, 3, 8), + (768, 3072, 256, 128, 128, False, True, True): (2, 2, 2, 32), + (768, 3072, 256, 128, 128, True, False, True): (2, 2, 1, 32), + (768, 3072, 512, 16, 16, False, True, True): (2, 4, 3, 4), + (768, 3072, 512, 16, 16, True, False, True): (1, 8, 3, 2), + (768, 3072, 512, 32, 32, False, True, True): (3, 8, 4, 4), + (768, 3072, 512, 32, 32, True, False, True): (3, 8, 3, 4), + (768, 3072, 512, 64, 64, False, True, True): (1, 8, 4, 8), + (768, 3072, 512, 64, 64, True, False, True): (1, 8, 3, 8), + (768, 3072, 512, 128, 128, False, True, True): (1, 4, 2, 32), + (768, 3072, 512, 128, 128, True, False, True): (1, 4, 1, 32), + (768, 3072, 1024, 16, 16, False, True, True): (4, 16, 3, 2), + (768, 3072, 1024, 16, 16, True, False, True): (4, 16, 3, 2), + (768, 3072, 1024, 32, 32, False, True, True): (4, 16, 5, 4), + (768, 3072, 1024, 32, 32, True, False, True): (4, 16, 5, 4), + (768, 3072, 1024, 64, 64, False, True, True): (2, 16, 3, 8), + (768, 3072, 1024, 64, 64, True, False, True): (2, 16, 3, 8), + (768, 3072, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (768, 3072, 1024, 128, 128, True, False, True): (1, 8, 1, 32), + (768, 3072, 2048, 16, 16, False, True, True): (2, 16, 3, 4), + (768, 3072, 2048, 16, 16, True, False, True): (2, 16, 3, 4), + (768, 3072, 2048, 32, 32, False, True, True): (4, 32, 5, 4), + (768, 3072, 2048, 32, 32, True, False, True): (2, 32, 3, 4), + (768, 3072, 2048, 64, 64, False, True, True): (2, 32, 3, 8), + (768, 3072, 2048, 64, 64, True, False, True): (2, 32, 3, 8), + (768, 3072, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (768, 3072, 2048, 128, 128, True, False, True): (2, 16, 1, 32), + (768, 3072, 4096, 16, 16, False, True, True): (1, 32, 5, 4), + (768, 3072, 4096, 16, 16, True, False, True): (3, 64, 3, 2), + (768, 3072, 4096, 32, 32, False, True, True): (5, 64, 3, 4), + (768, 3072, 4096, 32, 32, True, False, True): (5, 64, 3, 4), + (768, 3072, 4096, 64, 64, False, True, True): (1, 64, 3, 8), + (768, 3072, 4096, 64, 64, True, False, True): (5, 64, 3, 4), + (768, 3072, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (768, 3072, 4096, 128, 128, True, False, True): (1, 32, 1, 32), + (768, 3072, 8192, 16, 16, False, True, True): (1, 128, 3, 2), + (768, 3072, 8192, 16, 16, True, False, True): (1, 128, 3, 2), + (768, 3072, 8192, 32, 32, False, True, True): (1, 128, 3, 4), + (768, 3072, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (768, 3072, 8192, 64, 64, False, True, True): (3, 128, 3, 4), + (768, 3072, 8192, 64, 64, True, False, True): (3, 128, 3, 4), + (768, 3072, 8192, 128, 128, False, True, True): (4, 64, 2, 32), + (768, 3072, 8192, 128, 128, True, False, True): (2, 64, 1, 32), + (768, 3072, 16384, 16, 16, False, True, True): (1, 256, 2, 2), + (768, 3072, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (768, 3072, 16384, 32, 32, False, True, True): (8, 128, 3, 4), + (768, 3072, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (768, 3072, 16384, 64, 64, False, True, True): (1, 256, 3, 4), + (768, 3072, 16384, 64, 64, True, False, True): (3, 256, 3, 4), + (768, 3072, 16384, 128, 128, False, True, True): (3, 128, 1, 32), + (768, 3072, 16384, 128, 128, True, False, True): (2, 128, 2, 32), + (768, 3072, 32768, 16, 16, False, True, True): (1, 512, 3, 1), + (768, 3072, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (768, 3072, 32768, 32, 32, False, True, True): (1, 256, 3, 4), + (768, 3072, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (768, 3072, 32768, 64, 64, False, True, True): (2, 512, 3, 4), + (768, 3072, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (768, 3072, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (768, 3072, 32768, 128, 128, True, False, True): (2, 256, 2, 32), + (768, 3072, 50432, 16, 16, False, True, True): (1, 197, 3, 4), + (768, 3072, 50432, 16, 16, True, False, True): (1, 197, 3, 4), + (768, 3072, 50432, 32, 32, False, True, True): (1, 788, 2, 4), + (768, 3072, 50432, 32, 32, True, False, True): (1, 394, 3, 4), + (768, 3072, 50432, 64, 64, False, True, True): (1, 788, 3, 4), + (768, 3072, 50432, 64, 64, True, False, True): (2, 788, 3, 4), + (768, 3072, 50432, 128, 128, False, True, True): (1, 394, 1, 32), + (768, 3072, 50432, 128, 128, True, False, True): (2, 394, 2, 32), + (768, 3072, 65536, 16, 16, False, True, True): (1, 1024, 3, 1), + (768, 3072, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (768, 3072, 65536, 32, 32, False, True, True): (1, 512, 3, 4), + (768, 3072, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (768, 3072, 65536, 64, 64, False, True, True): (2, 1024, 3, 4), + (768, 3072, 65536, 64, 64, True, False, True): (5, 1024, 3, 4), + (768, 3072, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (768, 3072, 65536, 128, 128, True, False, True): (2, 512, 2, 32), + (768, 3072, 131072, 16, 16, False, True, True): (1, 2048, 3, 1), + (768, 3072, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (768, 3072, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (768, 3072, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (768, 3072, 131072, 64, 64, False, True, True): (1, 2048, 3, 4), + (768, 3072, 131072, 64, 64, True, False, True): (2, 2048, 3, 4), + (768, 3072, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (768, 3072, 131072, 128, 128, True, False, True): (1, 1024, 2, 32), + (1024, 1024, 256, 16, 16, False, True, True): (4, 8, 3, 2), + (1024, 1024, 256, 16, 16, True, False, True): (2, 8, 3, 2), + (1024, 1024, 256, 32, 32, False, True, True): (1, 8, 3, 4), + (1024, 1024, 256, 32, 32, True, False, True): (1, 8, 3, 4), + (1024, 1024, 256, 64, 64, False, True, True): (1, 4, 3, 8), + (1024, 1024, 256, 64, 64, True, False, True): (2, 4, 3, 8), + (1024, 1024, 256, 128, 128, False, True, True): (3, 2, 1, 32), + (1024, 1024, 256, 128, 128, True, False, True): (5, 2, 1, 32), + (1024, 1024, 512, 16, 16, False, True, True): (3, 8, 3, 4), + (1024, 1024, 512, 16, 16, True, False, True): (3, 8, 3, 4), + (1024, 1024, 512, 32, 32, False, True, True): (1, 16, 3, 4), + (1024, 1024, 512, 32, 32, True, False, True): (3, 16, 3, 4), + (1024, 1024, 512, 64, 64, False, True, True): (6, 8, 3, 8), + (1024, 1024, 512, 64, 64, True, False, True): (8, 8, 3, 8), + (1024, 1024, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (1024, 1024, 512, 128, 128, True, False, True): (1, 4, 1, 32), + (1024, 1024, 1024, 16, 16, False, True, True): (4, 8, 3, 4), + (1024, 1024, 1024, 16, 16, True, False, True): (1, 8, 3, 4), + (1024, 1024, 1024, 32, 32, False, True, True): (4, 16, 4, 4), + (1024, 1024, 1024, 32, 32, True, False, True): (5, 16, 3, 4), + (1024, 1024, 1024, 64, 64, False, True, True): (6, 16, 3, 8), + (1024, 1024, 1024, 64, 64, True, False, True): (3, 16, 2, 4), + (1024, 1024, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (1024, 1024, 1024, 128, 128, True, False, True): (2, 8, 1, 32), + (1024, 1024, 2048, 16, 16, False, True, True): (4, 16, 3, 4), + (1024, 1024, 2048, 16, 16, True, False, True): (1, 16, 3, 4), + (1024, 1024, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (1024, 1024, 2048, 32, 32, True, False, True): (2, 32, 3, 4), + (1024, 1024, 2048, 64, 64, False, True, True): (4, 32, 2, 4), + (1024, 1024, 2048, 64, 64, True, False, True): (8, 32, 2, 4), + (1024, 1024, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (1024, 1024, 2048, 128, 128, True, False, True): (1, 16, 1, 32), + (1024, 1024, 4096, 16, 16, False, True, True): (4, 32, 3, 4), + (1024, 1024, 4096, 16, 16, True, False, True): (1, 64, 3, 2), + (1024, 1024, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (1024, 1024, 4096, 32, 32, True, False, True): (1, 64, 3, 4), + (1024, 1024, 4096, 64, 64, False, True, True): (2, 64, 2, 4), + (1024, 1024, 4096, 64, 64, True, False, True): (2, 64, 2, 4), + (1024, 1024, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (1024, 1024, 4096, 128, 128, True, False, True): (4, 32, 1, 32), + (1024, 1024, 8192, 16, 16, False, True, True): (1, 128, 3, 1), + (1024, 1024, 8192, 16, 16, True, False, True): (1, 128, 3, 1), + (1024, 1024, 8192, 32, 32, False, True, True): (1, 128, 3, 4), + (1024, 1024, 8192, 32, 32, True, False, True): (1, 128, 3, 4), + (1024, 1024, 8192, 64, 64, False, True, True): (2, 128, 2, 4), + (1024, 1024, 8192, 64, 64, True, False, True): (2, 128, 2, 4), + (1024, 1024, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (1024, 1024, 8192, 128, 128, True, False, True): (4, 64, 1, 32), + (1024, 1024, 16384, 16, 16, False, True, True): (1, 128, 2, 4), + (1024, 1024, 16384, 16, 16, True, False, True): (4, 256, 3, 1), + (1024, 1024, 16384, 32, 32, False, True, True): (1, 256, 3, 4), + (1024, 1024, 16384, 32, 32, True, False, True): (1, 256, 3, 4), + (1024, 1024, 16384, 64, 64, False, True, True): (1, 256, 2, 4), + (1024, 1024, 16384, 64, 64, True, False, True): (1, 256, 2, 4), + (1024, 1024, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (1024, 1024, 16384, 128, 128, True, False, True): (4, 128, 1, 32), + (1024, 1024, 32768, 16, 16, False, True, True): (1, 256, 2, 4), + (1024, 1024, 32768, 16, 16, True, False, True): (4, 512, 3, 1), + (1024, 1024, 32768, 32, 32, False, True, True): (1, 512, 3, 4), + (1024, 1024, 32768, 32, 32, True, False, True): (1, 512, 3, 4), + (1024, 1024, 32768, 64, 64, False, True, True): (1, 512, 2, 4), + (1024, 1024, 32768, 64, 64, True, False, True): (1, 512, 2, 4), + (1024, 1024, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (1024, 1024, 32768, 128, 128, True, False, True): (1, 256, 1, 32), + (1024, 1024, 65536, 16, 16, False, True, True): (1, 512, 2, 4), + (1024, 1024, 65536, 16, 16, True, False, True): (1, 1024, 3, 1), + (1024, 1024, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (1024, 1024, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (1024, 1024, 65536, 64, 64, False, True, True): (1, 1024, 2, 4), + (1024, 1024, 65536, 64, 64, True, False, True): (1, 1024, 2, 4), + (1024, 1024, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (1024, 1024, 65536, 128, 128, True, False, True): (1, 512, 1, 32), + (1024, 1024, 131072, 16, 16, False, True, True): (4, 2048, 3, 1), + (1024, 1024, 131072, 16, 16, True, False, True): (4, 2048, 3, 1), + (1024, 1024, 131072, 32, 32, False, True, True): (1, 2048, 3, 4), + (1024, 1024, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (1024, 1024, 131072, 64, 64, False, True, True): (1, 2048, 2, 4), + (1024, 1024, 131072, 64, 64, True, False, True): (1, 2048, 2, 4), + (1024, 1024, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (1024, 1024, 131072, 128, 128, True, False, True): (1, 1024, 1, 32), + (1536, 1536, 256, 16, 16, False, True, True): (5, 4, 3, 2), + (1536, 1536, 256, 16, 16, True, False, True): (2, 2, 3, 4), + (1536, 1536, 256, 32, 32, False, True, True): (1, 8, 2, 4), + (1536, 1536, 256, 32, 32, True, False, True): (2, 4, 3, 4), + (1536, 1536, 256, 64, 64, False, True, True): (1, 4, 3, 8), + (1536, 1536, 256, 64, 64, True, False, True): (2, 4, 3, 8), + (1536, 1536, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (1536, 1536, 256, 128, 128, True, False, True): (2, 2, 2, 32), + (1536, 1536, 512, 16, 16, False, True, True): (1, 8, 3, 2), + (1536, 1536, 512, 16, 16, True, False, True): (1, 8, 3, 2), + (1536, 1536, 512, 32, 32, False, True, True): (1, 16, 3, 4), + (1536, 1536, 512, 32, 32, True, False, True): (1, 16, 3, 4), + (1536, 1536, 512, 64, 64, False, True, True): (3, 8, 3, 8), + (1536, 1536, 512, 64, 64, True, False, True): (3, 8, 3, 8), + (1536, 1536, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (1536, 1536, 512, 128, 128, True, False, True): (2, 4, 2, 32), + (1536, 1536, 1024, 16, 16, False, True, True): (2, 8, 3, 4), + (1536, 1536, 1024, 16, 16, True, False, True): (2, 8, 3, 4), + (1536, 1536, 1024, 32, 32, False, True, True): (1, 16, 3, 4), + (1536, 1536, 1024, 32, 32, True, False, True): (1, 16, 3, 4), + (1536, 1536, 1024, 64, 64, False, True, True): (2, 16, 3, 8), + (1536, 1536, 1024, 64, 64, True, False, True): (2, 16, 3, 8), + (1536, 1536, 1024, 128, 128, False, True, True): (3, 8, 1, 32), + (1536, 1536, 1024, 128, 128, True, False, True): (1, 8, 2, 32), + (1536, 1536, 2048, 16, 16, False, True, True): (1, 32, 3, 2), + (1536, 1536, 2048, 16, 16, True, False, True): (1, 32, 3, 2), + (1536, 1536, 2048, 32, 32, False, True, True): (3, 32, 2, 4), + (1536, 1536, 2048, 32, 32, True, False, True): (4, 32, 3, 4), + (1536, 1536, 2048, 64, 64, False, True, True): (1, 32, 3, 4), + (1536, 1536, 2048, 64, 64, True, False, True): (1, 32, 3, 4), + (1536, 1536, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (1536, 1536, 2048, 128, 128, True, False, True): (2, 16, 1, 32), + (1536, 1536, 4096, 16, 16, False, True, True): (1, 64, 3, 2), + (1536, 1536, 4096, 16, 16, True, False, True): (1, 16, 3, 4), + (1536, 1536, 4096, 32, 32, False, True, True): (1, 64, 2, 4), + (1536, 1536, 4096, 32, 32, True, False, True): (1, 64, 2, 4), + (1536, 1536, 4096, 64, 64, False, True, True): (1, 64, 3, 4), + (1536, 1536, 4096, 64, 64, True, False, True): (1, 64, 3, 4), + (1536, 1536, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (1536, 1536, 4096, 128, 128, True, False, True): (4, 32, 2, 32), + (1536, 1536, 8192, 16, 16, False, True, True): (1, 32, 3, 4), + (1536, 1536, 8192, 16, 16, True, False, True): (5, 32, 3, 4), + (1536, 1536, 8192, 32, 32, False, True, True): (1, 128, 2, 4), + (1536, 1536, 8192, 32, 32, True, False, True): (1, 128, 2, 4), + (1536, 1536, 8192, 64, 64, False, True, True): (1, 128, 3, 4), + (1536, 1536, 8192, 64, 64, True, False, True): (1, 128, 3, 4), + (1536, 1536, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (1536, 1536, 8192, 128, 128, True, False, True): (4, 64, 2, 32), + (1536, 1536, 16384, 16, 16, False, True, True): (1, 64, 3, 4), + (1536, 1536, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (1536, 1536, 16384, 32, 32, False, True, True): (1, 256, 2, 4), + (1536, 1536, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (1536, 1536, 16384, 64, 64, False, True, True): (1, 256, 3, 4), + (1536, 1536, 16384, 64, 64, True, False, True): (3, 256, 3, 4), + (1536, 1536, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (1536, 1536, 16384, 128, 128, True, False, True): (4, 128, 2, 32), + (1536, 1536, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (1536, 1536, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (1536, 1536, 32768, 32, 32, False, True, True): (1, 256, 3, 4), + (1536, 1536, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (1536, 1536, 32768, 64, 64, False, True, True): (1, 512, 3, 4), + (1536, 1536, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (1536, 1536, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (1536, 1536, 32768, 128, 128, True, False, True): (4, 256, 2, 32), + (1536, 1536, 65536, 16, 16, False, True, True): (5, 256, 3, 4), + (1536, 1536, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (1536, 1536, 65536, 32, 32, False, True, True): (1, 512, 3, 4), + (1536, 1536, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (1536, 1536, 65536, 64, 64, False, True, True): (1, 1024, 3, 4), + (1536, 1536, 65536, 64, 64, True, False, True): (1, 1024, 3, 4), + (1536, 1536, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (1536, 1536, 65536, 128, 128, True, False, True): (4, 512, 2, 32), + (1536, 1536, 131072, 16, 16, False, True, True): (3, 512, 3, 4), + (1536, 1536, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (1536, 1536, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (1536, 1536, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (1536, 1536, 131072, 64, 64, False, True, True): (1, 2048, 3, 4), + (1536, 1536, 131072, 64, 64, True, False, True): (1, 2048, 3, 4), + (1536, 1536, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (1536, 1536, 131072, 128, 128, True, False, True): (4, 1024, 2, 32), + (2048, 2048, 256, 16, 16, False, True, True): (1, 4, 3, 4), + (2048, 2048, 256, 16, 16, True, False, True): (1, 4, 3, 4), + (2048, 2048, 256, 32, 32, False, True, True): (3, 8, 3, 4), + (2048, 2048, 256, 32, 32, True, False, True): (3, 8, 3, 4), + (2048, 2048, 256, 64, 64, False, True, True): (4, 4, 4, 8), + (2048, 2048, 256, 64, 64, True, False, True): (8, 4, 4, 8), + (2048, 2048, 256, 128, 128, False, True, True): (3, 2, 1, 32), + (2048, 2048, 256, 128, 128, True, False, True): (3, 2, 1, 32), + (2048, 2048, 512, 16, 16, False, True, True): (4, 8, 3, 2), + (2048, 2048, 512, 16, 16, True, False, True): (4, 8, 3, 2), + (2048, 2048, 512, 32, 32, False, True, True): (3, 8, 3, 4), + (2048, 2048, 512, 32, 32, True, False, True): (1, 16, 2, 4), + (2048, 2048, 512, 64, 64, False, True, True): (4, 8, 2, 4), + (2048, 2048, 512, 64, 64, True, False, True): (4, 8, 2, 4), + (2048, 2048, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (2048, 2048, 512, 128, 128, True, False, True): (4, 4, 1, 32), + (2048, 2048, 1024, 16, 16, False, True, True): (4, 8, 3, 4), + (2048, 2048, 1024, 16, 16, True, False, True): (4, 8, 3, 4), + (2048, 2048, 1024, 32, 32, False, True, True): (4, 16, 3, 4), + (2048, 2048, 1024, 32, 32, True, False, True): (1, 16, 3, 4), + (2048, 2048, 1024, 64, 64, False, True, True): (2, 16, 2, 4), + (2048, 2048, 1024, 64, 64, True, False, True): (2, 16, 2, 4), + (2048, 2048, 1024, 128, 128, False, True, True): (8, 8, 1, 32), + (2048, 2048, 1024, 128, 128, True, False, True): (4, 8, 1, 32), + (2048, 2048, 2048, 16, 16, False, True, True): (4, 32, 3, 1), + (2048, 2048, 2048, 16, 16, True, False, True): (3, 32, 3, 2), + (2048, 2048, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (2048, 2048, 2048, 32, 32, True, False, True): (1, 32, 3, 4), + (2048, 2048, 2048, 64, 64, False, True, True): (2, 32, 2, 4), + (2048, 2048, 2048, 64, 64, True, False, True): (2, 32, 2, 4), + (2048, 2048, 2048, 128, 128, False, True, True): (6, 16, 1, 32), + (2048, 2048, 2048, 128, 128, True, False, True): (4, 16, 1, 32), + (2048, 2048, 4096, 16, 16, False, True, True): (4, 64, 3, 1), + (2048, 2048, 4096, 16, 16, True, False, True): (1, 64, 3, 1), + (2048, 2048, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (2048, 2048, 4096, 32, 32, True, False, True): (4, 64, 3, 4), + (2048, 2048, 4096, 64, 64, False, True, True): (2, 64, 2, 4), + (2048, 2048, 4096, 64, 64, True, False, True): (2, 64, 2, 4), + (2048, 2048, 4096, 128, 128, False, True, True): (4, 32, 1, 32), + (2048, 2048, 4096, 128, 128, True, False, True): (4, 32, 1, 32), + (2048, 2048, 8192, 16, 16, False, True, True): (4, 128, 3, 1), + (2048, 2048, 8192, 16, 16, True, False, True): (1, 128, 3, 1), + (2048, 2048, 8192, 32, 32, False, True, True): (4, 128, 3, 4), + (2048, 2048, 8192, 32, 32, True, False, True): (4, 64, 3, 4), + (2048, 2048, 8192, 64, 64, False, True, True): (1, 128, 2, 4), + (2048, 2048, 8192, 64, 64, True, False, True): (2, 128, 2, 4), + (2048, 2048, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (2048, 2048, 8192, 128, 128, True, False, True): (4, 64, 1, 32), + (2048, 2048, 16384, 16, 16, False, True, True): (4, 256, 3, 1), + (2048, 2048, 16384, 16, 16, True, False, True): (1, 256, 3, 1), + (2048, 2048, 16384, 32, 32, False, True, True): (1, 256, 3, 4), + (2048, 2048, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (2048, 2048, 16384, 64, 64, False, True, True): (1, 256, 2, 4), + (2048, 2048, 16384, 64, 64, True, False, True): (1, 256, 2, 4), + (2048, 2048, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (2048, 2048, 16384, 128, 128, True, False, True): (4, 128, 1, 32), + (2048, 2048, 32768, 16, 16, False, True, True): (8, 512, 3, 1), + (2048, 2048, 32768, 16, 16, True, False, True): (1, 512, 3, 1), + (2048, 2048, 32768, 32, 32, False, True, True): (1, 512, 3, 4), + (2048, 2048, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (2048, 2048, 32768, 64, 64, False, True, True): (1, 512, 2, 4), + (2048, 2048, 32768, 64, 64, True, False, True): (1, 512, 2, 4), + (2048, 2048, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (2048, 2048, 32768, 128, 128, True, False, True): (4, 256, 1, 32), + (2048, 2048, 65536, 16, 16, False, True, True): (4, 1024, 3, 1), + (2048, 2048, 65536, 16, 16, True, False, True): (1, 1024, 3, 1), + (2048, 2048, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (2048, 2048, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (2048, 2048, 65536, 64, 64, False, True, True): (1, 1024, 2, 4), + (2048, 2048, 65536, 64, 64, True, False, True): (1, 1024, 2, 4), + (2048, 2048, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (2048, 2048, 65536, 128, 128, True, False, True): (4, 512, 1, 32), + (2048, 2048, 131072, 16, 16, False, True, True): (4, 2048, 3, 1), + (2048, 2048, 131072, 16, 16, True, False, True): (1, 2048, 3, 1), + (2048, 2048, 131072, 32, 32, False, True, True): (1, 2048, 3, 4), + (2048, 2048, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (2048, 2048, 131072, 64, 64, False, True, True): (1, 2048, 2, 4), + (2048, 2048, 131072, 64, 64, True, False, True): (1, 2048, 2, 4), + (2048, 2048, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (2048, 2048, 131072, 128, 128, True, False, True): (4, 1024, 1, 32), + (3072, 768, 256, 16, 16, False, True, True): (4, 4, 3, 2), + (3072, 768, 256, 16, 16, True, False, True): (1, 2, 6, 4), + (3072, 768, 256, 32, 32, False, True, True): (1, 4, 6, 4), + (3072, 768, 256, 32, 32, True, False, True): (5, 4, 3, 4), + (3072, 768, 256, 64, 64, False, True, True): (4, 4, 3, 8), + (3072, 768, 256, 64, 64, True, False, True): (4, 4, 3, 8), + (3072, 768, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (3072, 768, 256, 128, 128, True, False, True): (5, 2, 1, 32), + (3072, 768, 512, 16, 16, False, True, True): (4, 4, 3, 4), + (3072, 768, 512, 16, 16, True, False, True): (1, 4, 3, 4), + (3072, 768, 512, 32, 32, False, True, True): (3, 8, 3, 4), + (3072, 768, 512, 32, 32, True, False, True): (3, 8, 3, 4), + (3072, 768, 512, 64, 64, False, True, True): (2, 8, 3, 8), + (3072, 768, 512, 64, 64, True, False, True): (2, 8, 3, 8), + (3072, 768, 512, 128, 128, False, True, True): (1, 4, 2, 32), + (3072, 768, 512, 128, 128, True, False, True): (1, 4, 1, 32), + (3072, 768, 1024, 16, 16, False, True, True): (1, 16, 3, 2), + (3072, 768, 1024, 16, 16, True, False, True): (3, 16, 3, 2), + (3072, 768, 1024, 32, 32, False, True, True): (1, 16, 3, 4), + (3072, 768, 1024, 32, 32, True, False, True): (3, 16, 3, 4), + (3072, 768, 1024, 64, 64, False, True, True): (4, 16, 3, 8), + (3072, 768, 1024, 64, 64, True, False, True): (4, 16, 3, 4), + (3072, 768, 1024, 128, 128, False, True, True): (5, 8, 1, 32), + (3072, 768, 1024, 128, 128, True, False, True): (5, 8, 1, 32), + (3072, 768, 2048, 16, 16, False, True, True): (4, 32, 3, 2), + (3072, 768, 2048, 16, 16, True, False, True): (1, 32, 3, 2), + (3072, 768, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (3072, 768, 2048, 32, 32, True, False, True): (1, 32, 2, 4), + (3072, 768, 2048, 64, 64, False, True, True): (2, 32, 3, 4), + (3072, 768, 2048, 64, 64, True, False, True): (4, 32, 3, 4), + (3072, 768, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (3072, 768, 2048, 128, 128, True, False, True): (1, 16, 1, 32), + (3072, 768, 4096, 16, 16, False, True, True): (3, 64, 3, 2), + (3072, 768, 4096, 16, 16, True, False, True): (1, 64, 3, 2), + (3072, 768, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (3072, 768, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (3072, 768, 4096, 64, 64, False, True, True): (2, 64, 3, 4), + (3072, 768, 4096, 64, 64, True, False, True): (2, 64, 3, 4), + (3072, 768, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (3072, 768, 4096, 128, 128, True, False, True): (1, 32, 1, 32), + (3072, 768, 8192, 16, 16, False, True, True): (4, 128, 3, 1), + (3072, 768, 8192, 16, 16, True, False, True): (1, 32, 3, 4), + (3072, 768, 8192, 32, 32, False, True, True): (1, 64, 3, 4), + (3072, 768, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (3072, 768, 8192, 64, 64, False, True, True): (2, 128, 3, 4), + (3072, 768, 8192, 64, 64, True, False, True): (2, 128, 3, 4), + (3072, 768, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (3072, 768, 8192, 128, 128, True, False, True): (1, 64, 1, 32), + (3072, 768, 16384, 16, 16, False, True, True): (4, 256, 3, 1), + (3072, 768, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (3072, 768, 16384, 32, 32, False, True, True): (1, 128, 3, 4), + (3072, 768, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (3072, 768, 16384, 64, 64, False, True, True): (2, 256, 3, 4), + (3072, 768, 16384, 64, 64, True, False, True): (2, 256, 3, 4), + (3072, 768, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (3072, 768, 16384, 128, 128, True, False, True): (1, 128, 1, 32), + (3072, 768, 32768, 16, 16, False, True, True): (4, 512, 3, 1), + (3072, 768, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (3072, 768, 32768, 32, 32, False, True, True): (1, 256, 3, 4), + (3072, 768, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (3072, 768, 32768, 64, 64, False, True, True): (2, 512, 3, 4), + (3072, 768, 32768, 64, 64, True, False, True): (2, 512, 3, 4), + (3072, 768, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (3072, 768, 32768, 128, 128, True, False, True): (1, 256, 1, 32), + (3072, 768, 50432, 16, 16, False, True, True): (4, 788, 3, 1), + (3072, 768, 50432, 16, 16, True, False, True): (1, 197, 3, 4), + (3072, 768, 50432, 32, 32, False, True, True): (1, 394, 3, 4), + (3072, 768, 50432, 32, 32, True, False, True): (1, 394, 3, 4), + (3072, 768, 50432, 64, 64, False, True, True): (1, 788, 3, 4), + (3072, 768, 50432, 64, 64, True, False, True): (2, 788, 3, 4), + (3072, 768, 50432, 128, 128, False, True, True): (1, 394, 1, 32), + (3072, 768, 50432, 128, 128, True, False, True): (1, 394, 1, 32), + (3072, 768, 65536, 16, 16, False, True, True): (4, 1024, 3, 1), + (3072, 768, 65536, 16, 16, True, False, True): (1, 256, 3, 4), + (3072, 768, 65536, 32, 32, False, True, True): (1, 512, 3, 4), + (3072, 768, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (3072, 768, 65536, 64, 64, False, True, True): (2, 1024, 3, 4), + (3072, 768, 65536, 64, 64, True, False, True): (2, 1024, 3, 4), + (3072, 768, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (3072, 768, 65536, 128, 128, True, False, True): (1, 512, 1, 32), + (3072, 768, 131072, 16, 16, False, True, True): (4, 2048, 3, 1), + (3072, 768, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (3072, 768, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (3072, 768, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (3072, 768, 131072, 64, 64, False, True, True): (2, 2048, 3, 4), + (3072, 768, 131072, 64, 64, True, False, True): (2, 2048, 3, 4), + (3072, 768, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (3072, 768, 131072, 128, 128, True, False, True): (1, 1024, 1, 32), + (3072, 3072, 256, 16, 16, False, True, True): (1, 4, 5, 2), + (3072, 3072, 256, 16, 16, True, False, True): (1, 4, 3, 2), + (3072, 3072, 256, 32, 32, False, True, True): (1, 4, 4, 4), + (3072, 3072, 256, 32, 32, True, False, True): (1, 4, 3, 4), + (3072, 3072, 256, 64, 64, False, True, True): (2, 4, 3, 8), + (3072, 3072, 256, 64, 64, True, False, True): (2, 4, 3, 8), + (3072, 3072, 256, 128, 128, False, True, True): (6, 2, 1, 32), + (3072, 3072, 256, 128, 128, True, False, True): (8, 2, 2, 32), + (3072, 3072, 512, 16, 16, False, True, True): (2, 4, 3, 4), + (3072, 3072, 512, 16, 16, True, False, True): (2, 4, 3, 4), + (3072, 3072, 512, 32, 32, False, True, True): (2, 8, 3, 4), + (3072, 3072, 512, 32, 32, True, False, True): (2, 8, 3, 4), + (3072, 3072, 512, 64, 64, False, True, True): (2, 8, 3, 8), + (3072, 3072, 512, 64, 64, True, False, True): (2, 8, 3, 8), + (3072, 3072, 512, 128, 128, False, True, True): (5, 4, 1, 32), + (3072, 3072, 512, 128, 128, True, False, True): (5, 4, 2, 32), + (3072, 3072, 1024, 16, 16, False, True, True): (1, 16, 3, 2), + (3072, 3072, 1024, 16, 16, True, False, True): (1, 16, 3, 2), + (3072, 3072, 1024, 32, 32, False, True, True): (2, 16, 3, 4), + (3072, 3072, 1024, 32, 32, True, False, True): (1, 16, 3, 4), + (3072, 3072, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (3072, 3072, 1024, 64, 64, True, False, True): (1, 16, 3, 4), + (3072, 3072, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (3072, 3072, 1024, 128, 128, True, False, True): (3, 8, 2, 32), + (3072, 3072, 2048, 16, 16, False, True, True): (1, 32, 3, 2), + (3072, 3072, 2048, 16, 16, True, False, True): (1, 16, 2, 4), + (3072, 3072, 2048, 32, 32, False, True, True): (1, 32, 2, 4), + (3072, 3072, 2048, 32, 32, True, False, True): (1, 32, 3, 4), + (3072, 3072, 2048, 64, 64, False, True, True): (1, 32, 3, 4), + (3072, 3072, 2048, 64, 64, True, False, True): (1, 32, 3, 4), + (3072, 3072, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (3072, 3072, 2048, 128, 128, True, False, True): (4, 16, 2, 32), + (3072, 3072, 4096, 16, 16, False, True, True): (2, 16, 3, 4), + (3072, 3072, 4096, 16, 16, True, False, True): (2, 16, 3, 4), + (3072, 3072, 4096, 32, 32, False, True, True): (1, 64, 2, 4), + (3072, 3072, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (3072, 3072, 4096, 64, 64, False, True, True): (1, 64, 3, 4), + (3072, 3072, 4096, 64, 64, True, False, True): (1, 64, 3, 4), + (3072, 3072, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (3072, 3072, 4096, 128, 128, True, False, True): (2, 32, 2, 32), + (3072, 3072, 8192, 16, 16, False, True, True): (2, 32, 3, 4), + (3072, 3072, 8192, 16, 16, True, False, True): (2, 32, 3, 4), + (3072, 3072, 8192, 32, 32, False, True, True): (1, 64, 3, 4), + (3072, 3072, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (3072, 3072, 8192, 64, 64, False, True, True): (1, 128, 3, 4), + (3072, 3072, 8192, 64, 64, True, False, True): (1, 128, 3, 4), + (3072, 3072, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (3072, 3072, 8192, 128, 128, True, False, True): (4, 64, 2, 32), + (3072, 3072, 16384, 16, 16, False, True, True): (2, 64, 3, 4), + (3072, 3072, 16384, 16, 16, True, False, True): (1, 64, 3, 4), + (3072, 3072, 16384, 32, 32, False, True, True): (1, 128, 3, 4), + (3072, 3072, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (3072, 3072, 16384, 64, 64, False, True, True): (1, 256, 3, 4), + (3072, 3072, 16384, 64, 64, True, False, True): (1, 256, 3, 4), + (3072, 3072, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (3072, 3072, 16384, 128, 128, True, False, True): (4, 128, 2, 32), + (3072, 3072, 32768, 16, 16, False, True, True): (3, 128, 3, 4), + (3072, 3072, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (3072, 3072, 32768, 32, 32, False, True, True): (1, 256, 3, 4), + (3072, 3072, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (3072, 3072, 32768, 64, 64, False, True, True): (1, 512, 3, 4), + (3072, 3072, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (3072, 3072, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (3072, 3072, 32768, 128, 128, True, False, True): (4, 256, 2, 32), + (3072, 3072, 65536, 16, 16, False, True, True): (5, 256, 3, 4), + (3072, 3072, 65536, 16, 16, True, False, True): (2, 256, 3, 4), + (3072, 3072, 65536, 32, 32, False, True, True): (1, 512, 3, 4), + (3072, 3072, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (3072, 3072, 65536, 64, 64, False, True, True): (1, 1024, 3, 4), + (3072, 3072, 65536, 64, 64, True, False, True): (1, 1024, 3, 4), + (3072, 3072, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (3072, 3072, 65536, 128, 128, True, False, True): (4, 512, 2, 32), + (3072, 3072, 131072, 16, 16, False, True, True): (5, 512, 3, 4), + (3072, 3072, 131072, 16, 16, True, False, True): (1, 512, 3, 4), + (3072, 3072, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (3072, 3072, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (3072, 3072, 131072, 64, 64, False, True, True): (1, 2048, 3, 4), + (3072, 3072, 131072, 64, 64, True, False, True): (1, 2048, 3, 4), + (3072, 3072, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (3072, 3072, 131072, 128, 128, True, False, True): (4, 1024, 2, 32), + (4096, 4096, 256, 16, 16, False, True, True): (1, 4, 3, 2), + (4096, 4096, 256, 16, 16, True, False, True): (1, 2, 3, 4), + (4096, 4096, 256, 32, 32, False, True, True): (4, 4, 4, 4), + (4096, 4096, 256, 32, 32, True, False, True): (4, 4, 4, 4), + (4096, 4096, 256, 64, 64, False, True, True): (1, 4, 3, 8), + (4096, 4096, 256, 64, 64, True, False, True): (4, 4, 2, 4), + (4096, 4096, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (4096, 4096, 256, 128, 128, True, False, True): (3, 2, 1, 32), + (4096, 4096, 512, 16, 16, False, True, True): (1, 4, 3, 4), + (4096, 4096, 512, 16, 16, True, False, True): (5, 8, 3, 2), + (4096, 4096, 512, 32, 32, False, True, True): (4, 8, 3, 4), + (4096, 4096, 512, 32, 32, True, False, True): (4, 8, 3, 4), + (4096, 4096, 512, 64, 64, False, True, True): (1, 8, 2, 4), + (4096, 4096, 512, 64, 64, True, False, True): (1, 8, 2, 4), + (4096, 4096, 512, 128, 128, False, True, True): (4, 4, 1, 32), + (4096, 4096, 512, 128, 128, True, False, True): (4, 4, 1, 32), + (4096, 4096, 1024, 16, 16, False, True, True): (1, 8, 3, 4), + (4096, 4096, 1024, 16, 16, True, False, True): (1, 8, 3, 4), + (4096, 4096, 1024, 32, 32, False, True, True): (1, 16, 3, 4), + (4096, 4096, 1024, 32, 32, True, False, True): (1, 16, 3, 4), + (4096, 4096, 1024, 64, 64, False, True, True): (4, 16, 2, 4), + (4096, 4096, 1024, 64, 64, True, False, True): (4, 16, 2, 4), + (4096, 4096, 1024, 128, 128, False, True, True): (4, 8, 1, 32), + (4096, 4096, 1024, 128, 128, True, False, True): (4, 8, 1, 32), + (4096, 4096, 2048, 16, 16, False, True, True): (1, 32, 3, 1), + (4096, 4096, 2048, 16, 16, True, False, True): (6, 8, 3, 4), + (4096, 4096, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (4096, 4096, 2048, 32, 32, True, False, True): (1, 32, 3, 4), + (4096, 4096, 2048, 64, 64, False, True, True): (4, 32, 2, 4), + (4096, 4096, 2048, 64, 64, True, False, True): (4, 32, 2, 4), + (4096, 4096, 2048, 128, 128, False, True, True): (4, 16, 1, 32), + (4096, 4096, 2048, 128, 128, True, False, True): (4, 16, 1, 32), + (4096, 4096, 4096, 16, 16, False, True, True): (1, 16, 3, 4), + (4096, 4096, 4096, 16, 16, True, False, True): (1, 64, 3, 1), + (4096, 4096, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (4096, 4096, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (4096, 4096, 4096, 64, 64, False, True, True): (4, 64, 2, 4), + (4096, 4096, 4096, 64, 64, True, False, True): (4, 64, 2, 4), + (4096, 4096, 4096, 128, 128, False, True, True): (4, 32, 1, 32), + (4096, 4096, 4096, 128, 128, True, False, True): (4, 32, 1, 32), + (4096, 4096, 8192, 16, 16, False, True, True): (4, 128, 3, 1), + (4096, 4096, 8192, 16, 16, True, False, True): (1, 128, 3, 1), + (4096, 4096, 8192, 32, 32, False, True, True): (1, 128, 3, 4), + (4096, 4096, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (4096, 4096, 8192, 64, 64, False, True, True): (4, 128, 2, 4), + (4096, 4096, 8192, 64, 64, True, False, True): (4, 128, 2, 4), + (4096, 4096, 8192, 128, 128, False, True, True): (4, 64, 1, 32), + (4096, 4096, 8192, 128, 128, True, False, True): (4, 64, 1, 32), + (4096, 4096, 16384, 16, 16, False, True, True): (1, 64, 3, 4), + (4096, 4096, 16384, 16, 16, True, False, True): (1, 256, 3, 1), + (4096, 4096, 16384, 32, 32, False, True, True): (1, 256, 3, 4), + (4096, 4096, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (4096, 4096, 16384, 64, 64, False, True, True): (4, 256, 2, 4), + (4096, 4096, 16384, 64, 64, True, False, True): (4, 256, 2, 4), + (4096, 4096, 16384, 128, 128, False, True, True): (4, 128, 1, 32), + (4096, 4096, 16384, 128, 128, True, False, True): (4, 128, 1, 32), + (4096, 4096, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (4096, 4096, 32768, 16, 16, True, False, True): (1, 512, 3, 1), + (4096, 4096, 32768, 32, 32, False, True, True): (1, 512, 3, 4), + (4096, 4096, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (4096, 4096, 32768, 64, 64, False, True, True): (4, 512, 2, 4), + (4096, 4096, 32768, 64, 64, True, False, True): (4, 512, 2, 4), + (4096, 4096, 32768, 128, 128, False, True, True): (4, 256, 1, 32), + (4096, 4096, 32768, 128, 128, True, False, True): (4, 256, 1, 32), + (4096, 4096, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (4096, 4096, 65536, 16, 16, True, False, True): (1, 1024, 3, 1), + (4096, 4096, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (4096, 4096, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (4096, 4096, 65536, 64, 64, False, True, True): (4, 1024, 2, 4), + (4096, 4096, 65536, 64, 64, True, False, True): (2, 1024, 2, 4), + (4096, 4096, 65536, 128, 128, False, True, True): (4, 512, 1, 32), + (4096, 4096, 65536, 128, 128, True, False, True): (4, 512, 1, 32), + (4096, 4096, 131072, 16, 16, False, True, True): (2, 2048, 3, 1), + (4096, 4096, 131072, 16, 16, True, False, True): (1, 2048, 3, 1), + (4096, 4096, 131072, 32, 32, False, True, True): (2, 2048, 3, 4), + (4096, 4096, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (4096, 4096, 131072, 64, 64, False, True, True): (2, 2048, 2, 4), + (4096, 4096, 131072, 64, 64, True, False, True): (2, 2048, 2, 4), + (4096, 4096, 131072, 128, 128, False, True, True): (4, 1024, 1, 32), + (4096, 4096, 131072, 128, 128, True, False, True): (4, 1024, 1, 32), + (6144, 6144, 256, 16, 16, False, True, True): (2, 2, 3, 4), + (6144, 6144, 256, 16, 16, True, False, True): (2, 2, 3, 4), + (6144, 6144, 256, 32, 32, False, True, True): (2, 4, 3, 4), + (6144, 6144, 256, 32, 32, True, False, True): (2, 4, 3, 4), + (6144, 6144, 256, 64, 64, False, True, True): (1, 4, 3, 4), + (6144, 6144, 256, 64, 64, True, False, True): (1, 4, 3, 4), + (6144, 6144, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (6144, 6144, 256, 128, 128, True, False, True): (5, 2, 2, 32), + (6144, 6144, 512, 16, 16, False, True, True): (4, 8, 3, 2), + (6144, 6144, 512, 16, 16, True, False, True): (4, 8, 3, 2), + (6144, 6144, 512, 32, 32, False, True, True): (2, 8, 3, 4), + (6144, 6144, 512, 32, 32, True, False, True): (2, 8, 3, 4), + (6144, 6144, 512, 64, 64, False, True, True): (1, 8, 3, 4), + (6144, 6144, 512, 64, 64, True, False, True): (1, 8, 3, 4), + (6144, 6144, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (6144, 6144, 512, 128, 128, True, False, True): (4, 4, 2, 32), + (6144, 6144, 1024, 16, 16, False, True, True): (4, 16, 3, 2), + (6144, 6144, 1024, 16, 16, True, False, True): (4, 4, 3, 4), + (6144, 6144, 1024, 32, 32, False, True, True): (1, 16, 3, 4), + (6144, 6144, 1024, 32, 32, True, False, True): (1, 16, 3, 4), + (6144, 6144, 1024, 64, 64, False, True, True): (1, 16, 3, 4), + (6144, 6144, 1024, 64, 64, True, False, True): (1, 16, 3, 4), + (6144, 6144, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (6144, 6144, 1024, 128, 128, True, False, True): (4, 8, 2, 32), + (6144, 6144, 2048, 16, 16, False, True, True): (1, 8, 3, 4), + (6144, 6144, 2048, 16, 16, True, False, True): (4, 8, 3, 4), + (6144, 6144, 2048, 32, 32, False, True, True): (1, 16, 3, 4), + (6144, 6144, 2048, 32, 32, True, False, True): (1, 16, 3, 4), + (6144, 6144, 2048, 64, 64, False, True, True): (1, 32, 3, 4), + (6144, 6144, 2048, 64, 64, True, False, True): (3, 32, 3, 4), + (6144, 6144, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (6144, 6144, 2048, 128, 128, True, False, True): (1, 16, 2, 32), + (6144, 6144, 4096, 16, 16, False, True, True): (3, 16, 3, 4), + (6144, 6144, 4096, 16, 16, True, False, True): (4, 16, 3, 4), + (6144, 6144, 4096, 32, 32, False, True, True): (1, 32, 3, 4), + (6144, 6144, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (6144, 6144, 4096, 64, 64, False, True, True): (1, 64, 3, 4), + (6144, 6144, 4096, 64, 64, True, False, True): (1, 64, 3, 4), + (6144, 6144, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (6144, 6144, 4096, 128, 128, True, False, True): (4, 32, 2, 32), + (6144, 6144, 8192, 16, 16, False, True, True): (1, 32, 3, 4), + (6144, 6144, 8192, 16, 16, True, False, True): (4, 32, 3, 4), + (6144, 6144, 8192, 32, 32, False, True, True): (1, 64, 3, 4), + (6144, 6144, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (6144, 6144, 8192, 64, 64, False, True, True): (1, 128, 3, 4), + (6144, 6144, 8192, 64, 64, True, False, True): (1, 128, 3, 4), + (6144, 6144, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (6144, 6144, 8192, 128, 128, True, False, True): (4, 64, 2, 32), + (6144, 6144, 16384, 16, 16, False, True, True): (1, 64, 3, 4), + (6144, 6144, 16384, 16, 16, True, False, True): (4, 64, 3, 4), + (6144, 6144, 16384, 32, 32, False, True, True): (1, 128, 3, 4), + (6144, 6144, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (6144, 6144, 16384, 64, 64, False, True, True): (1, 256, 3, 4), + (6144, 6144, 16384, 64, 64, True, False, True): (1, 256, 3, 4), + (6144, 6144, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (6144, 6144, 16384, 128, 128, True, False, True): (4, 128, 2, 32), + (6144, 6144, 32768, 16, 16, False, True, True): (1, 128, 3, 4), + (6144, 6144, 32768, 16, 16, True, False, True): (4, 128, 3, 4), + (6144, 6144, 32768, 32, 32, False, True, True): (1, 256, 3, 4), + (6144, 6144, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (6144, 6144, 32768, 64, 64, False, True, True): (1, 512, 3, 4), + (6144, 6144, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (6144, 6144, 32768, 128, 128, False, True, True): (1, 256, 1, 32), + (6144, 6144, 32768, 128, 128, True, False, True): (4, 256, 2, 32), + (6144, 6144, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (6144, 6144, 65536, 16, 16, True, False, True): (2, 256, 3, 4), + (6144, 6144, 65536, 32, 32, False, True, True): (1, 512, 3, 4), + (6144, 6144, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (6144, 6144, 65536, 64, 64, False, True, True): (1, 1024, 3, 4), + (6144, 6144, 65536, 64, 64, True, False, True): (1, 1024, 3, 4), + (6144, 6144, 65536, 128, 128, False, True, True): (1, 512, 1, 32), + (6144, 6144, 65536, 128, 128, True, False, True): (4, 512, 2, 32), + (6144, 6144, 131072, 16, 16, False, True, True): (1, 512, 3, 4), + (6144, 6144, 131072, 16, 16, True, False, True): (2, 512, 3, 4), + (6144, 6144, 131072, 32, 32, False, True, True): (1, 1024, 3, 4), + (6144, 6144, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (6144, 6144, 131072, 64, 64, False, True, True): (1, 2048, 3, 4), + (6144, 6144, 131072, 64, 64, True, False, True): (1, 2048, 3, 4), + (6144, 6144, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (6144, 6144, 131072, 128, 128, True, False, True): (4, 1024, 2, 32), + (8192, 8192, 256, 16, 16, False, True, True): (2, 2, 4, 4), + (8192, 8192, 256, 16, 16, True, False, True): (1, 1, 3, 4), + (8192, 8192, 256, 32, 32, False, True, True): (2, 4, 3, 4), + (8192, 8192, 256, 32, 32, True, False, True): (2, 4, 3, 4), + (8192, 8192, 256, 64, 64, False, True, True): (4, 4, 2, 4), + (8192, 8192, 256, 64, 64, True, False, True): (4, 4, 2, 4), + (8192, 8192, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (8192, 8192, 256, 128, 128, True, False, True): (4, 2, 1, 32), + (8192, 8192, 512, 16, 16, False, True, True): (1, 4, 3, 4), + (8192, 8192, 512, 16, 16, True, False, True): (3, 4, 3, 4), + (8192, 8192, 512, 32, 32, False, True, True): (1, 8, 3, 4), + (8192, 8192, 512, 32, 32, True, False, True): (6, 8, 3, 4), + (8192, 8192, 512, 64, 64, False, True, True): (4, 8, 2, 4), + (8192, 8192, 512, 64, 64, True, False, True): (4, 8, 2, 4), + (8192, 8192, 512, 128, 128, False, True, True): (4, 4, 1, 32), + (8192, 8192, 512, 128, 128, True, False, True): (4, 4, 1, 32), + (8192, 8192, 1024, 16, 16, False, True, True): (1, 4, 3, 4), + (8192, 8192, 1024, 16, 16, True, False, True): (1, 32, 3, 1), + (8192, 8192, 1024, 32, 32, False, True, True): (1, 16, 3, 4), + (8192, 8192, 1024, 32, 32, True, False, True): (1, 16, 3, 4), + (8192, 8192, 1024, 64, 64, False, True, True): (4, 16, 2, 4), + (8192, 8192, 1024, 64, 64, True, False, True): (4, 16, 2, 4), + (8192, 8192, 1024, 128, 128, False, True, True): (4, 8, 1, 32), + (8192, 8192, 1024, 128, 128, True, False, True): (4, 8, 1, 32), + (8192, 8192, 2048, 16, 16, False, True, True): (4, 8, 3, 4), + (8192, 8192, 2048, 16, 16, True, False, True): (1, 32, 3, 1), + (8192, 8192, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (8192, 8192, 2048, 32, 32, True, False, True): (1, 16, 4, 4), + (8192, 8192, 2048, 64, 64, False, True, True): (4, 32, 2, 4), + (8192, 8192, 2048, 64, 64, True, False, True): (4, 32, 2, 4), + (8192, 8192, 2048, 128, 128, False, True, True): (4, 16, 1, 32), + (8192, 8192, 2048, 128, 128, True, False, True): (4, 16, 1, 32), + (8192, 8192, 4096, 16, 16, False, True, True): (3, 16, 3, 4), + (8192, 8192, 4096, 16, 16, True, False, True): (2, 64, 3, 1), + (8192, 8192, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (8192, 8192, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (8192, 8192, 4096, 64, 64, False, True, True): (4, 64, 2, 4), + (8192, 8192, 4096, 64, 64, True, False, True): (2, 64, 2, 4), + (8192, 8192, 4096, 128, 128, False, True, True): (4, 32, 1, 32), + (8192, 8192, 4096, 128, 128, True, False, True): (4, 32, 1, 32), + (8192, 8192, 8192, 16, 16, False, True, True): (2, 128, 3, 1), + (8192, 8192, 8192, 16, 16, True, False, True): (2, 128, 3, 1), + (8192, 8192, 8192, 32, 32, False, True, True): (1, 128, 3, 4), + (8192, 8192, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (8192, 8192, 8192, 64, 64, False, True, True): (4, 128, 2, 4), + (8192, 8192, 8192, 64, 64, True, False, True): (2, 128, 2, 4), + (8192, 8192, 8192, 128, 128, False, True, True): (4, 64, 1, 32), + (8192, 8192, 8192, 128, 128, True, False, True): (4, 64, 1, 32), + (8192, 8192, 16384, 16, 16, False, True, True): (1, 64, 3, 4), + (8192, 8192, 16384, 16, 16, True, False, True): (1, 256, 3, 1), + (8192, 8192, 16384, 32, 32, False, True, True): (1, 256, 3, 4), + (8192, 8192, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (8192, 8192, 16384, 64, 64, False, True, True): (2, 256, 2, 4), + (8192, 8192, 16384, 64, 64, True, False, True): (2, 256, 2, 4), + (8192, 8192, 16384, 128, 128, False, True, True): (4, 128, 1, 32), + (8192, 8192, 16384, 128, 128, True, False, True): (4, 128, 1, 32), + (8192, 8192, 32768, 16, 16, False, True, True): (1, 512, 3, 1), + (8192, 8192, 32768, 16, 16, True, False, True): (1, 512, 3, 1), + (8192, 8192, 32768, 32, 32, False, True, True): (1, 512, 3, 4), + (8192, 8192, 32768, 32, 32, True, False, True): (1, 256, 3, 4), + (8192, 8192, 32768, 64, 64, False, True, True): (2, 512, 2, 4), + (8192, 8192, 32768, 64, 64, True, False, True): (2, 512, 2, 4), + (8192, 8192, 32768, 128, 128, False, True, True): (4, 256, 1, 32), + (8192, 8192, 32768, 128, 128, True, False, True): (4, 256, 1, 32), + (8192, 8192, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (8192, 8192, 65536, 16, 16, True, False, True): (1, 1024, 3, 1), + (8192, 8192, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (8192, 8192, 65536, 32, 32, True, False, True): (1, 512, 3, 4), + (8192, 8192, 65536, 64, 64, False, True, True): (4, 1024, 2, 4), + (8192, 8192, 65536, 64, 64, True, False, True): (2, 1024, 2, 4), + (8192, 8192, 65536, 128, 128, False, True, True): (4, 512, 1, 32), + (8192, 8192, 65536, 128, 128, True, False, True): (4, 512, 1, 32), + (8192, 8192, 131072, 16, 16, False, True, True): (1, 2048, 3, 1), + (8192, 8192, 131072, 16, 16, True, False, True): (2, 2048, 3, 1), + (8192, 8192, 131072, 32, 32, False, True, True): (4, 2048, 3, 4), + (8192, 8192, 131072, 32, 32, True, False, True): (1, 1024, 3, 4), + (8192, 8192, 131072, 64, 64, False, True, True): (2, 2048, 2, 4), + (8192, 8192, 131072, 64, 64, True, False, True): (2, 2048, 2, 4), + (8192, 8192, 131072, 128, 128, False, True, True): (4, 1024, 1, 32), + (8192, 8192, 131072, 128, 128, True, False, True): (4, 1024, 1, 32), + (16384, 16384, 256, 16, 16, False, True, True): (1, 2, 3, 4), + (16384, 16384, 256, 16, 16, True, False, True): (1, 2, 3, 4), + (16384, 16384, 256, 32, 32, False, True, True): (1, 4, 3, 4), + (16384, 16384, 256, 32, 32, True, False, True): (1, 4, 3, 4), + (16384, 16384, 256, 64, 64, False, True, True): (2, 4, 2, 4), + (16384, 16384, 256, 64, 64, True, False, True): (2, 4, 2, 4), + (16384, 16384, 256, 128, 128, False, True, True): (2, 2, 1, 32), + (16384, 16384, 256, 128, 128, True, False, True): (2, 2, 1, 32), + (16384, 16384, 512, 16, 16, False, True, True): (1, 2, 3, 4), + (16384, 16384, 512, 16, 16, True, False, True): (5, 2, 3, 4), + (16384, 16384, 512, 32, 32, False, True, True): (1, 8, 3, 4), + (16384, 16384, 512, 32, 32, True, False, True): (1, 4, 3, 4), + (16384, 16384, 512, 64, 64, False, True, True): (4, 8, 2, 4), + (16384, 16384, 512, 64, 64, True, False, True): (4, 8, 2, 4), + (16384, 16384, 512, 128, 128, False, True, True): (4, 4, 1, 32), + (16384, 16384, 512, 128, 128, True, False, True): (4, 4, 1, 32), + (16384, 16384, 1024, 16, 16, False, True, True): (1, 4, 3, 4), + (16384, 16384, 1024, 16, 16, True, False, True): (2, 16, 3, 1), + (16384, 16384, 1024, 32, 32, False, True, True): (1, 16, 3, 4), + (16384, 16384, 1024, 32, 32, True, False, True): (1, 8, 3, 4), + (16384, 16384, 1024, 64, 64, False, True, True): (4, 16, 2, 4), + (16384, 16384, 1024, 64, 64, True, False, True): (4, 16, 2, 4), + (16384, 16384, 1024, 128, 128, False, True, True): (4, 8, 1, 32), + (16384, 16384, 1024, 128, 128, True, False, True): (4, 8, 1, 32), + (16384, 16384, 2048, 16, 16, False, True, True): (1, 8, 3, 4), + (16384, 16384, 2048, 16, 16, True, False, True): (2, 32, 3, 1), + (16384, 16384, 2048, 32, 32, False, True, True): (1, 32, 3, 4), + (16384, 16384, 2048, 32, 32, True, False, True): (1, 16, 3, 4), + (16384, 16384, 2048, 64, 64, False, True, True): (4, 32, 2, 4), + (16384, 16384, 2048, 64, 64, True, False, True): (2, 32, 2, 4), + (16384, 16384, 2048, 128, 128, False, True, True): (4, 16, 1, 32), + (16384, 16384, 2048, 128, 128, True, False, True): (4, 16, 1, 32), + (16384, 16384, 4096, 16, 16, False, True, True): (1, 16, 3, 4), + (16384, 16384, 4096, 16, 16, True, False, True): (2, 64, 3, 1), + (16384, 16384, 4096, 32, 32, False, True, True): (1, 64, 3, 4), + (16384, 16384, 4096, 32, 32, True, False, True): (1, 32, 3, 4), + (16384, 16384, 4096, 64, 64, False, True, True): (4, 64, 2, 4), + (16384, 16384, 4096, 64, 64, True, False, True): (2, 64, 2, 4), + (16384, 16384, 4096, 128, 128, False, True, True): (4, 32, 1, 32), + (16384, 16384, 4096, 128, 128, True, False, True): (4, 32, 1, 32), + (16384, 16384, 8192, 16, 16, False, True, True): (1, 128, 3, 1), + (16384, 16384, 8192, 16, 16, True, False, True): (2, 128, 3, 1), + (16384, 16384, 8192, 32, 32, False, True, True): (1, 128, 3, 4), + (16384, 16384, 8192, 32, 32, True, False, True): (1, 64, 3, 4), + (16384, 16384, 8192, 64, 64, False, True, True): (2, 128, 2, 4), + (16384, 16384, 8192, 64, 64, True, False, True): (2, 128, 2, 4), + (16384, 16384, 8192, 128, 128, False, True, True): (4, 64, 1, 32), + (16384, 16384, 8192, 128, 128, True, False, True): (4, 64, 1, 32), + (16384, 16384, 16384, 16, 16, False, True, True): (1, 64, 3, 4), + (16384, 16384, 16384, 16, 16, True, False, True): (2, 256, 3, 1), + (16384, 16384, 16384, 32, 32, False, True, True): (1, 256, 3, 4), + (16384, 16384, 16384, 32, 32, True, False, True): (1, 128, 3, 4), + (16384, 16384, 16384, 64, 64, False, True, True): (2, 256, 2, 4), + (16384, 16384, 16384, 64, 64, True, False, True): (2, 256, 2, 4), + (16384, 16384, 16384, 128, 128, False, True, True): (4, 128, 1, 32), + (16384, 16384, 16384, 128, 128, True, False, True): (4, 128, 1, 32), + (16384, 16384, 32768, 16, 16, False, True, True): (1, 512, 3, 1), + (16384, 16384, 32768, 16, 16, True, False, True): (1, 128, 3, 4), + (16384, 16384, 32768, 32, 32, False, True, True): (2, 512, 3, 4), + (16384, 16384, 32768, 32, 32, True, False, True): (1, 256, 4, 4), + (16384, 16384, 32768, 64, 64, False, True, True): (2, 512, 2, 4), + (16384, 16384, 32768, 64, 64, True, False, True): (2, 512, 2, 4), + (16384, 16384, 32768, 128, 128, False, True, True): (4, 256, 1, 32), + (16384, 16384, 32768, 128, 128, True, False, True): (4, 256, 1, 32), + (16384, 16384, 65536, 16, 16, False, True, True): (1, 256, 3, 4), + (16384, 16384, 65536, 16, 16, True, False, True): (1, 1024, 3, 1), + (16384, 16384, 65536, 32, 32, False, True, True): (1, 1024, 3, 4), + (16384, 16384, 65536, 32, 32, True, False, True): (1, 512, 4, 4), + (16384, 16384, 65536, 64, 64, False, True, True): (2, 1024, 2, 4), + (16384, 16384, 65536, 64, 64, True, False, True): (2, 1024, 2, 4), + (16384, 16384, 65536, 128, 128, False, True, True): (4, 512, 1, 32), + (16384, 16384, 65536, 128, 128, True, False, True): (4, 512, 1, 32), + (16384, 16384, 131072, 16, 16, False, True, True): (1, 1024, 4, 4), + (16384, 16384, 131072, 16, 16, True, False, True): (2, 2048, 3, 1), + (16384, 16384, 131072, 32, 32, False, True, True): (1, 1024, 2, 4), + (16384, 16384, 131072, 32, 32, True, False, True): (1, 1024, 2, 4), + (16384, 16384, 131072, 64, 64, False, True, True): (4, 2048, 2, 4), + (16384, 16384, 131072, 64, 64, True, False, True): (2, 2048, 2, 4), + (16384, 16384, 131072, 128, 128, False, True, True): (4, 1024, 1, 32), + (16384, 16384, 131072, 128, 128, True, False, True): (4, 1024, 1, 32), + }, + ("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.56)): { + (192, 192, 256, 64, 64, False, True, True): (1, 4, 3, 8), + (192, 192, 256, 64, 64, True, False, True): (1, 4, 3, 8), + (192, 192, 512, 64, 64, False, True, True): (2, 8, 3, 8), + (192, 192, 512, 64, 64, True, False, True): (5, 8, 3, 8), + (192, 192, 1024, 64, 64, False, True, True): (2, 16, 4, 8), + (192, 192, 1024, 64, 64, True, False, True): (1, 16, 3, 8), + (192, 192, 2048, 64, 64, False, True, True): (3, 32, 3, 8), + (192, 192, 2048, 64, 64, True, False, True): (5, 32, 5, 8), + (192, 192, 4096, 64, 64, False, True, True): (3, 64, 2, 8), + (192, 192, 4096, 64, 64, True, False, True): (1, 64, 3, 8), + (192, 192, 8192, 64, 64, False, True, True): (3, 128, 3, 8), + (192, 192, 8192, 64, 64, True, False, True): (6, 128, 3, 4), + (192, 192, 16384, 64, 64, False, True, True): (1, 256, 1, 8), + (192, 192, 16384, 64, 64, True, False, True): (1, 256, 3, 4), + (192, 192, 32768, 64, 64, False, True, True): (1, 512, 1, 8), + (192, 192, 32768, 64, 64, True, False, True): (1, 512, 3, 4), + (192, 192, 65536, 64, 64, False, True, True): (1, 1024, 1, 8), + (192, 192, 65536, 64, 64, True, False, True): (1, 1024, 3, 4), + (192, 192, 131072, 64, 64, False, True, True): (1, 2048, 1, 8), + (192, 192, 131072, 64, 64, True, False, True): (3, 2048, 1, 4), + (384, 384, 256, 128, 128, False, True, True): (1, 2, 1, 32), + (384, 384, 256, 128, 128, True, False, True): (1, 2, 1, 32), + (384, 384, 512, 128, 128, False, True, True): (1, 4, 1, 32), + (384, 384, 512, 128, 128, True, False, True): (2, 4, 1, 32), + (384, 384, 1024, 128, 128, False, True, True): (1, 8, 1, 32), + (384, 384, 1024, 128, 128, True, False, True): (4, 8, 1, 32), + (384, 384, 2048, 128, 128, False, True, True): (1, 16, 1, 32), + (384, 384, 2048, 128, 128, True, False, True): (1, 16, 1, 32), + (384, 384, 4096, 128, 128, False, True, True): (1, 32, 1, 32), + (384, 384, 4096, 128, 128, True, False, True): (2, 32, 2, 32), + (384, 384, 8192, 128, 128, False, True, True): (1, 64, 1, 32), + (384, 384, 8192, 128, 128, True, False, True): (1, 64, 2, 32), + (384, 384, 16384, 128, 128, False, True, True): (1, 128, 1, 32), + (384, 384, 16384, 128, 128, True, False, True): (4, 128, 1, 32), + (384, 384, 32768, 128, 128, False, True, True): (3, 256, 1, 32), + (384, 384, 32768, 128, 128, True, False, True): (3, 256, 1, 32), + (384, 384, 65536, 128, 128, False, True, True): (3, 512, 1, 32), + (384, 384, 65536, 128, 128, True, False, True): (3, 512, 1, 32), + (384, 384, 131072, 128, 128, False, True, True): (1, 1024, 1, 32), + (384, 384, 131072, 128, 128, True, False, True): (3, 1024, 1, 32), + }, + ("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): { + (256, 256, 256, 16, 16): (1, 1, 16, 16, 1, 2), + (256, 256, 256, 32, 32): (1, 1, 16, 16, 1, 4), + (256, 256, 256, 64, 64): (1, 1, 16, 16, 1, 1), + (256, 256, 256, 128, 128): (2, 4, 16, 64, 1, 4), + (256, 256, 512, 16, 16): (1, 1, 16, 16, 1, 4), + (256, 256, 512, 32, 32): (1, 1, 16, 32, 1, 4), + (256, 256, 512, 64, 64): (1, 1, 16, 32, 1, 1), + (256, 256, 512, 128, 128): (1, 1, 32, 32, 1, 4), + (256, 256, 1024, 16, 16): (1, 1, 16, 16, 1, 4), + (256, 256, 1024, 32, 32): (1, 2, 16, 32, 1, 1), + (256, 256, 1024, 64, 64): (1, 1, 32, 32, 1, 2), + (256, 256, 1024, 128, 128): (1, 1, 32, 64, 1, 4), + (256, 256, 2048, 16, 16): (1, 1, 16, 64, 1, 8), + (256, 256, 2048, 32, 32): (2, 1, 32, 64, 1, 2), + (256, 256, 2048, 64, 64): (1, 1, 32, 32, 1, 1), + (256, 256, 2048, 128, 128): (1, 1, 64, 64, 1, 4), + (256, 256, 4096, 16, 16): (1, 1, 16, 64, 1, 1), + (256, 256, 4096, 32, 32): (2, 2, 32, 64, 1, 2), + (256, 256, 4096, 64, 64): (1, 1, 32, 128, 1, 4), + (256, 256, 4096, 128, 128): (1, 1, 64, 64, 1, 4), + (256, 256, 8192, 16, 16): (1, 2, 16, 64, 1, 2), + (256, 256, 8192, 32, 32): (1, 1, 32, 64, 1, 2), + (256, 256, 8192, 64, 64): (1, 1, 32, 64, 1, 2), + (256, 256, 8192, 128, 128): (1, 1, 64, 64, 1, 4), + (256, 256, 16384, 16, 16): (1, 1, 16, 64, 1, 2), + (256, 256, 16384, 32, 32): (1, 1, 32, 64, 1, 2), + (256, 256, 16384, 64, 64): (1, 1, 64, 64, 1, 2), + (256, 256, 16384, 128, 128): (2, 16, 64, 64, 1, 4), + (256, 256, 32768, 16, 16): (1, 1, 16, 128, 1, 2), + (256, 256, 32768, 32, 32): (1, 1, 32, 64, 1, 2), + (256, 256, 32768, 64, 64): (1, 1, 64, 64, 1, 2), + (256, 256, 32768, 128, 128): (2, 32, 64, 64, 1, 4), + (256, 256, 65536, 16, 16): (1, 1, 16, 64, 1, 1), + (256, 256, 65536, 32, 32): (1, 1, 32, 64, 1, 2), + (256, 256, 65536, 64, 64): (1, 1, 64, 32, 1, 1), + (256, 256, 65536, 128, 128): (2, 32, 64, 64, 1, 4), + (256, 256, 131072, 16, 16): (1, 1, 16, 64, 1, 1), + (256, 256, 131072, 32, 32): (1, 1, 32, 64, 1, 2), + (256, 256, 131072, 64, 64): (4, 1, 64, 32, 1, 1), + (256, 256, 131072, 128, 128): (2, 64, 64, 64, 1, 4), + (512, 512, 256, 16, 16): (1, 1, 16, 16, 1, 2), + (512, 512, 256, 32, 32): (1, 1, 16, 32, 1, 1), + (512, 512, 256, 64, 64): (1, 2, 16, 32, 1, 1), + (512, 512, 256, 128, 128): (2, 16, 64, 16, 2, 4), + (512, 512, 512, 16, 16): (1, 1, 16, 16, 1, 4), + (512, 512, 512, 32, 32): (1, 1, 16, 32, 1, 1), + (512, 512, 512, 64, 64): (1, 1, 32, 32, 1, 2), + (512, 512, 512, 128, 128): (2, 8, 32, 64, 1, 4), + (512, 512, 1024, 16, 16): (1, 1, 16, 64, 1, 8), + (512, 512, 1024, 32, 32): (1, 1, 32, 32, 3, 1), + (512, 512, 1024, 64, 64): (1, 4, 32, 64, 1, 2), + (512, 512, 1024, 128, 128): (1, 4, 64, 64, 1, 4), + (512, 512, 2048, 16, 16): (1, 1, 16, 64, 1, 2), + (512, 512, 2048, 32, 32): (1, 1, 32, 64, 1, 2), + (512, 512, 2048, 64, 64): (1, 1, 64, 64, 3, 4), + (512, 512, 2048, 128, 128): (1, 1, 64, 64, 1, 4), + (512, 512, 4096, 16, 16): (1, 1, 16, 64, 1, 2), + (512, 512, 4096, 32, 32): (2, 64, 32, 64, 1, 2), + (512, 512, 4096, 64, 64): (1, 1, 64, 64, 3, 4), + (512, 512, 4096, 128, 128): (1, 1, 64, 64, 1, 4), + (512, 512, 8192, 16, 16): (1, 2, 16, 128, 1, 2), + (512, 512, 8192, 32, 32): (1, 1, 32, 64, 1, 2), + (512, 512, 8192, 64, 64): (1, 1, 64, 64, 1, 2), + (512, 512, 8192, 128, 128): (1, 1, 64, 64, 1, 4), + (512, 512, 16384, 16, 16): (1, 2, 16, 128, 1, 2), + (512, 512, 16384, 32, 32): (1, 1, 32, 64, 1, 2), + (512, 512, 16384, 64, 64): (1, 1, 64, 64, 3, 2), + (512, 512, 16384, 128, 128): (2, 1, 64, 64, 1, 4), + (512, 512, 32768, 16, 16): (1, 2, 16, 128, 1, 2), + (512, 512, 32768, 32, 32): (1, 1, 32, 64, 1, 2), + (512, 512, 32768, 64, 64): (1, 1, 64, 64, 3, 4), + (512, 512, 32768, 128, 128): (2, 1, 64, 64, 1, 4), + (512, 512, 65536, 16, 16): (1, 2, 16, 128, 1, 2), + (512, 512, 65536, 32, 32): (1, 1, 32, 64, 1, 2), + (512, 512, 65536, 64, 64): (1, 1, 64, 64, 3, 4), + (512, 512, 65536, 128, 128): (2, 1, 64, 64, 1, 4), + (512, 512, 131072, 16, 16): (1, 1, 16, 64, 1, 1), + (512, 512, 131072, 32, 32): (1, 1, 32, 64, 1, 2), + (512, 512, 131072, 64, 64): (1, 1, 64, 64, 3, 4), + (512, 512, 131072, 128, 128): (2, 4, 64, 64, 1, 4), + (1024, 1024, 256, 16, 16): (1, 1, 16, 16, 1, 4), + (1024, 1024, 256, 32, 32): (2, 16, 32, 16, 3, 4), + (1024, 1024, 256, 64, 64): (1, 4, 32, 32, 1, 2), + (1024, 1024, 256, 128, 128): (1, 4, 128, 16, 3, 16), + (1024, 1024, 512, 16, 16): (1, 1, 16, 64, 1, 2), + (1024, 1024, 512, 32, 32): (2, 2, 32, 64, 1, 2), + (1024, 1024, 512, 64, 64): (2, 8, 64, 64, 3, 4), + (1024, 1024, 512, 128, 128): (1, 4, 64, 64, 1, 8), + (1024, 1024, 1024, 16, 16): (1, 1, 16, 64, 1, 2), + (1024, 1024, 1024, 32, 32): (1, 1, 32, 64, 1, 2), + (1024, 1024, 1024, 64, 64): (1, 8, 64, 64, 3, 4), + (1024, 1024, 1024, 128, 128): (1, 8, 64, 64, 1, 4), + (1024, 1024, 2048, 16, 16): (1, 2, 16, 64, 1, 2), + (1024, 1024, 2048, 32, 32): (1, 1, 32, 64, 1, 2), + (1024, 1024, 2048, 64, 64): (2, 16, 64, 64, 2, 2), + (1024, 1024, 2048, 128, 128): (2, 32, 64, 64, 1, 4), + (1024, 1024, 4096, 16, 16): (2, 16, 16, 128, 1, 2), + (1024, 1024, 4096, 32, 32): (1, 16, 32, 64, 3, 2), + (1024, 1024, 4096, 64, 64): (1, 1, 64, 64, 3, 4), + (1024, 1024, 4096, 128, 128): (2, 64, 128, 64, 1, 4), + (1024, 1024, 8192, 16, 16): (2, 16, 16, 128, 1, 2), + (1024, 1024, 8192, 32, 32): (1, 16, 32, 64, 3, 2), + (1024, 1024, 8192, 64, 64): (1, 1, 64, 64, 3, 4), + (1024, 1024, 8192, 128, 128): (2, 1, 64, 64, 1, 4), + (1024, 1024, 16384, 16, 16): (1, 2, 16, 128, 1, 2), + (1024, 1024, 16384, 32, 32): (1, 16, 32, 64, 3, 2), + (1024, 1024, 16384, 64, 64): (1, 1, 64, 64, 3, 4), + (1024, 1024, 16384, 128, 128): (2, 16, 128, 64, 1, 4), + (1024, 1024, 32768, 16, 16): (1, 1, 16, 128, 1, 2), + (1024, 1024, 32768, 32, 32): (1, 1, 32, 128, 1, 2), + (1024, 1024, 32768, 64, 64): (1, 32, 64, 32, 2, 1), + (1024, 1024, 32768, 128, 128): (2, 8, 128, 64, 1, 4), + (1024, 1024, 65536, 16, 16): (3, 2, 16, 128, 1, 2), + (1024, 1024, 65536, 32, 32): (1, 1, 32, 128, 1, 2), + (1024, 1024, 65536, 64, 64): (2, 4, 64, 32, 2, 1), + (1024, 1024, 65536, 128, 128): (2, 8, 128, 64, 1, 4), + (1024, 1024, 131072, 16, 16): (2, 1, 16, 128, 1, 2), + (1024, 1024, 131072, 32, 32): (1, 1, 32, 128, 1, 2), + (1024, 1024, 131072, 64, 64): (1, 4, 64, 32, 2, 1), + (1024, 1024, 131072, 128, 128): (4, 1, 128, 64, 1, 4), + (2048, 2048, 256, 16, 16): (1, 1, 16, 64, 1, 8), + (2048, 2048, 256, 32, 32): (1, 1, 32, 32, 3, 1), + (2048, 2048, 256, 64, 64): (1, 1, 32, 32, 2, 1), + (2048, 2048, 256, 128, 128): (1, 4, 64, 64, 1, 8), + (2048, 2048, 512, 16, 16): (1, 2, 16, 64, 1, 2), + (2048, 2048, 512, 32, 32): (1, 2, 32, 64, 1, 4), + (2048, 2048, 512, 64, 64): (1, 4, 64, 64, 1, 8), + (2048, 2048, 512, 128, 128): (1, 4, 64, 64, 1, 4), + (2048, 2048, 1024, 16, 16): (1, 2, 16, 128, 1, 2), + (2048, 2048, 1024, 32, 32): (1, 1, 32, 64, 1, 2), + (2048, 2048, 1024, 64, 64): (1, 8, 64, 64, 1, 4), + (2048, 2048, 1024, 128, 128): (1, 8, 128, 64, 1, 4), + (2048, 2048, 2048, 16, 16): (3, 4, 16, 128, 1, 2), + (2048, 2048, 2048, 32, 32): (1, 16, 32, 64, 5, 2), + (2048, 2048, 2048, 64, 64): (1, 1, 64, 64, 3, 4), + (2048, 2048, 2048, 128, 128): (1, 8, 128, 64, 1, 4), + (2048, 2048, 4096, 16, 16): (1, 2, 16, 128, 1, 2), + (2048, 2048, 4096, 32, 32): (1, 8, 32, 64, 3, 2), + (2048, 2048, 4096, 64, 64): (1, 1, 64, 64, 3, 4), + (2048, 2048, 4096, 128, 128): (1, 8, 128, 64, 1, 4), + (2048, 2048, 8192, 16, 16): (2, 4, 16, 128, 1, 2), + (2048, 2048, 8192, 32, 32): (1, 4, 32, 128, 3, 2), + (2048, 2048, 8192, 64, 64): (1, 8, 64, 64, 3, 2), + (2048, 2048, 8192, 128, 128): (1, 8, 128, 64, 1, 4), + (2048, 2048, 16384, 16, 16): (1, 2, 16, 128, 1, 2), + (2048, 2048, 16384, 32, 32): (1, 4, 32, 128, 3, 2), + (2048, 2048, 16384, 64, 64): (1, 8, 64, 64, 3, 2), + (2048, 2048, 16384, 128, 128): (1, 4, 128, 64, 1, 4), + (2048, 2048, 32768, 16, 16): (3, 2, 16, 128, 1, 2), + (2048, 2048, 32768, 32, 32): (1, 1, 32, 128, 3, 2), + (2048, 2048, 32768, 64, 64): (1, 1, 64, 64, 3, 2), + (2048, 2048, 32768, 128, 128): (1, 4, 128, 64, 1, 4), + (2048, 2048, 65536, 16, 16): (1, 2, 16, 128, 1, 2), + (2048, 2048, 65536, 32, 32): (1, 4, 32, 128, 1, 2), + (2048, 2048, 65536, 64, 64): (1, 1, 64, 64, 3, 2), + (2048, 2048, 65536, 128, 128): (1, 2, 128, 64, 1, 4), + (2048, 2048, 131072, 16, 16): (4, 2, 16, 128, 1, 2), + (2048, 2048, 131072, 32, 32): (1, 1, 32, 128, 3, 2), + (2048, 2048, 131072, 64, 64): (1, 1, 64, 64, 3, 2), + (2048, 2048, 131072, 128, 128): (1, 2, 128, 64, 1, 4), + (4096, 4096, 256, 16, 16): (1, 1, 16, 64, 1, 2), + (4096, 4096, 256, 32, 32): (1, 1, 32, 64, 3, 4), + (4096, 4096, 256, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 256, 128, 128): (3, 4, 128, 32, 1, 4), + (4096, 4096, 512, 16, 16): (1, 2, 16, 128, 1, 2), + (4096, 4096, 512, 32, 32): (1, 2, 32, 64, 3, 2), + (4096, 4096, 512, 64, 64): (1, 4, 64, 64, 1, 4), + (4096, 4096, 512, 128, 128): (1, 4, 128, 64, 1, 4), + (4096, 4096, 1024, 16, 16): (1, 2, 16, 128, 1, 2), + (4096, 4096, 1024, 32, 32): (1, 8, 32, 64, 3, 2), + (4096, 4096, 1024, 64, 64): (1, 4, 64, 64, 1, 4), + (4096, 4096, 1024, 128, 128): (2, 4, 128, 64, 1, 4), + (4096, 4096, 2048, 16, 16): (1, 1, 16, 128, 1, 2), + (4096, 4096, 2048, 32, 32): (1, 4, 32, 128, 1, 4), + (4096, 4096, 2048, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 2048, 128, 128): (1, 16, 128, 64, 1, 4), + (4096, 4096, 4096, 16, 16): (1, 1, 16, 64, 3, 1), + (4096, 4096, 4096, 32, 32): (1, 4, 32, 64, 3, 2), + (4096, 4096, 4096, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 4096, 128, 128): (5, 1, 128, 64, 1, 4), + (4096, 4096, 8192, 16, 16): (1, 1, 16, 128, 1, 2), + (4096, 4096, 8192, 32, 32): (1, 1, 32, 128, 3, 2), + (4096, 4096, 8192, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 8192, 128, 128): (2, 1, 128, 64, 1, 4), + (4096, 4096, 16384, 16, 16): (1, 1, 16, 128, 1, 2), + (4096, 4096, 16384, 32, 32): (1, 1, 32, 128, 3, 2), + (4096, 4096, 16384, 64, 64): (1, 1, 64, 64, 4, 4), + (4096, 4096, 16384, 128, 128): (2, 1, 128, 64, 1, 4), + (4096, 4096, 32768, 16, 16): (3, 1, 16, 128, 1, 2), + (4096, 4096, 32768, 32, 32): (1, 1, 32, 128, 3, 2), + (4096, 4096, 32768, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 32768, 128, 128): (2, 1, 128, 64, 1, 4), + (4096, 4096, 65536, 16, 16): (2, 2, 16, 128, 1, 2), + (4096, 4096, 65536, 32, 32): (1, 1, 32, 128, 4, 2), + (4096, 4096, 65536, 64, 64): (1, 1, 64, 64, 4, 4), + (4096, 4096, 65536, 128, 128): (2, 1, 128, 64, 1, 4), + (4096, 4096, 131072, 16, 16): (2, 1, 16, 128, 1, 2), + (4096, 4096, 131072, 32, 32): (1, 1, 32, 128, 3, 2), + (4096, 4096, 131072, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 131072, 128, 128): (2, 1, 128, 64, 1, 4), + (8192, 8192, 256, 16, 16): (1, 2, 16, 64, 1, 2), + (8192, 8192, 256, 32, 32): (1, 1, 32, 64, 1, 2), + (8192, 8192, 256, 64, 64): (1, 2, 64, 64, 1, 4), + (8192, 8192, 256, 128, 128): (3, 16, 128, 16, 1, 2), + (8192, 8192, 512, 16, 16): (1, 2, 16, 128, 1, 2), + (8192, 8192, 512, 32, 32): (1, 4, 32, 64, 3, 2), + (8192, 8192, 512, 64, 64): (2, 8, 64, 64, 4, 4), + (8192, 8192, 512, 128, 128): (1, 8, 128, 64, 1, 4), + (8192, 8192, 1024, 16, 16): (4, 2, 16, 128, 1, 2), + (8192, 8192, 1024, 32, 32): (1, 8, 32, 128, 1, 2), + (8192, 8192, 1024, 64, 64): (1, 16, 64, 64, 3, 2), + (8192, 8192, 1024, 128, 128): (2, 16, 128, 64, 2, 4), + (8192, 8192, 2048, 16, 16): (2, 1, 16, 64, 4, 1), + (8192, 8192, 2048, 32, 32): (1, 16, 32, 64, 5, 2), + (8192, 8192, 2048, 64, 64): (1, 16, 64, 64, 3, 2), + (8192, 8192, 2048, 128, 128): (2, 16, 128, 64, 2, 4), + (8192, 8192, 4096, 16, 16): (1, 1, 16, 64, 4, 1), + (8192, 8192, 4096, 32, 32): (1, 16, 32, 64, 5, 2), + (8192, 8192, 4096, 64, 64): (1, 16, 64, 64, 3, 2), + (8192, 8192, 4096, 128, 128): (2, 64, 128, 64, 2, 4), + (8192, 8192, 8192, 16, 16): (1, 1, 16, 64, 4, 1), + (8192, 8192, 8192, 32, 32): (1, 8, 32, 128, 5, 4), + (8192, 8192, 8192, 64, 64): (1, 8, 64, 64, 3, 2), + (8192, 8192, 8192, 128, 128): (2, 8, 128, 64, 1, 4), + (8192, 8192, 16384, 16, 16): (1, 1, 16, 64, 4, 1), + (8192, 8192, 16384, 32, 32): (1, 8, 32, 64, 5, 2), + (8192, 8192, 16384, 64, 64): (1, 8, 64, 64, 3, 2), + (8192, 8192, 16384, 128, 128): (1, 8, 128, 64, 1, 4), + (8192, 8192, 32768, 16, 16): (1, 1, 16, 64, 4, 1), + (8192, 8192, 32768, 32, 32): (1, 8, 32, 64, 5, 2), + (8192, 8192, 32768, 64, 64): (3, 8, 64, 64, 3, 2), + (8192, 8192, 32768, 128, 128): (2, 8, 128, 64, 1, 4), + (8192, 8192, 65536, 16, 16): (1, 1, 16, 64, 4, 1), + (8192, 8192, 65536, 32, 32): (5, 4, 32, 64, 3, 2), + (8192, 8192, 65536, 64, 64): (1, 8, 64, 64, 3, 2), + (8192, 8192, 65536, 128, 128): (2, 8, 128, 64, 1, 4), + (8192, 8192, 131072, 16, 16): (2, 1, 16, 64, 4, 1), + (8192, 8192, 131072, 32, 32): (1, 4, 32, 64, 5, 2), + (8192, 8192, 131072, 64, 64): (1, 4, 64, 128, 3, 4), + (8192, 8192, 131072, 128, 128): (2, 8, 128, 64, 1, 4), + (16384, 16384, 256, 16, 16): (1, 2, 16, 128, 1, 2), + (16384, 16384, 256, 32, 32): (1, 4, 32, 64, 3, 2), + (16384, 16384, 256, 64, 64): (2, 4, 64, 64, 4, 4), + (16384, 16384, 256, 128, 128): (1, 4, 128, 64, 1, 16), + (16384, 16384, 512, 16, 16): (1, 2, 16, 128, 3, 2), + (16384, 16384, 512, 32, 32): (1, 4, 32, 128, 5, 4), + (16384, 16384, 512, 64, 64): (1, 8, 64, 64, 3, 2), + (16384, 16384, 512, 128, 128): (2, 8, 128, 64, 1, 4), + (16384, 16384, 1024, 16, 16): (1, 2, 16, 128, 1, 2), + (16384, 16384, 1024, 32, 32): (1, 8, 32, 64, 5, 2), + (16384, 16384, 1024, 64, 64): (1, 16, 64, 64, 3, 2), + (16384, 16384, 1024, 128, 128): (5, 16, 128, 64, 2, 4), + (16384, 16384, 2048, 16, 16): (1, 2, 16, 128, 1, 2), + (16384, 16384, 2048, 32, 32): (1, 8, 32, 64, 5, 2), + (16384, 16384, 2048, 64, 64): (1, 16, 64, 64, 3, 2), + (16384, 16384, 2048, 128, 128): (4, 32, 128, 64, 2, 4), + (16384, 16384, 4096, 16, 16): (3, 2, 16, 128, 1, 2), + (16384, 16384, 4096, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 4096, 64, 64): (2, 16, 64, 64, 3, 2), + (16384, 16384, 4096, 128, 128): (3, 32, 128, 64, 2, 4), + (16384, 16384, 8192, 16, 16): (1, 2, 16, 128, 1, 2), + (16384, 16384, 8192, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 8192, 64, 64): (4, 8, 64, 64, 3, 2), + (16384, 16384, 8192, 128, 128): (5, 8, 128, 64, 1, 4), + (16384, 16384, 16384, 16, 16): (1, 2, 16, 128, 1, 2), + (16384, 16384, 16384, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 16384, 64, 64): (2, 4, 64, 128, 3, 4), + (16384, 16384, 16384, 128, 128): (4, 8, 128, 64, 1, 4), + (16384, 16384, 32768, 16, 16): (4, 2, 16, 128, 1, 2), + (16384, 16384, 32768, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 32768, 64, 64): (1, 8, 64, 64, 3, 2), + (16384, 16384, 32768, 128, 128): (2, 512, 128, 64, 2, 4), + (16384, 16384, 65536, 16, 16): (3, 2, 16, 128, 1, 2), + (16384, 16384, 65536, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 65536, 64, 64): (1, 4, 64, 128, 3, 4), + (16384, 16384, 65536, 128, 128): (2, 1024, 128, 64, 2, 4), + (16384, 16384, 131072, 16, 16): (1, 2, 16, 128, 1, 2), + (16384, 16384, 131072, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 131072, 64, 64): (3, 4, 64, 128, 3, 4), + (16384, 16384, 131072, 128, 128): (4, 2048, 128, 64, 2, 4), + }, + ("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.5)): { + (256, 256, 256, 16, 16): (5, 4, 16, 16, 1, 4), + (256, 256, 256, 32, 32): (5, 2, 32, 16, 1, 4), + (256, 256, 256, 64, 64): (4, 1, 32, 32, 1, 8), + (256, 256, 256, 128, 128): (2, 1, 32, 32, 1, 4), + (256, 256, 512, 16, 16): (2, 2, 16, 32, 1, 4), + (256, 256, 512, 32, 32): (4, 8, 32, 32, 1, 8), + (256, 256, 512, 64, 64): (4, 8, 32, 64, 1, 4), + (256, 256, 512, 128, 128): (4, 8, 32, 64, 1, 4), + (256, 256, 1024, 16, 16): (4, 2, 16, 64, 1, 2), + (256, 256, 1024, 32, 32): (4, 16, 32, 64, 1, 2), + (256, 256, 1024, 64, 64): (4, 16, 32, 64, 1, 4), + (256, 256, 1024, 128, 128): (4, 16, 64, 64, 1, 8), + (256, 256, 2048, 16, 16): (2, 16, 16, 64, 1, 8), + (256, 256, 2048, 32, 32): (4, 16, 32, 64, 1, 2), + (256, 256, 2048, 64, 64): (4, 16, 32, 64, 1, 4), + (256, 256, 2048, 128, 128): (4, 16, 64, 64, 1, 4), + (256, 256, 4096, 16, 16): (4, 32, 16, 64, 1, 1), + (256, 256, 4096, 32, 32): (2, 64, 32, 64, 1, 2), + (256, 256, 4096, 64, 64): (4, 64, 64, 64, 1, 4), + (256, 256, 4096, 128, 128): (4, 32, 64, 64, 1, 4), + (256, 256, 8192, 16, 16): (4, 64, 16, 64, 1, 1), + (256, 256, 8192, 32, 32): (4, 128, 32, 64, 1, 2), + (256, 256, 8192, 64, 64): (4, 64, 64, 64, 1, 4), + (256, 256, 8192, 128, 128): (4, 64, 64, 64, 1, 4), + (256, 256, 16384, 16, 16): (4, 128, 16, 64, 1, 1), + (256, 256, 16384, 32, 32): (2, 128, 32, 64, 1, 2), + (256, 256, 16384, 64, 64): (4, 32, 32, 128, 1, 4), + (256, 256, 16384, 128, 128): (4, 16, 64, 64, 1, 4), + (256, 256, 32768, 16, 16): (4, 64, 16, 64, 1, 1), + (256, 256, 32768, 32, 32): (2, 256, 32, 64, 1, 2), + (256, 256, 32768, 64, 64): (4, 32, 32, 128, 1, 4), + (256, 256, 32768, 128, 128): (4, 32, 64, 64, 1, 4), + (256, 256, 65536, 16, 16): (4, 128, 16, 64, 1, 1), + (256, 256, 65536, 32, 32): (4, 1, 32, 64, 1, 2), + (256, 256, 65536, 64, 64): (2, 1, 64, 64, 1, 2), + (256, 256, 65536, 128, 128): (4, 32, 64, 64, 1, 4), + (256, 256, 131072, 16, 16): (4, 64, 16, 64, 1, 1), + (256, 256, 131072, 32, 32): (2, 1, 32, 64, 1, 2), + (256, 256, 131072, 64, 64): (4, 32, 32, 128, 1, 4), + (256, 256, 131072, 128, 128): (4, 32, 64, 64, 1, 4), + (512, 512, 256, 16, 16): (4, 16, 16, 16, 1, 4), + (512, 512, 256, 32, 32): (2, 4, 32, 16, 1, 4), + (512, 512, 256, 64, 64): (2, 16, 64, 16, 3, 8), + (512, 512, 256, 128, 128): (4, 16, 64, 16, 1, 4), + (512, 512, 512, 16, 16): (1, 1, 16, 64, 1, 8), + (512, 512, 512, 32, 32): (2, 4, 16, 32, 1, 1), + (512, 512, 512, 64, 64): (2, 1, 32, 32, 1, 2), + (512, 512, 512, 128, 128): (4, 8, 32, 64, 1, 4), + (512, 512, 1024, 16, 16): (2, 8, 16, 64, 1, 8), + (512, 512, 1024, 32, 32): (4, 16, 32, 64, 1, 2), + (512, 512, 1024, 64, 64): (4, 16, 64, 64, 1, 4), + (512, 512, 1024, 128, 128): (2, 8, 64, 64, 1, 4), + (512, 512, 2048, 16, 16): (4, 16, 16, 64, 1, 4), + (512, 512, 2048, 32, 32): (4, 16, 32, 64, 1, 2), + (512, 512, 2048, 64, 64): (4, 16, 64, 64, 1, 8), + (512, 512, 2048, 128, 128): (4, 16, 64, 64, 1, 4), + (512, 512, 4096, 16, 16): (4, 32, 16, 128, 1, 2), + (512, 512, 4096, 32, 32): (4, 32, 32, 64, 1, 2), + (512, 512, 4096, 64, 64): (4, 32, 64, 64, 1, 4), + (512, 512, 4096, 128, 128): (4, 32, 64, 64, 1, 4), + (512, 512, 8192, 16, 16): (2, 32, 16, 128, 1, 2), + (512, 512, 8192, 32, 32): (4, 64, 32, 64, 1, 2), + (512, 512, 8192, 64, 64): (4, 128, 64, 64, 1, 2), + (512, 512, 8192, 128, 128): (4, 64, 64, 64, 1, 4), + (512, 512, 16384, 16, 16): (4, 32, 16, 64, 1, 1), + (512, 512, 16384, 32, 32): (4, 64, 32, 64, 1, 2), + (512, 512, 16384, 64, 64): (4, 16, 64, 64, 1, 4), + (512, 512, 16384, 128, 128): (4, 32, 64, 64, 1, 4), + (512, 512, 32768, 16, 16): (7, 16, 16, 128, 1, 2), + (512, 512, 32768, 32, 32): (4, 64, 32, 64, 1, 2), + (512, 512, 32768, 64, 64): (2, 32, 64, 64, 3, 2), + (512, 512, 32768, 128, 128): (2, 32, 64, 64, 1, 4), + (512, 512, 65536, 16, 16): (2, 32, 16, 64, 1, 1), + (512, 512, 65536, 32, 32): (4, 64, 32, 64, 1, 2), + (512, 512, 65536, 64, 64): (3, 32, 64, 64, 3, 2), + (512, 512, 65536, 128, 128): (4, 16, 64, 64, 1, 4), + (512, 512, 131072, 16, 16): (3, 32, 16, 128, 1, 2), + (512, 512, 131072, 32, 32): (4, 64, 32, 64, 1, 2), + (512, 512, 131072, 64, 64): (2, 32, 64, 64, 3, 2), + (512, 512, 131072, 128, 128): (3, 1, 64, 64, 1, 4), + (1024, 1024, 256, 16, 16): (4, 16, 16, 16, 1, 4), + (1024, 1024, 256, 32, 32): (4, 16, 32, 16, 1, 4), + (1024, 1024, 256, 64, 64): (4, 4, 64, 32, 1, 16), + (1024, 1024, 256, 128, 128): (4, 16, 64, 16, 1, 8), + (1024, 1024, 512, 16, 16): (2, 8, 16, 64, 1, 8), + (1024, 1024, 512, 32, 32): (3, 2, 32, 64, 1, 2), + (1024, 1024, 512, 64, 64): (4, 8, 32, 64, 1, 8), + (1024, 1024, 512, 128, 128): (4, 8, 64, 64, 1, 8), + (1024, 1024, 1024, 16, 16): (2, 2, 16, 64, 1, 2), + (1024, 1024, 1024, 32, 32): (2, 8, 32, 64, 1, 2), + (1024, 1024, 1024, 64, 64): (2, 8, 32, 128, 1, 4), + (1024, 1024, 1024, 128, 128): (2, 8, 64, 64, 1, 4), + (1024, 1024, 2048, 16, 16): (2, 16, 16, 128, 3, 2), + (1024, 1024, 2048, 32, 32): (4, 32, 32, 64, 1, 2), + (1024, 1024, 2048, 64, 64): (4, 16, 64, 64, 1, 4), + (1024, 1024, 2048, 128, 128): (4, 32, 64, 64, 1, 4), + (1024, 1024, 4096, 16, 16): (4, 16, 16, 128, 1, 2), + (1024, 1024, 4096, 32, 32): (3, 32, 32, 64, 1, 2), + (1024, 1024, 4096, 64, 64): (4, 32, 64, 64, 1, 4), + (1024, 1024, 4096, 128, 128): (4, 32, 64, 64, 1, 4), + (1024, 1024, 8192, 16, 16): (5, 16, 16, 128, 1, 2), + (1024, 1024, 8192, 32, 32): (2, 32, 32, 64, 3, 2), + (1024, 1024, 8192, 64, 64): (1, 16, 64, 64, 3, 2), + (1024, 1024, 8192, 128, 128): (4, 32, 64, 64, 1, 4), + (1024, 1024, 16384, 16, 16): (4, 16, 16, 128, 1, 2), + (1024, 1024, 16384, 32, 32): (1, 32, 32, 64, 3, 2), + (1024, 1024, 16384, 64, 64): (4, 16, 64, 64, 3, 2), + (1024, 1024, 16384, 128, 128): (4, 32, 128, 64, 1, 4), + (1024, 1024, 32768, 16, 16): (3, 16, 16, 128, 1, 2), + (1024, 1024, 32768, 32, 32): (1, 8, 32, 64, 3, 2), + (1024, 1024, 32768, 64, 64): (4, 16, 64, 64, 3, 2), + (1024, 1024, 32768, 128, 128): (4, 8, 128, 64, 2, 4), + (1024, 1024, 65536, 16, 16): (1, 2, 16, 128, 1, 2), + (1024, 1024, 65536, 32, 32): (2, 4, 32, 64, 3, 2), + (1024, 1024, 65536, 64, 64): (5, 16, 64, 64, 3, 2), + (1024, 1024, 65536, 128, 128): (5, 8, 128, 64, 2, 4), + (1024, 1024, 131072, 16, 16): (5, 2, 16, 128, 1, 2), + (1024, 1024, 131072, 32, 32): (1, 2, 32, 64, 3, 2), + (1024, 1024, 131072, 64, 64): (5, 16, 64, 64, 3, 2), + (1024, 1024, 131072, 128, 128): (2, 1, 128, 64, 2, 4), + (2048, 2048, 256, 16, 16): (4, 4, 16, 64, 1, 8), + (2048, 2048, 256, 32, 32): (4, 8, 32, 32, 1, 8), + (2048, 2048, 256, 64, 64): (4, 16, 64, 16, 1, 8), + (2048, 2048, 256, 128, 128): (4, 4, 128, 32, 3, 8), + (2048, 2048, 512, 16, 16): (2, 2, 16, 64, 1, 2), + (2048, 2048, 512, 32, 32): (2, 4, 32, 64, 3, 2), + (2048, 2048, 512, 64, 64): (4, 4, 64, 64, 1, 8), + (2048, 2048, 512, 128, 128): (4, 8, 64, 64, 1, 4), + (2048, 2048, 1024, 16, 16): (1, 8, 16, 64, 1, 2), + (2048, 2048, 1024, 32, 32): (2, 16, 32, 64, 3, 2), + (2048, 2048, 1024, 64, 64): (4, 8, 64, 64, 1, 4), + (2048, 2048, 1024, 128, 128): (4, 8, 128, 64, 1, 4), + (2048, 2048, 2048, 16, 16): (5, 4, 16, 128, 1, 2), + (2048, 2048, 2048, 32, 32): (1, 16, 32, 64, 3, 2), + (2048, 2048, 2048, 64, 64): (2, 8, 64, 64, 1, 4), + (2048, 2048, 2048, 128, 128): (2, 8, 128, 64, 1, 4), + (2048, 2048, 4096, 16, 16): (4, 2, 16, 128, 1, 2), + (2048, 2048, 4096, 32, 32): (2, 16, 32, 64, 3, 2), + (2048, 2048, 4096, 64, 64): (2, 8, 64, 64, 3, 2), + (2048, 2048, 4096, 128, 128): (4, 8, 128, 64, 1, 4), + (2048, 2048, 8192, 16, 16): (5, 4, 16, 128, 1, 2), + (2048, 2048, 8192, 32, 32): (2, 8, 32, 64, 3, 2), + (2048, 2048, 8192, 64, 64): (4, 8, 64, 64, 3, 2), + (2048, 2048, 8192, 128, 128): (4, 8, 128, 64, 1, 4), + (2048, 2048, 16384, 16, 16): (3, 2, 16, 128, 1, 2), + (2048, 2048, 16384, 32, 32): (2, 4, 32, 128, 3, 2), + (2048, 2048, 16384, 64, 64): (4, 8, 64, 64, 3, 2), + (2048, 2048, 16384, 128, 128): (4, 4, 128, 64, 1, 4), + (2048, 2048, 32768, 16, 16): (3, 2, 16, 128, 1, 2), + (2048, 2048, 32768, 32, 32): (3, 4, 32, 128, 3, 2), + (2048, 2048, 32768, 64, 64): (6, 4, 64, 64, 3, 2), + (2048, 2048, 32768, 128, 128): (3, 4, 128, 64, 1, 4), + (2048, 2048, 65536, 16, 16): (6, 2, 16, 128, 1, 2), + (2048, 2048, 65536, 32, 32): (1, 2, 32, 128, 1, 2), + (2048, 2048, 65536, 64, 64): (5, 4, 64, 64, 3, 2), + (2048, 2048, 65536, 128, 128): (5, 1, 128, 64, 2, 4), + (2048, 2048, 131072, 16, 16): (3, 2, 16, 128, 1, 2), + (2048, 2048, 131072, 32, 32): (2, 1, 32, 128, 3, 2), + (2048, 2048, 131072, 64, 64): (4, 1, 64, 64, 3, 2), + (2048, 2048, 131072, 128, 128): (3, 1, 128, 64, 2, 4), + (4096, 4096, 256, 16, 16): (5, 8, 16, 32, 1, 4), + (4096, 4096, 256, 32, 32): (4, 16, 32, 16, 2, 4), + (4096, 4096, 256, 64, 64): (2, 1, 64, 64, 3, 4), + (4096, 4096, 256, 128, 128): (4, 4, 128, 32, 1, 4), + (4096, 4096, 512, 16, 16): (4, 2, 16, 128, 1, 2), + (4096, 4096, 512, 32, 32): (4, 8, 32, 64, 1, 2), + (4096, 4096, 512, 64, 64): (4, 4, 64, 64, 1, 4), + (4096, 4096, 512, 128, 128): (4, 8, 128, 64, 2, 4), + (4096, 4096, 1024, 16, 16): (1, 2, 16, 128, 1, 2), + (4096, 4096, 1024, 32, 32): (6, 8, 32, 64, 3, 2), + (4096, 4096, 1024, 64, 64): (2, 16, 64, 64, 4, 4), + (4096, 4096, 1024, 128, 128): (2, 4, 128, 64, 2, 4), + (4096, 4096, 2048, 16, 16): (3, 1, 16, 128, 1, 2), + (4096, 4096, 2048, 32, 32): (1, 4, 32, 64, 5, 2), + (4096, 4096, 2048, 64, 64): (3, 16, 64, 64, 3, 2), + (4096, 4096, 2048, 128, 128): (4, 32, 128, 64, 2, 4), + (4096, 4096, 4096, 16, 16): (1, 2, 16, 128, 1, 2), + (4096, 4096, 4096, 32, 32): (1, 4, 32, 64, 3, 2), + (4096, 4096, 4096, 64, 64): (1, 1, 64, 64, 4, 4), + (4096, 4096, 4096, 128, 128): (2, 1, 128, 128, 1, 8), + (4096, 4096, 8192, 16, 16): (3, 1, 16, 128, 1, 2), + (4096, 4096, 8192, 32, 32): (2, 2, 32, 64, 5, 2), + (4096, 4096, 8192, 64, 64): (4, 16, 64, 64, 3, 2), + (4096, 4096, 8192, 128, 128): (4, 16, 128, 64, 2, 4), + (4096, 4096, 16384, 16, 16): (1, 2, 16, 128, 1, 2), + (4096, 4096, 16384, 32, 32): (4, 2, 32, 64, 5, 2), + (4096, 4096, 16384, 64, 64): (4, 16, 64, 64, 3, 2), + (4096, 4096, 16384, 128, 128): (4, 16, 128, 64, 2, 4), + (4096, 4096, 32768, 16, 16): (3, 1, 16, 128, 1, 2), + (4096, 4096, 32768, 32, 32): (3, 1, 32, 128, 1, 4), + (4096, 4096, 32768, 64, 64): (3, 1, 64, 64, 3, 4), + (4096, 4096, 32768, 128, 128): (5, 16, 128, 64, 2, 4), + (4096, 4096, 65536, 16, 16): (5, 1, 16, 128, 1, 2), + (4096, 4096, 65536, 32, 32): (5, 1, 32, 128, 1, 4), + (4096, 4096, 65536, 64, 64): (1, 1, 64, 64, 3, 4), + (4096, 4096, 65536, 128, 128): (3, 16, 128, 64, 2, 4), + (4096, 4096, 131072, 16, 16): (3, 1, 16, 128, 1, 2), + (4096, 4096, 131072, 32, 32): (3, 1, 32, 128, 3, 2), + (4096, 4096, 131072, 64, 64): (2, 1, 64, 64, 3, 4), + (4096, 4096, 131072, 128, 128): (1, 1, 128, 64, 1, 4), + (8192, 8192, 256, 16, 16): (4, 16, 16, 16, 1, 4), + (8192, 8192, 256, 32, 32): (1, 16, 32, 16, 4, 4), + (8192, 8192, 256, 64, 64): (4, 16, 64, 16, 3, 8), + (8192, 8192, 256, 128, 128): (4, 16, 128, 16, 1, 2), + (8192, 8192, 512, 16, 16): (2, 8, 16, 64, 1, 4), + (8192, 8192, 512, 32, 32): (4, 8, 32, 64, 3, 2), + (8192, 8192, 512, 64, 64): (2, 8, 64, 64, 4, 4), + (8192, 8192, 512, 128, 128): (4, 8, 128, 64, 2, 4), + (8192, 8192, 1024, 16, 16): (4, 16, 16, 64, 1, 8), + (8192, 8192, 1024, 32, 32): (2, 8, 32, 64, 5, 2), + (8192, 8192, 1024, 64, 64): (1, 16, 64, 64, 3, 2), + (8192, 8192, 1024, 128, 128): (5, 16, 128, 64, 2, 4), + (8192, 8192, 2048, 16, 16): (7, 2, 16, 128, 1, 2), + (8192, 8192, 2048, 32, 32): (1, 16, 32, 64, 5, 2), + (8192, 8192, 2048, 64, 64): (4, 16, 64, 64, 3, 2), + (8192, 8192, 2048, 128, 128): (6, 16, 128, 64, 2, 4), + (8192, 8192, 4096, 16, 16): (4, 2, 16, 128, 1, 2), + (8192, 8192, 4096, 32, 32): (2, 8, 32, 64, 5, 2), + (8192, 8192, 4096, 64, 64): (3, 16, 64, 64, 3, 2), + (8192, 8192, 4096, 128, 128): (3, 64, 128, 64, 2, 4), + (8192, 8192, 8192, 16, 16): (4, 2, 16, 128, 1, 2), + (8192, 8192, 8192, 32, 32): (1, 4, 32, 128, 5, 4), + (8192, 8192, 8192, 64, 64): (4, 4, 64, 64, 1, 4), + (8192, 8192, 8192, 128, 128): (2, 2, 128, 128, 3, 8), + (8192, 8192, 16384, 16, 16): (1, 2, 16, 128, 1, 2), + (8192, 8192, 16384, 32, 32): (4, 8, 32, 64, 5, 2), + (8192, 8192, 16384, 64, 64): (5, 8, 64, 64, 3, 2), + (8192, 8192, 16384, 128, 128): (3, 16, 128, 64, 2, 4), + (8192, 8192, 32768, 16, 16): (7, 2, 16, 128, 1, 2), + (8192, 8192, 32768, 32, 32): (3, 4, 32, 64, 3, 2), + (8192, 8192, 32768, 64, 64): (2, 8, 64, 64, 3, 2), + (8192, 8192, 32768, 128, 128): (6, 16, 128, 64, 2, 4), + (8192, 8192, 65536, 16, 16): (9, 2, 16, 128, 1, 2), + (8192, 8192, 65536, 32, 32): (7, 4, 32, 64, 5, 2), + (8192, 8192, 65536, 64, 64): (4, 8, 64, 64, 3, 2), + (8192, 8192, 65536, 128, 128): (3, 16, 128, 64, 2, 4), + (8192, 8192, 131072, 16, 16): (9, 2, 16, 128, 1, 2), + (8192, 8192, 131072, 32, 32): (1, 8, 32, 64, 5, 2), + (8192, 8192, 131072, 64, 64): (1, 8, 64, 64, 3, 2), + (8192, 8192, 131072, 128, 128): (4, 16, 128, 64, 2, 4), + (16384, 16384, 256, 16, 16): (5, 16, 16, 16, 1, 4), + (16384, 16384, 256, 32, 32): (4, 16, 32, 16, 4, 4), + (16384, 16384, 256, 64, 64): (4, 16, 64, 16, 3, 8), + (16384, 16384, 256, 128, 128): (4, 16, 128, 16, 1, 2), + (16384, 16384, 512, 16, 16): (2, 8, 16, 64, 1, 4), + (16384, 16384, 512, 32, 32): (1, 4, 32, 64, 5, 2), + (16384, 16384, 512, 64, 64): (4, 8, 64, 64, 1, 4), + (16384, 16384, 512, 128, 128): (3, 8, 128, 64, 2, 4), + (16384, 16384, 1024, 16, 16): (4, 2, 16, 128, 1, 2), + (16384, 16384, 1024, 32, 32): (4, 8, 32, 64, 5, 2), + (16384, 16384, 1024, 64, 64): (6, 16, 64, 64, 3, 2), + (16384, 16384, 1024, 128, 128): (3, 16, 128, 64, 2, 4), + (16384, 16384, 2048, 16, 16): (3, 2, 16, 128, 1, 2), + (16384, 16384, 2048, 32, 32): (1, 8, 32, 64, 5, 2), + (16384, 16384, 2048, 64, 64): (5, 16, 64, 64, 3, 2), + (16384, 16384, 2048, 128, 128): (2, 32, 128, 64, 2, 4), + (16384, 16384, 4096, 16, 16): (2, 2, 16, 128, 1, 2), + (16384, 16384, 4096, 32, 32): (1, 4, 32, 64, 3, 2), + (16384, 16384, 4096, 64, 64): (2, 8, 64, 64, 3, 2), + (16384, 16384, 4096, 128, 128): (3, 16, 128, 64, 2, 4), + (16384, 16384, 8192, 16, 16): (3, 2, 16, 128, 1, 2), + (16384, 16384, 8192, 32, 32): (2, 4, 32, 64, 5, 2), + (16384, 16384, 8192, 64, 64): (4, 8, 64, 64, 3, 2), + (16384, 16384, 8192, 128, 128): (8, 32, 128, 64, 2, 4), + (16384, 16384, 16384, 16, 16): (1, 2, 16, 256, 1, 4), + (16384, 16384, 16384, 32, 32): (1, 4, 32, 128, 3, 4), + (16384, 16384, 16384, 64, 64): (5, 4, 64, 64, 1, 4), + (16384, 16384, 16384, 128, 128): (4, 8, 128, 64, 2, 4), + (16384, 16384, 32768, 16, 16): (2, 2, 16, 128, 1, 2), + (16384, 16384, 32768, 32, 32): (1, 4, 32, 64, 3, 2), + (16384, 16384, 32768, 64, 64): (5, 4, 64, 64, 1, 4), + (16384, 16384, 32768, 128, 128): (5, 8, 128, 64, 2, 4), + (16384, 16384, 65536, 16, 16): (8, 2, 16, 128, 1, 2), + (16384, 16384, 65536, 32, 32): (6, 4, 32, 64, 5, 2), + (16384, 16384, 65536, 64, 64): (2, 4, 64, 64, 1, 4), + (16384, 16384, 65536, 128, 128): (4, 8, 128, 64, 2, 4), + (16384, 16384, 131072, 16, 16): (3, 1, 16, 128, 1, 2), + (16384, 16384, 131072, 32, 32): (1, 4, 32, 64, 3, 2), + (16384, 16384, 131072, 64, 64): (4, 4, 64, 64, 1, 4), + (16384, 16384, 131072, 128, 128): (1, 8, 128, 64, 2, 4), + (32768, 32768, 256, 16, 16): (4, 16, 16, 16, 1, 4), + (32768, 32768, 512, 16, 16): (4, 2, 16, 128, 1, 2), + (32768, 32768, 1024, 16, 16): (3, 2, 16, 128, 1, 2), + (32768, 32768, 2048, 16, 16): (4, 2, 16, 128, 1, 2), + (32768, 32768, 4096, 16, 16): (5, 4, 16, 64, 1, 1), + (32768, 32768, 8192, 16, 16): (4, 4, 16, 64, 1, 1), + (32768, 32768, 16384, 16, 16): (4, 4, 16, 64, 1, 1), + (32768, 32768, 32768, 16, 16): (5, 4, 16, 64, 1, 1), + }, + ("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.5)): { + (256, 256, 256, 16, 16): (1, 1, 16, 16, 1, 8), + (256, 256, 256, 32, 32): (1, 1, 16, 16, 1, 4), + (256, 256, 256, 64, 64): (1, 1, 16, 16, 1, 4), + (256, 256, 256, 128, 128): (1, 1, 16, 16, 1, 1), + (256, 256, 512, 16, 16): (1, 1, 16, 16, 1, 4), + (256, 256, 512, 32, 32): (1, 16, 16, 16, 1, 1), + (256, 256, 512, 64, 64): (1, 1, 16, 16, 1, 1), + (256, 256, 512, 128, 128): (1, 1, 32, 32, 1, 4), + (256, 256, 1024, 16, 16): (1, 1, 16, 32, 1, 2), + (256, 256, 1024, 32, 32): (1, 4, 16, 16, 1, 1), + (256, 256, 1024, 64, 64): (1, 1, 32, 32, 1, 4), + (256, 256, 1024, 128, 128): (1, 1, 32, 32, 1, 4), + (256, 256, 2048, 16, 16): (1, 2, 16, 32, 1, 2), + (256, 256, 2048, 32, 32): (1, 1, 16, 32, 1, 2), + (256, 256, 2048, 64, 64): (2, 1, 16, 32, 1, 2), + (256, 256, 2048, 128, 128): (1, 1, 16, 16, 1, 1), + (256, 256, 4096, 16, 16): (1, 1, 16, 32, 1, 2), + (256, 256, 4096, 32, 32): (1, 1, 16, 32, 1, 2), + (256, 256, 4096, 64, 64): (1, 1, 32, 32, 1, 4), + (256, 256, 4096, 128, 128): (3, 1, 32, 64, 1, 4), + (256, 256, 8192, 16, 16): (1, 32, 16, 64, 1, 2), + (256, 256, 8192, 32, 32): (1, 1, 32, 64, 1, 4), + (256, 256, 8192, 64, 64): (1, 1, 32, 64, 1, 4), + (256, 256, 8192, 128, 128): (2, 1, 64, 32, 1, 4), + (256, 256, 16384, 16, 16): (1, 1, 16, 64, 1, 2), + (256, 256, 16384, 32, 32): (1, 1, 32, 64, 1, 4), + (256, 256, 16384, 64, 64): (1, 128, 64, 64, 1, 4), + (256, 256, 16384, 128, 128): (2, 1, 64, 32, 1, 4), + (256, 256, 32768, 16, 16): (2, 128, 16, 64, 1, 1), + (256, 256, 32768, 32, 32): (1, 1, 32, 64, 1, 4), + (256, 256, 32768, 64, 64): (1, 128, 64, 64, 1, 4), + (256, 256, 32768, 128, 128): (2, 1, 64, 64, 1, 4), + (256, 256, 65536, 16, 16): (1, 1, 16, 64, 1, 2), + (256, 256, 65536, 32, 32): (1, 1, 32, 64, 1, 4), + (256, 256, 65536, 64, 64): (2, 1, 64, 64, 1, 4), + (256, 256, 65536, 128, 128): (1, 1, 128, 32, 1, 4), + (256, 256, 131072, 16, 16): (3, 128, 16, 64, 1, 1), + (256, 256, 131072, 32, 32): (1, 1, 32, 64, 1, 4), + (256, 256, 131072, 64, 64): (2, 1, 64, 64, 1, 4), + (256, 256, 131072, 128, 128): (1, 8192, 64, 16, 1, 4), + (512, 512, 256, 16, 16): (1, 2, 16, 16, 1, 1), + (512, 512, 256, 32, 32): (1, 4, 16, 16, 1, 1), + (512, 512, 256, 64, 64): (1, 16, 16, 16, 1, 1), + (512, 512, 256, 128, 128): (1, 1, 16, 32, 1, 4), + (512, 512, 512, 16, 16): (1, 8, 16, 32, 1, 2), + (512, 512, 512, 32, 32): (1, 8, 16, 32, 1, 2), + (512, 512, 512, 64, 64): (1, 2, 16, 32, 1, 2), + (512, 512, 512, 128, 128): (1, 1, 32, 32, 1, 4), + (512, 512, 1024, 16, 16): (1, 1, 16, 32, 1, 2), + (512, 512, 1024, 32, 32): (1, 1, 16, 32, 1, 2), + (512, 512, 1024, 64, 64): (1, 1, 16, 32, 1, 2), + (512, 512, 1024, 128, 128): (1, 1, 64, 32, 1, 4), + (512, 512, 2048, 16, 16): (1, 16, 16, 64, 1, 2), + (512, 512, 2048, 32, 32): (1, 1, 32, 32, 1, 4), + (512, 512, 2048, 64, 64): (1, 1, 32, 32, 1, 4), + (512, 512, 2048, 128, 128): (2, 1, 32, 32, 1, 4), + (512, 512, 4096, 16, 16): (2, 64, 16, 64, 1, 1), + (512, 512, 4096, 32, 32): (1, 64, 32, 64, 1, 4), + (512, 512, 4096, 64, 64): (1, 1, 32, 32, 1, 4), + (512, 512, 4096, 128, 128): (1, 1, 64, 32, 1, 4), + (512, 512, 8192, 16, 16): (2, 64, 16, 64, 1, 1), + (512, 512, 8192, 32, 32): (1, 256, 32, 32, 1, 1), + (512, 512, 8192, 64, 64): (1, 64, 64, 64, 1, 4), + (512, 512, 8192, 128, 128): (2, 1, 64, 32, 1, 8), + (512, 512, 16384, 16, 16): (2, 64, 16, 64, 1, 1), + (512, 512, 16384, 32, 32): (1, 128, 32, 32, 1, 1), + (512, 512, 16384, 64, 64): (1, 64, 64, 64, 1, 4), + (512, 512, 16384, 128, 128): (3, 1, 64, 32, 1, 8), + (512, 512, 32768, 16, 16): (2, 64, 16, 64, 1, 1), + (512, 512, 32768, 32, 32): (1, 128, 32, 32, 1, 1), + (512, 512, 32768, 64, 64): (1, 64, 64, 64, 1, 4), + (512, 512, 32768, 128, 128): (2, 1, 64, 32, 1, 8), + (512, 512, 65536, 16, 16): (2, 32, 16, 64, 1, 1), + (512, 512, 65536, 32, 32): (1, 128, 32, 32, 1, 1), + (512, 512, 65536, 64, 64): (1, 64, 64, 64, 1, 4), + (512, 512, 65536, 128, 128): (2, 1, 64, 32, 1, 8), + (512, 512, 131072, 16, 16): (2, 32, 16, 64, 1, 1), + (512, 512, 131072, 32, 32): (1, 128, 32, 32, 1, 1), + (512, 512, 131072, 64, 64): (3, 64, 64, 64, 1, 4), + (512, 512, 131072, 128, 128): (1, 8192, 64, 16, 1, 4), + (1024, 1024, 256, 16, 16): (1, 4, 16, 32, 1, 2), + (1024, 1024, 256, 32, 32): (1, 4, 16, 32, 1, 2), + (1024, 1024, 256, 64, 64): (1, 1, 16, 32, 1, 2), + (1024, 1024, 256, 128, 128): (1, 1, 16, 16, 1, 1), + (1024, 1024, 512, 16, 16): (1, 8, 16, 32, 1, 2), + (1024, 1024, 512, 32, 32): (1, 8, 16, 32, 1, 1), + (1024, 1024, 512, 64, 64): (1, 8, 32, 32, 1, 4), + (1024, 1024, 512, 128, 128): (2, 1, 32, 32, 1, 4), + (1024, 1024, 1024, 16, 16): (1, 16, 16, 32, 1, 2), + (1024, 1024, 1024, 32, 32): (1, 16, 32, 64, 1, 4), + (1024, 1024, 1024, 64, 64): (1, 16, 32, 64, 1, 4), + (1024, 1024, 1024, 128, 128): (1, 1, 32, 32, 1, 4), + (1024, 1024, 2048, 16, 16): (2, 32, 16, 64, 1, 1), + (1024, 1024, 2048, 32, 32): (1, 32, 32, 64, 1, 4), + (1024, 1024, 2048, 64, 64): (1, 32, 64, 64, 1, 4), + (1024, 1024, 2048, 128, 128): (1, 1, 32, 64, 1, 4), + (1024, 1024, 4096, 16, 16): (2, 16, 16, 64, 1, 1), + (1024, 1024, 4096, 32, 32): (1, 64, 32, 32, 1, 1), + (1024, 1024, 4096, 64, 64): (1, 64, 64, 64, 1, 4), + (1024, 1024, 4096, 128, 128): (2, 64, 64, 32, 1, 8), + (1024, 1024, 8192, 16, 16): (2, 16, 16, 64, 1, 1), + (1024, 1024, 8192, 32, 32): (1, 64, 32, 32, 1, 1), + (1024, 1024, 8192, 64, 64): (1, 64, 64, 64, 1, 4), + (1024, 1024, 8192, 128, 128): (4, 1, 32, 64, 1, 4), + (1024, 1024, 16384, 16, 16): (2, 16, 16, 64, 1, 1), + (1024, 1024, 16384, 32, 32): (1, 64, 32, 32, 1, 1), + (1024, 1024, 16384, 64, 64): (1, 32, 64, 64, 1, 4), + (1024, 1024, 16384, 128, 128): (2, 64, 64, 32, 1, 4), + (1024, 1024, 32768, 16, 16): (2, 16, 16, 64, 1, 1), + (1024, 1024, 32768, 32, 32): (1, 64, 32, 32, 1, 1), + (1024, 1024, 32768, 64, 64): (1, 32, 64, 64, 1, 4), + (1024, 1024, 32768, 128, 128): (4, 1, 32, 64, 1, 4), + (1024, 1024, 65536, 16, 16): (2, 16, 16, 64, 1, 1), + (1024, 1024, 65536, 32, 32): (1, 32, 32, 32, 1, 1), + (1024, 1024, 65536, 64, 64): (2, 32, 64, 64, 1, 4), + (1024, 1024, 65536, 128, 128): (4, 1, 64, 32, 1, 4), + (1024, 1024, 131072, 16, 16): (2, 16, 16, 64, 1, 1), + (1024, 1024, 131072, 32, 32): (1, 32, 32, 32, 1, 1), + (1024, 1024, 131072, 64, 64): (1, 16, 64, 64, 1, 4), + (1024, 1024, 131072, 128, 128): (1, 8192, 64, 16, 1, 4), + (2048, 2048, 256, 16, 16): (1, 4, 16, 32, 1, 2), + (2048, 2048, 256, 32, 32): (1, 8, 16, 32, 1, 1), + (2048, 2048, 256, 64, 64): (1, 8, 32, 32, 1, 4), + (2048, 2048, 256, 128, 128): (1, 4, 64, 64, 1, 8), + (2048, 2048, 512, 16, 16): (2, 8, 16, 32, 1, 2), + (2048, 2048, 512, 32, 32): (2, 8, 32, 64, 1, 4), + (2048, 2048, 512, 64, 64): (2, 4, 64, 64, 1, 4), + (2048, 2048, 512, 128, 128): (1, 8, 32, 64, 1, 4), + (2048, 2048, 1024, 16, 16): (2, 16, 16, 64, 3, 1), + (2048, 2048, 1024, 32, 32): (1, 32, 32, 32, 1, 1), + (2048, 2048, 1024, 64, 64): (1, 16, 64, 64, 1, 4), + (2048, 2048, 1024, 128, 128): (2, 4, 64, 64, 1, 8), + (2048, 2048, 2048, 16, 16): (2, 16, 16, 64, 1, 1), + (2048, 2048, 2048, 32, 32): (1, 32, 32, 32, 1, 1), + (2048, 2048, 2048, 64, 64): (1, 16, 64, 64, 1, 4), + (2048, 2048, 2048, 128, 128): (2, 32, 32, 64, 1, 4), + (2048, 2048, 4096, 16, 16): (3, 2, 16, 64, 1, 1), + (2048, 2048, 4096, 32, 32): (3, 4, 32, 32, 1, 1), + (2048, 2048, 4096, 64, 64): (1, 16, 64, 64, 1, 4), + (2048, 2048, 4096, 128, 128): (2, 32, 64, 32, 1, 4), + (2048, 2048, 8192, 16, 16): (3, 4, 16, 64, 1, 1), + (2048, 2048, 8192, 32, 32): (2, 4, 32, 32, 1, 1), + (2048, 2048, 8192, 64, 64): (2, 32, 64, 32, 1, 2), + (2048, 2048, 8192, 128, 128): (4, 1, 32, 64, 1, 4), + (2048, 2048, 16384, 16, 16): (3, 4, 16, 64, 1, 1), + (2048, 2048, 16384, 32, 32): (1, 4, 32, 32, 1, 1), + (2048, 2048, 16384, 64, 64): (2, 8, 64, 32, 1, 2), + (2048, 2048, 16384, 128, 128): (2, 8, 64, 32, 1, 4), + (2048, 2048, 32768, 16, 16): (2, 4, 16, 64, 1, 1), + (2048, 2048, 32768, 32, 32): (2, 8, 32, 32, 1, 1), + (2048, 2048, 32768, 64, 64): (1, 16, 64, 32, 1, 2), + (2048, 2048, 32768, 128, 128): (4, 1, 32, 64, 1, 4), + (2048, 2048, 65536, 16, 16): (3, 4, 16, 64, 1, 1), + (2048, 2048, 65536, 32, 32): (1, 8, 32, 32, 1, 1), + (2048, 2048, 65536, 64, 64): (1, 8, 64, 32, 1, 2), + (2048, 2048, 65536, 128, 128): (4, 1, 64, 32, 1, 4), + (2048, 2048, 131072, 16, 16): (2, 4, 16, 64, 1, 1), + (2048, 2048, 131072, 32, 32): (1, 8, 32, 32, 1, 1), + (2048, 2048, 131072, 64, 64): (3, 1, 64, 32, 1, 2), + (2048, 2048, 131072, 128, 128): (1, 8192, 128, 16, 1, 8), + (4096, 4096, 256, 16, 16): (2, 4, 16, 32, 1, 2), + (4096, 4096, 256, 32, 32): (1, 4, 32, 64, 1, 4), + (4096, 4096, 256, 64, 64): (1, 4, 64, 64, 1, 4), + (4096, 4096, 256, 128, 128): (1, 4, 32, 64, 1, 4), + (4096, 4096, 512, 16, 16): (2, 8, 16, 64, 3, 1), + (4096, 4096, 512, 32, 32): (2, 16, 32, 32, 1, 1), + (4096, 4096, 512, 64, 64): (1, 8, 64, 64, 1, 4), + (4096, 4096, 512, 128, 128): (1, 8, 32, 64, 1, 4), + (4096, 4096, 1024, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 1024, 32, 32): (1, 16, 32, 32, 1, 1), + (4096, 4096, 1024, 64, 64): (1, 16, 64, 32, 1, 2), + (4096, 4096, 1024, 128, 128): (1, 16, 32, 64, 1, 4), + (4096, 4096, 2048, 16, 16): (1, 16, 16, 64, 3, 1), + (4096, 4096, 2048, 32, 32): (1, 16, 32, 32, 1, 1), + (4096, 4096, 2048, 64, 64): (3, 16, 64, 32, 1, 2), + (4096, 4096, 2048, 128, 128): (4, 8, 32, 64, 1, 4), + (4096, 4096, 4096, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 4096, 32, 32): (1, 1, 32, 32, 1, 1), + (4096, 4096, 4096, 64, 64): (2, 16, 64, 32, 1, 2), + (4096, 4096, 4096, 128, 128): (4, 8, 32, 64, 1, 4), + (4096, 4096, 8192, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 8192, 32, 32): (2, 1, 32, 32, 1, 1), + (4096, 4096, 8192, 64, 64): (1, 16, 64, 32, 1, 2), + (4096, 4096, 8192, 128, 128): (2, 1, 32, 64, 1, 4), + (4096, 4096, 16384, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 16384, 32, 32): (1, 1, 32, 32, 1, 1), + (4096, 4096, 16384, 64, 64): (2, 8, 64, 32, 1, 2), + (4096, 4096, 16384, 128, 128): (2, 1, 32, 64, 1, 4), + (4096, 4096, 32768, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 32768, 32, 32): (1, 1, 32, 32, 1, 1), + (4096, 4096, 32768, 64, 64): (1, 8, 64, 32, 1, 2), + (4096, 4096, 32768, 128, 128): (2, 1, 32, 64, 1, 4), + (4096, 4096, 65536, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 65536, 32, 32): (3, 1, 32, 32, 1, 1), + (4096, 4096, 65536, 64, 64): (3, 4, 64, 32, 1, 2), + (4096, 4096, 65536, 128, 128): (2, 1, 32, 64, 1, 4), + (4096, 4096, 131072, 16, 16): (1, 8, 16, 64, 3, 1), + (4096, 4096, 131072, 32, 32): (1, 1, 32, 32, 1, 1), + (4096, 4096, 131072, 64, 64): (2, 8, 64, 32, 1, 2), + (4096, 4096, 131072, 128, 128): (1, 8192, 128, 16, 1, 8), + (8192, 8192, 256, 16, 16): (2, 4, 16, 64, 3, 1), + (8192, 8192, 256, 32, 32): (1, 8, 32, 32, 1, 1), + (8192, 8192, 256, 64, 64): (1, 4, 64, 64, 1, 4), + (8192, 8192, 256, 128, 128): (1, 4, 32, 64, 1, 4), + (8192, 8192, 512, 16, 16): (1, 4, 16, 64, 3, 1), + (8192, 8192, 512, 32, 32): (1, 16, 32, 32, 1, 1), + (8192, 8192, 512, 64, 64): (2, 4, 64, 64, 1, 4), + (8192, 8192, 512, 128, 128): (2, 1, 32, 64, 1, 4), + (8192, 8192, 1024, 16, 16): (3, 8, 16, 64, 3, 1), + (8192, 8192, 1024, 32, 32): (1, 16, 32, 32, 1, 1), + (8192, 8192, 1024, 64, 64): (1, 8, 64, 32, 1, 2), + (8192, 8192, 1024, 128, 128): (2, 4, 32, 64, 1, 4), + (8192, 8192, 2048, 16, 16): (1, 8, 16, 64, 3, 1), + (8192, 8192, 2048, 32, 32): (1, 16, 32, 32, 1, 1), + (8192, 8192, 2048, 64, 64): (2, 8, 64, 32, 1, 2), + (8192, 8192, 2048, 128, 128): (4, 1, 32, 64, 1, 4), + (8192, 8192, 4096, 16, 16): (1, 8, 16, 64, 3, 1), + (8192, 8192, 4096, 32, 32): (1, 16, 32, 32, 1, 1), + (8192, 8192, 4096, 64, 64): (1, 4, 64, 32, 1, 2), + (8192, 8192, 4096, 128, 128): (3, 1, 32, 64, 1, 4), + (8192, 8192, 8192, 16, 16): (1, 8, 16, 64, 3, 1), + (8192, 8192, 8192, 32, 32): (1, 8, 32, 32, 1, 1), + (8192, 8192, 8192, 64, 64): (1, 8, 64, 32, 1, 2), + (8192, 8192, 8192, 128, 128): (4, 1, 32, 64, 1, 4), + (8192, 8192, 16384, 16, 16): (3, 4, 16, 64, 3, 1), + (8192, 8192, 16384, 32, 32): (1, 8, 32, 32, 1, 1), + (8192, 8192, 16384, 64, 64): (2, 2, 64, 32, 1, 2), + (8192, 8192, 16384, 128, 128): (7, 1, 32, 64, 1, 4), + (8192, 8192, 32768, 16, 16): (1, 4, 16, 64, 3, 1), + (8192, 8192, 32768, 32, 32): (1, 8, 32, 32, 1, 1), + (8192, 8192, 32768, 64, 64): (3, 2, 64, 32, 1, 2), + (8192, 8192, 32768, 128, 128): (6, 1, 32, 64, 1, 4), + (8192, 8192, 65536, 16, 16): (1, 4, 16, 64, 3, 1), + (8192, 8192, 65536, 32, 32): (4, 8, 32, 32, 1, 1), + (8192, 8192, 65536, 64, 64): (1, 2, 64, 32, 1, 2), + (8192, 8192, 65536, 128, 128): (4, 1, 32, 64, 1, 4), + (8192, 8192, 131072, 16, 16): (1, 4, 16, 64, 3, 1), + (8192, 8192, 131072, 32, 32): (1, 8, 32, 32, 1, 1), + (8192, 8192, 131072, 64, 64): (5, 4, 64, 32, 1, 2), + (8192, 8192, 131072, 128, 128): (1, 4096, 128, 16, 1, 8), + (16384, 16384, 256, 16, 16): (1, 4, 16, 64, 3, 1), + (16384, 16384, 256, 32, 32): (1, 8, 32, 32, 1, 1), + (16384, 16384, 256, 64, 64): (1, 4, 64, 32, 1, 2), + (16384, 16384, 256, 128, 128): (1, 4, 32, 64, 1, 4), + (16384, 16384, 512, 16, 16): (1, 8, 16, 64, 3, 1), + (16384, 16384, 512, 32, 32): (1, 16, 32, 32, 1, 1), + (16384, 16384, 512, 64, 64): (1, 4, 64, 32, 1, 2), + (16384, 16384, 512, 128, 128): (3, 1, 32, 64, 1, 4), + (16384, 16384, 1024, 16, 16): (1, 8, 16, 64, 3, 1), + (16384, 16384, 1024, 32, 32): (1, 16, 32, 32, 1, 1), + (16384, 16384, 1024, 64, 64): (2, 4, 64, 32, 1, 2), + (16384, 16384, 1024, 128, 128): (1, 2, 32, 64, 1, 4), + (16384, 16384, 2048, 16, 16): (1, 4, 16, 64, 3, 1), + (16384, 16384, 2048, 32, 32): (1, 16, 32, 32, 1, 1), + (16384, 16384, 2048, 64, 64): (3, 4, 64, 32, 1, 2), + (16384, 16384, 2048, 128, 128): (2, 1, 32, 64, 1, 4), + (16384, 16384, 4096, 16, 16): (4, 8, 16, 64, 3, 1), + (16384, 16384, 4096, 32, 32): (5, 16, 32, 32, 1, 1), + (16384, 16384, 4096, 64, 64): (3, 2, 64, 32, 1, 2), + (16384, 16384, 4096, 128, 128): (2, 1, 32, 64, 1, 4), + (16384, 16384, 8192, 16, 16): (1, 4, 16, 64, 3, 1), + (16384, 16384, 8192, 32, 32): (1, 4, 32, 32, 1, 1), + (16384, 16384, 8192, 64, 64): (1, 2, 64, 32, 1, 2), + (16384, 16384, 8192, 128, 128): (2, 1, 32, 64, 1, 4), + (16384, 16384, 16384, 16, 16): (1, 8, 16, 64, 3, 1), + (16384, 16384, 16384, 32, 32): (1, 4, 32, 32, 1, 1), + (16384, 16384, 16384, 64, 64): (1, 2, 64, 32, 1, 2), + (16384, 16384, 16384, 128, 128): (3, 1, 32, 64, 1, 4), + (16384, 16384, 32768, 16, 16): (1, 4, 16, 64, 3, 1), + (16384, 16384, 32768, 32, 32): (1, 2, 32, 32, 1, 1), + (16384, 16384, 32768, 64, 64): (3, 2, 64, 32, 1, 2), + (16384, 16384, 32768, 128, 128): (3, 1, 32, 64, 1, 4), + (16384, 16384, 65536, 16, 16): (1, 8, 16, 64, 3, 1), + (16384, 16384, 65536, 32, 32): (1, 4, 32, 32, 1, 1), + (16384, 16384, 65536, 64, 64): (4, 4, 64, 32, 1, 2), + (16384, 16384, 65536, 128, 128): (5, 1, 32, 64, 1, 4), + (16384, 16384, 131072, 16, 16): (1, 2, 16, 64, 3, 1), + (16384, 16384, 131072, 32, 32): (1, 4, 32, 32, 1, 1), + (16384, 16384, 131072, 64, 64): (1, 2, 64, 32, 1, 2), + (16384, 16384, 131072, 128, 128): (1, 4096, 128, 16, 1, 8), + }, + # END GENERATED DATA +} + +if __name__ == "__main__": + for dtype in [torch.int8]: + for op in ["_int_bsr_dense_addmm"]: + main(op=op, force=False, dtype=dtype) + for dtype in [torch.float16, torch.bfloat16, torch.float32]: + for op in ["bsr_dense_addmm"]: + main(op=op, force=False, dtype=dtype) diff --git a/lib/python3.10/site-packages/torch/sparse/semi_structured.py b/lib/python3.10/site-packages/torch/sparse/semi_structured.py new file mode 100644 index 0000000000000000000000000000000000000000..0017a10e6771bb96f6c808ec9ece2e40dead4671 --- /dev/null +++ b/lib/python3.10/site-packages/torch/sparse/semi_structured.py @@ -0,0 +1,648 @@ +# mypy: allow-untyped-defs +import warnings +from collections import namedtuple +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch.sparse._semi_structured_conversions import ( + sparse_semi_structured_from_dense_cutlass, + sparse_semi_structured_to_dense_cutlass, +) +from torch.sparse._semi_structured_ops import ( + fallback_dispatcher, + semi_sparse_addmm, + semi_sparse_detach, + semi_sparse_indices, + semi_sparse_linear, + semi_sparse_mm, + semi_sparse_t, + semi_sparse_values, + semi_sparse_view, +) + + +__all__ = [ + "SparseSemiStructuredTensor", + "SparseSemiStructuredTensorCUTLASS", + "SparseSemiStructuredTensorCUSPARSELT", + "to_sparse_semi_structured", +] + +_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple( + "_SEMI_STRUCTURED_SPARSE_CONFIG", + "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols", +) + + +class SparseSemiStructuredTensor(torch.Tensor): + """ + This class implementes semi-structured sparsity as a Tensor subclass. + + Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse, + depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained + structured sparsity. + + There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS. + This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS + and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items. + Note that as such, this class cannot be insantiated directly. + + -`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints + - `def from_dense()` - backend specific compression routines + - `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm)) + """ + + _DEFAULT_ALG_ID: int = 0 + _DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG] + _FORCE_CUTLASS: bool = True + _FUSE_TRANSPOSE: bool = False + _PROTOTYPE_WARNING_SHOWN: bool = False + + BACKEND: str + SPARSE_DISPATCH: Dict[Callable, Callable] + + packed: Optional[torch.Tensor] + meta: Optional[torch.Tensor] + packed_t: Optional[torch.Tensor] + meta_t: Optional[torch.Tensor] + compressed_swizzled_bitmask: Optional[torch.Tensor] + fuse_transpose_cusparselt: bool + alg_id_cusparselt: int + + __slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"] + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + packed: Optional[torch.Tensor], + meta: Optional[torch.Tensor], + packed_t: Optional[torch.Tensor], + meta_t: Optional[torch.Tensor], + compressed_swizzled_bitmask: Optional[torch.Tensor], + fuse_transpose_cusparselt: bool = False, + alg_id_cusparselt: int = 0, + requires_grad: bool = False, + ): + """ + Create a new instance of the tensor subclass from the compressed sparse representation. + + We have the option to create the subclass with the compressed representations of both X and X', for training. + For inference, we only need a single representation (either X or X'), while the corresponding other set will be None. + + Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS) + + Args: + shape: The shape of the original dense tensor + packed: The compressed representation of the original dense tensor + meta: The metadata of the original dense tensor, if it is stored separately + packed_t: The compressed representation of the transposed original dense tensor + meta_t: The metadata of the transposed original dense tensor, if it is stored separately + compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should + participate in the computation. Used for pointwise ops. + fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition + with a matmul, which is useful in the case of 2:4 sparse training. + alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance + + Returns: + torch.Tensor: A torch.Tensor wrapper subclass. + + Raises: + ValueError: If all of the tensor arguments are None. + """ + if not cls._PROTOTYPE_WARNING_SHOWN: + warnings.warn( + ( + "The PyTorch API of SparseSemiStructuredTensor is in prototype stage " + "and will change in the near future. Please open a Github issue " + "for features requests and see our documentation on the torch.sparse " + "module for further information about the project." + ), + UserWarning, + ) + cls._PROTOTYPE_WARNING_SHOWN = True + + # Because this only runs onces, we also load the dispatch table here as well. + # We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead + # But this is useful since it allows users to overload the dispatch table for debugging / testing. + cls._load_dispatch_table() + + # we can also register the classes with dynamo when the warning is shown. + torch._dynamo.allow_in_graph(cls) + + if packed is not None: + previous_tensor = packed + elif packed_t is not None: + previous_tensor = packed_t + else: + raise ValueError("At least one of packed or packed_t must be provided") + + kwargs = { + "device": previous_tensor.device, + "dtype": previous_tensor.dtype, + "layout": previous_tensor.layout, + "requires_grad": requires_grad, + } + tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + tensor.packed = packed + tensor.meta = meta + tensor.packed_t = packed_t + tensor.meta_t = meta_t + tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask + tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt + tensor.alg_id_cusparselt = alg_id_cusparselt + return tensor + + def __repr__(self) -> str: # type: ignore[override] + assert hasattr(self, "shape") + return f"{self.__class__.__name__}(shape={self.shape})" + + def __tensor_flatten__( + self, + ) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]: + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = ( + self.shape, + self.fuse_transpose_cusparselt, + self.alg_id_cusparselt, + self.requires_grad, + ) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta: Tuple[torch.Size, bool, int, bool], + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta + return cls( + shape=shape, + packed=inner_tensors.get("packed", None), + meta=inner_tensors.get("meta", None), + packed_t=inner_tensors.get("packed_t", None), + meta_t=inner_tensors.get("meta_t", None), + compressed_swizzled_bitmask=inner_tensors.get( + "compressed_swizzled_bitmask", None + ), + fuse_transpose_cusparselt=fuse_transpose_cusparselt, + alg_id_cusparselt=alg_id_cusparselt, + requires_grad=requires_grad, + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: + if func._overloadpacket not in cls.SPARSE_DISPATCH: + raise NotImplementedError( + f"{cls.__name__} only supports a specific set of operations, " + f"can't perform requested op ({func.__name__})" + ) + return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs) + + @classmethod + def _load_dispatch_table(cls, custom_dispatch_table=None) -> None: + """ + Loads the op overload sparse dispatch table for the current class. + """ + if getattr(cls, "SPARSE_DISPATCH", None) is None: + cls.SPARSE_DISPATCH = { + torch.ops.aten.values: semi_sparse_values, + torch.ops.aten.indices: semi_sparse_indices, + torch.ops.aten.is_same_size: fallback_dispatcher, + torch.ops.aten.detach_: fallback_dispatcher, + torch.ops.aten.detach: semi_sparse_detach, + torch.ops.aten.t: semi_sparse_t, + torch.ops.aten.view: semi_sparse_view, + torch.ops.aten.mm: semi_sparse_mm, + torch.ops.aten.matmul: semi_sparse_mm, + torch.ops.aten.addmm: semi_sparse_addmm, + torch.ops.aten.linear: semi_sparse_linear, + torch.ops.aten._to_copy: fallback_dispatcher, + } + if custom_dispatch_table is not None: + cls.SPARSE_DISPATCH.update(custom_dispatch_table) + + @classmethod + def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None: + """ + Assert that the given tensor is valid for semi-structured sparse compression. + """ + # check device + if not original_tensor.is_cuda: + raise RuntimeError( + f"Error original_tensor.device= {original_tensor.device} is not supported! " + "Only CUDA tensors are currently supported." + ) + + # check dim + if original_tensor.dim() != 2: + raise RuntimeError( + f"Error original_tensor.dim = {original_tensor.dim()} is not supported! " + "Only 2d tensors are currently supported." + ) + + # check contiguous + if not original_tensor.is_contiguous(): + raise RuntimeError( + "Error original_tensor is not contiguous!" + "Only contiguous tensors are currently supported." + ) + + # check dtype + if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS: + raise RuntimeError( + f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! " + "dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}" + ) + + # check shape + m, n = original_tensor.shape + min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows + min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols + if m < min_rows or m % min_rows or n < min_cols or n % min_cols: + # TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples + raise RuntimeError( + f"Error original_tensor.shape {original_tensor.shape} is not supported! " + f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})" + ) + + @classmethod + def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor: + """ + Calculates padding for dense tensor and pads tensor if necessary. + If padding is not required, this function returns the original tensor. + """ + # only 2d matmul + assert dense_input.dim() == 2 + + # check shape + m, n = dense_input.shape + min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows + min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols + + # calculate padding + to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 + to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 + if to_pad_m or to_pad_n: + return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m)) + else: + return dense_input + + def to_dense(self): + col = self.shape[-1] + return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device)) + + @classmethod + def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor": + raise NotImplementedError + + def _mm( + self, + B: torch.Tensor, + *, + bias: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + raise NotImplementedError + + +def to_sparse_semi_structured( + original_tensor: torch.Tensor, + transposed: bool = False, +) -> SparseSemiStructuredTensor: + """ + This function converts a dense tensor into a sparse semi-structured tensor. + It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor. + + This function will check to ensure the dense tensor has the right dtype, size, dims, and device. + We currently only support semi-structured sparse tensors for 2d CUDA tensors. + Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in + `_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8). + + Args: + original_tensor (Tensor): the dense tensor to convert + transposed (bool, optional): deprecated arg to be removed in another release. Do not use. + Returns: + SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor + Raises: + None + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda() + tensor([[0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.], + ..., + [0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16) + >>> A_sparse = to_sparse_semi_structured(A) + SparseSemiStructuredTensor(shape=torch.Size([128, 128])) + >>> A_sparse.values() + tensor([[1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + ..., + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), + >>> A_sparse.indices() + tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + ..., + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16)) + """ + if transposed: + warnings.warn( + "Setting transpose from `to_sparse_semi_structured` is deprecated " + "and will be removed in a future release. " + "`SparseSemiStructuredTensor` only support contiguous input tensors.", + FutureWarning, + stacklevel=2, + ) + + # set from _FORCE_CUTLASS flag + SPARSE_SUBCLASS = ( + torch.sparse.SparseSemiStructuredTensorCUTLASS + if SparseSemiStructuredTensor._FORCE_CUTLASS + else torch.sparse.SparseSemiStructuredTensorCUSPARSELT + ) + + return SPARSE_SUBCLASS.from_dense(original_tensor) + + +class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): + """ + This class implements semi-structured sparsity for the CUTLASS backend. + + + In this implementation, the specified elements and metadata are stored seprately, + in packed and meta respectively. + + When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and + sparse_semi_structured_from_dense for conversion to the compressed format. + """ + + BACKEND = "cutlass" + _DTYPE_SHAPE_CONSTRAINTS = { + torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16), + torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8), + torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8), + torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4), + } + + @classmethod + def from_dense( + cls, original_tensor: torch.Tensor + ) -> "SparseSemiStructuredTensorCUTLASS": + cls._validate_device_dim_dtype_shape(original_tensor) + ( + sparse_tensor_cutlass, + meta_tensor_cutlass, + ) = sparse_semi_structured_from_dense_cutlass(original_tensor) + return cls( + original_tensor.shape, + packed=sparse_tensor_cutlass, + meta=meta_tensor_cutlass, + packed_t=None, + meta_t=None, + compressed_swizzled_bitmask=None, + requires_grad=original_tensor.requires_grad, + ) + + def to_dense(self): + assert self.meta is not None and self.packed is not None + return ( + sparse_semi_structured_to_dense_cutlass( + self.packed, + self.meta, + ) + if self.meta.ndim == 2 + else super().to_dense() + ) + + @classmethod + def prune_dense_static_sort( + cls, original_tensor: torch.Tensor, algorithm="" + ) -> "SparseSemiStructuredTensor": + """ + This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile. + + It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns. + The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`. + + Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor. + It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed + pruned dense tensor. + Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively. + + Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern + This can be used in the backward pass to mask the gradients. + + [9 1 7 4] [9 0 7 0] + [1 2 3 0] [0 2 0 0] + [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed + [1 2 6 2] [0 0 6 2] -> metadata + + -> pack to transposed CUTLASS -> packed_t + semi-structured representation -> metadata_t + + -> compute swizzled bitmask -> compressed_swizzled_bitmask + + + The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below: + ``` + from torch.sparse import SparseSemiStructuredTensorCUTLASS + from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask + + pruned = _sparse_semi_structured_tile(dense) + packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned) + packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous()) + bitmask = _compute_compressed_swizzled_bitmask(pruned) + + SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask) + ``` + """ + # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag. + ( + packed, + meta, + packed_t, + meta_t, + compressed_swizzled_bitmask, + ) = torch._sparse_semi_structured_tile( + original_tensor, algorithm=algorithm, use_cutlass=True + ) + + return cls( + original_tensor.shape, + packed=packed, + meta=meta, + packed_t=packed_t, + meta_t=meta_t, + compressed_swizzled_bitmask=compressed_swizzled_bitmask, + requires_grad=False, + ) + + def _mm( + self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs + ) -> torch.Tensor: + if isinstance(B, SparseSemiStructuredTensor): + raise ValueError( + "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware" + ) + cls_name = self.__class__.__name__ + if self.ndim != 2 or B.ndim != 2: + raise NotImplementedError( + f"`{cls_name}` matmul: Broadcasting is not implemented" + ) + if self.packed is None or self.meta is None: + raise NotImplementedError( + f"`{cls_name}` matmul: operation is not supported" + ) + else: + if bias is None: + res = torch._sparse_semi_structured_mm(self.packed, self.meta, B) + else: + res = torch._sparse_semi_structured_addmm( + bias, self.packed, self.meta, B + ) + return res[: self.shape[0]] + + +class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): + """ + The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor: + packed = [ specified elements of original tensor | metadata ] + For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements + The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t + attributes respectively. + + cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well + as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes. + """ + + BACKEND = "cusparselt" + _DTYPE_SHAPE_CONSTRAINTS = { + torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16), + torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8), + torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8), + } + + @classmethod + def from_dense( + cls, original_tensor: torch.Tensor + ) -> "SparseSemiStructuredTensorCUSPARSELT": + cls._validate_device_dim_dtype_shape(original_tensor) + return cls( + shape=original_tensor.shape, + packed=torch._cslt_compress(original_tensor), + meta=None, + packed_t=None, + meta_t=None, + compressed_swizzled_bitmask=None, + fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE, + alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID, + requires_grad=original_tensor.requires_grad, + ) + + @classmethod + def prune_dense_static_sort( + cls, original_tensor: torch.Tensor, algorithm="" + ) -> "SparseSemiStructuredTensor": + """ + This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata + layout and sparse matmul. + + The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor. + + [9 1 7 4] [9 0 7 0] + [1 2 3 0] [0 2 0 0] + [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed + [1 2 6 2] [0 0 6 2] + + -> pack to transposed cuSPARSELt -> packed_t + semi-structured representation + + -> compute swizzled bitmask -> compressed_swizzled_bitmask + + + The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below: + ``` + from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask + + pruned = _sparse_semi_structured_tile(dense) + packed_cusparselt = torch._cslt_compress(pruned) + packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous()) + bitmask = _compute_compressed_swizzled_bitmask(pruned) + + SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask) + ``` + """ + ( + packed, + meta, + packed_t, + meta_t, + compressed_swizzled_bitmask, + ) = torch._sparse_semi_structured_tile( + original_tensor, algorithm=algorithm, use_cutlass=False + ) + + return cls( + original_tensor.shape, + packed=packed, + meta=meta, + packed_t=packed_t, + meta_t=meta_t, + compressed_swizzled_bitmask=compressed_swizzled_bitmask, + requires_grad=False, + ) + + def _mm( + self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs + ) -> torch.Tensor: + if isinstance(B, SparseSemiStructuredTensor): + raise ValueError( + "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware" + ) + if self.ndim != 2 or B.ndim != 2: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented" + ) + if B.dtype != self.dtype: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, " + f"with A.dtype={self.dtype} and B.dtype={B.dtype}. " + "This operation is only supported when A and B have the same data type." + ) + if bias is not None and bias.dtype != self.dtype: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, " + "with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. " + "This operation is only supported when A, B and C have the same data type." + ) + if self.packed is None: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: operation is not supported" + ) + else: + res = torch._cslt_sparse_mm( + self.packed, + B, + bias=bias, + transpose_result=self.fuse_transpose_cusparselt, + alg_id=self.alg_id_cusparselt, + ) + return res.t() if self.fuse_transpose_cusparselt else res diff --git a/lib/python3.10/site-packages/torch/special/__init__.py b/lib/python3.10/site-packages/torch/special/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07e104c4090ef8456f8acf5f2548bbbc3532a2d3 --- /dev/null +++ b/lib/python3.10/site-packages/torch/special/__init__.py @@ -0,0 +1,1283 @@ +import torch +from torch._C import _add_docstr, _special # type: ignore[attr-defined] +from torch._torch_docs import common_args, multi_dim_common + +__all__ = [ + 'airy_ai', + 'bessel_j0', + 'bessel_j1', + 'bessel_y0', + 'bessel_y1', + 'chebyshev_polynomial_t', + 'chebyshev_polynomial_u', + 'chebyshev_polynomial_v', + 'chebyshev_polynomial_w', + 'digamma', + 'entr', + 'erf', + 'erfc', + 'erfcx', + 'erfinv', + 'exp2', + 'expit', + 'expm1', + 'gammainc', + 'gammaincc', + 'gammaln', + 'hermite_polynomial_h', + 'hermite_polynomial_he', + 'i0', + 'i0e', + 'i1', + 'i1e', + 'laguerre_polynomial_l', + 'legendre_polynomial_p', + 'log1p', + 'log_ndtr', + 'log_softmax', + 'logit', + 'logsumexp', + 'modified_bessel_i0', + 'modified_bessel_i1', + 'modified_bessel_k0', + 'modified_bessel_k1', + 'multigammaln', + 'ndtr', + 'ndtri', + 'polygamma', + 'psi', + 'round', + 'shifted_chebyshev_polynomial_t', + 'shifted_chebyshev_polynomial_u', + 'shifted_chebyshev_polynomial_v', + 'shifted_chebyshev_polynomial_w', + 'scaled_modified_bessel_k0', + 'scaled_modified_bessel_k1', + 'sinc', + 'softmax', + 'spherical_bessel_j0', + 'xlog1py', + 'xlogy', + 'zeta', +] + +Tensor = torch.Tensor + +entr = _add_docstr(_special.special_entr, + r""" +entr(input, *, out=None) -> Tensor +Computes the entropy on :attr:`input` (as defined below), elementwise. + +.. math:: + \begin{align} + \text{entr(x)} = \begin{cases} + -x * \ln(x) & x > 0 \\ + 0 & x = 0.0 \\ + -\infty & x < 0 + \end{cases} + \end{align} +""" + """ + +Args: + input (Tensor): the input tensor. + +Keyword args: + out (Tensor, optional): the output tensor. + +Example:: + >>> a = torch.arange(-0.5, 1, 0.5) + >>> a + tensor([-0.5000, 0.0000, 0.5000]) + >>> torch.special.entr(a) + tensor([ -inf, 0.0000, 0.3466]) +""") + +psi = _add_docstr(_special.special_psi, + r""" +psi(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.digamma`. +""") + +digamma = _add_docstr(_special.special_digamma, + r""" +digamma(input, *, out=None) -> Tensor + +Computes the logarithmic derivative of the gamma function on `input`. + +.. math:: + \digamma(x) = \frac{d}{dx} \ln\left(\Gamma\left(x\right)\right) = \frac{\Gamma'(x)}{\Gamma(x)} +""" + r""" +Args: + input (Tensor): the tensor to compute the digamma function on + +Keyword args: + {out} + +.. note:: This function is similar to SciPy's `scipy.special.digamma`. + +.. note:: From PyTorch 1.8 onwards, the digamma function returns `-Inf` for `0`. + Previously it returned `NaN` for `0`. + +Example:: + + >>> a = torch.tensor([1, 0.5]) + >>> torch.special.digamma(a) + tensor([-0.5772, -1.9635]) + +""".format(**common_args)) + +gammaln = _add_docstr(_special.special_gammaln, + r""" +gammaln(input, *, out=None) -> Tensor + +Computes the natural logarithm of the absolute value of the gamma function on :attr:`input`. + +.. math:: + \text{out}_{i} = \ln \Gamma(|\text{input}_{i}|) +""" + """ +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.arange(0.5, 2, 0.5) + >>> torch.special.gammaln(a) + tensor([ 0.5724, 0.0000, -0.1208]) + +""".format(**common_args)) + +polygamma = _add_docstr(_special.special_polygamma, + r""" +polygamma(n, input, *, out=None) -> Tensor + +Computes the :math:`n^{th}` derivative of the digamma function on :attr:`input`. +:math:`n \geq 0` is called the order of the polygamma function. + +.. math:: + \psi^{(n)}(x) = \frac{d^{(n)}}{dx^{(n)}} \psi(x) + +.. note:: + This function is implemented only for nonnegative integers :math:`n \geq 0`. +""" + """ +Args: + n (int): the order of the polygamma function + {input} + +Keyword args: + {out} + +Example:: + >>> a = torch.tensor([1, 0.5]) + >>> torch.special.polygamma(1, a) + tensor([1.64493, 4.9348]) + >>> torch.special.polygamma(2, a) + tensor([ -2.4041, -16.8288]) + >>> torch.special.polygamma(3, a) + tensor([ 6.4939, 97.4091]) + >>> torch.special.polygamma(4, a) + tensor([ -24.8863, -771.4742]) +""".format(**common_args)) + +erf = _add_docstr(_special.special_erf, + r""" +erf(input, *, out=None) -> Tensor + +Computes the error function of :attr:`input`. The error function is defined as follows: + +.. math:: + \mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt +""" + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.erf(torch.tensor([0, -1., 10.])) + tensor([ 0.0000, -0.8427, 1.0000]) +""".format(**common_args)) + +erfc = _add_docstr(_special.special_erfc, + r""" +erfc(input, *, out=None) -> Tensor + +Computes the complementary error function of :attr:`input`. +The complementary error function is defined as follows: + +.. math:: + \mathrm{erfc}(x) = 1 - \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt +""" + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.erfc(torch.tensor([0, -1., 10.])) + tensor([ 1.0000, 1.8427, 0.0000]) +""".format(**common_args)) + +erfcx = _add_docstr(_special.special_erfcx, + r""" +erfcx(input, *, out=None) -> Tensor + +Computes the scaled complementary error function for each element of :attr:`input`. +The scaled complementary error function is defined as follows: + +.. math:: + \mathrm{erfcx}(x) = e^{x^2} \mathrm{erfc}(x) +""" + r""" + +""" + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.erfcx(torch.tensor([0, -1., 10.])) + tensor([ 1.0000, 5.0090, 0.0561]) +""".format(**common_args)) + +erfinv = _add_docstr(_special.special_erfinv, + r""" +erfinv(input, *, out=None) -> Tensor + +Computes the inverse error function of :attr:`input`. +The inverse error function is defined in the range :math:`(-1, 1)` as: + +.. math:: + \mathrm{erfinv}(\mathrm{erf}(x)) = x +""" + r""" + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.erfinv(torch.tensor([0, 0.5, -1.])) + tensor([ 0.0000, 0.4769, -inf]) +""".format(**common_args)) + +logit = _add_docstr(_special.special_logit, + r""" +logit(input, eps=None, *, out=None) -> Tensor + +Returns a new tensor with the logit of the elements of :attr:`input`. +:attr:`input` is clamped to [eps, 1 - eps] when eps is not None. +When eps is None and :attr:`input` < 0 or :attr:`input` > 1, the function will yields NaN. + +.. math:: + \begin{align} + y_{i} &= \ln(\frac{z_{i}}{1 - z_{i}}) \\ + z_{i} &= \begin{cases} + x_{i} & \text{if eps is None} \\ + \text{eps} & \text{if } x_{i} < \text{eps} \\ + x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\ + 1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps} + \end{cases} + \end{align} +""" + r""" +Args: + {input} + eps (float, optional): the epsilon for input clamp bound. Default: ``None`` + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(5) + >>> a + tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + >>> torch.special.logit(a, eps=1e-6) + tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261]) +""".format(**common_args)) + +logsumexp = _add_docstr(_special.special_logsumexp, + r""" +logsumexp(input, dim, keepdim=False, *, out=None) + +Alias for :func:`torch.logsumexp`. +""".format(**multi_dim_common)) + +expit = _add_docstr(_special.special_expit, + r""" +expit(input, *, out=None) -> Tensor + +Computes the expit (also known as the logistic sigmoid function) of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \frac{1}{1 + e^{-\text{input}_{i}}} +""" + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> t = torch.randn(4) + >>> t + tensor([ 0.9213, 1.0887, -0.8858, -1.7683]) + >>> torch.special.expit(t) + tensor([ 0.7153, 0.7481, 0.2920, 0.1458]) +""".format(**common_args)) + +exp2 = _add_docstr(_special.special_exp2, + r""" +exp2(input, *, out=None) -> Tensor + +Computes the base two exponential function of :attr:`input`. + +.. math:: + y_{i} = 2^{x_{i}} + +""" + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4])) + tensor([ 1., 2., 8., 16.]) +""".format(**common_args)) + +expm1 = _add_docstr(_special.special_expm1, + r""" +expm1(input, *, out=None) -> Tensor + +Computes the exponential of the elements minus 1 +of :attr:`input`. + +.. math:: + y_{i} = e^{x_{i}} - 1 + +.. note:: This function provides greater precision than exp(x) - 1 for small values of x. + +""" + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.special.expm1(torch.tensor([0, math.log(2.)])) + tensor([ 0., 1.]) +""".format(**common_args)) + +xlog1py = _add_docstr(_special.special_xlog1py, + r""" +xlog1py(input, other, *, out=None) -> Tensor + +Computes ``input * log1p(other)`` with the following cases. + +.. math:: + \text{out}_{i} = \begin{cases} + \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\ + 0 & \text{if } \text{input}_{i} = 0.0 \text{ and } \text{other}_{i} != \text{NaN} \\ + \text{input}_{i} * \text{log1p}(\text{other}_{i})& \text{otherwise} + \end{cases} + +Similar to SciPy's `scipy.special.xlog1py`. + +""" + r""" + +Args: + input (Number or Tensor) : Multiplier + other (Number or Tensor) : Argument + +.. note:: At least one of :attr:`input` or :attr:`other` must be a tensor. + +Keyword args: + {out} + +Example:: + + >>> x = torch.zeros(5,) + >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')]) + >>> torch.special.xlog1py(x, y) + tensor([0., 0., 0., 0., nan]) + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([3, 2, 1]) + >>> torch.special.xlog1py(x, y) + tensor([1.3863, 2.1972, 2.0794]) + >>> torch.special.xlog1py(x, 4) + tensor([1.6094, 3.2189, 4.8283]) + >>> torch.special.xlog1py(2, y) + tensor([2.7726, 2.1972, 1.3863]) +""".format(**common_args)) + +xlogy = _add_docstr(_special.special_xlogy, + r""" +xlogy(input, other, *, out=None) -> Tensor + +Computes ``input * log(other)`` with the following cases. + +.. math:: + \text{out}_{i} = \begin{cases} + \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\ + 0 & \text{if } \text{input}_{i} = 0.0 \\ + \text{input}_{i} * \log{(\text{other}_{i})} & \text{otherwise} + \end{cases} + +Similar to SciPy's `scipy.special.xlogy`. + +""" + r""" + +Args: + input (Number or Tensor) : Multiplier + other (Number or Tensor) : Argument + +.. note:: At least one of :attr:`input` or :attr:`other` must be a tensor. + +Keyword args: + {out} + +Example:: + + >>> x = torch.zeros(5,) + >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')]) + >>> torch.special.xlogy(x, y) + tensor([0., 0., 0., 0., nan]) + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([3, 2, 1]) + >>> torch.special.xlogy(x, y) + tensor([1.0986, 1.3863, 0.0000]) + >>> torch.special.xlogy(x, 4) + tensor([1.3863, 2.7726, 4.1589]) + >>> torch.special.xlogy(2, y) + tensor([2.1972, 1.3863, 0.0000]) +""".format(**common_args)) + +i0 = _add_docstr(_special.special_i0, + r""" +i0(input, *, out=None) -> Tensor + +Computes the zeroth order modified Bessel function of the first kind for each element of :attr:`input`. + +.. math:: + \text{out}_{i} = I_0(\text{input}_{i}) = \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2} + +""" + r""" +Args: + input (Tensor): the input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.i0(torch.arange(5, dtype=torch.float32)) + tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019]) + +""".format(**common_args)) + +i0e = _add_docstr(_special.special_i0e, + r""" +i0e(input, *, out=None) -> Tensor +Computes the exponentially scaled zeroth order modified Bessel function of the first kind (as defined below) +for each element of :attr:`input`. + +.. math:: + \text{out}_{i} = \exp(-|x|) * i0(x) = \exp(-|x|) * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2} + +""" + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + >>> torch.special.i0e(torch.arange(5, dtype=torch.float32)) + tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070]) +""".format(**common_args)) + +i1 = _add_docstr(_special.special_i1, + r""" +i1(input, *, out=None) -> Tensor +Computes the first order modified Bessel function of the first kind (as defined below) +for each element of :attr:`input`. + +.. math:: + \text{out}_{i} = \frac{(\text{input}_{i})}{2} * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!) * (k+1)!} + +""" + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + >>> torch.special.i1(torch.arange(5, dtype=torch.float32)) + tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595]) +""".format(**common_args)) + +i1e = _add_docstr(_special.special_i1e, + r""" +i1e(input, *, out=None) -> Tensor +Computes the exponentially scaled first order modified Bessel function of the first kind (as defined below) +for each element of :attr:`input`. + +.. math:: + \text{out}_{i} = \exp(-|x|) * i1(x) = + \exp(-|x|) * \frac{(\text{input}_{i})}{2} * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!) * (k+1)!} + +""" + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + >>> torch.special.i1e(torch.arange(5, dtype=torch.float32)) + tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788]) +""".format(**common_args)) + +ndtr = _add_docstr(_special.special_ndtr, + r""" +ndtr(input, *, out=None) -> Tensor +Computes the area under the standard Gaussian probability density function, +integrated from minus infinity to :attr:`input`, elementwise. + +.. math:: + \text{ndtr}(x) = \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt + +""" + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + >>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) + tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987]) +""".format(**common_args)) + +ndtri = _add_docstr(_special.special_ndtri, + r""" +ndtri(input, *, out=None) -> Tensor +Computes the argument, x, for which the area under the Gaussian probability density function +(integrated from minus infinity to x) is equal to :attr:`input`, elementwise. + +.. math:: + \text{ndtri}(p) = \sqrt{2}\text{erf}^{-1}(2p - 1) + +.. note:: + Also known as quantile function for Normal Distribution. + +""" + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + >>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1])) + tensor([ -inf, -0.6745, 0.0000, 0.6745, inf]) +""".format(**common_args)) + +log_ndtr = _add_docstr(_special.special_log_ndtr, + r""" +log_ndtr(input, *, out=None) -> Tensor +Computes the log of the area under the standard Gaussian probability density function, +integrated from minus infinity to :attr:`input`, elementwise. + +.. math:: + \text{log\_ndtr}(x) = \log\left(\frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \right) + +""" + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + >>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) + tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014]) +""".format(**common_args)) + +log1p = _add_docstr(_special.special_log1p, + r""" +log1p(input, *, out=None) -> Tensor + +Alias for :func:`torch.log1p`. +""") + +sinc = _add_docstr(_special.special_sinc, + r""" +sinc(input, *, out=None) -> Tensor + +Computes the normalized sinc of :attr:`input.` + +.. math:: + \text{out}_{i} = + \begin{cases} + 1, & \text{if}\ \text{input}_{i}=0 \\ + \sin(\pi \text{input}_{i}) / (\pi \text{input}_{i}), & \text{otherwise} + \end{cases} +""" + r""" + +Args: + {input} + +Keyword args: + {out} + +Example:: + >>> t = torch.randn(4) + >>> t + tensor([ 0.2252, -0.2948, 1.0267, -1.1566]) + >>> torch.special.sinc(t) + tensor([ 0.9186, 0.8631, -0.0259, -0.1300]) +""".format(**common_args)) + +round = _add_docstr(_special.special_round, + r""" +round(input, *, out=None) -> Tensor + +Alias for :func:`torch.round`. +""") + +softmax = _add_docstr(_special.special_softmax, + r""" +softmax(input, dim, *, dtype=None) -> Tensor + +Computes the softmax function. + +Softmax is defined as: + +:math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}` + +It is applied to all slices along dim, and will re-scale them so that the elements +lie in the range `[0, 1]` and sum to 1. + +Args: + input (Tensor): input + dim (int): A dimension along which softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is cast to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + +Examples:: + >>> t = torch.ones(2, 2) + >>> torch.special.softmax(t, 0) + tensor([[0.5000, 0.5000], + [0.5000, 0.5000]]) + +""") + +log_softmax = _add_docstr(_special.special_log_softmax, + r""" +log_softmax(input, dim, *, dtype=None) -> Tensor + +Computes softmax followed by a logarithm. + +While mathematically equivalent to log(softmax(x)), doing these two +operations separately is slower and numerically unstable. This function +is computed as: + +.. math:: + \text{log\_softmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) +""" + r""" + +Args: + input (Tensor): input + dim (int): A dimension along which log_softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is cast to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + +Example:: + >>> t = torch.ones(2, 2) + >>> torch.special.log_softmax(t, 0) + tensor([[-0.6931, -0.6931], + [-0.6931, -0.6931]]) +""") + +zeta = _add_docstr(_special.special_zeta, + r""" +zeta(input, other, *, out=None) -> Tensor + +Computes the Hurwitz zeta function, elementwise. + +.. math:: + \zeta(x, q) = \sum_{k=0}^{\infty} \frac{1}{(k + q)^x} + +""" + r""" +Args: + input (Tensor): the input tensor corresponding to `x`. + other (Tensor): the input tensor corresponding to `q`. + +.. note:: + The Riemann zeta function corresponds to the case when `q = 1` + +Keyword args: + {out} + +Example:: + >>> x = torch.tensor([2., 4.]) + >>> torch.special.zeta(x, 1) + tensor([1.6449, 1.0823]) + >>> torch.special.zeta(x, torch.tensor([1., 2.])) + tensor([1.6449, 0.0823]) + >>> torch.special.zeta(2, torch.tensor([1., 2.])) + tensor([1.6449, 0.6449]) +""".format(**common_args)) + +multigammaln = _add_docstr(_special.special_multigammaln, + r""" +multigammaln(input, p, *, out=None) -> Tensor + +Computes the `multivariate log-gamma function +`_ with dimension +:math:`p` element-wise, given by + +.. math:: + \log(\Gamma_{p}(a)) = C + \displaystyle \sum_{i=1}^{p} \log\left(\Gamma\left(a - \frac{i - 1}{2}\right)\right) + +where :math:`C = \log(\pi) \cdot \frac{p (p - 1)}{4}` and :math:`\Gamma(-)` is the Gamma function. + +All elements must be greater than :math:`\frac{p - 1}{2}`, otherwise the behavior is undefiend. +""" + """ + +Args: + input (Tensor): the tensor to compute the multivariate log-gamma function + p (int): the number of dimensions + +Keyword args: + {out} + +Example:: + + >>> a = torch.empty(2, 3).uniform_(1, 2) + >>> a + tensor([[1.6835, 1.8474, 1.1929], + [1.0475, 1.7162, 1.4180]]) + >>> torch.special.multigammaln(a, 2) + tensor([[0.3928, 0.4007, 0.7586], + [1.0311, 0.3901, 0.5049]]) +""".format(**common_args)) + +gammainc = _add_docstr(_special.special_gammainc, + r""" +gammainc(input, other, *, out=None) -> Tensor + +Computes the regularized lower incomplete gamma function: + +.. math:: + \text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_0^{\text{other}_i} t^{\text{input}_i-1} e^{-t} dt + +where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive +and at least one is strictly positive. +If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`. +:math:`\Gamma(\cdot)` in the equation above is the gamma function, + +.. math:: + \Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt. + +See :func:`torch.special.gammaincc` and :func:`torch.special.gammaln` for related functions. + +Supports :ref:`broadcasting to a common shape ` +and float inputs. + +.. note:: + The backward pass with respect to :attr:`input` is not yet supported. + Please open an issue on PyTorch's Github to request it. + +""" + r""" +Args: + input (Tensor): the first non-negative input tensor + other (Tensor): the second non-negative input tensor + +Keyword args: + {out} + +Example:: + + >>> a1 = torch.tensor([4.0]) + >>> a2 = torch.tensor([3.0, 4.0, 5.0]) + >>> a = torch.special.gammaincc(a1, a2) + tensor([0.3528, 0.5665, 0.7350]) + tensor([0.3528, 0.5665, 0.7350]) + >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) + tensor([1., 1., 1.]) + +""".format(**common_args)) + +gammaincc = _add_docstr(_special.special_gammaincc, + r""" +gammaincc(input, other, *, out=None) -> Tensor + +Computes the regularized upper incomplete gamma function: + +.. math:: + \text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_{\text{other}_i}^{\infty} t^{\text{input}_i-1} e^{-t} dt + +where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive +and at least one is strictly positive. +If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`. +:math:`\Gamma(\cdot)` in the equation above is the gamma function, + +.. math:: + \Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt. + +See :func:`torch.special.gammainc` and :func:`torch.special.gammaln` for related functions. + +Supports :ref:`broadcasting to a common shape ` +and float inputs. + +.. note:: + The backward pass with respect to :attr:`input` is not yet supported. + Please open an issue on PyTorch's Github to request it. + +""" + r""" +Args: + input (Tensor): the first non-negative input tensor + other (Tensor): the second non-negative input tensor + +Keyword args: + {out} + +Example:: + + >>> a1 = torch.tensor([4.0]) + >>> a2 = torch.tensor([3.0, 4.0, 5.0]) + >>> a = torch.special.gammaincc(a1, a2) + tensor([0.6472, 0.4335, 0.2650]) + >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) + tensor([1., 1., 1.]) + +""".format(**common_args)) + +airy_ai = _add_docstr(_special.special_airy_ai, + r""" +airy_ai(input, *, out=None) -> Tensor + +Airy function :math:`\text{Ai}\left(\text{input}\right)`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + +bessel_j0 = _add_docstr(_special.special_bessel_j0, + r""" +bessel_j0(input, *, out=None) -> Tensor + +Bessel function of the first kind of order :math:`0`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + +bessel_j1 = _add_docstr(_special.special_bessel_j1, + r""" +bessel_j1(input, *, out=None) -> Tensor + +Bessel function of the first kind of order :math:`1`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + +bessel_y0 = _add_docstr(_special.special_bessel_y0, + r""" +bessel_y0(input, *, out=None) -> Tensor + +Bessel function of the second kind of order :math:`0`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + +bessel_y1 = _add_docstr(_special.special_bessel_y1, + r""" +bessel_y1(input, *, out=None) -> Tensor + +Bessel function of the second kind of order :math:`1`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + +chebyshev_polynomial_t = _add_docstr(_special.special_chebyshev_polynomial_t, + r""" +chebyshev_polynomial_t(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the first kind :math:`T_{n}(\text{input})`. + +If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}` +is returned. If :math:`n < 6` or :math:`|\text{input}| > 1` the recursion: + +.. math:: + T_{n + 1}(\text{input}) = 2 \times \text{input} \times T_{n}(\text{input}) - T_{n - 1}(\text{input}) + +is evaluated. Otherwise, the explicit trigonometric formula: + +.. math:: + T_{n}(\text{input}) = \text{cos}(n \times \text{arccos}(x)) + +is evaluated. + +""" + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format(**common_args)) + +chebyshev_polynomial_u = _add_docstr(_special.special_chebyshev_polynomial_u, + r""" +chebyshev_polynomial_t(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the second kind :math:`U_{n}(\text{input})`. + +If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, +:math:`2 \times \text{input}` is returned. If :math:`n < 6` or +:math:`|\text{input}| > 1`, the recursion: + +.. math:: + T_{n + 1}(\text{input}) = 2 \times \text{input} \times T_{n}(\text{input}) - T_{n - 1}(\text{input}) + +is evaluated. Otherwise, the explicit trigonometric formula: + +.. math:: + \frac{\text{sin}((n + 1) \times \text{arccos}(\text{input}))}{\text{sin}(\text{arccos}(\text{input}))} + +is evaluated. + +""" + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format(**common_args)) + +chebyshev_polynomial_v = _add_docstr(_special.special_chebyshev_polynomial_v, + r""" +chebyshev_polynomial_v(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the third kind :math:`V_{n}^{\ast}(\text{input})`. + +""" + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format(**common_args)) + +chebyshev_polynomial_w = _add_docstr(_special.special_chebyshev_polynomial_w, + r""" +chebyshev_polynomial_w(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the fourth kind :math:`W_{n}^{\ast}(\text{input})`. + +""" + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format(**common_args)) + +hermite_polynomial_h = _add_docstr(_special.special_hermite_polynomial_h, + r""" +hermite_polynomial_h(input, n, *, out=None) -> Tensor + +Physicist's Hermite polynomial :math:`H_{n}(\text{input})`. + +If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}` +is returned. Otherwise, the recursion: + +.. math:: + H_{n + 1}(\text{input}) = 2 \times \text{input} \times H_{n}(\text{input}) - H_{n - 1}(\text{input}) + +is evaluated. + +""" + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format(**common_args)) + +hermite_polynomial_he = _add_docstr(_special.special_hermite_polynomial_he, + r""" +hermite_polynomial_he(input, n, *, out=None) -> Tensor + +Probabilist's Hermite polynomial :math:`He_{n}(\text{input})`. + +If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}` +is returned. Otherwise, the recursion: + +.. math:: + He_{n + 1}(\text{input}) = 2 \times \text{input} \times He_{n}(\text{input}) - He_{n - 1}(\text{input}) + +is evaluated. + +""" + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format(**common_args)) + +laguerre_polynomial_l = _add_docstr(_special.special_laguerre_polynomial_l, + r""" +laguerre_polynomial_l(input, n, *, out=None) -> Tensor + +Laguerre polynomial :math:`L_{n}(\text{input})`. + +If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}` +is returned. Otherwise, the recursion: + +.. math:: + L_{n + 1}(\text{input}) = 2 \times \text{input} \times L_{n}(\text{input}) - L_{n - 1}(\text{input}) + +is evaluated. + +""" + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format(**common_args)) + +legendre_polynomial_p = _add_docstr(_special.special_legendre_polynomial_p, + r""" +legendre_polynomial_p(input, n, *, out=None) -> Tensor + +Legendre polynomial :math:`P_{n}(\text{input})`. + +If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}` +is returned. Otherwise, the recursion: + +.. math:: + P_{n + 1}(\text{input}) = 2 \times \text{input} \times P_{n}(\text{input}) - P_{n - 1}(\text{input}) + +is evaluated. + +""" + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format(**common_args)) + +modified_bessel_i0 = _add_docstr(_special.special_modified_bessel_i0, + r""" +modified_bessel_i0(input, *, out=None) -> Tensor + +Modified Bessel function of the first kind of order :math:`0`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + +modified_bessel_i1 = _add_docstr(_special.special_modified_bessel_i1, + r""" +modified_bessel_i1(input, *, out=None) -> Tensor + +Modified Bessel function of the first kind of order :math:`1`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + +modified_bessel_k0 = _add_docstr(_special.special_modified_bessel_k0, + r""" +modified_bessel_k0(input, *, out=None) -> Tensor + +Modified Bessel function of the second kind of order :math:`0`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + +modified_bessel_k1 = _add_docstr(_special.special_modified_bessel_k1, + r""" +modified_bessel_k1(input, *, out=None) -> Tensor + +Modified Bessel function of the second kind of order :math:`1`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + +scaled_modified_bessel_k0 = _add_docstr(_special.special_scaled_modified_bessel_k0, + r""" +scaled_modified_bessel_k0(input, *, out=None) -> Tensor + +Scaled modified Bessel function of the second kind of order :math:`0`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + +scaled_modified_bessel_k1 = _add_docstr(_special.special_scaled_modified_bessel_k1, + r""" +scaled_modified_bessel_k1(input, *, out=None) -> Tensor + +Scaled modified Bessel function of the second kind of order :math:`1`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + +shifted_chebyshev_polynomial_t = _add_docstr(_special.special_shifted_chebyshev_polynomial_t, + r""" +shifted_chebyshev_polynomial_t(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the first kind :math:`T_{n}^{\ast}(\text{input})`. + +""" + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format(**common_args)) + +shifted_chebyshev_polynomial_u = _add_docstr(_special.special_shifted_chebyshev_polynomial_u, + r""" +shifted_chebyshev_polynomial_u(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the second kind :math:`U_{n}^{\ast}(\text{input})`. + +""" + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format(**common_args)) + +shifted_chebyshev_polynomial_v = _add_docstr(_special.special_shifted_chebyshev_polynomial_v, + r""" +shifted_chebyshev_polynomial_v(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the third kind :math:`V_{n}^{\ast}(\text{input})`. + +""" + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format(**common_args)) + +shifted_chebyshev_polynomial_w = _add_docstr(_special.special_shifted_chebyshev_polynomial_w, + r""" +shifted_chebyshev_polynomial_w(input, n, *, out=None) -> Tensor + +Chebyshev polynomial of the fourth kind :math:`W_{n}^{\ast}(\text{input})`. + +""" + r""" +Args: + {input} + n (Tensor): Degree of the polynomial. + +Keyword args: + {out} +""".format(**common_args)) + +spherical_bessel_j0 = _add_docstr(_special.special_spherical_bessel_j0, + r""" +spherical_bessel_j0(input, *, out=None) -> Tensor + +Spherical Bessel function of the first kind of order :math:`0`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) diff --git a/lib/python3.10/site-packages/torch/testing/_comparison.py b/lib/python3.10/site-packages/torch/testing/_comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..668d6bff0d378ee5e3bb4b638b700a3801c17a3f --- /dev/null +++ b/lib/python3.10/site-packages/torch/testing/_comparison.py @@ -0,0 +1,1580 @@ +# mypy: allow-untyped-defs +import abc +import cmath +import collections.abc +import contextlib +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + NoReturn, + Optional, + Sequence, + Tuple, + Type, + Union, +) +from typing_extensions import deprecated + +import torch + + +try: + import numpy as np + + HAS_NUMPY = True +except ModuleNotFoundError: + HAS_NUMPY = False + np = None # type: ignore[assignment] + + +class ErrorMeta(Exception): + """Internal testing exception that makes that carries error metadata.""" + + def __init__( + self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = () + ) -> None: + super().__init__( + "If you are a user and see this message during normal operation " + "please file an issue at https://github.com/pytorch/pytorch/issues. " + "If you are a developer and working on the comparison functions, please `raise ErrorMeta.to_error()` " + "for user facing errors." + ) + self.type = type + self.msg = msg + self.id = id + + def to_error( + self, msg: Optional[Union[str, Callable[[str], str]]] = None + ) -> Exception: + if not isinstance(msg, str): + generated_msg = self.msg + if self.id: + generated_msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}" + + msg = msg(generated_msg) if callable(msg) else generated_msg + + return self.type(msg) + + +# Some analysis of tolerance by logging tests from test_torch.py can be found in +# https://github.com/pytorch/pytorch/pull/32538. +# {dtype: (rtol, atol)} +_DTYPE_PRECISIONS = { + torch.float16: (0.001, 1e-5), + torch.bfloat16: (0.016, 1e-5), + torch.float32: (1.3e-6, 1e-5), + torch.float64: (1e-7, 1e-7), + torch.complex32: (0.001, 1e-5), + torch.complex64: (1.3e-6, 1e-5), + torch.complex128: (1e-7, 1e-7), +} +# The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in +# their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values` +_DTYPE_PRECISIONS.update( + dict.fromkeys( + (torch.quint8, torch.quint2x4, torch.quint4x2, torch.qint8, torch.qint32), + _DTYPE_PRECISIONS[torch.float32], + ) +) + + +def default_tolerances( + *inputs: Union[torch.Tensor, torch.dtype], + dtype_precisions: Optional[Dict[torch.dtype, Tuple[float, float]]] = None, +) -> Tuple[float, float]: + """Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype. + + See :func:`assert_close` for a table of the default tolerance for each dtype. + + Returns: + (Tuple[float, float]): Loosest tolerances of all input dtypes. + """ + dtypes = [] + for input in inputs: + if isinstance(input, torch.Tensor): + dtypes.append(input.dtype) + elif isinstance(input, torch.dtype): + dtypes.append(input) + else: + raise TypeError( + f"Expected a torch.Tensor or a torch.dtype, but got {type(input)} instead." + ) + dtype_precisions = dtype_precisions or _DTYPE_PRECISIONS + rtols, atols = zip(*[dtype_precisions.get(dtype, (0.0, 0.0)) for dtype in dtypes]) + return max(rtols), max(atols) + + +def get_tolerances( + *inputs: Union[torch.Tensor, torch.dtype], + rtol: Optional[float], + atol: Optional[float], + id: Tuple[Any, ...] = (), +) -> Tuple[float, float]: + """Gets absolute and relative to be used for numeric comparisons. + + If both ``rtol`` and ``atol`` are specified, this is a no-op. If both are not specified, the return value of + :func:`default_tolerances` is used. + + Raises: + ErrorMeta: With :class:`ValueError`, if only ``rtol`` or ``atol`` is specified. + + Returns: + (Tuple[float, float]): Valid absolute and relative tolerances. + """ + if (rtol is None) ^ (atol is None): + # We require both tolerance to be omitted or specified, because specifying only one might lead to surprising + # results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0. + raise ErrorMeta( + ValueError, + f"Both 'rtol' and 'atol' must be either specified or omitted, " + f"but got no {'rtol' if rtol is None else 'atol'}.", + id=id, + ) + elif rtol is not None and atol is not None: + return rtol, atol + else: + return default_tolerances(*inputs) + + +def _make_mismatch_msg( + *, + default_identifier: str, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + extra: Optional[str] = None, + abs_diff: float, + abs_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None, + atol: float, + rel_diff: float, + rel_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None, + rtol: float, +) -> str: + """Makes a mismatch error message for numeric values. + + Args: + default_identifier (str): Default description of the compared values, e.g. "Tensor-likes". + identifier (Optional[Union[str, Callable[[str], str]]]): Optional identifier that overrides + ``default_identifier``. Can be passed as callable in which case it will be called with + ``default_identifier`` to create the description at runtime. + extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics. + abs_diff (float): Absolute difference. + abs_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the absolute difference. + atol (float): Allowed absolute tolerance. Will only be added to mismatch statistics if it or ``rtol`` are + ``> 0``. + rel_diff (float): Relative difference. + rel_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the relative difference. + rtol (float): Allowed relative tolerance. Will only be added to mismatch statistics if it or ``atol`` are + ``> 0``. + """ + equality = rtol == 0 and atol == 0 + + def make_diff_msg( + *, + type: str, + diff: float, + idx: Optional[Union[int, Tuple[int, ...]]], + tol: float, + ) -> str: + if idx is None: + msg = f"{type.title()} difference: {diff}" + else: + msg = f"Greatest {type} difference: {diff} at index {idx}" + if not equality: + msg += f" (up to {tol} allowed)" + return msg + "\n" + + if identifier is None: + identifier = default_identifier + elif callable(identifier): + identifier = identifier(default_identifier) + + msg = f"{identifier} are not {'equal' if equality else 'close'}!\n\n" + + if extra: + msg += f"{extra.strip()}\n" + + msg += make_diff_msg(type="absolute", diff=abs_diff, idx=abs_diff_idx, tol=atol) + msg += make_diff_msg(type="relative", diff=rel_diff, idx=rel_diff_idx, tol=rtol) + + return msg.strip() + + +def make_scalar_mismatch_msg( + actual: Union[bool, int, float, complex], + expected: Union[bool, int, float, complex], + *, + rtol: float, + atol: float, + identifier: Optional[Union[str, Callable[[str], str]]] = None, +) -> str: + """Makes a mismatch error message for scalars. + + Args: + actual (Union[bool, int, float, complex]): Actual scalar. + expected (Union[bool, int, float, complex]): Expected scalar. + rtol (float): Relative tolerance. + atol (float): Absolute tolerance. + identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the scalars. Can be passed + as callable in which case it will be called by the default value to create the description at runtime. + Defaults to "Scalars". + """ + abs_diff = abs(actual - expected) + rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected) + return _make_mismatch_msg( + default_identifier="Scalars", + identifier=identifier, + extra=f"Expected {expected} but got {actual}.", + abs_diff=abs_diff, + atol=atol, + rel_diff=rel_diff, + rtol=rtol, + ) + + +def make_tensor_mismatch_msg( + actual: torch.Tensor, + expected: torch.Tensor, + matches: torch.Tensor, + *, + rtol: float, + atol: float, + identifier: Optional[Union[str, Callable[[str], str]]] = None, +): + """Makes a mismatch error message for tensors. + + Args: + actual (torch.Tensor): Actual tensor. + expected (torch.Tensor): Expected tensor. + matches (torch.Tensor): Boolean mask of the same shape as ``actual`` and ``expected`` that indicates the + location of matches. + rtol (float): Relative tolerance. + atol (float): Absolute tolerance. + identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the tensors. Can be passed + as callable in which case it will be called by the default value to create the description at runtime. + Defaults to "Tensor-likes". + """ + + def unravel_flat_index(flat_index: int) -> Tuple[int, ...]: + if not matches.shape: + return () + + inverse_index = [] + for size in matches.shape[::-1]: + div, mod = divmod(flat_index, size) + flat_index = div + inverse_index.append(mod) + + return tuple(inverse_index[::-1]) + + number_of_elements = matches.numel() + total_mismatches = number_of_elements - int(torch.sum(matches)) + extra = ( + f"Mismatched elements: {total_mismatches} / {number_of_elements} " + f"({total_mismatches / number_of_elements:.1%})" + ) + + actual_flat = actual.flatten() + expected_flat = expected.flatten() + matches_flat = matches.flatten() + + if not actual.dtype.is_floating_point and not actual.dtype.is_complex: + # TODO: Instead of always upcasting to int64, it would be sufficient to cast to the next higher dtype to avoid + # overflow + actual_flat = actual_flat.to(torch.int64) + expected_flat = expected_flat.to(torch.int64) + + abs_diff = torch.abs(actual_flat - expected_flat) + # Ensure that only mismatches are used for the max_abs_diff computation + abs_diff[matches_flat] = 0 + max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0) + + rel_diff = abs_diff / torch.abs(expected_flat) + # Ensure that only mismatches are used for the max_rel_diff computation + rel_diff[matches_flat] = 0 + max_rel_diff, max_rel_diff_flat_idx = torch.max(rel_diff, 0) + return _make_mismatch_msg( + default_identifier="Tensor-likes", + identifier=identifier, + extra=extra, + abs_diff=max_abs_diff.item(), + abs_diff_idx=unravel_flat_index(int(max_abs_diff_flat_idx)), + atol=atol, + rel_diff=max_rel_diff.item(), + rel_diff_idx=unravel_flat_index(int(max_rel_diff_flat_idx)), + rtol=rtol, + ) + + +class UnsupportedInputs(Exception): # noqa: B903 + """Exception to be raised during the construction of a :class:`Pair` in case it doesn't support the inputs.""" + + +class Pair(abc.ABC): + """ABC for all comparison pairs to be used in conjunction with :func:`assert_equal`. + + Each subclass needs to overwrite :meth:`Pair.compare` that performs the actual comparison. + + Each pair receives **all** options, so select the ones applicable for the subclass and forward the rest to the + super class. Raising an :class:`UnsupportedInputs` during constructions indicates that the pair is not able to + handle the inputs and the next pair type will be tried. + + All other errors should be raised as :class:`ErrorMeta`. After the instantiation, :meth:`Pair._make_error_meta` can + be used to automatically handle overwriting the message with a user supplied one and id handling. + """ + + def __init__( + self, + actual: Any, + expected: Any, + *, + id: Tuple[Any, ...] = (), + **unknown_parameters: Any, + ) -> None: + self.actual = actual + self.expected = expected + self.id = id + self._unknown_parameters = unknown_parameters + + @staticmethod + def _inputs_not_supported() -> NoReturn: + raise UnsupportedInputs + + @staticmethod + def _check_inputs_isinstance(*inputs: Any, cls: Union[Type, Tuple[Type, ...]]): + """Checks if all inputs are instances of a given class and raise :class:`UnsupportedInputs` otherwise.""" + if not all(isinstance(input, cls) for input in inputs): + Pair._inputs_not_supported() + + def _fail( + self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = () + ) -> NoReturn: + """Raises an :class:`ErrorMeta` from a given exception type and message and the stored id. + + .. warning:: + + If you use this before the ``super().__init__(...)`` call in the constructor, you have to pass the ``id`` + explicitly. + """ + raise ErrorMeta(type, msg, id=self.id if not id and hasattr(self, "id") else id) + + @abc.abstractmethod + def compare(self) -> None: + """Compares the inputs and raises an :class`ErrorMeta` in case they mismatch.""" + + def extra_repr(self) -> Sequence[Union[str, Tuple[str, Any]]]: + """Returns extra information that will be included in the representation. + + Should be overwritten by all subclasses that use additional options. The representation of the object will only + be surfaced in case we encounter an unexpected error and thus should help debug the issue. Can be a sequence of + key-value-pairs or attribute names. + """ + return [] + + def __repr__(self) -> str: + head = f"{type(self).__name__}(" + tail = ")" + body = [ + f" {name}={value!s}," + for name, value in [ + ("id", self.id), + ("actual", self.actual), + ("expected", self.expected), + *[ + (extra, getattr(self, extra)) if isinstance(extra, str) else extra + for extra in self.extra_repr() + ], + ] + ] + return "\n".join((head, *body, *tail)) + + +class ObjectPair(Pair): + """Pair for any type of inputs that will be compared with the `==` operator. + + .. note:: + + Since this will instantiate for any kind of inputs, it should only be used as fallback after all other pairs + couldn't handle the inputs. + + """ + + def compare(self) -> None: + try: + equal = self.actual == self.expected + except Exception as error: + # We are not using `self._raise_error_meta` here since we need the exception chaining + raise ErrorMeta( + ValueError, + f"{self.actual} == {self.expected} failed with:\n{error}.", + id=self.id, + ) from error + + if not equal: + self._fail(AssertionError, f"{self.actual} != {self.expected}") + + +class NonePair(Pair): + """Pair for ``None`` inputs.""" + + def __init__(self, actual: Any, expected: Any, **other_parameters: Any) -> None: + if not (actual is None or expected is None): + self._inputs_not_supported() + + super().__init__(actual, expected, **other_parameters) + + def compare(self) -> None: + if not (self.actual is None and self.expected is None): + self._fail( + AssertionError, f"None mismatch: {self.actual} is not {self.expected}" + ) + + +class BooleanPair(Pair): + """Pair for :class:`bool` inputs. + + .. note:: + + If ``numpy`` is available, also handles :class:`numpy.bool_` inputs. + + """ + + def __init__( + self, + actual: Any, + expected: Any, + *, + id: Tuple[Any, ...], + **other_parameters: Any, + ) -> None: + actual, expected = self._process_inputs(actual, expected, id=id) + super().__init__(actual, expected, **other_parameters) + + @property + def _supported_types(self) -> Tuple[Type, ...]: + cls: List[Type] = [bool] + if HAS_NUMPY: + cls.append(np.bool_) + return tuple(cls) + + def _process_inputs( + self, actual: Any, expected: Any, *, id: Tuple[Any, ...] + ) -> Tuple[bool, bool]: + self._check_inputs_isinstance(actual, expected, cls=self._supported_types) + actual, expected = ( + self._to_bool(bool_like, id=id) for bool_like in (actual, expected) + ) + return actual, expected + + def _to_bool(self, bool_like: Any, *, id: Tuple[Any, ...]) -> bool: + if isinstance(bool_like, bool): + return bool_like + elif isinstance(bool_like, np.bool_): + return bool_like.item() + else: + raise ErrorMeta( + TypeError, f"Unknown boolean type {type(bool_like)}.", id=id + ) + + def compare(self) -> None: + if self.actual is not self.expected: + self._fail( + AssertionError, + f"Booleans mismatch: {self.actual} is not {self.expected}", + ) + + +class NumberPair(Pair): + """Pair for Python number (:class:`int`, :class:`float`, and :class:`complex`) inputs. + + .. note:: + + If ``numpy`` is available, also handles :class:`numpy.number` inputs. + + Kwargs: + rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default + values based on the type are selected with the below table. + atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default + values based on the type are selected with the below table. + equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``. + check_dtype (bool): If ``True``, the type of the inputs will be checked for equality. Defaults to ``False``. + + The following table displays correspondence between Python number type and the ``torch.dtype``'s. See + :func:`assert_close` for the corresponding tolerances. + + +------------------+-------------------------------+ + | ``type`` | corresponding ``torch.dtype`` | + +==================+===============================+ + | :class:`int` | :attr:`~torch.int64` | + +------------------+-------------------------------+ + | :class:`float` | :attr:`~torch.float64` | + +------------------+-------------------------------+ + | :class:`complex` | :attr:`~torch.complex64` | + +------------------+-------------------------------+ + """ + + _TYPE_TO_DTYPE = { + int: torch.int64, + float: torch.float64, + complex: torch.complex128, + } + _NUMBER_TYPES = tuple(_TYPE_TO_DTYPE.keys()) + + def __init__( + self, + actual: Any, + expected: Any, + *, + id: Tuple[Any, ...] = (), + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = False, + check_dtype: bool = False, + **other_parameters: Any, + ) -> None: + actual, expected = self._process_inputs(actual, expected, id=id) + super().__init__(actual, expected, id=id, **other_parameters) + + self.rtol, self.atol = get_tolerances( + *[self._TYPE_TO_DTYPE[type(input)] for input in (actual, expected)], + rtol=rtol, + atol=atol, + id=id, + ) + self.equal_nan = equal_nan + self.check_dtype = check_dtype + + @property + def _supported_types(self) -> Tuple[Type, ...]: + cls = list(self._NUMBER_TYPES) + if HAS_NUMPY: + cls.append(np.number) + return tuple(cls) + + def _process_inputs( + self, actual: Any, expected: Any, *, id: Tuple[Any, ...] + ) -> Tuple[Union[int, float, complex], Union[int, float, complex]]: + self._check_inputs_isinstance(actual, expected, cls=self._supported_types) + actual, expected = ( + self._to_number(number_like, id=id) for number_like in (actual, expected) + ) + return actual, expected + + def _to_number( + self, number_like: Any, *, id: Tuple[Any, ...] + ) -> Union[int, float, complex]: + if HAS_NUMPY and isinstance(number_like, np.number): + return number_like.item() + elif isinstance(number_like, self._NUMBER_TYPES): + return number_like # type: ignore[return-value] + else: + raise ErrorMeta( + TypeError, f"Unknown number type {type(number_like)}.", id=id + ) + + def compare(self) -> None: + if self.check_dtype and type(self.actual) is not type(self.expected): + self._fail( + AssertionError, + f"The (d)types do not match: {type(self.actual)} != {type(self.expected)}.", + ) + + if self.actual == self.expected: + return + + if self.equal_nan and cmath.isnan(self.actual) and cmath.isnan(self.expected): + return + + abs_diff = abs(self.actual - self.expected) + tolerance = self.atol + self.rtol * abs(self.expected) + + if cmath.isfinite(abs_diff) and abs_diff <= tolerance: + return + + self._fail( + AssertionError, + make_scalar_mismatch_msg( + self.actual, self.expected, rtol=self.rtol, atol=self.atol + ), + ) + + def extra_repr(self) -> Sequence[str]: + return ( + "rtol", + "atol", + "equal_nan", + "check_dtype", + ) + + +class TensorLikePair(Pair): + """Pair for :class:`torch.Tensor`-like inputs. + + Kwargs: + allow_subclasses (bool): + rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default + values based on the type are selected. See :func:assert_close: for details. + atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default + values based on the type are selected. See :func:assert_close: for details. + equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``. + check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same + :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different + :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared. + check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this + check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to + :func:`torch.promote_types`) before being compared. + check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this + check is disabled, tensors with different ``layout``'s are converted to strided tensors before being + compared. + check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride. + """ + + def __init__( + self, + actual: Any, + expected: Any, + *, + id: Tuple[Any, ...] = (), + allow_subclasses: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = False, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, + **other_parameters: Any, + ): + actual, expected = self._process_inputs( + actual, expected, id=id, allow_subclasses=allow_subclasses + ) + super().__init__(actual, expected, id=id, **other_parameters) + + self.rtol, self.atol = get_tolerances( + actual, expected, rtol=rtol, atol=atol, id=self.id + ) + self.equal_nan = equal_nan + self.check_device = check_device + self.check_dtype = check_dtype + self.check_layout = check_layout + self.check_stride = check_stride + + def _process_inputs( + self, actual: Any, expected: Any, *, id: Tuple[Any, ...], allow_subclasses: bool + ) -> Tuple[torch.Tensor, torch.Tensor]: + directly_related = isinstance(actual, type(expected)) or isinstance( + expected, type(actual) + ) + if not directly_related: + self._inputs_not_supported() + + if not allow_subclasses and type(actual) is not type(expected): + self._inputs_not_supported() + + actual, expected = (self._to_tensor(input) for input in (actual, expected)) + for tensor in (actual, expected): + self._check_supported(tensor, id=id) + return actual, expected + + def _to_tensor(self, tensor_like: Any) -> torch.Tensor: + if isinstance(tensor_like, torch.Tensor): + return tensor_like + + try: + return torch.as_tensor(tensor_like) + except Exception: + self._inputs_not_supported() + + def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None: + if tensor.layout not in { + torch.strided, + torch.jagged, + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + raise ErrorMeta( + ValueError, f"Unsupported tensor layout {tensor.layout}", id=id + ) + + def compare(self) -> None: + actual, expected = self.actual, self.expected + + self._compare_attributes(actual, expected) + if any(input.device.type == "meta" for input in (actual, expected)): + return + + actual, expected = self._equalize_attributes(actual, expected) + self._compare_values(actual, expected) + + def _compare_attributes( + self, + actual: torch.Tensor, + expected: torch.Tensor, + ) -> None: + """Checks if the attributes of two tensors match. + + Always checks + + - the :attr:`~torch.Tensor.shape`, + - whether both inputs are quantized or not, + - and if they use the same quantization scheme. + + Checks for + + - :attr:`~torch.Tensor.layout`, + - :meth:`~torch.Tensor.stride`, + - :attr:`~torch.Tensor.device`, and + - :attr:`~torch.Tensor.dtype` + + are optional and can be disabled through the corresponding ``check_*`` flag during construction of the pair. + """ + + def raise_mismatch_error( + attribute_name: str, actual_value: Any, expected_value: Any + ) -> NoReturn: + self._fail( + AssertionError, + f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.", + ) + + if actual.shape != expected.shape: + raise_mismatch_error("shape", actual.shape, expected.shape) + + if actual.is_quantized != expected.is_quantized: + raise_mismatch_error( + "is_quantized", actual.is_quantized, expected.is_quantized + ) + elif actual.is_quantized and actual.qscheme() != expected.qscheme(): + raise_mismatch_error("qscheme()", actual.qscheme(), expected.qscheme()) + + if actual.layout != expected.layout: + if self.check_layout: + raise_mismatch_error("layout", actual.layout, expected.layout) + elif ( + actual.layout == torch.strided + and self.check_stride + and actual.stride() != expected.stride() + ): + raise_mismatch_error("stride()", actual.stride(), expected.stride()) + + if self.check_device and actual.device != expected.device: + raise_mismatch_error("device", actual.device, expected.device) + + if self.check_dtype and actual.dtype != expected.dtype: + raise_mismatch_error("dtype", actual.dtype, expected.dtype) + + def _equalize_attributes( + self, actual: torch.Tensor, expected: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Equalizes some attributes of two tensors for value comparison. + + If ``actual`` and ``expected`` are ... + + - ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory. + - ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to + :func:`torch.promote_types`). + - ... not of the same ``layout``, they are converted to strided tensors. + + Args: + actual (Tensor): Actual tensor. + expected (Tensor): Expected tensor. + + Returns: + (Tuple[Tensor, Tensor]): Equalized tensors. + """ + # The comparison logic uses operators currently not supported by the MPS backends. + # See https://github.com/pytorch/pytorch/issues/77144 for details. + # TODO: Remove this conversion as soon as all operations are supported natively by the MPS backend + if actual.is_mps or expected.is_mps: # type: ignore[attr-defined] + actual = actual.cpu() + expected = expected.cpu() + + if actual.device != expected.device: + actual = actual.cpu() + expected = expected.cpu() + + if actual.dtype != expected.dtype: + actual_dtype = actual.dtype + expected_dtype = expected.dtype + # For uint64, this is not sound in general, which is why promote_types doesn't + # allow it, but for easy testing, we're unlikely to get confused + # by large uint64 overflowing into negative int64 + if actual_dtype in [torch.uint64, torch.uint32, torch.uint16]: + actual_dtype = torch.int64 + if expected_dtype in [torch.uint64, torch.uint32, torch.uint16]: + expected_dtype = torch.int64 + dtype = torch.promote_types(actual_dtype, expected_dtype) + actual = actual.to(dtype) + expected = expected.to(dtype) + + if actual.layout != expected.layout: + # These checks are needed, since Tensor.to_dense() fails on tensors that are already strided + actual = actual.to_dense() if actual.layout != torch.strided else actual + expected = ( + expected.to_dense() if expected.layout != torch.strided else expected + ) + + return actual, expected + + def _compare_values(self, actual: torch.Tensor, expected: torch.Tensor) -> None: + if actual.is_quantized: + compare_fn = self._compare_quantized_values + elif actual.is_sparse: + compare_fn = self._compare_sparse_coo_values + elif actual.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + compare_fn = self._compare_sparse_compressed_values + elif actual.layout == torch.jagged: + actual, expected = actual.values(), expected.values() + compare_fn = self._compare_regular_values_close + else: + compare_fn = self._compare_regular_values_close + + compare_fn( + actual, expected, rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan + ) + + def _compare_quantized_values( + self, + actual: torch.Tensor, + expected: torch.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + ) -> None: + """Compares quantized tensors by comparing the :meth:`~torch.Tensor.dequantize`'d variants for closeness. + + .. note:: + + A detailed discussion about why only the dequantized variant is checked for closeness rather than checking + the individual quantization parameters for closeness and the integer representation for equality can be + found in https://github.com/pytorch/pytorch/issues/68548. + """ + return self._compare_regular_values_close( + actual.dequantize(), + expected.dequantize(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + identifier=lambda default_identifier: f"Quantized {default_identifier.lower()}", + ) + + def _compare_sparse_coo_values( + self, + actual: torch.Tensor, + expected: torch.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + ) -> None: + """Compares sparse COO tensors by comparing + + - the number of sparse dimensions, + - the number of non-zero elements (nnz) for equality, + - the indices for equality, and + - the values for closeness. + """ + if actual.sparse_dim() != expected.sparse_dim(): + self._fail( + AssertionError, + ( + f"The number of sparse dimensions in sparse COO tensors does not match: " + f"{actual.sparse_dim()} != {expected.sparse_dim()}" + ), + ) + + if actual._nnz() != expected._nnz(): + self._fail( + AssertionError, + ( + f"The number of specified values in sparse COO tensors does not match: " + f"{actual._nnz()} != {expected._nnz()}" + ), + ) + + self._compare_regular_values_equal( + actual._indices(), + expected._indices(), + identifier="Sparse COO indices", + ) + self._compare_regular_values_close( + actual._values(), + expected._values(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + identifier="Sparse COO values", + ) + + def _compare_sparse_compressed_values( + self, + actual: torch.Tensor, + expected: torch.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + ) -> None: + """Compares sparse compressed tensors by comparing + + - the number of non-zero elements (nnz) for equality, + - the plain indices for equality, + - the compressed indices for equality, and + - the values for closeness. + """ + format_name, compressed_indices_method, plain_indices_method = { + torch.sparse_csr: ( + "CSR", + torch.Tensor.crow_indices, + torch.Tensor.col_indices, + ), + torch.sparse_csc: ( + "CSC", + torch.Tensor.ccol_indices, + torch.Tensor.row_indices, + ), + torch.sparse_bsr: ( + "BSR", + torch.Tensor.crow_indices, + torch.Tensor.col_indices, + ), + torch.sparse_bsc: ( + "BSC", + torch.Tensor.ccol_indices, + torch.Tensor.row_indices, + ), + }[actual.layout] + + if actual._nnz() != expected._nnz(): + self._fail( + AssertionError, + ( + f"The number of specified values in sparse {format_name} tensors does not match: " + f"{actual._nnz()} != {expected._nnz()}" + ), + ) + + # Compressed and plain indices in the CSR / CSC / BSR / BSC sparse formates can be `torch.int32` _or_ + # `torch.int64`. While the same dtype is enforced for the compressed and plain indices of a single tensor, it + # can be different between two tensors. Thus, we need to convert them to the same dtype, or the comparison will + # fail. + actual_compressed_indices = compressed_indices_method(actual) + expected_compressed_indices = compressed_indices_method(expected) + indices_dtype = torch.promote_types( + actual_compressed_indices.dtype, expected_compressed_indices.dtype + ) + + self._compare_regular_values_equal( + actual_compressed_indices.to(indices_dtype), + expected_compressed_indices.to(indices_dtype), + identifier=f"Sparse {format_name} {compressed_indices_method.__name__}", + ) + self._compare_regular_values_equal( + plain_indices_method(actual).to(indices_dtype), + plain_indices_method(expected).to(indices_dtype), + identifier=f"Sparse {format_name} {plain_indices_method.__name__}", + ) + self._compare_regular_values_close( + actual.values(), + expected.values(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + identifier=f"Sparse {format_name} values", + ) + + def _compare_regular_values_equal( + self, + actual: torch.Tensor, + expected: torch.Tensor, + *, + equal_nan: bool = False, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + ) -> None: + """Checks if the values of two tensors are equal.""" + self._compare_regular_values_close( + actual, expected, rtol=0, atol=0, equal_nan=equal_nan, identifier=identifier + ) + + def _compare_regular_values_close( + self, + actual: torch.Tensor, + expected: torch.Tensor, + *, + rtol: float, + atol: float, + equal_nan: bool, + identifier: Optional[Union[str, Callable[[str], str]]] = None, + ) -> None: + """Checks if the values of two tensors are close up to a desired tolerance.""" + matches = torch.isclose( + actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + if torch.all(matches): + return + + if actual.shape == torch.Size([]): + msg = make_scalar_mismatch_msg( + actual.item(), + expected.item(), + rtol=rtol, + atol=atol, + identifier=identifier, + ) + else: + msg = make_tensor_mismatch_msg( + actual, expected, matches, rtol=rtol, atol=atol, identifier=identifier + ) + self._fail(AssertionError, msg) + + def extra_repr(self) -> Sequence[str]: + return ( + "rtol", + "atol", + "equal_nan", + "check_device", + "check_dtype", + "check_layout", + "check_stride", + ) + + +def originate_pairs( + actual: Any, + expected: Any, + *, + pair_types: Sequence[Type[Pair]], + sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,), + mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,), + id: Tuple[Any, ...] = (), + **options: Any, +) -> List[Pair]: + """Originates pairs from the individual inputs. + + ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or + :class:`~collections.abc.Mapping`'s. In this case the pairs are originated by recursing through them. + + Args: + actual (Any): Actual input. + expected (Any): Expected input. + pair_types (Sequence[Type[Pair]]): Sequence of pair types that will be tried to construct with the inputs. + First successful pair will be used. + sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise. + mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise. + id (Tuple[Any, ...]): Optional id of a pair that will be included in an error message. + **options (Any): Options passed to each pair during construction. + + Raises: + ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Sequence`'s, but their + length does not match. + ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Mapping`'s, but their set of + keys do not match. + ErrorMeta: With :class`TypeError`, if no pair is able to handle the inputs. + ErrorMeta: With any expected exception that happens during the construction of a pair. + + Returns: + (List[Pair]): Originated pairs. + """ + # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: + # "a" == "a"[0][0]... + if ( + isinstance(actual, sequence_types) + and not isinstance(actual, str) + and isinstance(expected, sequence_types) + and not isinstance(expected, str) + ): + actual_len = len(actual) + expected_len = len(expected) + if actual_len != expected_len: + raise ErrorMeta( + AssertionError, + f"The length of the sequences mismatch: {actual_len} != {expected_len}", + id=id, + ) + + pairs = [] + for idx in range(actual_len): + pairs.extend( + originate_pairs( + actual[idx], + expected[idx], + pair_types=pair_types, + sequence_types=sequence_types, + mapping_types=mapping_types, + id=(*id, idx), + **options, + ) + ) + return pairs + + elif isinstance(actual, mapping_types) and isinstance(expected, mapping_types): + actual_keys = set(actual.keys()) + expected_keys = set(expected.keys()) + if actual_keys != expected_keys: + missing_keys = expected_keys - actual_keys + additional_keys = actual_keys - expected_keys + raise ErrorMeta( + AssertionError, + ( + f"The keys of the mappings do not match:\n" + f"Missing keys in the actual mapping: {sorted(missing_keys)}\n" + f"Additional keys in the actual mapping: {sorted(additional_keys)}" + ), + id=id, + ) + + keys: Collection = actual_keys + # Since the origination aborts after the first failure, we try to be deterministic + with contextlib.suppress(Exception): + keys = sorted(keys) + + pairs = [] + for key in keys: + pairs.extend( + originate_pairs( + actual[key], + expected[key], + pair_types=pair_types, + sequence_types=sequence_types, + mapping_types=mapping_types, + id=(*id, key), + **options, + ) + ) + return pairs + + else: + for pair_type in pair_types: + try: + return [pair_type(actual, expected, id=id, **options)] + # Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the + # inputs. Thus, we try the next pair type. + except UnsupportedInputs: + continue + # Raising an `ErrorMeta` during origination is the orderly way to abort and so we simply re-raise it. This + # is only in a separate branch, because the one below would also except it. + except ErrorMeta: + raise + # Raising any other exception during origination is unexpected and will give some extra information about + # what happened. If applicable, the exception should be expected in the future. + except Exception as error: + raise RuntimeError( + f"Originating a {pair_type.__name__}() at item {''.join(str([item]) for item in id)} with\n\n" + f"{type(actual).__name__}(): {actual}\n\n" + f"and\n\n" + f"{type(expected).__name__}(): {expected}\n\n" + f"resulted in the unexpected exception above. " + f"If you are a user and see this message during normal operation " + "please file an issue at https://github.com/pytorch/pytorch/issues. " + "If you are a developer and working on the comparison functions, " + "please except the previous error and raise an expressive `ErrorMeta` instead." + ) from error + else: + raise ErrorMeta( + TypeError, + f"No comparison pair was able to handle inputs of type {type(actual)} and {type(expected)}.", + id=id, + ) + + +def not_close_error_metas( + actual: Any, + expected: Any, + *, + pair_types: Sequence[Type[Pair]] = (ObjectPair,), + sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,), + mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,), + **options: Any, +) -> List[ErrorMeta]: + """Asserts that inputs are equal. + + ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or + :class:`~collections.abc.Mapping`'s. In this case the comparison happens elementwise by recursing through them. + + Args: + actual (Any): Actual input. + expected (Any): Expected input. + pair_types (Sequence[Type[Pair]]): Sequence of :class:`Pair` types that will be tried to construct with the + inputs. First successful pair will be used. Defaults to only using :class:`ObjectPair`. + sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise. + mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise. + **options (Any): Options passed to each pair during construction. + """ + # Hide this function from `pytest`'s traceback + __tracebackhide__ = True + + try: + pairs = originate_pairs( + actual, + expected, + pair_types=pair_types, + sequence_types=sequence_types, + mapping_types=mapping_types, + **options, + ) + except ErrorMeta as error_meta: + # Explicitly raising from None to hide the internal traceback + raise error_meta.to_error() from None # noqa: RSE102 + + error_metas: List[ErrorMeta] = [] + for pair in pairs: + try: + pair.compare() + except ErrorMeta as error_meta: + error_metas.append(error_meta) + # Raising any exception besides `ErrorMeta` while comparing is unexpected and will give some extra information + # about what happened. If applicable, the exception should be expected in the future. + except Exception as error: + raise RuntimeError( + f"Comparing\n\n" + f"{pair}\n\n" + f"resulted in the unexpected exception above. " + f"If you are a user and see this message during normal operation " + "please file an issue at https://github.com/pytorch/pytorch/issues. " + "If you are a developer and working on the comparison functions, " + "please except the previous error and raise an expressive `ErrorMeta` instead." + ) from error + + # [ErrorMeta Cycles] + # ErrorMeta objects in this list capture + # tracebacks that refer to the frame of this function. + # The local variable `error_metas` refers to the error meta + # objects, creating a reference cycle. Frames in the traceback + # would not get freed until cycle collection, leaking cuda memory in tests. + # We break the cycle by removing the reference to the error_meta objects + # from this frame as it returns. + error_metas = [error_metas] + return error_metas.pop() + + +def assert_close( + actual: Any, + expected: Any, + *, + allow_subclasses: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = False, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, + msg: Optional[Union[str, Callable[[str], str]]] = None, +): + r"""Asserts that ``actual`` and ``expected`` are close. + + If ``actual`` and ``expected`` are strided, non-quantized, real-valued, and finite, they are considered close if + + .. math:: + + \lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert + + Non-finite values (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are + only considered equal to each other if ``equal_nan`` is ``True``. + + In addition, they are only considered close if they have the same + + - :attr:`~torch.Tensor.device` (if ``check_device`` is ``True``), + - ``dtype`` (if ``check_dtype`` is ``True``), + - ``layout`` (if ``check_layout`` is ``True``), and + - stride (if ``check_stride`` is ``True``). + + If either ``actual`` or ``expected`` is a meta tensor, only the attribute checks will be performed. + + If ``actual`` and ``expected`` are sparse (either having COO, CSR, CSC, BSR, or BSC layout), their strided members are + checked individually. Indices, namely ``indices`` for COO, ``crow_indices`` and ``col_indices`` for CSR and BSR, + or ``ccol_indices`` and ``row_indices`` for CSC and BSC layouts, respectively, + are always checked for equality whereas the values are checked for closeness according to the definition above. + + If ``actual`` and ``expected`` are quantized, they are considered close if they have the same + :meth:`~torch.Tensor.qscheme` and the result of :meth:`~torch.Tensor.dequantize` is close according to the + definition above. + + ``actual`` and ``expected`` can be :class:`~torch.Tensor`'s or any tensor-or-scalar-likes from which + :class:`torch.Tensor`'s can be constructed with :func:`torch.as_tensor`. Except for Python scalars the input types + have to be directly related. In addition, ``actual`` and ``expected`` can be :class:`~collections.abc.Sequence`'s + or :class:`~collections.abc.Mapping`'s in which case they are considered close if their structure matches and all + their elements are considered close according to the above definition. + + .. note:: + + Python scalars are an exception to the type relation requirement, because their :func:`type`, i.e. + :class:`int`, :class:`float`, and :class:`complex`, is equivalent to the ``dtype`` of a tensor-like. Thus, + Python scalars of different types can be checked, but require ``check_dtype=False``. + + Args: + actual (Any): Actual input. + expected (Any): Expected input. + allow_subclasses (bool): If ``True`` (default) and except for Python scalars, inputs of directly related types + are allowed. Otherwise type equality is required. + rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default + values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. + atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default + values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. + equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal. + check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same + :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different + :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared. + check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this + check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to + :func:`torch.promote_types`) before being compared. + check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this + check is disabled, tensors with different ``layout``'s are converted to strided tensors before being + compared. + check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride. + msg (Optional[Union[str, Callable[[str], str]]]): Optional error message to use in case a failure occurs during + the comparison. Can also passed as callable in which case it will be called with the generated message and + should return the new message. + + Raises: + ValueError: If no :class:`torch.Tensor` can be constructed from an input. + ValueError: If only ``rtol`` or ``atol`` is specified. + AssertionError: If corresponding inputs are not Python scalars and are not directly related. + AssertionError: If ``allow_subclasses`` is ``False``, but corresponding inputs are not Python scalars and have + different types. + AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match. + AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match. + AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.shape`. + AssertionError: If ``check_layout`` is ``True``, but corresponding tensors do not have the same + :attr:`~torch.Tensor.layout`. + AssertionError: If only one of corresponding tensors is quantized. + AssertionError: If corresponding tensors are quantized, but have different :meth:`~torch.Tensor.qscheme`'s. + AssertionError: If ``check_device`` is ``True``, but corresponding tensors are not on the same + :attr:`~torch.Tensor.device`. + AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``. + AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride. + AssertionError: If the values of corresponding tensors are not close according to the definition above. + + The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching + ``dtype``'s, the maximum of both tolerances is used. + + +---------------------------+------------+----------+ + | ``dtype`` | ``rtol`` | ``atol`` | + +===========================+============+==========+ + | :attr:`~torch.float16` | ``1e-3`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.bfloat16` | ``1.6e-2`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.float32` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` | + +---------------------------+------------+----------+ + | :attr:`~torch.complex32` | ``1e-3`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` | + +---------------------------+------------+----------+ + | :attr:`~torch.quint8` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.quint2x4` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.quint4x2` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.qint8` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.qint32` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | other | ``0.0`` | ``0.0`` | + +---------------------------+------------+----------+ + + .. note:: + + :func:`~torch.testing.assert_close` is highly configurable with strict default settings. Users are encouraged + to :func:`~functools.partial` it to fit their use case. For example, if an equality check is needed, one might + define an ``assert_equal`` that uses zero tolerances for every ``dtype`` by default: + + >>> import functools + >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) + >>> assert_equal(1e-9, 1e-10) + Traceback (most recent call last): + ... + AssertionError: Scalars are not equal! + + Expected 1e-10 but got 1e-09. + Absolute difference: 9.000000000000001e-10 + Relative difference: 9.0 + + Examples: + >>> # tensor to tensor comparison + >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) + >>> actual = torch.acos(torch.cos(expected)) + >>> torch.testing.assert_close(actual, expected) + + >>> # scalar to scalar comparison + >>> import math + >>> expected = math.sqrt(2.0) + >>> actual = 2.0 / math.sqrt(2.0) + >>> torch.testing.assert_close(actual, expected) + + >>> # numpy array to numpy array comparison + >>> import numpy as np + >>> expected = np.array([1e0, 1e-1, 1e-2]) + >>> actual = np.arccos(np.cos(expected)) + >>> torch.testing.assert_close(actual, expected) + + >>> # sequence to sequence comparison + >>> import numpy as np + >>> # The types of the sequences do not have to match. They only have to have the same + >>> # length and their elements have to match. + >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] + >>> actual = tuple(expected) + >>> torch.testing.assert_close(actual, expected) + + >>> # mapping to mapping comparison + >>> from collections import OrderedDict + >>> import numpy as np + >>> foo = torch.tensor(1.0) + >>> bar = 2.0 + >>> baz = np.array(3.0) + >>> # The types and a possible ordering of mappings do not have to match. They only + >>> # have to have the same set of keys and their elements have to match. + >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) + >>> actual = {"baz": baz, "bar": bar, "foo": foo} + >>> torch.testing.assert_close(actual, expected) + + >>> expected = torch.tensor([1.0, 2.0, 3.0]) + >>> actual = expected.clone() + >>> # By default, directly related instances can be compared + >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) + >>> # This check can be made more strict with allow_subclasses=False + >>> torch.testing.assert_close( + ... torch.nn.Parameter(actual), expected, allow_subclasses=False + ... ) + Traceback (most recent call last): + ... + TypeError: No comparison pair was able to handle inputs of type + and . + >>> # If the inputs are not directly related, they are never considered close + >>> torch.testing.assert_close(actual.numpy(), expected) + Traceback (most recent call last): + ... + TypeError: No comparison pair was able to handle inputs of type + and . + >>> # Exceptions to these rules are Python scalars. They can be checked regardless of + >>> # their type if check_dtype=False. + >>> torch.testing.assert_close(1.0, 1, check_dtype=False) + + >>> # NaN != NaN by default. + >>> expected = torch.tensor(float("Nan")) + >>> actual = expected.clone() + >>> torch.testing.assert_close(actual, expected) + Traceback (most recent call last): + ... + AssertionError: Scalars are not close! + + Expected nan but got nan. + Absolute difference: nan (up to 1e-05 allowed) + Relative difference: nan (up to 1.3e-06 allowed) + >>> torch.testing.assert_close(actual, expected, equal_nan=True) + + >>> expected = torch.tensor([1.0, 2.0, 3.0]) + >>> actual = torch.tensor([1.0, 4.0, 5.0]) + >>> # The default error message can be overwritten. + >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") + Traceback (most recent call last): + ... + AssertionError: Argh, the tensors are not close! + >>> # If msg is a callable, it can be used to augment the generated message with + >>> # extra information + >>> torch.testing.assert_close( + ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" + ... ) + Traceback (most recent call last): + ... + AssertionError: Header + + Tensor-likes are not close! + + Mismatched elements: 2 / 3 (66.7%) + Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) + Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) + + Footer + """ + # Hide this function from `pytest`'s traceback + __tracebackhide__ = True + + error_metas = not_close_error_metas( + actual, + expected, + pair_types=( + NonePair, + BooleanPair, + NumberPair, + TensorLikePair, + ), + allow_subclasses=allow_subclasses, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + check_device=check_device, + check_dtype=check_dtype, + check_layout=check_layout, + check_stride=check_stride, + msg=msg, + ) + + if error_metas: + # TODO: compose all metas into one AssertionError + raise error_metas[0].to_error(msg) + + +@deprecated( + "`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. " + "Please use `torch.testing.assert_close()` instead. " + "You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.", + category=FutureWarning, +) +def assert_allclose( + actual: Any, + expected: Any, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = True, + msg: str = "", +) -> None: + """ + .. warning:: + + :func:`torch.testing.assert_allclose` is deprecated since ``1.12`` and will be removed in a future release. + Please use :func:`torch.testing.assert_close` instead. You can find detailed upgrade instructions + `here `_. + """ + if not isinstance(actual, torch.Tensor): + actual = torch.tensor(actual) + if not isinstance(expected, torch.Tensor): + expected = torch.tensor(expected, dtype=actual.dtype) + + if rtol is None and atol is None: + rtol, atol = default_tolerances( + actual, + expected, + dtype_precisions={ + torch.float16: (1e-3, 1e-3), + torch.float32: (1e-4, 1e-5), + torch.float64: (1e-5, 1e-8), + }, + ) + + torch.testing.assert_close( + actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + check_device=True, + check_dtype=False, + check_stride=False, + msg=msg or None, + ) diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b279db754372dd3fa9e2bc0d50df0a2b3210512c Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/code_template.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/code_template.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dcbe82999970f27f677324727b9ede8e35091ef Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/code_template.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/context.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/context.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e07841606c502201de09444d177c21d4de95658 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/context.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/gen.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/gen.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b10ba14ec1db1220f6c0ff8390a5b56c6a32d61 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/gen.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae88e971bbb9d812b0810982c30193adfe6b4bed Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40efa54a9a92b03b8a8488d899efbcb53d5e2ae1 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/gen_executorch.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/gen_executorch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71acf89d6c4b419bb4f937e81ad31cd4940e7097 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/gen_executorch.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba6437f1f67fa516ebb17abf8924d41109f5a66b Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe61927e00ffeb75216d154c201929382231e75d Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f1bc111e59d5e8f6c1abdad37bfd01838e46baa Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5256e0eb95a27e314f585ce0467462e9ba69d7aa Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/local.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/local.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c588b8fc951906b38ea099c4366be49f766d86a2 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/local.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/model.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..959c63f797cc8afade99d3310dc5610b23a87d9d Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/model.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/native_function_generation.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/native_function_generation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..116d7e1ead619d91f65820804ca280b1dd4d4d56 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/native_function_generation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/utils.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03428d6f6bb9f04c95ee64f3864f1199579afab4 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/__pycache__/yaml_utils.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/__pycache__/yaml_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6b81be6c1c2948ea83d347ef2be24e063bb8b39 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/__pycache__/yaml_utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/aoti/__init__.py b/lib/python3.10/site-packages/torchgen/aoti/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/torchgen/aoti/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/aoti/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..591598c79e874679d6eafa84137fceb638feadb9 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/aoti/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..306eaa15b111d84ca8963ab50245ec4dd4f28829 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/aoti/fallback_ops.py b/lib/python3.10/site-packages/torchgen/aoti/fallback_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..aa88214b3672f199b2858eeb18ec2917ba3c2d0b --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/aoti/fallback_ops.py @@ -0,0 +1,149 @@ +# Be extra careful when you edit this file, because it affects AOTInductor ABI compatbility. See +# https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 +# for details. +# +# The inductor_fallback_ops list is based on the fallback ops from torch/_inductor/lowering.py. +# Generally speaking, it is ok to add a new op to the list, but you need to run +# `python torchgen/gen.py --update-aoti-c-shim` in order to regenerate C shim header files. +# But it is NOT ok to remove an existing fallback op from the list, since that will break +# some existing AOTInductor-compiled models. +inductor_fallback_ops = { + "aten._adaptive_avg_pool2d_backward.default", + "aten._adaptive_avg_pool2d.default", + "aten._adaptive_avg_pool3d.default", + "aten._adaptive_avg_pool3d_backward.default", + "aten.adaptive_max_pool2d_backward.default", + "aten.adaptive_max_pool2d.default", + "aten.adaptive_max_pool3d.default", + "aten.adaptive_max_pool3d_backward.default", + "aten.addbmm.default", + "aten._addmm_activation.default", + "aten.addmm.out", + "aten.addmv.default", + "aten.angle.default", + "aten.avg_pool2d_backward.default", + "aten.avg_pool2d.default", + "aten.avg_pool3d_backward.default", + "aten.avg_pool3d.default", + "aten.bernoulli_.float", + "aten.bernoulli_.Tensor", + "aten.bmm.out", + "aten.bucketize.Tensor", + "aten.cat.default", + "aten._cdist_backward.default", + "aten._cdist_forward.default", + "aten.cholesky_inverse.default", + "aten.cholesky_solve.default", + "aten.convolution_backward.default", + "aten._cudnn_rnn.default", + "aten._cudnn_rnn_backward.default", + "aten.convolution.default", + "aten.cummax.default", + "aten.cummin.default", + "aten.cumprod.default", + "aten.cumsum.default", + "aten._efficient_attention_backward.default", + "aten._efficient_attention_forward.default", + "aten._efficientzerotensor.default", + "aten._embedding_bag.default", + "aten._embedding_bag_dense_backward.default", + "aten._embedding_bag_forward_only.default", + "aten._embedding_bag_per_sample_weights_backward.default", + "aten.exponential.default", + "aten._fft_c2c.default", + "aten._fft_r2c.default", + "aten._flash_attention_backward.default", + "aten._flash_attention_forward.default", + "aten.fractional_max_pool2d_backward.default", + "aten.fractional_max_pool2d.default", + "aten.fractional_max_pool3d.default", + "aten.fractional_max_pool3d_backward.default", + "aten._fused_moving_avg_obs_fq_helper.default", + "aten._fused_moving_avg_obs_fq_helper_functional.default", + "aten.gcd.default", + "aten.geqrf.default", + "aten.grid_sampler_2d_backward.default", + "aten.histc.default", + "aten.histogram.bin_ct", + "aten._histogramdd_bin_edges.default", + "aten._histogramdd_from_bin_cts.default", + "aten.index_put.default", + "aten.index_reduce.default", + "aten.index.Tensor", + "aten.kthvalue.default", + "aten.logcumsumexp.default", + "aten.lu_unpack.default", + "aten.masked_scatter.default", + "aten.masked_scatter_backward.default", + "aten.max_pool2d_with_indices_backward.default", + "aten.max_pool2d_with_indices.default", + "aten.max_pool3d_with_indices.default", + "aten.max_pool3d_with_indices_backward.default", + "aten.max_unpool2d.default", + "aten.max_unpool3d.default", + "aten.median.default", + "aten.mm.out", + "aten.mode.default", + "aten.mul.Scalar", + "aten.mul.Tensor", + "aten.nanmedian.default", + "aten.native_dropout.default", + "aten.normal_functional.default", + "aten.nonzero.default", + "aten.ormqr.default", + "aten._pdist_backward.default", + "aten._pdist_forward.default", + "aten.polar.default", + "aten.pow.Scalar", + "aten.pow.Tensor_Scalar", + "aten.pow.Tensor_Tensor", + "aten.rand.default", + "aten.rand.generator", + "aten.randint.default", + "aten.randint.generator", + "aten.randint.low", + "aten.randint.low_out", + "aten.randn.default", + "aten.randn.generator", + "aten.randperm.default", + "aten.repeat_interleave.Tensor", + "aten.replication_pad1d_backward.default", + "aten.replication_pad2d_backward.default", + "aten.reshape.default", + "aten.resize_.default", + "aten.resize_as_.default", + "aten._scaled_dot_product_efficient_attention_backward.default", + "aten._scaled_dot_product_efficient_attention.default", + "aten._scaled_dot_product_flash_attention_backward.default", + "aten._scaled_dot_product_flash_attention.default", + "aten._scaled_dot_product_cudnn_attention_backward.default", + "aten._scaled_dot_product_cudnn_attention.default", + "aten._scaled_dot_product_flash_attention_for_cpu_backward.default", + "aten._scaled_dot_product_flash_attention_for_cpu.default", + "aten._scaled_mm.default", + "aten.scatter_reduce.two_out", + "aten.scatter.src_out", + "aten.scatter.value_out", + "aten.searchsorted.default", + "aten._segment_reduce_backward.default", + "aten.segment_reduce.default", + "aten.slice.Tensor", + "aten.soft_margin_loss_backward.default", + "aten.sort.default", + "aten.sort.stable", + "aten._sparse_coo_tensor_with_dims_and_tensors.default", + "aten._thnn_fused_lstm_cell.default", + "aten.topk.default", + "aten._to_sparse.default", + "aten.to_sparse.default", + "aten.triangular_solve.default", + "aten._trilinear.default", + "aten.uniform.default", + "aten.upsample_bicubic2d_backward.default", + "aten.upsample_linear1d_backward.default", + "aten.upsample_trilinear3d_backward.default", + "aten.view_as_complex.default", + "aten.view_as_real.default", + "aten.view.dtype", + "aten.zeros.names", +} diff --git a/lib/python3.10/site-packages/torchgen/api/__init__.py b/lib/python3.10/site-packages/torchgen/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3f0e4a0ed5afa9d1551a465e19f4598dbcf38e2 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/autograd.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/autograd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46083da33fe521f2f23d41f03b54f8364283f647 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/autograd.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/cpp.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/cpp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a6436a1ad317ec921b2b548176d6666ba5bd7cb Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/cpp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/dispatcher.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/dispatcher.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69213e311ca2ffeea92c7ba24eb6dc12dc5587a2 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/dispatcher.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/functionalization.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/functionalization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaa0746ae4b646efb3db1638f552b57df6eae6c0 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/functionalization.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/lazy.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/lazy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a46629796dbec28d3e9f43569f9c10d8bd0e48a Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/lazy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/meta.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/meta.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c95b184a8ae0a4c1089629b5f2472221cecbbf6 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/meta.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/native.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/native.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dbb52899cf1f230c30bc64d1d3e15f3627417a0 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/native.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/python.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/python.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67d9e417e7de9a909b1ade50997b7672c99c601e Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/python.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/structured.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/structured.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f15593684ac0ea84a34a3b1c81fb111946cb5d5 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/structured.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/translate.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/translate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21618f60fcd3f6edd2ff71554fa02418762ba337 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/translate.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/ufunc.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/ufunc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e622a21d85057282c1d62d2879741d24a5f0485 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/ufunc.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/__pycache__/unboxing.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/__pycache__/unboxing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f665ee5a62728be257a31f1a88a22fd0c2ba657 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/__pycache__/unboxing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/autograd.py b/lib/python3.10/site-packages/torchgen/api/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..644069395e1dd86d7bc65c4d69473faa1a068b66 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/autograd.py @@ -0,0 +1,870 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import cast, Sequence + +from torchgen import local +from torchgen.api import cpp +from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT +from torchgen.model import ( + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + NativeFunctionsViewGroup, + SchemaKind, + Type, +) +from torchgen.utils import IDENT_REGEX + + +# Represents a saved attribute involved in backward calculation. +# Note that it can be a derived property of an input argument, e.g.: +# we could save `other.scalar_type()` instead of the entire `other` tensor. +@dataclass(frozen=True) +class SavedAttribute: + # The NamedCType holds the updated name and cpp type of the attribute + # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type` + nctype: NamedCType + + # The expression to read the derived property at save time, e.g.: + # `other.scalar_type()`. + expr: str + + +# Represents a backward formula that calculates derivatives for one +# or more tensors. +@dataclass(frozen=True) +class Derivative: + # The formula string (legit C++ expression). + # Note that expressions against input arguments have been replaced with the + # corresponding saved attributes. + # E.g.: + # raw formula: `mul_tensor_backward(grad, self, other.scalar_type())` + # here: `mul_tensor_backward(grad, self, other_scalar_type)` + formula: str + + # The formula string before input argument replacement + original_formula: str + + # Names of the arguments for which this formula calculates derivatives. + var_names: tuple[str, ...] + + # Saved inputs that are referenced by the formula. + saved_inputs: tuple[SavedAttribute, ...] + + # Saved outputs that are referenced by the formula. + saved_outputs: tuple[SavedAttribute, ...] + + # Gradients that are referenced by name in the formula. + named_gradients: set[str] + + +# Represents a forward formula that calculates forward derivatives +# for one tensor. +@dataclass(frozen=True) +class ForwardDerivative: + # The formula string (legit C++ expression). + # Note that special keywords such as "linear" or "element_wise" have been + # replaced by the automatically generated formula. + formula: str + + # Name of the output arguments for which this formula calculates forward + # derivatives + var_names: tuple[str, ...] + + # Type of the output arguments for which this formula calculates forward + # derivatives + var_types: tuple[Type, ...] + + # Inputs for which the forward derivatives are required for this formula + required_inputs_fw_grad: tuple[str, ...] | None + + # Inputs for which the primal is required for this formula + required_inputs_primal: tuple[str, ...] | None + + # Flag to specify if this formula requires the original value of self + # This is only used by inplace operations + required_original_self_value: bool + + # If this formula is specified in derivatives.yaml or if we are re-using the + # out of place formula for inplace + is_reusing_outplace_formula: bool + + +# Represents differentiability info for a NativeFunction. +@dataclass(frozen=True) +class DifferentiabilityInfo: + # The base name read from derivatives.yaml. + name: str + + # The matching native function. + # + # There can be multiple NativeFunction having the same base name: + # - different overloads with different types of input arguments; + # - in-place/out/functional variants of the same function; + # + # We first use the schema string (under the 'name' key) in derivatives.yaml + # to find the NativeFunction having the same schema string. + # Then we find the in-place/out/functional variants of the matching function. + # Among these variants, we choose the one having the same name as the + # derivatives.yaml entry. If there is no exact match, then we choose the + # in-place variant. + # TODO: maybe the logic to search for all variants is no longer necessary? + func: NativeFunction + + # The name of the generated autograd function. + # It's set only if we will calculate a derivative, i.e. + # 'args_with_derivatives' is not empty. + op: str | None + + # The derivatives formulae for this function. + # Note that the length of this sequence is the number of differentiable inputs + derivatives: Sequence[Derivative] + + # The forward derivatives formulae for this function. + # Note that the length of this sequence is the number of differentiable outputs + forward_derivatives: Sequence[ForwardDerivative] + + # The union of 'saved_inputs' of all 'derivatives'. + all_saved_inputs: Sequence[SavedAttribute] + + # The union of 'saved_outputs' of all 'derivatives'. + all_saved_outputs: Sequence[SavedAttribute] + + # All named gradients that are available for use, in the same + # order as in the grads vector. + available_named_gradients: Sequence[str] + + # The named gradients that are used in any of the derivatives. + # Invariant: all(name in available_named_gradients for name in used_named_gradients) + used_named_gradients: set[str] + + # The function's input arguments for which it calculates derivatives. + # It's the union of 'var_names' of all 'derivatives', sorted by the + # argument order in the function schema. + args_with_derivatives: Sequence[Binding] + + # Names of arguments whose derivative formula is 'non_differentiable'. + non_differentiable_arg_names: Sequence[str] + + # Raw data read from derivatives.yaml. + output_differentiability: list[bool] | None + + # output_differentiability in derivatives.yaml can be a list of + # conditions that express if the output is differentiable. In this case, + # the number of conditions must match the number of outputs + # (NB: we only support one condition right now). + # output_differentiability gets populated with True for each condition, + # while output_differentiability_conditions gets populated with the conditions + output_differentiability_conditions: list[str] | None + + @property + def has_derivatives(self) -> bool: + return len(self.args_with_derivatives) > 0 + + # Generates a new DifferentiabilityInfo using the exact same set of derivative information, + # but with a new operator name. + # This is used when generating "copy" variants of view ops, + # which are able to use the exact same derivative formula as the original view op + # See Note [Codegen'd {view}_copy Operators] + def create_view_copy_from_view_derivative( + self, g: NativeFunctionsViewGroup + ) -> DifferentiabilityInfo | None: + if g.view_copy is None: + return None + f = g.view_copy + + name_split_by_period = self.name.split(".", maxsplit=2) + # Append a "_copy" to the base name of the operator (but keep the overload name the same) + view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join( + name_split_by_period[1:] + ) + view_copy_op_name = None if self.op is None else f"{self.op}_copy" + + return DifferentiabilityInfo( + # Use the "_copy" version of name/func/op + name=view_copy_name, + func=f, + op=view_copy_op_name, + # But keep all derivative info the same + derivatives=self.derivatives, + forward_derivatives=self.forward_derivatives, + all_saved_inputs=self.all_saved_inputs, + all_saved_outputs=self.all_saved_outputs, + available_named_gradients=self.available_named_gradients, + used_named_gradients=self.used_named_gradients, + args_with_derivatives=self.args_with_derivatives, + non_differentiable_arg_names=self.non_differentiable_arg_names, + output_differentiability=self.output_differentiability, + output_differentiability_conditions=self.output_differentiability_conditions, + ) + + +def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool: + if info is None: + return False + for derivative in info.derivatives: + formula = derivative.formula + if re.search(IDENT_REGEX.format(ident), formula): + return True + return False + + +def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool: + return uses_ident(info, "retain_variables") + + +def uses_single_grad(info: DifferentiabilityInfo | None) -> bool: + return uses_ident(info, "grad") + + +# Represents a differentiable `Argument`. +# How is it different from the `Argument` type? +# - It's processed Arguments which are differentiable and only used in the +# context of the autograd codegen; +# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument; +@dataclass(frozen=True) +class DifferentiableInput: + name: str + type: Type + + # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. + cpp_type: str + + +# Represents a differentiable `Return`. +# How it it different from the `Return` type? +# - The name in `Return` is optional. Here it is always populated using the same +# `cpp.return_names()` method. +# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant? +# - It's processed Returns which are differentiable, in compliance with the +# `output_differentiability` field defined in derivatives.yaml (if specified), +# and are only used in the context of the autograd codegen; +@dataclass(frozen=True) +class DifferentiableOutput: + name: str + type: Type + + # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. + cpp_type: str + + +@dataclass(frozen=True) +class NativeFunctionWithDifferentiabilityInfo: + func: NativeFunction + info: dict[str, DifferentiabilityInfo] | None + fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None + + +# TODO: Update comment below since it is out of date. +def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str: + """How are we going to call the underlying implementation of a + declaration? There are two strategies: + - use_derived: we want to call the implementation on CPUDoubleType + (or a similar, derived Type instance). Because these derived + instances deal in Tensors, not Variables (it's a completely different + object, so it doesn't dispatch back to VariableType), code on + this dispatch path needs to wrap/unwrap tensors. If the + derived implementation takes and returns tensors, the + implementation is usually differentiable (although we also use + the derived dispatch path for non-differentiable functions + that we still want to dispatch on the derived Type instance; + e.g., size()) + - use_type: we want to call the implementation on Type, because + it is implemented concretely, and the functions it invokes will + get dispatched back to VariableType (which will ensure that they + are differentiable.) + """ + # fn is derived as long as any of its per-key differentiability infos + # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType + # and ADInplaceOrViewType. We want to generate these functions as long as a + # derivative is defined for ANY dispatch key. + if fn.func.is_abstract or ( + fn.info is not None and any(info.has_derivatives for info in fn.info.values()) + ): + # If the function is abstract (not implemented on at::Type), we must + # call the implementation on the derived type with unpacked tensors. + + # If the function has a derivative specified and is concrete, we could + # call either implementation. We prefer the calling the derived + # type's implementation with unpacked tensors because it is more + # performant in some cases: any internal calls to other ATen functions + # won't have the history tracked. + + # If the function has a type dispatched argument (i.e. is a factory), + # we prefer calling the derived type's implementation both because it is + # more performant and to ensure factory functions return tensors with _version + # of 0 (probably not strictly necessary, but nice to have to keeps versions simple + # to understand. + + return "use_derived" + else: + # If the function is concrete (we don't have to override it) and we + # didn't declare it in derivatives.yaml, we'll assume that it is + # actually implemented out of differentiable functions. (This + # assumption might not hold, but then you'll see gradcheck fail.) + return "use_type" + + +def is_foreach_func(f: NativeFunction) -> bool: + return f.func.name.name.base.startswith("_foreach_") + + +# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind +# is functional for their backward derivatives (and forward derivatives in the future), i.e., +# they would find such one in `functional_info_by_signature`. There however are some exceptions: +_foreach_with_inplace_ref = {"_foreach_zero_"} +_foreach_with_tensor_overload = { + "_foreach_add.Tensor", + "_foreach_mul.Tensor", + "_foreach_div.Tensor", +} +# The following do not support the alpha kwarg, which the nonforeach versions support. +_skip_argument_len_check = { + "_foreach_add.Scalar", + "_foreach_add_.Scalar", + "_foreach_add.ScalarList", + "_foreach_add_.ScalarList", + "_foreach_sub.Scalar", + "_foreach_sub_.Scalar", + "_foreach_sub.ScalarList", + "_foreach_sub_.ScalarList", +} + + +# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function +# reference to generate derivatives. +def is_reference_for_foreach( + f: NativeFunction, + function_schema: FunctionSchema, +) -> bool: + return ( + f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base + and ( + not function_schema.name.name.inplace + or str(f.func.name) in _foreach_with_inplace_ref + ) + and ( + str(f.func.name) in _skip_argument_len_check + or len(f.func.arguments.flat_non_out) + == len(function_schema.arguments.flat_non_out) + ) + and all( + ref_arg.type in (arg.type, getattr(arg.type, "elem", None)) + for arg, ref_arg in zip( + f.func.arguments.flat_non_out, + function_schema.arguments.flat_non_out, + ) + ) + ) + + +# TODO(crcrpar): Avoid hard coding "Default" ideally. +def gen_foreach_derivativeinfo( + foreach_function: NativeFunction, + functional_info_by_signature: dict[ + FunctionSchema, dict[str, DifferentiabilityInfo] + ], + non_functional_info_by_signature: dict[ + FunctionSchema, dict[str, DifferentiabilityInfo] + ], + dispatch_key: str = "Default", +) -> tuple[DifferentiabilityInfo | None, bool]: + """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place. + + The second return value indicates whether the info is generated in this function. + """ + ref_diff_info: DifferentiabilityInfo | None = None + + for function_schema, diff_info in functional_info_by_signature.items(): + if not is_reference_for_foreach(foreach_function, function_schema): + continue + ref_diff_info = diff_info[dispatch_key] + if ref_diff_info is not None: + break + # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature + # while the info of `zero_` is in non_functional_info_by_signature + if ( + ref_diff_info is None + and foreach_function.func.kind() == SchemaKind.inplace + and str(foreach_function.func.name) in _foreach_with_inplace_ref + ): + for function_schema, diff_info in non_functional_info_by_signature.items(): + if not is_reference_for_foreach(foreach_function, function_schema): + continue + ref_diff_info = diff_info[dispatch_key] + if ref_diff_info is not None: + break + if ref_diff_info is None: + return None, False + + # non out-place uses the existing Derivative. + if foreach_function.func.kind() == SchemaKind.inplace: + return ref_diff_info, False + + map_refarg2foreacharg, map_name2arg = {}, {} + for i, (arg, ref_arg) in enumerate( + zip( + foreach_function.func.arguments.flat_non_out, + function_schema.arguments.flat_non_out, + ) + ): + map_refarg2foreacharg[ref_arg.name] = arg.name + map_name2arg[arg.name] = arg + + all_saved_inputs, all_saved_outputs, all_var_names = [], [], [] + modified_derivative_formulas = [] + for i, derivative in enumerate(ref_diff_info.derivatives): + modified_formula = derivative.formula.replace("grad", "grads[i]").replace( + "result", "result[i]" + ) + saved_inputs, saved_outputs = [], [] + # note(crcrpar): This context seems necessary to call `cpp.argument_type` + with local.parametrize( + use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, + use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, + ): + for ref_input in derivative.saved_inputs: + ref_input_jit_name = ref_input.expr.split(".")[0] + mapped_name = map_refarg2foreacharg[ref_input_jit_name] + if isinstance(map_name2arg[mapped_name].type, ListType): + mapped_expr = mapped_name + "[i]" + else: + mapped_expr = mapped_name + new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr) + modified_formula = modified_formula.replace( + cast(str, ref_input.nctype.name), new_expr + ) + + nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name) + canonical_nctype = NamedCType( + nctype.name, nctype.type.remove_const_ref() + ) + saved_inputs.append( + SavedAttribute(nctype=canonical_nctype, expr=mapped_name) + ) + for ref_output in derivative.saved_outputs: + if ref_output.nctype.name == "result": + saved_outputs.append( + SavedAttribute( + nctype=NamedCType( + name="result", type=BaseCType(tensorListT) + ), + expr="result", + ) + ) + else: + raise RuntimeError("") + var_names = [map_refarg2foreacharg[var] for var in derivative.var_names] + all_var_names.extend(var_names) + all_saved_inputs.extend(saved_inputs) + all_saved_outputs.extend(saved_outputs) + modified_derivative = Derivative( + formula=modified_formula, + original_formula=derivative.formula, + var_names=tuple(var_names), + saved_inputs=tuple(saved_inputs), + saved_outputs=tuple(saved_outputs), + named_gradients=set(), + ) + modified_derivative_formulas.append(modified_derivative) + + with local.parametrize( + use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, + use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, + ): + args_with_derivatives = [ + Binding( + name=arg.name, + nctype=cpp.argument_type(arg, binds=arg.name), + argument=arg, + default=None, + ) + for arg in foreach_function.func.arguments.flat_non_out + if arg.name in all_var_names + ] + + forward_derivatives: list[ForwardDerivative] = [] + fw_derivative: ForwardDerivative + for fw_derivative in ref_diff_info.forward_derivatives: + var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef] + var_types: list[Type] = list(fw_derivative.var_types) + required_inputs_fw_grad: list[str] = [] + required_inputs_primal: list[str] = [] + if fw_derivative.required_inputs_fw_grad is not None: + required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad) + if fw_derivative.required_inputs_primal: + required_inputs_primal = list(fw_derivative.required_inputs_primal) + modified_formula = fw_derivative.formula + + # Foreach's result is TensorList + if "result" in modified_formula: + modified_formula = fw_derivative.formula.replace("result", "result[i]") + + for foreach_arg, ref_arg in zip( + foreach_function.func.arguments.flat_non_out, + ref_diff_info.func.func.arguments.flat_non_out, + ): + # Modify reference forward formula + if ( + isinstance(foreach_arg.type, ListType) + and not foreach_arg.type.is_tensor_like() + ): + # Assuming ScalarList + modified_formula = modified_formula.replace( + ref_arg.name, foreach_arg.name + "[i]" + ) + elif foreach_arg.type.is_tensor_like(): + # Assuming TensorList / Tensor + # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}" + assert isinstance(foreach_arg.type, ListType) or ( + foreach_arg.type == BaseType(BaseTy.Tensor) + and str(foreach_function.func.name) in _foreach_with_tensor_overload + ), f"{foreach_function.func.name}, {foreach_arg.type}" + for suffix in ("_p", "_t"): + curr_expr = ref_arg.name + suffix + if curr_expr in modified_formula: + new_expr = foreach_arg.name + suffix + modified_formula = modified_formula.replace(curr_expr, new_expr) + else: + # Assuming Scalar + if foreach_arg.name != ref_arg.name: + modified_formula = modified_formula.replace( + ref_arg.name, foreach_arg.name + ) + + # note(crcrpar): there should exist a cooler way... + for i, name in enumerate(var_names): + if name == ref_arg.name: + var_names[i] = foreach_arg.name + var_types[i] = foreach_arg.type + for i, name in enumerate(required_inputs_fw_grad): + if name == ref_arg.name: + required_inputs_fw_grad[i] = foreach_arg.name + for i, name in enumerate(required_inputs_primal): + if name == ref_arg.name: + required_inputs_primal[i] = foreach_arg.name + forward_derivatives.append( + ForwardDerivative( + formula=modified_formula, + var_names=tuple(var_names), + var_types=tuple(var_types), + required_inputs_fw_grad=tuple(required_inputs_fw_grad), + required_inputs_primal=tuple(required_inputs_primal), + required_original_self_value=fw_derivative.required_original_self_value, + is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula, + ) + ) + + return ( + DifferentiabilityInfo( + name=foreach_function.func.name.name.base, + func=foreach_function, + op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}", + derivatives=modified_derivative_formulas, + forward_derivatives=forward_derivatives, + all_saved_inputs=tuple(set(all_saved_inputs)), + all_saved_outputs=tuple(set(all_saved_outputs)), + available_named_gradients=(), + used_named_gradients=set(), + args_with_derivatives=args_with_derivatives, + non_differentiable_arg_names=[], + output_differentiability=None, + output_differentiability_conditions=None, + ), + True, + ) + + +def match_differentiability_info( + native_functions: list[NativeFunction], + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], +) -> list[NativeFunctionWithDifferentiabilityInfo]: + """Sets the "derivative" key on declarations to matching autograd function + In-place functions will use the out-of-place derivative definition if there + is no in-place specific derivative. + """ + + functional_info_by_signature = { + schema.signature(strip_default=True): info_dict + for schema, info_dict in differentiability_infos.items() + if schema.kind() == SchemaKind.functional + } + non_functional_info_by_signature = { + schema.signature(strip_default=True): info_dict + for schema, info_dict in differentiability_infos.items() + if schema.kind() != SchemaKind.functional + } + + def find_info( + f: NativeFunction, + ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]: + # Don't bother matching info to generated out= variants + if "generated" in f.tags and f.func.kind() == SchemaKind.out: + return None, False + + # (1) Check for an exact match + if f.func in differentiability_infos: + return differentiability_infos[f.func], True + + # (2) If no exact match, check if the out-of-place variant + # of this operator has a match. + # i.e mul() for mul_() or mul_out() + # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing + # native functions instead of the out-place counterparts. + f_sig = f.func.signature(strip_default=True) + if f_sig in functional_info_by_signature and not is_foreach_func(f): + return functional_info_by_signature[f_sig], False + + # (3) Some operators have a derivative explicitly defined for the mutable + # variant, but get a code-generated out-of-place variant which does *not* + # come with a derivative formula. + # For the generated out-of-place variant, use the mutable variant's formula + # if it exists. + if "generated" in f.tags and f_sig in non_functional_info_by_signature: + info_dict = non_functional_info_by_signature[f_sig] + # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389 + assert not any( + any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs) + for info in info_dict.values() + ), f"""\ +Attempted to convert a derivative formula for a mutable operator + to be used by automatically by its functional variant ("{str(f.func)}"). + this is not currently supported (we'd need to fix up the formula in the codegen).""" + return info_dict, False + + # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml` + if is_foreach_func(f): + assert f.func not in differentiability_infos + diff_info, is_generated = gen_foreach_derivativeinfo( + f, + functional_info_by_signature, + non_functional_info_by_signature, + ) + if diff_info is None: + return None, False + # TODO(crcrpar): Avoid hard coding "Default" ideally. + diff_info_dict = {"Default": diff_info} + if is_generated: + differentiability_infos[f.func] = diff_info_dict + functional_info_by_signature[f.func] = diff_info_dict + return diff_info_dict, is_generated + + return None, False + + result: list[NativeFunctionWithDifferentiabilityInfo] = [] + for f in native_functions: + info_dict, is_exact_match = find_info(f) + + # Currently, the '.strides()' to 'strides_or_error' replacement does not support + # 'self' derivatives of an inplace function, so we must check for this case. + if f.func.kind() == SchemaKind.inplace and (info_dict is not None): + for info in info_dict.values(): + for derivative in info.derivatives: + if "self" in derivative.var_names: + for saved_input in derivative.saved_inputs: + assert "strides_or_error" not in saved_input.expr, ( + "Calling '.strides()' in the 'self' derivative formula of an " + f"in-place function is not supported: {f.func}" + ) + + if not info_dict: + result.append( + NativeFunctionWithDifferentiabilityInfo( + func=f, info=None, fw_derivatives=None + ) + ) + continue + + fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {} + for key, info in info_dict.items(): + if not info.forward_derivatives: + fw_derivative_dict[key] = [] + continue + + forward_derivatives = info.forward_derivatives + + # For functions that have a single def for out-of-place and inplace (like abs()) + if f.func.kind() == SchemaKind.inplace: + # For inplace functions there is a little bit of work to do: + # 1) Validate the formula and make sure the input that is modified in not used: + # - If there is a formula for the inplace variant of the function (is_exact_match == True) then + # we make sure that the original value of the input that is being modified inplace (self_p) is + # not used in the formula. Note that the formula can use "original_self_p" here and that would + # trigger a clone of the original input. + # - If we are re-using the out of place formula (is_exact_match == False) then we replace every + # occurrence of self_p and self_t by original_self_p and original_self_t. These will be + # populated by cloned version of the original input (either the clone done by the backward AD + # logic if self is also used in a backward formula or a special clone that we add). + # 2) At this point, there cannot be a self_p in the formula. + # 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is + # simply called self (as it is modified inplace). + # 4) Update the required primals data in case it used to contain "result" but should now contain + # "self" + # 5) If it is not an exact match, the user formula is not modifying the existing forward grad + # inplace as it should. So add some code that makes sure that we do so if the forward grad + # already exists. + + assert ( + len(info.forward_derivatives) == 1 + ) # Only single output inplace should exist + fw_info = info.forward_derivatives[0] + formula = fw_info.formula + + def replace_self_with_original_self(formula: str, postfix: str) -> str: + def repl(m: re.Match[str]) -> str: + return f"{m.group(1)}original_self{postfix}{m.group(2)}" + + return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula) + + if re.search(IDENT_REGEX.format("self_p"), formula): + if is_exact_match: + # For manually defined formulas, don't allow the original value to be used + raise RuntimeError( + f'The formula for "{f.func.name}" is using the original value of self ' + "that is being modified inplace. This would lead to wrong forward gradients. " + 'Please use "result" in the formula only.' + ) + else: + # When the original formula is out of place, we save a clone of the primal + # value to be able to access this value if needed + # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t" + formula = replace_self_with_original_self(formula, "_p") + formula = replace_self_with_original_self(formula, "_t") + + # replace "result" from the formula by "self_p" + def repl(m: re.Match[str]) -> str: + return f"{m.group(1)}self_p{m.group(2)}" + + formula = re.sub(IDENT_REGEX.format("result"), repl, formula) + + required_primals = fw_info.required_inputs_primal + if re.search(IDENT_REGEX.format("self_p"), formula): + required_primals = ( + required_primals + ("self",) if required_primals else ("self",) + ) + + if not is_exact_match: + # NOTE [In-place forward AD formula Optimization] + # + # This optimization transforms the formula to directly do inplace, i.e. + # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met: + # + # 1) the formula satisfies the pattern: "self_t.op(*args)" + # 2) "op" in (1) needs to be the same as the op the derivative is for + # + # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2) + # If there is a need, we can relax (2) to allow any op that has an in-place variant + is_single_method_on_self_t = False + directly_do_inplace = False + op_name: str | None = None + between_parens: str | None = None + match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula) + if match: + op_name, between_parens = match.group(1), match.group(2) + + # We want to... + # Match: self_t.op1(other_p.op2(arg)) + # Avoid: self_t.op1(args) + self_t.op2(args) + # Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args) + def check_parens_nest_level_gt_zero(s: str) -> bool: + level = 1 + for ch in s: + if ch == ")": + level -= 1 + if level == 0: + return False + if ch == "(": + level += 1 + return True + + is_single_method_on_self_t = check_parens_nest_level_gt_zero( + between_parens + ) + directly_do_inplace = ( + is_single_method_on_self_t and op_name == info.name + ) + + if directly_do_inplace: + assert op_name is not None + assert between_parens is not None + formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}" + else: + # Make sure that the forward grad is modified inplace when the original formula + # is out of place + formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}" + + required_original_self_value = bool( + re.search(IDENT_REGEX.format("original_self_p"), formula) + ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula)) + + forward_derivatives = [ + ForwardDerivative( + formula=formula, + var_names=("self",), + var_types=fw_info.var_types, + required_inputs_fw_grad=fw_info.required_inputs_fw_grad, + required_inputs_primal=required_primals, + required_original_self_value=required_original_self_value, + is_reusing_outplace_formula=not is_exact_match, + ), + ] + + fw_derivative_dict[key] = forward_derivatives + + result.append( + NativeFunctionWithDifferentiabilityInfo( + func=f, info=info_dict, fw_derivatives=fw_derivative_dict + ) + ) + + return result + + +def is_differentiable( + name: str, type: Type, info: DifferentiabilityInfo | None +) -> bool: + return type.is_tensor_like() and ( + info is None or name not in info.non_differentiable_arg_names + ) + + +def gen_differentiable_outputs( + fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" +) -> list[DifferentiableOutput]: + f = fn.func + info = fn.info[key] if fn.info else None + outputs: list[DifferentiableOutput] = [ + DifferentiableOutput( + name=name, + type=ret.type, + cpp_type=cpp.return_type(ret, symint=True).cpp_type(), + ) + for name, ret in zip(cpp.return_names(f), f.func.returns) + ] + output_differentiability = info.output_differentiability if info else None + if output_differentiability is not None: + if len(output_differentiability) != len(outputs): + raise RuntimeError( + f"The length of output_differentiability ({len(output_differentiability)}), " + f"does not match the number of outputs ({len(outputs)})." + ) + differentiable_outputs: list[DifferentiableOutput] = [] + if False in output_differentiability and f.func.kind() == SchemaKind.inplace: + raise RuntimeError( + "output_differentiability=False for inplace operation (version_counter won't get updated)" + ) + for differentiable, output in zip(output_differentiability, outputs): + if differentiable: + differentiable_outputs.append(output) + return differentiable_outputs + candidate_differentiable_outputs = list( + filter(lambda r: is_differentiable(r.name, r.type, info), outputs) + ) + if uses_single_grad(info): + return candidate_differentiable_outputs[:1] + else: + return candidate_differentiable_outputs diff --git a/lib/python3.10/site-packages/torchgen/api/cpp.py b/lib/python3.10/site-packages/torchgen/api/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..c657570ee3e2494053e0a618d5707bcb5def0d19 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/cpp.py @@ -0,0 +1,472 @@ +from __future__ import annotations + +from typing import Sequence + +from torchgen import local +from torchgen.api.types import ( + ArgName, + ArrayCType, + ArrayRefCType, + BaseCType, + BaseTypeToCppMapping, + Binding, + boolT, + ConstRefCType, + CType, + dimnameListT, + intArrayRefT, + iTensorListRefT, + ListCType, + longT, + MutRefCType, + NamedCType, + OptionalCType, + optionalIntArrayRefT, + optionalSymIntArrayRefT, + scalarT, + SpecialArgName, + symIntArrayRefT, + SymIntT, + tensorListT, + tensorOptionsT, + tensorT, + TupleCType, + VectorCType, + voidT, +) +from torchgen.model import ( + Argument, + Arguments, + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + OptionalType, + Return, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import assert_never + + +# This file describes the translation of JIT schema to the public C++ +# API, which is what people use when they call functions like at::add. +# +# Prominent characteristics of the C++ API: +# +# - dtype, layout, device and pin_memory are collected into +# a single C++ type TensorOptions (the native functions API +# also has this, but tensor options is really most relevant +# for the C++ API; it makes calling kwarg factory functions +# pleasant) +# +# - defaulting lives here (in fact, the dispatcher is completely +# oblivious of defaults!) +# +# BTW: policy on name collisions: we try not to have types with +# collisions, but functions are fair game to collide + + +def name( + func: FunctionSchema, + *, + faithful_name_for_out_overloads: bool = False, + symint_overload: bool = False, +) -> str: + name = str(func.name.name) + if symint_overload: + name += "_symint" + if func.is_out_fn(): + if faithful_name_for_out_overloads: + name += "_outf" + else: + name += "_out" + + return name + + +# Translation of "value types" in JIT schema to C++ API type. Value +# types look the same no matter if they are argument types or return +# types. Returns None if the type in question is not a value type. +def valuetype_type( + t: Type, + *, + binds: ArgName, + mutable: bool = True, + remove_non_owning_ref_types: bool = False, + symint: bool = False, +) -> NamedCType | None: + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: + return None + elif str(t) == "SymInt": + if symint: + return NamedCType(binds, BaseCType(SymIntT)) + else: + return NamedCType(binds, BaseCType(longT)) + if remove_non_owning_ref_types: + if t.name == BaseTy.str: + raise AssertionError( + "string ref->value conversion: not implemented yet" + ) + # All other BaseType currently map directly to BaseCppTypes. + return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) + elif isinstance(t, OptionalType): + elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint) + if elem is None: + return None + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + if str(t.elem) == "bool": + assert t.size is not None + return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size)) + else: + return None + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Translation of types occurring in JIT arguments to a C++ argument type. +# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type. +# For example, we'll return std::vector instead of IntArrayRef. +# See Note [translation from C++ reference to value types] +def argumenttype_type( + t: Type, + *, + mutable: bool, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = False, +) -> NamedCType: + # If it's a value type, do the value type translation + r = valuetype_type( + t, + binds=binds, + mutable=mutable, + symint=symint, + remove_non_owning_ref_types=remove_non_owning_ref_types, + ) + if r is not None: + return r + + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType(binds, MutRefCType(BaseCType(tensorT))) + else: + return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) + elif t.name == BaseTy.Scalar: + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + else: + raise AssertionError(f"base type should have been value type {t}") + elif isinstance(t, OptionalType): + if str(t.elem) == "Tensor": + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType( + binds, MutRefCType(BaseCType(tensorT)) + ) # TODO: fix this discrepancy + else: + return NamedCType( + binds, ConstRefCType(OptionalCType(BaseCType(tensorT))) + ) + elif str(t.elem) == "Scalar": + return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) + elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": + return NamedCType(binds, BaseCType(optionalIntArrayRefT)) + elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt": + if symint: + return NamedCType(binds, BaseCType(optionalSymIntArrayRefT)) + else: + return NamedCType(binds, BaseCType(optionalIntArrayRefT)) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint) + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + # TODO: remove these special cases, ArrayRef fallthrough works fine + if str(t.elem) == "int": + if remove_non_owning_ref_types: + return NamedCType(binds, VectorCType(BaseCType(longT))) + else: + return NamedCType(binds, BaseCType(intArrayRefT)) + if str(t.elem) == "SymInt": + if remove_non_owning_ref_types: + if symint: + return NamedCType(binds, VectorCType(BaseCType(SymIntT))) + else: + return NamedCType(binds, VectorCType(BaseCType(longT))) + else: + if symint: + return NamedCType(binds, BaseCType(symIntArrayRefT)) + else: + return NamedCType(binds, BaseCType(intArrayRefT)) + if str(t.elem) == "Tensor": + if local.use_ilistref_for_tensor_lists(): + return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) + else: + return NamedCType(binds, BaseCType(tensorListT)) + elif str(t.elem) == "Scalar": + return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) + elif str(t.elem) == "Dimname": + return NamedCType(binds, BaseCType(dimnameListT)) + elif str(t.elem) == "Tensor?": + return NamedCType( + binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) + ) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint) + return NamedCType(binds, ArrayRefCType(elem.type)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Translate a JIT argument into its C++ type +def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds) + + +# Translation of a (non-multi) return type from JIT to C++ +# N.B: returntype_type returns a CType, not a NamedCType. +# This is mostly because of the mismatch between return types and return names. +# e.g. a function with a return type of 'void' has 0 return names, +# and a function with a return type of 'std::tuple' has >1 return name. +def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType: + # placeholder is ignored + # NB: symint is ALWAYS respected for return types. So symint argument + # here is IGNORED + r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True) + if r is not None: + return r.type + + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + if mutable: + if local.use_const_ref_for_mutable_tensors(): + return ConstRefCType(BaseCType(tensorT)) + else: + return MutRefCType(BaseCType(tensorT)) + else: + # Note [Tensor Copy Returns] + # Currently, we use "Argument.is_write" to determine + # whether or not Tensor return types should be copies or references. + # If that ever changes, take a look at other locations of this note! + return BaseCType(tensorT) + elif t.name == BaseTy.Scalar: + return BaseCType(scalarT) + elif isinstance(t, ListType): + assert ( + not mutable + ), "Native functions should never return a mutable tensor list. They should return void." + elem = returntype_type(t.elem, mutable=False) + assert t.size is None, f"fixed size list returns not supported: {t}" + return VectorCType(elem) + elif isinstance(t, OptionalType): + elem = returntype_type(t.elem, mutable=mutable) + if str(t.elem) == "Tensor": + return OptionalCType(elem) + + raise AssertionError(f"unrecognized return type {t}") + + +# Translation of a single return to its C++ type +def return_type(r: Return, *, symint: bool = False) -> CType: + return returntype_type(r.type, mutable=r.is_write, symint=symint) + + +# Translation of a full (possibly multi) return from JIT to its C++ type +def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType: + if len(rs) == 0: + return BaseCType(voidT) + elif len(rs) == 1: + return return_type(rs[0], symint=symint) + else: + return TupleCType([return_type(r, symint=symint) for r in rs]) + + +def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: + returns: list[str] = [] + for i, r in enumerate(f.func.returns): + # If we have an inplace function, the return argument is + # implicitly named self. + # TODO: Consider incorporating this into the data model + if f.func.name.name.inplace: + assert i == 0, "illegal inplace function with multiple returns" + name = "self" + # If we are out function, the name is the name of the + # corresponding output function (r.name will get recorded + # in field_name later.) + elif f.func.is_out_fn(): + name = f.func.arguments.out[i].name + # If the return argument is explicitly named... + elif r.name: + name_conflict = any( + r.name == a.name for a in f.func.schema_order_arguments() + ) + if name_conflict and not f.func.is_out_fn(): + name = f"{r.name}_return" + else: + name = r.name + # If there is no explicit name and no fallback name was passed in, we just name the output result, + # unless it's a multi-return, in which case it's result0, + # result1, etc (zero-indexed) + else: + name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}" + returns.append(name) + return returns + + +JIT_TO_CPP_DEFAULT = { + "False": "false", + "True": "true", + "None": "::std::nullopt", # UGH this one is type directed + "Mean": "at::Reduction::Mean", + "[]": "{}", + "contiguous_format": "c10::MemoryFormat::Contiguous", + "long": "at::kLong", +} + + +# Convert a JIT default into C++ expression representing the default +def default_expr(d: str, t: Type, *, symint: bool) -> str: + if d == "None" and str(t) == "Tensor?": + return "{}" + if isinstance(t, BaseType) and t.name is BaseTy.str: + # Schema allows single quotes but C++ needs double + if len(d) >= 2 and d[0] == "'" and d[-1] == "'": + s = "" + i = 1 + while i + 1 < len(d): + if d[i] != "\\": + if d[i] == '"': + s += '\\"' + else: + s += d[i] + i += 1 + else: + if d[i + 1] == "'": + s += "'" + else: + s += d[i : i + 2] + i += 2 + + return f'"{s}"' + + if isinstance(t, OptionalType): + if d == "None": + return "::std::nullopt" + + return default_expr(d, t.elem, symint=symint) + + if isinstance(t, ListType): + if d.startswith("[") and d.endswith("]"): + return "{" + d[1:-1] + "}" + elif symint and d.isdigit() and str(t.elem) == "SymInt": + return f"c10::SymInt({d})" + elif t.size is None: + # NOTE: Sized lists can have scalar defaults + raise ValueError(f"Expected a list default '[...]' but found: '{d}'") + + return JIT_TO_CPP_DEFAULT.get(d, d) + + +# Convert an argument into its C++ API form + + +def argument( + a: Argument | TensorOptionsArguments | SelfArgument, + *, + cpp_no_default_args: set[str], + method: bool, + faithful: bool, + symint: bool = False, + has_tensor_options: bool, +) -> list[Binding]: + def sub_argument( + a: Argument | TensorOptionsArguments | SelfArgument, + ) -> list[Binding]: + return argument( + a, + cpp_no_default_args=cpp_no_default_args, + method=method, + faithful=faithful, + symint=symint, + has_tensor_options=has_tensor_options, + ) + + if isinstance(a, Argument): + binds: ArgName + if a.name == "memory_format" and has_tensor_options: + binds = SpecialArgName.possibly_redundant_memory_format + else: + binds = a.name + default: str | None = None + if a.name not in cpp_no_default_args and a.default is not None: + default = default_expr(a.default, a.type, symint=symint) + return [ + Binding( + nctype=argument_type(a, binds=binds, symint=symint), + name=a.name, + default=default, + argument=a, + ) + ] + elif isinstance(a, TensorOptionsArguments): + if faithful: + return ( + sub_argument(a.dtype) + + sub_argument(a.layout) + + sub_argument(a.device) + + sub_argument(a.pin_memory) + ) + else: + default = None + # Enforced by NativeFunction.__post_init__ + assert "options" not in cpp_no_default_args + if all(x.default == "None" for x in a.all()): + default = "{}" + elif a.dtype.default == "long": + default = "at::kLong" # TODO: this is wrong + return [ + Binding( + nctype=NamedCType("options", BaseCType(tensorOptionsT)), + name="options", + default=default, + argument=a, + ) + ] + elif isinstance(a, SelfArgument): + if method: + # Caller is responsible for installing implicit this in context! + return [] + else: + return sub_argument(a.argument) + else: + assert_never(a) + + +def arguments( + arguments: Arguments, + *, + faithful: bool, + symint: bool = False, + method: bool, + cpp_no_default_args: set[str], +) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + if faithful: + args.extend(arguments.non_out) + args.extend(arguments.out) + else: + args.extend(arguments.out) + args.extend(arguments.non_out) + return [ + r.no_default() if faithful else r + for a in args + for r in argument( + a, + faithful=faithful, + symint=symint, + method=method, + has_tensor_options=arguments.tensor_options is not None, + cpp_no_default_args=cpp_no_default_args, + ) + ] diff --git a/lib/python3.10/site-packages/torchgen/api/dispatcher.py b/lib/python3.10/site-packages/torchgen/api/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..103e6cf429907d1577c3d9caca6f3e28de9e129a --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/dispatcher.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import itertools +from typing import Sequence + +from torchgen.api import cpp +from torchgen.api.types import ArgName, Binding, CType, NamedCType +from torchgen.model import ( + Argument, + FunctionSchema, + Return, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import assert_never, concatMap + + +# This file describes the translation of JIT schema to the dispatcher +# API, the *unboxed* calling convention by which invocations through +# the dispatcher are made. Historically, the dispatcher API matched +# the C++ API, but with the establishment of the boxed API, we've +# made changes to the dispatcher API to so that the unboxed API +# better aligns with the boxed API. The dispatcher API hooks heavily +# into our template based boxing/unboxing machinery, so changes +# to this convention will usually need template updates too. +# +# Prominent characteristics of the dispatcher API: +# +# - dtype, layout, device and pin_memory are represented as separate +# arguments. +# + + +def name(func: FunctionSchema) -> str: + return cpp.name(func) + + +def argumenttype_type( + t: Type, + *, + mutable: bool, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = True, +) -> NamedCType: + # This is a faux amis. If it makes sense in the future to add + # more special cases here, or invert things so cpp.argument_type + # calls this, or just completely inline the function, please do + # it. + return cpp.argumenttype_type( + t, + mutable=mutable, + binds=binds, + symint=symint, + remove_non_owning_ref_types=remove_non_owning_ref_types, + ) + + +def argument_type( + a: Argument, + *, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = True, +) -> NamedCType: + return argumenttype_type( + a.type, + mutable=a.is_write, + binds=binds, + remove_non_owning_ref_types=remove_non_owning_ref_types, + symint=symint, + ) + + +def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType: + # At present, there is no difference. But there could be! + return cpp.returns_type(rs, symint=symint) + + +def jit_arguments(func: FunctionSchema) -> list[Argument]: + def to_argument( + a: Argument | TensorOptionsArguments | SelfArgument, + ) -> list[Argument]: + if isinstance(a, Argument): + return [a] + elif isinstance(a, SelfArgument): + return [a.argument] + elif isinstance(a, TensorOptionsArguments): + return [a.dtype, a.layout, a.device, a.pin_memory] + else: + assert_never(a) + + return list( + concatMap( + to_argument, + itertools.chain( + func.arguments.positional, func.arguments.kwarg_only, func.arguments.out + ), + ) + ) + + +def argument( + a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True +) -> Binding: + return Binding( + nctype=argument_type( + a, + binds=a.name, + remove_non_owning_ref_types=remove_non_owning_ref_types, + symint=symint, + ), + name=a.name, + argument=a, + ) + + +def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]: + return [argument(a, symint=symint) for a in jit_arguments(func)] diff --git a/lib/python3.10/site-packages/torchgen/api/functionalization.py b/lib/python3.10/site-packages/torchgen/api/functionalization.py new file mode 100644 index 0000000000000000000000000000000000000000..93667e39b17fa4ba82414fa92bb7200faf6f6515 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/functionalization.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +from torchgen.api import dispatcher +from torchgen.api.types import ( + BaseCppType, + BaseCType, + Binding, + boolT, + ConstRefCType, + CType, + longT, + NamedCType, + tensorT, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + NativeFunction, + NativeFunctionsViewGroup, +) + + +# This file describes the translation of JIT schema to API's used +# when creating view lambdas that are used by the functionalization pass. +# There are two types of lambdas: forward lambdas and reverse lambdas. +# These API's mostly follow the dispatcher API, with a few quirks: +# - The lambda capture has to convert reference types to value types +# - While the forward lambda just directly calls into the at::_ops API +# (following the dispatcher convention), the logic here for the reverse lambda +# is responsible for generating both the call-site, and the declarations +# (which are implemented manually in the at::functionalization::impl namespace). + +# The lambdas generated for each view op in the functionalization pass are of the form +# [capture_arguments](outer_arguments) -> returns_type { +# return name(inner_arguments); +# } + +# Define some specific lambda input arguments. +base_binding = Binding( + name="base", + nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))), + argument=Argument( + name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None + ), + default=None, +) +mutated_view_binding = Binding( + name="mutated_view", + nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), + argument=Argument( + name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None + ), + default=None, +) +mutated_view_idx_binding = Binding( + name="mutated_view_idx", + nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)), + argument=Argument( + name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None + ), + default=None, +) +reapply_views_binding = Binding( + name="reapply_views", + nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)), + argument=Argument( + name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None + ), + default=None, +) + +InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode") +inverse_return_mode_binding = Binding( + name="inverse_return_mode", + nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)), + argument=Argument( + name="inverse_return_mode", + # NB: not actually a bool but it doesn't matter because this isn't used + type=BaseType(BaseTy.bool), + default=None, + annotation=None, + ), + default=None, +) + + +# The lambda capture itself doesn't have a name. +# The name returned here corresponds to the name of the inner function called by the lambda. +def name( + g: NativeFunctionsViewGroup, + *, + is_reverse: bool, + include_namespace: bool, + reapply_views: bool | None = None, +) -> str: + if reapply_views is None: + # reapply_views is only important for the fwd lambda, + # since we always plumb the runtime "reapply_views" argument into the reverse function. + assert is_reverse + if is_reverse: + return reverse_name(g.view, include_namespace) + # in the forward case, we just directly call into the at::_ops API (so we always need the namespace) + assert include_namespace + assert g.view_copy is not None + api_name = ( + g.view.func.name.unambiguous_name() + if reapply_views + else g.view_copy.func.name.unambiguous_name() + ) + return f"at::_ops::{api_name}::call" + + +def reverse_name(f: NativeFunction, include_namespace: bool) -> str: + # for the reverse: we plumb the "reapply_views" flag into that function and support + # both copy and non-copy variants. (We could avoid doing that, but that would require + # writing out twice as many view inverse functions). + api_name = f.func.name.unambiguous_name() + # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't) + if include_namespace: + return f"at::functionalization::FunctionalInverses::{api_name}_inverse" + else: + return f"{api_name}_inverse" + + +def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]: + # capture arguments include all arguments except `self`. + # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), + # So any reference types (IntArrayRef) need to be converted to value types (vector) + args = func.arguments.flat_all + assert args[0].type == BaseType(BaseTy.Tensor) + non_self_args = args[1:] + non_self_value_bindings = [ + dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args + ] + + all_bindings = [ + inverse_return_mode_binding if is_reverse else reapply_views_binding + ] + all_bindings.extend(non_self_value_bindings) + return all_bindings + + +def returns_type(func: FunctionSchema) -> CType: + # Assertion: all view ops return tensor-like outputs + assert len(func.returns) >= 1 + for ret in func.returns: + assert ret.type.is_tensor_like() + # However, the return type of the lambda is always an individual tensor. + # For multi-tensor outputs, each tensor needs to be tracked individually. + return BaseCType(tensorT) + + +def outer_arguments(*, is_reverse: bool) -> list[Binding]: + if is_reverse: + return [base_binding, mutated_view_binding, mutated_view_idx_binding] + else: + return [base_binding, mutated_view_idx_binding] + + +def inner_call_index(func: FunctionSchema) -> Binding | None: + # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. + # When we replay a view op that returns multiple tensors, we need to index into the output appropriately + if len(func.returns) > 1 or ( + len(func.returns) == 1 and func.returns[0].type.is_list_like() + ): + return mutated_view_idx_binding + return None + + +def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: + args = func.arguments.flat_all + assert args[0].type == BaseType(BaseTy.Tensor) + non_self_args = args[1:] + # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API. + # Both of these follow the dispatcher API. + non_self_bindings = [dispatcher.argument(a) for a in non_self_args] + if not is_reverse: + # the forward lambda swaps out the original tensor argument with the lambd arg "base" + return [base_binding] + non_self_bindings + else: + # the reverse lambda does the same, but with an additional "mutated_view" arg + # additionally, we have a calling convention: for view ops that return multiple tensor outputs + # their corresponding view_inverse function takes in an additional index argument. + index_binding = inner_call_index(func) + if index_binding is not None: + return [ + base_binding, + mutated_view_binding, + inverse_return_mode_binding, + index_binding, + ] + non_self_bindings + else: + return [ + base_binding, + mutated_view_binding, + inverse_return_mode_binding, + ] + non_self_bindings diff --git a/lib/python3.10/site-packages/torchgen/api/lazy.py b/lib/python3.10/site-packages/torchgen/api/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..cfffa516b656b8f479b5bfe16d4a4620fc35f9b0 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/lazy.py @@ -0,0 +1,467 @@ +from __future__ import annotations + +from typing import Any + +from torchgen.api.types import ( + BaseCppType, + BaseCType, + boolT, + CType, + deviceT, + doubleT, + generatorT, + layoutT, + ListCType, + longT, + memoryFormatT, + NamedCType, + OptionalCType, + scalarT, + scalarTypeT, + stringT, + SymIntT, + VectorCType, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + ListType, + OperatorName, + OptionalType, + Return, + TensorOptionsArguments, + Type, +) + + +_valueT: BaseCppType | None = None + + +# A ValueT is an IR type which represents the computation of a Tensor. In other +# words, a PyTorch user will do operations on lazy tensors, and each output lazy +# tensor internally tracks a ValueT representing the IR node that would have +# actually produced the value of this tensor for real. +# +# This is configurable because different lazy tensor backends (LTC vs XLA) will +# have different IR representations. (Though, arguably, after unification they +# shouldn't!) +def getValueT() -> BaseCppType: + global _valueT + if not _valueT: + raise NotImplementedError( + "The value type needs to be set with setValueT() in run_gen_lazy_tensor()" + ) + + return _valueT + + +def setValueT(val: BaseCppType) -> None: + global _valueT + _valueT = val + + +# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object, +# making it easier to represent special properties of an arg. +tensorListValueT = BaseCppType("torch::lazy", "Value") + + +def process_ir_type( + typ: Type, properties: LazyIrProperties, *, symint: bool +) -> BaseCType | VectorCType | OptionalCType | ListCType: + """ + This function takes a type from NativeFunctions and converts it for use with + lazy tensor codegen. + + Type conversion for lazy currently consists of + (1) changing at::Tensors into lazy::Values + (2) wrapping everything in a BaseCType + (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef) + + (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.) + There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like' + + This is incomplete- there are assertions in places that it's expected to need to add + more types as the codegen is used with more operators. + """ + if isinstance(typ, BaseType): + if typ.name == BaseTy.Tensor: + return BaseCType(getValueT()) + elif typ.name == BaseTy.Scalar: + if properties.TreatScalarsAsConstants: + return BaseCType(scalarT) + # at::scalar has special handling, + # and is wrapped in an lazy::Value just like at::tensor + return BaseCType(getValueT()) + elif typ.name == BaseTy.ScalarType: + return BaseCType(scalarTypeT) + elif typ.name == BaseTy.int: + return BaseCType(longT) + elif typ.name == BaseTy.SymInt: + if symint: + return BaseCType(getValueT()) + else: + return BaseCType(longT) + elif typ.name == BaseTy.bool: + return BaseCType(boolT) + elif typ.name == BaseTy.float: + return BaseCType(doubleT) + elif typ.name == BaseTy.str: + return BaseCType(stringT) + elif typ.name == BaseTy.Device: + return BaseCType(deviceT) + elif typ.name == BaseTy.Generator: + return BaseCType(generatorT) + elif typ.name == BaseTy.Layout: + return BaseCType(layoutT) + elif typ.name == BaseTy.MemoryFormat: + return BaseCType(memoryFormatT) + else: + raise AssertionError(f"TODO add support for type {repr(typ)}") + elif isinstance(typ, OptionalType): + return OptionalCType(process_ir_type(typ.elem, properties, symint=symint)) + elif isinstance(typ, ListType): + if str(typ.elem) == "Tensor?": + # TODO(whc) is this actually correct? or should it use a Vector like above + return ListCType(OptionalCType(BaseCType(getValueT()))) + elif str(typ.elem) == "Tensor": + # this is a TensorList which comes in from GetTensorList as a Value + return BaseCType(tensorListValueT) + elif typ.elem == BaseType(BaseTy.SymInt): + # TODO: return a value type. The problem here is analogous to + # the problem with tensorListValueT: if you have SymInt[] you + # cannot conveniently save the list of Value directly, as nodes + # expect to save values as a vector for ALL arguments. So you + # need a separate IR node that represents all of the size nodes + # assembled into a list. I'm not an LTC dev so I don't want to + # figure it out right now. Y'all figure it out... + return VectorCType(BaseCType(longT)) + + else: + return VectorCType(process_ir_type(typ.elem, properties, symint=symint)) + else: + raise AssertionError(f"unrecognized type {repr(typ)}") + + +# TODO: Determining this based off of CType is bad; this should be computed +# from Type directly; then the same logic as process_ir_type can be used +# +# Invariant: passed typ should be an *owning* CType (e.g., we will report +# that ArrayRef is NOT a value type) +def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool: + """ + Given a type, determine if it is a Value-like type. This is equivalent to + being Tensor-like, but assumes the type has already been transformed. + """ + if isinstance(typ, BaseCType): + # I am regretting my naming conventions, but now we are wrapping at::scalar in + # lazy value, while preserving other 'scalar' types as scalars in the IR + treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants + return ( + typ.type == getValueT() + or (typ.type == scalarT and not treat_scalars_as_constants) + or typ.type == SymIntT + ) + elif typ == VectorCType(BaseCType(SymIntT)): + # TODO: report True for this + return False + elif isinstance(typ, (OptionalCType, ListCType, VectorCType)): + return isValueType(typ.elem, properties) + return False + + +def isSymIntType(typ: Type) -> bool: + return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt + + +def isWrappedScalarType(typ: Type) -> bool: + """ + Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value. + Since we literally change the type from scalarT to valueT, information is lost. + This function helps build a list of wrapped scalars to save that information + """ + if isinstance(typ, BaseType): + # I am regretting my naming conventions, but now we are wrapping at::scalar in + # lazy value, while preserving other 'scalar' types as scalars in the IR + return typ.name == BaseTy.Scalar + elif isinstance(typ, (OptionalType, ListType)): + return isWrappedScalarType(typ.elem) + return False + + +# TODO: dedupe with Type.is_generator_like +def isGeneratorType(typ: Type) -> bool: + if isinstance(typ, BaseType): + return typ.name == BaseTy.Generator + elif isinstance(typ, (OptionalType)): + return isGeneratorType(typ.elem) + return False + + +# This class caches a few derived properties computed from an Argument +# and LazyIrProperties +class LazyArgument: + name: str + orig_type: Type + lazy_type_: CType | None + is_wrapped_scalar: bool + is_generator: bool + # TODO: this is lies, it is false for symint list + is_symint_or_list: bool + + # Whether or not we are treating this as symint or not + symint: bool + + # true if this argument is or contains a lazy IR value + is_lazy_value: bool + + def __init__( + self, arg: Argument, properties: LazyIrProperties, *, symint: bool + ) -> None: + self.name = arg.name + self.orig_type = arg.type + self.symint = symint + self.is_optional = isinstance(arg.type, OptionalType) + self.is_generator = isGeneratorType(arg.type) + self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint) + self.is_wrapped_scalar = isWrappedScalarType(arg.type) + self.is_symint_or_list = symint and ( + isSymIntType(arg.type) + or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem)) + # TODO: lists of symints are not currently treated as value types + # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem)) + ) + + self.is_lazy_value = isValueType(self.lazy_type, properties) + + @property + def lazy_type(self) -> CType: + assert ( + self.lazy_type_ is not None + ), f"Attempted to access lazy_type for invalid argument {self.name}" + return self.lazy_type_ + + +class LazyIrProperties: + """Collection of properties for an IR node + + The property groups are listed below. Each group is mutually + exclusive, meaning that only one property from each group can be True + at any one time. The properties can be accessed as if they were normal + attributes. The mutual exclusivity is automatically handled. + """ + + Properties: tuple[tuple[str, ...], ...] = ( + ( + "ShapePrecompute", # Assume shape has been precomputed + "ShapeCompute", # Need to compute the shape on construction + "ShapeCache", # Utilize the shape cache to defer computation + ), + ( + "Lower", # Codegen full lower function + "LowerDeclOnly", # Codegen only lower function declaration + ), + ( + "CanBeReused", # Codegen full reuse function + "CanBeReusedDeclOnly", # Codegen only reuse function declaration + ), + ( + "CreateFn", # Codegen full create function + "CreateFnDeclOnly", # Codegen only create function declaration + ), + ( + "TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values + ), + ) + + def __init__(self, *default_properties: str) -> None: + properties: dict[tuple[str, ...], str | None] = dict.fromkeys( + LazyIrProperties.Properties + ) + self.__dict__["properties"] = properties + for p in default_properties: + setattr(self, p, True) + + def __getattr__(self, key: str) -> Any: + properties = self.__dict__["properties"] + for values in LazyIrProperties.Properties: + if key in values: + return properties[values] == key + + return self.__getattribute__(key) + + def __setattr__(self, key: str, value: Any) -> Any: + properties = self.__dict__["properties"] + for values in LazyIrProperties.Properties: + if key in values: + properties[values] = key if value else None + return value + + raise KeyError(f"Invalid property: {key}") + + +# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node. +# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML), +# but carries type information from a native FunctionSchema modified for use with IR nodes, +# and preserving original argument names. +# +# TODO: This is not idiomatic with how other torchgen APIs transform on schema. +class LazyIrSchema: + # The name of the operator this function schema describes. + name: OperatorName + + positional_args: tuple[LazyArgument, ...] + keyword_args: tuple[LazyArgument, ...] + + # TODO: Need to handle collisions with argument names at some point + returns: tuple[Return, ...] + + # if this schema has a Generator arg, list its orig ctype/name but don't + # build a LazyArgument since lazy IR doesn't support it + generator_arg: NamedCType | None = None + + # original function schema + func: FunctionSchema + + # Whether or not we are code-genning for SymInt or not + symint: bool + + properties: LazyIrProperties = LazyIrProperties( + # default properties + "ShapePrecompute", + "Lower", + "CanBeReused", + ) + opkind: str | None = None + + def __init__( + self, + func: FunctionSchema, + properties: LazyIrProperties | None = None, + *, + symint: bool, + ) -> None: + if properties: + self.properties = properties + + self.func = func + self.symint = symint + positional_args: list[LazyArgument] = [] + for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]: + if arg_field == "self_arg" and func.arguments.self_arg is not None: + arg = func.arguments.self_arg.argument + positional_args.append( + LazyArgument(arg, self.properties, symint=symint) + ) + elif getattr(func.arguments, arg_field) is not None: + positional_args.extend( + LazyArgument(arg, self.properties, symint=symint) + for arg in getattr(func.arguments, arg_field) + ) + self.positional_args = tuple(positional_args) + + keyword_args: list[LazyArgument] = [] + for arg_field in [ + "pre_tensor_options_kwarg_only", + "tensor_options", + "post_tensor_options_kwarg_only", + "out", + ]: + curr_args = getattr(func.arguments, arg_field) + if curr_args is not None: + if isinstance(curr_args, TensorOptionsArguments): + curr_args = curr_args.all() + for arg in curr_args: + if isGeneratorType(arg.type): + assert ( + self.generator_arg is None + ), "We expect there is only one generator arg" + self.generator_arg = NamedCType( + arg.name, arg.type # type:ignore[arg-type] + ) + keyword_args.extend( + LazyArgument(arg, self.properties, symint=symint) + for arg in curr_args + ) + self.keyword_args = tuple(keyword_args) + self.name = func.name + self.returns = func.returns + + @property + def node_name(self) -> str: + """ + Return camel-case version of op in node. + + Note: This function also appends any `overload_name` in the operation. + For example, if the op is `bitwise_and.Tensor`, the returned name + will be `BitwiseAndTensor`. + """ + op_name = f"{self.name.name}_{self.name.overload_name}".lower() + return "".join(word.capitalize() or "" for word in op_name.split("_")) + + @property + def aten_name(self) -> str: + return str(self.name.name) + + @property + def base_name(self) -> str: + return f"{self.name.name.base}" + + def filtered_args( + self, + positional: bool = True, + keyword: bool = True, + values: bool = True, + scalars: bool = True, + generator: bool = True, + ) -> list[LazyArgument]: + # This function maintains the sorted order of arguments but provides different filtered views. + # Some parts of the code care about kwargs vs args (TS lowerings), + # other parts care about whether they need to wrap the arg in a lazy value or leave it alone. + # Generators are special cased, as they are needed for fallback/shape-inference but not supported + # in TS lowerings and therefore also omitted from lazy IR. + args: list[LazyArgument] = [] + if positional: + args.extend(self.positional_args) + if keyword: + args.extend(self.keyword_args) + + if values and scalars and generator: + return args + elif values and scalars: + return [a for a in args if not a.is_generator] + elif values: + return [a for a in args if a.is_lazy_value] + elif scalars: + return [ + a + for a in args + if not a.is_lazy_value and (generator or not a.is_generator) + ] + + return [] + + @property + def positional_values(self) -> list[LazyArgument]: + return self.filtered_args( + positional=True, keyword=False, values=True, scalars=False + ) + + @property + def positional_scalars(self) -> list[LazyArgument]: + return self.filtered_args( + positional=True, keyword=False, values=False, scalars=True + ) + + @property + def keyword_values(self) -> list[LazyArgument]: + return self.filtered_args( + positional=False, keyword=True, values=True, scalars=False + ) + + @property + def keyword_scalars(self) -> list[LazyArgument]: + return self.filtered_args( + positional=False, keyword=True, values=False, scalars=True + ) diff --git a/lib/python3.10/site-packages/torchgen/api/meta.py b/lib/python3.10/site-packages/torchgen/api/meta.py new file mode 100644 index 0000000000000000000000000000000000000000..2e99d151faeaccea7ca47f372fd26f9985ce7249 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/meta.py @@ -0,0 +1,13 @@ +from torchgen.model import NativeFunctionsGroup + + +# Follows dispatcher calling convention, but: +# - Mutable arguments not allowed. Meta functions are always +# written in functional form. Look at FunctionSchema.signature() +# - No tensor returns; instead we return a TensorMeta describing +# the tensor in question + + +def name(g: NativeFunctionsGroup) -> str: + # use the overload name from the functional version + return str(g.functional.func.name).replace(".", "_") diff --git a/lib/python3.10/site-packages/torchgen/api/native.py b/lib/python3.10/site-packages/torchgen/api/native.py new file mode 100644 index 0000000000000000000000000000000000000000..a00e8266b8daa7a2614e516a010cc23c497d6151 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/native.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from typing import Sequence + +from torchgen import local +from torchgen.api import cpp +from torchgen.api.types import ( + ArgName, + BaseCType, + Binding, + boolT, + ConstRefCType, + CType, + deviceT, + layoutT, + ListCType, + MutRefCType, + NamedCType, + OptionalCType, + scalarT, + scalarTypeT, + tensorT, +) +from torchgen.model import ( + Argument, + FunctionSchema, + Return, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import assert_never + + +# This file describes the translation of JIT schema to the native functions API. +# This looks a lot like the C++ API (which makes historical sense, because the +# idea was you wrote native functions to implement functions in the C++ API), +# but over time we have evolved the C++ API without actually changing our +# native:: kernels. The intention is to make native API and dispatcher API +# line up as closely as possible, since this results in the least overhead +# (no translation is needed from dispatcher API to native API). +# +# NB: this is symint aware, you will get the non-SymInt variant for some +# dispatch entries and SymInt for others. + + +def name(func: FunctionSchema) -> str: + name = str(func.name.name) + # TODO: delete this! + if func.is_out_fn(): + name += "_out" + if func.name.overload_name: + name += f"_{func.name.overload_name}" + return name + + +def argumenttype_type( + t: Type, *, mutable: bool, binds: ArgName, symint: bool +) -> NamedCType: + if str(t) == "Tensor?": + tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT)) + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType(binds, MutRefCType(tensor_type)) + else: + return NamedCType(binds, ConstRefCType(tensor_type)) + elif str(t) == "Tensor?[]": + return NamedCType( + binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) + ) + elif str(t) == "Scalar": + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + elif str(t) == "Scalar?": + return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) + return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint) + + +def returns_type(rs: Sequence[Return], *, symint: bool) -> CType: + return cpp.returns_type(rs, symint=symint) + + +def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint) + + +def argument( + a: Argument | SelfArgument | TensorOptionsArguments, + *, + is_out: bool, + symint: bool, +) -> list[Binding]: + # Ideally, we NEVER default native functions. However, there are a number + # of functions that call native:: directly and rely on the defaulting + # existing. So for BC, we generate defaults for non-out variants (but not + # for out variants, where it is impossible to generate an appropriate + # default) + should_default = not is_out + if isinstance(a, Argument): + default: str | None = None + if should_default and a.default is not None: + default = cpp.default_expr(a.default, a.type, symint=symint) + return [ + Binding( + nctype=argument_type(a, binds=a.name, symint=symint), + name=a.name, + default=default, + argument=a, + ) + ] + elif isinstance(a, SelfArgument): + # Erase SelfArgument from the distinction + return argument(a.argument, is_out=is_out, symint=symint) + elif isinstance(a, TensorOptionsArguments): + default = None + if should_default: + default = "{}" + # TODO: Not sure why the arguments assigned here are for + # TensorOptionsArguments and not the constituent pieces. It seems + # to matter + return [ + Binding( + nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))), + name="dtype", + default=default, + argument=a, + ), + Binding( + nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))), + name="layout", + default=default, + argument=a, + ), + Binding( + nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))), + name="device", + default=default, + argument=a, + ), + Binding( + nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))), + name="pin_memory", + default=default, + argument=a, + ), + ] + else: + assert_never(a) + + +def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + args.extend(func.arguments.non_out) + args.extend(func.arguments.out) + return [ + r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn()) + ] diff --git a/lib/python3.10/site-packages/torchgen/api/python.py b/lib/python3.10/site-packages/torchgen/api/python.py new file mode 100644 index 0000000000000000000000000000000000000000..eb0f07489887225b1ee0df12815f1e17f506aaf7 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/python.py @@ -0,0 +1,1519 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence + +from torchgen.api import cpp +from torchgen.api.types import Binding, CppSignature, CppSignatureGroup +from torchgen.gen import pythonify_default +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + OptionalType, + Return, + Type, + Variant, +) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Data Models +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# [Notes] python binding codegen +# +# The Python binding codegen produces code that takes the input list of +# PyObjects, finds the matching ATen C++ function using PythonArgParser, +# converts the PyObjects into C++ types and calls the ATen C++ function: +# +# +--------+ parsing +------------------------+ binding +-----------------------+ +# | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch | +# +--------+ +------------------------+ +-----------------------+ +# +# The following examples demonstrate the data models the Python binding +# codegen needs to deal with and the tasks it needs to accomplish. It +# helps understand the purpose of the new data types we introduced below. +# +# - Function Schema (source of truth) +# +# aten::empty.names(int[] size, *, Dimname[]? names, +# ScalarType? dtype=None, Layout? layout=None, +# Device? device=None, bool? pin_memory=None, +# MemoryFormat? memory_format=None) -> Tensor +# +# - Python Signature +# +# It's used to generate input schema string for PythonArgParser. +# Note: TensorOptions fields are reordered and the additional +# 'requires_grad' field is added: +# +# empty(IntArrayRef size, *, DimnameList? names, +# MemoryFormat? memory_format=None, ScalarType dtype=None, +# Layout layout=torch.strided, Device device=None, +# bool pin_memory=False, bool requires_grad=False) +# +# - C++ Signature +# +# It's used to generate C++ lambda formals & dispatch call. +# Note: the scattered TensorOptions fields are packed into 'options'. +# +# auto dispatch_empty = +# [](IntArrayRef size, std::optional names, +# const TensorOptions & options, +# std::optional memory_format) -> Tensor { +# pybind11::gil_scoped_release no_gil; +# return torch::empty(size, names, options, memory_format); +# }; +# +# - Binding between Python Arguments and C++ Arguments +# +# Given a set of Python Arguments in scope, we need produce the +# binding expressions that translate the Python API into C++ API: +# +# Python Args Cpp Args Binding Exprs +# ----------------------------------------------------------------- +# 0: size size '_r.intlist(0)' +# 1: names names 'names' [special init] +# 2: memory_format -------+ +# 3: dtype -----+-|--> options 'options' [special packing] +# 4: layout / | +# 5: device / +--> memory_format '_r.memoryformatOptional(2)' +# 6: pin_memory / +# 7: requires_grad -+ +# +# So the full dispatch expression would look like: +# +# dispatch_empty(_r.intlist(0), names, options, +# _r.memoryformatOptional(2)) +# +# Where does 'names' come from? It involves special local init: +# +# auto __names = _r.toDimnameListOptional(1); +# std::optional names = +# __names ? std::make_optional(DimnameList(__names.value())) +# : std::nullopt; +# +# Where does 'options' come from? It involves special local init +# for TensorOptions. Note that Python side has the additional +# 'requires_grad' field: +# +# const auto options = TensorOptions() +# .dtype(_r.scalartype(3)) +# .device(_r.device(5)) +# .layout(_r.layoutOptional(4)) +# .requires_grad(_r.toBool(7)) +# .pinned_memory(_r.toBool(6)); +# +# In some other cases one Python Argument can map to multiple C++ +# Arguments. For example: +# +# aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) +# -> (Tensor values, Tensor indices) +# +# Python Args Cpp Args Binding Exprs +# --------------------------------------------------------------------- +# +----> max 'out[0]' +# /-----> max_values 'out[1] +# 0: input / self '_r.tensor(0)' +# 1: dim / dim '_r.dimname(1)' +# 2: keepdim / keepdim '_r.toBool(2)' +# 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)' +# +# As demonstrated above, the binding can involve reordering, +# packing, unpacking and special local inits. +# +# +# Let's look at a concrete example: +# +# static PythonArgParser parser({ +# "abs(Tensor input, *, Tensor out=None)", +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- Python Schema, represented by PythonSignature and PythonArgument +# +# }, /*traceable=*/true); +# +# ParsedArgs<2> parsed_args; +# auto _r = parser.parse(nullptr, args, kwargs, parsed_args); +# +# ... +# +# if (_r.isNone(1)) { +# ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out') +# represented by PythonArgParserOutputExpr +# +# // aten::abs(Tensor self) -> Tensor +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- NativeFunction schema, base version +# +# auto dispatch_abs = [](const Tensor & self) -> Tensor { +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- dispatch_lambda_args / dispatch_lambda_return_str +# generated from NativeFunction / CppSignature +# (deprecated PythonSignature is special) +# arguments are represented by DispatchLambdaArgument +# +# pybind11::gil_scoped_release no_gil; +# return self.abs(); +# ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs +# generated from NativeFunction / CppSignature +# }; +# return wrap(dispatch_abs(_r.tensor(0))); +# ~~~~~~~~~~~~~ +# ^ +# +--- dispatch_lambda_exprs +# binding PythonArgParserOutputExpr (python args) +# and DispatchLambdaArgument (c++ args) +# +# } else { +# // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- NativeFunction schema, out-variant +# +# auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor { +# pybind11::gil_scoped_release no_gil; +# return at::abs_out(out, self); +# }; +# return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0))); +# } +# +# +# [Notes] python interface codegen +# The python dataclasses below are used used to generate both python binding code +# and pyi type hint signatures. +# In theory these two should look very similar, but there are number of differences +# in how pyi signatures vs. python_arg_parser signatures are generated. +# These differences have been encapsulated in signature_str() vs. signature_str_pyi() +# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments. +# For examples, only pyi signatures include return types. + + +@dataclass(frozen=True) +class PythonReturns: + returns: tuple[Return, ...] + + +@dataclass(frozen=True) +class PythonArgument: + name: str + type: Type + default: str | None + + # Used to generate the default init expr for some PythonArgParser outputs, e.g.: + # + # _r.layoutWithDefault(3, layout_from_backend(self.options().backend()))) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # ^ + # +--- default_init str + default_init: str | None + + # Compute argument formal for python argument parsing. + # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. + def argument_str(self, *, method: bool = False, symint: bool = True) -> str: + type_str = ( + argument_type_str(self.type, symint=symint) + .replace("const ", "") + .replace(" &", "") + ) + + name = self.name + # s/self/input/ outside method bindings + # [old codegen] TODO: remove this? doesn't rename in codegen, it's just + # for the parse string + if name == "self" and type_str in ["Tensor", "Number"] and not method: + name = "input" + + # add default + if self.default is not None: + default = { + "nullptr": "None", + "::std::nullopt": "None", + "std::nullopt": "None", + "{}": "None", + }.get(self.default, self.default) + return f"{type_str} {name}={default}" + else: + return f"{type_str} {name}" + + def argument_str_pyi( + self, *, method: bool = False, deprecated: bool = False + ) -> str: + type_str = argument_type_str_pyi(self.type) + + name = self.name + # s/self/input/ outside method bindings + # [old codegen] TODO: remove this? doesn't rename in codegen, it's just + # for the parse string + if name == "self" and type_str == "Tensor" and not method and not deprecated: + name = "input" + + if name == "from": # from is a Python keyword... + name += "_" + + # pyi merges the _out and functional variants into the same signature, with an optional out arg + if name == "out" and type_str == "Tensor" and not deprecated: + type_str = "Optional[" + type_str + "]" + + # pyi deprecated signatures don't get defaults for their out arg + treat_as_no_default = ( + deprecated + and isinstance(self, PythonOutArgument) + and self.default == "None" + ) + + # add default + if self.default is not None and not treat_as_no_default: + if ( + isinstance(self.type, ListType) + and self.type.elem == BaseType(BaseTy.int) + and self.default.startswith("{") + and self.default.endswith("}") + ): + default = ( + "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")" + ) + else: + default = { + "nullptr": "None", + "::std::nullopt": "None", + "std::nullopt": "None", + "{}": "None", + "c10::MemoryFormat::Contiguous": "contiguous_format", + "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine", + }.get(self.default, self.default) + return f"{name}: {type_str} = {default}" + else: + return f"{name}: {type_str}" + + +@dataclass(frozen=True) +class PythonOutArgument(PythonArgument): + # In Python signature multiple output fields are packed into one 'out' argument. + # When binding to C++, it's first binded to a local 'out' variable: + # 'auto out = _r.tensorlist_n<2>(2);', + # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc. + # TODO: maybe don't need keep scattered out fields for python signature? + outputs: tuple[PythonArgument, ...] + + @staticmethod + def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None: + if not outputs: + return None + + size = len(outputs) + if size == 1: + return PythonOutArgument( + name=outputs[0].name, + type=outputs[0].type, + default="None", + default_init=None, + outputs=outputs, + ) + elif size > 1: + if any(not a.type.is_tensor_like() for a in outputs): + raise RuntimeError(f"Unsupported output type: {outputs}") + return PythonOutArgument( + name="out", + # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None? + type=ListType(BaseType(BaseTy.Tensor), size), + default="None", + default_init=None, + outputs=outputs, + ) + raise AssertionError(r"Unexpected PythonOutArgument size") + + +@dataclass(frozen=True) +class PythonSignature: + # Base operator name, without inplace/outplace suffix. + name: str + + # Positional arguments. + # TODO: create a dedicated SelfArgument type for 'self'? + input_args: tuple[PythonArgument, ...] + + # Keyword arguments excluding the 'out' argument and scattered kwargs belonging + # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc). + input_kwargs: tuple[PythonArgument, ...] + + output_args: PythonOutArgument | None + + # Return types, which are only used by pyi + returns: PythonReturns + + # These are scattered kwargs arguments belonging to TensorOptions. + # When binding to C++, they are packed into a TensorOptions object 'options'. + # It's possible that the C++ signature doesn't take TensorOptions object (e.g. + # for out variant), in which case they will be used as scattered fields without + # being packed into 'options'. + # TODO: maybe create a PythonTensorOptionsArgument? + tensor_options_args: tuple[PythonArgument, ...] + + # method or function signature? + method: bool + + @property + def deprecated(self) -> bool: + return False + + def arguments( + self, *, skip_outputs: bool = False, skip_tensor_options: bool = False + ) -> tuple[PythonArgument | PythonOutArgument, ...]: + result: list[PythonArgument | PythonOutArgument] = [] + result.extend(self.input_args) + result.extend(self.input_kwargs) + if self.output_args is not None and not skip_outputs: + result.append(self.output_args) + if not skip_tensor_options: + result.extend(self.tensor_options_args) + return tuple(result) + + def arguments_count(self) -> int: + return len(self.arguments()) + + def output_idx(self) -> int: + return len(self.input_args) + len(self.input_kwargs) + + # [old codegen] Compute the Python function signature for argument parsing, + # as specified in torch/csrc/utils/python_arg_parser.h. WARNING: + # this is NOT the same type signature as specified by PEP 484 + # as understood by mypy; our format was independently developed + # and has some quirks to make it more suitable specifically + # for error parsing. + # + # For a translation to mypy-valid type signatures, see + # signature_str_pyi(). + def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: list[str] = [ + a.argument_str(method=self.method, symint=symint) for a in args + ] + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, "*") + + return f'{self.name}({", ".join(schema_formals)})' + + def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: list[str] = [ + a.argument_str_pyi(method=self.method) for a in args + ] + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, "*") + + # only pyi signatures include returns + returns_str = returns_str_pyi(self) + # pyi also includes self (with no typing/defaults) for methods + if self.method: + schema_formals.insert(0, "self") + return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' + + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: + # only pyi uses vararg signatures + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: list[str] = [ + a.argument_str_pyi(method=self.method) for a in args + ] + # vararg only applies to pyi signatures. vararg variants are not generated for all signatures + num_args = self.arguments_count() + num_positionalargs = len(self.input_args) + + have_vararg_version = False + if num_args > 0: + vararg_type = args[0].type + if ( + isinstance(vararg_type, ListType) + and str(vararg_type.elem) in ["int", "SymInt"] + and num_positionalargs == 1 + ): + have_vararg_version = True + + if not have_vararg_version: + return None + + # Below are the major changes in vararg vs. regular pyi signatures + # vararg signatures also omit the asterix + assert isinstance(vararg_type, ListType) + schema_formals[0] = ( + "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem) + ) + + returns_str = returns_str_pyi(self) + # pyi also includes self (with no typing/defaults) for methods + if self.method: + schema_formals.insert(0, "self") + return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' + + +# The deprecated python signature involves some special logic, so create a +# dedicated data model to store these extra properties. +@dataclass(frozen=True) +class PythonSignatureDeprecated(PythonSignature): + # Schema for the deprecated function + deprecated_schema: FunctionSchema + + # The deprecated signature might miss some arguments that the corresponding + # C++ signature expects. We need store the constant default values to pass in. + # For example: + # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) + # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + # [func call]: self.addmm(mat1, mat2, beta, 1) + # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case. + deprecated_args_exprs: tuple[str, ...] + + @property + def deprecated(self) -> bool: + return True + + def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: + return ( + PythonSignature.signature_str( + self, skip_outputs=skip_outputs, symint=symint + ) + + "|deprecated" + ) + + def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: list[str] = [ + a.argument_str_pyi(method=self.method, deprecated=True) for a in args + ] + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, "*") + + returns_str = returns_str_pyi(self) + return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' + + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: + # the codegen doesn't include vararg variants for deprecated signatures + return None + + +# This struct is used to hold the PythonSignature and its corresponding +# NativeFunction BEFORE grouping base and out-variant functions. +# Why not store NativeFunction in PythonSignature or construct PythonSignature +# from NativeFunction? Because they are not 1-1 mapped. +# One native function could have both deprecated and non-deprecated python +# signatures - NativeFunction doesn't contain information to construct the +# deprecated python signature. +# One python signature is used to handle both the base and the out-variant +# function - see 'PythonSignatureGroup'. +@dataclass(frozen=True) +class PythonSignatureNativeFunctionPair: + signature: PythonSignature + function: NativeFunction + + +# We merge pairs of functions with signatures that are equivalent mod +# output arguments, and use a single entry in the python_arg_parser sig +# list for both (output arguments become optional). +@dataclass(frozen=True) +class PythonSignatureGroup: + # The signature used for Python argument parsing. The outplace signature + # is preferred if exists, because it can be used to parse inputs for both + # the out-place variant and the base version (with output omitted). + signature: PythonSignature + + # The regular ATen declaration (e.g. conv2d) + base: NativeFunction + + # The out variant (e.g. conv2d_out) + outplace: NativeFunction | None + + @classmethod + def from_pairs( + cls, + functional: PythonSignatureNativeFunctionPair, + out: PythonSignatureNativeFunctionPair | None, + ) -> PythonSignatureGroup: + if out is None: + return PythonSignatureGroup( + signature=functional.signature, + base=functional.function, + outplace=None, + ) + + # prefer the signature with optional out=... arguments because it's the + # superset that can be used to parse input for both base and outplace. + signature_kwargs = out.signature.__dict__.copy() + + # Out overloads in C++ don't have TensorOptions arguments, + # so take these from the functional variant + signature_kwargs[ + "tensor_options_args" + ] = functional.signature.tensor_options_args + + return PythonSignatureGroup( + signature=type(out.signature)(**signature_kwargs), + base=functional.function, + outplace=out.function, + ) + + +# C++ function dispatch is wrapped in a lambda function. The lambda function +# has almost the same signature as the C++ function, only with some small +# variants - see details below. +# This data model is used to represent arguments of the lambda function +# signature. +@dataclass(frozen=True) +class DispatchLambdaArgument: + name: str + type_str: str + is_out_arg: bool + + +# To pass PyObjects arguments to C++ function (via the lambda wrapper), +# we need first convert PyObjects into simple C++ objects. This work +# is done by PythonArgParser. +# This data model is used to represent the output of PythonArgParser. +# It has 1-1 mapping with PythonArgument in PythonSignature. +@dataclass(frozen=True) +class PythonArgParserOutputExpr: + # argument name + name: str + + # RHS expression to reference PythonArgParser output. + expr: str + + # In some special cases we need create different expr, e.g.: + # '_r.isNone(1)' instead of '_r.tensor(1)'. + index: int + + # The python argument it maps to. + argument: PythonArgument + + @property + def is_none_expr(self) -> str: + return f"_r.isNone({self.index})" + + +# To pass PythonArgParser output to the lambda wrapper, we need bind +# PythonArgParserOutputExpr to DispatchLambdaArgument. +# They are not always 1-1 mapped, e.g. scattered TensorOptions fields +# need be packed into a TensorOptions object, which is the argument +# that the lambda function wrapper takes. +@dataclass(frozen=True) +class DispatchLambdaArgumentExprs: + # The exprs that provide the binding for lambda arguments, e.g.: + # + # 'self' -> '_r.tensor(0)' + # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]' + # 'options' -> 'options' + # + # It has 1-1 mapping with DispatchLambdaArgument. + exprs: Sequence[str] + + # Special local inits, which might introduce new variables that + # the 'exprs' above reference, e.g.: + # + # 'auto out = _r.tensorlist_n<2>(2);' + # + inits: Sequence[str] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Helper Functions +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature: + return CppSignatureGroup.from_native_function(f, method=method).signature + + +def has_tensor_options(f: NativeFunction) -> bool: + return f.func.arguments.tensor_options is not None + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python Signature +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +# 'simple_type' was introduced by the old codegen, which is slightly +# different from the python schema type, e.g.: doesn't have '?' suffix +# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type. +def argument_type_str( + t: Type, *, simple_type: bool = False, symint: bool = True +) -> str: + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + return "Tensor" + elif t.name == BaseTy.int: + return "int64_t" + elif t.name == BaseTy.float: + return "double" + elif t.name == BaseTy.str: + return "c10::string_view" + elif t.name in [ + BaseTy.bool, + BaseTy.QScheme, + BaseTy.Scalar, + BaseTy.ScalarType, + BaseTy.Generator, + BaseTy.Storage, + BaseTy.Layout, + BaseTy.Device, + BaseTy.DeviceIndex, + BaseTy.MemoryFormat, + BaseTy.Dimname, + BaseTy.Stream, + BaseTy.ConstQuantizerPtr, + BaseTy.SymInt, + ]: + # These python schema type names line up with their function schema names + return t.name.name + + elif isinstance(t, OptionalType): + if str(t.elem) == "Tensor": + # Is it desired to keep '?' for simple_type with new style dispatcher? + return "Tensor?" + elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) + return f"{elem}?" + elif isinstance(t, ListType): + size = t.size if not simple_type else None + if str(t.elem) == "bool": + assert t.size is not None + return f"::std::array" + elif str(t.elem) == "int": + return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" + elif str(t.elem) == "SymInt": + if symint: + return ( + f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef" + ) + else: + return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" + elif str(t.elem) == "Tensor": + return f"TensorList[{size}]" if size is not None else "TensorList" + elif str(t.elem) == "Scalar": + return f"ScalarList[{size}]" if size is not None else "ScalarList" + elif str(t.elem) == "Tensor?": + if simple_type: + return "c10::List<::std::optional>" + else: + return "const c10::List<::std::optional> &" + elif str(t.elem) == "Dimname": + return f"DimnameList[{size}]" if size is not None else "DimnameList" + elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) + return f"ArrayRef<{elem}>" + + raise RuntimeError(f"unrecognized type {repr(t)}") + + +def argument_type_size(t: Type) -> int | None: + l = t.is_list_like() + if l is not None and str(l.elem) != "bool": + return l.size + else: + return None + + +def argument(a: Argument) -> PythonArgument: + return PythonArgument( + name=a.name, + type=a.type, + # TODO: directly translate a.default to python default + default=( + str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False))) + if a.default is not None + else None + ), + default_init=None, + ) + + +# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen +def signature( + f: NativeFunction, *, method: bool = False, pyi: bool = False +) -> PythonSignature: + return signature_from_schema( + f.func, category_override=f.category_override, method=method, pyi=pyi + ) + + +def signature_from_schema( + func: FunctionSchema, + *, + category_override: str | None, + method: bool = False, + pyi: bool = False, +) -> PythonSignature: + args: list[Argument] = [] + args.extend(func.arguments.pre_self_positional) + # Skip SelfArgument if this is method. + if not method and func.arguments.self_arg is not None: + args.append(func.arguments.self_arg.argument) + args.extend(func.arguments.post_self_positional) + args.extend(func.arguments.pre_tensor_options_kwarg_only) + # Skip TensorOptionsArguments. Python side TensorOptions + # arguments are created based on different rules - see below. + args.extend(func.arguments.post_tensor_options_kwarg_only) + args.extend(func.arguments.out) + + input_arg_set = {a.name for a in func.arguments.flat_positional} + kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only} + out_arg_set = {a.name for a in func.arguments.out} + + input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args))) + input_kwargs = tuple( + map(argument, filter(lambda a: a.name in kwarg_only_set, args)) + ) + outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args))) + + # Reintroduce the scattered fields of TensorOptions for Python. + # Compared to the cpp counterpart, the python arguments have new property + # (default_init) and a new argument 'requires_grad', which require some + # special handlings. + # [old codegen] TODO: because these aren't guaranteed to be 100% faithful + # to the original versions in the yaml, this recreation is a potential + # source of drift between eager and JIT. Pull this logic out to a shared place. + + has_tensor_input_arg = any( + a.type.is_tensor_like() for a in func.arguments.flat_non_out + ) + if any(a.name == "requires_grad" for a in func.schema_order_arguments()): + raise ValueError( + "argument named requires_grad is reserved, should not explicitly add it in the schema" + ) + + # [old codegen] this probably won't work if one of the returns is not a tensor, + # but it will produce a compile-time error that is obvious. + has_tensor_return = any(r.type.is_tensor_like() for r in func.returns) + + name: str = cpp.name(func) + is_factory_function = category_override == "factory" or ( + has_tensor_return and not has_tensor_input_arg + ) + is_like_or_new_function = ( + category_override in ("new", "like") + or name.startswith("new_") + or name.endswith("_like") + ) + is_dummy_function = category_override == "dummy" + + tensor_options_args: list[PythonArgument] = [] + if (is_factory_function or is_like_or_new_function) and not is_dummy_function: + + def topt_default_init(name: str) -> str | None: + topt_args = func.arguments.tensor_options + if topt_args is None: + return None + a = getattr(topt_args, name) + if a.default is None or a.default == "None": + return None + return cpp.default_expr(a.default, a.type, symint=False) + + tensor_options_args.append( + PythonArgument( + name="dtype", + type=OptionalType(BaseType(BaseTy.ScalarType)), + default="None", + default_init=( + None if is_like_or_new_function else topt_default_init("dtype") + ), + ) + ) + tensor_options_args.append( + PythonArgument( + name="layout", + type=OptionalType(BaseType(BaseTy.Layout)), + default="None", + default_init=( + None if is_like_or_new_function else topt_default_init("layout") + ), + ) + ) + tensor_options_args.append( + PythonArgument( + name="device", + type=OptionalType(BaseType(BaseTy.Device)), + default="None", + default_init=( + None + if is_like_or_new_function + else ( + topt_default_init("device") + or "torch::tensors::get_default_device()" + ) + ), + ) + ) + tensor_options_args.append( + PythonArgument( + name="pin_memory", + type=OptionalType(BaseType(BaseTy.bool)), + default="False", + default_init=None, + ) + ) + tensor_options_args.append( + PythonArgument( + name="requires_grad", + type=OptionalType(BaseType(BaseTy.bool)), + default="False", + default_init=None, + ) + ) + + returns = PythonReturns(returns=func.returns) + + return PythonSignature( + name=str(func.name.name), + input_args=input_args, + input_kwargs=input_kwargs, + output_args=PythonOutArgument.from_outputs(outputs), + tensor_options_args=tuple(tensor_options_args), + returns=returns, + method=method, + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python Interface +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]: + if len(returns) <= 1 or all(r.name is None for r in returns): + return [] + else: + if any(r.name is None for r in returns): + # When building on Windows, `PyStructSequence_UnnamedField` could not be + # resolved by the linker for some reason, which cause error in building: + # + # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol + # PyStructSequence_UnnamedField + # + # Thus, at this point in time, we do not support unnamed + # fields in structseq; you must either name all fields, + # or none of them. + raise ValueError("Unnamed field is not supported by codegen") + + return [str(r.name) for r in returns] + + +def argument_type_str_pyi(t: Type) -> str: + add_optional = False + if isinstance(t, OptionalType): + t = t.elem + add_optional = True + + if isinstance(t, BaseType): + if t.name in [BaseTy.int, BaseTy.DeviceIndex]: + ret = "_int" + if t.name == BaseTy.SymInt: + ret = "Union[_int, SymInt]" + elif t.name == BaseTy.float: + ret = "_float" + elif t.name == BaseTy.str: + ret = "str" + elif t.name == BaseTy.Scalar: + ret = "Union[Number, _complex]" + elif t.name == BaseTy.ScalarType: + ret = "_dtype" + elif t.name == BaseTy.bool: + ret = "_bool" + elif t.name == BaseTy.QScheme: + ret = "_qscheme" + elif t.name == BaseTy.Layout: + ret = "_layout" + elif t.name == BaseTy.Device: + ret = "Optional[DeviceLikeType]" + elif t.name == BaseTy.MemoryFormat: + ret = "memory_format" + elif t.name == BaseTy.Dimname: + ret = "Union[str, ellipsis, None]" + elif t.name == BaseTy.Storage: + ret = "Union[Storage, UntypedStorage]" + elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]: + # These python schema type names line up with their function schema names + ret = t.name.name + + elif isinstance(t, ListType): + if str(t.elem) == "int": + ret = "Union[_int, _size]" if t.size is not None else "_size" + elif t.is_tensor_like(): + # TODO: this doesn't seem right... + # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]] + # It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]] + if isinstance(t.elem, OptionalType): + add_optional = True + ret = ( + "Union[Tensor, Tuple[Tensor, ...], List[Tensor]]" + if t.size is not None + else "Union[Tuple[Tensor, ...], List[Tensor]]" + ) + elif str(t.elem) == "float": + ret = "Sequence[_float]" + elif str(t.elem) == "SymInt" and t.size is not None: + elem = argument_type_str_pyi(t.elem) + ret = f"Union[{elem}, Sequence[{elem}]]" + else: + elem = argument_type_str_pyi(t.elem) + ret = f"Sequence[{elem}]" + + else: + raise RuntimeError(f"unrecognized type {repr(t)}") + + if add_optional: + ret = "Optional[" + ret + "]" + + return ret + + +def return_type_str_pyi(t: Type) -> str: + # Where arguments are open to accepting Union, return types should return + # concrete types + + if isinstance(t, OptionalType): + inner = return_type_str_pyi(t.elem) + return f"Optional[{inner}]" + + if isinstance(t, BaseType): + if t.name == BaseTy.Device: + return "_device" + elif t.name == BaseTy.Dimname: + ret = "Optional[str]" + else: + return argument_type_str_pyi(t) + + if isinstance(t, ListType): + inner = return_type_str_pyi(t.elem) + return f"Tuple[{inner}, ...]" + + return argument_type_str_pyi(t) + + +def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None: + python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] + structseq_name = signature.name + field_names = structseq_fieldnames(signature.returns.returns) + if field_names: + # These types are structseq objects which act like named NamedTuples, but + # the constructor acts like the constructor of tuple. Using typing.NamedTuple + # does not allow us to override __init__. + seq_type = f"Tuple[{', '.join(python_returns)}]" + structseq_def_lines = [ + f"class {structseq_name}({seq_type}):", + ] + for name, typ in zip(field_names, python_returns): + structseq_def_lines.extend( + [ + " @property", + f" def {name}(self) -> {typ}: ...", + ] + ) + structseq_def_lines.extend( + [ + f" def __new__(cls, sequence: {seq_type}): ...", + f" n_fields: _int = {len(field_names)}", + f" n_sequeunce_fields: _int = {len(field_names)}", + " n_unnamed_fields: _int = 0", + " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing", + "", # add an extra newline + ] + ) + structseq_def = "\n".join(structseq_def_lines) + # Example: + # structseq_def = ( + # "class max(Tuple[Tensor, Tensor]):\n" + # " @property\n" + # " def values(self) -> Tensor: ...\n" + # " @property\n" + # " def indices(self) -> Tensor: ...\n" + # " def __new__(cls, sequence: Tuple[Tensor, Tensor]): ...\n" + # " n_fields: _int = 2", + # " n_sequeunce_fields: _int = 2", + # " n_unnamed_fields: _int = 0", + # " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing", + # ) + return structseq_name, structseq_def + return None + + +def returns_str_pyi(signature: PythonSignature) -> str: + field_names = structseq_fieldnames(signature.returns.returns) + if field_names: + return f"torch.return_types.{signature.name}" + + python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] + if len(python_returns) > 1: + return "Tuple[" + ", ".join(python_returns) + "]" + if len(python_returns) == 1: + return python_returns[0] + return "None" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# C++ Function Dispatch +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# This section provides APIs to generate the code that does C++ function +# dispatch. The C++ function call is wrapped by a lambda function. +# For example: +# +# // aten::selu_(Tensor(a!) self) -> Tensor(a!) +# auto dispatch_selu_ = [](Tensor self) -> Tensor { +# pybind11::gil_scoped_release no_gil; +# return at::selu_(self); +# }; +# +# The lambda function's signature follows the C++ signature in common +# cases, e.g.: +# +# // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor +# [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor +# +# For out variant the 'out' argument's type is changed from 'Tensor &' +# to 'Tensor'. It's because when calling the lambda it passes in the +# PythonArgParser output '_r.tensor(3)', which is stack allocated object +# and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'. +# +# // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +# [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor +# +# For multi-output case it can keep using reference type because the +# PythonArgParser output has been unpacked to local variables, e.g.: +# +# // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, +# // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) +# [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple +# +# For deprecated python signature, it should follow deprecated python arg order. +# TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary? + + +def dispatch_lambda_args( + ps: PythonSignature, f: NativeFunction, symint: bool = True +) -> tuple[DispatchLambdaArgument, ...]: + if isinstance(ps, PythonSignatureDeprecated): + schema = ps.deprecated_schema + else: + schema = f.func + + # Start with cpp arguments - dispatch lambda signature always include 'self' + cpp_args = cpp.arguments( + arguments=schema.arguments, + faithful=False, + symint=symint, + method=False, + cpp_no_default_args=f.cpp_no_default_args, + ) + out_args: set[str] = {a.name for a in schema.arguments.out} + + # Convert from cpp argument to lambda argument + def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument: + type_str = cpp_arg.type + is_out_arg = cpp_arg.name in out_args + if ps.method and cpp_arg.name == "self": + # For method's 'self', we can use 'const Tensor &' and simply ignore mutability! + type_str = "const at::Tensor &" + else: + # For other cases we need prevent dangling refs to temps (unless it's + # unpacked scattered output) + # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'. + # TODO: avoid this special handling? + ensure_temp_safe = len(out_args) <= 1 or not is_out_arg + if ensure_temp_safe: + type_str = { + "at::Tensor &": "at::Tensor", + }.get(type_str, type_str) + return DispatchLambdaArgument( + name=cpp_arg.name, + type_str=type_str, + is_out_arg=is_out_arg, + ) + + return tuple(map(dispatch_lambda_arg, cpp_args)) + + +# [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean +# it's enough to just extend the list here. Before you do this, make sure +# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h. +SUPPORTED_RETURN_TYPES = { + "at::Tensor", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple>", + "::std::vector", + # Needed for flash attention forw/backward + "::std::tuple", + "at::Scalar", + "bool", + "int64_t", + "void*", + "void", + "at::QScheme", + "double", + "at::IntArrayRef", + "at::ScalarType", + "at::Stream", +} + + +def dispatch_lambda_return_str(f: NativeFunction) -> str: + # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &') + # because the dispatch lambdas take mutable arguments *by value*, not + # by reference. If you then return a reference to such an argument, you + # will now have a pointer to a dangling stack entry. Not good. + # + # You want: + # + # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); }; + # ^^^^^^ + # + # *not* + # + # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); }; + # ^^^^^^^ + # + # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing + # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a + # mutable reference to temporary. Maybe we could assign it to a + # variable itself.) + returns_without_annotation = tuple( + Return(r.name, r.type, None) for r in f.func.returns + ) + return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type() + if return_str not in SUPPORTED_RETURN_TYPES: + raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}") + return return_str + + +def cpp_dispatch_target(f: NativeFunction) -> str: + symint = f.func.has_symint() + name = cpp.name(f.func, symint_overload=symint) + if Variant.method in f.variants: + return f"self.{name}" + if Variant.function in f.variants: + if has_tensor_options(f) or f.func.name.name.base.endswith("_like"): + namespace = "torch" + else: + namespace = "at" + return f"{namespace}::{name}" + raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}") + + +def cpp_dispatch_exprs( + f: NativeFunction, + *, + python_signature: PythonSignature | None = None, +) -> tuple[str, ...]: + cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments() + + exprs: tuple[str, ...] = () + if not isinstance(python_signature, PythonSignatureDeprecated): + # By default the exprs are consistent with the C++ signature. + exprs = tuple(a.name for a in cpp_args) + else: + # For deprecated python signature we may need fill in some constants. + exprs = tuple( + filter( + lambda n: n != "out" or f.func.is_out_fn(), + python_signature.deprecated_args_exprs, + ) + ) + + if Variant.method in f.variants: + exprs = tuple(filter("self".__ne__, exprs)) + + return exprs + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python / C++ Args Binding +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +# We explicitly enumerate the PythonArgParser unpacking methods for all +# supported types. This might be more verbose than necessary, partially +# because of the irregularity of unpacking method naming, partially +# because we want to mimic the old codegen behavior - to reject +# unexpected and/or unsupported cases which the old codegen rejects. +# For certain cases it is intentionally more restrictive than necessary, +# e.g.: it doesn't accepts doublelist with definite size. +def arg_parser_unpack_method( + t: Type, default: str | None, default_init: str | None, *, symint: bool = True +) -> str: + has_default_init = default_init is not None + if has_default_init and str(t) not in ( + "ScalarType?", + "ScalarType", + "Device", + "Device?", + "Layout", + "Layout?", + "bool", + "bool?", + ): + raise RuntimeError(f"type '{t}' does not supported unpacking with default") + + if isinstance(t, BaseType): + if t.name in [ + BaseTy.Tensor, + BaseTy.Stream, + BaseTy.Storage, + BaseTy.Scalar, + BaseTy.Dimname, + ]: + # These unpack methods line up with their schema names + return t.name.name.lower() + elif t.name == BaseTy.ScalarType: + return "scalartypeWithDefault" if has_default_init else "scalartype" + elif t.name == BaseTy.Device: + return "deviceWithDefault" if has_default_init else "device" + elif t.name == BaseTy.DeviceIndex: + return "toInt64" + elif t.name == BaseTy.int: + return "toInt64" + elif t.name == BaseTy.SymInt: + return "toSymInt" if symint else "toInt64" + elif t.name == BaseTy.bool: + return "toBoolWithDefault" if has_default_init else "toBool" + elif t.name == BaseTy.float: + return "toDouble" + elif t.name == BaseTy.str: + return "stringView" + elif t.name == BaseTy.Layout: + return "layoutWithDefault" if has_default_init else "layout" + elif t.name == BaseTy.MemoryFormat: + return "memoryformat" + + elif isinstance(t, OptionalType): + if str(t.elem) == "Tensor": + return "optionalTensor" + elif str(t.elem) == "Generator": + return "generator" + elif str(t.elem) == "Dimname[]": + return "toDimnameListOptional" + elif not has_default_init and default in ( + None, + "None", + "::std::nullopt", + "std::nullopt", + ): + # If default is None: append 'Optional' to elem's unpacking method + return ( + arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional" + ) + else: + # Otherwise, load as underlying type with default + return arg_parser_unpack_method( + t.elem, default, default_init, symint=symint + ) + + elif isinstance(t, ListType): + if str(t.elem) == "Tensor": + # accept and use definite size + return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist" + elif str(t.elem) == "Tensor?": + return "list_of_optional_tensors" + elif str(t.elem) == "Dimname": + # accept definite size + return "dimnamelist" + elif str(t.elem) == "int": + # accept definite size + return "intlist" + elif str(t.elem) == "float": + return "doublelist" + elif str(t.elem) == "SymInt": + # accept definite size + return "symintlist" if symint else "intlist" + elif str(t.elem) == "Scalar": + return "scalarlist" + raise RuntimeError(f"type '{t}' is not supported by PythonArgParser") + + +# Return RHS expression for python argument using PythonArgParser output. +# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)' +def arg_parser_output_expr( + arg_index: int, a: PythonArgument, *, symint: bool = True +) -> PythonArgParserOutputExpr: + has_default = a.default_init is not None + unpack_method = arg_parser_unpack_method( + t=a.type, default=a.default, default_init=a.default_init, symint=symint + ) + default = f", {a.default_init}" if has_default else "" + expr = f"_r.{unpack_method}({arg_index}{default})" + + return PythonArgParserOutputExpr( + name=a.name, + expr=expr, + index=arg_index, + argument=a, + ) + + +# Returns a map with key = arg_name and value = PythonArgParserOutputExpr. +def arg_parser_output_exprs( + ps: PythonSignature, f: NativeFunction, *, symint: bool = True +) -> dict[str, PythonArgParserOutputExpr]: + return { + e.name: e + for i, a in enumerate(ps.arguments()) + for e in (arg_parser_output_expr(i, a, symint=symint),) + } + + +# argument name to type for scattered tensor options fields +TENSOR_OPTIONS_FIELDS = { + "dtype": "ScalarType?", + "device": "Device?", + "layout": "Layout?", + "pin_memory": "bool?", + "requires_grad": "bool?", +} + + +# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args). +def dispatch_lambda_exprs( + ps: PythonSignature, f: NativeFunction, *, symint: bool = True +) -> DispatchLambdaArgumentExprs: + # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing + # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser + # outputs. + arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) + lambda_args = dispatch_lambda_args(ps, f, symint=symint) + inits: list[str] = [] + lambda_args_exprs: dict[str, str] = {} + + has_toptions = has_tensor_options(f) + + # 1. special inits/unpacking to provide binding exprs for lambda arguments. + for a in ps.arguments(skip_tensor_options=True): + name = a.name + arg_parser_expr = arg_parser_outputs[a.name].expr + + if has_toptions and name == "self": + # TODO: why this needs to be special case? + inits.extend( + [ + f"auto self = {arg_parser_expr};", + ] + ) + lambda_args_exprs[name] = name + elif ( + isinstance(a, PythonOutArgument) + and len(a.outputs) > 1 + and f.func.is_out_fn() + ): + inits.extend( + [ + f"auto out = {arg_parser_expr};", + ] + ) + for i, out_arg in enumerate(a.outputs): + lambda_args_exprs[out_arg.name] = f"out[{i}]" + elif str(a.type) == "Dimname[]?": + # [old codegen] + # TODO: make this part of something more general, or get rid of it. + # optional> are special. The PythonArgParser returns an + # optional>, which cannot be implicitly converted to + # optional>. One needs to unwrap the optional and rewrap. + inits.extend( + [ + f"auto __{name} = {arg_parser_expr};", + f"::std::optional {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950 + ] + ) + lambda_args_exprs[name] = name + else: + # default case - directly using PythonArgParser output expr + lambda_args_exprs[name] = arg_parser_expr + + # method's self is passed directly to python binding, rather than parsed + if ps.method: + lambda_args_exprs["self"] = "self" + + # 2. special packing/checking for TensorOptions. + tensor_options_args_names = [a.name for a in ps.tensor_options_args] + if has_toptions: + if f.func.is_out_fn(): + raise RuntimeError(f"{f.func}: tensor options with output arg") + for a in ps.tensor_options_args: + if a.name not in TENSOR_OPTIONS_FIELDS: + raise RuntimeError( + f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments" + ) + if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name): + raise RuntimeError( + f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'" + ) + if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS): + raise RuntimeError( + f"{f.func}: incomplete tensor options args: {tensor_options_args_names}" + ) + + inits.append( + f"""\ +const auto options = TensorOptions() + .dtype({arg_parser_outputs['dtype'].expr}) + .device({arg_parser_outputs['device'].expr}) + .layout({arg_parser_outputs['layout'].expr}) + .requires_grad({arg_parser_outputs['requires_grad'].expr}) + .pinned_memory({arg_parser_outputs['pin_memory'].expr}); +torch::utils::maybe_initialize_device(options); +""" + ) + lambda_args_exprs["options"] = "options" + + # 3. special case - access scattered TensorOptions fields without packing + # TODO: maybe move to the generator side as it's not related to binding. + if not has_toptions and tensor_options_args_names: + if "dtype" in tensor_options_args_names: + # we're an output-arg variant, check these args against output tensor + if not f.func.is_out_fn(): + raise RuntimeError( + f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}" + ) + if not all(a in tensor_options_args_names for a in ("layout", "device")): + raise RuntimeError( + f"{f.func}: incomplete tensor options for output check" + ) + + inits.append( + f"""\ +check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr}, + {arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr}, + {arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr}); +""" + ) + # we'll set requires_grad on outgoing tensor + if "requires_grad" not in tensor_options_args_names: + raise RuntimeError( + f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]' + ) + + return DispatchLambdaArgumentExprs( + exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args), + inits=inits, + ) diff --git a/lib/python3.10/site-packages/torchgen/api/structured.py b/lib/python3.10/site-packages/torchgen/api/structured.py new file mode 100644 index 0000000000000000000000000000000000000000..93a72eb2b4a5c119ee8f60ce04f0517fe862b4d5 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/structured.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from torchgen.api import cpp +from torchgen.api.types import ( + ArgName, + ArrayRefCType, + BaseCType, + Binding, + ConstRefCType, + dimnameListT, + intArrayRefT, + iOptTensorListRefT, + iTensorListRefT, + NamedCType, + OptionalCType, + optionalIntArrayRefT, + optionalScalarRefT, + optionalTensorRefT, + scalarT, + tensorT, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + ListType, + NativeFunctionsGroup, + OptionalType, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import assert_never + + +# This file describes the translation of JIT schema to the structured functions API. +# This is similar to native API, but a number of historical problems with native +# API have been fixed. + + +# Translation of types occurring in JIT arguments to a C++ argument type. +# NB: For now, mutable doesn't do anything; but it could if we make +# some more nominal types +def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: + # If it's a value type, do the value type translation + # NB: structured kernels ALWAYS have symint off, since they involve actual + # kernels that require real ints. The one exception is the + # CompositeExplicitAutograd and the meta function (which could + # hypothetically be SymInt), but for simplicity we plan for these to just + # be handled in Python + r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable) + if r is not None: + return r + + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) + elif t.name == BaseTy.Scalar: + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + else: + raise AssertionError(f"base type should have been value type {t}") + elif isinstance(t, OptionalType): + if t.elem == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(optionalTensorRefT)) + elif t.elem == BaseType(BaseTy.Scalar): + return NamedCType(binds, BaseCType(optionalScalarRefT)) + elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": + return NamedCType(binds, BaseCType(optionalIntArrayRefT)) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + if t.elem == BaseType(BaseTy.Tensor): + return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) + elif t.elem == OptionalType(BaseType(BaseTy.Tensor)): + return NamedCType(binds, BaseCType(iOptTensorListRefT)) + # TODO: delete these special cases; see torchgen.api.cpp--these + # must be changed in tandem, but there are problems; see + # https://github.com/pytorch/pytorch/pull/51485 + elif str(t.elem) == "int": + return NamedCType(binds, BaseCType(intArrayRefT)) + elif str(t.elem) == "Dimname": + return NamedCType(binds, BaseCType(dimnameListT)) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + return NamedCType(binds, ArrayRefCType(elem.type)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +def argument_type(a: Argument, *, binds: ArgName) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, binds=binds) + + +# returns_type intentionally omitted, because structured kernels never "return"; +# instead, they always indirectly report their outputs (in the case of a meta +# function, by calling set_output; in the case of an impl function, by writing +# directly into the provided out argument). + + +# Structured kernels are never defaulted +def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]: + if isinstance(a, Argument): + return [ + Binding( + nctype=argument_type(a, binds=a.name), + name=a.name, + default=None, + argument=a, + ) + ] + elif isinstance(a, SelfArgument): + return argument(a.argument) + elif isinstance(a, TensorOptionsArguments): + raise AssertionError("structured kernels don't support TensorOptions yet") + else: + assert_never(a) + + +def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + + if g.out.precomputed: + # A list of parameters for the impl function with + # certain parameters replaced with precomputed counterparts + # as specified in native_functions.yaml. + non_out_args_replaced: list[ + Argument | TensorOptionsArguments | SelfArgument + ] = [] + for a in g.out.func.arguments.non_out: + if isinstance(a, Argument) and a.name in g.out.precomputed.replace: + # If a is in precompute.replace, append the parameters + # that should replace it onto non_out_args_replaced. + non_out_args_replaced.extend(g.out.precomputed.replace[a.name]) + else: + # If not, push a as it is. + non_out_args_replaced.append(a) + + args.extend(non_out_args_replaced) + # g.out.precomputed.add is the list of parameters that are added + # without replacement after the non out args and just before the out args + args.extend(g.out.precomputed.add) + else: + args.extend(g.out.func.arguments.non_out) + + args.extend(g.out.func.arguments.out) + return [r for arg in args for r in argument(arg)] + + +def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + args.extend(g.functional.func.arguments.non_out) + return [r for arg in args for r in argument(arg)] + + +def out_arguments(g: NativeFunctionsGroup) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + args.extend(g.out.func.arguments.out) + return [r for arg in args for r in argument(arg)] diff --git a/lib/python3.10/site-packages/torchgen/api/translate.py b/lib/python3.10/site-packages/torchgen/api/translate.py new file mode 100644 index 0000000000000000000000000000000000000000..761fb3c7c2b98707bd9b9f79a8a5842fc7ce11a8 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/translate.py @@ -0,0 +1,433 @@ +from __future__ import annotations + +from typing import NoReturn, Sequence + +from torchgen.api.types import ( + ArrayRefCType, + BaseCType, + Binding, + boolT, + ConstRefCType, + deviceT, + Expr, + intArrayRefT, + iOptTensorListRefT, + layoutT, + ListCType, + longT, + memoryFormatT, + MutRefCType, + NamedCType, + opmath_t, + OptionalCType, + optionalIntArrayRefT, + optionalScalarRefT, + optionalSymIntArrayRefT, + optionalTensorRefT, + scalar_t, + scalarT, + scalarTypeT, + SpecialArgName, + symIntArrayRefT, + SymIntT, + tensorOptionsT, + tensorT, + VectorCType, +) + + +# This file implements a small program synthesis engine that implements +# conversions between one API to another. +# +# The key data type in this file in NamedCType, short for Named C++ semantic type. A NamedCType +# represents a C++ type, plus semantic information about what it represents. +# For example, consider the argument "bool pin_memory"; its normal C++ type is +# "bool", but its C++ semantic type also keeps track that this represents a +# "pin_memory"; you can't just use a random other boolean in a context where you +# need a "pin_memory"! +# +# The translator takes a list of needed NamedCTypes, and then figures out how +# to construct expressions with these NamedCTypes from the given bindings. Many +# of these expressions are trivial (I need a Tensor other; there's a Tensor +# other scope); others are more nontrivial and may require packing/unpacking. +# Some examples of non-trivial action: +# +# - Need the "dtype" binding? Well, maybe "dtype" isn't available +# in the context, instead, "options" is, and you need to extract +# it from there. (Gather) +# +# - Need the "context" binding? Well, maybe "context" isn't available +# in the context, and you need to construct it from "dtype", "device", +# etc. (Scatter) +# +# - Need the "memory_format" binding? Well, actually, it's available +# from both "memory_format" and "options", so you had better make sure +# they are consistent. (Join) + +options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) + +out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT))) + +longVec_ctype = VectorCType(BaseCType(longT)) +longSymVec_ctype = VectorCType(BaseCType(SymIntT)) +optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) +optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) +optionalTensor_ctype = OptionalCType(BaseCType(tensorT)) + + +class UnsatError(RuntimeError): + pass + + +# Given a set of in-scope bindings and a set of target bindings, synthesize +# a list of expressions that uses only the in-scope bindings (bindings) that +# have all of the types of goals. You may want to use this function if +# you're generating code for a function like: +# +# void f({args}) { +# g({exprs}); // g is a different API +# } +# +# and you need to generate "exprs". +# +# Typically, a list of Bindings is convenient to get (you usually call something +# like arguments() to get them); but technically you only need less information: +# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for +# 'goals', an (ordered) list of NamedCType goals is sufficient. If you are doing +# something more complicated, e.g., tracking the set of bindings in a context, +# you may find using these smaller types more convenient. +def translate( + bindings: Sequence[Expr | Binding], + goals: Sequence[NamedCType | Binding], + *, + method: bool = False, + allow_expensive_conversions: bool = False, +) -> list[Expr]: + binding_exprs: list[Expr] = [] + for b in bindings: + if isinstance(b, Binding): + binding_exprs.append( + Expr( + expr=b.name, + type=b.nctype, + ) + ) + else: + binding_exprs.append(b) + + goal_ctypes: list[NamedCType] = [] + for g in goals: + if isinstance(g, Binding): + goal_ctypes.append(g.nctype) + else: + goal_ctypes.append(g) + + # Add all the bindings to the context + ctx: dict[NamedCType, str] = {} + for b in binding_exprs: + ctx[b.type] = b.expr + + # While we're at it, do some simple forward inference, looking through + # constructors. + # + # NB: When should you do forward inference versus backward inference? + # The general idea: + # + # - Backward inference WHEN the goal gets smaller + # - Forward inference WHEN the hypothesis gets smaller + # + # This helps ensure termination: backward inference starts with a goal + # and tries to make it simpler and simpler until it's trivial; if the + # goal can grow in size, we blow up to a really huge goal size. + # Similarly, with forward inference we take hypotheses and decompose + # them into simpler hypotheses; if hypotheses could expand in size, + # we also have potential nontermination. (In the code below, forward + # inference is only ever carried out at a single step, but you could + # imagine repeated application of forward inference being profitable.) + # + # A good starting point in the literature for exploring more about proof + # search are these lecture notes + # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf + # + # TODO: My kingdom for a pattern matcher + # https://www.python.org/dev/peps/pep-0634/ + # + # TODO: This could get us in recomputation trouble if b.expr is nontrivial. + # Fix this by implementing some sort of sharing so that if multiple + # goals share the same expression, we only compute it once. This seems + # to matter in practice as compiler is often unwilling to CSE nontrivial + # expressions like scalar.to() + t = b.type + if ( + isinstance(t, ConstRefCType) + and isinstance(t.elem, OptionalCType) + and isinstance(t.elem.elem, BaseCType) + and str(t.elem.elem.type) == "at::Tensor" + ): + ctx[ + NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT))) + ] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())" + + if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))): + ctx[ + NamedCType(t.name, BaseCType(optionalTensorRefT)) + ] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())" + + if t.type == ConstRefCType(BaseCType(scalarT)): + ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to()" + + if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): + ctx[ + NamedCType(t.name, BaseCType(optionalScalarRefT)) + ] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())" + + if t.type == BaseCType(scalar_t): + ctx[ + NamedCType(t.name, BaseCType(opmath_t)) + ] = f"static_cast({b.expr})" + + # [Note: IOptTensorListRef] + if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))): + ctx[ + NamedCType(t.name, BaseCType(iOptTensorListRefT)) + ] = f"at::IOptTensorListRef({b.expr})" + + # Add implicit bindings if the generated code is inside a Tensor method + if method: + ctx[ + NamedCType("self", MutRefCType(BaseCType(tensorT))) + ] = "const_cast(*this)" + ctx[ + NamedCType("self", ConstRefCType(BaseCType(tensorT))) + ] = "const_cast(*this)" + # This is better! Byte-for-byte compat + # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this" + + def unsat(goal: NamedCType) -> NoReturn: + ctx_desc = "\n".join( + f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items() + ) + raise UnsatError( + f""" +Failed to synthesize the expression "{goal.cpp_type()} {goal.name}". +When I failed, the following bindings were available in the context: + +{ctx_desc} + +This probably means there is a missing rule in the rules of torchgen.api.translate. +Check this module for more information. +""" + ) + + # A shitty backtracking search implementation. It's shitty because it + # does backtracking via stack (bad idea!) and for the most part tries to + # avoid backtracking. In particular, if + # direct=True, we won't try to do any fancy synthesis, just trivial + # conversions (e.g., "T a" is OK for "const T& a"). So all of the + # existing rules in this function simply try to solve immediately, + # and bail if things don't work out. + def solve(goal: NamedCType, *, direct: bool) -> str: + def direct_solve(goal: NamedCType) -> str: + return solve(goal, direct=True) + + if goal in ctx: + # Trivial + return ctx[goal] + + # const & is satisfied with mutable & + if isinstance(goal.type, ConstRefCType): + try: + # WARNING: not strictly decreasing; be careful not + # to add a direct conversion that goes satisfies + # mutable& with const& + return solve( + NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct + ) + except UnsatError: + pass + + # mutable & is satisfied with value + if isinstance(goal.type, MutRefCType): + try: + return solve(NamedCType(goal.name, goal.type.elem), direct=direct) + except UnsatError: + pass + + # TODO: These are referentially equal, shouldn't have to do this; + # ensuring we don't use type synonym IntArrayRef in codegen would + # help + if goal.type == ArrayRefCType(BaseCType(longT)): + return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct) + + if direct: + unsat(goal) + + # For now, all of these rules are mutually exclusive. + if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): + memory_format = direct_solve( + NamedCType( + SpecialArgName.possibly_redundant_memory_format, + OptionalCType(BaseCType(memoryFormatT)), + ) + ) + # No need to join "memory_format" and "options" if the target API takes "options" directly. + # Otherwise it will cause the redundant memory_format error. + if options_ctype in goal_ctypes: + return memory_format + try: + options = direct_solve(options_ctype) + return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" + except UnsatError: + return memory_format + elif goal == NamedCType("options", BaseCType(tensorOptionsT)): + dtype = direct_solve( + NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))) + ) + pin_memory = direct_solve( + NamedCType("pin_memory", OptionalCType(BaseCType(boolT))) + ) + device = direct_solve( + NamedCType("device", OptionalCType(BaseCType(deviceT))) + ) + layout = direct_solve( + NamedCType("layout", OptionalCType(BaseCType(layoutT))) + ) + return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})" + + elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): + try: + options = direct_solve(options_ctype) + return f"c10::optTypeMetaToScalarType({options}.dtype_opt())" + except UnsatError: + out_tensor = direct_solve(out_tensor_ctype) + return f"{out_tensor}.scalar_type()" + + elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): + try: + options = direct_solve(options_ctype) + return f"{options}.layout_opt()" + except UnsatError: + out_tensor = direct_solve(out_tensor_ctype) + return f"{out_tensor}.layout()" + + elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): + try: + options = direct_solve(options_ctype) + return f"{options}.device_opt()" + except UnsatError: + out_tensor = direct_solve(out_tensor_ctype) + return f"{out_tensor}.device()" + + elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): + try: + options = direct_solve(options_ctype) + return f"{options}.pinned_memory_opt()" + except UnsatError: + # If we're calling a factory op from its out= variant, + # We don't actually care about the value of pin_memory. + out_tensor = direct_solve(out_tensor_ctype) + return "::std::nullopt" + + # We can always do translations from value types to reference types, like vector -> IntArrayRef + elif goal.type == BaseCType(intArrayRefT): + try: + return direct_solve(NamedCType(goal.name, longVec_ctype)) + except UnsatError: + # We can also go SymIntArrayRef -> IntArrayRef + symIntArrayRef_type = direct_solve( + NamedCType(goal.name, BaseCType(symIntArrayRefT)) + ) + return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})" + elif goal.type == BaseCType(symIntArrayRefT): + try: + r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT))) + return f"c10::fromIntArrayRefSlow({r})" + except UnsatError: + return direct_solve(NamedCType(goal.name, longSymVec_ctype)) + elif goal.type == BaseCType(SymIntT): + return direct_solve(NamedCType(goal.name, BaseCType(longT))) + elif goal.type == OptionalCType(BaseCType(SymIntT)): + argname = direct_solve( + NamedCType(goal.name, OptionalCType(BaseCType(longT))) + ) + return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt" + elif goal.type == BaseCType(longT): + symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT))) + return f"{symInt_type}.guard_int(__FILE__, __LINE__)" + elif goal.type == OptionalCType(BaseCType(longT)): + argname = direct_solve( + NamedCType(goal.name, OptionalCType(BaseCType(SymIntT))) + ) + return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt" + elif goal.type == BaseCType(optionalIntArrayRefT): + try: + return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) + except UnsatError: + argname = direct_solve( + NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT)) + ) + return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt" + elif goal.type == BaseCType(optionalSymIntArrayRefT): + # TODO: You might also want to solve this from longSymVec_ctype or + # an optional version of it + argname = direct_solve( + NamedCType(goal.name, BaseCType(optionalIntArrayRefT)) + ) + return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt" + elif goal.type == BaseCType(optionalScalarRefT): + return direct_solve(NamedCType(goal.name, optionalScalar_ctype)) + elif goal.type == BaseCType(optionalTensorRefT): + return direct_solve(NamedCType(goal.name, optionalTensor_ctype)) + + # Note [translation from C++ reference to value types] + # The below cases are all for when we have an argument with a reference type, + # and a corresponding goal with a value type. + # These are needed when we populate the inputs to a lambda capture and we need + # to guarantee the lifetime of each captured argument. + # We guard it with an explicit kwarg because converting to a value type is expensive + # (O(n)) to convert from IntArrayRef to vector), + # so the caller of translate() should be explicit that they need it. + if allow_expensive_conversions: + if goal.type == VectorCType(BaseCType(longT)): + intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) + argname = direct_solve(intArrayRef_ctype) + return f"{argname}.vec()" + if goal.type == VectorCType(BaseCType(SymIntT)): + symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT)) + argname = direct_solve(symIntArrayRef_ctype) + return f"{argname}.vec()" + elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): + optionalIntArrayRef_ctype = NamedCType( + goal.name, BaseCType(optionalIntArrayRefT) + ) + argname = direct_solve(optionalIntArrayRef_ctype) + return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt" + elif goal.type == OptionalCType(BaseCType(scalarT)): + optionalScalarRef_ctype = NamedCType( + goal.name, BaseCType(optionalScalarRefT) + ) + argname = direct_solve(optionalScalarRef_ctype) + return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" + elif goal.type == OptionalCType(BaseCType(scalarT)): + optionalTensorRef_ctype = NamedCType( + goal.name, BaseCType(optionalTensorRefT) + ) + argname = direct_solve(optionalTensorRef_ctype) + return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" + # Technically, we also need to handle cases of C++ containers holding reference types. + # But there currently aren't any ops that require lambda capture codegen + # With arguments like ::std::vector. + # If that changes, we'll have to add the translation here. + + # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor. + # We could probably generalize this to non-tensor types too. + if goal.type == MutRefCType(BaseCType(tensorT)): + const_ref_tensor_ctype = NamedCType( + goal.name, ConstRefCType(BaseCType(tensorT)) + ) + argname = direct_solve(const_ref_tensor_ctype) + return f"const_cast({argname})" + + unsat(goal) + + return [Expr(solve(g, direct=False), g) for g in goal_ctypes] diff --git a/lib/python3.10/site-packages/torchgen/api/types/__init__.py b/lib/python3.10/site-packages/torchgen/api/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e98bb8df493f2375b514e6c6aeb897cebe8ec7d --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/types/__init__.py @@ -0,0 +1,5 @@ +from torchgen.api.types.types import * +from torchgen.api.types.types_base import * + + +from torchgen.api.types.signatures import * # usort: skip diff --git a/lib/python3.10/site-packages/torchgen/api/types/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/types/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ec272841253753b2345d503dd5fcb03da91700c Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/types/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/types/__pycache__/signatures.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/types/__pycache__/signatures.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9a92c168ab93499583bd32d8ee5b06eae5426ed Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/types/__pycache__/signatures.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/types/__pycache__/types.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/types/__pycache__/types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e0a3294b99beb1615a3fbf7fed0abf8abd0b011 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/types/__pycache__/types.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/types/__pycache__/types_base.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/api/types/__pycache__/types_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a20d33f21c7fecb5d2a9966aee5f7b7c9d0f6b0a Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/api/types/__pycache__/types_base.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/api/types/signatures.py b/lib/python3.10/site-packages/torchgen/api/types/signatures.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d85ca6e2fe88e3a7047b2f3b1c887f5e583846 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/types/signatures.py @@ -0,0 +1,426 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterator, Sequence, TYPE_CHECKING + +from torchgen.api.types.types_base import Binding, CType, Expr + + +if TYPE_CHECKING: + from torchgen.model import ( + BackendIndex, + FunctionSchema, + NativeFunction, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + ) + + +@dataclass(frozen=True) +class CppSignature: + """ + A CppSignature represents a single overload in the C++ API. For + any given function schema, there may be multiple CppSignatures + corresponding to it, based on how we desugar to C++. See also + CppSignatureGroup. + """ + + # The schema this signature is derived from + func: FunctionSchema + + # Is this a C++ signature for a method, i.e. Tensor::my_op(...)? + method: bool + + # Is this a faithful C++ signature (i.e. following the JIT schema) or a convenience API + # (i.e. with a potential TensorOptions argument and out arguments in the front) + faithful: bool + + # Is this a symint C++ signature. For BC reasons, functions that take + # SymInts still present as int64_t in C++, and the SymInt variant is + # offered at a different overload name + # + # NB: If a function RETURNS a SymInt, this is ALWAYS false + symint: bool + + # The set of C++ arguments which should not have defaults applied to them + cpp_no_default_args: set[str] + + # Is this a fallback C++ binding? Fallback bindings are enabled by + # manual_cpp_binding: True and are alternate, non-public API that + # lets manual C++ binding implementors access the binding that would + # have been automatically generated + fallback_binding: bool = False + + # Return the unpacked argument structure of this signature, + # discarding information about which arguments are semantically + # related to each other. + def arguments(self) -> Sequence[Binding]: + return cpp.arguments( + self.func.arguments, + faithful=self.faithful, + symint=self.symint, + method=self.method, + cpp_no_default_args=self.cpp_no_default_args, + ) + + def name(self, *, suppress_symint_suffix: bool = False) -> str: + n = cpp.name( + self.func, + faithful_name_for_out_overloads=self.faithful, + symint_overload=False if suppress_symint_suffix else self.symint, + ) + if self.fallback_binding: + n = f"__dispatch_{n}" + return n + + # Render the C++ declaration for this signature + def decl( + self, + *, + name: str | None = None, + prefix: str = "", + is_redispatching_fn: bool = False, + suppress_symint_suffix: bool = False, + ) -> str: + returns_type = cpp.returns_type( + self.func.returns, symint=self.symint + ).cpp_type() + cpp_args = [a.decl() for a in self.arguments()] + if is_redispatching_fn: + cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args + cpp_args_str = ", ".join(cpp_args) + if name is None: + name = prefix + self.name(suppress_symint_suffix=suppress_symint_suffix) + return f"{returns_type} {name}({cpp_args_str})" + + # Render the C++ definition for this signature, not including + # the body (with curly braces) + def defn( + self, + *, + name: str | None = None, + prefix: str = "", + is_redispatching_fn: bool = False, + ) -> str: + returns_type = cpp.returns_type( + self.func.returns, symint=self.symint + ).cpp_type() + cpp_args = [a.defn() for a in self.arguments()] + if is_redispatching_fn: + cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args + cpp_args_str = ", ".join(cpp_args) + if name is None: + name = prefix + self.name() + return f"{returns_type} {name}({cpp_args_str})" + + def ptr_type(self) -> str: + args_types_str = ", ".join(a.type for a in self.arguments()) + return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_types_str})" + + # Return the C++ function type, e.g., something like int(bool) + def type(self) -> str: + args_types_str = ", ".join(a.type for a in self.arguments()) + return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} ({args_types_str})" + + +# Represents group of all CppSignatures associated with a +# FunctionSchema. Right now, that's the regular, user-visible +# signature, as well as a "faithful" signature which doesn't +# have grouping. +@dataclass(frozen=True) +class CppSignatureGroup: + func: FunctionSchema + signature: CppSignature + faithful_signature: CppSignature | None + symint_signature: CppSignature | None + symint_faithful_signature: CppSignature | None + + def most_faithful_signature(self) -> CppSignature: + if self.faithful_signature: + return self.faithful_signature + else: + return self.signature + + def signatures(self, *, symint: bool = True) -> Iterator[CppSignature]: + yield self.signature + if self.faithful_signature: + yield self.faithful_signature + if symint: + if self.symint_signature: + yield self.symint_signature + if self.symint_faithful_signature: + yield self.symint_faithful_signature + + @staticmethod + def from_native_function( + f: NativeFunction, *, method: bool, fallback_binding: bool = False + ) -> CppSignatureGroup: + func = f.func + + def make_sig(*, faithful: bool, symint: bool) -> CppSignature: + return CppSignature( + func=func, + faithful=faithful, + symint=symint, + method=method, + fallback_binding=fallback_binding, + cpp_no_default_args=f.cpp_no_default_args, + ) + + def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]: + faithful_signature: CppSignature | None = None + if func.arguments.tensor_options is not None or len(func.arguments.out) > 0: + faithful_signature = make_sig(faithful=True, symint=symint) + signature = make_sig(faithful=False, symint=symint) + return signature, faithful_signature + + signature, faithful_signature = make_sigs(symint=False) + symint_signature: CppSignature | None = None + symint_faithful_signature: CppSignature | None = None + if func.has_symint(): + symint_signature, symint_faithful_signature = make_sigs(symint=True) + + return CppSignatureGroup( + func=func, + signature=signature, + faithful_signature=faithful_signature, + symint_signature=symint_signature, + symint_faithful_signature=symint_faithful_signature, + ) + + +@dataclass(frozen=True) +class DispatcherSignature: + # The schema this signature is derived from + func: FunctionSchema + + # Allows you to prepend an arbitrary prefix to the signature name. + # This is useful for parts of the codegen that generate wrappers around kernels, + # and need to avoid naming collisions. + prefix: str = "" + + symint: bool = True + + def arguments(self) -> list[Binding]: + return dispatcher.arguments(self.func, symint=self.symint) + + def name(self) -> str: + return self.prefix + dispatcher.name(self.func) + + def decl(self, name: str | None = None) -> str: + args_str = ", ".join(a.decl() for a in self.arguments()) + if name is None: + name = self.name() + return f"{self.returns_type().cpp_type()} {name}({args_str})" + + def defn( + self, name: str | None = None, *, is_redispatching_fn: bool = False + ) -> str: + args = [a.defn() for a in self.arguments()] + if is_redispatching_fn: + args = ["c10::DispatchKeySet dispatchKeySet"] + args + args_str = ", ".join(args) + if name is None: + name = self.name() + return f"{self.returns_type().cpp_type()} {name}({args_str})" + + def exprs(self) -> list[Expr]: + return [Expr(a.name, a.nctype) for a in self.arguments()] + + def returns_type(self) -> CType: + return dispatcher.returns_type(self.func.returns, symint=self.symint) + + def ptr_type(self) -> str: + dispatcher_args_types_str = ", ".join(a.type for a in self.arguments()) + return f"{self.returns_type().cpp_type()} (*)({dispatcher_args_types_str})" + + # Return the C++ function type, e.g., something like int(bool) + def type(self) -> str: + dispatcher_args_types_str = ", ".join(a.type for a in self.arguments()) + return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})" + + @staticmethod + def from_schema( + func: FunctionSchema, *, prefix: str = "", symint: bool = True + ) -> DispatcherSignature: + return DispatcherSignature(func, prefix, symint) + + +@dataclass(frozen=True) +class NativeSignature: + # The schema this signature is derived from + func: FunctionSchema + + symint: bool + + prefix: str = "" + + def name(self) -> str: + return self.prefix + native.name(self.func) + + def decl(self, name: str | None = None) -> str: + args_str = ", ".join(a.decl() for a in self.arguments()) + if name is None: + name = self.name() + return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})" + + def defn(self, name: str | None = None) -> str: + args_str = ", ".join(a.defn() for a in self.arguments()) + if name is None: + name = self.name() + return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})" + + def ptr_type(self) -> str: + # don't include defaults in type signature! + args_str = ", ".join(a.defn() for a in self.arguments()) + return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})" + + def arguments(self) -> list[Binding]: + return native.arguments(self.func, symint=self.symint) + + def returns_type(self) -> CType: + return native.returns_type(self.func.returns, symint=self.symint) + + def dispatcher_exprs(self) -> list[Expr]: + return translate.translate( + self.arguments(), dispatcher.arguments(self.func), method=False + ) + + +@dataclass(frozen=True) +class ViewInverseSignature: + g: NativeFunctionsViewGroup + + def name(self) -> str: + return functionalization.reverse_name(self.g.view, include_namespace=False) + + def decl(self) -> str: + return_type = functionalization.returns_type(self.g.view.func) + decls = [ + a.decl() + for a in functionalization.inner_arguments( + self.g.view.func, is_reverse=True + ) + ] + return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});" + + +@dataclass(frozen=True) +class FunctionalizationLambda: + g: NativeFunctionsViewGroup + + # are we generating the forward lambda or the reverse lambda? + is_reverse: bool + + def captures(self) -> list[Expr]: + # The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments + # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed, + # and plumb it into the lambda. + outer_ctx = dispatcher.arguments(self.g.view.func) + [ + functionalization.reapply_views_binding, + functionalization.inverse_return_mode_binding, + ] + capture_bindings = functionalization.capture_arguments( + self.g.view.func, is_reverse=self.is_reverse + ) + # allow_expensive_conversions is set because we want to convert + # some reference types (IntArrayRef) to value types (vector). + capture_exprs = translate.translate( + outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True + ) + return capture_exprs + + def decl(self) -> str: + return_type = functionalization.returns_type(self.g.view.func) + capture_str = ", ".join( + f"{val.type.name} = {val.expr}" for val in self.captures() + ) + decls = [ + a.decl() + for a in functionalization.outer_arguments(is_reverse=self.is_reverse) + ] + return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}" + + def inner_call(self, *, reapply_views: bool | None = None) -> str: + inner_call_name = functionalization.name( + self.g, + is_reverse=self.is_reverse, + include_namespace=True, + reapply_views=reapply_views, + ) + + arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse) + capture_ctx = functionalization.capture_arguments( + self.g.view.func, is_reverse=self.is_reverse + ) + full_ctx = arg_ctx + capture_ctx + + assert self.g.view_copy is not None + call_bindings = functionalization.inner_arguments( + self.g.view_copy.func, is_reverse=self.is_reverse + ) + maybe_index = functionalization.inner_call_index(self.g.view_copy.func) + call_exprs = [ + e.expr for e in translate.translate(full_ctx, call_bindings, method=False) + ] + if not self.is_reverse and maybe_index is not None: + return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];' + else: + return f'{inner_call_name}({", ".join(call_exprs)});' + + @staticmethod + def from_func( + g: NativeFunctionsViewGroup, *, is_reverse: bool + ) -> FunctionalizationLambda: + return FunctionalizationLambda(g, is_reverse) + + +@dataclass(frozen=True) +class StructuredImplSignature: + g: NativeFunctionsGroup + name: str + + def defn(self, name: str | None = None) -> str: + args_str = ", ".join(a.defn() for a in self.arguments()) + return f"TORCH_IMPL_FUNC({self.name})({args_str})" + + def arguments(self) -> list[Binding]: + return structured.impl_arguments(self.g) + + +# Helper functions + + +def kernel_signature( + f: NativeFunction, backend_index: BackendIndex, *, prefix: str = "" +) -> NativeSignature | DispatcherSignature: + # Note [External Backends Follow Dispatcher API] + # Kernel signatures for in-tree backends follow the "native" API, + # while kernels for out-of-tree backends follow the dispatcher API. + # See the comments in `native.py` for details, but historically there have been + # some small differences in schema convention between them and the Dispatcher API. + # Any differences that require translating between the two will results in a runtime cost, + # so we'd like to keep the differences as small as possible. + # With external backends, we'd like to enforce that they write their kernels with schemas + # that match the Dispatcher API directly, if they can. + meta = backend_index.get_kernel(f) + symint = meta is not None and meta.supports_symint() + if symint: + assert ( + f.func.has_symint() + ), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema" + if backend_index.external: + return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint) + else: + return NativeSignature(f.func, prefix=prefix, symint=symint) + + +# Functions only, no types +from torchgen.api import ( + cpp, + dispatcher, + functionalization, + native, + structured, + translate, +) diff --git a/lib/python3.10/site-packages/torchgen/api/types/types.py b/lib/python3.10/site-packages/torchgen/api/types/types.py new file mode 100644 index 0000000000000000000000000000000000000000..30e027a631200029e01f337b96c77013193bfd4f --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/types/types.py @@ -0,0 +1,191 @@ +""" +Where should I add a new type? `types_base.py` vs `types.py` + +This file defines data model classes for torchgen typing system, as well as some base types such as int32_t. + +`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types. + +The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't +contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused +if we want to generate code for another C++ library. + +Add new types to `types.py` if these types are ATen/c10 related. +Add new types to `types_base.py` if they are basic and not attached to ATen/c10. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from torchgen.api.types.types_base import ( + BaseCppType, + BaseCType, + boolT, + byteT, + charT, + CType, + doubleT, + floatT, + int32T, + longT, + shortT, +) +from torchgen.model import BaseTy, ScalarType + + +TENSOR_LIST_LIKE_CTYPES = [ + "at::TensorList", + "const c10::List<::std::optional> &", + "const at::ITensorListRef &", +] + + +halfT = BaseCppType("at", "Half") +complexHalfT = BaseCppType( + "c10", "complex" +) # stuffing template param here is an abuse +complexFloatT = BaseCppType("c10", "complex") +complexDoubleT = BaseCppType("c10", "complex") +bfloat16T = BaseCppType("at", "BFloat16") +float8_e5m2T = BaseCppType("at", "Float8_e5m2") +float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz") +float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn") +float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz") +stringT = BaseCppType("c10", "string_view") +generatorT = BaseCppType("at", "Generator") +scalarTypeT = BaseCppType("at", "ScalarType") +tensorT = BaseCppType("at", "Tensor") +optionalTensorRefT = BaseCppType("at", "OptionalTensorRef") +tensorListT = BaseCppType("at", "TensorList") +iTensorListRefT = BaseCppType("at", "ITensorListRef") +iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef") +dimnameT = BaseCppType("at", "Dimname") +dimnameListT = BaseCppType("at", "DimnameList") +dimVectorT = BaseCppType("at", "DimVector") +layoutT = BaseCppType("at", "Layout") +deviceT = BaseCppType("at", "Device") +deviceIndexT = BaseCppType("at", "DeviceIndex") +scalarT = BaseCppType("at", "Scalar") +optionalScalarRefT = BaseCppType("at", "OptionalScalarRef") +memoryFormatT = BaseCppType("at", "MemoryFormat") +qschemeT = BaseCppType("at", "QScheme") +storageT = BaseCppType("at", "Storage") +streamT = BaseCppType("at", "Stream") +intArrayRefT = BaseCppType("at", "IntArrayRef") +optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef") +optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef") +tensorOptionsT = BaseCppType("at", "TensorOptions") +typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize") +tensorGeometryT = BaseCppType("at", "TensorGeometry") +SymIntT = BaseCppType("c10", "SymInt") +symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef") + +# Types representing template parameters. Technically, we probably shouldn't +# represent them this way in codegen, but it was pretty convenient. +scalar_t = BaseCppType("", "scalar_t") +opmath_t = BaseCppType("", "opmath_t") + +ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = { + ScalarType.Byte: byteT, + ScalarType.Char: charT, + ScalarType.Short: shortT, + ScalarType.Int: int32T, + ScalarType.Long: longT, + ScalarType.Half: halfT, + ScalarType.Float: floatT, + ScalarType.Double: doubleT, + ScalarType.ComplexHalf: complexHalfT, + ScalarType.ComplexFloat: complexFloatT, + ScalarType.ComplexDouble: complexDoubleT, + ScalarType.Bool: boolT, + ScalarType.Float8_e5m2: float8_e5m2T, + ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT, + ScalarType.Float8_e4m3fn: float8_e4m3fnT, + ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT, +} + +BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = { + BaseTy.int: longT, + BaseTy.float: doubleT, + BaseTy.bool: boolT, + BaseTy.str: stringT, + BaseTy.Generator: generatorT, + BaseTy.ScalarType: scalarTypeT, + BaseTy.Tensor: tensorT, + BaseTy.Dimname: dimnameT, + BaseTy.DimVector: dimVectorT, + BaseTy.Layout: layoutT, + BaseTy.Device: deviceT, + BaseTy.DeviceIndex: deviceIndexT, + BaseTy.Scalar: scalarT, + BaseTy.MemoryFormat: memoryFormatT, + BaseTy.QScheme: qschemeT, + BaseTy.Storage: storageT, + BaseTy.Stream: streamT, + BaseTy.SymInt: SymIntT, +} + +# CTypes encode C++ type structure as needed for translation. + + +@dataclass(frozen=True) +class OptionalCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"::std::optional<{self.elem.cpp_type()}>" + + def cpp_type_registration_declarations(self) -> str: + return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>" + + def remove_const_ref(self) -> CType: + return OptionalCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class ListCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"c10::List<{self.elem.cpp_type()}>" + + def cpp_type_registration_declarations(self) -> str: + return f"c10::List<{self.elem.cpp_type_registration_declarations()}>" + + def remove_const_ref(self) -> CType: + return ListCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class ArrayRefCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"at::ArrayRef<{self.elem.cpp_type()}>" + + def cpp_type_registration_declarations(self) -> str: + return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>" + + def remove_const_ref(self) -> CType: + return ArrayRefCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class VectorizedCType(CType): + # This template is explicitly specialized, so the only valid + # elems are those we have specializations for (e.g., float, double, ...) + # scalar_t is also a common argument here (when we are codegen in + # a templated context) + elem: BaseCType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + return f"at::vec::Vectorized<{self.elem.cpp_type()}>" + + def cpp_type_registration_declarations(self) -> str: + raise NotImplementedError + + def remove_const_ref(self) -> CType: + return self diff --git a/lib/python3.10/site-packages/torchgen/api/types/types_base.py b/lib/python3.10/site-packages/torchgen/api/types/types_base.py new file mode 100644 index 0000000000000000000000000000000000000000..e031b79485e057769302149369500cdb3df4c1e2 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/types/types_base.py @@ -0,0 +1,276 @@ +""" +Where should I add a new type? `types_base.py` vs `types.py` + +This file defines data model classes for torchgen typing system, as well as some base types such as int32_t. + +`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types. + +The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't +contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused +if we want to generate code for another C++ library. + +Add new types to `types.py` if these types are ATen/c10 related. +Add new types to `types_base.py` if they are basic and not attached to ATen/c10. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import auto, Enum +from typing import TYPE_CHECKING, Union + + +if TYPE_CHECKING: + from torchgen.model import Argument, SelfArgument, TensorOptionsArguments + + +# An ArgName is just the str name of the argument in schema; +# but in some special circumstances, we may add a little extra +# context. The Enum SpecialArgName covers all of these cases; +# grep for their construction sites to see when they can occur. + + +class SpecialArgName(Enum): + possibly_redundant_memory_format = auto() + + +ArgName = Union[str, SpecialArgName] + + +# This class shouldn't be created directly; instead, use/create one of the singletons below. +@dataclass(frozen=True) +class BaseCppType: + ns: str | None + name: str + + def __str__(self) -> str: + if self.ns is None or self.ns == "": + return self.name + return f"{self.ns}::{self.name}" + + +# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen. +# Templated types get their own dataclass, mainly to make namespace parsing easier. +byteT = BaseCppType("", "uint8_t") +charT = BaseCppType("", "int8_t") +shortT = BaseCppType("", "int16_t") +# It would be more symmetric for this to be called intT, but it easy to mix +# this up with JIT int (which is int64_t in C++), so we intentionally don't +# define intT to make it obvious when you've stuffed it up +int32T = BaseCppType("", "int32_t") +longT = BaseCppType("", "int64_t") +doubleT = BaseCppType("", "double") +floatT = BaseCppType("", "float") +boolT = BaseCppType("", "bool") +voidT = BaseCppType("", "void") + + +class CType(ABC): + @abstractmethod + def cpp_type(self, *, strip_ref: bool = False) -> str: + raise NotImplementedError + + @abstractmethod + def cpp_type_registration_declarations(self) -> str: + raise NotImplementedError + + @abstractmethod + def remove_const_ref(self) -> CType: + return self + + +@dataclass(frozen=True) +class BaseCType(CType): + type: BaseCppType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + return str(self.type) + + # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml + # TODO: Kill this when we eventually remove it! + def cpp_type_registration_declarations(self) -> str: + return str(self.type).replace("at::", "") + + def remove_const_ref(self) -> CType: + return self + + +@dataclass(frozen=True) +class ConstRefCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + if strip_ref: + return self.elem.cpp_type(strip_ref=strip_ref) + return f"const {self.elem.cpp_type()} &" + + def cpp_type_registration_declarations(self) -> str: + return f"const {self.elem.cpp_type_registration_declarations()} &" + + def remove_const_ref(self) -> CType: + return self.elem.remove_const_ref() + + +@dataclass(frozen=True) +class VectorCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"::std::vector<{self.elem.cpp_type()}>" + + def cpp_type_registration_declarations(self) -> str: + return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>" + + def remove_const_ref(self) -> CType: + return VectorCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class ArrayCType(CType): + elem: CType + size: int + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"::std::array<{self.elem.cpp_type()},{self.size}>" + + def cpp_type_registration_declarations(self) -> str: + return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>" + + def remove_const_ref(self) -> CType: + return ArrayCType(self.elem.remove_const_ref(), self.size) + + +@dataclass(frozen=True) +class TupleCType(CType): + elems: list[CType] + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>' + + def cpp_type_registration_declarations(self) -> str: + return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>' + + def remove_const_ref(self) -> CType: + return TupleCType([e.remove_const_ref() for e in self.elems]) + + +@dataclass(frozen=True) +class MutRefCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + if strip_ref: + return self.elem.cpp_type(strip_ref=strip_ref) + return f"{self.elem.cpp_type()} &" + + def cpp_type_registration_declarations(self) -> str: + return f"{self.elem.cpp_type_registration_declarations()} &" + + def remove_const_ref(self) -> CType: + return self.elem.remove_const_ref() + + +# A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus +# semantic information about what it represents. For example, consider the +# argument "bool pin_memory"; its normal C++ type is "bool", but its C++ +# semantic type also keeps track that this represents a "pin_memory"; you can't +# just use a random other boolean in a context where you need a "pin_memory"! +# + + +@dataclass(frozen=True) +class NamedCType: + name: ArgName + type: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + return self.type.cpp_type(strip_ref=strip_ref) + + # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml + # TODO: Kill this when we eventually remove it! + def cpp_type_registration_declarations(self) -> str: + return self.type.cpp_type_registration_declarations() + + def remove_const_ref(self) -> NamedCType: + return NamedCType(self.name, self.type.remove_const_ref()) + + def with_name(self, name: str) -> NamedCType: + return NamedCType(name, self.type) + + +# A binding represents any C++ binding site for a formal parameter. +# We don't distinguish between binding sites for different APIs; +# instead, all of the important distinctions are encoded in CType, +# which you can use to figure out if a given Binding is appropriate +# for use in another context. (See torchgen.api.translate) + + +@dataclass(frozen=True) +class Binding: + name: str + nctype: NamedCType + argument: Argument | TensorOptionsArguments | SelfArgument + # TODO: maybe don't represent default here + default: str | None = None + + def rename(self, name: str) -> Binding: + return Binding( + name=name, + nctype=self.nctype, + argument=self.argument, + default=self.default, + ) + + @property + def type(self) -> str: + return self.nctype.cpp_type() + + def no_default(self) -> Binding: + return Binding( + name=self.name, + nctype=self.nctype, + default=None, + argument=self.argument, + ) + + def decl(self, *, func_ptr_cast: bool = False) -> str: + mb_default = "" + if self.default is not None: + mb_default = f"={self.default}" + + # casting only needs to know the type + if func_ptr_cast: + return f"{self.type}" + else: + return f"{self.type} {self.name}{mb_default}" + + # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml + # TODO: Kill this when we eventually remove it! + def decl_registration_declarations(self) -> str: + type_s = self.nctype.cpp_type_registration_declarations() + mb_default = "" + if self.default is not None: + mb_default = f"={self.default}" + return f"{type_s} {self.name}{mb_default}" + + def defn(self) -> str: + return f"{self.type} {self.name}" + + def with_name(self, name: str) -> Binding: + return Binding( + name=name, nctype=self.nctype, argument=self.argument, default=self.default + ) + + +# An Expr is a C++ expression. It has a C++ string representing its syntax, +# as well as a CType saying what it provides. + + +@dataclass(frozen=True) +class Expr: + expr: str + type: NamedCType diff --git a/lib/python3.10/site-packages/torchgen/api/ufunc.py b/lib/python3.10/site-packages/torchgen/api/ufunc.py new file mode 100644 index 0000000000000000000000000000000000000000..17adcccecab563b6a4003215c778a00d5e1399c4 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/ufunc.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torchgen.api.types as api_types +from torchgen.api import cpp, structured +from torchgen.api.types import ( + ArgName, + BaseCppType, + BaseCType, + Binding, + ConstRefCType, + CType, + NamedCType, + scalarT, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + DispatchKey, + FunctionSchema, + NativeFunctionsGroup, + Type, +) + + +def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str: + assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas" + return f"ufunc_{func.name.name}_{dispatch_key}" + + +def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str: + return schema_kernel_name(g.out.func, dispatch_key) + + +# Tensors are omitted (as they are stored in TensorIterator), everything else is +# passed along (technically, we can pass tensors along too, it just wastes +# argument registers) +# +# NB: used for CPU only +def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None: + # Dispatch stubs are always plain ints + r = cpp.valuetype_type(t, binds=binds, symint=False) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + elif t == BaseType(BaseTy.Tensor): + return None + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +def opmath_type(scalar_t: BaseCppType) -> BaseCppType: + if scalar_t == api_types.scalar_t: + return api_types.opmath_t + raise NotImplementedError + + +# NB: Tensors in constructor are stored in opmath_t, not scalar_t +# because Tensor in constructor = its a scalar tensor partially applied = +# it can be higher precision and we want to compute in that higher precision +# +# NB: CUDA only +def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: + r = cpp.valuetype_type(t, binds=binds, symint=False) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, BaseCType(opmath_type(scalar_t))) + elif t == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(opmath_type(scalar_t))) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Only Tensors ever get passed directly to operator() +# +# NB: CUDA only +# (Actually, this works for CPU too) +def ufunctor_apply_type( + t: Type, *, binds: ArgName, scalar_t: BaseCppType +) -> NamedCType: + if t == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(scalar_t)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# The actual ufunc template function the user writes. Everything here +# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t +# in CPU +def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType: + r = cpp.valuetype_type(t, binds=binds, symint=False) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, compute_t) + elif t == BaseType(BaseTy.Tensor): + return NamedCType(binds, compute_t) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding: + return Binding( + nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t), + name=a.name, + default=None, + argument=a, + ) + + +def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding: + return Binding( + nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t), + name=a.name, + default=None, + argument=a, + ) + + +def ufunc_argument(a: Argument, compute_t: CType) -> Binding: + return Binding( + nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t), + name=a.name, + default=None, + argument=a, + ) + + +@dataclass(frozen=True) +class UfunctorBindings: + ctor: list[Binding] + apply: list[Binding] + + +# ufunctors are a CUDA-only concept representing functors that take some of +# their arguments on a host-side constructor, and the rest in the device-side +# apply. E.g., +# +# template +# struct CUDAFunctorOnSelf_add { +# using opmath_t = at::opmath_type; +# opmath_t other_; +# opmath_t alpha_; +# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {} +# __device__ scalar_t operator()(scalar_t self) { +# return ufunc::add(static_cast(self), other_, alpha_); +# } +# }; +# +# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers +# to the operator() definition +def ufunctor_arguments( + g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType +) -> UfunctorBindings: + ctor = [] + apply = [] + for a in g.functional.func.arguments.flat_non_out: + if a.type.is_tensor_like(): + if scalar_tensor_idx == 0: + # put it in the ctor anyway + ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) + scalar_tensor_idx = None + else: + if scalar_tensor_idx is not None: + scalar_tensor_idx -= 1 + apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t)) + else: + ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) + assert scalar_tensor_idx is None + return UfunctorBindings(ctor=ctor, apply=apply) + + +# ufuncs are the inner loop template functions that you wrote in ufunc/add.h +# which do the actual computation in question. E.g., +# +# template +# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ { +# return self + alpha * other; +# } +# +# In this file, we refer to T as compute_t which is bound by caller +def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]: + return [ + ufunc_argument(a, compute_t=compute_t) + for a in g.functional.func.arguments.flat_non_out + ] + + +# Stubs are the DispatchStub trampolines that CPU kernels use to get to their +# vectorized versions. E.g., +# +# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha); +# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); +def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]: + # stubs drop all tensor arguments (they are implicit in the TensorIterator + # argument and keep everything else) + return [ + r + for a in g.out.func.arguments.flat_non_out + if not a.type.is_tensor_like() + for r in structured.argument(a) + ] diff --git a/lib/python3.10/site-packages/torchgen/api/unboxing.py b/lib/python3.10/site-packages/torchgen/api/unboxing.py new file mode 100644 index 0000000000000000000000000000000000000000..1e649b7517889d284bf13fe8d0bd737e4e81f5f5 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/api/unboxing.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +from torchgen.api import cpp +from torchgen.api.types import Binding, CppSignatureGroup, CType +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + ListType, + NativeFunction, + OptionalType, + Type, +) + + +# This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the +# ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is +# an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the +# job done statically. These generated unboxing wrappers will be useful under the scenario where we need to register +# a fixed set of operators known at compile time and thus can save some time in runtime initialization phase. +# +# Here's an example on how the codegen works: +# +# - Function Schema (source of truth) +# +# aten::empty.names(int[] size, *, Dimname[]? names, +# ScalarType? dtype=None, Layout? layout=None, +# Device? device=None, bool? pin_memory=None, +# MemoryFormat? memory_format=None) -> Tensor +# - Argument Conversion +# Generates C++ code to convert an ivalue (from stack) to its underlying C++ type. +# - int[] size +# ```cpp +# const c10::List size_list_in = (std::move(peek(stack, 0, 7))).toList(); +# +# std::vector size_vec; +# for (c10::IValue size_elem: size_list_in) { +# int64_t size_base = size_elem.to(); +# size_vec.push_back(size_base); +# } +# at::ArrayRef size_list_out(size_vec); +# ~~~~~~~~~~~~~ <-- The converted argument from ivalues in the stack. +# Will be passed to unboxed kernel. +# ``` +# - Dimname[]? names +# ```cpp +# ::std::optional names_opt = (std::move(peek(stack, 1, 7))).toOptional(); +# ::std::optional> names_opt_out; +# if (names_opt.has_value()) { +# ~~~~~~~~~~~ <-- Unwrapping optional shell +# const c10::IValue names_opt_in = names_opt.value(); +# const c10::List names_list_in = names_opt_in.toList(); +# +# std::vector names_vec; +# for (c10::IValue names_elem: names_list_in) { +# ~~~~~~~~~~~~~~~~~~~~~~~~~ <-- Unrolling list, then convert elements one by one. +# at::Dimname names_base = names_elem.to(); +# names_vec.push_back(names_base); +# } +# at::ArrayRef names_list_out(names_vec); +# +# names_opt_out = ::std::optional>(names_list_out); +# } else { +# names_opt_out = ::std::optional>(); +# } +# ``` +# - ScalarType? dtype (similarly for the rest of the arguments) +# ```cpp +# ::std::optional dtype_opt = (std::move(peek(stack, 2, 7))).toOptional(); +# ::std::optional dtype_opt_out; +# if (dtype_opt.has_value()) { +# const c10::IValue dtype_opt_in = dtype_opt.value(); +# at::ScalarType dtype_base = dtype_opt_in.to(); +# ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it +# directly using ".to()" API. +# dtype_opt_out = ::std::optional(dtype_base); +# } else { +# dtype_opt_out = ::std::optional(); +# } +# ``` +# +# - Unboxed Kernel Call +# ```cpp +# auto result_ = torch::empty( +# size_list_out, +# names_opt_out, +# options, +# memory_format_opt_out +# ); +# ``` +# +# - Push Result Back to Stack +# ```cpp +# drop(stack, 7); +# pack(stack, std::move(result_)); +# ``` +connector = "\n\t" + + +# Return unboxing function name for a NativeFunction +def name(f: NativeFunction) -> str: + return f.func.name.unambiguous_name() + + +# Convert all the arguments in a NativeFunction to C++ code +def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]: + # we need the 'self' argument so method needs to be False + args = ( + CppSignatureGroup.from_native_function(f, method=False) + .most_faithful_signature() + .arguments() + ) + code_list = [ + f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));" + for i in range(len(args)) + ] + [""] + binding_list = [] + for arg in args: + # expecting only Argument + if not isinstance(arg.argument, Argument): + raise Exception( # noqa: TRY002 + f"Unexpected argument type, expecting `Argument` but got {arg}" + ) + argument: Argument = arg.argument + unboxed_name, _, code, decl = argumenttype_ivalue_convert( + argument.type, + argument.name, + mutable=argument.is_write, + ) + code_list.extend(decl) + code_list.extend(code) + binding_list.append(arg.with_name(unboxed_name)) + return binding_list, code_list + + +# Takes in the type, name and mutability corresponding to an argument, and generates a tuple of: +# (1) the C++ code necessary to unbox the argument +# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType +def argumenttype_ivalue_convert( + t: Type, arg_name: str, *, mutable: bool = False +) -> tuple[str, CType, list[str], list[str]]: + # Unboxing is for mobile, which doesn't care about SymInts + ctype = cpp.argumenttype_type( + t=t, mutable=mutable, binds=arg_name, symint=False + ).type + + if isinstance(t, BaseType): + out_name = f"{arg_name}_base" + code, decl = _gen_code_base_type( + arg_name=arg_name, out_name=out_name, ctype=ctype + ) + elif isinstance(t, OptionalType): + out_name = f"{arg_name}_opt_out" + code, decl = _gen_code_optional_type( + arg_name=arg_name, + out_name=out_name, + t=t, + ctype=ctype, + ) + elif isinstance(t, ListType): + out_name = f"{arg_name}_list_out" + code, decl = _gen_code_list_type( + arg_name=arg_name, + out_name=out_name, + t=t, + ctype=ctype, + ) + else: + raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}") # noqa: TRY002 + return out_name, ctype, code, decl + + +def _gen_code_base_type( + arg_name: str, out_name: str, ctype: CType +) -> tuple[list[str], list[str]]: + return [ + f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" + ], [] + + +def _gen_code_optional_type( + arg_name: str, out_name: str, t: OptionalType, ctype: CType +) -> tuple[list[str], list[str]]: + in_name = f"{arg_name}_opt_in" + res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name) + return ( + f""" +auto {arg_name}_opt = {arg_name}.toOptional(); +{ctype.cpp_type(strip_ref=True)} {out_name}; +if ({arg_name}_opt.has_value()) {{ + const c10::IValue {in_name} = {arg_name}_opt.value(); + {connector.join(res_code)} + {out_name} = {ctype.cpp_type(strip_ref=True)}({res_name}); +}} else {{ + {out_name} = {ctype.cpp_type(strip_ref=True)}(); +}} + """.split( + "\n" + ), + decl, + ) + + +def _gen_code_list_type( + arg_name: str, out_name: str, t: ListType, ctype: CType +) -> tuple[list[str], list[str]]: + in_name = f"{arg_name}_list_in" + elem_name = f"{arg_name}_elem" + code = [f"const c10::List {in_name} = {arg_name}.toList();"] + res_name, res_ctype, res_code, decl = argumenttype_ivalue_convert(t.elem, elem_name) + # handle list type with size, e.g., bool[4] + if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool and t.size: + code.extend( + f""" +{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name}); + """.split( + "\n" + ) + ) + # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional> + elif isinstance(t.elem, OptionalType): + code.extend( + f""" +{ctype.cpp_type(strip_ref=True)} {out_name}; +for (c10::IValue {elem_name}: {in_name}) {{ + {connector.join(res_code)} + {out_name}.push_back({res_name}); +}} + """.split( + "\n" + ) + ) + else: + # use ArrayRef as default. + vec_name = arg_name + "_vec" + # need to bring vector instantiation out of scope so that ArrayRef has valid data + decl.append(f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};") + code.extend( + f""" +for (c10::IValue {elem_name}: {in_name}) {{ + {connector.join(res_code)} + {vec_name}.push_back({res_name}); +}} +{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); + """.split( + "\n" + ) + ) + return code, decl diff --git a/lib/python3.10/site-packages/torchgen/dest/__init__.py b/lib/python3.10/site-packages/torchgen/dest/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f08a743ae2dc766530fd8f93be9ebb8b7733f21 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/dest/__init__.py @@ -0,0 +1,19 @@ +from torchgen.dest.lazy_ir import ( + generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes, + GenLazyIR as GenLazyIR, + GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition, + GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition, +) +from torchgen.dest.native_functions import ( + compute_native_function_declaration as compute_native_function_declaration, +) +from torchgen.dest.register_dispatch_key import ( + gen_registration_headers as gen_registration_headers, + gen_registration_helpers as gen_registration_helpers, + RegisterDispatchKey as RegisterDispatchKey, +) +from torchgen.dest.ufunc import ( + compute_ufunc_cpu as compute_ufunc_cpu, + compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel, + compute_ufunc_cuda as compute_ufunc_cuda, +) diff --git a/lib/python3.10/site-packages/torchgen/dest/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/dest/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d6e6ba822b67599543b2cfb01922fae4e4f04fe Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/dest/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9c2e1ac91945db06eaa210e1cce5535703853fe Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c1d481c5fea926a5fb5acbe3ae18248b0e1ffeb Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/dest/__pycache__/native_functions.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/dest/__pycache__/native_functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..592dd14d12b28210a5ef38d64988a67d78992cd5 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/dest/__pycache__/native_functions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe55e7f1bab1c4a5213716fdb848dcfa806f0f60 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/dest/__pycache__/ufunc.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/dest/__pycache__/ufunc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9a589124d3fe1f456e0c34fe2fb3ace9c078f91 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/dest/__pycache__/ufunc.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/dest/lazy_ir.py b/lib/python3.10/site-packages/torchgen/dest/lazy_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..976c823a1653746a9c4f1289dca9b5d883097838 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/dest/lazy_ir.py @@ -0,0 +1,707 @@ +from __future__ import annotations + +import itertools +from abc import ABC +from dataclasses import dataclass +from typing import Any + +import torchgen.api.dispatcher as dispatcher +from torchgen.api.lazy import ( + getValueT, + isValueType, + LazyArgument, + LazyIrProperties, + LazyIrSchema, + tensorListValueT, +) +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + deviceT, + DispatcherSignature, + kernel_signature, + NativeSignature, + OptionalCType, + VectorCType, +) +from torchgen.context import method_with_native_function +from torchgen.dest.lazy_ts_lowering import ts_lowering_body +from torchgen.model import ( + Argument, + BackendIndex, + BackendMetadata, + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + NativeFunctionsGroup, +) + + +def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str: + """ + Given a LazyArgument, + generate a c++ string for materializing an rvalue of that arg for passing into + a lazy Node constructor. + """ + + # TODO: Matching on CType seems wrong; should be matching on Type + if isValueType(arg.lazy_type): + if isinstance(arg.lazy_type, BaseCType): + if arg.is_wrapped_scalar: + return f"node_{arg.name}" + elif arg.lazy_type.type is tensorListValueT: + return f"lazy_{arg.name}_tensorlist" + elif arg.is_symint_or_list: + return f"GetSymIntValue({arg.name})" + return f"lazy_{arg.name}->GetIrValue()" + elif isinstance(arg.lazy_type, OptionalCType): + if arg.is_symint_or_list: + # TODO: I don't understand when you should put lazy_ in the name + # or not + return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt" + elif arg.is_wrapped_scalar: + return f"node_{arg.name}" + return ( + f"lazy_{arg.name} ? " + f"std::make_optional(lazy_{arg.name}->GetIrValue()) : " + "::std::nullopt" + ) + else: + raise AssertionError( + f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})" + ) + else: + # NB: this is here because right now we aren't treating SymInt[] as a + # value type; when we do this needs to move above + # NB: we cannot test arg.lazy_type as we've already specified it is an + # int64_t and so we cannot distinguish between SymInt and int64_t + if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType( + BaseTy.SymInt + ): + if arg.symint: + return f"GetSymIntArrayRefValue({arg.name})" + else: + return f"std::vector({arg.name}.begin(), {arg.name}.end())" + elif isinstance(arg.lazy_type, VectorCType) and isinstance( + arg.lazy_type.elem, BaseCType + ): + return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())" + elif ( + isinstance(arg.lazy_type, OptionalCType) + and isinstance(arg.lazy_type.elem, VectorCType) + and isinstance(arg.lazy_type.elem.elem, BaseCType) + ): + return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})" + else: + return f"{arg.name}" + + +def node_ctor_inputs(schema: LazyIrSchema) -> str: + """ + Produce a formatted string with the arguments as passed into the constructor of a node class. + """ + node_ctor_values = [ + node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args() + ] + return ", ".join(node_ctor_values) + + +def gen_fallback_code( + schema: LazyIrSchema, + sig: DispatcherSignature | NativeSignature, + overload_name: str, +) -> str: + """ + Generate code that falls back to eager conditioned on a predicate + """ + dispatcher_sig = DispatcherSignature.from_schema(schema.func) + exprs = translate(sig.arguments(), dispatcher_sig.arguments()) + fallback_args = ",\n ".join([a.expr for a in exprs]) + if len(overload_name): + aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})" + else: + aten_op_str = f"ATEN_OP({schema.aten_name})" + return f""" + if (force_eager_fallback({aten_symbol(schema)})) {{ + return at::native::call_fallback_fn_symint<<c_eager_fallback, {aten_op_str}>::call( + {fallback_args} + ); + }} +""" + + +def aten_symbol(schema: LazyIrSchema) -> str: + missing_interned_strings = { + "sigmoid_backward", + } + if schema.aten_name in missing_interned_strings: + return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")' + + if not schema.aten_name.startswith("at::"): + return f"at::aten::{schema.aten_name}" + else: + return schema.aten_name + + +# converts all tensor-like arguments to meta tensors. Returns: +# (1) a string containing all of the logic that does the conversions. +# (2) a context, to be used by translate(), with all of the relevant bindings. +def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]: + context: list[Binding] = [] + unwrapped_tensor_args: list[str] = [] + for arg in sig.arguments(): + if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like(): + unwrapped_name = f"{arg.name}_meta" + unwrapped_tensor_args.append( + f"auto {unwrapped_name} = to_meta({arg.name});" + ) + context.append(arg.with_name(unwrapped_name)) + else: + context.append(arg) + unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) + return unwrap_tensor_args_str, context + + +@dataclass(frozen=True) +class GenLazyIR(ABC): + backend_index: BackendIndex + backend_name: str + node_base: str + use_lazy_shape: bool + + @method_with_native_function + def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]: + func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func + metadata = self.backend_index.get_kernel( + f.functional if isinstance(f, NativeFunctionsGroup) else f + ) + schema = LazyIrSchema( + func, symint=metadata is not None and metadata.supports_symint() + ) + return self.gen(schema) + + # there is no lowering functionality generated unless this IR base class is subclassed and + # implemented as a backend-specific node + def lowering_function(self, schema: LazyIrSchema) -> str: + return "" + + def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + return "" + + def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + return f"""bool CanBeReused({node_ctor_args}) const {{ + return false; + }}""" + + def node_base_ctor_call(self, schema: LazyIrSchema) -> str: + value_args = schema.filtered_args(values=True, scalars=False) + # backends can customize the way the node base class constructor is called, + # as long as all of its arguments can be generated from information available from the schema + base_ctor_value_args_list = [] + for arg in value_args: + if isinstance(arg.lazy_type, (BaseCType, VectorCType)): + base_ctor_value_args_list.append(f"{arg.name}") + elif isinstance(arg.lazy_type, OptionalCType): + base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)") + else: + raise AssertionError( + f"Unsupported type ({arg.lazy_type}) - add support if necessary" + ) + base_ctor_value_args = ", ".join(base_ctor_value_args_list) + + scalar_args = schema.filtered_args(values=False, scalars=True) + + # Shape construction. + # Conditionally build shape depending on specified shape property + if schema.properties.ShapePrecompute: + shape_ctor_arg = "std::move(shapes)," + elif schema.properties.ShapeCompute: + shape_args = [a.name for a in value_args] + shape_args.extend(a.name for a in scalar_args) + shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)})," + elif schema.properties.ShapeCache: + shape_args = [f"operand({i})" for i in range(len(value_args))] + shape_args.extend(a.name for a in scalar_args) + shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }}," + else: + shape_ctor_arg = "" + + scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args) + + return f"""{self.node_base}( + {schema.node_name}::ClassOpKind(), + OpList{{{base_ctor_value_args}}}, + {shape_ctor_arg} + /* num_outputs */ {len(schema.returns)}, + torch::lazy::MHash({scalar_hashes}))""" + + def gen(self, schema: LazyIrSchema) -> list[str]: + opkind = schema.opkind or aten_symbol(schema) + + # for now, we just want one IR class decl and soon after also the method defs + # and we use the functional version not out/inplace. + all_args = schema.filtered_args() + scalar_args = schema.filtered_args(values=False, scalars=True) + + ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args] + reuse_ctor_args = ", ".join(ctor_args) + if self.use_lazy_shape and schema.properties.ShapePrecompute: + ctor_args.append("std::vector&& shapes") + node_ctor_args = ", ".join(ctor_args) + + scalar_initializers = ",\n ".join( + [ + # This code is just special casing the mapping from string_view -> strings + f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)" + if a.lazy_type.cpp_type() == "::std::optional" + else f"{a.name}({a.name})" + for a in scalar_args + ] + ) + if len(scalar_initializers): + scalar_initializers = f",\n {scalar_initializers}" + scalar_decls = "\n ".join( + [ + f"std::string {a.name};" + if a.lazy_type.cpp_type() == "c10::string_view" + else f"::std::optional {a.name};" + if a.lazy_type.cpp_type() == "::std::optional" + else f"{a.lazy_type.cpp_type()} {a.name};" + for a in scalar_args + ] + ) + optional_values = [ + arg.name + for arg in schema.filtered_args(values=True, scalars=False) + if isinstance(arg.lazy_type, OptionalCType) + ] + has_optional_decls = "\n ".join( + [f"bool has_{value}: 1;" for value in optional_values] + ) + has_optional_defs = "\n ".join( + [f"has_{value} = !!{value};" for value in optional_values] + ) + members_to_string = [] + for arg in scalar_args: + if isinstance(arg.lazy_type, OptionalCType): + value = f"{arg.name}.value()" + if arg.is_generator: + value = '"torch.Generator()"' + members_to_string.append( + f"""if ({arg.name}.has_value()) {{ + ss << ", {arg.name}=" << {value}; + }} else {{ + ss << ", {arg.name}=null"; + }}""" + ) + else: + members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};') + members_to_string_str = "\n ".join(members_to_string) + + return [ + f"""\ +class {schema.node_name} : public {self.node_base} {{ + public: + static torch::lazy::OpKind ClassOpKind() {{ + return torch::lazy::OpKind({opkind}); + }} + + {schema.node_name}({node_ctor_args}) + : {self.node_base_ctor_call(schema)}{scalar_initializers} + {{ + {has_optional_defs} + }} + + std::string ToString() const override {{ + std::stringstream ss; + ss << {self.node_base}::ToString(); + {members_to_string_str} + return ss.str(); + }} + + {self.create_function(schema, reuse_ctor_args)} + + {self.can_be_reused_function(schema, reuse_ctor_args)} + + {self.lowering_function(schema)} + + {scalar_decls} + {has_optional_decls} + +}}; + +""", + ] + + +@dataclass(frozen=True) +class GenTSLazyIR(GenLazyIR): + def lowering_function(self, schema: LazyIrSchema) -> str: + signature = """ + torch::lazy::TSOpVector Lower( + std::shared_ptr function, + torch::lazy::TSLoweringContext* loctx) const override""" + + if schema.properties.LowerDeclOnly: + return f"{signature};" + elif schema.properties.Lower: + return f"""{signature} {{ + {ts_lowering_body(schema)} + }} + """ + else: + return "" + + def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + signature = f"static NodePtr Create({node_ctor_args})" + if schema.properties.CreateFnDeclOnly: + return f"{signature};" + elif not schema.properties.CreateFn: + return "" + return f"""{signature} {{ + return ReuseOrMakeNode<{schema.node_name}>(data); + }}""" + + def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + signature = f"bool CanBeReused({node_ctor_args}) const" + if schema.properties.CanBeReusedDeclOnly: + return f"{signature};" + elif not schema.properties.CanBeReused: + return "" + value_comparison = [] + for arg in itertools.chain(schema.positional_values, schema.keyword_values): + if isinstance(arg.lazy_type, OptionalCType): + value_comparison.append( + f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)" + ) + else: + value_comparison.append(f"operand(i++) == {arg.name}") + for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars): + if isinstance(arg.lazy_type, OptionalCType): + value_comparison.append( + f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))" + ) + else: + value_comparison.append(f"this->{arg.name} == {arg.name}") + value_comparison_str = " &&\n ".join(value_comparison) + + return f"""{signature} {{ + size_t i = 0; + return ({value_comparison_str}); + }}""" + + +@dataclass(frozen=True) +class GenLazyNativeFuncDefinition: + class_method_name: str + backend_index: BackendIndex + tensor_class: str + gen_forced_fallback_code: bool + backend_namespace: str + get_tensorlist: str + get_tensor_or_wrap_number: str + try_get_tensor: str + metrics_counter: str + create_tensor: str + create_from_first_tensor: bool + create_aten_from_ltc_tensor: str + tuple_aten_from_ltc_tensors: str + lazy_tensor_ptr: str + get_device_fn: str + + def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str: + value_args = schema.filtered_args(values=True, scalars=False) + # Generates lazy_{name} variables for LazyTensors wrapping input tensors + lazy_tensor_decls: list[str] = [] + for arg in value_args: + if arg.is_wrapped_scalar: + if isinstance(arg.lazy_type, OptionalCType): + lazy_tensor_decls.append( + f"""auto node_{arg.name} = {arg.name} ? + std::make_optional(torch::lazy::LazyGraphExecutor::Get()-> + GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)): + ::std::nullopt;""" + ) + else: + lazy_tensor_decls.append( + f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()-> + GetIrValueForScalarFromCodegen({arg.name}, *common_device);""" + ) + elif arg.is_symint_or_list: + continue # values are extracted in isValueType + elif isinstance(arg.lazy_type, BaseCType): + if arg.lazy_type.type is tensorListValueT: + lazy_tensor_decls.append( + f"auto lazy_{arg.name}_tensorlist = " + f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});" + ) + else: + lazy_tensor_decls.append( + f"{self.lazy_tensor_ptr} lazy_{arg.name} = " + f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);" + ) + elif isinstance(arg.lazy_type, OptionalCType): + assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem + # TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it + # until we encounter a real world example. + lazy_tensor_decls.append( + f"{self.lazy_tensor_ptr} lazy_{arg.name} = " + f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));" + ) + else: + raise AssertionError( + f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})" + ) + return ("\n ").join(lazy_tensor_decls) + + def force_eager_fallback( + self, + func: NativeFunction, + schema: LazyIrSchema, + metadata: BackendMetadata, + sig: DispatcherSignature | NativeSignature, + ) -> str: + if self.gen_forced_fallback_code: + return gen_fallback_code( + schema, sig, overload_name=func.func.name.overload_name + ) + return "" + + def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str: + return f"{self.metrics_counter};" + + def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str: + value_args = schema.filtered_args(values=True, scalars=False) + scalar_args = schema.filtered_args(values=False, scalars=True) + value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar] + optional_device = OptionalCType(BaseCType(deviceT)) + optional_devices = [ + a.name for a in scalar_args if a.lazy_type == optional_device + ] + assert ( + len(value_types_names) > 0 or len(optional_devices) > 0 + ), "Expected at least one Value or Device type" + get_device_str = ( + f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})" + ) + return f"""auto common_device = {get_device_str}; + TORCH_INTERNAL_ASSERT(common_device); + """ + + def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str: + metadata = self.backend_index.get_kernel(func) + assert metadata is not None + all_args = schema.filtered_args() + returns_length = len(schema.returns) + # call the meta kernel if it exists, to compute output shape/dtype for our IR + # Note [Generated LTC Shape Functions] + # LTC uses meta tensors from core to do shape inference when possible, and otherwise + # we generate a shape function declaration that needs to be manually implemented. + # How do we detect which ops are eligible to use meta tensors? + # In general we should be able to use meta tensors not just on structured operators, + # but also on composite operators that are implemented in terms of structured kernels. + # We don't currently have a way of knowing at codegen time which ops are implemented that way. + # This is the case for all view and view_copy operators however, so we're going to + # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them). + is_view_copy_op = "view_copy" in func.tags + is_structured = func.structured or func.structured_delegate is not None + if is_structured or is_view_copy_op: + meta_out = """ +std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};""" + if returns_length > 1: + + def this_shape(i: int) -> str: + return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())" + + shapes_str = ",".join([this_shape(i) for i in range(returns_length)]) + meta_out = "std::vector shapes{" + shapes_str + "};" + + # Convert tensor args to the meta device and call it. + # (We can't pass in the input tensors directly, because they are "functional wrappers". + # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.) + # Even at::meta:: functions might redispatch, e.g. if they call into view ops. + dispatcher_sig = DispatcherSignature.from_schema(func.func) + meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) + meta_call_args = [ + e.expr + for e in translate( + meta_call_ctx, dispatcher_sig.arguments(), method=False + ) + ] + if is_view_copy_op: + # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel + assert func.has_composite_explicit_autograd_non_functional_kernel + dispatch_ns = "compositeexplicitautogradnonfunctional" + else: + dispatch_ns = "meta" + aten_name = schema.aten_name + # TODO: this is trolling + if func.func.has_symint() and metadata.supports_symint(): + aten_name += "_symint" + shape_str = f"""\ + {meta_conversion_str} + auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)}); + {meta_out}""" + else: + shape_sig = ComputeShapeSignature( + metadata.kernel, func, symint=metadata.supports_symint() + ) + shape_str = f""" + auto shapes = {shape_sig.shape_call};""" + + shape_str += f""" + TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});""" + + # Calculating which dimensions are symbolic + func_schema_str = "aten::" + str(func.func) + shape_str += f""" + if(torch::lazy::symbolicShapeEnabled()){{ + std::vector inputs = {{ {', '.join(str(a.name) for a in all_args)} }}; + const char* schema_str = "{func_schema_str}"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); + }} + """ + return shape_str + + def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str: + node_ctor_input_str = node_ctor_inputs(schema) + return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str}); + if (!node) {{ + {self.shape_inference(func, schema)} + node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes)); + CacheNode(node); + }} + """ + + def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str: + # xla uses an instance method for tensor creation, for the time being + if self.create_from_first_tensor: + # TODO(whc) remove this if XLA switches to using static method for creation + assert ( + first_tensor_name is not None + ), "Requires first tensor to create lazy tensor" + return f"{first_tensor_name}.{self.create_tensor}" + return f"{self.backend_namespace}::{self.create_tensor}" + + def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str: + returns_length = len(schema.returns) + value_args = schema.filtered_args(values=True, scalars=False) + value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar] + first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None + bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}( + {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));""" + + if returns_length > 1: + assert ( + len(value_types_names) > 0 + ), "Code below assumes there is at least one tensor arg" + bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors; + for (int i = 0; i < {returns_length}; i++) {{ + lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device)); + }} + auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);""" + + if schema.name.name.inplace or func.func.is_out_fn(): + assert returns_length == 1, ( + "We assumed there was no such case where an op is an in-place variant " + f"and has tuple outputs, but got tuple of len {returns_length}." + ) + bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node); + auto& result = {first_tensor_name};""" + + bridge_str += """ + return result;""" + return bridge_str + + @method_with_native_function + def __call__(self, func: NativeFunction) -> list[str]: + sig = kernel_signature(func, self.backend_index) + metadata = self.backend_index.get_kernel(func) + assert metadata is not None + schema = LazyIrSchema(func.func, symint=metadata.supports_symint()) + return [ + f"""\ + {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{ + {self.force_eager_fallback(func, schema, metadata, sig)} + {self.metrics(func, schema)} + {self.get_device(func, schema)} + {self.lazy_tensor_decls(func, schema)} + {self.build_ir_node(func, schema)} + {self.return_aten_tensor(func, schema)} + }}\n + """ + ] + + +class ComputeShapeSignature: + """ + Here we use the base name as the suffix of the signature to avoid generating for in-place variants. + """ + + def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None: + self.__schema = LazyIrSchema(f.func, symint=symint) + self.__dispatch_args = ", ".join( + [a.decl() for a in dispatcher.arguments(f.func, symint=symint)] + ) + self.__call_args = ", ".join( + [f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)] + ) + self.__kernel_name = kernel_name + + def __decl_suffix(self) -> str: + return f"{self.__kernel_name}({self.__dispatch_args})" + + def __call_suffix(self) -> str: + return f"{self.__kernel_name}({self.__call_args})" + + @property + def shape_decl(self) -> str: + return f"TORCH_API std::vector compute_shape_{self.__decl_suffix()}" + + @property + def shape_call(self) -> str: + return f"torch::lazy::compute_shape_{self.__call_suffix()}" + + +@dataclass(frozen=True) +class GenLazyShapeInferenceDefinition: + backend_index: BackendIndex + tensor_class: str + + @method_with_native_function + def __call__(self, f: NativeFunction) -> list[str]: + metadata = self.backend_index.get_kernel(f) + assert metadata is not None + + # See Note [Generated LTC Shape Functions] + is_view_copy_op = "view_copy" in f.tags + is_structured = f.structured or f.structured_delegate is not None + if is_structured or is_view_copy_op: + return [] + else: + shape_sig = ComputeShapeSignature( + metadata.kernel, f, symint=metadata.supports_symint() + ) + return ["\n".join([f"{shape_sig.shape_decl};"])] + + +def generate_non_native_lazy_ir_nodes( + non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR +) -> list[str]: + """Generate the non-native lazy IR node classes""" + nodes = [] + for op in non_native: + # Set default properties for Non-Native IRs + properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly") + for p in op.get("properties", []): + setattr(properties, p, True) + + # non-native is assumed to want symint bindings if you wrote symint + schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True) + schema.opkind = op.get("opkind") + nodes.append(gen_lazy_ir.gen(schema)[0]) + + return nodes diff --git a/lib/python3.10/site-packages/torchgen/dest/lazy_ts_lowering.py b/lib/python3.10/site-packages/torchgen/dest/lazy_ts_lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..70161216d8e7c95e194b0d89b345e0da886ef989 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/dest/lazy_ts_lowering.py @@ -0,0 +1,48 @@ +from torchgen.api.lazy import LazyArgument, LazyIrSchema +from torchgen.api.types import OptionalCType + + +def ts_lowering_body(schema: LazyIrSchema) -> str: + # for now, we just want one IR class decl and soon after also the method defs + # and we use the functional version not out/inplace. + emplace_arguments = [] + + def get_value(arg: LazyArgument) -> str: + if isinstance(arg.lazy_type, OptionalCType): + return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr" + return "loctx->GetOutputOp(operand(i++))" + + for arg in schema.positional_args: + if arg.is_lazy_value: + emplace_arguments.append(get_value(arg)) + continue + emplace_arguments.append(f'"{arg.name}", {arg.name}') + + emplace_arguments_str = "\n ".join( + [f"arguments.emplace_back({a});" for a in emplace_arguments] + ) + emplace_kwarg_values = [ + f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values + ] + emplace_kwarg_scalars = [ + f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars + ] + emplace_kwarguments = "\n ".join( + [ + f"kwarguments.emplace_back({a});" + for a in emplace_kwarg_values + emplace_kwarg_scalars + ] + ) + return f"""\ + std::vector arguments; + std::vector kwarguments; + arguments.reserve({len(emplace_arguments)}); + kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)}); + size_t i = 0; + {emplace_arguments_str} + {emplace_kwarguments} + torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); + TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)}); + + return {schema.aten_name}_out; +""" diff --git a/lib/python3.10/site-packages/torchgen/dest/native_functions.py b/lib/python3.10/site-packages/torchgen/dest/native_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..a93405555bc229db19c7975108a32a8cbedf19e4 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/dest/native_functions.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import torchgen.api.meta as meta +import torchgen.api.structured as structured +from torchgen.api.types import kernel_signature +from torchgen.context import with_native_function_and_index +from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup +from torchgen.utils import mapMaybe + + +@with_native_function_and_index +def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None: + sig = kernel_signature(f, backend_index) + metadata = backend_index.get_kernel(f) + if metadata is None: + return None + if "legacy::" in metadata.kernel: + return None + else: + prefix = "static" if backend_index.external else "TORCH_API" + return f"{prefix} {sig.decl(name=metadata.kernel)};" + + +@with_native_function_and_index +def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]: + meta_name = meta.name(g) + out_args = structured.impl_arguments(g) + metadata = backend_index.get_kernel(g) + if metadata is None: + return [] + prefix = "" if backend_index.external else "TORCH_API " + return [ + f"""\ +struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{ +void impl({', '.join(a.decl() for a in out_args)}); +}}; +""" + ] + + +# Generates NativeFunctions.h, a list of forward declarations of all +# actual kernel definitions we keep in aten/src/ATen/native/ +@with_native_function_and_index +def compute_native_function_declaration( + g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex +) -> list[str]: + metadata = backend_index.get_kernel(g) + if isinstance(g, NativeFunctionsGroup): + if metadata is not None and metadata.structured: + if backend_index.external: + # Structured hasn't been tested with external backends yet. + raise AssertionError( + "Structured external backend functions are not implemented yet." + ) + else: + return gen_structured(g, backend_index) + else: + return list( + mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions()) + ) + else: + x = gen_unstructured(g, backend_index) + return [] if x is None else [x] diff --git a/lib/python3.10/site-packages/torchgen/dest/register_dispatch_key.py b/lib/python3.10/site-packages/torchgen/dest/register_dispatch_key.py new file mode 100644 index 0000000000000000000000000000000000000000..091bec237238e2117697b54a0dc9d6816eef8146 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/dest/register_dispatch_key.py @@ -0,0 +1,1005 @@ +from __future__ import annotations + +import itertools +import textwrap +from dataclasses import dataclass +from typing import Literal, TYPE_CHECKING + +import torchgen.api.cpp as cpp +import torchgen.api.meta as meta +import torchgen.api.structured as structured +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + ConstRefCType, + CppSignature, + CppSignatureGroup, + DispatcherSignature, + Expr, + kernel_signature, + MutRefCType, + NamedCType, + NativeSignature, + tensorT, +) +from torchgen.context import method_with_native_function, native_function_manager +from torchgen.model import ( + Argument, + BackendIndex, + DeviceCheckType, + DispatchKey, + gets_generated_out_inplace_wrapper, + is_cuda_dispatch_key, + NativeFunction, + NativeFunctionsGroup, + SchemaKind, + TensorOptionsArguments, +) +from torchgen.utils import assert_never, mapMaybe, Target + + +if TYPE_CHECKING: + from torchgen.selective_build.selector import SelectiveBuilder + + +def gen_registration_headers( + backend_index: BackendIndex, + per_operator_headers: bool, + rocm: bool, +) -> list[str]: + if per_operator_headers: + headers = ["#include "] + else: + headers = ["#include "] + + if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta): + headers.append("#include ") + elif backend_index.dispatch_key == DispatchKey.CUDA: + if rocm: + headers.append("#include ") + else: + headers.append("#include ") + elif backend_index.dispatch_key == DispatchKey.MPS: + headers.append("#include ") + elif backend_index.dispatch_key == DispatchKey.XPU: + # XPU specific, this header resides in third_party/torch-xpu-ops + headers.append("#include ") + elif per_operator_headers: + headers += [ + "#include ", + "#include ", + "#include ", + "#include ", + ] + else: + headers.append("#include ") + + headers.append("#include ") + return headers + + +def gen_empty_impl_names( + backend_index: BackendIndex, +) -> tuple[str | None, str | None]: + empty_impl = None + empty_strided_impl = None + + if backend_index.dispatch_key in ( + DispatchKey.Meta, + DispatchKey.CPU, + DispatchKey.CUDA, + DispatchKey.MPS, + DispatchKey.XPU, + ): + dispatch = str(backend_index.dispatch_key).lower() + empty_impl = f"at::detail::empty_{dispatch}" + empty_strided_impl = f"at::detail::empty_strided_{dispatch}" + elif backend_index.dispatch_key in ( + DispatchKey.CompositeExplicitAutogradNonFunctional, + DispatchKey.QuantizedCPU, + DispatchKey.QuantizedCUDA, + DispatchKey.XPU, + ): + empty_impl = "at::empty" + empty_strided_impl = "at::empty_strided" + + return empty_impl, empty_strided_impl + + +def gen_create_out_helper(backend_index: BackendIndex) -> list[str]: + if backend_index.dispatch_key == DispatchKey.Meta: + empty_options = "options.device(at::kMeta)" + else: + empty_options = "options" + + empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index) + if empty_impl is None: + return [] + + return [ + f""" +Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{ + if (strides.empty()) {{ + return {empty_impl}(sizes, {empty_options}); + }} else {{ + return {empty_strided_impl}(sizes, strides, {empty_options}); + }} +}} +""" + ] + + +def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]: + _, empty_strided_impl = gen_empty_impl_names(backend_index) + return ( + [] + if empty_strided_impl is None + else [ + f""" +std::optional maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{ + if (out.strides() != strides) {{ + return {empty_strided_impl}(sizes, strides, options); + }} + return std::nullopt; +}} +""" + ] + ) + + +def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]: + if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: + # The function isn't used by this key (since only functional ops have a kernel for this key), + # so we need to not include it to avoid a defined-but-not-used error. + return [] + return [ + """ +void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) { + TORCH_CHECK(options.dtype() == out.dtype(), + "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead"); + TORCH_CHECK(options.device() == out.device(), + "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead"); + const bool resized = at::native::resize_output(out, sizes); + // Only restride if a resize occurred; otherwise we ignore the (advisory) + // strides from the meta function and directly use the output tensor's + // preexisting strides + if (resized) { + if (!strides.empty()) { + TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); + // TODO: avoid the redispatch here + out.as_strided_(sizes, strides); + } else if (options.memory_format_opt().has_value()) { + out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); + } + } +} +""" + ] + + +def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]: + return [ + """ +void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) { + // These checks are needed on those operators that: + // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm') + // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod') + // For other operators (e.g. 'add'), 'TensorIterator' already checks + // these things separately. + TORCH_CHECK(options.dtype() == self.dtype(), + "Bad in-place call: ", + "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match"); + TORCH_CHECK(options.device() == self.device(), + "Bad in-place call: ", + "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match"); + TORCH_CHECK(sizes == self.sizes(), + "Bad in-place call: ", + "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match"); +} +""" + ] + + +def gen_registration_helpers(backend_index: BackendIndex) -> list[str]: + return [ + 'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")', + *gen_create_out_helper(backend_index), + *gen_resize_out_helper(backend_index), + *gen_check_inplace_helper(backend_index), + *gen_maybe_create_proxy_helper(backend_index), + "C10_DIAGNOSTIC_POP()", + ] + + +# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp). +# +# - The primary function of this file is to register all of the +# implementations for the given dispatch key to the dispatcher, +# so they are available for use in PyTorch. If dispatch is +# None, we generate schema (def) registrations and catchall +# registrations. +# - The secondary function of this file is to generate a wrapper +# around functions. In CPUType these wrappers do nothing +# (and should be removed), but in other cases they handle +# DeviceGuard. A small extra benefit of wrappers is they +# are not overloaded, so they can be used in the registration +# API without having to disambiguate which overload you want +# (as would be the case if you directly registered native:: +# functions). +# - The tertiary function of this file is to generate *static* +# cpp API bindings which can be used to bypass dispatcher +# directly to kernels, but with user-friendly cpp-style API +@dataclass(frozen=True) +class RegisterDispatchKey: + backend_index: BackendIndex + + target: Literal[ + Target.ANONYMOUS_DEFINITION, + Target.NAMESPACED_DEFINITION, + Target.NAMESPACED_DECLARATION, + Target.REGISTRATION, + ] + + # Selector object to determine which operators to generate + # registration code for. + selector: SelectiveBuilder + + # Whether or not we are actually code-genning for ROCm + rocm: bool + + # Whether or not to generate symint registrations or not. External users + # of codegen who don't care about symints can set this to false to get + # non-SymInt codegen + symint: bool + + # The class that all unstructured native functions live under. This is used to improve + # compiler error messages when a kernel writer adds a native function with the wrong signature. + # This is only used in unstructured kernels, since structured kernels already live in a class. + # Finally, this field is currently Optional because it is only used by external backends. + # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating + # all of the existing kernel signatures scattered across aten/src/ATen/native. + class_method_name: str | None + + # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering + # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher. + skip_dispatcher_op_registration: bool + + @staticmethod + def gen_device_check( + type: DeviceCheckType, args: list[Argument], method_name: str + ) -> str: + if type == DeviceCheckType.NoCheck: + return " // No device check\n" + + device_check = "std::optional common_device = std::nullopt;\n" + device_check += "(void)common_device; // Suppress unused variable warning\n" + for arg in args: + # Only tensor like arguments are eligible + if arg.type.is_tensor_like(): + device_check += f""" + c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");""" + return device_check + + @method_with_native_function + def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]: + if isinstance(f, NativeFunctionsGroup): + g: NativeFunctionsGroup = f + # Note: We call gen_structured() if the operator is marked structured, regardless of the backend. + # gen_structured() has special logic to handle auto-generated kernels. + if g.structured: + return self.gen_structured(g) + else: + return list( + mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()) + ) + elif isinstance(f, NativeFunction): + r = self.gen_unstructured(f) + return [] if r is None else [r] + else: + assert_never(f) + + def wrapper_kernel_sig( + self, f: NativeFunction + ) -> NativeSignature | DispatcherSignature: + # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names. + return DispatcherSignature.from_schema( + f.func, + prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_", + symint=self.symint, + ) + + def gen_out_inplace_wrapper( + self, f: NativeFunction, g: NativeFunctionsGroup | None + ) -> str | None: + if g is None: + return None + k = f.func.kind() + if k is SchemaKind.inplace: + copy_op = "at::_copy_from" + elif k is SchemaKind.out: + copy_op = "at::_copy_from_and_resize" + else: + raise AssertionError("gen_out_inplace_wrapper called on a functional op") + + sig = self.wrapper_kernel_sig(f) + name = sig.name() + + func_res = f"{name}_tmp" + return_names = cpp.return_names(f) + if len(return_names) > 1: + updates = "\n ".join( + f"{copy_op}(std::get<{i}>({func_res}), {ret_name});" + for i, ret_name in enumerate(return_names) + ) + returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})' + elif len(return_names) == 1: + ret_name = return_names[0] + updates = f"{copy_op}({func_res}, {ret_name});" + returns = ret_name + else: + assert len(f.func.arguments.out) == 1 + returns = "" + out_arg = f.func.arguments.out[0] + if out_arg.type.is_list_like(): + updates = f"""\ + for (int64_t i = 0; i < {func_res}.size(); ++i) {{ + {copy_op}({func_res}[i], {out_arg.name}[i]); + }}""" + else: + updates = f"{copy_op}({func_res}, {out_arg.name});" + + functional_sig = self.wrapper_kernel_sig(g.functional) + wrapper_name = sig.name() + + return f"""\ +{sig.defn(name=wrapper_name)} {{ + auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))}); + {updates} + return {returns}; +}} +""" + + def gen_structured(self, g: NativeFunctionsGroup) -> list[str]: + metadata = self.backend_index.get_kernel(g) + if self.backend_index.dispatch_key == DispatchKey.Meta: + assert not self.backend_index.has_kernel(g.out), ( + "Do not explicitly specify Meta dispatch key on structured " + "functions, they will be automatically generated for you" + ) + elif ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): + assert not self.backend_index.has_kernel(g.out), ( + "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured " + "functions, they will be automatically generated for you" + ) + elif metadata is None or not metadata.structured: + return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())) + structured_gen = StructuredRegisterDispatchKey( + self.backend_index, + self.target, + self.selector, + self.rocm, + self.symint, + self.class_method_name, + self.skip_dispatcher_op_registration, + g, + ) + return list(mapMaybe(structured_gen.gen_one, g.functions())) + + def gen_unstructured( + self, f: NativeFunction, g: NativeFunctionsGroup | None = None + ) -> str | None: + with native_function_manager(f): + inplace_meta = False + gets_out_inplace_wrapper = False + if not self.backend_index.has_kernel(f): + if ( + self.backend_index.dispatch_key == DispatchKey.Meta + and f.func.kind() is SchemaKind.inplace + and + # Defer to composites for meta implementation + not f.has_composite_kernel + and + # Inplace list operations are not supported + len(f.func.returns) == 1 + ): + inplace_meta = True + elif ( + not self.backend_index.use_out_as_primary + and g is not None + and gets_generated_out_inplace_wrapper(f, g, self.backend_index) + ): + # We want to generate inplace/out wrappers, that don't have a kernel for the backend. + gets_out_inplace_wrapper = True + else: + return None + if f.manual_kernel_registration: + return None + + if ( + self.target is Target.REGISTRATION + and not self.selector.is_native_function_selected(f) + ): + return None + + sig = self.wrapper_kernel_sig(f) + + name = sig.name() + returns_type = sig.returns_type().cpp_type() + args = sig.arguments() + args_str = ", ".join(a.defn() for a in args) + + # See Note [Direct dispatch bindings] + cpp_sig_group = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + + # TODO: dedupe this with the structured codegen + if self.target is Target.NAMESPACED_DECLARATION: + result = "" + for cpp_sig in cpp_sig_group.signatures(symint=self.symint): + result += f"TORCH_API {cpp_sig.decl()};\n" + return result + elif self.target is Target.NAMESPACED_DEFINITION: + + def generate_defn(cpp_sig: CppSignature) -> str: + return f""" +{cpp_sig.defn()} {{ +return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); +}} +""" + + result = "" + for cpp_sig in cpp_sig_group.signatures(symint=self.symint): + result += generate_defn(cpp_sig) + return result + + elif self.target is Target.ANONYMOUS_DEFINITION: + # short circuit for inplace_meta + if inplace_meta: + assert f.func.arguments.self_arg is not None + self_arg_name = f.func.arguments.self_arg.argument.name + # TODO: handle in place on tensor list + return f""" +{returns_type} {name}({args_str}) {{ + TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), + "Cannot inplace into non-meta tensor with meta tensor argument"); + return {self_arg_name}; +}} +""" + + # short circuit for generated inplace/out wrappers + if gets_out_inplace_wrapper: + return self.gen_out_inplace_wrapper(f, g) + + metadata = self.backend_index.get_kernel(f) + if metadata is None: + return None + if self.class_method_name is None: + impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}" + else: + impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}" + + kernel_sig = kernel_signature(f, self.backend_index) + + args_exprs_str = ", ".join( + e.expr + for e in translate( + sig.arguments(), kernel_sig.arguments(), method=False + ) + ) + + device_check = " // No device check\n" + # Backends that require device guards presumably also require device checks. + if self.backend_index.device_guard: + device_check_args = itertools.chain( + f.func.arguments.out, f.func.arguments.flat_positional + ) + device_check = RegisterDispatchKey.gen_device_check( + f.device_check, list(device_check_args), name + ) + + device_guard = "// DeviceGuard omitted" # default + if f.device_guard and self.backend_index.device_guard: + has_tensor_options = any( + isinstance(a, TensorOptionsArguments) + for a in f.func.arguments.non_out + ) + if has_tensor_options: + # kernel is creating a tensor + device_guard = """ + const DeviceGuard device_guard(device_or_default(device));""" + + # CUDA requires special handling + if is_cuda_dispatch_key(self.backend_index.dispatch_key): + device_guard = ( + f"globalContext().lazyInitCUDA();\n{device_guard}" + ) + else: + # kernel is operating on existing tensors + + # There is precedence for which argument we use to do + # device guard. This describes the precedence order. + self_arg = ( + [f.func.arguments.self_arg.argument] + if f.func.arguments.self_arg is not None + else [] + ) + candidate_args = itertools.chain( + self_arg, + f.func.arguments.out, + f.func.arguments.flat_positional, + ) + + # Only tensor like arguments are eligible + device_of = next( + ( + f"{a.name}" + for a in candidate_args + if a.type.is_tensor_like() + ), + None, + ) + if device_of is not None: + device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" + + return f"""\ +namespace {{ + +{returns_type} {name}({args_str}) {{ + {device_check} + + {device_guard} + return {impl_name}({args_exprs_str}); +}} + +}} // anonymous namespace +""" + + elif self.target is Target.REGISTRATION: + if f.manual_kernel_registration or self.skip_dispatcher_op_registration: + return None + else: + payload = f"TORCH_FN({name})" + return f'm.impl("{f.func.name}",\n{payload});\n' + else: + assert_never(self.target) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# STRUCTURED +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@dataclass(frozen=True) +class StructuredRegisterDispatchKey(RegisterDispatchKey): + g: NativeFunctionsGroup + + def gen_class_set_output_functions( + self, k: SchemaKind, parent_class: str, generate_super: bool + ) -> str: + if generate_super: + set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);" + else: + set_output_super = "" + + def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str: + return f""" +void set_output_{name}( + int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, + TensorOptions options, DimnameList names +) override {{ +{textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")} + if (!names.empty()) {{ + namedinference::propagate_names(outputs_[output_idx], names); + }} + // super must happen after, so that downstream can use maybe_get_output + // to retrieve the output +{textwrap.indent(set_output_super, " ")} +}} +""" + + return f""" +{gen_set_output_function("strided", maybe_create_proxy=True)} +{gen_set_output_function("raw_strided", maybe_create_proxy=False)} +""" + + def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str: + if self.backend_index.dispatch_key in [ + DispatchKey.CUDA, + DispatchKey.MPS, + DispatchKey.CompositeExplicitAutogradNonFunctional, + ]: + maybe_set_guard = """ +auto current_device = guard_.current_device(); +if (C10_UNLIKELY(current_device.has_value())) { + TORCH_INTERNAL_ASSERT(*current_device == options.device(), + "structured kernels don't support multi-device outputs"); +} else { + guard_.reset_device(options.device()); +} +""" + maybe_set_guard_line = maybe_set_guard + "\n" + else: + maybe_set_guard_line = maybe_set_guard = "" + + if maybe_create_proxy: + create_proxy = """ +auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options); +if (C10_UNLIKELY(maybe_proxy.has_value())) { + proxy_outputs_[output_idx] = std::move(maybe_proxy).value(); +} +""" + else: + create_proxy = "" + + if k is SchemaKind.functional: + assert self.backend_index.dispatch_key in ( + DispatchKey.Meta, + DispatchKey.CPU, + DispatchKey.CUDA, + DispatchKey.MPS, + DispatchKey.XPU, + DispatchKey.CompositeExplicitAutogradNonFunctional, + ) + return f"""{maybe_set_guard_line} +outputs_[output_idx] = create_out(sizes, strides, options);""" + elif k is SchemaKind.inplace: + return f"""{maybe_set_guard_line} +const auto& out = outputs_[output_idx].get(); +check_inplace(out, sizes, options); +{create_proxy}""" + elif k is SchemaKind.out: + return f"""{maybe_set_guard_line} +const auto& out = outputs_[output_idx].get(); +resize_out(out, sizes, strides, options); +{create_proxy}""" + elif k is SchemaKind.mutable or k is SchemaKind.scratch: + raise AssertionError( + f"{k} structured operators are currently not supported" + ) + else: + assert_never(k) + + # returns the definition of a ctor, as well as how to construct + # this class to a variable named op + def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str: + if k is SchemaKind.functional: + return "" + elif k is SchemaKind.inplace: + # TODO: Make sure out argument is guaranteed to be self + return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}" + elif k is SchemaKind.out: + out_args = ", ".join(f"Tensor& out{i}" for i in range(returns)) + out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns)) + return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}" + elif k is SchemaKind.mutable or k is SchemaKind.scratch: + raise AssertionError( + f"{k} structured operators are currently not supported" + ) + else: + assert_never(k) + + def gen_class( + self, + f: NativeFunction, + k: SchemaKind, + *, + class_name: str, + parent_class: str, + generate_super: bool, + ) -> str: + if k is SchemaKind.functional: + output_type = "Tensor" + output_value = "outputs_[output_idx]" + proxy_field = "" + elif k is SchemaKind.inplace: + output_type = "std::reference_wrapper" + output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()" + proxy_field = f"std::array<::std::optional, {len(f.func.returns)}> proxy_outputs_;" + elif k is SchemaKind.out: + output_type = "std::reference_wrapper" + output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()" + proxy_field = f"std::array<::std::optional, {len(f.func.returns)}> proxy_outputs_;" + else: + raise RuntimeError(f"Unsupported SchemaKind {k}") + + if self.backend_index.dispatch_key == DispatchKey.CUDA: + if self.rocm: + guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;" + else: + guard_field = "c10::cuda::OptionalCUDAGuard guard_;" + elif ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): + guard_field = "c10::OptionalDeviceGuard guard_;" + elif self.backend_index.dispatch_key == DispatchKey.MPS: + # TODO: Move to OptionalMPSGuard. + guard_field = "c10::OptionalDeviceGuard guard_;" + else: + guard_field = "" + + indent = " " * 4 + class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns)) + lines = ( + f"struct {class_name} final : public {parent_class} {{", + f"{textwrap.indent(class_ctor_str, indent)}", + f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}", + " const Tensor& maybe_get_output(int64_t output_idx) override {", + f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit + " }", + # type: ignore[possibly-undefined] # TODO: audit + f" std::array<{output_type}, {len(f.func.returns)}> outputs_;", + f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit + f"{textwrap.indent(guard_field, indent)}", + "};", + ) + return "\n".join(line for line in lines if line) + + @method_with_native_function + def gen_one(self, f: NativeFunction) -> str | None: + assert not f.manual_kernel_registration + + if ( + self.target is Target.REGISTRATION + and not self.selector.is_native_function_selected(f) + ): + return None + + # TODO: Now, there is something interesting going on here. In the code below, + # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace + # based on the out implementation. But in fact, out is definable by + # functional too (just not very efficiently), and this is honestly the + # MORE likely situation for a backend implementor. How do we pick? + # Well, taking a page from Haskell type classes and default methods, + # we could conceivably register a circular definition (out in terms + # of functional, and functional in terms of out) and just require + # someone to implement one or the other. We'd have to do a little bit + # of work to not register one of these "weak" definitions unless there + # is a strong definition somewhere in the DAG! So it's not implemented yet. + if ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + and f.func.kind() is SchemaKind.out + ): + # Never generate a default implementation for out, that's what you + # have to define as a backend implementor + return None + + # Note [Direct dispatch bindings] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Signature of the non-dispatched function we'll expose in a header + # (e.g., at::cpu::add). We don't generate methods (TODO: do this + # when CPUTensor class is a thing); nor do we generate fallback + # bindings for manual_cpp_binding functions. + cpp_sig_group = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + + # Signature of the wrapper function we'll register to the dispatcher + kern = self.backend_index.get_kernel(f) + sig = NativeSignature( + f.func, + prefix=f"wrapper_{self.backend_index.dispatch_key}_", + symint=kern is not None and kern.supports_symint(), + ) + + if self.target is Target.NAMESPACED_DECLARATION: + result = "" + for cpp_sig in cpp_sig_group.signatures(symint=self.symint): + result += f"TORCH_API {cpp_sig.decl()};\n" + return result + + elif self.target is Target.NAMESPACED_DEFINITION: + + def generate_defn(cpp_sig: CppSignature) -> str: + return f""" +{cpp_sig.defn()} {{ +return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); +}} +""" + + result = "" + for cpp_sig in cpp_sig_group.signatures(symint=self.symint): + result += generate_defn(cpp_sig) + return result + + elif self.target is Target.ANONYMOUS_DEFINITION: + k = f.func.kind() + + # Construct the body of the wrapper function with signature sig + sig_body = [] + # We'll use context to keep track of any variables we've brought + # into scope while generating code + context: list[Binding | Expr] = list(sig.arguments()) + + # Initialize the class corresponding to this structured + # operator; feeding it the output argument(s) if it is known + if self.backend_index.dispatch_key is DispatchKey.Meta: + class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" + parent_class = f"at::meta::structured_{meta.name(self.g)}" + elif ( + self.backend_index.dispatch_key + is DispatchKey.CompositeExplicitAutogradNonFunctional + ): + # TODO: dedup this branch + class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" + parent_class = f"at::meta::structured_{meta.name(self.g)}" + else: + metadata = self.backend_index.get_kernel(self.g) + assert metadata is not None + class_name = f"structured_{metadata.kernel}_{k.name}" + parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}" + + if self.backend_index.device_guard: + device_check_args = itertools.chain( + f.func.arguments.out, f.func.arguments.flat_positional + ) + sig_body.append( + RegisterDispatchKey.gen_device_check( + f.device_check, list(device_check_args), sig.name() + ) + ) + + if k is SchemaKind.functional: + sig_body.append(f"{class_name} op;") + elif k is SchemaKind.inplace: + sig_body.append(f"{class_name} op(self);") + elif k is SchemaKind.out: + out_args_str = ", ".join(a.name for a in f.func.arguments.out) + sig_body.append(f"{class_name} op({out_args_str});") + + # Translate the input native arguments into structured + # arguments for the meta call + meta_exprs = ", ".join( + e.expr + for e in translate( + context, structured.meta_arguments(self.g), method=False + ) + ) + + if self.g.out.precomputed: + # If this function group has precomputed elements, the meta function + # returns a struct containing them which must be saved so that it + # can be unpacked when generating code to call the impl. + sig_body.append(f"auto precompute = op.meta({meta_exprs});") + + # Put all of the contents of the precompute struct into the context + # so that translate will be able to return the correct args for the + # call to the impl. + precomputed_values = [ + *self.g.out.precomputed.replace.values(), + self.g.out.precomputed.add, + ] + for precomputed_elems in precomputed_values: + for arg in precomputed_elems: + context.append( + Expr( + expr=f"precompute.{arg.name}", + type=structured.argument_type(arg, binds=arg.name), + ) + ) + + # Add a use of the precompute struct so FB internal compilers don't + # complain that there is an unused variable. + sig_body.append("(void)precompute;") + else: + sig_body.append(f"op.meta({meta_exprs});") + + # After running meta, op.outputs_ is guaranteed to be valid; + # add it to the context + out_args = structured.out_arguments(self.g) + for i, out_arg in enumerate(out_args): + assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type + + if k is SchemaKind.out: + expr = f"op.maybe_get_output({i})" + else: + expr = f"op.outputs_[{i}]" + + context.append( + Expr( + expr=expr, + # TODO: Stop hardcoding that the output type is a Tensor. Note + # that for the codegen here this is fine because outputs_ is + # hardcoded to be tensor already + type=NamedCType( + out_arg.nctype.name, MutRefCType(BaseCType(tensorT)) + ), + ) + ) + + # With the expanded context, do the impl call (if not a meta + # function) + if ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): + # TODO: https://github.com/pytorch/pytorch/issues/53023 + out_sig_group = CppSignatureGroup.from_native_function( + self.g.out, method=False, fallback_binding=f.manual_cpp_binding + ) + out_sig = out_sig_group.most_faithful_signature() + api_name = out_sig.name() + out_exprs = ", ".join( + e.expr + for e in translate(context, out_sig.arguments(), method=False) + ) + # TODO: I think this means structured won't work with method + # only functions (but maybe you're saved by faithful? iunno.) + # NB: Originally I wrote this as an at::redispatch call, but + # I got in trouble because that meant I needed a DispatchKeySet + # in the wrapper function, which meant I needed a DispatchKeySet + # in the DispatchKeyFunctions declarations, but the defined API + # there does NOT permit a dispatch key set. I think you can + # probably unwind this by calling some function to do the TLS + # fetch and get the DispatchKeySet when you don't have it, but + # I didn't do it for this version + sig_body.append(f"at::{api_name}({out_exprs});") + elif self.backend_index.dispatch_key != DispatchKey.Meta: + impl_exprs = ", ".join( + e.expr + for e in translate( + context, structured.impl_arguments(self.g), method=False + ) + ) + sig_body.append(f"op.impl({impl_exprs});") + + # Go over each output, and check if there is a proxy created for it. + # If so, copy it over to the original output. + if k is SchemaKind.out or k is SchemaKind.inplace: + for i in range(len(f.func.returns)): + sig_body.append( + f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);" + ) + + # Destructively return the final tensors + # TODO: Do this in translate instead + if k is SchemaKind.functional: + if len(f.func.returns) == 1: + ret_expr = "std::move(op.outputs_[0])" # small optimization + else: + moved = ", ".join( + f"std::move(op.outputs_[{i}])" + for i in range(len(f.func.returns)) + ) + ret_expr = f"std::make_tuple({moved})" + elif k is SchemaKind.inplace: + ret_expr = "self" + elif k is SchemaKind.out: + if len(f.func.returns) == 1: + ret_expr = f.func.arguments.out[0].name + else: + refs = ", ".join(a.name for a in f.func.arguments.out) + ret_expr = f"std::forward_as_tuple({refs})" + sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit + + sig_body_str = "\n".join(sig_body) + + # For an overview of what this template code looks like, see + # https://github.com/pytorch/rfcs/pull/9 + return f"""\ +{self.gen_class( +f, k, +class_name=class_name, +parent_class=parent_class, +generate_super=self.g.out.structured_inherits is not None +)} + +{sig.defn()} {{ +{sig_body_str} +}} +""" + + elif self.target is Target.REGISTRATION: + return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' + else: + assert_never(self.target) + # Silence mypy's "Missing return statement" error + return None diff --git a/lib/python3.10/site-packages/torchgen/dest/ufunc.py b/lib/python3.10/site-packages/torchgen/dest/ufunc.py new file mode 100644 index 0000000000000000000000000000000000000000..073df2eb1849ba45fffad59d9afd98fa79ffceb9 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/dest/ufunc.py @@ -0,0 +1,551 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence, TYPE_CHECKING + +import torchgen.api.ufunc as ufunc +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + CType, + Expr, + NamedCType, + opmath_t, + scalar_t, + StructuredImplSignature, + VectorizedCType, +) +from torchgen.context import with_native_function +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + DispatchKey, + NativeFunctionsGroup, + ScalarType, + UfuncKey, +) +from torchgen.utils import OrderedSet + + +if TYPE_CHECKING: + from torchgen.api.ufunc import UfunctorBindings + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# CUDA STUFF +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# NB: not bothering to generate dispatch stub forward declaration in header, +# we can just paste it whereever necessary + +# TODO: use BackendIndex +# dispatch_key: DispatchKey # only CPU/CUDA right now + + +# Represents functors for implementing CUDA ufuncs. +# Functors are templated by scalar_t because when USERS instantiate functors +# they are templated. A functor looks something like this: +# +# template +# struct CUDAFunctorOnSelf_add { +# using opmath_t = at::opmath_type; +# opmath_t other_; +# opmath_t alpha_; +# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) +# : other_(other), alpha_(alpha) {} +# __device__ scalar_t operator()(scalar_t self) { +# return ufunc::add(static_cast(self), other_, alpha_); +# } +# }; +# +@dataclass(frozen=True) +class UfunctorSignature: + g: NativeFunctionsGroup + scalar_tensor_idx: int | None + name: str + + def arguments(self) -> UfunctorBindings: + return ufunc.ufunctor_arguments( + self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t + ) + + def fields(self) -> list[Binding]: + # fields are renamed to have a trailing underscore, as is conventional + return [b.rename(f"{b.name}_") for b in self.arguments().ctor] + + def returns_type(self) -> CType: + # TODO: don't hardcode; return type will be inferred based on tags on + # the native function + return BaseCType(scalar_t) + + def decl_fields(self) -> str: + return "\n".join(f"{f.type} {f.name};" for f in self.fields()) + + def inline_defn_ctor(self) -> str: + args_str = ", ".join(a.decl() for a in self.arguments().ctor) + # NB: hypothetically could do this with translate but the + # transition here is very regular + init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor) + return f"{self.name}({args_str}) : {init_str} {{}}" + + def decl_apply(self) -> str: + args_str = ", ".join(a.decl() for a in self.arguments().apply) + return f"{self.returns_type().cpp_type()} operator()({args_str}) const" + + +@dataclass(frozen=True) +class UfuncSignature: + g: NativeFunctionsGroup + name: str + compute_t: CType + + def arguments(self) -> list[Binding]: + return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t) + + def call(self, ctx: Sequence[Binding | Expr]) -> str: + return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + +# steps: +# 1. take the functional signature +# 2. use api.ufunc to convert it to template signature. this establishes +# the type of the template function +# 3. use api.ufunc (II) to generate a split struct / operator() signature. +# this establish context in which we call the template signature +# +# StructuredImplSignature context +# ~> functor constructor sig +# +# Functor constructor context +# ~> functor fields sig +# +# Functor apply context (functor fields + functor apply sig) +# ~> template sig +# + + +def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool: + num_tensors = sum( + 1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like() + ) + return num_tensors == 2 + + +def compute_ufunc_cuda_functors( + g: NativeFunctionsGroup, +) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]: + # First, build the functors. + ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {} + ufunctors: list[str] = [] + loops = g.out.ufunc_inner_loop + scalar_tensor_idx_lookup = { + UfuncKey.CUDAFunctorOnSelf: 1, + UfuncKey.CUDAFunctorOnOther: 0, + UfuncKey.CUDAFunctor: None, + } + if eligible_for_binary_scalar_specialization(g): + keys = [ + UfuncKey.CUDAFunctorOnSelf, + UfuncKey.CUDAFunctorOnOther, + UfuncKey.CUDAFunctor, + ] + else: + keys = [UfuncKey.CUDAFunctor] + for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]: + assert k not in loops, f"cannot use {k} on non-binary function" + for k in keys: + # If the key was directly defined, skip functor codegen; we assume the + # user already done it for us + if k in loops: + ufunctor_sig = UfunctorSignature( + g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name + ) + for dtype in loops[k].supported_dtypes: + ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig + continue + + # Note [ScalarOnly and Generic must match names for CUDA] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Otherwise, look in ANY of the generic entries. For simplicity of + # codegen, both ScalarOnly and Generic are defined, the ufunc name + # must match (if they didn't match, we'd have to generate distinct + # functors per dtype, which is awful, so we're not going to do it unless + # someone really forces us to) + ufunc_name = None + supported_dtypes: OrderedSet[ScalarType] = OrderedSet() + for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]: + if lk not in loops: + continue + if ufunc_name is None: + ufunc_name = loops[lk].name + else: + # See Note [ScalarOnly and Generic must match names for CUDA] + assert ( + ufunc_name == loops[lk].name + ), "ScalarOnly and Generic must have same ufunc name" + supported_dtypes |= loops[lk].supported_dtypes + assert ufunc_name is not None + + name = f"{k}_{ufunc_name}" + ufunctor_sig = UfunctorSignature( + g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name + ) + for dtype in supported_dtypes: + ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig + + ufunc_sig = UfuncSignature( + g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t) + ) + apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply + ufunctors.append( + f""" +template +struct {ufunctor_sig.name} {{ + using opmath_t = at::opmath_type; + {ufunctor_sig.decl_fields()} + {ufunctor_sig.inline_defn_ctor()} + __device__ {ufunctor_sig.decl_apply()} {{ + return {ufunc_sig.call(apply_ctx)}; + }} +}}; +""" + ) + + return ufunctor_sigs, "\n".join(ufunctors) + + +@dataclass(frozen=True) +class BinaryScalarSpecializationConfig: + scalar_idx: int + ctor_tensor: str + ufunc_key: UfuncKey + + +BinaryScalarSpecializationConfigs = [ + BinaryScalarSpecializationConfig( + scalar_idx=0, + ctor_tensor="self", + ufunc_key=UfuncKey.CUDAFunctorOnOther, + ), + BinaryScalarSpecializationConfig( + scalar_idx=1, + ctor_tensor="other", + ufunc_key=UfuncKey.CUDAFunctorOnSelf, + ), +] + + +def compute_ufunc_cuda_dtype_body( + g: NativeFunctionsGroup, + dtype: ScalarType, + inner_loops: dict[UfuncKey, UfunctorSignature], + parent_ctx: Sequence[Binding], +) -> str: + body = "using opmath_t = at::opmath_type;" + body += "if (false) {}\n" # for ease of codegen + for config in BinaryScalarSpecializationConfigs: + if config.ufunc_key not in inner_loops: + continue + ufunctor_sig = inner_loops[config.ufunc_key] + scalar_idx = config.scalar_idx + 1 + # Make a copy and at the same time widen the type (not permissible + # without copy; we don't want to mutate the input argument anyway) + ctx: list[Expr | Binding] = list(parent_ctx) + ctx.append( + Expr( + expr=f"iter.scalar_value({scalar_idx})", + type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)), + ) + ) + ufunctor_ctor_exprs_str = ", ".join( + a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor) + ) + + # NB: ufunctor must be allocated before iter.remove_operand is called, + # as it relies on iter + body += f"""\ +else if (iter.is_cpu_scalar({scalar_idx})) {{ + {ufunctor_sig.name} ufunctor({ufunctor_ctor_exprs_str}); + iter.remove_operand({scalar_idx}); + gpu_kernel(iter, ufunctor); +}}""" + + ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor] + ufunctor_ctor_exprs_str = ", ".join( + a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor) + ) + body += f""" +else {{ + gpu_kernel(iter, {ufunctor_sig.name}({ufunctor_ctor_exprs_str})); +}} + """ + return body + + +@with_native_function +def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str: + # First, build the functors, indexing them by dtype + ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g) + + # Next, build the conditionals + sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA)) + dtype_cases = [] + for dtype, inner_ufunc_sigs in ufunctor_sigs.items(): + dtype_cases.append( + f""" +AT_DISPATCH_CASE(at::ScalarType::{dtype}, + [&]() {{ + {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())} + }} +) +""" + ) + + dtype_cases_str = "\n".join(dtype_cases) + + stub_sig = StubSignature(g) + + return f""" +{ufunctors} + +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; + +{stub_sig.kernel_defn()} {{ + AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}", + {dtype_cases_str} + ); +}} +REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); + +{sig.defn()} {{ + {stub_sig.direct_call(sig.arguments())}; +}} +""" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# CPU STUFF +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@dataclass(frozen=True) +class StubSignature: + g: NativeFunctionsGroup + + @property + def name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_stub" + + @property + def kernel_name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_kernel" + + @property + def type_name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_fn" + + def arguments(self) -> list[Binding]: + return ufunc.stub_arguments(self.g) + + def type(self) -> str: + cpp_args = self.arguments() + return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})" + + def dispatch_decl(self) -> str: + return f"DECLARE_DISPATCH({self.type_name}, {self.name})" + + def dispatch_defn(self) -> str: + return f"DEFINE_DISPATCH({self.name})" + + def kernel_defn(self) -> str: + return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})" + + def type_defn(self) -> str: + return f"using {self.type_name} = {self.type()}" + + # must be called from context where this is TensorIteratorBase* + def call(self, ctx: Sequence[Binding]) -> str: + return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + # used in CUDA to skip the unnecessary dynamic dispatch + def direct_call(self, ctx: Sequence[Binding]) -> str: + return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + +@with_native_function +def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str: + stub_sig = StubSignature(g) + sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU)) + + return f""" +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; +{stub_sig.dispatch_defn()}; + +{sig.defn()} {{ + {stub_sig.call(sig.arguments())}; +}} +""" + + +def compute_ufunc_cpu_dtype_body( + g: NativeFunctionsGroup, + dtype: ScalarType, + inner_loops: dict[UfuncKey, UfuncSignature], + parent_ctx: Sequence[Binding], +) -> str: + assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" + assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector} + scalar_loop = inner_loops[UfuncKey.CPUScalar] + vec_loop = None + if UfuncKey.CPUVector in inner_loops: + vec_loop = inner_loops[UfuncKey.CPUVector] + + # NB: We DON'T use translate here, because translate is + # incapable of CSE'ing the scalar accesses in case it is also + # used by Vectorized; also, the unpacking here is very simple + # and only affects Scalar; everything else is implicitly captured + # by the lambda + + # Setup scalar in scope + body = [] + ctx = [] + for b in parent_ctx: + if isinstance(b.argument, Argument) and b.argument.type != BaseType( + BaseTy.Scalar + ): + continue + body.append(f"auto _s_{b.name} = {b.name}.to();") + ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t)))) + if vec_loop is not None: + for b in parent_ctx: + if isinstance(b.argument, Argument) and b.argument.type != BaseType( + BaseTy.Scalar + ): + continue + body.append( + f"auto _v_{b.name} = at::vec::Vectorized(_s_{b.name});" + ) + ctx.append( + Expr( + f"_v_{b.name}", + NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))), + ) + ) + + # Setup lambda signature + # NB: simplified version of ufunctor_arguments + scalar_bindings = [] + vec_bindings = [] + for a in g.functional.func.arguments.flat_non_out: + if not a.type.is_tensor_like(): + continue + assert a.type == BaseType(BaseTy.Tensor) + scalar_bindings.append( + Binding( + name=a.name, + nctype=NamedCType(a.name, BaseCType(scalar_t)), + argument=a, + ) + ) + if vec_loop is not None: + vec_bindings.append( + Binding( + name=a.name, + nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))), + argument=a, + ) + ) + + def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]: + r: list[Expr | Binding] = [] + r.extend(ctx) + r.extend(b) + return r + + body_str = "\n".join(body) + if vec_loop is not None: + return f""" +{body_str} +cpu_kernel_vec(iter, + [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}, + [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }} +); +""" + else: + return f""" +{body_str} +cpu_kernel(iter, + [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }} +); +""" + + +@with_native_function +def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str: + stub_sig = StubSignature(g) + + # Reindex the ufunc by dtypes; processing generic/scalaronly as well + loops = g.out.ufunc_inner_loop + ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {} + for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]: + lks = [] + # ORDER MATTERS: this specifies overriding precedence + if k in loops: # should happen rarely + lks.append(k) + if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar: + lks.append(UfuncKey.ScalarOnly) + if UfuncKey.Generic in loops: + lks.append(UfuncKey.Generic) + # TODO: don't hardcode ufunc:: namespace here, should be centralized smh + for lk in lks: + for dtype in loops[lk].supported_dtypes: + compute_t: CType + if k is UfuncKey.CPUScalar: + compute_t = BaseCType(scalar_t) + elif k is UfuncKey.CPUVector: + compute_t = VectorizedCType(BaseCType(scalar_t)) + else: + raise AssertionError + inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {}) + if k not in inner_ufunc_sigs: + inner_ufunc_sigs[k] = UfuncSignature( + g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t + ) + + # Build the conditionals + dtype_cases = [] + for dtype, inner_ufunc_sigs in ufunc_sigs.items(): + dtype_cases.append( + f""" +AT_DISPATCH_CASE(at::ScalarType::{dtype}, + [&]() {{ + {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())} + }} +) +""" + ) + + dtype_cases_str = "\n".join(dtype_cases) + return f""" +namespace {{ + +{stub_sig.kernel_defn()} {{ + AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}", + {dtype_cases_str} + ); +}} + +}} // anonymous namespace + +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()}; +REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); +""" diff --git a/lib/python3.10/site-packages/torchgen/executorch/__init__.py b/lib/python3.10/site-packages/torchgen/executorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/torchgen/executorch/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/executorch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..455f387adcc00d68f5ee2759f66599d59a35f23d Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/executorch/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/executorch/__pycache__/model.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/executorch/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1b68f1879d08b57e2d1b05c9de30e9097ce219b Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/executorch/__pycache__/model.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/executorch/__pycache__/parse.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/executorch/__pycache__/parse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54e4f41f9e6ca15b1a3e157ed27d0d4e130df771 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/executorch/__pycache__/parse.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/__init__.py b/lib/python3.10/site-packages/torchgen/executorch/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7edd0478fc3e7bb651d98ba8eb9ddb8aef191c3 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/executorch/api/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/__pycache__/custom_ops.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/executorch/api/__pycache__/custom_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8342e59bcd31f45b0ea547a4a27a54d3601bff39 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/executorch/api/__pycache__/custom_ops.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b72d9c38b3c230cfafb6eec54932dea0dd9b8796 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/executorch/api/__pycache__/et_cpp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7478dda10a1793f4e3d4c5e3a0da1c19486de9d Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/executorch/api/__pycache__/unboxing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/custom_ops.py b/lib/python3.10/site-packages/torchgen/executorch/api/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe62c72f6882ee7d8595d1c36da39b406137bee --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/executorch/api/custom_ops.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from typing import Sequence, TYPE_CHECKING + +from torchgen import dest + + +# disable import sorting to avoid circular dependency. +from torchgen.api.types import DispatcherSignature # usort: skip +from torchgen.context import method_with_native_function +from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant +from torchgen.utils import concatMap, Target + + +if TYPE_CHECKING: + from torchgen.executorch.model import ETKernelIndex + from torchgen.selective_build.selector import SelectiveBuilder + + +# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at +# model authoring side. +@dataclass(frozen=True) +class ComputeNativeFunctionStub: + @method_with_native_function + def __call__(self, f: NativeFunction) -> str | None: + if Variant.function not in f.variants: + return None + + sig = DispatcherSignature.from_schema( + f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False + ) + assert sig is not None + if len(f.func.returns) == 0: + ret_name = "" + elif len(f.func.returns) == 1: + if f.func.arguments.out: + ret_name = f.func.arguments.out[0].name + else: + ret_name = next( + ( + a.name + for a in f.func.arguments.flat_non_out + if a.type == f.func.returns[0].type + ), + "", + ) + if not ret_name: + # if return type is tensor + if f.func.returns[0].type == BaseType(BaseTy.Tensor): + # Returns an empty tensor + ret_name = "at::Tensor()" + else: + raise Exception( # noqa: TRY002 + f"Can't handle this return type {f.func}" + ) # noqa: TRY002 + elif len(f.func.arguments.out) == len(f.func.returns): + # Returns a tuple of out arguments + tensor_type = "at::Tensor &" + comma = ", " + ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>( + {comma.join([r.name for r in f.func.arguments.out])} + )""" + else: + assert all( + a.type == BaseType(BaseTy.Tensor) for a in f.func.returns + ), f"Only support tensor returns but got {f.func.returns}" + # Returns a tuple of empty tensors + tensor_type = "at::Tensor" + comma = ", " + ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>( + {comma.join(["at::Tensor()" for _ in f.func.returns])} + )""" + ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else "" + return f""" +{sig.defn()} {{ + {ret_str} +}} + """ + + +def gen_custom_ops_registration( + *, + native_functions: Sequence[NativeFunction], + selector: SelectiveBuilder, + kernel_index: ETKernelIndex, + rocm: bool, +) -> tuple[str, str]: + """ + Generate custom ops registration code for dest.RegisterDispatchKey. + + :param native_functions: a sequence of `NativeFunction` + :param selector: for selective build. + :param kernel_index: kernels for all the ops. + :param rocm: bool for dest.RegisterDispatchKey. + :return: generated C++ code to register custom operators into PyTorch + """ + + # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet. + # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex. + + dispatch_key = DispatchKey.CPU + backend_index = kernel_index._to_backend_index() + static_init_dispatch_registrations = "" + ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list) + for native_function in native_functions: + ns_grouped_native_functions[native_function.namespace].append(native_function) + + for namespace, functions in ns_grouped_native_functions.items(): + if len(functions) == 0: + continue + dispatch_registrations_body = "\n".join( + list( + concatMap( + dest.RegisterDispatchKey( + backend_index, + Target.REGISTRATION, + selector, + rocm=rocm, + symint=False, + class_method_name=None, + skip_dispatcher_op_registration=False, + ), + functions, + ) + ) + ) + static_init_dispatch_registrations += f""" +TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ +{dispatch_registrations_body} +}};""" + anonymous_definition = "\n".join( + list( + concatMap( + dest.RegisterDispatchKey( + backend_index, + Target.ANONYMOUS_DEFINITION, + selector, + rocm=rocm, + symint=False, + class_method_name=None, + skip_dispatcher_op_registration=False, + ), + native_functions, + ) + ) + ) + return anonymous_definition, static_init_dispatch_registrations diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/et_cpp.py b/lib/python3.10/site-packages/torchgen/executorch/api/et_cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..76cebcd0f0f1dca2adf65d3f8029183d87cba85d --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/executorch/api/et_cpp.py @@ -0,0 +1,370 @@ +from __future__ import annotations + +from typing import Sequence + +from torchgen import local +from torchgen.api.types import ( + ArgName, + BaseCType, + Binding, + ConstRefCType, + CType, + MutRefCType, + NamedCType, + SpecialArgName, + TupleCType, + VectorCType, + voidT, +) +from torchgen.executorch.api.types import ( + ArrayRefCType, + BaseTypeToCppMapping, + OptionalCType, + scalarT, + tensorListT, + tensorT, +) +from torchgen.model import ( + Argument, + Arguments, + BaseTy, + BaseType, + ListType, + NativeFunction, + OptionalType, + Return, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import assert_never + + +""" +This file describes the translation of JIT schema to the public C++ API, which is what people use when they call +functions like at::add. It also serves as a native function API, which is the signature of kernels, +since in Executorch CppSignature is the same as NativeSignature. + +Difference between this file and torchgen.api.cpp.py: + + - Executorch doesn't support TensorOptions, however in this file we still keep the logic here to be compatible with + torchgen.api.cpp, so that we can do stuff like ATen mode (running ATen kernels in Executorch). + + - Executorch doesn't support Dimname. + + - Executorch runtime doesn't support SymInt, will treat it as int. +""" + + +# Translation of "value types" in JIT schema to C++ API type. Value +# types look the same no matter if they are argument types or return +# types. Returns None if the type in question is not a value type. +def valuetype_type( + t: Type, + *, + binds: ArgName, + remove_non_owning_ref_types: bool = False, +) -> NamedCType | None: + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: + return None + # For SymInt we simply treat it as int. + elif str(t) == "SymInt": + return NamedCType(binds, BaseCType(BaseTypeToCppMapping[BaseTy.int])) + if remove_non_owning_ref_types: + if t.name == BaseTy.str: + raise AssertionError( + "string ref->value conversion: not implemented yet" + ) + # All other BaseType currently map directly to BaseCppTypes. + return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) + elif isinstance(t, OptionalType): + elem = valuetype_type(t.elem, binds=binds) + if elem is None: + return None + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + if str(t.elem) == "bool": + assert t.size is not None + return NamedCType( + binds, ArrayRefCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool])) + ) + else: + return None + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Translation of types occurring in JIT arguments to a C++ argument type. +# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type. +# For example, we'll return std::vector instead of IntArrayRef. +# See Note [translation from C++ reference to value types] +def argumenttype_type( + t: Type, + *, + mutable: bool, + binds: ArgName, + remove_non_owning_ref_types: bool = False, +) -> NamedCType: + # If it's a value type, do the value type translation + r = valuetype_type( + t, + binds=binds, + remove_non_owning_ref_types=remove_non_owning_ref_types, + ) + if r is not None: + return r + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType(binds, MutRefCType(BaseCType(tensorT))) + else: + return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) + elif t.name == BaseTy.Scalar: + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + else: + raise AssertionError(f"base type should have been value type {t}") + elif isinstance(t, OptionalType): + if str(t.elem) == "Tensor": + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType( + binds, MutRefCType(BaseCType(tensorT)) + ) # TODO: fix this discrepancy + else: + return NamedCType( + binds, ConstRefCType(OptionalCType(BaseCType(tensorT))) + ) + elif str(t.elem) == "Scalar": + return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + # TODO: keeping these special cases for Tensor[] and Tensor?[] so that we can hookup with ATen kernels. + if str(t.elem) == "Tensor": + return NamedCType(binds, BaseCType(tensorListT)) + elif str(t.elem) == "Dimname": + raise NotImplementedError("Executorch doesn't support Dimname") + elif str(t.elem) == "Tensor?": + return NamedCType(binds, ArrayRefCType(OptionalCType(BaseCType(tensorT)))) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + return NamedCType(binds, ArrayRefCType(elem.type)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Translate a JIT argument into its C++ type +def argument_type(a: Argument, *, binds: ArgName) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, binds=binds) + + +# Translation of a (non-multi) return type from JIT to C++ +# N.B: returntype_type returns a CType, not a NamedCType. +# This is mostly because of the mismatch between return types and return names. +# e.g. a function with a return type of 'void' has 0 return names, +# and a function with a return type of 'std::tuple' has >1 return name. +def returntype_type(t: Type, *, mutable: bool) -> CType: + # placeholder is ignored + r = valuetype_type(t, binds="__placeholder__") + if r is not None: + return r.type + + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + if mutable: + if local.use_const_ref_for_mutable_tensors(): + return ConstRefCType(BaseCType(tensorT)) + else: + return MutRefCType(BaseCType(tensorT)) + else: + # Note [Tensor Copy Returns] + # Currently, we use "Argument.is_write" to determine + # whether or not Tensor return types should be copies or references. + # If that ever changes, take a look at other locations of this note! + return BaseCType(tensorT) + elif t.name == BaseTy.Scalar: + return BaseCType(scalarT) + elif isinstance(t, ListType): + assert ( + not mutable + ), "Native functions should never return a mutable tensor list. They should return void." + elem = returntype_type(t.elem, mutable=False) + assert t.size is None, f"fixed size list returns not supported: {t}" + return VectorCType(elem) + + raise AssertionError(f"unrecognized return type {t}") + + +# Translation of a single return to its C++ type +def return_type(r: Return) -> CType: + return returntype_type(r.type, mutable=r.is_write) + + +# Translation of a full (possibly multi) return from JIT to its C++ type +def returns_type(rs: Sequence[Return]) -> CType: + if len(rs) == 0: + return BaseCType(voidT) + elif len(rs) == 1: + return return_type(rs[0]) + else: + return TupleCType([return_type(r) for r in rs]) + + +def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: + returns: list[str] = [] + for i, r in enumerate(f.func.returns): + # If we have an inplace function, the return argument is + # implicitly named self. + # TODO: Consider incorporating this into the data model + if f.func.name.name.inplace: + assert i == 0, "illegal inplace function with multiple returns" + name = "self" + # If we are out function, the name is the name of the + # corresponding output function (r.name will get recorded + # in field_name later.) + elif f.func.is_out_fn(): + name = f.func.arguments.out[i].name + # If the return argument is explicitly named... + elif r.name: + name_conflict = any( + r.name == a.name for a in f.func.schema_order_arguments() + ) + if name_conflict and not f.func.is_out_fn(): + name = f"{r.name}_return" + else: + name = r.name + # If there is no explicit name and no fallback name was passed in, we just name the output result, + # unless it's a multi-return, in which case it's result0, + # result1, etc (zero-indexed) + else: + name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}" + returns.append(name) + return returns + + +JIT_TO_CPP_DEFAULT = { + "False": "false", + "True": "true", + "None": "torch::executorch::nullopt", # UGH this one is type directed + "[]": "{}", + "contiguous_format": "torch::executorch::MemoryFormat::Contiguous", + "long": "torch::executorch::kLong", +} + + +# Convert a JIT default into C++ expression representing the default +def default_expr(d: str, t: Type) -> str: + if d == "None" and str(t) == "Tensor?": + return "{}" + if isinstance(t, BaseType) and t.name is BaseTy.str: + # Schema allows single quotes but C++ needs double + if len(d) >= 2 and d[0] == "'" and d[-1] == "'": + s = "" + i = 1 + while i + 1 < len(d): + if d[i] != "\\": + if d[i] == '"': + s += '\\"' + else: + s += d[i] + i += 1 + else: + if d[i + 1] == "'": + s += "'" + else: + s += d[i : i + 2] + i += 2 + + return f'"{s}"' + + if isinstance(t, OptionalType): + if d == "None": + return "torch::executor::nullopt" + + return default_expr(d, t.elem) + + if isinstance(t, ListType): + if d.startswith("[") and d.endswith("]"): + return "{" + d[1:-1] + "}" + elif t.size is None: + # NOTE: Sized lists can have scalar defaults + raise ValueError(f"Expected a list default '[...]' but found: '{d}'") + + return JIT_TO_CPP_DEFAULT.get(d, d) + + +# Convert an argument into its C++ API form + + +def argument( + a: Argument | TensorOptionsArguments | SelfArgument, + *, + cpp_no_default_args: set[str], + method: bool, + faithful: bool, + has_tensor_options: bool, +) -> list[Binding]: + def sub_argument( + a: Argument | TensorOptionsArguments | SelfArgument, + ) -> list[Binding]: + return argument( + a, + cpp_no_default_args=cpp_no_default_args, + method=method, + faithful=faithful, + has_tensor_options=has_tensor_options, + ) + + if isinstance(a, Argument): + binds: ArgName + if a.name == "memory_format" and has_tensor_options: + binds = SpecialArgName.possibly_redundant_memory_format + else: + binds = a.name + default: str | None = None + if a.name not in cpp_no_default_args and a.default is not None: + default = default_expr(a.default, a.type) + return [ + Binding( + nctype=argument_type(a, binds=binds), + name=a.name, + default=default, + argument=a, + ) + ] + elif isinstance(a, TensorOptionsArguments): + raise NotImplementedError("Need to implement type resolution for TensorOptions") + elif isinstance(a, SelfArgument): + if method: + # Caller is responsible for installing implicit this in context! + return [] + else: + return sub_argument(a.argument) + else: + assert_never(a) + + +def arguments( + arguments: Arguments, + *, + faithful: bool, + method: bool, + cpp_no_default_args: set[str], +) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + if faithful: + args.extend(arguments.non_out) + args.extend(arguments.out) + else: + args.extend(arguments.out) + args.extend(arguments.non_out) + return [ + r.no_default() if faithful else r + for a in args + for r in argument( + a, + faithful=faithful, + method=method, + has_tensor_options=arguments.tensor_options is not None, + cpp_no_default_args=cpp_no_default_args, + ) + ] diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/types/__init__.py b/lib/python3.10/site-packages/torchgen/executorch/api/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..08cb168df737163ae8b4189a421869c0c7718d7e --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/executorch/api/types/__init__.py @@ -0,0 +1,4 @@ +from torchgen.executorch.api.types.types import * + + +from torchgen.executorch.api.types.signatures import * # usort: skip diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/types/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/executorch/api/types/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1caa4709977470109497f422ef58af21710f4de3 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/executorch/api/types/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/types/__pycache__/signatures.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/executorch/api/types/__pycache__/signatures.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..066a0d37535f687ed17d960b69a7e3c6ad61dced Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/executorch/api/types/__pycache__/signatures.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da5a69e88c7cbdb760113fb6ee2904085b88d87c Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/executorch/api/types/__pycache__/types.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/types/signatures.py b/lib/python3.10/site-packages/torchgen/executorch/api/types/signatures.py new file mode 100644 index 0000000000000000000000000000000000000000..ac3477cede6ed03114508c051055258761e6ee59 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/executorch/api/types/signatures.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torchgen.api.cpp as aten_cpp +from torchgen.executorch.api.types.types import contextArg + + +if TYPE_CHECKING: + from torchgen.api.types import Binding, CType + from torchgen.model import FunctionSchema, NativeFunction + + +@dataclass(frozen=True) +class ExecutorchCppSignature: + """ + This signature is merely a CppSignature with Executorch types (optionally + contains KernelRuntimeContext as well). The inline definition of + CppSignature is generated in Functions.h and it's used by unboxing + functions. + """ + + # The schema this signature is derived from + func: FunctionSchema + + # The set of C++ arguments which should not have defaults applied to them + cpp_no_default_args: set[str] + + # Allows you to prepend an arbitrary prefix to the signature name. + # This is useful for parts of the codegen that generate wrappers around kernels, + # and need to avoid naming collisions. + prefix: str = "" + + def arguments(self, *, include_context: bool = True) -> list[Binding]: + return ([contextArg] if include_context else []) + et_cpp.arguments( + self.func.arguments, + faithful=True, # always faithful, out argument at the end + method=False, # method not supported + cpp_no_default_args=self.cpp_no_default_args, + ) + + def name(self) -> str: + return self.prefix + aten_cpp.name( + self.func, + faithful_name_for_out_overloads=True, + ) + + def decl(self, name: str | None = None, *, include_context: bool = True) -> str: + args_str = ", ".join( + a.decl() for a in self.arguments(include_context=include_context) + ) + if name is None: + name = self.name() + return f"{self.returns_type().cpp_type()} {name}({args_str})" + + def defn(self, name: str | None = None) -> str: + args = [a.defn() for a in self.arguments()] + args_str = ", ".join(args) + if name is None: + name = self.name() + return f"{self.returns_type().cpp_type()} {name}({args_str})" + + def returns_type(self) -> CType: + return et_cpp.returns_type(self.func.returns) + + @staticmethod + def from_native_function( + f: NativeFunction, *, prefix: str = "" + ) -> ExecutorchCppSignature: + return ExecutorchCppSignature( + func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args + ) + + +from torchgen.executorch.api import et_cpp diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/types/types.py b/lib/python3.10/site-packages/torchgen/executorch/api/types/types.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a960a8246b98116cf1687cd9cba71132598b78 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/executorch/api/types/types.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from torchgen.api.types import ( + BaseCppType, + BaseCType, + Binding, + boolT, + CType, + doubleT, + Expr, + longT, + MutRefCType, + NamedCType, +) +from torchgen.model import BaseTy + + +halfT = BaseCppType("torch::executor", "Half") +bfloat16T = BaseCppType("torch::executor", "BFloat16") +stringT = BaseCppType("torch::executor", "string_view") +scalarTypeT = BaseCppType("torch::executor", "ScalarType") +tensorT = BaseCppType("torch::executor", "Tensor") +tensorListT = BaseCppType("torch::executor", "TensorList") +scalarT = BaseCppType("torch::executor", "Scalar") +memoryFormatT = BaseCppType("torch::executor", "MemoryFormat") +intArrayRefT = BaseCppType("torch::executor", "IntArrayRef") +optionalT = BaseCppType("torch::executor", "optional") +contextT = BaseCppType("torch::executor", "KernelRuntimeContext") + +contextExpr = Expr( + expr="context", + type=NamedCType(name="context", type=MutRefCType(BaseCType(contextT))), +) + +contextArg = Binding( + name="context", + nctype=contextExpr.type, + argument=None, # type: ignore[arg-type] + default=None, +) + +BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = { + BaseTy.int: longT, + BaseTy.float: doubleT, + BaseTy.bool: boolT, + BaseTy.str: stringT, + BaseTy.ScalarType: scalarTypeT, + BaseTy.Tensor: tensorT, + BaseTy.Scalar: scalarT, + BaseTy.MemoryFormat: memoryFormatT, +} + + +@dataclass(frozen=True) +class OptionalCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"torch::executor::optional<{self.elem.cpp_type()}>" + + def cpp_type_registration_declarations(self) -> str: + return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>" + + def remove_const_ref(self) -> CType: + return OptionalCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class ArrayRefCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"torch::executor::ArrayRef<{self.elem.cpp_type()}>" + + def cpp_type_registration_declarations(self) -> str: + return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>" + + def remove_const_ref(self) -> CType: + return ArrayRefCType(self.elem.remove_const_ref()) diff --git a/lib/python3.10/site-packages/torchgen/executorch/api/unboxing.py b/lib/python3.10/site-packages/torchgen/executorch/api/unboxing.py new file mode 100644 index 0000000000000000000000000000000000000000..6845e72a22a5d884bd98db7739b0654f1ef88c10 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/executorch/api/unboxing.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Sequence, TYPE_CHECKING + +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + ListType, + NativeFunction, + OptionalType, + Type, +) + + +if TYPE_CHECKING: + from torchgen.api.types import Binding, CType, NamedCType + + +connector = "\n\t" + + +# Return unboxing function name for a NativeFunction +def name(f: NativeFunction) -> str: + return f.func.name.unambiguous_name() + + +@dataclass(frozen=True) +class Unboxing: + """ + Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing. + A sample generated code: + // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + void mul_out(EValue** stack) { + EValue& self = *stack[0]; + EValue& other = *stack[1]; + EValue& out = *stack[2]; + const torch::executor::Tensor & self_base = self.to(); + const torch::executor::Tensor & other_base = other.to(); + torch::executor::Tensor & out_base = out.to(); + + EXECUTORCH_SCOPE_PROF("native_call_mul.out"); + torch::executor::mul_outf(self_base, other_base, out_base); + + + } + """ + + # this is a callable that converts a JIT argument, into its C++ type. + # Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type. + argument_type_gen: Callable[ + ..., + NamedCType, + ] + + # Convert all the arguments in a NativeFunction to C++ code + def convert_arguments( + self, args: Sequence[Binding] + ) -> tuple[list[Binding], list[str]]: + code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))] + binding_list = [] + for arg in args: + # expecting only Argument + if not isinstance(arg.argument, Argument): + raise Exception( # noqa: TRY002 + f"Unexpected argument type, expecting `Argument` but got {arg}" + ) + argument: Argument = arg.argument + unboxed_name, _, code, decl = self.argumenttype_evalue_convert( + argument.type, argument.name, mutable=argument.is_write + ) + code_list.extend(decl) + code_list.extend(code) + binding_list.append(arg.with_name(unboxed_name)) + return binding_list, code_list + + def argumenttype_evalue_convert( + self, t: Type, arg_name: str, *, mutable: bool = False + ) -> tuple[str, CType, list[str], list[str]]: + """ + Takes in the type, name and mutability corresponding to an argument, and generates a tuple of: + (1) the C++ code necessary to unbox the argument + (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType + :param t: a `Type` of an argument + :param arg_name: argument name + :param mutable: boolean for whether this argument type is mutable + :return: unboxed result + """ + ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type + + if isinstance(t, BaseType): + out_name = f"{arg_name}_base" + code, decl = self._gen_code_base_type( + arg_name=arg_name, out_name=out_name, ctype=ctype + ) + elif isinstance(t, OptionalType): + out_name = f"{arg_name}_opt_out" + code, decl = self._gen_code_optional_type( + arg_name=arg_name, out_name=out_name, t=t, ctype=ctype + ) + elif isinstance(t, ListType): + out_name = f"{arg_name}_list_out" + code, decl = self._gen_code_list_type( + arg_name=arg_name, out_name=out_name, t=t, ctype=ctype + ) + else: + raise Exception( # noqa: TRY002 + f"Cannot handle type {t}. arg_name: {arg_name}" + ) # noqa: TRY002 + return out_name, ctype, code, decl + + def _gen_code_base_type( + self, arg_name: str, out_name: str, ctype: CType + ) -> tuple[list[str], list[str]]: + return [ + f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" + ], [] + + def _gen_code_optional_type( + self, arg_name: str, out_name: str, t: OptionalType, ctype: CType + ) -> tuple[list[str], list[str]]: + in_name = f"{arg_name}_opt_in" + res_name, base_type, res_code, decl = self.argumenttype_evalue_convert( + t.elem, in_name + ) + return ( + f""" + auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>(); + """.split( + "\n" + ), + decl, + ) + + def _gen_code_list_type( + self, arg_name: str, out_name: str, t: ListType, ctype: CType + ) -> tuple[list[str], list[str]]: + in_name = f"{arg_name}_list_in" + elem_name = f"{arg_name}_elem" + code = [] + res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert( + t.elem, elem_name + ) + + if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor: + code.extend( + f""" + auto {out_name} = {arg_name}.toTensorList(); + """.split( + "\n" + ) + ) + elif isinstance(t.elem, BaseType) and ( + t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt + ): + code.extend( + f""" + auto {out_name} = {arg_name}.toIntList(); + """.split( + "\n" + ) + ) + elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float: + code.extend( + f""" + auto {out_name} = {arg_name}.toDoubleList(); + """.split( + "\n" + ) + ) + elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool: + # handle list type with size, e.g., bool[4] + code.extend( + f""" +#ifdef USE_ATEN_LIB +std::array {out_name}; +auto {in_name} = {arg_name}.toBoolList(); +size_t _i = 0; +for (auto {elem_name}: {in_name}) {{ + {out_name}[_i++] = {elem_name}; +}} +#else +auto {out_name} = {arg_name}.toBoolList(); +#endif + """.split( + "\n" + ) + ) + # pytorch codegen: + # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional> + elif ( + isinstance(t.elem, OptionalType) + and isinstance(t.elem.elem, BaseType) + and t.elem.elem.name == BaseTy.Tensor + ): + code.extend( + f""" +#ifdef USE_ATEN_LIB +auto {in_name} = {arg_name}.toListOptionalTensor(); +c10::List<::std::optional> {out_name}; +for (auto {elem_name}: {in_name}) {{ + {out_name}.push_back({elem_name}); +}} +#else +auto {out_name} = {arg_name}.toListOptionalTensor(); +#endif + """.split( + "\n" + ) + ) + else: + # use ArrayRef as default. + vec_name = arg_name + "_vec" + # need to bring vector instantiation out of scope so that ArrayRef has valid data + decl.append( + f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};" + ) + code.extend( + f""" + for (EValue {elem_name}: {in_name}) {{ + {connector.join(res_code)} + {vec_name}.push_back({res_name}); + }} + {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); + """.split( + "\n" + ) + ) + return code, decl diff --git a/lib/python3.10/site-packages/torchgen/executorch/model.py b/lib/python3.10/site-packages/torchgen/executorch/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6aadfe41daed2e5d124736273cc6ec820aee6ee6 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/executorch/model.py @@ -0,0 +1,220 @@ +# Represents all kernels used by an Executorch model. +# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure. + +from __future__ import annotations + +import itertools +from collections import defaultdict, namedtuple +from dataclasses import dataclass +from enum import IntEnum + +from torchgen.model import ( + BackendIndex, + BackendMetadata, + DispatchKey, + NativeFunction, + NativeFunctionsGroup, + OperatorName, +) +from torchgen.utils import assert_never + + +KERNEL_KEY_VERSION = 1 + + +# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen +class ScalarType(IntEnum): + Byte = 0 + Char = 1 + Short = 2 + Int = 3 + Long = 4 + Float = 6 + Double = 7 + Bool = 11 + + +ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"]) + + +@dataclass(frozen=True) +class ETKernelKeyOpArgMeta: + arg_name: str + dtype: str + # The order of the dimensions if entry is a Tensor + dim_order: tuple[int, ...] + + def to_native_string(self) -> str: + dtype_str = ScalarType[self.dtype].value + dim_str = str(self.dim_order)[1:-1].replace(" ", "") + return f"{dtype_str};{dim_str}" + + +@dataclass(frozen=True) +class ETKernelKey: + # Field undefined is default = True + arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = () + + # Indicator for this kernel being used as a catch all + default: bool = False + + version: int = KERNEL_KEY_VERSION + + @staticmethod + def gen_from_yaml( + args: dict[str, tuple[str, str]], + type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val + dim_order_alias_map: dict[str, list[int]], + ) -> list[ETKernelKey]: + """Generate ETKernelKeys from arg kernel specs + Multiple ETKernelKeys are returned due to dtype permutations from utilizing + type_alias_map (actualizing each potential type permutation as a KernelKey) + + Args: + args: Mapping from argument name to kernel specs + Kernel specs are a tuple of (dtype, dim_order). + Currently tuple entries must be aliased via the alias map arguments + type_alias_map: Mapping from type alias to potential type enums + i.e { T0 : [Double, Int] } means T0 can be either Double or Int + Used for lookup by args + dim_order_alias_map: Mapping from alias to a list of dimension orders + Used for lookup by args + """ + # Cast to dim order to int + dim_order_alias_map = { + k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items() + } + kernel_keys = [] + + # Get all used Dtype Alias + dtype_alias_used = set() + for type_alias, dim_order in args.values(): + # Enforce usage of alias initially + # TODO: Support inlined arguments + assert type_alias in type_alias_map, "Undefined type alias: " + str( + type_alias + ) + assert ( + dim_order in dim_order_alias_map + ), "Undefined dim_order alias: " + str(dim_order) + dtype_alias_used.add(type_alias) + + # Generate all permutations of dtype alias values + alias_dtypes = [ + [(alias, dtype) for dtype in type_alias_map[alias]] + for alias in dtype_alias_used + ] + alias_permutations = [ + dict(permutation) for permutation in list(itertools.product(*alias_dtypes)) + ] + + # Using each alias value permutation, generate kernel keys + op_arg_cache = {} + for permutation in alias_permutations: + arg_list = [] + for arg_name, arg_spec in args.items(): + dtype = permutation[arg_spec[0]] + dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment] + if ( + cache_key := (arg_name, dtype, tuple(dim_order)) + ) not in op_arg_cache: + op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type] + + arg_list.append(op_arg_cache[cache_key]) + kernel_keys.append(ETKernelKey(tuple(arg_list))) + + return kernel_keys + + def to_native_string(self) -> str: + if self.default: + return "default" + return ( + "v" + + str(KERNEL_KEY_VERSION) + + "/" + + "|".join([arg.to_native_string() for arg in self.arg_meta]) + ) + + +@dataclass(frozen=True) +class ETKernelIndex: + index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] + + def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool: + m = self.get_kernels(g) + return m is not None + + def get_kernels( + self, g: NativeFunction | NativeFunctionsGroup + ) -> dict[ETKernelKey, BackendMetadata]: + if isinstance(g, NativeFunction): + f = g + elif isinstance(g, NativeFunctionsGroup): + f = g.functional + else: + assert_never(g) + if f.func.name not in self.index: + return {} + return self.index[f.func.name] + + @staticmethod + def grow_from_backend_indices( + kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]], + backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]], + ) -> None: + for dk in backend_indices: + index = backend_indices[dk] + for op, backend_metadata in index.items(): + if op in kernel_index: + kernel_index[op][ETKernelKey(default=True)] = backend_metadata + else: + kernel_index[op] = {ETKernelKey(default=True): backend_metadata} + + @staticmethod + def from_backend_indices( + backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] + ) -> ETKernelIndex: + kernel_index: dict[ + OperatorName, dict[ETKernelKey, BackendMetadata] + ] = defaultdict(dict) + ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices) + return ETKernelIndex(kernel_index) + + def grow( + self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] + ) -> ETKernelIndex: + ETKernelIndex.grow_from_backend_indices(self.index, backend_indices) + return self + + def _to_backend_index(self) -> BackendIndex: + """ + WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex. + """ + index: dict[OperatorName, BackendMetadata] = {} + for op in self.index: + kernel_dict = self.index[op] + assert ( + len(kernel_dict.values()) == 1 + ), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}" + index[op] = kernel_dict.get( + ETKernelKey(default=True), + BackendMetadata(kernel="", structured=False, cpp_namespace=""), + ) + return BackendIndex( + dispatch_key=DispatchKey.CPU, + use_out_as_primary=False, + device_guard=False, + external=False, + index=index, + ) + + # Note duplicate ETKernelKey from index_b will clobber the metadata from index_a + @staticmethod + def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex: + combined = defaultdict(dict, index_a.index.copy()) + + for op, entry in index_b.index.items(): + for key, metadata in entry.items(): + combined[op][key] = metadata + + return ETKernelIndex(combined) diff --git a/lib/python3.10/site-packages/torchgen/executorch/parse.py b/lib/python3.10/site-packages/torchgen/executorch/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..8095abd5b6bc33fb02b4af5b1643ad348fca6c1a --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/executorch/parse.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +from collections import defaultdict, namedtuple +from typing import Any + +import yaml + +from torchgen.executorch.model import ETKernelIndex, ETKernelKey +from torchgen.gen import LineLoader, parse_native_yaml +from torchgen.model import ( + BackendMetadata, + DispatchKey, + FunctionSchema, + NativeFunction, + OperatorName, +) +from torchgen.utils import NamespaceHelper + + +# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices. +ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"]) + +# Fields in native_functions.yaml used to determine which kernels should be used +ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"] + + +def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]: + """Given a loaded yaml representing kernel assignment information, extract the + mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance) + + Args: + ei: Dict keys {kernels, type_alias, dim_order_alias} + See ETKernelKey for description of arguments + """ + e = ei.copy() + if (kernels := e.pop("kernels", None)) is None: + return {} + + type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment] + dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment] + dim_order_alias.pop("__line__", None) + + kernel_mapping: dict[ETKernelKey, BackendMetadata] = {} + + for entry in kernels: # type: ignore[attr-defined] + arg_meta = entry.get("arg_meta") + if arg_meta is not None: + arg_meta.pop("__line__") + + kernel_name = entry.get("kernel_name") + namespace_helper = NamespaceHelper.from_namespaced_entity( + kernel_name, max_level=3 + ) + kernel_namespace = namespace_helper.get_cpp_namespace(default="at") + backend_metadata = BackendMetadata( + kernel=namespace_helper.entity_name, + structured=False, + cpp_namespace=(kernel_namespace + "::native"), + ) + + kernel_keys = ( + [ETKernelKey((), default=True)] + if arg_meta is None + else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type] + ) + + for kernel_key in kernel_keys: + assert kernel_key not in kernel_mapping, ( + "Duplicate kernel key: " + str(kernel_key) + " " + str(e) + ) + kernel_mapping[kernel_key] = backend_metadata + + return kernel_mapping + + +def parse_et_yaml_struct(es: object) -> ETKernelIndex: + """Given a loaded yaml representing a list of operators, for each op extract the mapping + of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance + that should be used by the kernel key). + """ + indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {} + for ei in es: # type: ignore[attr-defined] + e = ei.copy() + + funcs = e.pop("func") + assert isinstance(funcs, str), f"not a str: {funcs}" + namespace_helper = NamespaceHelper.from_namespaced_entity( + namespaced_entity=funcs, max_level=1 + ) + opname = FunctionSchema.parse(namespace_helper.entity_name).name + + assert opname not in indices, f"Duplicate func found in yaml: {opname} already" + + if len(index := parse_from_yaml(e)) != 0: + indices[opname] = index + + return ETKernelIndex(indices) + + +def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]: + """Given a loaded yaml representing a list of operators, extract the + kernel key related fields indexed by the operator name. + """ + fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict) + for ei in es: # type: ignore[attr-defined] + funcs = ei.get("func") + assert isinstance(funcs, str), f"not a str: {funcs}" + namespace_helper = NamespaceHelper.from_namespaced_entity( + namespaced_entity=funcs, max_level=1 + ) + opname = FunctionSchema.parse(namespace_helper.entity_name).name + + for field in ET_FIELDS: + if (value := ei.get(field)) is not None: + fields[opname][field] = value + + return fields + + +def parse_et_yaml( + path: str, + tags_yaml_path: str, + ignore_keys: set[DispatchKey] | None = None, + skip_native_fns_gen: bool = False, +) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]: + """Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict + of fields to persist from native_functions.yaml to functions.yaml + """ + with open(path) as f: + es = yaml.load(f, Loader=LineLoader) + + et_kernel = extract_kernel_fields(es) + + # Remove ET specific fields from entries for BC compatibility + strip_et_fields(es) + + native_yaml = parse_native_yaml( + path, + tags_yaml_path, + ignore_keys, + skip_native_fns_gen=skip_native_fns_gen, + loaded_yaml=es, + ) + return native_yaml.native_functions, et_kernel + + +def strip_et_fields(es: object) -> None: + """Given a loaded yaml representing a list of operators, + remove ET specific fields from every entries for BC compatibility + """ + for entry in es: # type: ignore[attr-defined] + for field in ET_FIELDS: + entry.pop(field, None) diff --git a/lib/python3.10/site-packages/torchgen/operator_versions/__init__.py b/lib/python3.10/site-packages/torchgen/operator_versions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aab06f068ac42f9e831752a2600d836c43cb5076 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5960a5dcaf98b5a6cd82b389dc9d4767a4ad09e8 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f179e068e9d004c9bdec133fa30ccbe565736f73 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/operator_versions/gen_mobile_upgraders.py b/lib/python3.10/site-packages/torchgen/operator_versions/gen_mobile_upgraders.py new file mode 100644 index 0000000000000000000000000000000000000000..362ce427d508ca7885803e013ef3ac4640314a1c --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/operator_versions/gen_mobile_upgraders.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import os +from enum import Enum +from operator import itemgetter +from pathlib import Path +from typing import Any + +import torch +from torch.jit.generate_bytecode import generate_upgraders_bytecode +from torchgen.code_template import CodeTemplate +from torchgen.operator_versions.gen_mobile_upgraders_constant import ( + MOBILE_UPGRADERS_HEADER_DESCRIPTION, +) + + +class ByteCode(Enum): + instructions = 1 + constants = 2 + types = 3 + operators = 4 + register_size = 5 + + +EXCLUDED_OP_SET = [ + "aten::full.names", + "aten::full.out", + "aten::full", +] + +EXCLUE_UPGRADER_SET = ["full_0_4", "full_out_0_4"] + +ONE_INSTRUCTION = CodeTemplate( + """ + Instruction{OpCode::${operator_name}, ${X}, ${N}},""" +) + +INSTRUCTION_LIST = CodeTemplate( + """std::vector({ + ${instruction_list} + }), // instructions list""" +) + +ONE_CONSTANT = CodeTemplate( + """ + c10::IValue(${constant}),""" +) + +CONSTANT_LIST = CodeTemplate( + """std::vector({ + ${constant_list} + }), // constants list""" +) + +CONSTANTS_LIST_EMPTY = """std::vector(), // constants list""" + +ONE_TYPE = CodeTemplate("""c10::parseType("${type_str}"),""") + +TYPE_LIST = CodeTemplate( + """std::vector({ + ${type_list} + }), // types list""" +) + +TYPE_LIST_EMPTY = """std::vector(), // types list""" + +ONE_OPERATOTR_STRING = CodeTemplate( + """ + OperatorString({"${operator_name}", "${overload_name}", ${num_of_args}}),""" +) + +OPERATOR_STRING_LIST = CodeTemplate( + """ + std::vector({ + ${operator_string_list} + }), // operators list""" +) + +ONE_UPGRADER_FUNCTION = CodeTemplate( + """ + mobile::Function::registerFunc( + "${upgrader_name}", + ${instruction_list}, + ${constant_list}, + ${type_list}, + ${register_size} + )""" +) + +ONE_UPGRADER_SRC = CodeTemplate( + """ + ByteCodeFunctionWithOperator({ + ${bytecode_function}, + ${operator_string_list} + }),""" +) + + +ONE_UPGRADER_IN_VERSION_MAP = CodeTemplate( + """Upgrader({${upgrader_min_version}, ${upgrader_max_version}, "${upgrader_name}", ${bytecode_func_index}})""" +) # noqa: E501 + +ONE_OPERATOR_IN_VERSION_MAP = CodeTemplate( + """ + {std::string("${operator_name}"), + std::vector({ + ${upgrader_list_in_version_map} + })},""" +) + + +OPERATOR_VERSION_MAP = CodeTemplate( + """ +const std::unordered_map> +getOperatorVersionMapForMobile() { + static std::unordered_map> + operatorVersionMapForMobile({ + ${operator_list_in_version_map} + }); + return operatorVersionMapForMobile; +} +""" +) + + +UPGRADER_CPP_SRC = CodeTemplate( + MOBILE_UPGRADERS_HEADER_DESCRIPTION + + """ +#include +#include + +namespace c10 { +TypePtr parseType(const std::string& pythonStr); +} // namespace c10 + +namespace torch { +namespace jit { + +// clang-format off + +// From operator_versions_map +${operator_version_map} + +const std::vector& getUpgraderBytecodeList() { + auto generate_upgrader_bytecode_list = []() { + std::vector upgrader_function_list({ + ${upgrader_bytecode} + }); + for (const auto& upgrader_function : upgrader_function_list) { + for (const auto& op : upgrader_function.operators) { + upgrader_function.function.append_operator( + op.name, + op.overload_name, + op.num_specified_args); + } + } + return upgrader_function_list; + }; + static std::vector upgraderBytecodeList = + generate_upgrader_bytecode_list(); + return upgraderBytecodeList; +} + +// clang-format on + +} // namespace jit +} // namespace torch +""" +) + +UPGRADER_MOBILE_FILE_NAME = "upgrader_mobile.cpp" + +UPGRADER_ELEMENT = CodeTemplate( + """\ +Upgrader({${min_version}, ${max_version}, ${operator_name}, ${index}}), +""" +) + +PER_OPERATOR_UPGRADER_LIST = CodeTemplate( + """\ +{ + std::string(${operator_name}), + std::vector({${upgrader_list}}); +} +""" +) + + +def construct_instruction(instruction_list_from_yaml: list[Any]) -> str: + instruction_list_part = [] + for instruction in instruction_list_from_yaml: + instruction_list_part.append( + ONE_INSTRUCTION.substitute( + operator_name=instruction[0], + X=instruction[1], + N=instruction[2], + ) + ) + return INSTRUCTION_LIST.substitute( + instruction_list="".join(instruction_list_part).lstrip("\n") + ) + + +def construct_constants(constants_list_from_yaml: list[Any]) -> str: + constants_list_part = [] + for constant_from_yaml in constants_list_from_yaml: + convert_constant = None + if isinstance(constant_from_yaml, str): + # Add quotes if it's string + convert_constant = f'"{constant_from_yaml}"' + elif isinstance(constant_from_yaml, bool): + convert_constant = "true" if constant_from_yaml else "false" + elif constant_from_yaml is None: + convert_constant = "" + elif isinstance(constant_from_yaml, int): + convert_constant = str(constant_from_yaml) + else: + raise ValueError( + f"The type of {constant_from_yaml} is {type(constant_from_yaml)}. " + "Please add change in construct_constants function in gen_mobile_upgraders.py." + ) + constants_list_part.append(ONE_CONSTANT.substitute(constant=convert_constant)) + if len(constants_list_part) == 0: + return CONSTANTS_LIST_EMPTY + return CONSTANT_LIST.substitute( + constant_list="".join(constants_list_part).lstrip("\n") + ) + + +def construct_operators(operator_list_from_yaml: list[Any]) -> str: + operator_list_part = [] + for operator in operator_list_from_yaml: + operator_list_part.append( + ONE_OPERATOTR_STRING.substitute( + operator_name=operator[0], + overload_name=operator[1], + num_of_args=operator[2], + ) + ) + return OPERATOR_STRING_LIST.substitute( + operator_string_list="".join(operator_list_part).lstrip("\n") + ) + + +def construct_types(types_tr_list_from_yaml: list[Any]) -> str: + types_tr_list_part = [] + for types_tr in types_tr_list_from_yaml: + types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr)) + if len(types_tr_list_part) == 0: + return TYPE_LIST_EMPTY + return TYPE_LIST.substitute(type_list="".join(types_tr_list_part).lstrip("\n")) + + +def construct_register_size(register_size_from_yaml: int) -> str: + if not isinstance(register_size_from_yaml, int): + raise ValueError( + f"Input register size is {register_size_from_yaml} and" + "it's type is {type(register_size_from_yaml)}. An int type is expected." + ) + return str(register_size_from_yaml) + + +def construct_version_maps( + upgrader_bytecode_function_to_index_map: dict[str, Any] +) -> str: + version_map = torch._C._get_operator_version_map() + sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return] + sorted_version_map = dict(sorted_version_map_) + + operator_list_in_version_map_part = [] + for op_name in sorted_version_map: + upgraders_in_version_map_part = [] + # TODO: remove the skip after these two operators schemas are fixed + if op_name in EXCLUDED_OP_SET: + continue + upgrader_ranges = torch._C._get_upgrader_ranges(op_name) + upgrader_entries = sorted_version_map[op_name] + assert len(upgrader_ranges) == len(upgrader_entries) + for idx, upgrader_entry in enumerate(upgrader_entries): + upgrader_name = upgrader_entry.upgrader_name + bytecode_function_index = upgrader_bytecode_function_to_index_map[ + upgrader_name + ] + upgraders_in_version_map_part.append( + ONE_UPGRADER_IN_VERSION_MAP.substitute( + upgrader_min_version=upgrader_ranges[idx].min_version, + upgrader_max_version=upgrader_ranges[idx].max_version, + upgrader_name=upgrader_name, + bytecode_func_index=bytecode_function_index, + ) + ) + operator_list_in_version_map_part.append( + ONE_OPERATOR_IN_VERSION_MAP.substitute( + operator_name=op_name, + upgrader_list_in_version_map="".join(upgraders_in_version_map_part), + ) + ) + return OPERATOR_VERSION_MAP.substitute( + operator_list_in_version_map="".join(operator_list_in_version_map_part).lstrip( + "\n" + ) + ) + + +def get_upgrader_bytecode_function_to_index_map( + upgrader_dict: list[dict[str, Any]] +) -> dict[str, Any]: + upgrader_bytecode_function_to_index_map = {} + index = 0 + for upgrader_bytecode in upgrader_dict: + for upgrader_name in upgrader_bytecode.keys(): + if upgrader_name in EXCLUE_UPGRADER_SET: + continue + upgrader_bytecode_function_to_index_map[upgrader_name] = index + index += 1 + return upgrader_bytecode_function_to_index_map + + +def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None: + body_parts = [] + upgrader_bytecode_function_to_index_map = ( + get_upgrader_bytecode_function_to_index_map(upgrader_dict) + ) + version_map_src = construct_version_maps(upgrader_bytecode_function_to_index_map) + all_upgrader_src_string = [] + for upgrader_bytecode in upgrader_dict: + for upgrader_name, bytecode in upgrader_bytecode.items(): + # TODO: remove the skip after these two operators schemas are fixed + if upgrader_name in EXCLUE_UPGRADER_SET: + continue + instruction_list_str = "" + constant_list_str = "" + type_list_str = "" + register_size_str = "" + operator_list_str = "" + for table_name, contents in bytecode.items(): + element = ByteCode[table_name] + body_string = "" + if element is ByteCode.instructions: + instruction_list_str = construct_instruction(contents) + elif element is ByteCode.constants: + constant_list_str = construct_constants(contents) + elif element is ByteCode.operators: + operator_list_str = construct_operators(contents) + elif element is ByteCode.types: + type_list_str = construct_types(contents) + elif element is ByteCode.register_size: + register_size_str = construct_register_size(contents) + + one_upgrader_function_string = ONE_UPGRADER_FUNCTION.substitute( + upgrader_name=upgrader_name, + instruction_list=instruction_list_str, + constant_list=constant_list_str, + type_list=type_list_str, + register_size=register_size_str, + ) + one_upgrader_src_string = ONE_UPGRADER_SRC.substitute( + bytecode_function=one_upgrader_function_string.lstrip("\n"), + operator_string_list=operator_list_str.lstrip("\n"), + ) + all_upgrader_src_string.append(one_upgrader_src_string) + + upgrader_file_content = UPGRADER_CPP_SRC.substitute( + operator_version_map=version_map_src, + upgrader_bytecode="".join(all_upgrader_src_string).lstrip("\n"), + ) + body_parts.append(upgrader_file_content) + print("writing file to : ", cpp_path + "/" + UPGRADER_MOBILE_FILE_NAME) + with open(os.path.join(cpp_path, UPGRADER_MOBILE_FILE_NAME), "wb") as out_file: + final_output = "".join(body_parts) + out_file.write(upgrader_file_content.encode("utf-8")) + + +def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]: + sorted_upgrader_list = sorted( + upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader)) + ) + return sorted_upgrader_list + + +def main() -> None: + upgrader_list = generate_upgraders_bytecode() + sorted_upgrader_list = sort_upgrader(upgrader_list) + for up in sorted_upgrader_list: + print("after sort upgrader : ", next(iter(up))) + + pytorch_dir = Path(__file__).resolve().parents[2] + upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "mobile" + write_cpp(str(upgrader_path), sorted_upgrader_list) + + +if __name__ == "__main__": + main() diff --git a/lib/python3.10/site-packages/torchgen/operator_versions/gen_mobile_upgraders_constant.py b/lib/python3.10/site-packages/torchgen/operator_versions/gen_mobile_upgraders_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..04b5ad887e54153115eeca7b6686d7c2de8dfc06 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/operator_versions/gen_mobile_upgraders_constant.py @@ -0,0 +1,7 @@ +MOBILE_UPGRADERS_HEADER_DESCRIPTION = """/** + * @generated + * This is an auto-generated file. Please do not modify it by hand. + * To re-generate, please run: + * cd ~/pytorch && python torchgen/operator_versions/gen_mobile_upgraders.py + */ +""" diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/native/native_functions.yaml b/lib/python3.10/site-packages/torchgen/packaged/ATen/native/native_functions.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c7533e4ef854c9cfe0ed1cbfb5ba26e059f3940c --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/native/native_functions.yaml @@ -0,0 +1,15701 @@ +# See README.md in this directory for more guidance + +# *********NB: _cast_* operators are DEPRECATED and will be removed +# eventually. These were previously used before TorchScript IR supported +# representing ScalarType's. They are now superseded by usage of +# `aten::to()`. The ops remain here for backward compatibility purposes. + +# DEPRECATED. DO NOT USE +- func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Char(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Double(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Float(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Int(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Long(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Short(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# DEPRECATED. DO NOT USE +- func: _cast_Half(Tensor self, bool non_blocking=False) -> Tensor + variants: function + +# Computes the gradient of current tensor w.r.t. graph leaves. +- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> () + manual_cpp_binding: True + variants: method + +# DEPRECATED. Sets the tensor data held by this `Variable` to be the same as +# `new_data`. It requires that `new_data` and `Variable` have compatible tensor +# type, by checking `_has_compatible_shallow_copy_type(this, new_data)`. +# +# This function is deprecated because it doesn't really make sense in a world +# where Variables *are* Tensors (as opposed to them containing tensors, which +# is what the previous interpretation was.) +- func: set_data(Tensor(a!) self, Tensor new_data) -> () + manual_cpp_binding: True + variants: method + +- func: data(Tensor self) -> Tensor + manual_cpp_binding: True + variants: method + +# True if this `Variable` is a leaf and thus does not have a `grad_fn`. +- func: is_leaf(Tensor self) -> bool + manual_cpp_binding: True + variants: method + +# Returns the output index of this variable from the forward operation that +# produced it. Conversely, it returns the input index of the gradient `Node` to +# which this `Variable` is connected (because in the gradient computation, +# inputs and outputs switch meaning). For example: +# +# y0, y1, y2 = f(x) +# assert y0.output_nr == 0 +# assert y1.output_nr == 1 +# assert y2.output_nr == 2 +# +- func: output_nr(Tensor self) -> int + manual_cpp_binding: True + variants: method + +- func: _version(Tensor self) -> int + manual_cpp_binding: True + variants: method + +- func: requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!) + manual_cpp_binding: True + variants: method + +# Enables .grad attribute for non-leaf Tensors. +- func: retain_grad(Tensor(a!) self) -> () + manual_cpp_binding: True + variants: method + +- func: retains_grad(Tensor self) -> bool + manual_cpp_binding: True + variants: method + +- func: _fw_primal(Tensor(a) self, int level) -> Tensor(a) + variants: method + dispatch: + CompositeExplicitAutograd: _fw_primal + +- func: _make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a) + variants: function + dispatch: + CompositeExplicitAutograd: _make_dual + +- func: _unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent) + variants: function + +# NOTE: [_new_zeros_with_same_feature_meta] +# This function creates a new tensor with the layout and TensorOptions +# of `other` but also takes into account the batch dimensions of `self` +# +# This function has a couple extra constraints because it is also used for `jvp` +# in functorch. +# - is used for forward AD because there is the restriction +# that the primal and tangent must have the same layout +# - We cannot assume that `self` and `other` have the same sizes or even dim +# because in the inplace over view case, `other` is the base tensor, and +# `self` is the forward grad with respect to the view, which can have an +# entirely different shape +# - takes the number of batch dims for `self` because we also handle +# some batching logic. We handle that here instead of a batching rule because +# we'd like to avoid calling as_strided in the batching rule (as to enable +# nested vmap in functorch). +# - needs to be CompositeExplicitAutograd for jvp support in functorch. +# functorch currently relies on TensorWrapper which does not have storage +# CompositeExplicitAutograd makes sure the TensorWrapper is unwrapped. +# - this function may eventually take on another int argument to store the +# the number of batch dims for other once we support that use case +- func: _new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _new_zeros_with_same_feature_meta + autogen: _new_zeros_with_same_feature_meta.out + +# This function compares the storage numel of self with that of other, where +# storage numel is computed as: `other.storage().nbytes() / other.itemsize()`. +# We create this function for composite compliance purposes. The batching rule +# always returns true because vmapped as_strided does not support accessing +# storage locations not indexable by the input tensor. +# See the note above for more information. +- func: _has_same_storage_numel(Tensor self, Tensor other) -> bool + variants: function + dispatch: + CompositeExplicitAutograd: _has_same_storage_numel + +- func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!) + variants: method + tags: inplace_view + +- func: rename(Tensor(a) self, Dimname[]? names) -> Tensor(a) + variants: method + +- func: align_to(Tensor(a) self, Dimname[] names) -> Tensor(a) + variants: method + +- func: align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a) + variants: method + +- func: align_as(Tensor self, Tensor other) -> Tensor + variants: method + +- func: align_tensors(Tensor[] tensors) -> Tensor[] + +# Not assert because it's a keyword; not Assert because FX already +# took that syntax +# TODO: need to specify this is side-effectful somehow +- func: _assert_async(Tensor self) -> () + dispatch: + CPU: _assert_async_cpu + CUDA: _assert_async_cuda + +- func: _assert_async.msg(Tensor self, str assert_msg) -> () + dispatch: + CPU: _assert_async_msg_cpu + CUDA: _assert_async_msg_cuda + +- func: _assert_scalar(Scalar self, str assert_msg) -> () + dispatch: + CompositeExplicitAutograd: _assert_scalar + +- func: _functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor + dispatch: + CompositeExplicitAutograd: _functional_assert_scalar + +- func: _functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor + dispatch: + CPU: _functional_assert_async_msg_cpu + +- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> () + +- func: _print(str s) -> () + dispatch: + CompositeExplicitAutograd: _print + +- func: sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> () + dispatch: + CompositeExplicitAutograd: sym_constrain_range + +- func: sym_constrain_range_for_size(Scalar size, *, int? min=None, int? max=None) -> () + dispatch: + CompositeExplicitAutograd: sym_constrain_range_for_size + +- func: _functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor + dispatch: + CompositeExplicitAutograd: _functional_sym_constrain_range + +- func: _functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor + dispatch: + CompositeExplicitAutograd: _functional_sym_constrain_range_for_size + +- func: _make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + dispatch: + CPU: _make_dep_token_cpu + +- func: refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a) + variants: method + +- func: _use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool + device_check: NoCheck # Tensor arguments allowed to be on different devices, see also _cudnn_ctc_loss + dispatch: + CUDA: _use_cudnn_ctc_loss + +- func: _use_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> bool + device_check: NoCheck # Tensor arguments allowed to be on different devices, see also _cudnn_ctc_loss + dispatch: + CUDA: _use_cudnn_ctc_loss_tensor + +- func: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + device_check: NoCheck # log_probs is expected to be on CUDA while targets is expected to be on CPU + dispatch: + CUDA: _cudnn_ctc_loss + autogen: _cudnn_ctc_loss.out + +- func: _cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + device_check: NoCheck # log_probs is expected to be on CUDA while targets is expected to be on CPU + dispatch: + CUDA: _cudnn_ctc_loss_tensor + +- func: _use_cudnn_rnn_flatten_weight() -> bool + +- func: _cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor + dispatch: + CUDA: _cudnn_rnn_flatten_weight + autogen: _cudnn_rnn_flatten_weight.out + +- func: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + # rnn_tanh may or may not redispatch to _cudnn_rnn based on algorithm and build. Thus it might hit dispatch or kernel device check. + # Disable dispatch time device check for consistent behavior. + device_check: NoCheck + dispatch: + CUDA: _cudnn_rnn + autogen: _cudnn_rnn.out + tags: nondeterministic_seeded + +- func: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + dispatch: + CUDA: _cudnn_rnn_backward + autogen: _cudnn_rnn_backward.out + +- func: _cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + dispatch: + CUDA: _cudnn_init_dropout_state + autogen: _cudnn_init_dropout_state.out + tags: nondeterministic_seeded + +- func: _debug_has_internal_overlap(Tensor self) -> int + variants: function + +- func: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor) + variants: function + dispatch: + CUDA: fused_dropout_cuda + tags: nondeterministic_seeded + autogen: _fused_dropout.out + +- func: _masked_scale(Tensor self, Tensor mask, float scale) -> Tensor + variants: function + dispatch: + CUDA: masked_scale_cuda + autogen: _masked_scale.out + +- func: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) + variants: function + dispatch: + CPU: native_dropout_cpu + CUDA: native_dropout_cuda + NestedTensorCPU, NestedTensorCUDA: native_dropout_nested + tags: [nondeterministic_seeded, core] + autogen: native_dropout.out + +- func: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor + dispatch: + CPU, NestedTensorCPU, NestedTensorCUDA: native_dropout_backward + CUDA: native_dropout_backward_cuda + autogen: native_dropout_backward.out + tags: pointwise + +- func: _sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor) + +- func: _sobol_engine_ff_(Tensor(a!) self, int n, Tensor sobolstate, int dimension, int num_generated) -> Tensor(a!) + +- func: _sobol_engine_scramble_(Tensor(a!) self, Tensor ltm, int dimension) -> Tensor(a!) + +- func: _sobol_engine_initialize_state_(Tensor(a!) self, int dimension) -> Tensor(a!) + +- func: _reshape_from_tensor(Tensor self, Tensor shape) -> Tensor + +- func: _shape_as_tensor(Tensor self) -> Tensor + +- func: dropout(Tensor input, float p, bool train) -> Tensor + tags: nondeterministic_seeded + +- func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + tags: nondeterministic_seeded + +- func: feature_dropout(Tensor input, float p, bool train) -> Tensor + tags: nondeterministic_seeded + +- func: feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + tags: nondeterministic_seeded + +- func: alpha_dropout(Tensor input, float p, bool train) -> Tensor + tags: nondeterministic_seeded + +- func: alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + tags: nondeterministic_seeded + +- func: feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor + tags: nondeterministic_seeded + +- func: feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + tags: nondeterministic_seeded + +- func: abs(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: abs + SparseCPU, SparseCUDA: abs_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr + NestedTensorCPU, NestedTensorCUDA: NestedTensor_abs + tags: [core, pointwise] + +- func: abs_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: abs_ + SparseCPU, SparseCUDA: abs_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr_ + NestedTensorCPU, NestedTensorCUDA: NestedTensor_abs_ + +- func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: abs_out + MPS: abs_out_mps + SparseCPU, SparseCUDA: abs_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr_out + tags: pointwise + +# Note [Adding an alias] +# To add an alias do the following: +# +# 1) Copy the original functions native_functions.yaml entry, but replace the +# original function's name with their own and delete any dispatch +# keys for the aliases. Specifying a dispatch key will prevent +# autograd from recording the operations the alias performs, which +# will stop it from "inheriting" the original operation's autograd behavior. +# 2) Implement the corresponding functions and have them redispatch to the +# original function. +# 3) Add docstrings to the new function that reference the original function, +# and document the method as usual (if it exists.) +# (See torch/_torch_docs.py and docs/source/torch.rst if adding a function, +# torch/_tensor_docs.py and docs/source/tensors.rst if adding a method, +# or module-specific doc bindings (like torch/linalg/__init__.py) if +# adding an alias in a namespace.) +# 4) Update torch/overrides.py consistent with the original function. +# 5) Update the alias_map in torch/csrc/jit/passes/normalize_ops.cpp. +# 6) Add aliases argument to existing OpInfo/UnaryUfuncInfo or create new OpInfo/UnaryUfuncInfo entry +# in op_db list in torch/testing/_internal/common_methods_invocations.py +# +# See torch.absolute, an alias for torch.abs, as an example. +# Absolute, alias for abs + +- func: absolute(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: absolute_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + +- func: absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: angle(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CPU, CUDA: angle + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr + tags: pointwise + +- func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: angle_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr_out + tags: pointwise + +- func: view_as_real(Tensor(a) self) -> Tensor(a) + variants: function + dispatch: + CPU, CUDA, MPS, Meta: view_as_real + +- func: view_as_complex(Tensor(a) self) -> Tensor(a) + variants: function + dispatch: + CPU, CUDA, MPS, Meta: view_as_complex + +- func: sgn(Tensor self) -> Tensor + variants: function, method + structured_delegate: sgn.out + dispatch: + SparseCPU, SparseCUDA: sgn_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sgn_sparse_csr + NestedTensorCPU, NestedTensorCUDA: NestedTensor_sgn + tags: pointwise + +- func: sgn_(Tensor(a!) self) -> Tensor(a!) + variants: method + structured_delegate: sgn.out + dispatch: + SparseCPU, SparseCUDA: sgn_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sgn_sparse_csr_ + NestedTensorCPU, NestedTensorCUDA: NestedTensor_sgn_ + tags: pointwise + +- func: sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: sgn_out + MPS: sgn_out_mps + SparseCPU, SparseCUDA: sgn_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sgn_sparse_csr_out + tags: pointwise + +- func: chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor + variants: method + +- func: real(Tensor(a) self) -> Tensor(a) + device_check: NoCheck # TensorIterator + variants: function + +- func: imag(Tensor(a) self) -> Tensor(a) + device_check: NoCheck # TensorIterator + variants: function + +- func: _conj(Tensor(a) self) -> Tensor(a) + variants: function, method + dispatch: + CompositeExplicitAutograd: _conj + +- func: conj(Tensor(a) self) -> Tensor(a) + variants: function, method + manual_cpp_binding: True + +- func: _conj_physical(Tensor self) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: _conj_physical + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: conj_physical_sparse_csr + autogen: _conj_physical.out + +- func: conj_physical(Tensor self) -> Tensor + variants: function, method + tags: pointwise + +- func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: conj_physical_out + MPS: conj_physical_out_mps + SparseCPU, SparseCUDA: conj_physical_out_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: conj_physical_sparse_csr_out + tags: pointwise + +- func: conj_physical_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + dispatch: + CompositeExplicitAutograd: conj_physical_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: conj_physical_sparse_csr_ + tags: pointwise + +- func: resolve_conj(Tensor(a) self) -> Tensor(a) + variants: function, method + +- func: resolve_neg(Tensor(a) self) -> Tensor(a) + variants: function, method + +- func: _neg_view(Tensor(a) self) -> Tensor(a) + variants: function, method + dispatch: + CompositeExplicitAutograd: _neg_view + +- func: acos(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: acos.out + tags: [core, pointwise] + +- func: acos_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: acos.out + tags: pointwise + +- func: acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: acos_out + MPS: acos_out_mps + tags: pointwise + +# arccos, alias of acos +- func: arccos(Tensor self) -> Tensor + variants: function, method + +- func: arccos_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + +- func: arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + +- func: avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor + tags: core + autogen: avg_pool1d.out + +- func: adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor + tags: core + autogen: adaptive_avg_pool1d.out + +# Return: (Tensor output, Tensor indices) +- func: adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor) + +- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: add.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: add_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr + MkldnnCPU: mkldnn_add + ZeroTensor: add_zerotensor + NestedTensorCPU, NestedTensorCUDA: NestedTensor_add_Tensor + tags: [core, pointwise] + +- func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: add.out + dispatch: + SparseCPU, SparseCUDA, SparseMeta: add_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr_ + MkldnnCPU: mkldnn_add_ + NestedTensorCPU, NestedTensorCUDA: NestedTensor_add__Tensor + tags: pointwise + +- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + ufunc_inner_loop: + Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf) + ScalarOnly: add (Bool) + dispatch: + SparseCPU, SparseMeta: add_out_sparse_cpu + SparseCUDA: add_out_sparse_cuda + SparseCsrCPU, SparseCsrMeta: add_out_sparse_compressed_cpu + SparseCsrCUDA: add_out_sparse_compressed_cuda + MkldnnCPU: mkldnn_add_out + MPS: add_out_mps + tags: pointwise + +- func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + variants: function + dispatch: + CPU: add_relu + +- func: _add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + variants: function + dispatch: + CPU: add_relu_ + +- func: _add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + variants: function + dispatch: + CPU: add_relu_out + +- func: _add_relu.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + variants: function + dispatch: + CPU: add_relu + +- func: _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + variants: function + dispatch: + CPU: add_relu_ + autogen: _add_relu.Scalar_out + +# For C++ only, until we have conversion from C++ numbers to Tensor +- func: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: add + tags: [core, pointwise] + +- func: add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: add_ + autogen: add.Scalar_out + tags: pointwise + +- func: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor + structured_delegate: addmv.out + variants: function, method + +- func: addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + structured_delegate: addmv.out + variants: function, method + +- func: addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU: addmv_out_cpu + CUDA: addmv_out_cuda + MPS: addmv_out_mps + SparseCsrCPU: addmv_out_sparse_compressed + SparseCsrCUDA: addmv_out_sparse_compressed_cuda + +- func: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + variants: function, method + dispatch: + CPU, CUDA: addr + MPS: addr_mps + CompositeExplicitAutograd: math_addr + +- func: addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + variants: method + dispatch: + CompositeExplicitAutograd: addr_ + +- func: addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: addr_out + MPS: addr_out_mps + CompositeExplicitAutograd: math_addr_out + +- func: affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: affine_grid_generator + autogen: affine_grid_generator.out + +- func: affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor + variants: function + +- func: _is_all_true(Tensor self) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: _is_all_true + +- func: _is_any_true(Tensor self) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: _is_any_true + +# Note: this function is only for testing. +- func: _test_check_tensor(Tensor self) -> Tensor + variants: function + +# Note; this function is only for testing +- func: _test_functorch_fallback(Tensor self, Tensor other) -> Tensor + variants: function + dispatch: + CPU: _test_functorch_fallback + autogen: _test_functorch_fallback.out + +- func: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: all.out + variants: function, method + dispatch: + NestedTensorCPU, NestedTensorCUDA: NestedTensor_all + + +- func: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: all.dims_out + variants: function, method + cpp_no_default_args: ['dim'] + dispatch: + CompositeExplicitAutograd: all_dims_default + +- func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + dispatch: + CPU, CUDA: all_out + MPS: all_out_mps + +- func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + dispatch: + CPU, CUDA: all_dims_out + CompositeExplicitAutograd: all_dims_out_default + cpp_no_default_args: ['dim'] + +- func: all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool + variants: function, method + tags: data_dependent_output + dispatch: + CompositeExplicitAutograd: allclose + +- func: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: any.out + variants: function, method + tags: core + +- func: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: any.dims_out + variants: function, method + cpp_no_default_args: ['dim'] + tags: core + dispatch: + CompositeExplicitAutograd: any_dims_default + +- func: any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + dispatch: + CPU, CUDA: any_out + MPS: any_out_mps + +- func: any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + dispatch: + CPU, CUDA: any_dims_out + CompositeExplicitAutograd: any_dims_out_default + cpp_no_default_args: ['dim'] + +- func: any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: arange + +- func: arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: arange + +# This operator should be named `arange.start_out` if following the naming convention. However that +# name is already taken. Disabled because of CI job failures. +# FIXME: enable this +#- func: arange.start_out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!) +# dispatch: +# CompositeExplicitAutograd: arange_start_out + +- func: arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: arange + cpp_no_default_args: ['step'] + tags: core + +- func: arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: arange_out + +- func: arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, Meta: arange_out + CUDA: arange_cuda_out + MPS: arange_mps_out + cpp_no_default_args: ['step'] + +# This function is a temporary hack to allow tracing of arange like constructs with dynamic +# bounds on arange. Normal arange is not traceable because it does not take any tensor inputs; +# if the range you need is based on another tensor, calling this function directly will +# preserve tracing. Get rid of this when arange can directly take tensors for bounds +# (so that it can be traced directly). +- func: _dim_arange(Tensor like, int dim) -> Tensor + +- func: argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor + structured_delegate: argmax.out + device_check: NoCheck # TensorIterator + variants: function, method + tags: core + +- func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU, CUDA: argmax_out + MPS: argmax_out_mps + +- func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor + structured_delegate: argmin.out + device_check: NoCheck # TensorIterator + variants: function, method + tags: core + +- func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU, CUDA: argmin_out + MPS: argmin_out_mps + +- func: acosh(Tensor self) -> Tensor + variants: function, method + structured_delegate: acosh.out + tags: [core, pointwise] + +- func: acosh_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + structured_delegate: acosh.out + tags: pointwise + +- func: acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: acosh_out + MPS: acosh_out_mps + tags: pointwise +# arccosh, alias for acosh + +- func: arccosh(Tensor self) -> Tensor + variants: function, method + +- func: arccosh_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + +- func: arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + +- func: asinh(Tensor self) -> Tensor + variants: function, method + structured_delegate: asinh.out + dispatch: + SparseCPU, SparseCUDA: asinh_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asinh_sparse_csr + tags: [core, pointwise] + +- func: asinh_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + structured_delegate: asinh.out + dispatch: + SparseCPU, SparseCUDA: asinh_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asinh_sparse_csr_ + tags: pointwise + +- func: asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: asinh_out + MPS: asinh_out_mps + SparseCPU, SparseCUDA: asinh_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asinh_sparse_csr_out + tags: pointwise + +# arcsinh, alias for asinh +- func: arcsinh(Tensor self) -> Tensor + variants: function, method + +- func: arcsinh_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + +- func: arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + +- func: atanh(Tensor self) -> Tensor + structured_delegate: atanh.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: atanh_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atanh_sparse_csr + tags: [core, pointwise] + +- func: atanh_(Tensor(a!) self) -> Tensor(a!) + structured_delegate: atanh.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: atanh_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atanh_sparse_csr_ + tags: pointwise + +- func: atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: atanh_out + MPS: atanh_out_mps + SparseCPU, SparseCUDA: atanh_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atanh_sparse_csr_out + tags: pointwise +# arctanh, alias for atanh + +- func: arctanh(Tensor self) -> Tensor + variants: function, method + +- func: arctanh_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + +- func: arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + +- func: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) + variants: function, method + dispatch: + ZeroTensor, CPU, CUDA: as_strided_tensorimpl + Meta: as_strided_tensorimpl_meta_symint + MPS: as_strided_tensorimpl_mps + QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl + device_check: NoCheck + device_guard: False + tags: core + +- func: as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) + use_const_ref_for_mutable_tensors: True + variants: function, method + device_check: NoCheck + device_guard: False + tags: inplace_view + dispatch: + CompositeExplicitAutogradNonFunctional: as_strided__symint + +- func: asin(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: asin.out + dispatch: + SparseCPU, SparseCUDA: asin_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asin_sparse_csr + tags: [core, pointwise] + +- func: asin_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: asin.out + dispatch: + SparseCPU, SparseCUDA: asin_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asin_sparse_csr_ + tags: pointwise + +- func: asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: asin_out + MPS: asin_out_mps + SparseCPU, SparseCUDA: asin_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asin_sparse_csr_out + tags: pointwise + +# arcsin, alias of asin +- func: arcsin(Tensor self) -> Tensor + variants: function, method + +- func: arcsin_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + +- func: arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + +- func: atan(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: atan.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: atan_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atan_sparse_csr + tags: [core, pointwise] + +- func: atan_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: atan.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: atan_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atan_sparse_csr_ + tags: pointwise + +- func: atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: atan_out + MPS: atan_out_mps + SparseCPU, SparseCUDA: atan_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atan_sparse_csr_out + tags: pointwise + +# arctan, alias of atan +- func: arctan(Tensor self) -> Tensor + variants: function, method + +- func: arctan_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + +- func: arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + +- func: atleast_1d(Tensor self) -> Tensor + variants: function + +- func: atleast_1d.Sequence(Tensor[] tensors) -> Tensor[] + +- func: atleast_2d(Tensor self) -> Tensor + variants: function + +- func: atleast_2d.Sequence(Tensor[] tensors) -> Tensor[] + variants: function + +- func: atleast_3d(Tensor self) -> Tensor + variants: function + +- func: atleast_3d.Sequence(Tensor[] tensors) -> Tensor[] + variants: function + +- func: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + variants: function, method + structured_delegate: baddbmm.out + +- func: baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + variants: method + structured_delegate: baddbmm.out + +- func: baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + dispatch: + CPU: baddbmm_out_cpu + CUDA: baddbmm_out_cuda + MPS: baddbmm_out_mps + SparseCsrCUDA: baddbmm_out_sparse_csr_cuda + +- func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: bartlett_window + autogen: bartlett_window.out + +- func: bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: bartlett_window + autogen: bartlett_window.periodic_out + +- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor + +- func: quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor + dispatch: + QuantizedCPU: quantized_batch_norm + autogen: quantized_batch_norm.out + +- func: _batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int) + +- func: _batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor) + +# Sample bernoulli with values in `self` as probability. +- func: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: bernoulli + tags: nondeterministic_seeded + +- func: bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + tags: nondeterministic_seeded + dispatch: + CPU, CUDA: bernoulli_out + MPS: bernoulli_out_mps + +- func: bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + tags: nondeterministic_seeded + dispatch: + CPU, CUDA: bernoulli_ + MPS: bernoulli_mps_ + autogen: bernoulli.Tensor, bernoulli.Tensor_out + +- func: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + tags: nondeterministic_seeded + dispatch: + CPU, CUDA: bernoulli_ + MPS: bernoulli_mps_ + autogen: bernoulli.float_out + +# Note [bernoulli.p schema] +# We should probably just fix the overload ambiguity by appending a _functional to the C++ API name (BC breaking) +# This out-of-place version isn't used explicitly, but needed by jit. +# There is no default valid on `p` here because it would introduce ambiguity +# with `bernoulli(Tensor self, *, Generator? generator=None)` declaration. +- func: bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutogradNonFunctional: bernoulli + +- func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor + +- func: binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + device_check: NoCheck # TensorIterator + python_module: nn + variants: function + dispatch: + CPU: binary_cross_entropy_cpu + CUDA: binary_cross_entropy_cuda + MPS: binary_cross_entropy_mps + +- func: binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: nn + variants: function + dispatch: + CPU: binary_cross_entropy_out_cpu + CUDA: binary_cross_entropy_out_cuda + MPS: binary_cross_entropy_out_mps + +- func: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + python_module: nn + variants: function + dispatch: + CPU: binary_cross_entropy_backward_cpu + CUDA: binary_cross_entropy_backward_cuda + MPS: binary_cross_entropy_backward_mps + +- func: binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + variants: function + dispatch: + CPU: binary_cross_entropy_backward_out_cpu + CUDA: binary_cross_entropy_backward_out_cuda + MPS: binary_cross_entropy_backward_out_mps + +- func: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: binary_cross_entropy_with_logits + autogen: binary_cross_entropy_with_logits.out + +- func: bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor + variants: function, method + dispatch: + CPU: _bincount_cpu + CUDA: _bincount_cuda + MPS: _bincount_mps + tags: dynamic_output_shape + autogen: bincount.out + +- func: bitwise_not(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: bitwise_not.out + variants: function, method + tags: [core, pointwise] + +- func: bitwise_not_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: bitwise_not.out + variants: method + tags: pointwise + +- func: bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: bitwise_not_out + MPS: bitwise_not_out_mps + tags: pointwise + +- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA, MPS: copysign_out + tags: pointwise + +- func: copysign.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: copysign.out + tags: pointwise + +- func: copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: copysign.out + +- func: copysign.Scalar(Tensor self, Scalar other) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: copysign + tags: pointwise + +- func: copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + variants: method + dispatch: + CompositeExplicitAutograd: copysign_ + +- func: copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: copysign_out + tags: pointwise + +- func: _lazy_clone(Tensor self) -> Tensor + # Like clone, but the copy takes place lazily, only if either the + # input or the output are written. + variants: function, method + dispatch: + CompositeExplicitAutograd: _lazy_clone + +- func: logical_not(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: logical_not + NestedTensorCPU, NestedTensorCUDA: NestedTensor_logical_not + tags: [core, pointwise] + +- func: logical_not_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: logical_not_ + NestedTensorCPU, NestedTensorCUDA: NestedTensor_logical_not_ + tags: pointwise + +- func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: logical_not_out + MPS: logical_not_out_mps + tags: pointwise + +- func: logical_xor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: logical_xor + tags: [core, pointwise] + +- func: logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: logical_xor_ + tags: pointwise + +- func: logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: logical_xor_out + MPS: logical_xor_out_mps + tags: pointwise + +- func: logical_and(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: logical_and + tags: [core, pointwise] + +- func: logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: logical_and_ + tags: pointwise + +- func: logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: logical_and_out + MPS: logical_and_out_mps + tags: pointwise + +- func: logical_or(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: logical_or + tags: [core, pointwise] + +- func: logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: logical_or_ + tags: pointwise + +- func: logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: logical_or_out + MPS: logical_or_out_mps + tags: pointwise + +- func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: blackman_window + autogen: blackman_window.out + +- func: blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: blackman_window + autogen: blackman_window.periodic_out + +- func: bmm(Tensor self, Tensor mat2) -> Tensor + structured_delegate: bmm.out + variants: function, method + dispatch: + SparseCPU: bmm_sparse_cpu + SparseCUDA: bmm_sparse_cuda + NestedTensorCPU: bmm_nested + NestedTensorCUDA: bmm_nested_cuda + tags: core + +- func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + dispatch: + CPU: bmm_out_cpu + CUDA: bmm_out_cuda + MPS: bmm_out_mps + SparseCPU: bmm_out_sparse_cpu + SparseCUDA: bmm_out_sparse_cuda + SparseCsrCUDA: bmm_out_sparse_csr_cuda + +- func: broadcast_tensors(Tensor[] tensors) -> Tensor[] + device_check: NoCheck + device_guard: False + +- func: broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) + variants: function, method + dispatch: + CompositeImplicitAutograd: broadcast_to_symint + +- func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) + variants: function + dispatch: + SparseCPU, SparseCUDA: sparse_broadcast_to + +- func: cat(Tensor[] tensors, int dim=0) -> Tensor + structured_delegate: cat.out + dispatch: + SparseCPU, SparseCUDA: cat_sparse + QuantizedCPU: cat_quantized_cpu + NestedTensorCPU, NestedTensorCUDA: cat_nested + tags: core + +- func: cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + structured: True + precomputed: + - dim -> int dim, int valid, bool all_contiguous, bool all_same_dtype, bool all_same_sizes_and_stride, MemoryFormat memory_format + dispatch: + CPU: cat_out_cpu + CUDA: cat_out_cuda + MPS: cat_out_mps + QuantizedCPU: cat_out_quantized_cpu + +- func: cat.names(Tensor[] tensors, Dimname dim) -> Tensor + +- func: cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + +# alias for torch.cat +- func: concat(Tensor[] tensors, int dim=0) -> Tensor + +- func: concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + +- func: concat.names(Tensor[] tensors, Dimname dim) -> Tensor + +- func: concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + +# alias for torch.cat +- func: concatenate(Tensor[] tensors, int dim=0) -> Tensor + +- func: concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + +- func: concatenate.names(Tensor[] tensors, Dimname dim) -> Tensor + +- func: concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + +- func: block_diag(Tensor[] tensors) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: block_diag + autogen: block_diag.out + +- func: ceil(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: ceil.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: ceil_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ceil_sparse_csr + tags: [core, pointwise] + +- func: ceil_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: ceil.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: ceil_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ceil_sparse_csr_ + tags: pointwise + +- func: ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: ceil_out + MPS: ceil_out_mps + SparseCPU, SparseCUDA: ceil_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ceil_sparse_csr_out + tags: pointwise + +# alias for torch.linalg.multi_dot +- func: chain_matmul(Tensor[] matrices) -> Tensor + variants: function + +# alias for torch.linalg.multi_dot +- func: chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!) + +- func: unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[] + variants: function, method + device_check: NoCheck + device_guard: False + +- func: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[] + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeImplicitAutograd: chunk + NestedTensorCPU, NestedTensorCUDA: chunk_nested_tensor + +- func: tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[] + variants: function, method + dispatch: + CompositeImplicitAutograd: tensor_split_sections_symint + +- func: tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[] + variants: function, method + dispatch: + CompositeImplicitAutograd: tensor_split_indices_symint + +- func: tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[] + variants: function, method + +- func: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + cpp_no_default_args: ['min'] + structured_delegate: clamp.out + dispatch: + QuantizedCPU: clamp_quantized_cpu + tags: [core, pointwise] + +- func: clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor + variants: function, method + structured_delegate: clamp.Tensor_out + tags: [core, pointwise] + +- func: clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + cpp_no_default_args: ['min'] + structured_delegate: clamp.out + tags: pointwise + +- func: clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) + variants: function, method + structured_delegate: clamp.Tensor_out + tags: pointwise + +- func: clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + cpp_no_default_args: ['min'] + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: clamp_out + MPS: clamp_out_mps + tags: pointwise + +- func: clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: clamp_Tensor_out + MPS: clamp_Tensor_out_mps + tags: pointwise + +- func: clamp_max(Tensor self, Scalar max) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: clamp_max.out + tags: pointwise + +- func: clamp_max.Tensor(Tensor self, Tensor max) -> Tensor + variants: function, method + structured_delegate: clamp_max.Tensor_out + tags: pointwise + +- func: clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: clamp_max.out + tags: pointwise + +- func: clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!) + variants: function, method + structured_delegate: clamp_max.Tensor_out + tags: pointwise + +- func: clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: clamp_max_out + MPS: clamp_max_out_mps + tags: pointwise + +- func: clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: clamp_max_Tensor_out + MPS: clamp_max_Tensor_out_mps + tags: pointwise + +- func: clamp_min(Tensor self, Scalar min) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: clamp_min.out + tags: pointwise + +- func: clamp_min.Tensor(Tensor self, Tensor min) -> Tensor + variants: function, method + structured_delegate: clamp_min.Tensor_out + tags: pointwise + +- func: clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: clamp_min.out + tags: pointwise + +- func: clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!) + variants: function, method + structured_delegate: clamp_min.Tensor_out + tags: pointwise + +- func: clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: clamp_min_out + MPS: clamp_min_out_mps + tags: pointwise + +- func: clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: clamp_min_Tensor_out + MPS: clamp_min_Tensor_out_mps + tags: pointwise + +# clip is an alias for clamp +- func: clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor + cpp_no_default_args: ['min'] + variants: function, method + tags: pointwise + +- func: clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor + variants: function, method + tags: pointwise + +- func: clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) + cpp_no_default_args: ['min'] + variants: function, method + tags: pointwise + +- func: clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) + variants: function, method + tags: pointwise + +- func: clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + cpp_no_default_args: ['min'] + tags: pointwise + +- func: clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) + +- func: cudnn_is_acceptable(Tensor self) -> bool + device_check: NoCheck + device_guard: False + +- func: complex(Tensor real, Tensor imag) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: complex + +- func: complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: complex_out + MPS: complex_out_mps + +- func: polar(Tensor abs, Tensor angle) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: polar + +- func: polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: polar_out + MPS: polar_out_mps + +- func: constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: constant_pad_nd + MPS: constant_pad_nd_mps + autogen: constant_pad_nd.out + tags: core + +- func: contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a) + variants: method + manual_cpp_binding: True + +- func: convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + dispatch: + CompositeExplicitAutograd: convolution + autogen: convolution.out + tags: core + +- func: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + dispatch: + CompositeExplicitAutograd, CUDA: convolution_backward + autogen: convolution_backward.out + tags: core + +- func: convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + dispatch: + CompositeExplicitAutograd: convolution_overrideable + autogen: convolution_overrideable.out + +- func: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + dispatch: + CompositeExplicitAutograd: convolution_backward_overrideable + autogen: convolution_backward_overrideable.out + +- func: _convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor + dispatch: + CompositeExplicitAutograd: _convolution + autogen: _convolution.out + +- func: _convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, int[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor + +- func: _convolution_mode(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, str padding, SymInt[] dilation, SymInt groups) -> Tensor + dispatch: + CompositeImplicitAutograd: _convolution_mode_symint + +- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + +- func: conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor + dispatch: + CompositeImplicitAutograd: conv1d_symint + +- func: conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor + dispatch: + CompositeImplicitAutograd: conv2d_symint + +- func: conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor + dispatch: + CompositeImplicitAutograd: conv3d_symint + +- func: conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor + cpp_no_default_args: ['bias', 'stride', 'padding'] + dispatch: + CompositeImplicitAutograd: conv1d_padding_symint + +- func: conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor + cpp_no_default_args: ['bias', 'stride', 'padding'] + dispatch: + CompositeImplicitAutograd: conv2d_padding_symint + +- func: conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding="valid", SymInt[3] dilation=1, SymInt groups=1) -> Tensor + cpp_no_default_args: ['bias', 'stride', 'padding'] + dispatch: + CompositeImplicitAutograd: conv3d_padding_symint + +- func: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor + dispatch: + CompositeExplicitAutograd: conv_tbc + autogen: conv_tbc.out + +- func: conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int pad) -> (Tensor, Tensor, Tensor) + +# NB: we inherit the goofy argument order from PyTorch torch.nn.functional +- func: conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] output_padding=0, SymInt groups=1, SymInt[1] dilation=1) -> Tensor + dispatch: + CompositeImplicitAutograd: conv_transpose1d_symint + +- func: conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor + dispatch: + CompositeImplicitAutograd: conv_transpose2d_symint + +- func: conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt groups=1, SymInt[3] dilation=1) -> Tensor + dispatch: + CompositeImplicitAutograd: conv_transpose3d_symint + +- func: copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor + variants: function + dispatch: + Meta: copy_meta + CompositeExplicitAutogradNonFunctional: copy + tags: core + +- func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + dispatch: + MkldnnCPU: copy_mkldnn_ + SparseCPU, SparseCUDA: copy_sparse_wrapper_ + CompositeExplicitAutograd: copy_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: copy_sparse_compressed_ + NestedTensorCPU, NestedTensorCUDA: copy_nested_ + autogen: copy.out + +- func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor + dispatch: + MPS: _copy_from_mps + autogen: _copy_from.out + +# We need this to be able to properly copy from a CPU to an XLA tensor with different sizes. +# See https://github.com/pytorch/xla/issues/2881 +- func: _copy_from_and_resize(Tensor self, Tensor dst) -> Tensor + dispatch: + MPS: _copy_from_and_resize_mps + autogen: _copy_from_and_resize.out + +- func: cos(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: cos.out + dispatch: + NestedTensorCPU, NestedTensorCUDA: cos_nested + tags: [core, pointwise] + +- func: cos_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: cos.out + tags: pointwise + +- func: cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: cos_out + MPS: cos_out_mps + tags: pointwise + +- func: cosh(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: cosh.out + tags: [core, pointwise] + +- func: cosh_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: cosh.out + tags: pointwise + +- func: cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: cosh_out + MPS: cosh_out_mps + tags: pointwise + +- func: cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor + +- func: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor + variants: function, method + dispatch: + CPU: count_nonzero_cpu + CUDA: count_nonzero_cuda + MPS: count_nonzero_mps + autogen: count_nonzero.dim_IntList_out + +- func: count_nonzero(Tensor self, int? dim=None) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: count_nonzero + autogen: count_nonzero.out + +- func: cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor + variants: function, method + +- func: corrcoef(Tensor self) -> Tensor + variants: function, method + +- func: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid + dispatch: + CUDA: cudnn_affine_grid_generator_forward + autogen: cudnn_affine_grid_generator.out + +# TODO: Why do I have to call this grad?! +- func: cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta + dispatch: + CUDA: cudnn_affine_grid_generator_backward + autogen: cudnn_affine_grid_generator_backward.out + +- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) + dispatch: + CUDA: cudnn_batch_norm + autogen: cudnn_batch_norm.out + +# NB: You can only use this if you used cudnn_batch_norm training=True +- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: cudnn_batch_norm_backward + autogen: cudnn_batch_norm_backward.out + +- func: cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + dispatch: + CUDA: cudnn_convolution + +- func: cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CUDA: cudnn_convolution_out + +- func: cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + dispatch: + CUDA: cudnn_convolution_transpose + autogen: cudnn_convolution_transpose.out + +- func: _mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + dispatch: + MPS: _mps_convolution_transpose + autogen: _mps_convolution_transpose.out + +- func: mps_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask) -> (Tensor, Tensor) + dispatch: + MPS: mps_convolution_transpose_backward + autogen: mps_convolution_transpose_backward.out + +- func: cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + dispatch: + CUDA: cudnn_convolution_relu + autogen: cudnn_convolution_relu.out + +- func: cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + dispatch: + CUDA: cudnn_convolution_add_relu + autogen: cudnn_convolution_add_relu.out + +# NB: input is special cased in a way I don't quite understand +- func: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output + dispatch: + CUDA: cudnn_grid_sampler_forward + autogen: cudnn_grid_sampler.out + +- func: cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) -> (Tensor grad_self, Tensor grad_grid) + dispatch: + CUDA: cudnn_grid_sampler_backward + autogen: cudnn_grid_sampler_backward.out + +- func: cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: cummax + +- func: cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + device_check: NoCheck # TensorIterator + dispatch: + CompositeExplicitAutograd: cummax_out + +- func: cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) + device_check: NoCheck # TensorIterator + variants: function, method + +- func: cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + device_check: NoCheck # TensorIterator + +- func: _cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> () + variants: function + dispatch: + CPU: cummax_helper_cpu + CUDA: cummax_helper_cuda + +- func: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: cummin + +- func: cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + device_check: NoCheck # TensorIterator + dispatch: + CompositeExplicitAutograd: cummin_out + +- func: cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) + device_check: NoCheck # TensorIterator + variants: function, method + +- func: cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + device_check: NoCheck # TensorIterator + +- func: _cummin_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> () + variants: function + dispatch: + CPU: cummin_helper_cpu + CUDA: cummin_helper_cuda + +- func: cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor + variants: function + device_check: NoCheck + device_guard: False + +- func: cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + structured_delegate: cumprod.out + device_check: NoCheck # TensorIterator + variants: function, method + +- func: cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) + structured_delegate: cumprod.out + variants: method + +- func: cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + structured: True + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: cumprod_out + MPS: cumprod_out_mps + +- func: cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) + variants: method + +- func: cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: cumprod_backward(Tensor grad, Tensor input, int dim, Tensor output) -> Tensor + variants: function + device_check: NoCheck + device_guard: False + +- func: cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + structured_delegate: cumsum.out + device_check: NoCheck # TensorIterator + variants: function, method + tags: core + +- func: cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) + structured_delegate: cumsum.out + variants: method + +- func: cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + structured: True + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: cumsum_out + MPS: cumsum_out_mps + +- func: cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) + variants: method + +- func: cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: cumulative_trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor + +- func: cumulative_trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor + +- func: ctc_loss.IntList(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor + +# convenience function that converts to intlists for you +- func: ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor + +- func: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + dispatch: + CPU: ctc_loss_cpu + CUDA: ctc_loss_gpu + Meta: ctc_loss_meta + autogen: _ctc_loss.out + tags: dynamic_output_shape # the shape of second output is data dependent + +- func: _ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + dispatch: + CPU, CUDA: ctc_loss_tensor + autogen: _ctc_loss.Tensor_out + tags: dynamic_output_shape # the shape of second output is data dependent + +- func: _ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor + dispatch: + CPU: ctc_loss_backward_cpu + CUDA: ctc_loss_backward_gpu + autogen: _ctc_loss_backward.out + +- func: _ctc_loss_backward.Tensor(Tensor grad, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor + dispatch: + CPU, CUDA: ctc_loss_backward_tensor + +- func: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutogradNonFunctional: diag_embed + autogen: diag_embed.out + +- func: diagflat(Tensor self, int offset=0) -> Tensor + variants: function, method + +- func: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) + variants: function, method + dispatch: + CompositeExplicitAutograd: diagonal + tags: core + +- func: linalg_diagonal(Tensor(a) A, *, int offset=0, int dim1=-2, int dim2=-1) -> Tensor(a) + python_module: linalg + variants: function + +- func: diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a) + variants: function, method + +- func: diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor + variants: function + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: diagonal_backward_symint + autogen: diagonal_backward.out + +- func: fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!) + variants: method + +- func: diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor + variants: function, method + +- func: diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!) + variants: function + +- func: gradient.scalarint(Tensor self, *, Scalar? spacing=None, int? dim=None, int edge_order=1) -> Tensor[] + variants: function + +- func: gradient.scalararray(Tensor self, *, Scalar spacing, int[] dim, int edge_order=1) -> Tensor[] + variants: function + +- func: gradient.array(Tensor self, *, int[] dim, int edge_order=1) -> Tensor[] + variants: function + +- func: gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[] + variants: function + +- func: gradient.scalarrayarray(Tensor self, *, Scalar[] spacing, int[] dim, int edge_order=1) -> Tensor[] + variants: function + +- func: gradient.tensorarrayint(Tensor self, *, Tensor[] spacing, int? dim=None, int edge_order=1) -> Tensor[] + variants: function + +- func: gradient.tensorarray(Tensor self, *, Tensor[] spacing, int[] dim, int edge_order=1) -> Tensor[] + variants: function + +- func: div.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: div.out + dispatch: + SparseCPU, SparseCUDA: div_sparse + ZeroTensor: div_zerotensor + NestedTensorCPU, NestedTensorCUDA: NestedTensor_div_Tensor + tags: [core, pointwise] + +- func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: div.out + dispatch: + SparseCPU, SparseCUDA: div_sparse_ + tags: pointwise + +- func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: div_out + MPS: div_out_mps + SparseCPU, SparseCUDA: div_out_sparse_zerodim + tags: pointwise + +- func: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: div.out_mode + dispatch: + SparseCPU, SparseCUDA: div_sparse + tags: [core, pointwise] + +- func: div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: div.out_mode + dispatch: + SparseCPU, SparseCUDA: div_sparse_ + tags: pointwise + +- func: div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: div_out_mode + MPS: div_out_mode_mps + SparseCPU, SparseCUDA: div_out_sparse_zerodim + tags: pointwise + +# For C++ only, until we have conversion from C++ numbers to Tensor +- func: div.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: div + NestedTensorCPU, NestedTensorCUDA: NestedTensor_div_Scalar + tags: [core, pointwise] + +- func: div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: div_ + autogen: div.Scalar_out + tags: pointwise + +- func: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: div + tags: [core, pointwise] + +- func: div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) + variants: method + dispatch: + CompositeExplicitAutograd: div_ + autogen: div.Scalar_mode_out + tags: pointwise + +# divide, alias for div +- func: divide.Tensor(Tensor self, Tensor other) -> Tensor + variants: function, method + +- func: divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + +- func: divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + +- func: divide.Scalar(Tensor self, Scalar other) -> Tensor + variants: function, method + +- func: divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + variants: method + +- func: divide.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + variants: function, method + +- func: divide_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) + variants: method + +- func: divide.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + +- func: divide.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor + variants: function, method + +- func: divide_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) + variants: method + + # true_divide, an alias for div +- func: true_divide.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + tags: pointwise + +- func: true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + +- func: true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: true_divide.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + +- func: dot(Tensor self, Tensor tensor) -> Tensor + variants: function, method + dispatch: + CPU: dot + CUDA: dot_cuda + MPS: dot_mps + +- func: dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: dot_out + +- func: vdot(Tensor self, Tensor other) -> Tensor + variants: function, method + dispatch: + CPU: vdot + CUDA: vdot_cuda + +- func: vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: vdot_out + +- func: einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor + +- func: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor + dispatch: + CompositeExplicitAutograd: embedding_symint + NestedTensorCPU, NestedTensorCUDA: NestedTensor_embedding + autogen: embedding.out + tags: core + +- func: embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor + dispatch: + CompositeImplicitAutograd: embedding_backward_symint + +- func: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor + dispatch: + CPU: embedding_dense_backward_cpu + CUDA: embedding_dense_backward_cuda + MPS: embedding_dense_backward_mps + autogen: embedding_dense_backward.out + tags: core + +- func: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!) + dispatch: + CPU: embedding_renorm_cpu_ + CUDA: embedding_renorm_cuda_ + autogen: embedding_renorm, embedding_renorm.out + +- func: embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor + +# NOTE [ embedding_bag Native Functions ] +# The `_embedding_bag.*` variants assume that input tensors except for `weight`, +# e.g. `indices` and `offsets` (and `offset2bag`), are contiguous. +# We really only need to enforce this for `_embedding_bag` (the forward) because +# the backward inputs are the same as forward ones. +# The above `embedding_bag` wrapper is created to achieve this, e.g., +# applying indices = indices.contiguous(). +# The backward functions apply a check that these input tensors are contiguous. + + +- func: _embedding_bag_forward_only(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) + dispatch: + CPU: _embedding_bag_forward_only_cpu + CUDA: _embedding_bag_forward_only_cuda + autogen: _embedding_bag_forward_only.out + +- func: _rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor) + +# row_stack is the alias of vstack +- func: row_stack(Tensor[] tensors) -> Tensor + +- func: row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + +- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) + +# To keep backward and forward compatibility, and to avoid ambiguity with the +# original signature above, scale_grad_by_freq, mode, sparse, +# per_sample_weights, and include_last_offset parameters do not have default +# values. Once the original signature is removed, default values can be added. +- func: embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) + +- func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) + dispatch: + CPU: _embedding_bag_cpu + CUDA: _embedding_bag_cuda + autogen: _embedding_bag.out + tags: core + +- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + dispatch: + CPU, CUDA: _embedding_bag_backward_symint + +- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + dispatch: + CompositeImplicitAutograd: _embedding_bag_sparse_backward_symint + +- func: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + dispatch: + CPU: _embedding_bag_dense_backward_cpu + CUDA: _embedding_bag_dense_backward_cuda + autogen: _embedding_bag_dense_backward.out + +- func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor + dispatch: + CPU: _embedding_bag_per_sample_weights_backward_cpu + CUDA: _embedding_bag_per_sample_weights_backward_cuda + autogen: _embedding_bag_per_sample_weights_backward.out + +- func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: empty_names + autogen: empty.names_out + +- func: empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + dispatch: + CPU: empty_cpu + CUDA: empty_cuda + MPS: empty_mps + Meta: empty_meta_symint + MkldnnCPU: empty_mkldnn + SparseCPU, SparseCUDA: empty_sparse + SparseMeta: empty_sparse_symint + SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed + SparseCsrMeta: empty_sparse_compressed_symint + QuantizedCPU, QuantizedCUDA, QuantizedMeta: empty_unknown_quantized + tags: core + +- func: empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: empty_permuted_symint + autogen: empty_permuted.out + +# We do not make new_empty a composite that calls into new_empty_strided, as the strided version +# is significantly more difficult to implement by different backends +- func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + variants: method + dispatch: + CompositeExplicitAutograd: new_empty_symint + autogen: new_empty.out + +- func: new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + variants: method + dispatch: + CompositeExplicitAutogradNonFunctional: new_empty_strided_symint + autogen: new_empty_strided.out + +- func: new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + variants: method + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: new_full + autogen: new_full.out + +- func: new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + variants: method + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: new_zeros + autogen: new_zeros.out + +- func: new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + variants: method + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: new_ones + autogen: new_ones.out + +# other overrides are to provide a more helpful error message that dtype is required +- func: _empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor + dispatch: + CPU: empty_affine_quantized_other_backends_stub + QuantizedCPU, QuantizedCUDA: empty_affine_quantized + autogen: _empty_affine_quantized.out + +# it's a factory function receiving a tensor argument, thus overriding explicitly +# other overrides are to provide a more helpful error message that dtype is required +- func: _empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor + category_override: factory + dispatch: + CPU: empty_per_channel_affine_quantized_other_backends_stub + QuantizedCPU, QuantizedCUDA: empty_per_channel_affine_quantized + autogen: _empty_per_channel_affine_quantized.out + +- func: resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) + use_const_ref_for_mutable_tensors: True + variants: method + device_check: NoCheck + device_guard: False + tags: [core, inplace_view] + dispatch: + Meta: resize__symint + CPU: resize_ + CUDA: resize_cuda_ + MPS: resize_mps_ + QuantizedCPU: quantized_resize_cpu_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: resize_sparse_csr_ + autogen: resize, resize.out + +# This is a utility function to enable users to resize out tensor while registering kernels for out variants. +# Eventually, we can consider exposing `resize_output` as a public API to ship it with python op registration +# to make it easy to register out variants for ops. +- func: _resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!) + use_const_ref_for_mutable_tensors: True + variants: function + dispatch: + Meta: _resize_output_ + autogen: _resize_output, _resize_output.out + +- func: empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + category_override: factory + variants: function + dispatch: + QuantizedCPU, QuantizedCUDA: empty_quantized + autogen: empty_quantized.out + +- func: empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + device_guard: False + +- func: empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: empty_like + QuantizedCPU, QuantizedCUDA: empty_like_quantized + SparseCPU, SparseCUDA, SparseMeta: empty_like_sparse_coo + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: empty_like_sparse_csr + NestedTensorCPU, NestedTensorCUDA: empty_like_nested + autogen: empty_like.out + +- func: empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CPU: empty_strided_cpu + CUDA: empty_strided_cuda + MPS: empty_strided_mps + Meta: empty_strided_meta_symint + QuantizedCPU, QuantizedCUDA: empty_strided_unknown_quantized + autogen: empty_strided.out + tags: core + +- func: erf(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: erf.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: erf_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erf_sparse_csr + tags: [core, pointwise] + +- func: erf_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: erf.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: erf_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erf_sparse_csr_ + tags: pointwise + +- func: erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: erf_out + MPS: erf_out_mps + SparseCPU, SparseCUDA: erf_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erf_sparse_csr_out + tags: pointwise + +- func: erfc(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: erfc.out + variants: function, method + tags: pointwise + +- func: erfc_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: erfc.out + variants: function, method + tags: pointwise + +- func: erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: erfc_out + tags: pointwise + +- func: exp(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: exp.out + variants: function, method + tags: [core, pointwise] + +- func: exp_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: exp.out + variants: function, method + tags: pointwise + +- func: exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: exp_out + MPS: exp_out_mps + tags: pointwise + +- func: exp2(Tensor self) -> Tensor + structured_delegate: exp2.out + variants: function, method + tags: pointwise + +- func: exp2_(Tensor(a!) self) -> Tensor(a!) + structured_delegate: exp2.out + variants: function, method + tags: pointwise + +- func: exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: exp2_out + MPS: exp2_out_mps + tags: pointwise + +- func: expm1(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: expm1.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: expm1_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: expm1_sparse_csr + tags: [core, pointwise] + +- func: expm1_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: expm1.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: expm1_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: expm1_sparse_csr_ + tags: pointwise + +- func: expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: expm1_out + MPS: expm1_out_mps + SparseCPU, SparseCUDA: expm1_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: expm1_sparse_csr_out + tags: pointwise + +- func: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) + variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: expand + tags: core + +- func: expand_as(Tensor(a) self, Tensor other) -> Tensor(a) + variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. + device_check: NoCheck + device_guard: False + +# decomposes to eye.m +- func: eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: eye + +- func: eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: eye + +- func: eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, Meta: eye_out_cpu + CUDA: eye_out_cuda + MPS: eye_out_mps + +- func: eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, Meta: eye_out_cpu + CUDA: eye_out_cuda + MPS: eye_out_mps + +- func: flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a) + variants: function, method + +- func: flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a) + variants: function, method + +- func: flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a) + variants: function, method + +- func: flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a) + variants: function, method + +- func: unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a) + variants: function, method + dispatch: + CompositeImplicitAutograd: unflatten_symint + +- func: unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a) + variants: function, method + dispatch: + CompositeImplicitAutograd: unflatten_dimname_symint + +- func: fill.Scalar(Tensor self, Scalar value) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: fill + tags: core + +- func: fill.Tensor(Tensor self, Tensor value) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: fill + +- func: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CPU, CUDA: fill_ + MPS: fill_scalar_mps + QuantizedCPU, QuantizedCUDA: fill_quantized_ + Meta: fill_meta_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: fill_sparse_csr_ + NestedTensorCPU, NestedTensorCUDA: fill_nested_ + autogen: fill.Scalar_out + +- func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CPU, CUDA: fill_ + MPS: fill_tensor_mps_ + QuantizedCPU, QuantizedCUDA: fill_quantized_ + Meta: fill_meta_ + NestedTensorCPU, NestedTensorCUDA: fill_nested_ + autogen: fill.Tensor_out + +- func: floor(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: floor.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: floor_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: floor_sparse_csr + tags: [core, pointwise] + +- func: floor_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: floor.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: floor_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: floor_sparse_csr_ + tags: pointwise + +- func: floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: floor_out + MPS: floor_out_mps + SparseCPU, SparseCUDA: floor_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: floor_sparse_csr_out + tags: pointwise + +- func: floor_divide(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CPU, CUDA: floor_divide + MPS: floor_divide_mps + SparseCPU, SparseCUDA: floor_divide_sparse + +- func: floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CPU, CUDA: floor_divide_ + MPS: floor_divide_mps_ + SparseCPU, SparseCUDA: floor_divide_sparse_ + +- func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: floor_divide_out + MPS: floor_divide_out_mps + SparseCPU, SparseCUDA: floor_divide_out_sparse_zerodim + +- func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: floor_divide + +- func: floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: floor_divide_ + autogen: floor_divide.Scalar_out + +- func: frac(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: frac.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: frac_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: frac_sparse_csr + tags: pointwise + +- func: frac_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: frac.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: frac_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: frac_sparse_csr_ + tags: pointwise + +- func: frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: frac_out + MPS: frac_out_mps + SparseCPU, SparseCUDA: frac_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: frac_sparse_csr_out + tags: pointwise + +- func: full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: full + autogen: full.names_out + +- func: full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: full + tags: core + +- func: full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: full_out + +- func: full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: full_like + autogen: full_like.out + +- func: from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CPU: from_file + autogen: from_file.out + +- func: gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: gcd_out + tags: pointwise + +- func: gcd(Tensor self, Tensor other) -> Tensor + structured_delegate: gcd.out + variants: function, method + tags: pointwise + +- func: gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!) + structured_delegate: gcd.out + variants: function, method + +- func: lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: lcm_out + tags: pointwise + +- func: lcm(Tensor self, Tensor other) -> Tensor + structured_delegate: lcm.out + variants: function, method + tags: pointwise + +- func: lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!) + structured_delegate: lcm.out + variants: function, method + +# NOTE [ grid_sampler Native Functions ] +# `grid_sampler` is _supposed to_ do all the shape checking and then dispatch to +# one of `cudnn_grid_sampler`, `grid_sampler_2d`, or `grid_sampler_3d`, each of +# which has the corresponding backward defined as native functions as well. +# However, we do shape checking everywhere for now since each of the mentioned +# functions can be called directly, which will lead to crashes otherwise. +# See https://github.com/pytorch/pytorch/issues/73187 for more information. +# +# There is also _grid_sampler_2d_backward_cpu_fallback which is an +# implementation detail of grid_sampler_2d and is only exposed here for testing +# purposes. +# +# Additionally, arguments `padding_mode` and `interpolation_mode` are cast to +# enums defined in `native/GridSampler.h`. `cudnn_grid_sampler` doesn't take in +# `interpolation_mode` because it only supports Bilinear interpolation mode. +# Nor does it take in `align_corners` because it only supports the mode +# `align_corners = True`. +- func: grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + +- func: grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + dispatch: + CPU, QuantizedCPU: grid_sampler_2d_cpu + CUDA: grid_sampler_2d_cuda + MPS: grid_sampler_2d_mps + autogen: grid_sampler_2d.out + tags: core + +# `grid_sampler_2d_backward` takes in `output_mask` to optimize performance for +# the case where `input` doesn't require gradient. Gradient for `grid` is always +# computed (only `output_mask[0]` is checked by the implementations). +- func: grid_sampler_2d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor) + dispatch: + CPU: grid_sampler_2d_backward_cpu + CUDA: grid_sampler_2d_backward_cuda + autogen: grid_sampler_2d_backward.out + +# See NOTE [ grid_sample CPU fallback ] +- func: _grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + dispatch: + CompositeExplicitAutograd: _grid_sampler_2d_cpu_fallback + autogen: _grid_sampler_2d_cpu_fallback.out + +- func: _grid_sampler_2d_cpu_fallback_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor) + +- func: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + dispatch: + CPU: grid_sampler_3d_cpu + CUDA: grid_sampler_3d_cuda + autogen: grid_sampler_3d.out + +# `grid_sampler_3d_backward` takes in `output_mask` to optimize performance for +# the case where `input` doesn't require gradient. Gradient for `grid` is always +# computed (only `output_mask[0]` is checked by the implementations). +- func: grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor) + dispatch: + CPU: grid_sampler_3d_backward_cpu + CUDA: grid_sampler_3d_backward_cuda + autogen: grid_sampler_3d_backward.out + +- func: hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: hann_window + autogen: hann_window.out + +- func: hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: hann_window + autogen: hann_window.periodic_out + +- func: hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: hamming_window + autogen: hamming_window.out + +- func: hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: hamming_window + autogen: hamming_window.periodic_out + +- func: hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: hamming_window + autogen: hamming_window.periodic_alpha_out + +- func: hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: hamming_window + autogen: hamming_window.periodic_alpha_beta_out + +- func: kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: kaiser_window + autogen: kaiser_window.out + +- func: kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: kaiser_window + autogen: kaiser_window.periodic_out + +- func: kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: kaiser_window + autogen: kaiser_window.beta_out + +- func: hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor + +- func: group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor + +- func: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) + dispatch: + CPU, CUDA: native_group_norm + CompositeExplicitAutograd: math_group_norm + autogen: native_group_norm.out + tags: core + +- func: native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + dispatch: + CPU, CUDA: native_group_norm_backward + autogen: native_group_norm_backward.out + tags: core + +# Real to complex forward FFT +- func: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor + variants: function + dispatch: + CPU: _fft_r2c_mkl + CUDA: _fft_r2c_cufft + MPS: _fft_r2c_mps + +- func: _fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!) + variants: function + dispatch: + CPU: _fft_r2c_mkl_out + CUDA: _fft_r2c_cufft_out + MPS: _fft_r2c_mps_out + +# Complex to real inverse FFT +- func: _fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor + variants: function + dispatch: + CPU: _fft_c2r_mkl + CUDA: _fft_c2r_cufft + MPS: _fft_c2r_mps + +- func: _fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!) + variants: function + dispatch: + CPU: _fft_c2r_mkl_out + CUDA: _fft_c2r_cufft_out + MPS: _fft_c2r_mps_out + +# Standard complex to complex FFT (forward or backward) +- func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor + variants: function + dispatch: + CPU: _fft_c2c_mkl + CUDA: _fft_c2c_cufft + MPS: _fft_c2c_mps + +- func: _fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) + variants: function + dispatch: + CPU: _fft_c2c_mkl_out + CUDA: _fft_c2c_cufft_out + MPS: _fft_c2c_mps_out + +- func: _validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> () + device_check: NoCheck + variants: function + dispatch: + CPU: _validate_compressed_sparse_indices_cpu + CUDA: _validate_compressed_sparse_indices_cuda + +- func: _cufft_get_plan_cache_size(DeviceIndex device_index) -> int + +- func: _cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int + +- func: _cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> () + +- func: _cufft_clear_plan_cache(DeviceIndex device_index) -> () + +- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: index.Tensor_out + variants: function, method + dispatch: + QuantizedCPU: quantized_index + tags: [core, dynamic_output_shape] + # NB: This function is special-cased in tools/autograd/gen_variable_type.py + # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: + # - Tensor Tensor::index(ArrayRef indices) + # - Tensor Tensor::index(std::initializer_list indices) + +- func: index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + structured: True + structured_inherits: TensorIteratorBase + precomputed: + - indices -> DimVector sizes, DimVector strides + dispatch: + CPU, CUDA, MPS: index_out + +# Used by inductor to signal indexing without bounds checks +# Note that we don't support boolean indexing, to avoid dynamic output shapes +- func: _unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _unsafe_index + +# Used by inductor to generate masked loads +# Note that we don't support boolean indexing, to avoid dynamic output shapes +- func: _unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _unsafe_masked_index + +- func: _unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _unsafe_masked_index_put_accumulate + +- func: index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + precomputed: + - dim -> int dim + dispatch: + CPU, CUDA: index_copy_out + +- func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) + variants: method + structured_delegate: index_copy.out + +- func: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor + variants: function, method + structured_delegate: index_copy.out + +- func: index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!) + variants: method + +- func: index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor + variants: function, method + +- func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) + device_check: NoCheck # delegate to _index_put_impl_, which leverages TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: index_put_ + autogen: index_put.out + # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: + # - Tensor & Tensor::index_put_(ArrayRef indices, Tensor const & rhs) + # - Tensor & Tensor::index_put_(ArrayRef indices, Scalar v) + # - Tensor & Tensor::index_put_(std::initializer_list indices, Tensor const & rhs) + # - Tensor & Tensor::index_put_(std::initializer_list indices, Scalar v) + +- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + device_check: NoCheck # delegate to _index_put_impl_ after clone, which leverages TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: index_put + tags: core + +- func: _unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + device_check: NoCheck # delegate to _index_put_impl_ after clone, which leverages TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: _unsafe_index_put + +- func: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CPU, CUDA, MPS: _index_put_impl_ + QuantizedCPU: _index_put_impl_quantized_cpu_ + QuantizedCUDA: _index_put_impl_quantized_cuda_ + autogen: _index_put_impl, _index_put_impl.out + +- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor + variants: function + +- func: isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor + variants: function, method + +- func: isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + variants: function + structured: True + dispatch: + CPU, CUDA: isin_Tensor_Tensor_out + MPS: isin_Tensor_Tensor_out_mps + +- func: isin.Tensor_Tensor(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor + variants: function + structured_delegate: isin.Tensor_Tensor_out + +- func: isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + variants: function + structured: True + dispatch: + CPU, CUDA: isin_Tensor_Scalar_out + +- func: isin.Tensor_Scalar(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False) -> Tensor + variants: function + structured_delegate: isin.Tensor_Scalar_out + +- func: isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + variants: function + structured: True + dispatch: + CPU, CUDA: isin_Scalar_Tensor_out + +- func: isin.Scalar_Tensor(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor + variants: function + structured_delegate: isin.Scalar_Tensor_out + +- func: isnan(Tensor self) -> Tensor + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CPU, CUDA, MPS: isnan + SparseCPU, SparseCUDA: isnan_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isnan_sparse_csr + autogen: isnan.out + tags: [core, pointwise] + +- func: is_distributed(Tensor self) -> bool + variants: function, method + device_check: NoCheck + device_guard: False + +- func: is_floating_point(Tensor self) -> bool + variants: function, method + device_check: NoCheck + device_guard: False + manual_cpp_binding: True + +- func: is_complex(Tensor self) -> bool + variants: function, method + device_check: NoCheck + device_guard: False + manual_cpp_binding: True + +- func: is_conj(Tensor self) -> bool + variants: function, method + device_guard: False + manual_cpp_binding: True + +- func: _is_zerotensor(Tensor self) -> bool + variants: function, method + device_guard: False + manual_cpp_binding: True + +- func: is_neg(Tensor self) -> bool + variants: function, method + device_guard: False + manual_cpp_binding: True + +- func: isreal(Tensor self) -> Tensor + variants: function, method + +- func: is_nonzero(Tensor self) -> bool + variants: function, method + device_check: NoCheck + device_guard: False + +- func: is_same_size(Tensor self, Tensor other) -> bool + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + NestedTensorCPU, NestedTensorCUDA: nested_is_same_size + CompositeExplicitAutograd: is_same_size + +- func: is_signed(Tensor self) -> bool + variants: function, method + device_check: NoCheck + device_guard: False + manual_cpp_binding: True + +- func: is_inference(Tensor self) -> bool + variants: function, method + device_check: NoCheck + device_guard: False + manual_cpp_binding: True + +- func: kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor + +- func: kron(Tensor self, Tensor other) -> Tensor + variants: function, method + +- func: kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + +- func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + variants: function, method + dispatch: + CompositeExplicitAutograd: kthvalue + +- func: kthvalue.values(Tensor self, int k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + dispatch: + CPU: kthvalue_out_cpu + CUDA: kthvalue_out_cuda + +- func: kthvalue.dimname(Tensor self, int k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + variants: function, method + +- func: kthvalue.dimname_out(Tensor self, int k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + +- func: layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor + dispatch: + CompositeImplicitAutograd: layer_norm_symint + +- func: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) + dispatch: + CPU: layer_norm_cpu + CUDA: layer_norm_cuda + MPS: layer_norm_mps + CompositeExplicitAutograd: math_native_layer_norm + NestedTensorCPU, NestedTensorCUDA: nested_layer_norm + autogen: native_layer_norm.out + tags: core + +- func: native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + dispatch: + CPU: layer_norm_backward_cpu + CUDA: layer_norm_backward_cuda + MPS: layer_norm_backward_mps + NestedTensorCPU, NestedTensorCUDA: layer_norm_backward_nested + autogen: native_layer_norm_backward.out + tags: core + +- func: rms_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor + +- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: nan_to_num + SparseCPU, SparseCUDA: nan_to_num_sparse + tags: pointwise + +- func: nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!) + variants: function, method + dispatch: + CompositeExplicitAutograd: nan_to_num_ + SparseCPU, SparseCUDA: nan_to_num_sparse_ + tags: pointwise + +- func: nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: nan_to_num_out + MPS: nan_to_num_out_mps + SparseCPU, SparseCUDA: nan_to_num_sparse_out + tags: pointwise + +- func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor + python_module: nn + dispatch: + CompositeImplicitAutograd: linear + NestedTensorCPU, NestedTensorCUDA: nested_linear + MPS: _mps_linear + +- func: linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + dispatch: + NestedTensorCPU, NestedTensorCUDA: nested_linear_backward + MPS: mps_linear_backward + autogen: linear_backward.out + +- func: linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CompositeExplicitAutograd: linear_out + +- func: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor + python_module: nn + dispatch: + MkldnnCPU: mkldnn_linear + autogen: mkldnn_linear.out + +- func: mkldnn_linear_backward_input(int[] input_size, Tensor grad_output, Tensor weight) -> Tensor + dispatch: + MkldnnCPU: mkldnn_linear_backward_input + autogen: mkldnn_linear_backward_input.out + +- func: mkldnn_linear_backward_weights(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined) -> (Tensor, Tensor) + dispatch: + MkldnnCPU: mkldnn_linear_backward_weights + autogen: mkldnn_linear_backward_weights.out + +- func: mkldnn_linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + dispatch: + MkldnnCPU: mkldnn_linear_backward + autogen: mkldnn_linear_backward.out + +- func: _cslt_compress(Tensor input) -> Tensor + dispatch: + CUDA: _cslt_compress + +- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0) -> Tensor + dispatch: + CUDA: _cslt_sparse_mm + +- func: _cslt_sparse_mm_search(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False) -> int + dispatch: + CUDA: _cslt_sparse_mm_search + +- func: _sparse_semi_structured_tile(Tensor input, str algorithm="", bool use_cutlass=True) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + dispatch: + CUDA: _sparse_semi_structured_tile + +- func: _sparse_semi_structured_apply(Tensor input, Tensor thread_masks) -> (Tensor, Tensor) + dispatch: + CUDA: _sparse_semi_structured_apply + +- func: _sparse_semi_structured_apply_dense(Tensor input, Tensor thread_masks) -> Tensor + dispatch: + CUDA: _sparse_semi_structured_apply_dense + +# DEPRECATED: Use torch.__sparse_semi_structured_mm/torch._sparse_semi_structured_addmm instead +- func: _sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor + dispatch: + CUDA: _sparse_semi_structured_linear + +- func: _sparse_semi_structured_mm(Tensor mat1, Tensor mat1_meta, Tensor mat2, *, ScalarType? out_dtype=None) -> Tensor + dispatch: + CUDA: _sparse_semi_structured_mm + +- func: _sparse_semi_structured_addmm(Tensor input, Tensor mat1, Tensor mat1_meta, Tensor mat2, *, Scalar alpha=1, Scalar beta=1, ScalarType? out_dtype=None) -> Tensor + dispatch: + CUDA: _sparse_semi_structured_addmm + +- func: _mixed_dtypes_linear(Tensor input, Tensor weight, Tensor scale, *, Tensor? bias=None, str? activation=None) -> Tensor + dispatch: + CUDA: _mixed_dtypes_linear + +- func: fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor + +- func: fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor + +- func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int) + +- func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor + +- func: _wrapped_linear_prepack(Tensor weight, Tensor weight_scale, Tensor weight_zero_point, Tensor bias) -> Tensor + +- func: _wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor + +- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + +- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + +- func: fbgemm_pack_quantized_matrix(Tensor input) -> Tensor + +- func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor + +- func: ldexp.Tensor(Tensor self, Tensor other) -> Tensor + variants: function, method + +- func: ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: function, method + tags: pointwise + +- func: ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + tags: pointwise + +- func: linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: linspace + +- func: linspace.Tensor_Tensor(Tensor start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + category_override: factory + dispatch: + CompositeExplicitAutograd: linspace + +- func: linspace.Tensor_Scalar(Tensor start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + category_override: factory + dispatch: + CompositeExplicitAutograd: linspace + +- func: linspace.Scalar_Tensor(Scalar start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + category_override: factory + dispatch: + CompositeExplicitAutograd: linspace + +- func: linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, Meta: linspace_out + CUDA: linspace_cuda_out + MPS: linspace_out_mps + +- func: linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) + category_override: factory + dispatch: + CompositeExplicitAutograd: linspace_out + +- func: linspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) + category_override: factory + dispatch: + CompositeExplicitAutograd: linspace_out + +- func: linspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) + category_override: factory + dispatch: + CompositeExplicitAutograd: linspace_out + +- func: log(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: log.out + variants: function, method + tags: [core, pointwise] + +- func: log_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: log.out + variants: function, method + tags: pointwise + +- func: log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: log_out + MPS: log_out_mps + tags: pointwise + +- func: log10(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: log10.out + variants: function, method + tags: [core, pointwise] + +- func: log10_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: log10.out + variants: function, method + tags: pointwise + +- func: log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: log10_out + MPS: log10_out_mps + tags: pointwise + +- func: log1p(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: log1p.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: log1p_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: log1p_sparse_csr + tags: [core, pointwise] + +- func: log1p_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: log1p.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: log1p_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: log1p_sparse_csr_ + tags: pointwise + +- func: log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: log1p_out + MPS: log1p_out_mps + SparseCPU, SparseCUDA: log1p_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: log1p_sparse_csr_out + tags: pointwise + +- func: log2(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: log2.out + variants: function, method + tags: [core, pointwise] + +- func: log2_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: log2.out + variants: function, method + tags: pointwise + +- func: log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: log2_out + MPS: log2_out_mps + tags: pointwise + +- func: logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: logaddexp_out + MPS: logaddexp_out_mps + tags: pointwise + +- func: logaddexp(Tensor self, Tensor other) -> Tensor + variants: method, function + structured_delegate: logaddexp.out + tags: pointwise + +- func: logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: logaddexp2_out + MPS: logaddexp2_out_mps + tags: pointwise + +- func: logaddexp2(Tensor self, Tensor other) -> Tensor + variants: method, function + structured_delegate: logaddexp2.out + tags: pointwise + +- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: xlogy.OutTensor + variants: function, method + tags: pointwise + +- func: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: xlogy + tags: pointwise + +- func: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: xlogy + tags: pointwise + +# xlogy: inplace variant +- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: xlogy.OutTensor + tags: pointwise + +- func: xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: xlogy_ + +# xlogy: out variant +- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + variants: function + dispatch: + CPU, CUDA: xlogy_out + MPS: xlogy_out_mps + tags: pointwise + +- func: xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: xlogy_out + tags: pointwise + +- func: xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: xlogy_out + tags: pointwise + +- func: logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: logspace + +- func: logspace.Tensor_Tensor(Tensor start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + category_override: factory + dispatch: + CompositeExplicitAutograd: logspace + +- func: logspace.Tensor_Scalar(Tensor start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + category_override: factory + dispatch: + CompositeExplicitAutograd: logspace + +- func: logspace.Scalar_Tensor(Scalar start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + category_override: factory + dispatch: + CompositeExplicitAutograd: logspace + +- func: logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, Meta: logspace_out + CUDA: logspace_cuda_out + +- func: logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + category_override: factory + dispatch: + CompositeExplicitAutograd: logspace_out + +- func: logspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + category_override: factory + dispatch: + CompositeExplicitAutograd: logspace_out + +- func: logspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + category_override: factory + dispatch: + CompositeExplicitAutograd: logspace_out + +# log_softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models. +- func: log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + variants: function, method + +- func: log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + variants: function + dispatch: + CompositeExplicitAutograd: log_softmax_out + +- func: log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + variants: function, method + +- func: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + structured_delegate: _log_softmax.out + tags: core + +- func: _log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU: log_softmax_cpu_out + CUDA: log_softmax_cuda_out + MPS: log_softmax_mps_out + +- func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor + structured_delegate: _log_softmax_backward_data.out + +- func: _log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU: log_softmax_backward_cpu_out + CUDA: log_softmax_backward_cuda_out + MPS: log_softmax_backward_mps_out + +- func: _logcumsumexp(Tensor self, int dim) -> Tensor + dispatch: + CPU: _logcumsumexp_cpu + CUDA: _logcumsumexp_cuda + +- func: _logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: _logcumsumexp_out_cpu + CUDA: _logcumsumexp_out_cuda + +- func: logcumsumexp(Tensor self, int dim) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: logcumsumexp + +- func: logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: logcumsumexp_out + +- func: logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor + variants: function, method + +- func: logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + +- func: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: logsumexp + +- func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + # calls squeeze + CompositeExplicitAutogradNonFunctional: logsumexp_out + +- func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor + +- func: matmul(Tensor self, Tensor other) -> Tensor + variants: function, method + dispatch: + CompositeImplicitAutograd: matmul + NestedTensorCPU, NestedTensorCUDA: matmul_nested + +- func: matmul_backward(Tensor grad, Tensor self, Tensor other, bool[2] mask) -> (Tensor, Tensor) + dispatch: + NestedTensorCPU, NestedTensorCUDA: matmul_backward_nested + autogen: matmul_backward.out + +- func: matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeImplicitAutograd: matmul_out + NestedTensorCPU, NestedTensorCUDA: matmul_out_nested + +# Alias to linalg.matrix_power +- func: matrix_power(Tensor self, int n) -> Tensor + variants: function, method + +# Alias to linalg.matrix_power +- func: matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!) + +# Alias to linalg.matrix_exp +- func: matrix_exp(Tensor self) -> Tensor + variants: function, method + +# This function should be deprecated in favor of differential_analytic_matrix_function in FunctionsManual.cpp +- func: matrix_exp_backward(Tensor self, Tensor grad) -> Tensor + +# DEPRECATED: Use torch.aminmax instead +- func: _aminmax(Tensor self) -> (Tensor, Tensor) + dispatch: + CPU, CUDA: _aminmax_all + autogen: _aminmax.out + +# DEPRECATED: Use torch.aminmax instead +- func: _aminmax.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor) + dispatch: + CPU, CUDA: _aminmax + autogen: _aminmax.dim_out + +- func: aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max) + device_check: NoCheck # TensorIterator + structured_delegate: aminmax.out + variants: function, method + +- func: aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max) + device_check: NoCheck # TensorIterator + structured: True + dispatch: + CPU, CUDA: aminmax_out + MPS: aminmax_out_mps + +- func: _compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor + dispatch: + CPU, CUDA: _compute_linear_combination + +- func: _compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: _compute_linear_combination_out + +- func: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + device_check: NoCheck # TensorIterator + structured_delegate: max.dim_max + variants: function, method + dispatch: + QuantizedCPU, QuantizedCUDA: qmax + tags: core + +- func: max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) + device_check: NoCheck # TensorIterator + structured: True + precomputed: + - dim -> int dim + dispatch: + CPU, CUDA: max_out + MPS: max_out_mps + +- func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + device_check: NoCheck # TensorIterator + variants: function, method + +- func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) + device_check: NoCheck # TensorIterator + +- func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor + variants: function + device_check: NoCheck + device_guard: False + dispatch: + CompositeImplicitAutograd: value_selecting_reduction_backward_symint + +- func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + variants: function, method + structured_delegate: amax.out + tags: core + +- func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU, CUDA: amax_out + MPS: amax_out_mps + +# Return: (Tensor output, Tensor indices) +- func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + +- func: max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor + +- func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + dispatch: + CompositeImplicitAutograd: max_pool2d + MPS: mps_max_pool2d + +- func: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + dispatch: + MPS: mps_max_pool2d_backward + autogen: max_pool2d_backward.out + +- func: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + dispatch: + MkldnnCPU: mkldnn_max_pool2d + autogen: mkldnn_max_pool2d.out + +- func: mkldnn_max_pool2d_backward(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + dispatch: + MkldnnCPU: mkldnn_max_pool2d_backward + autogen: mkldnn_max_pool2d_backward.out + +- func: mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + dispatch: + MkldnnCPU: mkldnn_max_pool3d + autogen: mkldnn_max_pool3d.out + +- func: mkldnn_max_pool3d_backward(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + dispatch: + MkldnnCPU: mkldnn_max_pool3d_backward + autogen: mkldnn_max_pool3d_backward.out + +- func: quantized_max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor + dispatch: + QuantizedCPU: quantized_max_pool1d + autogen: quantized_max_pool1d.out + +- func: quantized_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + dispatch: + QuantizedCPU: quantized_max_pool2d + QuantizedCUDA: quantized_max_pool2d_cudnn + autogen: quantized_max_pool2d.out + +- func: quantized_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + dispatch: + QuantizedCPU: quantized_max_pool3d + autogen: quantized_max_pool3d.out + +- func: max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + +# The CPU and GPU dispatch variants are named weirdly here because otherwise there +# are namespacing issues in C++ +- func: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: mean + tags: core + +# For normal naming convention this should be `mean.out`. However since we already have `mean.out` we have to rename this. +- func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CompositeExplicitAutograd: mean_dtype_out + +- func: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + structured_delegate: mean.out + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + QuantizedCPU: mean_quantized_cpu + tags: core + +- func: mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + structured: True + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: mean_out + MPS: mean_out_mps + QuantizedCPU: mean_out_quantized_cpu + +- func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + device_check: NoCheck # Composite + variants: function, method + +- func: nanmean.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # Composite + +- func: median(Tensor self) -> Tensor + variants: function, method + dispatch: + CPU: median_cpu + CUDA: median_cuda + MPS: median_mps + autogen: median.out + +- func: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + variants: function, method + dispatch: + CompositeExplicitAutograd: median + +- func: median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + dispatch: + CPU: median_out_cpu + CUDA: median_out_cuda + MPS: median_out_mps + +- func: median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + variants: function, method + +- func: median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + +- func: nanmedian(Tensor self) -> Tensor + variants: function, method + dispatch: + CPU: nanmedian_cpu + CUDA: nanmedian_cuda + autogen: nanmedian.out + +- func: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + variants: function, method + dispatch: + CompositeExplicitAutograd: nanmedian + +- func: nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + dispatch: + CPU: nanmedian_out_cpu + CUDA: nanmedian_out_cuda + +- func: nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + variants: function, method + +- func: nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + +- func: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + device_check: NoCheck # TensorIterator + structured_delegate: min.dim_min + variants: function, method + dispatch: + QuantizedCPU, QuantizedCUDA: qmin + tags: core + +- func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + device_check: NoCheck # TensorIterator + structured: True + precomputed: + - dim -> int dim + dispatch: + CPU, CUDA: min_out + MPS: min_out_mps + +- func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + device_check: NoCheck # TensorIterator + variants: function, method + +- func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + device_check: NoCheck # TensorIterator + +- func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + variants: function, method + structured_delegate: amin.out + tags: core + +- func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU, CUDA: amin_out + MPS: amin_out_mps + +# TODO: Add this function to MPS dispatch key so that we avoid declaring it in +# native_functions.yaml +# https://github.com/pytorch/pytorch/issues/77394 +- func: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + dispatch: + MPS: _mps_convolution + autogen: _mps_convolution.out + +- func: mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + dispatch: + MPS: mps_convolution_backward + autogen: mps_convolution_backward.out + +- func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + dispatch: + CompositeExplicitAutograd: mkldnn_convolution + autogen: mkldnn_convolution.out + +- func: mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor) + dispatch: + CPU: mkldnn_rnn_layer + MkldnnCPU: mkldnn_rnn_layer + autogen: mkldnn_rnn_layer.out + +- func: mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) + dispatch: + CPU: mkldnn_rnn_layer_backward + autogen: mkldnn_rnn_layer_backward.out + +- func: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: miopen_batch_norm + autogen: miopen_batch_norm.out + +- func: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: miopen_batch_norm_backward + autogen: miopen_batch_norm_backward.out + +- func: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + dispatch: + CUDA: miopen_convolution + autogen: miopen_convolution.out + +- func: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + dispatch: + CUDA: miopen_convolution_transpose + autogen: miopen_convolution_transpose.out + +- func: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + dispatch: + CUDA: miopen_depthwise_convolution + autogen: miopen_depthwise_convolution.out + +- func: miopen_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + dispatch: + CUDA: miopen_convolution_relu + +- func: miopen_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + dispatch: + CUDA: miopen_convolution_add_relu + +- func: miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + dispatch: + CUDA: miopen_rnn + autogen: miopen_rnn.out + tags: nondeterministic_seeded + + +- func: miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + dispatch: + CUDA: miopen_rnn_backward + autogen: miopen_rnn_backward.out + +- func: mm(Tensor self, Tensor mat2) -> Tensor + structured_delegate: mm.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: _sparse_mm + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm + tags: core + +- func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU: mm_out_cpu + CUDA: mm_out_cuda + MPS: mm_out_mps + SparseCPU, SparseCUDA: _sparse_mm_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm_out + +- func: _int_mm(Tensor self, Tensor mat2) -> Tensor + dispatch: + CPU: _int_mm_cpu + CUDA: _int_mm_cuda + +- func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: _int_mm_out_cpu + CUDA: _int_mm_out_cuda + +- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor + dispatch: + CPU: _convert_weight_to_int4pack_cpu + CUDA: _convert_weight_to_int4pack_cuda + MPS: _convert_weight_to_int4pack_mps + +- func: _weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor + dispatch: + CPU: _weight_int4pack_mm_cpu + MPS: _weight_int4pack_mm_mps + CUDA: _weight_int4pack_mm_cuda + +- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor + dispatch: + CPU: _weight_int8pack_mm_cpu + MPS: _weight_int8pack_mm_mps + +- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor + python_module: sparse + +- func: _sparse_mm.reduce(Tensor sparse, Tensor dense, str reduce) -> Tensor + python_module: sparse + +- func: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor + dispatch: + SparseCPU: sparse_sparse_matmul_cpu + SparseCUDA: sparse_sparse_matmul_cuda + autogen: _sparse_sparse_matmul.out + +- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + variants: function, method + dispatch: + CPU, CUDA: mode + +- func: mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + dispatch: + CompositeExplicitAutograd: mode_out + +- func: mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + variants: function, method + +- func: mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + +- func: mul.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: mul.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: mul_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_sparse_csr + MkldnnCPU: mkldnn_mul + ZeroTensor: mul_zerotensor + NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul_Tensor + tags: [core, pointwise] + +- func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: mul.out + variants: method + dispatch: + SparseCPU, SparseCUDA: mul_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_sparse_csr_ + MkldnnCPU: mkldnn_mul_ + NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul__Tensor + tags: pointwise + +- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: mul_out + MPS: mul_out_mps + SparseCPU: mul_out_sparse_cpu + SparseCUDA: mul_out_sparse_cuda + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_out_sparse_csr + MkldnnCPU: mkldnn_mul_out + tags: pointwise + # For C++ only, until we have conversion from C++ numbers to Tensor + +- func: mul.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: mul + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_scalar_sparse_csr + NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul_Scalar + tags: [core, pointwise] + +- func: mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: mul_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul__scalar_sparse_csr + NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul__Scalar + autogen: mul.Scalar_out + tags: pointwise +# multiply, alias for mul + +- func: multiply.Tensor(Tensor self, Tensor other) -> Tensor + variants: function, method + +- func: multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + +- func: multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + +- func: multiply.Scalar(Tensor self, Scalar other) -> Tensor + variants: function, method + +- func: multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + variants: method + +- func: mv(Tensor self, Tensor vec) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: mv + SparseCPU, SparseCUDA: mv_sparse + +- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: mv_out + +- func: mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: mvlgamma_out + tags: pointwise + +- func: mvlgamma(Tensor self, int p) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: mvlgamma + tags: pointwise + +- func: mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: mvlgamma_ + tags: pointwise + +- func: narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor + variants: function, method + dispatch: + CPU: narrow_copy_dense_cpu + SparseCPU, SparseCUDA: narrow_copy_sparse + CompositeExplicitAutogradNonFunctional: narrow_copy_dense_symint + tags: view_copy + +- func: narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: narrow_copy_dense_cpu_out + +- func: narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeImplicitAutograd: narrow_symint + NestedTensorCPU, NestedTensorCUDA: narrow_nested_symint + +- func: narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeImplicitAutograd: narrow_tensor_symint + +- func: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + dispatch: + CPU: batch_norm_cpu + CUDA: batch_norm_cuda + MPS: batch_norm_mps + MkldnnCPU: mkldnn_batch_norm + +- func: native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + dispatch: + CUDA: batch_norm_cuda_out + MPS: batch_norm_mps_out + CPU: batch_norm_cpu_out + +# TODO: In 2 weeks, we should make native_batch_norm composite implicit so that this correct schema percolates correctly through our dispatching +- func: _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + dispatch: + CPU: _batch_norm_legit_cpu + CUDA: _batch_norm_legit_cuda + MPS: _batch_norm_legit_mps + MkldnnCPU: _mkldnn_batch_norm_legit + autogen: _native_batch_norm_legit_functional + tags: core + +# HACK: identical to _native_batch_norm_legit, but training is known to be False, +# So we known that running stats will not be mutated. +# The real fix here is batch norm consolidation. +- func: _native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor) + dispatch: + CompositeExplicitAutograd: _batch_norm_legit_no_training + autogen: _native_batch_norm_legit_no_training.out + tags: core + +- func: _native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!)) + dispatch: + CPU: _batch_norm_legit_cpu_out + CUDA: _batch_norm_legit_cuda_out + MPS: _batch_norm_legit_mps_out + +- func: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + dispatch: + CPU: _batch_norm_legit_no_stats_cpu + CUDA: _batch_norm_legit_no_stats_cuda + MPS: _batch_norm_legit_no_stats_mps + MkldnnCPU: _mkldnn_batch_norm_legit_no_stats + tags: core + +- func: _native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + dispatch: + CPU: _batch_norm_legit_no_stats_cpu_out + CUDA: _batch_norm_legit_no_stats_cuda_out + MPS: _batch_norm_legit_no_stats_mps_out + +- func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor) + dispatch: + CUDA: batch_norm_stats_cuda + autogen: batch_norm_stats.out + +- func: batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor + dispatch: + CUDA: batch_norm_elemt_cuda + +- func: batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CUDA: batch_norm_elemt_cuda_out + +# for backward compatibility +- func: batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor) + dispatch: + CUDA: batch_norm_gather_stats_cuda + autogen: batch_norm_gather_stats.out + +- func: batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor) + dispatch: + CUDA: batch_norm_gather_stats_with_counts_cuda + autogen: batch_norm_gather_stats_with_counts.out + +- func: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + dispatch: + CPU: batch_norm_backward_cpu + CUDA: batch_norm_backward_cuda + MPS: batch_norm_backward_mps + MkldnnCPU: mkldnn_batch_norm_backward + autogen: native_batch_norm_backward.out + +- func: batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor) + dispatch: + CUDA: batch_norm_backward_reduce_cuda + autogen: batch_norm_backward_reduce.out + +- func: batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor + dispatch: + CUDA: batch_norm_backward_elemt_cuda + autogen: batch_norm_backward_elemt.out + +- func: batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor) + dispatch: + CPU: batch_norm_update_stats_cpu + CUDA: batch_norm_update_stats_cuda + autogen: batch_norm_update_stats.out + +- func: is_vulkan_available() -> bool + +- func: _nnpack_available() -> bool + +- func: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _nnpack_spatial_convolution + autogen: _nnpack_spatial_convolution.out + +- func: ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: ones + autogen: ones.names_out + +- func: ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: ones + +- func: ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: ones_out + +- func: ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: ones_like + NestedTensorCPU, NestedTensorCUDA: ones_like + autogen: ones_like.out + +- func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor + +- func: cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor + +- func: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor + dispatch: + CompositeExplicitAutograd: _euclidean_dist + autogen: _euclidean_dist.out + +- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor + dispatch: + CPU, CUDA: _cdist_forward + MPS: _cdist_forward_mps + autogen: _cdist_forward.out + tags: core + +- func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor + dispatch: + CPU, CUDA: _cdist_backward + autogen: _cdist_backward.out + +- func: pdist(Tensor self, float p=2) -> Tensor + +- func: _pdist_forward(Tensor self, float p=2) -> Tensor + dispatch: + CPU, CUDA: _pdist_forward + autogen: _pdist_forward.out + tags: core + +- func: _pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor + dispatch: + CPU, CUDA: _pdist_backward + autogen: _pdist_backward.out + +- func: cosine_similarity(Tensor x1, Tensor x2, int dim=1, float eps=1e-08) -> Tensor + variants: function + +- func: permute(Tensor(a) self, int[] dims) -> Tensor(a) + variants: function, method + dispatch: + CompositeExplicitAutograd: permute + MPS: permute_mps + SparseCPU, SparseCUDA: permute_sparse_coo + tags: core + +- func: movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) + variants: function, method + +- func: movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a) + variants: function, method + +# moveaxis, alias for movedim +- func: moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) + variants: function, method + +- func: moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a) + variants: function, method + +# Only exposed from C++ -- in Python, +# we expose it as an attribute `T`, not a function. +# +# I'd like to name this "T" in C++ too, but +# calling a native function "T" causes undefined +# behavior on Windows, for reasons I don't understand +# (maybe related to capital letter collation somehow...) +- func: numpy_T(Tensor(a) self) -> Tensor(a) + variants: method + +# Exposed on Python as an attribute 'H' +- func: matrix_H(Tensor(a) self) -> Tensor(a) + variants: method + +# Exposed on Python as an attribute 'mT' +- func: mT(Tensor(a) self) -> Tensor(a) + variants: method + +# Exposed on Python as an attribute 'mH' +- func: mH(Tensor(a) self) -> Tensor(a) + variants: method + +- func: adjoint(Tensor(a) self) -> Tensor(a) + variants: function, method + +- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor + dispatch: + CPU: pixel_shuffle_cpu + MPS: pixel_shuffle_mps + CompositeExplicitAutogradNonFunctional: math_pixel_shuffle + autogen: pixel_shuffle.out + +- func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor + dispatch: + CPU: pixel_unshuffle_cpu + MPS: pixel_unshuffle_mps + CompositeExplicitAutogradNonFunctional: math_pixel_unshuffle + autogen: pixel_unshuffle.out + +- func: channel_shuffle(Tensor self, SymInt groups) -> Tensor + dispatch: + CPU, CUDA: channel_shuffle + QuantizedCPU: channel_shuffle_quantized_cpu + autogen: channel_shuffle.out + +- func: native_channel_shuffle(Tensor self, SymInt groups) -> Tensor + dispatch: + CPU: channel_shuffle_cpu + CompositeImplicitAutograd: math_channel_shuffle + +- func: is_pinned(Tensor self, Device? device=None) -> bool + variants: method + dispatch: + # the NestedTensor keys are necessary because NestedTensor has been removed + # from the CompositeExplicitAutograd keyset see Note [NestedTensor Not Included in Backend Keys] + CompositeExplicitAutograd, NestedTensorCPU: is_pinned + SparseCsrCPU: is_pinned_sparse_compressed + SparseCPU: is_pinned_sparse_coo + +# TODO: add a copy kwarg that guarantees that the tensor is put into fresh +# pinned memory +- func: pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a) + variants: method + +# Unlike pin_memory, this is guaranteed to give a new non-aliasing tensor +- func: _pin_memory(Tensor self, Device? device=None) -> Tensor + dispatch: + CompositeExplicitAutograd: _pin_memory + NestedTensorCPU: _pin_memory_nested + SparseCPU: _pin_memory_sparse_coo + SparseCsrCPU: _pin_memory_sparse_compressed + autogen: _pin_memory.out + +- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor + variants: function, method + +- func: poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor + variants: function + +- func: rad2deg(Tensor self) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: rad2deg + SparseCPU, SparseCUDA: rad2deg_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: rad2deg_sparse_csr + +- func: rad2deg_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + dispatch: + CompositeExplicitAutograd: rad2deg_ + SparseCPU, SparseCUDA: rad2deg_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: rad2deg_sparse_csr_ + +- func: rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: rad2deg_out + SparseCPU, SparseCUDA: rad2deg_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: rad2deg_sparse_csr_out + +- func: deg2rad(Tensor self) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: deg2rad + SparseCPU, SparseCUDA: deg2rad_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: deg2rad_sparse_csr + tags: pointwise + +- func: deg2rad_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + dispatch: + CompositeExplicitAutograd: deg2rad_ + SparseCPU, SparseCUDA: deg2rad_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: deg2rad_sparse_csr_ + tags: pointwise + +- func: deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: deg2rad_out + SparseCPU, SparseCUDA: deg2rad_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: deg2rad_sparse_csr_out + tags: pointwise + +- func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: scalar_tensor + autogen: scalar_tensor.out + tags: core + +- func: rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: rand + autogen: rand.names_out + tags: nondeterministic_seeded + +- func: rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + device_check: NoCheck + device_guard: False + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: rand + autogen: rand.generator_with_names_out + +- func: rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: [core, nondeterministic_seeded] + dispatch: + CompositeExplicitAutograd: rand + +- func: rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: rand + +- func: rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: rand_out + +- func: rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + tags: nondeterministic_seeded + +- func: rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: rand_like + autogen: rand_like.out + +- func: randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randint + +- func: randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randint + +- func: randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randint + +- func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randint + +- func: randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randint_out + +- func: randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randint_out + +- func: randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randint_out + +- func: randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randint_out + +- func: randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: randint_like + autogen: randint_like.out + +- func: randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: randint_like + autogen: randint_like.low_dtype_out + +- func: randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: [core, nondeterministic_seeded] + dispatch: + CompositeExplicitAutograd: randn + +- func: randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randn + +- func: randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: randn + autogen: randn.names_out + +- func: randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: randn + autogen: randn.generator_with_names_out + +- func: randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + tags: nondeterministic_seeded + +- func: randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + tags: nondeterministic_seeded + +- func: randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like + autogen: randn_like.out + +- func: randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: [core, nondeterministic_seeded] + dispatch: + CompositeExplicitAutograd: randperm + +- func: randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randperm + +- func: randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randperm_out + +- func: randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + tags: nondeterministic_seeded + dispatch: + CPU: randperm_out_cpu + CUDA: randperm_out_cuda + MPS: randperm_out_mps + +- func: range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: range + +- func: range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: range + +- func: range.out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: range_out_no_step + +- func: range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, Meta: range_out + CUDA: range_cuda_out + MPS: range_mps_out + cpp_no_default_args: ['step'] + +- func: ravel(Tensor(a) self) -> Tensor(a) + variants: function, method + +- func: reciprocal(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: reciprocal.out + variants: function, method + tags: [core, pointwise] + +- func: reciprocal_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: reciprocal.out + variants: function, method + tags: pointwise + +- func: reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: reciprocal_out + MPS: reciprocal_out_mps + tags: pointwise + +- func: neg(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: neg.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: neg_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: neg_sparse_csr + NestedTensorCPU, NestedTensorCUDA: NestedTensor_neg + tags: [core, pointwise] + +- func: neg_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: neg.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: neg_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: neg_sparse_csr_ + NestedTensorCPU, NestedTensorCUDA: NestedTensor_neg_ + tags: pointwise + +- func: neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: neg_out + MPS: neg_out_mps + SparseCPU, SparseCUDA: neg_out_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: neg_sparse_csr_out + tags: pointwise +# Alias for neg + +- func: negative(Tensor self) -> Tensor + variants: function, method + +- func: negative_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + +- func: negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + +- func: repeat(Tensor self, SymInt[] repeats) -> Tensor + variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. + dispatch: + CompositeExplicitAutograd: repeat + MPS: repeat_mps + autogen: repeat.out + tags: core + +- func: repeat_interleave.Tensor(Tensor repeats, *, SymInt? output_size=None) -> Tensor + variants: function + dispatch: + CPU: repeat_interleave_cpu + CUDA: repeat_interleave_cuda + MPS: repeat_interleave_mps + tags: dynamic_output_shape + autogen: repeat_interleave.Tensor_out + +- func: repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + variants: function, method + dispatch: + CompositeImplicitAutograd: repeat_interleave_symint + +- func: repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + variants: function, method + dispatch: + CompositeImplicitAutograd: repeat_interleave_symint + +- func: reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeImplicitAutograd: reshape_symint + CompositeImplicitAutogradNestedTensor: reshape_nested_symint + +- func: _reshape_copy(Tensor self, SymInt[] size) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _reshape_copy_symint + +# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape. +# They are not user-facing, hence the leading underscore. Please don't use it +# anywhere else. +- func: _reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor, MPS: _reshape_alias + # We don't need to support mkldnn since this is handled explicitly by the reshape operator. + +- func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor + device_check: NoCheck + device_guard: False + dispatch: + MkldnnCPU: mkldnn_reshape + autogen: _mkldnn_reshape.out + +- func: reshape_as(Tensor(a) self, Tensor other) -> Tensor(a) + variants: method + device_check: NoCheck + device_guard: False + dispatch: + CompositeImplicitAutograd: reshape_as + CompositeImplicitAutogradNestedTensor: reshape_as_nested + +- func: round(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: round.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: round_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: round_sparse_csr + tags: [core, pointwise] + +- func: round_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: round.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: round_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: round_sparse_csr_ + tags: pointwise + +- func: round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU: round_out + CUDA: round_out + MPS: round_out_mps + SparseCPU, SparseCUDA: round_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: round_sparse_csr_out + tags: pointwise + +- func: round.decimals(Tensor self, *, int decimals) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: round.decimals_out + variants: function, method + tags: pointwise + +- func: round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: round.decimals_out + variants: function, method + tags: pointwise + +- func: round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU: round_decimals_out + CUDA: round_decimals_out + tags: pointwise + +- func: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor + device_check: NoCheck # TensorIterator + tags: nondeterministic_seeded + +- func: rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) + tags: nondeterministic_seeded + device_check: NoCheck # TensorIterator + +- func: relu(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CPU, CUDA: relu + MPS: relu_mps + MkldnnCPU: mkldnn_relu + QuantizedCPU: relu_quantized_cpu + QuantizedCUDA: relu_quantized_cuda + NestedTensorCPU, NestedTensorCUDA: NestedTensor_relu + SparseCPU, SparseCUDA: relu_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: relu_sparse_csr + tags: [core, pointwise] + +- func: relu_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CPU, CUDA: relu_ + MPS: relu_mps_ + MkldnnCPU: mkldnn_relu_ + QuantizedCPU: relu_quantized_cpu_ + QuantizedCUDA: relu_quantized_cuda_ + NestedTensorCPU, NestedTensorCUDA: NestedTensor_relu_ + SparseCPU, SparseCUDA: relu_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: relu_sparse_csr_ + autogen: relu.out + tags: pointwise + +- func: relu6(Tensor self) -> Tensor + python_module: nn + +- func: relu6_(Tensor(a!) self) -> Tensor(a!) + python_module: nn + +- func: prelu(Tensor self, Tensor weight) -> Tensor + variants: function, method + autogen: prelu.out + +- func: _prelu_kernel(Tensor self, Tensor weight) -> Tensor + dispatch: + CPU, CUDA: _prelu_kernel + QuantizedCPU: _prelu_kernel_quantized_cpu + MkldnnCPU: mkldnn_prelu + MPS: prelu_mps + +- func: _prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) + dispatch: + CPU, CUDA: _prelu_kernel_backward + MkldnnCPU: mkldnn_prelu_backward + MPS: prelu_backward_mps + +- func: gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU: gelu_out_cpu + CUDA: gelu_out_cuda + MPS: gelu_out_mps + +- func: gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!) + structured_delegate: gelu.out + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + QuantizedCPU: gelu_quantized_cpu_ + NestedTensorCPU, NestedTensorCUDA: NestedTensor_gelu_ + +- func: gelu(Tensor self, *, str approximate='none') -> Tensor + structured_delegate: gelu.out + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + MkldnnCPU: mkldnn_gelu + QuantizedCPU: gelu_quantized_cpu + QuantizedCUDA: gelu_quantized_cuda + NestedTensorCPU, NestedTensorCUDA: NestedTensor_gelu + tags: [core, pointwise] + +- func: gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU: gelu_backward_out_cpu + CUDA: gelu_backward_out_cuda + MPS: gelu_backward_out_mps + +- func: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor + structured_delegate: gelu_backward.grad_input + python_module: nn + dispatch: + MkldnnCPU: mkldnn_gelu_backward + NestedTensorCPU, NestedTensorCUDA: gelu_backwards_nested + tags: pointwise + +- func: infinitely_differentiable_gelu_backward(Tensor grad, Tensor self) -> Tensor + variants: function + python_module: nn + device_check: NoCheck + device_guard: False + +- func: hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: hardshrink_out + +- func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor + structured_delegate: hardshrink.out + device_check: NoCheck # TensorIterator + variants: function, method + +- func: hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: hardshrink_backward_out + +- func: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor + structured_delegate: hardshrink_backward.grad_input + variants: function, method + +- func: rsqrt(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: rsqrt.out + variants: function, method + tags: [core, pointwise] + +- func: rsqrt_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: rsqrt.out + variants: function, method + tags: pointwise + +- func: rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: rsqrt_out + MPS: rsqrt_out_mps + tags: pointwise + +- func: select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + +- func: select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: select_symint + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: select_sparse_csr + NestedTensorCPU, NestedTensorCUDA: select_nested + tags: core + +- func: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor + variants: function + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutogradNonFunctional: select_backward_symint + autogen: select_backward.out + +- func: _nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor + variants: function + device_check: NoCheck + device_guard: False + dispatch: + NestedTensorCPU, NestedTensorCUDA: _nested_select_backward_symint + +- func: selu(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + +- func: selu_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: celu(Tensor self, Scalar alpha=1.0) -> Tensor + device_check: NoCheck # TensorIterator + dispatch: + CompositeExplicitAutograd: celu + +- func: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CompositeExplicitAutograd: celu_ + autogen: celu.out + +- func: silu(Tensor self) -> Tensor + structured_delegate: silu.out + python_module: nn + dispatch: + NestedTensorCPU, NestedTensorCUDA: NestedTensor_silu + tags: pointwise + +- func: silu_(Tensor(a!) self) -> Tensor(a!) + structured_delegate: silu.out + python_module: nn + dispatch: + NestedTensorCPU, NestedTensorCUDA: NestedTensor_silu_ + tags: pointwise + +- func: silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU, CUDA: silu_out + MPS: silu_out_mps + tags: pointwise + +- func: silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU, CUDA: silu_backward_out + MPS: silu_backward_out_mps + tags: pointwise + +- func: silu_backward(Tensor grad_output, Tensor self) -> Tensor + structured_delegate: silu_backward.grad_input + python_module: nn + dispatch: + CompositeImplicitAutograd: math_silu_backward + NestedTensorCPU, NestedTensorCUDA: silu_backward_nested + tags: pointwise + +- func: mish(Tensor self) -> Tensor + structured_delegate: mish.out + python_module: nn + +- func: mish_(Tensor(a!) self) -> Tensor(a!) + structured_delegate: mish.out + python_module: nn + +- func: mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU, CUDA: mish_out + MPS: mish_out_mps + +- func: mish_backward(Tensor grad_output, Tensor self) -> Tensor + python_module: nn + dispatch: + CPU, CUDA: mish_backward + MPS: mish_backward_mps + CompositeImplicitAutograd: math_mish_backward + +- func: sigmoid(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: sigmoid.out + variants: function, method + dispatch: + QuantizedCPU: sigmoid_quantized_cpu + MkldnnCPU: mkldnn_sigmoid + tags: [core, pointwise] + +- func: sigmoid_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: sigmoid.out + variants: function, method + dispatch: + MkldnnCPU: mkldnn_sigmoid_ + tags: pointwise + +- func: sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: sigmoid_out + MPS: sigmoid_out_mps + tags: pointwise + +- func: logit(Tensor self, float? eps=None) -> Tensor + variants: function, method + dispatch: + CPU, CUDA: logit + MPS: logit_mps + tags: pointwise + +- func: logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!) + variants: function, method + dispatch: + CPU, CUDA: logit_ + tags: pointwise + +- func: logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: logit_out + MPS: logit_out_mps + tags: pointwise + +- func: sin(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: sin.out + variants: function, method + dispatch: + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sin_sparse_csr + SparseCPU, SparseCUDA: sin_sparse + NestedTensorCPU, NestedTensorCUDA: sin_nested + tags: [core, pointwise] + +- func: sin_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: sin.out + variants: function, method + dispatch: + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sin_sparse_csr_ + SparseCPU, SparseCUDA: sin_sparse_ + tags: pointwise + +- func: sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: sin_out + MPS: sin_out_mps + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sin_sparse_csr_out + SparseCPU, SparseCUDA: sin_sparse_out + tags: pointwise + +- func: sinc(Tensor self) -> Tensor + structured_delegate: sinc.out + variants: function, method + tags: pointwise + +- func: sinc_(Tensor(a!) self) -> Tensor(a!) + structured_delegate: sinc.out + variants: function, method + tags: pointwise + +- func: sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: sinc_out + tags: pointwise + +- func: sinh(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: sinh.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: sinh_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sinh_sparse_csr + tags: [core, pointwise] + +- func: sinh_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: sinh.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: sinh_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sinh_sparse_csr_ + tags: pointwise + +- func: sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: sinh_out + MPS: sinh_out_mps + SparseCPU, SparseCUDA: sinh_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sinh_sparse_csr_out + +# Returns a copy of this `Variable` that is detached from its autograd graph. +# This method is OK to call if the `Variable` is a view. +# +# NOTE: Previously, if we change the tensor metadata (e.g. sizes / strides / +# storage / storage_offset) of a tensor created from `detach()`, those metadata +# in the original tensor will also be updated. However, the new behavior is that +# those metadata changes to the detached tensor will not update the original tensor +# anymore, and in the `detach()` function we need to set `allow_tensor_metadata_change_` +# to false to make such changes explicitly illegal, in order to prevent users from +# changing metadata of the detached tensor and expecting the original tensor to also +# be updated. + tags: pointwise +- func: detach(Tensor(a) self) -> Tensor(a) + variants: function, method + dispatch: + CompositeExplicitAutograd: detach + NestedTensorCPU, NestedTensorCUDA: detach + +# Like `detach()`, but modifies this `Variable` in-place. This method may +# only be called on non-view `Variable`s. You can use `is_view()` to check +# this. If this `Variable` is a view, throws an `std::runtime_error()`. +- func: detach_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + tags: inplace_view + dispatch: + CompositeExplicitAutograd: detach_ + +- func: size.int(Tensor self, int dim) -> int + variants: function + device_check: NoCheck + device_guard: False + manual_cpp_binding: True + +- func: size.Dimname(Tensor self, Dimname dim) -> int + variants: function, method + device_check: NoCheck + device_guard: False + +- func: sym_size.int(Tensor self, int dim) -> SymInt + variants: function + device_check: NoCheck + device_guard: False + tags: core + manual_cpp_binding: True + +- func: sym_numel(Tensor self) -> SymInt + variants: function + device_check: NoCheck + device_guard: False + tags: core + manual_cpp_binding: True + +- func: sym_storage_offset(Tensor self) -> SymInt + variants: function + device_check: NoCheck + device_guard: False + tags: core + manual_cpp_binding: True + +- func: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: slice + tags: core + +# NOTE: The implementation of split_with_sizes bypasses the dispatcher to call this; undo +# that if adding specific implementations here! + +- func: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor + variants: function + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: slice_backward + autogen: slice_backward.out + +# NB: This op exists to back the implementation of reverse view_funcs for various views (chunk, +# slice.Tensor, split_with_sizes, et al.). Currently, these are only used during fake-ification +# of PT2 graph input subclass instances that are views. This means: +# * This op shouldn't really show up in eager mode (so e.g. XLA shouldn't have to implement it) +# * This op shouldn't show up in a PT2 graph (so a PT2 backend shouldn't have to implement it) +# * A subclass will have to implement this to work in PT2 if a subclass view is used as a graph +# input AND the view utilizes this op in its inverse. The idea is that slice_inverse() is +# easier to implement for a subclass than as_strided() +- func: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: slice_inverse_symint + +- func: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutogradNonFunctional: slice_scatter + autogen: slice_scatter.out + tags: [core, view_copy] + +- func: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutogradNonFunctional: select_scatter_symint + autogen: select_scatter.out + tags: core + +- func: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutogradNonFunctional: diagonal_scatter + autogen: diagonal_scatter.out + +- func: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutogradNonFunctional: as_strided_scatter_symint + autogen: as_strided_scatter.out + +- func: smm(Tensor self, Tensor mat2) -> Tensor + variants: function, method + +# softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models. +- func: softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + variants: function, method + +- func: softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + variants: function + dispatch: + CompositeExplicitAutograd: softmax_out + +- func: softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + variants: function, method + +- func: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor + structured_delegate: _softmax.out + dispatch: + MkldnnCPU: mkldnn_softmax + NestedTensorCPU, NestedTensorCUDA: softmax_nested + tags: core + +- func: _softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU: softmax_cpu_out + CUDA: softmax_cuda_out + MPS: softmax_mps_out + +- func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor + structured_delegate: _softmax_backward_data.out + dispatch: + NestedTensorCPU, NestedTensorCUDA: nested_softmax_backward + +- func: _softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + dispatch: + CPU: softmax_backward_cpu_out + CUDA: softmax_backward_cuda_out + MPS: softmax_backward_mps_out + +- func: unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: unsafe_split + autogen: unsafe_split.Tensor_out + +- func: split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: split + +- func: split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[] + variants: function, method + device_guard: False + dispatch: + CompositeImplicitAutograd: split_symint + +- func: unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: unsafe_split_with_sizes + autogen: unsafe_split_with_sizes.out + +- func: split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: split_with_sizes + NestedTensorCPU, NestedTensorCUDA: split_with_sizes_nested + tags: core + +- func: hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] + variants: function, method + +- func: hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] + variants: function, method + +- func: vsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] + variants: function, method + +- func: vsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] + variants: function, method + +- func: dsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] + variants: function, method + +- func: dsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] + variants: function, method + +- func: squeeze(Tensor(a) self) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: squeeze + QuantizedCPU, QuantizedCUDA: squeeze_quantized + NestedTensorCPU, NestedTensorCUDA: squeeze_nested + +- func: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: squeeze + QuantizedCPU, QuantizedCUDA: squeeze_quantized + NestedTensorCPU, NestedTensorCUDA: squeeze_dim_nested + tags: core + +- func: squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + + +- func: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: squeeze + QuantizedCPU, QuantizedCUDA: squeeze_quantized + NestedTensorCPU, NestedTensorCUDA: squeeze_dim_nested + tags: core + +- func: squeeze_(Tensor(a!) self) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + tags: inplace_view + dispatch: + CompositeExplicitAutograd: squeeze_ + +- func: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + tags: inplace_view + dispatch: + CompositeExplicitAutograd: squeeze_ + +- func: squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + tags: inplace_view + dispatch: + CompositeExplicitAutograd: squeeze_ + +- func: squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + tags: inplace_view + +- func: sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + variants: function, method + +- func: sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: _sspaddmm_out_only_sparse + CUDA: _sspaddmm_out_only_sparse_cuda + SparseCPU: _sspaddmm_out_cpu + SparseCUDA: _sspaddmm_out_cuda + +- func: _chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor + dispatch: + CompositeExplicitAutograd: _chunk_cat + CUDA: _chunk_cat_cuda + +- func: _chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: _chunk_cat_out + CUDA: _chunk_cat_out_cuda + +- func: stack(Tensor[] tensors, int dim=0) -> Tensor + dispatch: + CompositeExplicitAutograd: stack + +- func: stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: stack_out + +- func: _stack(Tensor[] tensors, int dim=0) -> Tensor + dispatch: # match the backends supported by _cat + CPU: _stack_cpu + CompositeExplicitAutograd: _stack + +- func: _stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + dispatch: # match the backends supported by _cat_out + CPU: _stack_out_cpu + CompositeExplicitAutograd: _stack_out + +- func: hstack(Tensor[] tensors) -> Tensor + +- func: hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + +- func: vstack(Tensor[] tensors) -> Tensor + +- func: vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + +- func: dstack(Tensor[] tensors) -> Tensor + +- func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + +# Overload without center & pad mode, needed for forward-compatibility +- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor + variants: function, method + cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized'] + +- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor + variants: function, method + +- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor + variants: function, method + +- func: stride.int(Tensor self, int dim) -> int + variants: function + device_check: NoCheck + device_guard: False + manual_cpp_binding: True + +- func: stride.Dimname(Tensor self, Dimname dim) -> int + variants: function, method + device_check: NoCheck + device_guard: False + +- func: sym_stride.int(Tensor self, int dim) -> SymInt + variants: function + device_check: NoCheck + device_guard: False + tags: core + manual_cpp_binding: True + +- func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: sum + SparseCPU, SparseCUDA, SparseMeta: sum_coo + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_csr + autogen: sum.out + +- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + # TODO: Align the signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype + structured_delegate: sum.IntList_out + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + NestedTensorCPU: NestedTensor_sum_dim_CPU + SparseCPU, SparseCUDA: sum_sparse_coo + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_sparse_compressed + tags: core + +- func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + structured: True + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: sum_out + MPS: sum_out_mps + +- func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +# TODO: this function will be replaced once nested expand semantics have been settled on +- func: _nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor + dispatch: + NestedTensorCPU: _nested_sum_backward_cpu + +- func: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + variants: function, method + dispatch: + CPU, CUDA: nansum + MPS: nansum_mps + +- func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: nansum_out + MPS: nansum_out_mps + +- func: sum_to_size(Tensor self, SymInt[] size) -> Tensor + variants: method + device_check: NoCheck + device_guard: False + dispatch: + CompositeImplicitAutograd: sum_to_size_symint + +- func: sqrt(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: sqrt.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: sqrt_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sqrt_sparse_csr + tags: [core, pointwise] + +- func: sqrt_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: sqrt.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: sqrt_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sqrt_sparse_csr_ + tags: pointwise + +- func: sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: sqrt_out + MPS: sqrt_out_mps + SparseCPU, SparseCUDA: sqrt_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sqrt_sparse_csr_out + tags: pointwise + +- func: square(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + tags: pointwise + +- func: square_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + tags: pointwise + +- func: square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + tags: pointwise + +- func: std(Tensor self, bool unbiased=True) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + cpp_no_default_args: ["unbiased"] + +- func: std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + cpp_no_default_args: ["unbiased"] + +- func: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CPU, CUDA: std + MPS: std_mps + QuantizedCPU: std_quantized_cpu + +- func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + device_check: NoCheck # TensorIterator + variants: function + cpp_no_default_args: ["unbiased"] + +- func: std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + device_check: NoCheck # TensorIterator + variants: function + cpp_no_default_args: ["unbiased"] + +- func: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CPU, CUDA: std_mean + MPS: std_mean_mps + autogen: std_mean.correction_out + +- func: std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + device_check: NoCheck # TensorIterator + variants: function + cpp_no_default_args: ["unbiased"] + +- func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + device_check: NoCheck # TensorIterator + variants: function + +- func: std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + cpp_no_default_args: ["unbiased"] + +- func: std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: std_out + QuantizedCPU: std_out_quantized_cpu + +- func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + cpp_no_default_args: ["unbiased"] + +- func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + cpp_no_default_args: ["unbiased"] + +- func: std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + +- func: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CPU, CUDA: prod + MPS: prod_mps + autogen: prod.out + tags: core + +- func: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + structured_delegate: prod.int_out + device_check: NoCheck # TensorIterator + variants: function, method + tags: core + +- func: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + structured: True + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: prod_out + MPS: prod_out_mps + +- func: prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: t(Tensor(a) self) -> Tensor(a) + device_check: NoCheck + device_guard: False + variants: function, method + dispatch: + CompositeExplicitAutograd: t + +- func: t_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck + device_guard: False + variants: method + tags: inplace_view + dispatch: + CompositeExplicitAutograd: t_ + +- func: tan(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: tan.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: tan_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tan_sparse_csr + tags: [core, pointwise] + +- func: tan_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: tan.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: tan_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tan_sparse_csr_ + tags: pointwise + +- func: tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: tan_out + MPS: tan_out_mps + SparseCPU, SparseCUDA: tan_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tan_sparse_csr_out + tags: pointwise + +- func: tanh(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: tanh.out + variants: function, method + dispatch: + QuantizedCPU: tanh_quantized_cpu + MkldnnCPU: mkldnn_tanh + SparseCPU, SparseCUDA: tanh_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tanh_sparse_csr + NestedTensorCPU, NestedTensorCUDA: NestedTensor_tanh + tags: [core, pointwise] + +- func: tanh_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: tanh.out + variants: function, method + dispatch: + MkldnnCPU: mkldnn_tanh_ + SparseCPU, SparseCUDA: tanh_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tanh_sparse_csr_ + NestedTensorCPU, NestedTensorCUDA: NestedTensor_tanh_ + tags: pointwise + +- func: tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: tanh_out + MPS: tanh_out_mps + SparseCPU, SparseCUDA: tanh_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tanh_sparse_csr_out + tags: pointwise + +- func: tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor + variants: function + +- func: tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!) + variants: function + +# TODO: namespace threshold in 'nn' +- func: threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + structured_delegate: threshold.out + dispatch: + QuantizedCPU: threshold_quantized_cpu + +- func: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + structured_delegate: threshold.out + +- func: threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: threshold_out + MPS: threshold_out_mps + +- func: threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: threshold_backward_out + MPS: threshold_backward_out_mps + SparseCPU, SparseCUDA: threshold_backward_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: threshold_backward_sparse_compressed_out + +- func: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor + variants: function + structured_delegate: threshold_backward.grad_input + dispatch: + MkldnnCPU: mkldnn_relu_backward + SparseCPU, SparseCUDA: threshold_backward_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: threshold_backward_sparse_compressed + NestedTensorCPU, NestedTensorCUDA: threshold_backwards_nested + tags: pointwise + +- func: tile(Tensor self, SymInt[] dims) -> Tensor + variants: function, method + dispatch: + CompositeImplicitAutograd: tile_symint + +- func: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: transpose + NestedTensorCPU, NestedTensorCUDA: transpose_nested + +- func: transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + +- func: _mkldnn_transpose(Tensor self, int dim0, int dim1) -> Tensor + device_check: NoCheck + device_guard: False + dispatch: + MkldnnCPU: mkldnn_transpose + +- func: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + tags: inplace_view + dispatch: + CompositeExplicitAutograd: transpose_ + +- func: _mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + device_check: NoCheck + device_guard: False + dispatch: + MkldnnCPU: mkldnn_transpose_ + autogen: _mkldnn_transpose.out + +- func: one_hot(Tensor self, int num_classes=-1) -> Tensor + python_module: nn + variants: function + tags: dynamic_output_shape + +- func: flip(Tensor self, int[] dims) -> Tensor + variants: function, method + dispatch: + CPU, QuantizedCPU, CUDA, QuantizedCUDA: flip + MPS: flip_mps + autogen: flip.out + tags: core + +- func: fliplr(Tensor self) -> Tensor + variants: function, method + +- func: flipud(Tensor self) -> Tensor + variants: function, method + +- func: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor + variants: function, method + dispatch: + CPU, MPS: roll + CUDA: roll_cuda + autogen: roll.out + +# default int[] value [0,1] should not add space after comma, since codegen parser uses ', ' to split args + +- func: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: rot90 + autogen: rot90.out + +- func: trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor + +- func: trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor + +- func: trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor + +- func: trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor + +# Fused implementation detail for transformers. Adds in-projection bias to QKV and divides Q by sqrt(D/num_heads). +- func: _transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor) + dispatch: + CPU, NestedTensorCPU: transform_bias_rescale_qkv_cpu + CUDA, NestedTensorCUDA: transform_bias_rescale_qkv_cuda + autogen: _transform_bias_rescale_qkv.out + +- func: _nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor + dispatch: + CPU, CUDA: NestedTensor_nested_tensor_from_mask + autogen: _nested_tensor_from_mask.out + +- func: _nested_tensor_from_mask_left_aligned(Tensor t, Tensor mask) -> bool + dispatch: + CPU, CUDA: NestedTensor_nested_tensor_from_mask_left_aligned + +- func: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor + device_check: NoCheck # cpu_nested_shape_example will always be on CPU + dispatch: + CPU: nested_from_padded_generic + CUDA: nested_from_padded_cuda + autogen: _nested_from_padded.out + +# These private functions are temporary. They will be updated/deleted when nested tensors switch to using SymInts for their metadata representation +- func: _nested_tensor_size(Tensor self) -> Tensor + variants: method + dispatch: + NestedTensorCPU, NestedTensorCUDA: _nested_tensor_size + autogen: _nested_tensor_size.out + +- func: _nested_tensor_strides(Tensor self) -> Tensor + variants: method + dispatch: + NestedTensorCPU, NestedTensorCUDA: _nested_tensor_strides + autogen: _nested_tensor_strides.out + +- func: _nested_tensor_storage_offsets(Tensor self) -> Tensor + variants: method + dispatch: + NestedTensorCPU, NestedTensorCUDA, NestedTensorMeta: _nested_tensor_storage_offsets + autogen: _nested_tensor_storage_offsets.out + +# _nested_from_padded is not usable from Python, so +# _nested_from_padded_and_nested_example is available for testing. +- func: _nested_from_padded_and_nested_example(Tensor padded, Tensor nt_example) -> Tensor + dispatch: + NestedTensorCPU, NestedTensorCUDA: NestedTensor_from_padded_and_nested_example + autogen: _nested_from_padded_and_nested_example.out + +# The input arguments' types to this functions are temporary. When nested tensors switch to using SymInts for their metadata representation +# this will need to be updated +- func: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a) + variants: function + device_check: NoCheck + dispatch: + CPU, CUDA: _nested_view_from_buffer + +- func: _nested_view_from_buffer_copy(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor + variants: function + device_check: NoCheck + tags: view_copy + dispatch: + CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy + autogen: _nested_view_from_buffer_copy.out + +- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) + variants: function + device_check: NoCheck + dispatch: {} + +- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor + variants: function + device_check: NoCheck + tags: view_copy + dispatch: + CompositeExplicitAutogradNonFunctional: _nested_view_from_jagged_copy + autogen: _nested_view_from_jagged_copy.out + +- func: _nested_get_values(Tensor(a) self) -> Tensor(a) + variants: function + device_check: NoCheck + dispatch: {} + +- func: _nested_get_values_copy(Tensor self) -> Tensor + variants: function + device_check: NoCheck + tags: view_copy + dispatch: + CompositeExplicitAutogradNonFunctional: _nested_get_values_copy + autogen: _nested_get_values_copy.out + +- func: _nested_get_offsets(Tensor self) -> Tensor + variants: function + device_check: NoCheck + dispatch: {} + +# returns undefined Tensor if no lengths present +- func: _nested_get_lengths(Tensor self) -> Tensor + variants: function + device_check: NoCheck + dispatch: {} + +- func: _nested_get_ragged_idx(Tensor self) -> int + variants: function + device_check: NoCheck + dispatch: {} + +- func: _nested_get_min_seqlen(Tensor self) -> Tensor + variants: function + device_check: NoCheck + dispatch: {} + +- func: _nested_get_max_seqlen(Tensor self) -> Tensor + variants: function + device_check: NoCheck + dispatch: {} + +- func: _nested_get_jagged_dummy(Tensor any) -> Tensor + category_override: dummy + dispatch: {} + +- func: _nested_compute_contiguous_strides_offsets(Tensor nested_size) -> (Tensor, Tensor) + variants: function + device_check: NoCheck + dispatch: + CPU, CUDA: _nested_compute_contiguous_strides_offsets + +- func: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor + dispatch: + # calls unsqueeze + CompositeExplicitAutogradNonFunctional: _trilinear + autogen: _trilinear.out + +- func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor + +- func: trunc(Tensor self) -> Tensor + structured_delegate: trunc.out + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + SparseCPU, SparseCUDA: trunc_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr + tags: [core, pointwise] + +- func: trunc_(Tensor(a!) self) -> Tensor(a!) + structured_delegate: trunc.out + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + SparseCPU, SparseCUDA: trunc_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr_ + tags: pointwise + +- func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: trunc_out + MPS: trunc_out_mps + SparseCPU, SparseCUDA: trunc_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr_out + tags: pointwise +# Alias for trunc + +- func: fix(Tensor self) -> Tensor + variants: function, method + +- func: fix_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + +- func: fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + +- func: type_as(Tensor self, Tensor other) -> Tensor + variants: method + +- func: _has_compatible_shallow_copy_type(Tensor self, Tensor from) -> bool + variants: function + +- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) + variants: function + dispatch: + CPU: _unique_cpu + CUDA: _unique_cuda + autogen: _unique.out + +- func: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CPU: unique_dim_cpu + CUDA: unique_dim_cuda + tags: dynamic_output_shape + autogen: unique_dim.out + +- func: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CPU: unique_consecutive_cpu + CUDA: unique_consecutive_cuda + MPS: unique_consecutive_mps + tags: dynamic_output_shape + autogen: unique_consecutive.out + +- func: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CPU: unique_dim_consecutive_cpu + CUDA: unique_dim_consecutive_cuda + MPS: unique_dim_consecutive_mps + tags: dynamic_output_shape + autogen: unique_dim_consecutive.out + +# _unique and _unique_dim are fragile and modifying them easily cause internal break +# the below operator is a temporary hack for adding return_counts support +# Please don't rely on these two operators, they will be removed soon + +- func: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CPU: _unique2_cpu + CUDA: _unique2_cuda + MPS: _unique2_mps + tags: dynamic_output_shape + autogen: _unique2.out + +- func: _unsafe_view(Tensor self, SymInt[] size) -> Tensor + dispatch: + CompositeExplicitAutograd: _unsafe_view + autogen: _unsafe_view.out + +- func: unsqueeze(Tensor(a) self, int dim) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: unsqueeze + SparseCPU, SparseCUDA: unsqueeze_sparse + QuantizedCPU, QuantizedCUDA: unsqueeze_quantized + NestedTensorCPU, NestedTensorCUDA: unsqueeze_nested + tags: core + +- func: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + tags: inplace_view + dispatch: + CompositeExplicitAutograd: unsqueeze_ + +- func: vander(Tensor x, int? N=None, bool increasing=False) -> Tensor + +- func: var(Tensor self, bool unbiased=True) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + cpp_no_default_args: ["unbiased"] + +- func: var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + tags: core + cpp_no_default_args: ["unbiased"] + +- func: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CPU, CUDA: var + MPS: var_mps + tags: core + +- func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + cpp_no_default_args: ["unbiased"] + +- func: var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: var_out + +- func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + cpp_no_default_args: ["unbiased"] + +- func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + cpp_no_default_args: ["unbiased"] + +- func: var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + +- func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + device_check: NoCheck # TensorIterator + variants: function + cpp_no_default_args: ["unbiased"] + +- func: var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + device_check: NoCheck # TensorIterator + variants: function + cpp_no_default_args: ["unbiased"] + +- func: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CPU, CUDA: var_mean + MPS: var_mean_mps + autogen: var_mean.correction_out + +- func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + device_check: NoCheck # TensorIterator + variants: function + cpp_no_default_args: ["unbiased"] + +- func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + device_check: NoCheck # TensorIterator + variants: function + +- func: view_as(Tensor(a) self, Tensor other) -> Tensor(a) + variants: method + device_check: NoCheck + device_guard: False + +- func: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CPU, CUDA, MPS: where + NestedTensorCPU, NestedTensorCUDA: NestedTensor_where + tags: [core, pointwise] + +- func: where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA, MPS: where_self_out + +- func: where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor + variants: function + +- func: where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor + variants: function, method + +- func: where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor + variants: function + +- func: where(Tensor condition) -> Tensor[] + device_check: NoCheck # TensorIterator + variants: function + +- func: norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor + variants: function + +# VariableType::_weight_norm does not want to be given a gap in the autograd graph, +# so we don't define "dispatch" variants for it. +- func: _weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor + variants: function + +- func: _weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor) + variants: function + dispatch: + CPU: weight_norm_cpu + CUDA: weight_norm_cuda + MPS: weight_norm_mps + autogen: _weight_norm_interface.out + +- func: _weight_norm_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) + variants: function + dispatch: + CPU: weight_norm_backward_cpu + CUDA: weight_norm_backward_cuda + MPS: weight_norm_backward_mps + autogen: _weight_norm_interface_backward.out + +- func: _weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) + variants: function + +- func: zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: zeros + autogen: zeros.names_out + +- func: _efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CPU: _efficientzerotensor + CUDA: _efficientzerotensor_cuda + MPS: _efficientzerotensor_mps + Meta: _efficientzerotensor_meta_symint + autogen: _efficientzerotensor.out + +- func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: zeros_symint + +- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: zeros_out + SparseCPU, SparseCUDA, SparseMeta: zeros_sparse_out + +- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: zeros_like + autogen: zeros_like.out + +- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor + variants: function + dispatch: + CPU: _standard_gamma_grad_cpu + CUDA: _standard_gamma_grad_cuda + autogen: _standard_gamma_grad.out + +- func: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor + variants: function + dispatch: + CPU: _s_gamma_cpu + CUDA: _s_gamma_cuda + tags: nondeterministic_seeded + autogen: _standard_gamma.out + +- func: _dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor + dispatch: + CPU: _dirichlet_grad_cpu + CUDA: _dirichlet_grad_cuda + autogen: _dirichlet_grad.out + +- func: _sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor + tags: nondeterministic_seeded + variants: function + dispatch: + CPU: _s_dirichlet_cpu + CUDA: _s_dirichlet_cuda + autogen: _sample_dirichlet.out + +- func: poisson(Tensor self, Generator? generator=None) -> Tensor + device_check: NoCheck # TensorIterator + dispatch: + CPU: _s_poisson_cpu + CUDA: _s_poisson_cuda + tags: nondeterministic_seeded + autogen: poisson.out + +- func: binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor + device_check: NoCheck # TensorIterator + dispatch: + CPU: _s_binomial_cpu + CUDA: _s_binomial_cuda + tags: nondeterministic_seeded + autogen: binomial.out + +# When more variants get ported to native, this dispatch will get more +# complicated + +- func: native_norm(Tensor self, Scalar p=2) -> Tensor + dispatch: + SparseCPU, SparseCUDA: norm_sparse + autogen: native_norm.out + +- func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor + dispatch: + SparseCPU, SparseCUDA: norm_sparse + autogen: native_norm.ScalarOpt_dim_dtype_out + +- func: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) + dispatch: + CPU: _batch_norm_with_update_cpu + CUDA: _batch_norm_with_update_cuda + MPS: _batch_norm_with_update_mps + MkldnnCPU: _batch_norm_with_update_mkldnn + autogen: _batch_norm_with_update_functional + +- func: _batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!)) + dispatch: + CPU: _batch_norm_with_update_cpu_out + CUDA: _batch_norm_with_update_cuda_out + MPS: _batch_norm_with_update_mps_out + +- func: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) + dispatch: + CompositeExplicitAutograd: _batch_norm_no_update + autogen: _batch_norm_no_update.out + +- func: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor) + dispatch: + CPU: _new_batch_norm_backward_cpu + CUDA: _new_batch_norm_backward_cuda + MPS: _new_batch_norm_backward_mps + MkldnnCPU: _new_batch_norm_backward_mkldnn + +# TODO: reduce signatures down to one when optional args is available +- func: _sparse_sum(Tensor self) -> Tensor + +- func: _sparse_sum.dtype(Tensor self, *, ScalarType dtype) -> Tensor + +- func: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor + dispatch: + CompositeExplicitAutograd: _sparse_sum + autogen: _sparse_sum.dim_out + +- func: _sparse_sum.dim_dtype(Tensor self, int[1] dim, *, ScalarType dtype) -> Tensor + +- func: _sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor + dispatch: + SparseCPU: _sparse_sum_backward_cpu + SparseCUDA: _sparse_sum_backward_cuda + autogen: _sparse_sum_backward.out + +- func: _sparse_csr_sum.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + dispatch: + SparseCsrCPU: _sparse_csr_sum_cpu + SparseCsrCUDA: _sparse_csr_sum_cuda + autogen: _sparse_csr_sum.dim_dtype_out + +- func: _sparse_csr_prod.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + dispatch: + SparseCsrCPU: _sparse_csr_prod_cpu + SparseCsrCUDA: _sparse_csr_prod_cuda + autogen: _sparse_csr_prod.dim_dtype_out + +- func: _sparse_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + python_module: sparse + variants: function + +- func: _sparse_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + python_module: sparse + variants: function + +- func: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + python_module: sparse + dispatch: + SparseCPU: softmax_sparse_cpu + SparseCUDA: softmax_sparse_cuda + autogen: _sparse_softmax.out + +- func: _sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor + dispatch: + SparseCPU: softmax_backward_sparse_cpu + SparseCUDA: softmax_backward_sparse_cuda + autogen: _sparse_softmax_backward_data.out + +- func: _sparse_log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + python_module: sparse + variants: function + +- func: _sparse_log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + python_module: sparse + variants: function + +- func: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + python_module: sparse + dispatch: + SparseCPU: log_softmax_sparse_cpu + SparseCUDA: log_softmax_sparse_cuda + autogen: _sparse_log_softmax.out + +- func: _sparse_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor + dispatch: + SparseCPU: log_softmax_backward_sparse_cpu + SparseCUDA: log_softmax_backward_sparse_cuda + autogen: _sparse_log_softmax_backward_data.out + +- func: _spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor + python_module: sparse + dispatch: + CPU: spdiags + autogen: _spdiags.out + +- func: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: norm + autogen: norm.ScalarOpt_dtype_out + +- func: norm.Scalar(Tensor self, Scalar p=2) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: norm + autogen: norm.Scalar_out + +- func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor + structured_delegate: norm.dtype_out + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + SparseCPU, SparseCUDA: sparse_dtype_norm + +- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor + structured_delegate: norm.out + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + SparseCPU, SparseCUDA: sparse_norm + +- func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + structured: True + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: norm_dtype_out + MPS: norm_dtype_out_mps + +- func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + structured: True + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: norm_out + MPS: norm_out_mps + +# These four redispatch in their implementation, so OK to be CompositeImplicitAutograd +- func: norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) + variants: method, function + dispatch: + CompositeExplicitAutograd: frexp + tags: pointwise + +- func: frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent) + dispatch: + CPU, CUDA: frexp_out + tags: pointwise + +# Deprecated (v.1.12) +- func: frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + variants: function + +# Deprecated (v.1.12) +- func: frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + variants: function + +# Deprecated (v.1.12) +- func: nuclear_norm(Tensor self, bool keepdim=False) -> Tensor + variants: function + +# Deprecated (v.1.12) +- func: nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + variants: function + +# Deprecated (v.1.12) +- func: nuclear_norm.dim(Tensor self, int[2] dim, bool keepdim=False) -> Tensor + variants: function + +# Deprecated (v.1.12) +- func: nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + variants: function + +- func: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: clone + SparseCPU, SparseCUDA: clone_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: clone_sparse_compressed + MkldnnCPU: mkldnn_clone + QuantizedCPU, QuantizedCUDA: quantized_clone + NestedTensorCPU, NestedTensorCUDA: clone_nested + autogen: clone.out + tags: [core, pointwise] + +- func: positive(Tensor(a) self) -> Tensor(a) + variants: function, method + tags: pointwise + +- func: resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!) + use_const_ref_for_mutable_tensors: True + variants: function, method + dispatch: + CompositeExplicitAutograd: resize_as_ + autogen: resize_as, resize_as.out + tags: inplace_view + +- func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!) + use_const_ref_for_mutable_tensors: True + variants: function, method + dispatch: + SparseCPU, SparseCUDA: resize_as_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: resize_as_sparse_compressed_ + autogen: resize_as_sparse, resize_as_sparse.out + +- func: zero_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CPU, CUDA: zero_ + MPS: zero_mps_ + Meta: zero_meta_ + SparseCPU, SparseCUDA, SparseMeta: zero_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: zero_sparse_csr_ + MkldnnCPU: mkldnn_zero_ + NestedTensorCPU, NestedTensorCUDA: zero_nested_ + autogen: zero, zero.out + +- func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: sub_out + MPS: sub_out_mps + SparseCPU, SparseCUDA: sub_out_sparse + tags: pointwise + +- func: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: sub.out + dispatch: + SparseCPU, SparseCUDA: sub_sparse + ZeroTensor: sub_zerotensor + NestedTensorCPU, NestedTensorCUDA: NestedTensor_sub_Tensor + tags: [core, pointwise] + +- func: sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: sub.out + dispatch: + SparseCPU, SparseCUDA: sub_sparse_ + tags: pointwise +# For C++ only, until we have conversion from C++ numbers to Tensor + +- func: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: sub + tags: [core, pointwise] + +- func: sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: sub_ + autogen: sub.Scalar_out + tags: pointwise +# subtract, alias for sub + +- func: subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + +- func: subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + variants: function, method + +- func: subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + variants: method + +# For C++ only, until we have conversion from C++ numbers to Tensor +- func: subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + variants: function, method + +- func: subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + variants: method + +- func: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CPU, CUDA: rsub + autogen: rsub.Tensor_out + +- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: heaviside_out + tags: pointwise + +- func: heaviside(Tensor self, Tensor values) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: heaviside.out + tags: pointwise + +- func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: heaviside.out + +# For C++ only, until we have conversion from C++ numbers to Tensor +- func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: rsub + autogen: rsub.Scalar_out + +# Functionally the same as addmm, but we give it a different derivative formula +# that doesn't propagate gradients to non-present entries on sparse. + tags: pointwise +- func: _sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + python_module: sparse + dispatch: + CompositeExplicitAutograd: _sparse_addmm + autogen: _sparse_addmm.out + +- func: sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + python_module: sparse + dispatch: + SparseCsrCUDA: sparse_sampled_addmm_out_sparse_csr_cuda + SparseCsrCPU: sparse_sampled_addmm_out_sparse_csr_cpu + +- func: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + python_module: sparse + dispatch: + SparseCsrCUDA: sparse_sampled_addmm_sparse_csr_cuda + SparseCsrCPU: sparse_sampled_addmm_sparse_csr_cpu + +- func: _sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor) + python_module: sparse + dispatch: + SparseCsrCPU: _sparse_mm_reduce_impl_sparse_csr_cpu + +- func: _sparse_mm_reduce_impl_backward(Tensor self, Tensor grad_out, Tensor weight, str reduce, Tensor arg_out, bool[2] output_mask) -> (Tensor, Tensor) + python_module: sparse + dispatch: + SparseCsrCPU: _sparse_mm_reduce_impl_backward_sparse_csr_cpu + +- func: addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU: addmm_out_cpu + CUDA: addmm_out_cuda + MPS: addmm_out_mps + SparseCPU: addmm_out_sparse_dense_cpu + SparseCUDA: addmm_out_sparse_dense_cuda + SparseCsrCPU: addmm_out_sparse_compressed_cpu + SparseCsrCUDA: addmm_out_sparse_compressed_cuda + +- func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + structured_delegate: addmm.out + variants: function, method + dispatch: + SparseCPU: addmm_sparse_dense_cpu + SparseCUDA: addmm_sparse_dense_cuda + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: addmm_sparse_compressed_dense + tags: core + +- func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + structured_delegate: addmm.out + variants: method + dispatch: + # Warning! For whatever reason, the inplace sparse addmm is NON + # broadcasting + SparseCPU: s_addmm_sparse_dense_cpu_ + SparseCUDA: s_addmm_sparse_dense_cuda_ + +- func: _addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU: addmm_activation_out_cpu + CUDA: addmm_activation_out_cuda + +- func: _addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor + structured_delegate: _addmm_activation.out + variants: function, method + +- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor + variants: function + dispatch: + CUDA: _scaled_mm_cuda + +- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) + variants: function + dispatch: + CUDA: _scaled_mm_out_cuda + +# NOTE [ Sparse: autograd and API ] +# +# +# Sparse Tensor Constructors +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The API entry points to sparse tensor construction should be +# `sparse_coo tensor` and `_sparse_coo_tensor_unsafe`. Depending on whether the +# indices and values tensors are given, they eventually dispatch to either +# `sparse_coo_tensor_with_dims` or `sparse_coo_tensor_with_dims_and_tensors`. +# +# The autograd support for ctor is implement on `sparse_coo_tensor_with_dims_and_tensors`. +# +# The API methods `sparse_coo tensor` and `_sparse_coo_tensor_unsafe` +# **must not** have specific type dispatches because otherwise codegen will +# consider them as abstract methods (see Note [Abstract ATen methods]), dispatch +# using **Tensor** type, and thus lose autograd tracking on the actual method +# they dispatch to, e.g., `sparse_coo_tensor_with_dims_and_tensors`. +# +# +# Sparse Methods API Design +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Goals: 1. Flexible API for users to write custom sparse ops +# 2. ctor and member accessor with autograd support +# +# To achieve 1, we need to provide a set of *dangerous* APIs (dangerous in the +# sense that misusing them will break sparse tensor invariant and may out in +# unexpected behavior, e.g., crash). These methods are all prefixed with +# underscore "_" to indicate that they should be used with care. We provide: +# +# + `_indices()`: returns the *raw* indices within the sparse tensor (not just +# sharing storage). Any inplace operation will change the +# actual indices, including t_, set_, as_strided_, resize_, +# etc. +# + `_values()`: returns the *raw* values within the sparse tensor. Similar +# semantics as `_indices()` +# + `_nnz()`: returns the number of non-zero entries. This will always be +# determined by the shapes of indices and values. +# + `_coalesced_(bool)`: inplace sets whether the tensor is coalesced, and +# returns itself. +# +# These methods are very useful in writing new operations, e.g., a custom +# autograd Function. +# +# We also provide other public *safe* APIs: +# + `indices()`: returns a **view** of the indices tensor if the sparse tensor +# is **coalesced**. +# + `values()`: returns a **view** of the values tensor if the containing +# sparse tensor is **coalesced**. +# + `sparse_dim()`: number of sparse dimensions +# + `dense_dim()`: number of dense dimensions +# + `is_coalesced()`: whether the sparse tensor is coalesced +# +# `_indices()` and `_values()` should returns the raw indices and values dense +# tensors within a sparse tensor. They can be quite unsafe with inplace +# operations like `t_()`, and exposes uncoalesced indices and values. The public +# recommended API is `indices()` and `values()`, both of which first check that +# the tensor is coalesced and return views on those tensors. +# +# +# Autograd Support +# ~~~~~~~~~~~~~~~~ +# +# Autograd is supported on `values()` and sparse tensor ctor with indices and +# values tensors. E.g., `torch.sparse_coo_tensor(i, v).values().sum()` is +# differentiable w.r.t. `v`. +# +# NB: The `values()` and `_values()` operators are special in that they are +# layout-aware, i.e., the output depends not just on the data it represents, but +# also on the input layout details (in this case, the `indices` tensor). See +# NOTE [ as_strided Backward and layout-aware/agnostic autograd ] in Functions.cpp +# for discussion on layout-aware vs layout-agnostic autograd. Since PyTorch ops +# operate in the layout-agnostic mode, similar to `as_strided`, backward of +# these two operators need to consider them in a layout-agnostic way: +# + `values()`: +# Input is coalesced. +# We just pretend having `input.indices()` as an additional argument +# `input_indices`, then forward is similar to +# `input.to(kStrided).index_select(input_indices)` regardless of the layout. +# Note that `values()` normally is layout-aware even if we constrain +# ourselves on sparse inputs since it may include all zeros values entries +# as "present" entries. +# + `_values()`: +# Input may be uncoalesced. +# It is not straightforward to construct a layout-agnostic version because +# duplicate indices entries may exist and additional parameterization is +# needed to distribute the value into different values entries. Furthermore, +# this op is intended to provide ways to write custom sparse ops, rather +# than being used in autograd graph, so it is marked as *non-differentiable* +# in derivatives.yaml. +# +# Before reading the following, see NOTE [ Autograd Variable Views ] in +# variable.h for details on views that are tracked by autograd, and views that +# are not. +# +# Moreover, these methods return tensors that share storage with inputs, so we +# mark these methods as view ops to support autograd history tracking. +# The sparse tensor ctor output should technically be view of both input indices +# and values tensors, but currently we only support setting as view of a single +# Variable, so it is only view of the values tensor. +# TODO: clone indices in sparse tensor ctor. +# +# For other methods that return outputs that share storage with inputs, i.e., +# `indices()` and `_indices()`. We mark their outputs as non-differentiable, so +# the view relation is not tracked by autograd, but the version counter is still +# shared. In other words, their outputs are non-differentiable views of the +# sparse tensor. +# FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given +# the default would never make sense. + +- func: _sparse_compressed_tensor_with_dims(int nnz, int dense_dim, int[] size, int[] blocksize, ScalarType index_dtype, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + dispatch: + CompositeExplicitAutograd: sparse_compressed_tensor_with_dims + +- func: sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + dispatch: + CompositeExplicitAutograd: sparse_compressed_tensor + +- func: sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor +- func: sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor +- func: sparse_bsr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor +- func: sparse_bsc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + +- func: sparse_compressed_tensor.comp_plain_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + dispatch: + CompositeExplicitAutograd: sparse_compressed_tensor +- func: sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor +- func: sparse_csc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor +- func: sparse_bsr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor +- func: sparse_bsc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + +- func: _sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeImplicitAutograd: _sparse_compressed_tensor_unsafe_symint + +- func: _sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: _sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: _sparse_bsr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: _sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + +- func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + dispatch: + CompositeExplicitAutograd: sparse_coo_tensor + autogen: sparse_coo_tensor.size_out + +- func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + +- func: sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + +- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + dispatch: + CompositeImplicitAutograd: _sparse_coo_tensor_unsafe_symint + +- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None) -> () + +- func: _validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout) -> () +- func: _validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> () +- func: _validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> () +- func: _validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> () +- func: _validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> () + +- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + dispatch: + SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse + autogen: _sparse_coo_tensor_with_dims.out + +- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor + dispatch: + SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse_symint + autogen: _sparse_coo_tensor_with_dims_and_tensors.out + +- func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) + use_const_ref_for_mutable_tensors: True + variants: method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: sparse_resize_ + autogen: sparse_resize, sparse_resize.out + +- func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) + use_const_ref_for_mutable_tensors: True + variants: method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: sparse_resize_and_clear_ + autogen: sparse_resize_and_clear, sparse_resize_and_clear.out + +- func: sparse_mask(Tensor self, Tensor mask) -> Tensor + variants: method + dispatch: + SparseCPU, SparseCUDA: sparse_mask + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_mask_sparse_compressed + autogen: sparse_mask.out + +- func: _sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor + variants: method + dispatch: + SparseCPU, SparseCUDA: sparse_mask_projection + autogen: _sparse_mask_projection.out + +- func: _to_cpu(Tensor[] tensors) -> Tensor[] + variants: function + +- func: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor + variants: method + +# Special case of to_dense with custom derivative +- func: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor + variants: method + dispatch: + SparseCPU, SparseCUDA: sparse_to_dense + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_dense + MkldnnCPU: mkldnn_to_dense + autogen: _to_dense.out + +- func: to_dense_backward(Tensor grad, Tensor input, bool? masked_grad=None) -> Tensor + +- func: sparse_dim(Tensor self) -> int + variants: method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: sparse_dim_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_dim_sparse_csr + CompositeExplicitAutograd: sparse_dim_default + device_check: NoCheck + device_guard: False + +# legacy method +- func: _dimI(Tensor self) -> int + variants: method + dispatch: + SparseCPU, SparseCUDA: sparse_dim_sparse + device_check: NoCheck + device_guard: False + +- func: dense_dim(Tensor self) -> int + variants: method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: dense_dim_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: dense_dim_sparse_csr + CompositeExplicitAutograd: dense_dim_default + device_check: NoCheck + device_guard: False + +# legacy method +- func: _dimV(Tensor self) -> int + variants: method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: dense_dim_sparse + device_check: NoCheck + device_guard: False + +- func: _nnz(Tensor self) -> int + variants: method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: _nnz_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _nnz_sparse_csr + device_check: NoCheck + device_guard: False + +# NOTE: [ coalesce autograd ] +# coalesce returns self directly for already coalesced sparse tensors. +# This means coalesce cannot have a derivative registered, otherwise it creates +# circular references in the autograd graph (see gh-52874). +# Instead, the derivative is registered on the slow-path "_coalesce" +- func: coalesce(Tensor(a) self) -> Tensor(a) + variants: method + +- func: _coalesce(Tensor self) -> Tensor + dispatch: + SparseCPU: _coalesce_sparse_cpu + SparseCUDA: _coalesce_sparse_cuda + autogen: _coalesce.out + +- func: is_coalesced(Tensor self) -> bool + variants: method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: is_coalesced_sparse + CompositeExplicitAutograd: is_coalesced_default + device_check: NoCheck + device_guard: False + +- func: _indices(Tensor(a) self) -> Tensor(a) + variants: method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: _indices_sparse + device_check: NoCheck + device_guard: False + +- func: _values(Tensor(a) self) -> Tensor(a) + variants: method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: _values_sparse + device_check: NoCheck + device_guard: False + +# This method doesn't do any check but only directly sets the flag. So it can be +# a bit unsafe. Similar to _indices and _values, this is useful for implementing +# custom sparse operations in Python/C++ extension. +- func: _coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!) + variants: method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: _coalesced_sparse_ + device_check: NoCheck + device_guard: False + autogen: _coalesced, _coalesced.out + +- func: indices(Tensor(a) self) -> Tensor(a) + variants: method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: indices_sparse + CompositeExplicitAutograd: indices_default + device_check: NoCheck + device_guard: False + +- func: values(Tensor(a) self) -> Tensor(a) + variants: method + dispatch: + SparseCPU, SparseCUDA, SparseMeta: values_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: values_sparse_csr + NestedTensorCPU, NestedTensorCUDA: values_nested + CompositeExplicitAutograd: values_default + device_check: NoCheck + device_guard: False + +- func: crow_indices(Tensor(a) self) -> Tensor(a) + variants: method + dispatch: + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: crow_indices_sparse_csr + CompositeExplicitAutograd: crow_indices_default + device_check: NoCheck + device_guard: False + +- func: col_indices(Tensor(a) self) -> Tensor(a) + variants: method + dispatch: + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: col_indices_sparse_csr + CompositeExplicitAutograd: col_indices_default + device_check: NoCheck + device_guard: False + +- func: ccol_indices(Tensor(a) self) -> Tensor(a) + variants: method + dispatch: + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ccol_indices_sparse_csr + CompositeExplicitAutograd: ccol_indices_default + device_check: NoCheck + device_guard: False + +- func: row_indices(Tensor(a) self) -> Tensor(a) + variants: method + dispatch: + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: row_indices_sparse_csr + CompositeExplicitAutograd: row_indices_default + device_check: NoCheck + device_guard: False + +- func: hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + SparseCPU: hspmm_out_sparse_cpu + SparseCUDA: hspmm_out_sparse_cuda + +- func: hspmm(Tensor mat1, Tensor mat2) -> Tensor + dispatch: + SparseCPU: hspmm_sparse_cpu + SparseCUDA: hspmm_sparse_cuda + +- func: copy_sparse_to_sparse_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + device_check: NoCheck # Allows copy into different device + variants: function + dispatch: + SparseCPU, SparseCUDA, SparseMeta: copy_sparse_ + autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out + +# By adding the AutogradNestedTensor this makes this function CompositeImplicit-like for nested tensors +- func: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] + variants: function, method + dispatch: + CompositeExplicitAutograd: unbind + NestedTensorCPU, NestedTensorCUDA: NestedTensor_unbind + +- func: unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[] + variants: function, method + +- func: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor + variants: method + +# Special case of to_sparse.sparse_dim with custom derivative +- func: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor + variants: method + dispatch: + CPU, CUDA: dense_to_sparse + SparseCPU, SparseCUDA: sparse_coo_to_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse + autogen: _to_sparse.sparse_dim_out + +- func: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor + variants: method + +# Special case of to_sparse with custom derivative +- func: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor + variants: method + dispatch: + CPU, CUDA: dense_to_sparse + SparseCPU, SparseCUDA: sparse_coo_to_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse + autogen: _to_sparse.out + +- func: to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor + variants: method + +# Special case of to_sparse_csr with custom derivative +- func: _to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor + variants: method + dispatch: + CPU, CUDA: dense_to_sparse_csr + SparseCPU, SparseCUDA: coo_to_sparse_csr + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse_csr + autogen: _to_sparse_csr.out + +- func: to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor + variants: method + +# Special case of to_sparse_csc with custom derivative +- func: _to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor + variants: method + dispatch: + CPU, CUDA: dense_to_sparse_csc + SparseCPU, SparseCUDA: coo_to_sparse_csc + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse_csc + autogen: _to_sparse_csc.out + +- func: to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + variants: method + +# Special case of to_sparse_bsr with custom derivative +- func: _to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + variants: method + dispatch: + CPU, CUDA: dense_to_sparse_bsr + SparseCPU, SparseCUDA: coo_to_sparse_bsr + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse_bsr + autogen: _to_sparse_bsr.out + +- func: to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + variants: method + +# Special case of to_sparse_bsc with custom derivative +- func: _to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + variants: method + dispatch: + CPU, CUDA: dense_to_sparse_bsc + SparseCPU, SparseCUDA: coo_to_sparse_bsc + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse_bsc + autogen: _to_sparse_bsc.out + +- func: _to_sparse_semi_structured(Tensor dense) -> (Tensor, Tensor) + variants: function + dispatch: + CUDA: _to_sparse_semi_structured + +- func: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor + variants: method + dispatch: + CPU: dense_to_mkldnn + autogen: to_mkldnn.out + +- func: mkldnn_reorder_conv2d_weight(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor + variants: function + python_module: nn + dispatch: + MkldnnCPU: mkldnn_reorder_conv2d_weight + autogen: mkldnn_reorder_conv2d_weight.out + +- func: mkldnn_reorder_conv3d_weight(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor + variants: function + python_module: nn + dispatch: + MkldnnCPU: mkldnn_reorder_conv3d_weight + autogen: mkldnn_reorder_conv3d_weight.out + +- func: to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor + +- func: quantize_per_tensor_dynamic(Tensor self, ScalarType dtype, bool reduce_range) -> Tensor + variants: function + dispatch: + CPU, CUDA: quantize_per_tensor_dynamic + autogen: quantize_per_tensor_dynamic.out + +- func: quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor + variants: function + dispatch: + CPU, CUDA: quantize_per_tensor + autogen: quantize_per_tensor.out + +- func: quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor + variants: function + dispatch: + CPU, CUDA: quantize_per_tensor_tensor_qparams + autogen: quantize_per_tensor.tensor_qparams_out + +- func: quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype) -> Tensor[] + variants: function + dispatch: + CPU: quantize_per_tensor_list_cpu + autogen: quantize_per_tensor.tensors_out + +- func: quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype) -> Tensor + variants: function + dispatch: + CPU, CUDA: quantize_per_channel + autogen: quantize_per_channel.out + +- func: dequantize.self(Tensor self) -> Tensor + variants: function, method + dispatch: + CPU, CUDA: dequantize_cpu_or_cuda + QuantizedCPU, QuantizedCUDA: dequantize_quantized + autogen: dequantize.self_out + +- func: dequantize.tensors(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + QuantizedCPU: dequantize_tensors_quantized_cpu + autogen: dequantize.tensors_out + +- func: q_scale(Tensor self) -> float + variants: function, method + dispatch: + QuantizedCPU, QuantizedCUDA: q_scale_quant + +- func: q_zero_point(Tensor self) -> int + variants: function, method + dispatch: + QuantizedCPU, QuantizedCUDA: q_zero_point_quant + +- func: q_per_channel_scales(Tensor self) -> Tensor + variants: function, method + dispatch: + QuantizedCPU, QuantizedCUDA: q_per_channel_scales + autogen: q_per_channel_scales.out + +- func: q_per_channel_zero_points(Tensor self) -> Tensor + variants: function, method + dispatch: + QuantizedCPU, QuantizedCUDA: q_per_channel_zero_points + autogen: q_per_channel_zero_points.out + +- func: q_per_channel_axis(Tensor self) -> int + variants: function, method + dispatch: + QuantizedCPU, QuantizedCUDA: q_per_channel_axis + +- func: int_repr(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + QuantizedCPU: int_repr_quantized_cpu + QuantizedCUDA: int_repr_quantized_cuda + autogen: int_repr.out + +- func: _make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor + dispatch: + CPU: make_per_tensor_quantized_tensor_cpu + CUDA: make_per_tensor_quantized_tensor_cuda + autogen: _make_per_tensor_quantized_tensor.out + +- func: _make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int axis) -> Tensor + dispatch: + CPU: make_per_channel_quantized_tensor_cpu + CUDA: make_per_channel_quantized_tensor_cuda + autogen: _make_per_channel_quantized_tensor.out + +- func: qscheme(Tensor self) -> QScheme + variants: method + dispatch: + QuantizedCPU, QuantizedCUDA: qscheme_quant + +- func: fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + +- func: fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + +- func: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + variants: function + dispatch: + CPU, CUDA: fake_quantize_per_tensor_affine_cachemask + autogen: fake_quantize_per_tensor_affine_cachemask.out + +- func: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + variants: function + dispatch: + CPU, CUDA: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams + autogen: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out + +- func: fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor + variants: function + +- func: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor + variants: function + dispatch: + CPU, CUDA: _fake_quantize_learnable_per_tensor_affine + autogen: _fake_quantize_learnable_per_tensor_affine.out + +- func: _fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CPU, CUDA: _fake_quantize_learnable_per_tensor_affine_backward + +- func: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + +- func: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + variants: function + dispatch: + CPU, CUDA: fake_quantize_per_channel_affine_cachemask + autogen: fake_quantize_per_channel_affine_cachemask.out + +- func: fake_quantize_per_channel_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor + variants: function + +- func: _fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor + variants: function + dispatch: + CPU, CUDA: _fake_quantize_learnable_per_channel_affine + autogen: _fake_quantize_learnable_per_channel_affine.out + +- func: _fake_quantize_learnable_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CPU, CUDA: _fake_quantize_learnable_per_channel_affine_backward + +- func: fused_moving_avg_obs_fake_quant(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> Tensor + variants: function + +- func: _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) + dispatch: + CPU: fused_moving_avg_obs_fake_quant_cpu + CUDA: fused_moving_avg_obs_fake_quant_cuda + autogen: _fused_moving_avg_obs_fq_helper_functional, _fused_moving_avg_obs_fq_helper.out + +- func: _choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int) + variants: function + +- func: _saturate_weight_to_fp16(Tensor weight) -> Tensor + variants: function + +- func: choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor) + variants: function + +- func: _autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a) + variants: method + device_guard: False + +- func: _autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a) + variants: method + device_guard: False + +- func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: _to_copy + NestedTensorCPU, NestedTensorCUDA: _to_copy_nested + autogen: _to_copy.out + tags: core + +# to(Device) must not exist because all constructors of Device also works for +# TensorOptions. Otherwise, an ambiguity error is thrown. +# See NOTE [ TensorOptions Constructors ]. +- func: to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + variants: method + device_check: NoCheck + device_guard: False + +- func: to.device(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + variants: method + device_check: NoCheck + device_guard: False + +- func: to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + variants: method + device_check: NoCheck + device_guard: False + +- func: to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + variants: method + device_check: NoCheck + device_guard: False + +- func: meshgrid(Tensor[] tensors) -> Tensor[] + +# TODO: Two weeks after this lands, combine these two overloads, +# making "indexing" optional. These are temporarily distinct for +# forward-compatibility reasons. +- func: meshgrid.indexing(Tensor[] tensors, *, str indexing) -> Tensor[] + +- func: cartesian_prod(Tensor[] tensors) -> Tensor + variants: function + +- func: combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor + variants: function + +- func: item(Tensor self) -> Scalar + tags: data_dependent_output + variants: method + +- func: result_type.Tensor(Tensor tensor, Tensor other) -> ScalarType + variants: function + +- func: result_type.Scalar(Tensor tensor, Scalar other) -> ScalarType + variants: function + +- func: result_type.Scalar_Tensor(Scalar scalar, Tensor tensor) -> ScalarType + variants: function + +- func: result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType + +- func: can_cast(ScalarType from_, ScalarType to) -> bool + variants: function + +- func: promote_types(ScalarType type1, ScalarType type2) -> ScalarType + variants: function + +# NB: Does NOT check precondition that numel == 1 +- func: _local_scalar_dense(Tensor self) -> Scalar + tags: [core, data_dependent_output] + dispatch: + CPU: _local_scalar_dense_cpu + CUDA: _local_scalar_dense_cuda + MPS: _local_scalar_dense_mps + variants: function + +# MPS LSTM implementation + +- func: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) + dispatch: + MPS: _lstm_mps + autogen: _lstm_mps.out + tags: nondeterministic_seeded + +- func: lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[]) + dispatch: + MPS: lstm_mps_backward + autogen: lstm_mps_backward.out + + +# Fused RNN kernels +- func: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: _thnn_fused_lstm_cell_cuda + autogen: _thnn_fused_lstm_cell.out + +# NB: The composite version of this function below is a simple wrapper that duplicates some of the outputs +# It is necessary to avoid triggering TensorImpl use count checks in debug mode +# NB: this is function is NOT differentiable +- func: _thnn_fused_lstm_cell_backward_impl(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: _thnn_fused_lstm_cell_backward_impl_cuda + autogen: _thnn_fused_lstm_cell_backward_impl.out + +- func: _thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + +- func: _thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + +- func: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor) + dispatch: + CUDA: _thnn_fused_gru_cell_cuda + autogen: _thnn_fused_gru_cell.out + +- func: _thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + dispatch: + CUDA: _thnn_fused_gru_cell_backward_cuda + autogen: _thnn_fused_gru_cell_backward.out + +- func: _thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + +# RNN cells and layers +- func: lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor) + tags: nondeterministic_seeded + +- func: lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor) + tags: nondeterministic_seeded + +- func: gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + tags: nondeterministic_seeded + +- func: gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + tags: nondeterministic_seeded + +- func: rnn_tanh.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + tags: nondeterministic_seeded + +- func: rnn_tanh.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + tags: nondeterministic_seeded + +- func: rnn_relu.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + tags: nondeterministic_seeded + +- func: rnn_relu.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + tags: nondeterministic_seeded + +- func: lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor) + +- func: gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor + +- func: rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor + +- func: rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor + +# Quantized RNN layer registration has been moved to C10 dispatch in `RNN.cpp` + +# Quantized RNN layers +# - func: quantized_lstm(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor) + + +# - func: quantized_lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor) + + +# Quantized GRU layers + +# - func: quantized_gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) +# + +# - func: quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) +# + +# Quantized RNN cells +- func: quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor) + +- func: quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + +- func: quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + +- func: quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + +# PackedSequence utilities +- func: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor) + dispatch: + CompositeExplicitAutograd: _pack_padded_sequence + autogen: _pack_padded_sequence.out + +- func: _pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor + dispatch: + CompositeImplicitAutograd: _pack_padded_sequence_backward_symint + +- func: _pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor) + +# wrappers for legacy TH methods + +- func: set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + dispatch: + CPU, CUDA, Meta, MPS: set_ + autogen: set.source_Storage, set.source_Storage_out + tags: inplace_view + +- func: set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + dispatch: + CPU: set_storage_cpu_ + Meta: set_storage_meta__symint + CUDA: set_storage_cuda_ + MPS: set_storage_mps_ + QuantizedCPU, QuantizedCUDA: set_storage_quantized_ + autogen: set.source_Storage_storage_offset, set.source_Storage_storage_offset_out + tags: inplace_view + +- func: set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + dispatch: + CompositeImplicitAutograd: set__symint + tags: inplace_view + +- func: set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + dispatch: + CPU, CUDA, Meta, MPS: set_tensor_ + autogen: set.source_Tensor, set.source_Tensor_out + tags: inplace_view + +- func: set_(Tensor(a!) self) -> Tensor(a!) + variants: method + dispatch: + CPU: set_cpu_ + CUDA: set_cuda_ + Meta: set_meta_ + MPS: set_mps_ + autogen: set, set.out + tags: inplace_view + +# Not making it CompositeImplicitAutograd because lift +# should be a primitive w.r.t. functorch + +# TODO: this should have a view annotation +# TODO: shouldn't be a method +- func: lift(Tensor self) -> Tensor + dispatch: + CompositeExplicitAutograd: lift + autogen: lift.out + +# lift_fresh is called with an argument that is guaranteed to be +# fresh (i.e., newly allocated). This is ONLY called from a +# torch.tensor call; if you FX trace a lift_fresh, you are obligated +# to convert this into a lift_fresh_copy (because FX will violate the +# freshness invariant when tracing). +- func: lift_fresh(Tensor(a) self) -> Tensor(a) + dispatch: + CompositeExplicitAutograd: lift_fresh + +# Like lift, but it clones the input. +- func: lift_fresh_copy(Tensor self) -> Tensor + tags: view_copy + dispatch: + CompositeExplicitAutogradNonFunctional: lift_fresh_copy + autogen: lift_fresh_copy.out + +- func: is_set_to(Tensor self, Tensor tensor) -> bool + variants: method + device_check: NoCheck + device_guard: False + dispatch: + CPU, CUDA, MPS: is_set_to + +- func: masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CPU: masked_fill__cpu + CUDA: masked_fill__cuda + QuantizedCPU: masked_fill__quantized_cpu + QuantizedCUDA: masked_fill__quantized_cuda + MPS: masked_fill__mps + autogen: masked_fill.Scalar_out + +- func: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: masked_fill + NestedTensorCPU, NestedTensorCUDA: NestedTensor_masked_fill + tags: pointwise + +- func: masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CPU: masked_fill__cpu + CUDA: masked_fill__cuda + QuantizedCPU: masked_fill__quantized_cpu + QuantizedCUDA: masked_fill__quantized_cuda + MPS: masked_fill__mps + autogen: masked_fill.Tensor_out + +- func: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: masked_fill + +- func: masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!) + variants: method + dispatch: + CPU: masked_scatter__cpu + CUDA: masked_scatter__cuda + MPS: masked_scatter__mps + autogen: masked_scatter.out + +- func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: masked_scatter + +- func: masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor + dispatch: + CompositeExplicitAutograd: masked_scatter_backward_symint + +- func: _masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor + dispatch: + CUDA: masked_softmax_cuda + CPU: masked_softmax_cpu + autogen: _masked_softmax.out + +- func: _masked_softmax_backward(Tensor grad_output, Tensor output, Tensor mask, int? dim=None) -> Tensor + dispatch: + CUDA: masked_softmax_backward_cuda + CPU: masked_softmax_backward_cpu + autogen: _masked_softmax_backward.out + +- func: view(Tensor(a) self, SymInt[] size) -> Tensor(a) + variants: method + device_check: NoCheck + device_guard: False + dispatch: + ZeroTensor, Meta, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view + MkldnnCPU: mkldnn_view + NestedTensorCPU, NestedTensorCUDA: view_nested + tags: core + +# Warning: If you want to change the name or overload name of this +# operator, you might also want to change the `isBlockListedSchema` +# function in `torch/csrc/jit/frontend/schema_catching.cpp`. +# The name and overload name of this operator is hardcoded in that +# function in order to workaround a bug: +# https://github.com/pytorch/pytorch/issues/47964 +- func: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) + variants: method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: view_dtype + +- func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) + variants: method + dispatch: + CPU, CUDA: put_ + autogen: put.out + +- func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor + variants: function, method + dispatch: + CompositeExplicitAutograd: put + +- func: index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + precomputed: + - dim -> int dim + dispatch: + CPU: index_add_cpu_out + CUDA: index_add_cuda_out + MPS: index_add_mps_out + +- func: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!) + structured_delegate: index_add.out + variants: method + +- func: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor + structured_delegate: index_add.out + variants: function, method + +- func: index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor + variants: function, method + +- func: index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + precomputed: + - dim -> int dim + dispatch: + CPU: index_reduce_cpu_out + CUDA: index_reduce_cuda_out + +- func: index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!) + structured_delegate: index_reduce.out + variants: method + +- func: index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor + structured_delegate: index_reduce.out + variants: function, method + +- func: index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CPU: index_fill_ + CUDA: index_fill_ + MPS: index_fill_mps_ + autogen: index_fill.int_Scalar_out + +- func: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: index_fill + +- func: index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CPU, CUDA: index_fill_ + MPS: index_fill_mps_ + autogen: index_fill.int_Tensor_out + +- func: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: index_fill + +- func: index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + +- func: index_fill_.Dimname_Tensor(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + +- func: index_fill.Dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: index_fill.Dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + +- func: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + structured_delegate: scatter.src_out + variants: function, method + tags: core + +- func: scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) + structured_delegate: scatter.src_out + variants: method + +- func: scatter.src_out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + dispatch: + CPU, CUDA: scatter_src_out + MPS: scatter_src_out_mps + +- func: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + structured_delegate: scatter.value_out + variants: function, method + tags: core + +- func: scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) + structured_delegate: scatter.value_out + variants: method + +- func: scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + dispatch: + CPU, CUDA: scatter_value_out + MPS: scatter_value_out_mps + +- func: scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor + structured_delegate: scatter.reduce_out + variants: function, method + +- func: scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!) + structured_delegate: scatter.reduce_out + variants: method + +- func: scatter.reduce_out(Tensor self, int dim, Tensor index, Tensor src, *, str reduce, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + dispatch: + CPU, CUDA: scatter_reduce_out + MPS: scatter_reduce_out_mps + +- func: scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor + structured_delegate: scatter.value_reduce_out + variants: function, method + +- func: scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!) + structured_delegate: scatter.value_reduce_out + variants: method + +- func: scatter.value_reduce_out(Tensor self, int dim, Tensor index, Scalar value, *, str reduce, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + dispatch: + CPU, CUDA: scatter_value_reduce_out + MPS: scatter_value_reduce_out_mps + +- func: scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor + variants: function, method + +- func: scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor + variants: function, method + +- func: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + structured_delegate: scatter_add.out + variants: function, method + tags: core + +- func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) + structured_delegate: scatter_add.out + variants: method + +- func: scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + dispatch: + CPU, CUDA: scatter_add + MPS: scatter_add_mps_out + +- func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor + variants: function, method + +- func: scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor + structured_delegate: scatter_reduce.two_out + variants: function, method + tags: core + +- func: scatter_reduce_.two(Tensor(a!) self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor(a!) + structured_delegate: scatter_reduce.two_out + variants: method + +- func: scatter_reduce.two_out(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + dispatch: + CPU, CUDA: scatter_reduce_two + +- func: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + structured_delegate: eq.Scalar_out + device_check: NoCheck # TensorIterator + variants: method + +- func: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + structured_delegate: eq.Tensor_out + device_check: NoCheck # TensorIterator + variants: method + +- func: bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + variants: function + dispatch: + CPU, CUDA: bitwise_and_out + MPS: bitwise_and_out_mps + tags: pointwise + +- func: bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: bitwise_and_out + tags: pointwise + +- func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CompositeExplicitAutograd: bitwise_and + tags: [core, pointwise] + +- func: bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: bitwise_and + autogen: bitwise_and.Scalar_Tensor_out + tags: pointwise + +- func: bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + structured_delegate: bitwise_and.Tensor_out + tags: [core, pointwise] + +- func: bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: bitwise_and_ + tags: pointwise + +- func: bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: bitwise_and.Tensor_out + tags: pointwise + +- func: __and__.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + +- func: __and__.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + +- func: __iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + +- func: __iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + +- func: bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + variants: function + dispatch: + CPU, CUDA: bitwise_or_out + MPS: bitwise_or_out_mps + tags: pointwise + +- func: bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: bitwise_or_out + tags: pointwise + +- func: bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CompositeExplicitAutograd: bitwise_or + tags: [core, pointwise] + +- func: bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: bitwise_or + autogen: bitwise_or.Scalar_Tensor_out + tags: pointwise + +- func: bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + structured_delegate: bitwise_or.Tensor_out + tags: [core, pointwise] + +- func: bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: bitwise_or_ + tags: pointwise + +- func: bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: bitwise_or.Tensor_out + tags: pointwise + +- func: __or__.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + +- func: __or__.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + +- func: __ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + +- func: __ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + +- func: bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + variants: function + dispatch: + CPU, CUDA: bitwise_xor_out + MPS: bitwise_xor_out_mps + tags: pointwise + +- func: bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: bitwise_xor_out + tags: pointwise + +- func: bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CompositeExplicitAutograd: bitwise_xor + tags: [core, pointwise] + +- func: bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: bitwise_xor + autogen: bitwise_xor.Scalar_Tensor_out + tags: pointwise + +- func: bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + structured_delegate: bitwise_xor.Tensor_out + tags: [core, pointwise] + +- func: bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: bitwise_xor_ + tags: pointwise + +- func: bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: bitwise_xor.Tensor_out + tags: pointwise + +- func: __xor__.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + tags: pointwise + +- func: __xor__.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + tags: pointwise + +- func: __ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + tags: pointwise + +- func: __ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + tags: pointwise + +- func: __lshift__.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CPU, CUDA, MPS: __lshift__ + tags: pointwise + +- func: __lshift__.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CPU, CUDA, MPS: __lshift__ + tags: pointwise + +- func: __ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CPU, CUDA, MPS: __ilshift__ + autogen: __lshift__.Scalar_out + tags: pointwise + +- func: __ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CPU, CUDA, MPS: __ilshift__ + autogen: __lshift__.Tensor_out + tags: pointwise + +- func: bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: bitwise_left_shift.Tensor_out + tags: pointwise + +- func: bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: bitwise_left_shift.Tensor_out + tags: pointwise + +- func: bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA, MPS: bitwise_left_shift_out + tags: pointwise + +- func: bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CompositeExplicitAutograd: bitwise_left_shift + tags: pointwise + +- func: bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: bitwise_left_shift_ + tags: pointwise + +- func: bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: bitwise_left_shift_out + tags: pointwise + +- func: bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: bitwise_left_shift + autogen: bitwise_left_shift.Scalar_Tensor_out + tags: pointwise + +- func: __rshift__.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CPU, CUDA, MPS: __rshift__ + tags: pointwise + +- func: __rshift__.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CPU, CUDA, MPS: __rshift__ + tags: pointwise + +- func: __irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CPU, CUDA, MPS: __irshift__ + autogen: __rshift__.Scalar_out + +- func: __irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CPU, CUDA, MPS: __irshift__ + autogen: __rshift__.Tensor_out + +- func: bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: bitwise_right_shift.Tensor_out + tags: pointwise + +- func: bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: bitwise_right_shift.Tensor_out + tags: pointwise + +- func: bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA, MPS: bitwise_right_shift_out + tags: pointwise + +- func: bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CompositeExplicitAutograd: bitwise_right_shift + tags: pointwise + +- func: bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: bitwise_right_shift_ + tags: pointwise + +- func: bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: bitwise_right_shift_out + tags: pointwise + +- func: bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CompositeExplicitAutograd: bitwise_right_shift + autogen: bitwise_right_shift.Scalar_Tensor_out + tags: pointwise + +- func: tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) + structured_delegate: tril.out + variants: method + +- func: triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) + structured_delegate: triu.out + variants: method + +- func: digamma_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: digamma.out + variants: method + tags: pointwise + +- func: lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: lerp.Scalar_out + tags: pointwise + +- func: lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: lerp.Tensor_out + tags: pointwise + +- func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + variants: method + dispatch: + CPU, CUDA: addbmm_ + MPS: addbmm_mps_ + +- func: addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: addbmm_out + MPS: addbmm_out_mps + +- func: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + variants: method, function + dispatch: + CPU, CUDA: addbmm + MPS: addbmm_mps + +- func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + tags: nondeterministic_seeded + dispatch: + CPU, CUDA: random_ + Meta: random_meta_ + MPS: random_mps_ + autogen: random.from, random.from_out + +- func: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + tags: nondeterministic_seeded + variants: method + dispatch: + CPU, CUDA: random_ + Meta: random_meta_ + MPS: random_mps_ + autogen: random.to, random.to_out + +- func: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + tags: nondeterministic_seeded + variants: method + dispatch: + CPU, CUDA: random_ + MPS: random_mps_ + Meta: random_meta_ + autogen: random, random.out + +- func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + tags: nondeterministic_seeded + variants: method + dispatch: + CPU, CUDA: uniform_ + MPS: uniform_mps_ + Meta: uniform_meta_ + autogen: uniform, uniform.out + +- func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + tags: nondeterministic_seeded + dispatch: + CPU, CUDA: cauchy_ + autogen: cauchy, cauchy.out + +- func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + tags: nondeterministic_seeded + variants: method + dispatch: + CPU, CUDA: log_normal_ + autogen: log_normal, log_normal.out + +- func: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + tags: nondeterministic_seeded + variants: method + dispatch: + CPU, CUDA: exponential_ + MPS: exponential_mps_ + autogen: exponential, exponential.out + +- func: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + tags: nondeterministic_seeded + variants: method + dispatch: + CPU, CUDA: geometric_ + + # wrappers for TH functions + autogen: geometric, geometric.out + +- func: diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + +- func: diag(Tensor self, int diagonal=0) -> Tensor + variants: method, function + +- func: cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + +- func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor + variants: method, function + +- func: triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU: triu_cpu + CUDA: triu_cuda + MPS: triu_mps_out + +- func: triu(Tensor self, int diagonal=0) -> Tensor + structured_delegate: triu.out + variants: method, function + +- func: tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU: tril_cpu + CUDA: tril_cuda + MPS: tril_mps_out + +- func: tril(Tensor self, int diagonal=0) -> Tensor + structured_delegate: tril.out + variants: method, function + +- func: tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CPU: tril_indices_cpu + CUDA: tril_indices_cuda + autogen: tril_indices.out + +- func: triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CPU: triu_indices_cpu + CUDA: triu_indices_cuda + autogen: triu_indices.out + +- func: trace(Tensor self) -> Tensor + variants: method, function + dispatch: + CPU: trace_cpu + CUDA: trace_cuda + MPS: trace_mps + autogen: trace.out + +- func: trace_backward(Tensor grad, SymInt[] sizes) -> Tensor + variants: function + device_check: NoCheck + device_guard: False + dispatch: + CompositeImplicitAutograd: trace_backward_symint + +- func: ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: ne_Scalar_out + MPS: ne_scalar_out_mps + QuantizedCPU: ne_out_quantized_cpu + tags: pointwise + +- func: ne.Scalar(Tensor self, Scalar other) -> Tensor + structured_delegate: ne.Scalar_out + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + QuantizedCPU: ne_quantized_cpu + tags: [core, pointwise] + +- func: ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: ne_Tensor_out + MPS: ne_tensor_out_mps + QuantizedCPU: ne_out_quantized_cpu + tags: pointwise + +- func: ne.Tensor(Tensor self, Tensor other) -> Tensor + structured_delegate: ne.Tensor_out + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + QuantizedCPU: ne_quantized_cpu + tags: [core, pointwise] + +- func: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + structured_delegate: ne.Scalar_out + device_check: NoCheck # TensorIterator + variants: method + +- func: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + structured_delegate: ne.Tensor_out + device_check: NoCheck # TensorIterator + variants: method + +# not_equal, alias for torch.ne +- func: not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + +- func: not_equal.Scalar(Tensor self, Scalar other) -> Tensor + variants: method, function + +- func: not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + +- func: not_equal.Tensor(Tensor self, Tensor other) -> Tensor + variants: method, function + +- func: not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + variants: method + +- func: not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + +- func: eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: eq_Scalar_out + MPS: eq_scalar_out_mps + QuantizedCPU: eq_out_quantized_cpu + tags: pointwise + +- func: eq.Scalar(Tensor self, Scalar other) -> Tensor + structured_delegate: eq.Scalar_out + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + QuantizedCPU: eq_quantized_cpu + NestedTensorCPU, NestedTensorCUDA: eq_scalar_nested + tags: [core, pointwise] + +- func: eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: eq_Tensor_out + MPS: eq_tensor_out_mps + QuantizedCPU: eq_out_quantized_cpu + tags: pointwise + +- func: eq.Tensor(Tensor self, Tensor other) -> Tensor + structured_delegate: eq.Tensor_out + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + QuantizedCPU: eq_quantized_cpu + NestedTensorCPU, NestedTensorCUDA: eq_tensor_nested + tags: [core, pointwise] + +- func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: ge_Scalar_out + MPS: ge_scalar_out_mps + QuantizedCPU: ge_out_quantized_cpu + tags: pointwise + +- func: ge.Scalar(Tensor self, Scalar other) -> Tensor + structured_delegate: ge.Scalar_out + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + QuantizedCPU: ge_quantized_cpu + NestedTensorCPU, NestedTensorCUDA: ge_scalar_nested + tags: [core, pointwise] + +- func: ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: ge_Tensor_out + MPS: ge_tensor_out_mps + QuantizedCPU: ge_out_quantized_cpu + tags: pointwise + +- func: ge.Tensor(Tensor self, Tensor other) -> Tensor + structured_delegate: ge.Tensor_out + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + QuantizedCPU: ge_quantized_cpu + tags: [core, pointwise] + +- func: ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + structured_delegate: ge.Scalar_out + device_check: NoCheck # TensorIterator + variants: method + +- func: ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + structured_delegate: ge.Tensor_out + device_check: NoCheck # TensorIterator + variants: method + +# greater_equal, alias for torch.ge +- func: greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + +- func: greater_equal.Scalar(Tensor self, Scalar other) -> Tensor + variants: method, function + +- func: greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + +- func: greater_equal.Tensor(Tensor self, Tensor other) -> Tensor + variants: method, function + +- func: greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + variants: method + +- func: greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + +- func: le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: le_Scalar_out + MPS: le_scalar_out_mps + QuantizedCPU: le_out_quantized_cpu + tags: pointwise + +- func: le.Scalar(Tensor self, Scalar other) -> Tensor + structured_delegate: le.Scalar_out + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + QuantizedCPU: le_quantized_cpu + tags: [core, pointwise] + +- func: le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: le_Tensor_out + MPS: le_tensor_out_mps + QuantizedCPU: le_out_quantized_cpu + tags: pointwise + +- func: le.Tensor(Tensor self, Tensor other) -> Tensor + structured_delegate: le.Tensor_out + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + QuantizedCPU: le_quantized_cpu + tags: [core, pointwise] + +- func: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + structured_delegate: le.Scalar_out + device_check: NoCheck # TensorIterator + variants: method + +- func: le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + structured_delegate: le.Tensor_out + device_check: NoCheck # TensorIterator + variants: method + +# less_equal, alias for torch.le +- func: less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + +- func: less_equal.Scalar(Tensor self, Scalar other) -> Tensor + variants: method, function + +- func: less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + +- func: less_equal.Tensor(Tensor self, Tensor other) -> Tensor + variants: method, function + +- func: less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + variants: method + +- func: less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + +- func: gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: gt_Scalar_out + MPS: gt_scalar_out_mps + QuantizedCPU: gt_out_quantized_cpu + tags: pointwise + +- func: gt.Scalar(Tensor self, Scalar other) -> Tensor + structured_delegate: gt.Scalar_out + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + QuantizedCPU: gt_quantized_cpu + NestedTensorCPU, NestedTensorCUDA: gt_scalar_nested + tags: [core, pointwise] + +- func: gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: gt_Tensor_out + MPS: gt_tensor_out_mps + QuantizedCPU: gt_out_quantized_cpu + tags: pointwise + +- func: gt.Tensor(Tensor self, Tensor other) -> Tensor + structured_delegate: gt.Tensor_out + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + QuantizedCPU: gt_quantized_cpu + tags: [core, pointwise] + +- func: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + structured_delegate: gt.Scalar_out + device_check: NoCheck # TensorIterator + variants: method + +- func: gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + structured_delegate: gt.Tensor_out + device_check: NoCheck # TensorIterator + variants: method + +# greater, alias for torch.gt +- func: greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + +- func: greater.Scalar(Tensor self, Scalar other) -> Tensor + variants: method, function + +- func: greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + +- func: greater.Tensor(Tensor self, Tensor other) -> Tensor + variants: method, function + +- func: greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + variants: method + +- func: greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + +- func: lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: lt_Scalar_out + MPS: lt_scalar_out_mps + QuantizedCPU: lt_out_quantized_cpu + tags: pointwise + +- func: lt.Scalar(Tensor self, Scalar other) -> Tensor + structured_delegate: lt.Scalar_out + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + QuantizedCPU: lt_quantized_cpu + tags: [core, pointwise] + +- func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: lt_Tensor_out + MPS: lt_tensor_out_mps + QuantizedCPU: lt_out_quantized_cpu + tags: pointwise + +- func: lt.Tensor(Tensor self, Tensor other) -> Tensor + structured_delegate: lt.Tensor_out + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + QuantizedCPU: lt_quantized_cpu + tags: [core, pointwise] + +- func: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + structured_delegate: lt.Scalar_out + device_check: NoCheck # TensorIterator + variants: method + +- func: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + structured_delegate: lt.Tensor_out + device_check: NoCheck # TensorIterator + variants: method + +# less, alias for torch.lt +- func: less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + +- func: less.Scalar(Tensor self, Scalar other) -> Tensor + variants: method, function + +- func: less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + +- func: less.Tensor(Tensor self, Tensor other) -> Tensor + variants: method, function + +- func: less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + variants: method + +- func: less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + +- func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: take_out + +- func: take(Tensor self, Tensor index) -> Tensor + variants: method, function + dispatch: + CPU, CUDA: take + +- func: take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + +- func: take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor + variants: method, function + +- func: index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, QuantizedCPU: index_select_out_cpu_ + CUDA, QuantizedCUDA: index_select_out_cuda + MPS: index_select_out_mps + +- func: index_select(Tensor self, int dim, Tensor index) -> Tensor + variants: method, function + dispatch: + CPU: index_select_cpu_ + QuantizedCPU: index_select_quantized_cpu_ + CUDA: index_select_cuda + QuantizedCUDA: index_select_quantized_cuda + SparseCPU: index_select_sparse_cpu + SparseCUDA: index_select_sparse_cuda + MPS: index_select_mps + tags: core + +- func: index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + +- func: index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor + variants: method, function + +- func: index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor + variants: function + device_check: NoCheck + device_guard: False + dispatch: + CompositeImplicitAutograd: index_select_backward_symint + +- func: masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: masked_select_out_cpu + CUDA: masked_select_out_cuda + MPS: masked_select_out_mps + tags: dynamic_output_shape + +- func: masked_select(Tensor self, Tensor mask) -> Tensor + variants: method, function + dispatch: + CPU: masked_select_cpu + CUDA: masked_select_cuda + MPS: masked_select_mps + tags: dynamic_output_shape + +- func: masked_select_backward(Tensor grad, Tensor input, Tensor mask) -> Tensor + variants: function + device_check: NoCheck + device_guard: False + +- func: nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: nonzero_out_cpu + CUDA: nonzero_out_cuda + MPS: nonzero_out_mps + tags: dynamic_output_shape + +- func: nonzero(Tensor self) -> Tensor + variants: method, function + dispatch: + CPU: nonzero_cpu + CUDA: nonzero_cuda + MPS: nonzero_mps + tags: [dynamic_output_shape, core] + +- func: nonzero_static.out(Tensor self, *, int size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: nonzero_static_out_cpu + +- func: nonzero_static(Tensor self, *, int size, int fill_value=-1) -> Tensor + variants: method, function + dispatch: + CPU: nonzero_static_cpu + +- func: nonzero_numpy(Tensor self) -> Tensor[] + variants: method, function + +- func: argwhere(Tensor self) -> Tensor + variants: method, function + tags: dynamic_output_shape + +- func: gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU, CUDA: gather_out + MPS: gather_out_mps + +- func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor + variants: method, function + structured_delegate: gather.out + tags: core + +- func: gather_backward(Tensor grad, Tensor self, int dim, Tensor index, bool sparse_grad) -> Tensor + variants: function + device_check: NoCheck + device_guard: False + +- func: gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + +- func: gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor + variants: method, function + +- func: _gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor + +- func: addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: addcmul_out + MPS: addcmul_out_mps + tags: pointwise + +- func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + structured_delegate: addcmul.out + device_check: NoCheck # TensorIterator + variants: method, function + tags: pointwise + +- func: addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) + structured_delegate: addcmul.out + device_check: NoCheck # TensorIterator + variants: method + tags: pointwise + +- func: addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: addcdiv_out + MPS: addcdiv_out_mps + tags: pointwise + +- func: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + structured_delegate: addcdiv.out + device_check: NoCheck # TensorIterator + variants: method, function + tags: pointwise + +- func: addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) + structured_delegate: addcdiv.out + device_check: NoCheck # TensorIterator + variants: method + tags: pointwise + +- func: cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor + python_module: nn + dispatch: + CompositeImplicitAutograd: cross_entropy_loss_symint + +- func: triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient) + structured: True + dispatch: + CPU, CUDA: triangular_solve_out + MPS: triangular_solve_mps_out + SparseCsrCPU: triangular_solve_out_sparse_csr_cpu + SparseCsrCUDA: triangular_solve_out_sparse_csr_cuda + +- func: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient) + structured_delegate: triangular_solve.X + variants: method, function + +- func: _linalg_check_errors(Tensor info, str api_name, *, bool is_matrix) -> () + dispatch: + CompositeExplicitAutograd: _linalg_check_errors + +- func: linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + dispatch: + CPU, CUDA: linalg_solve_triangular_out + MPS: linalg_solve_triangular_mps_out + +- func: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor + python_module: linalg + variants: function + dispatch: + CPU, CUDA: linalg_solve_triangular + MPS: linalg_solve_triangular_mps + +- func: linalg_vander(Tensor x, *, SymInt? N=None) -> Tensor + python_module: linalg + dispatch: + CompositeImplicitAutograd: linalg_vander_symint + +- func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) + +- func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) + variants: method, function + +# swapaxes, alias for transpose +- func: swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + +- func: swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + tags: inplace_view + +# swapdims, alias for transpose +- func: swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + +- func: swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + variants: method + device_check: NoCheck + device_guard: False + tags: inplace_view + +- func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: cholesky_out + +- func: cholesky(Tensor self, bool upper=False) -> Tensor + variants: method, function + dispatch: + CPU, CUDA: cholesky + +- func: cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: cholesky_solve_out + +- func: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor + variants: method, function + dispatch: + CompositeExplicitAutograd: cholesky_solve + +- func: _cholesky_solve_helper(Tensor self, Tensor A, bool upper) -> Tensor + variants: function + dispatch: + CPU: _cholesky_solve_helper_cpu + CUDA: _cholesky_solve_helper_cuda + autogen: _cholesky_solve_helper.out + +- func: cholesky_inverse(Tensor self, bool upper=False) -> Tensor + variants: method, function + dispatch: + CPU, CUDA: cholesky_inverse + +- func: cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: cholesky_inverse_out + +- func: qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + +- func: qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R) + variants: method, function + +- func: geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau) + dispatch: + CPU, CUDA: geqrf_out + +- func: geqrf(Tensor self) -> (Tensor a, Tensor tau) + variants: method, function + dispatch: + CPU, CUDA: geqrf + +# orgqr, alias for linalg_householder_product +- func: orgqr(Tensor self, Tensor input2) -> Tensor + variants: method, function + +- func: orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!) + +- func: ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: ormqr_out + +- func: ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor + variants: method, function + dispatch: + CPU, CUDA: ormqr + +- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info) + variants: function + +- func: lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) + +- func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor + variants: method, function + +# lu_unpack +- func: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) + structured_delegate: lu_unpack.out + variants: function + +- func: lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) + variants: function + structured: True + dispatch: + CPU, CUDA: lu_unpack_out + +# TODO: remove dispatch section when porting TH CUDA to ATen +- func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + tags: nondeterministic_seeded + dispatch: + CPU, CUDA: multinomial_out + MPS: multinomial_out_mps + +- func: multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor + variants: method, function + dispatch: + CPU, CUDA: multinomial + MPS: multinomial_mps + tags: nondeterministic_seeded + +- func: lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: lgamma_out + MPS: lgamma_out_mps + tags: pointwise + +- func: lgamma_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: lgamma.out + variants: method + tags: pointwise + +- func: lgamma(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: lgamma.out + variants: method, function + tags: pointwise + +- func: digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: digamma_out + MPS: digamma_out_mps + tags: pointwise + +- func: digamma(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: digamma.out + variants: method, function + tags: pointwise + +- func: polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: polygamma_out + MPS: polygamma_out_mps + tags: pointwise + +- func: polygamma(int n, Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: polygamma.out + variants: method, function + tags: pointwise + +- func: polygamma_(Tensor(a!) self, int n) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: polygamma_ + tags: pointwise + +- func: erfinv(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: erfinv.out + variants: method, function + dispatch: + SparseCPU, SparseCUDA: erfinv_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr + tags: pointwise + +- func: erfinv_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: erfinv.out + variants: method + dispatch: + SparseCPU, SparseCUDA: erfinv_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_ + tags: pointwise + +- func: erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: erfinv_out + MPS: erfinv_out_mps + SparseCPU, SparseCUDA: erfinv_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_out + tags: pointwise + +- func: i0(Tensor self) -> Tensor + structured_delegate: i0.out + variants: function, method + tags: pointwise + +- func: i0_(Tensor(a!) self) -> Tensor(a!) + structured_delegate: i0.out + variants: function, method + tags: pointwise + +- func: i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: i0_out + tags: pointwise + +- func: sign(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: sign.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: sign_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sign_sparse_csr + tags: [core, pointwise] + +- func: sign_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: sign.out + variants: method + dispatch: + SparseCPU, SparseCUDA: sign_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sign_sparse_csr_ + tags: pointwise + +- func: sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: sign_out + MPS: sign_out_mps + SparseCPU, SparseCUDA: sign_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sign_sparse_csr_out + tags: pointwise + +- func: signbit(Tensor self) -> Tensor + variants: function, method + structured_delegate: signbit.out + dispatch: + SparseCPU, SparseCUDA: signbit_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: signbit_sparse_csr + tags: pointwise + +- func: signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU: signbit_out + CUDA: signbit_out + MPS: signbit_out_mps + SparseCPU, SparseCUDA: signbit_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: signbit_sparse_csr_out + tags: pointwise + +- func: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CompositeExplicitAutograd: dist + autogen: dist.out + +- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: atan2_out + MPS: atan2_out_mps + tags: [core, pointwise] + +- func: atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: atan2.out + variants: method + tags: pointwise + +- func: atan2(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: atan2.out + variants: method, function + tags: [core, pointwise] +# arctan2, alias of atan2 + +- func: arctan2(Tensor self, Tensor other) -> Tensor + variants: method, function + +- func: arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + +- func: arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + +- func: lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: lerp_Scalar + MPS: lerp_Scalar_mps + tags: pointwise + +- func: lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: lerp_Tensor + MPS: lerp_Tensor_mps + tags: pointwise + +- func: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + structured_delegate: lerp.Scalar_out + tags: pointwise + +- func: lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + structured_delegate: lerp.Tensor_out + tags: pointwise + +- func: histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, MPS: histogram_histc_out + CUDA: _histc_out_cuda + +- func: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor + variants: method, function + dispatch: + CPU, MPS: histogram_histc + CUDA: _histc_cuda + +- func: histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) + dispatch: + CPU, MPS: histogram_out + +- func: histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges) + variants: method, function + dispatch: + CPU, MPS: histogram + +- func: histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) + dispatch: + CPU, MPS: histogram_out + +- func: histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges) + variants: method, function + dispatch: + CPU, MPS: histogram + +- func: _histogramdd_bin_edges(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor[] + dispatch: + CPU, MPS: histogramdd_bin_edges + autogen: _histogramdd_bin_edges.out + +- func: _histogramdd_from_bin_cts(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor + dispatch: + CPU, MPS: _histogramdd + autogen: _histogramdd_from_bin_cts.out + +- func: _histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor + dispatch: + CPU, MPS: _histogramdd + autogen: _histogramdd_from_bin_tensors.out + +- func: histogramdd(Tensor self, int[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges) + +- func: histogramdd.int_bins(Tensor self, int bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges) + +- func: histogramdd.TensorList_bins(Tensor self, Tensor[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges) + +- func: fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CompositeExplicitAutograd: fmod_out + tags: pointwise + +- func: fmod.Scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CompositeExplicitAutograd: fmod + tags: [core, pointwise] + +- func: fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + dispatch: + CompositeExplicitAutograd: fmod_ + tags: pointwise + +- func: fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: fmod_out + MPS: fmod_mps_out + tags: pointwise + +- func: fmod.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: fmod.Tensor_out + variants: method, function + tags: [core, pointwise] + +- func: fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: fmod.Tensor_out + tags: pointwise + +- func: hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: hypot_out + MPS: hypot_out_mps + tags: pointwise + +- func: hypot(Tensor self, Tensor other) -> Tensor + structured_delegate: hypot.out + variants: method, function + tags: pointwise + +- func: hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!) + structured_delegate: hypot.out + variants: method + tags: pointwise + +- func: igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: igamma_out + tags: pointwise + +- func: igamma(Tensor self, Tensor other) -> Tensor + structured_delegate: igamma.out + variants: method, function + tags: pointwise + +- func: igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!) + structured_delegate: igamma.out + variants: method + tags: pointwise + +- func: igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: igammac_out + tags: pointwise + +- func: igammac(Tensor self, Tensor other) -> Tensor + structured_delegate: igammac.out + variants: method, function + tags: pointwise + +- func: igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!) + structured_delegate: igammac.out + variants: method + tags: pointwise + +- func: nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA, MPS: nextafter_out + tags: pointwise + +- func: nextafter(Tensor self, Tensor other) -> Tensor + structured_delegate: nextafter.out + variants: method, function + tags: pointwise + +- func: nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!) + structured_delegate: nextafter.out + variants: method + tags: pointwise + +- func: remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: remainder_out + tags: pointwise + +- func: remainder.Scalar(Tensor self, Scalar other) -> Tensor + variants: method, function + dispatch: + CompositeExplicitAutograd: remainder + tags: [core, pointwise] + +- func: remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + variants: method + dispatch: + CompositeExplicitAutograd: remainder_ + tags: pointwise + +- func: remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: remainder_out + MPS: remainder_out_mps + tags: pointwise + +- func: remainder.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: remainder.Tensor_out + variants: method, function + tags: [core, pointwise] + +- func: remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: remainder.Tensor_out + variants: method + tags: pointwise + +- func: remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: function + dispatch: + CPU, CUDA, MPS: remainder + autogen: remainder.Scalar_Tensor_out + tags: pointwise + +- func: min(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CPU, CUDA: min + MPS: min_mps + QuantizedCPU: min_quantized_cpu + +- func: min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: min_unary_out + QuantizedCPU: min_quantized_unary_out + +- func: fmin(Tensor self, Tensor other) -> Tensor + structured_delegate: fmin.out + device_check: NoCheck # TensorIterator + variants: method, function + tags: pointwise + +- func: fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA, MPS: fmin_out + tags: pointwise + +- func: max(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CPU, CUDA: max + MPS: max_mps + QuantizedCPU: max_quantized_cpu + +- func: fmax(Tensor self, Tensor other) -> Tensor + structured_delegate: fmax.out + device_check: NoCheck # TensorIterator + variants: method, function + tags: pointwise + +- func: fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA, MPS: fmax_out + tags: pointwise + +- func: maximum(Tensor self, Tensor other) -> Tensor + structured_delegate: maximum.out + device_check: NoCheck # TensorIterator + variants: method, function + tags: [core, pointwise] + +- func: maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: maximum_out + MPS: maximum_out_mps + tags: pointwise + +# binary max, alias of maximum +# NOTE: max is not an alias for maximum, since there is also unary max +- func: max.other(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + tags: pointwise + +- func: max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + tags: pointwise + +- func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: max_unary_out + QuantizedCPU: max_quantized_unary_out + +- func: minimum(Tensor self, Tensor other) -> Tensor + structured_delegate: minimum.out + device_check: NoCheck # TensorIterator + variants: method, function + tags: [core, pointwise] + +- func: minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + CPU, CUDA: minimum_out + MPS: minimum_out_mps + tags: pointwise + +# binary min, alias for minimum +# NOTE: min is not an alias for minimum, since there is also unary min +- func: min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + tags: pointwise + +- func: min.other(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + tags: pointwise + +- func: quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor + variants: method, function + +- func: quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + +- func: quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor + variants: method, function + +- func: quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + +- func: nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor + variants: method, function + +- func: nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + +- func: nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor + variants: method, function + +- func: nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + +- func: sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + device_check: NoCheck # TensorIterator + dispatch: + CompositeExplicitAutograd: sort_out + +- func: sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + structured: True + dispatch: + CPU, CUDA: sort_stable_out + MPS: sort_stable_out_mps + +- func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + device_check: NoCheck # TensorIterator + variants: method, function + dispatch: + CompositeExplicitAutograd: sort + tags: core + +- func: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + structured_delegate: sort.values_stable + variants: method, function + dispatch: + QuantizedCPU: sort_quantized_cpu_stable + +- func: sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + +- func: sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + +- func: sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) + variants: method, function + +- func: sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) + variants: method, function + +- func: msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + +- func: msort(Tensor self) -> Tensor + variants: method, function + +- func: argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + +- func: argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + +- func: argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function + +- func: argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor + variants: method, function + +- func: topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + structured: True + dispatch: + CPU: topk_out_cpu + CUDA: topk_out_cuda + MPS: topk_out_mps + +- func: topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) + variants: method, function + structured_delegate: topk.values + dispatch: + QuantizedCPU: topk_quantized_cpu + tags: core + +- func: all(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: all.all_out + variants: method, function + +- func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + structured: True + dispatch: + CPU, CUDA: all_all_out + MPS: all_all_out_mps + +- func: any(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: any.all_out + variants: method, function + dispatch: + SparseCPU, SparseCUDA: any_sparse + tags: core + +- func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + structured: True + dispatch: + CPU, CUDA: any_all_out + MPS: any_all_out_mps + +- func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + dispatch: + CPU, CUDA: renorm_out + MPS: renorm_out_mps + +- func: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor + device_check: NoCheck # TensorIterator + variants: method, function + structured_delegate: renorm.out + +- func: renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + structured_delegate: renorm.out + +- func: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) + variants: method + device_check: NoCheck + device_guard: False + dispatch: + CPU, CUDA, Meta, MPS: unfold + QuantizedCPU, QuantizedCUDA: unfold + +- func: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor + variants: function + dispatch: + CPU, CUDA: unfold_backward + autogen: unfold_backward.out + +- func: equal(Tensor self, Tensor other) -> bool + tags: [data_dependent_output, pointwise] + variants: method, function + dispatch: + CPU: cpu_equal + CUDA: cuda_equal + MPS: mps_equal + QuantizedCPU: equal_quantized_cpu + +- func: pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: pow_Tensor_Tensor_out + MPS: pow_tensor_tensor_out_mps + tags: pointwise + +- func: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: pow.Tensor_Tensor_out + variants: method, function + tags: [core, pointwise] + +- func: pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + dispatch: + CPU, CUDA: pow_Scalar_out + MPS: pow_Scalar_out_mps + tags: pointwise + +- func: pow.Scalar(Scalar self, Tensor exponent) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: pow.Scalar_out + tags: [core, pointwise] + +- func: pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: pow_Tensor_Scalar_out + SparseCPU, SparseCUDA: pow_out_sparse_scalar + MPS: pow_tensor_scalar_out_mps + tags: pointwise + +- func: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: pow.Tensor_Scalar_out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: pow_sparse_scalar + tags: [core, pointwise] + +- func: pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: pow.Tensor_Scalar_out + variants: method + tags: pointwise + +- func: pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured_delegate: pow.Tensor_Tensor_out + variants: method + tags: pointwise + +- func: float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + tags: pointwise + +- func: float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + variants: function, method + tags: pointwise + +- func: float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + tags: pointwise + +- func: float_power.Scalar(Scalar self, Tensor exponent) -> Tensor + tags: pointwise + +- func: float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + tags: pointwise + +- func: float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + variants: function, method + tags: pointwise + +- func: float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) + variants: method + tags: pointwise + +- func: float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) + variants: method + tags: pointwise + +- func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + tags: nondeterministic_seeded + variants: method + dispatch: + CPU, CUDA: normal_ + MPS: normal_mps_ + Meta: normal_meta_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: normal_sparse_csr_ + NestedTensorCPU, NestedTensorCUDA: normal_nested_ + autogen: normal.out + +# Only used by the functionalization pass. +# Normally, the codegen would be able to generate a normal() NativeFunction, +# but we can't due to overload ambiguity with normal.Tensor_float. +- func: normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor + device_check: NoCheck # TensorIterator + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: normal_functional + +- func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + tags: nondeterministic_seeded + dispatch: + CPU, CUDA: normal_out + MPS: normal_mps_out + Meta: normal_out_meta + +- func: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor + dispatch: + CPU, CUDA: normal + MPS: normal_mps + Meta: normal_meta + tags: nondeterministic_seeded + +- func: normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: normal_out + Meta: normal_out_meta + MPS: normal_mps_out + tags: nondeterministic_seeded + +- func: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor + dispatch: + CPU, CUDA: normal + MPS: normal_mps + Meta: normal_meta + tags: nondeterministic_seeded + +- func: normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: normal_out + Meta: normal_out_meta + MPS: normal_mps_out + tags: nondeterministic_seeded + +- func: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor + dispatch: + CPU, CUDA: normal + MPS: normal_mps + Meta: normal_meta + tags: nondeterministic_seeded + +- func: normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: normal + tags: nondeterministic_seeded + +- func: normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: normal_out + tags: nondeterministic_seeded + +- func: alias(Tensor(a) self) -> Tensor(a) + variants: method, function + dispatch: + CompositeExplicitAutograd: alias + NestedTensorCPU, NestedTensorCUDA: alias_nested + tags: core + +- func: _amp_foreach_non_finite_check_and_unscale_(Tensor(a!)[] self, Tensor(b!) found_inf, Tensor inv_scale) -> () + variants: function + dispatch: + CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_ + CPU: _amp_foreach_non_finite_check_and_unscale_cpu_ + autogen: _amp_foreach_non_finite_check_and_unscale, _amp_foreach_non_finite_check_and_unscale.out + +- func: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!) + variants: function + dispatch: + CUDA: _amp_update_scale_cuda_ + CPU: _amp_update_scale_cpu_ + autogen: _amp_update_scale, _amp_update_scale.out + + #- func: _cat(Tensor[] tensors, int dim=0) -> Tensor + #dispatch: + #CPU: _cat_cpu + #CUDA: cat_cuda + #MPS: cat_mps + #QuantizedCPU: cat_quantized_cpu + + #- func: _cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + #dispatch: + #CPU: _cat_out_cpu + #CUDA: cat_out_cuda + #QuantizedCPU: cat_out_quantized_cpu + +- func: _foreach_add.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow + CUDA: foreach_tensor_add_scalar_kernel_cuda + +- func: _foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow_ + CUDA: foreach_tensor_add_scalar_kernel_cuda_ + autogen: _foreach_add.Scalar_out + +- func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow + CUDA: foreach_tensor_add_list_kernel_cuda + +- func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow_ + CUDA: foreach_tensor_add_list_kernel_cuda_ + autogen: _foreach_add.List_out + +- func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_add_scalarlist_kernel_slow + CUDA: foreach_tensor_add_scalarlist_kernel_cuda + +- func: _foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_add_scalarlist_kernel_slow_ + CUDA: foreach_tensor_add_scalarlist_kernel_cuda_ + autogen: _foreach_add.ScalarList_out + +- func: _foreach_add.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow + CUDA: foreach_tensor_add_tensor_kernel_cuda + +- func: _foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow_ + CUDA: foreach_tensor_add_tensor_kernel_cuda_ + autogen: _foreach_add.Tensor_out + +- func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sub_scalar_kernel_slow + CUDA: foreach_tensor_sub_scalar_kernel_cuda + +- func: _foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sub_scalar_kernel_slow_ + CUDA: foreach_tensor_sub_scalar_kernel_cuda_ + autogen: _foreach_sub.Scalar_out + +- func: _foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sub_list_kernel_slow + CUDA: foreach_tensor_sub_list_kernel_cuda + +- func: _foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sub_list_kernel_slow_ + CUDA: foreach_tensor_sub_list_kernel_cuda_ + autogen: _foreach_sub.List_out + +- func: _foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sub_scalarlist_kernel_slow + CUDA: foreach_tensor_sub_scalarlist_kernel_cuda + +- func: _foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sub_scalarlist_kernel_slow_ + CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_ + autogen: _foreach_sub.ScalarList_out + +- func: _foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow + CUDA: foreach_tensor_mul_scalar_kernel_cuda + +- func: _foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow_ + CUDA: foreach_tensor_mul_scalar_kernel_cuda_ + autogen: _foreach_mul.Scalar_out + +- func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow + CUDA: foreach_tensor_mul_list_kernel_cuda + +- func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow_ + CUDA: foreach_tensor_mul_list_kernel_cuda_ + autogen: _foreach_mul.List_out + +- func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_mul_scalarlist_kernel_slow + CUDA: foreach_tensor_mul_scalarlist_kernel_cuda + +- func: _foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_mul_scalarlist_kernel_slow_ + CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_ + autogen: _foreach_mul.ScalarList_out + +- func: _foreach_mul.Tensor(Tensor[] self, Tensor other) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow + CUDA: foreach_tensor_mul_tensor_kernel_cuda + +- func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow_ + CUDA: foreach_tensor_mul_tensor_kernel_cuda_ + autogen: _foreach_mul.Tensor_out + +- func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_div_scalar_kernel_slow + CUDA: foreach_tensor_div_scalar_kernel_cuda + +- func: _foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_div_scalar_kernel_slow_ + CUDA: foreach_tensor_div_scalar_kernel_cuda_ + autogen: _foreach_div.Scalar_out + +- func: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow + CUDA: foreach_tensor_div_list_kernel_cuda + +- func: _foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow_ + CUDA: foreach_tensor_div_list_kernel_cuda_ + autogen: _foreach_div.List_out + +- func: _foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_div_scalarlist_kernel_slow + CUDA: foreach_tensor_div_scalarlist_kernel_cuda + +- func: _foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_div_scalarlist_kernel_slow_ + CUDA: foreach_tensor_div_scalarlist_kernel_cuda_ + autogen: _foreach_div.ScalarList_out + +- func: _foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow + CUDA: foreach_tensor_div_tensor_kernel_cuda + +- func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow_ + CUDA: foreach_tensor_div_tensor_kernel_cuda_ + autogen: _foreach_div.Tensor_out + +- func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow + CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda + +- func: _foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow_ + CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_ + autogen: _foreach_clamp_max.Scalar_out + +- func: _foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow + CUDA: foreach_tensor_clamp_max_list_kernel_cuda + +- func: _foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow_ + CUDA: foreach_tensor_clamp_max_list_kernel_cuda_ + autogen: _foreach_clamp_max.List_out + +- func: _foreach_clamp_max.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow + CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda + +- func: _foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow_ + CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_ + autogen: _foreach_clamp_max.ScalarList_out + +- func: _foreach_clamp_min.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow + CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda + +- func: _foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow_ + CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_ + autogen: _foreach_clamp_min.Scalar_out + +- func: _foreach_clamp_min.List(Tensor[] self, Tensor[] other) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow + CUDA: foreach_tensor_clamp_min_list_kernel_cuda + +- func: _foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow_ + CUDA: foreach_tensor_clamp_min_list_kernel_cuda_ + autogen: _foreach_clamp_min.List_out + +- func: _foreach_clamp_min.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow + CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda + +- func: _foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow_ + CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_ + autogen: _foreach_clamp_min.ScalarList_out + +# foreach_minimum/maximum dispatches to clamp_max/min +- func: _foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow + CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda + +- func: _foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow_ + CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_ + autogen: _foreach_maximum.Scalar_out + +# foreach_minimum/maximum dispatches to clamp_max/min +- func: _foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow + CUDA: foreach_tensor_clamp_min_list_kernel_cuda + +- func: _foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow_ + CUDA: foreach_tensor_clamp_min_list_kernel_cuda_ + autogen: _foreach_maximum.List_out + +# foreach_minimum/maximum dispatches to clamp_max/min +- func: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow + CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda + +- func: _foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow_ + CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_ + autogen: _foreach_maximum.ScalarList_out + +- func: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow + CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda + +- func: _foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow_ + CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_ + autogen: _foreach_minimum.Scalar_out + +- func: _foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow + CUDA: foreach_tensor_clamp_max_list_kernel_cuda + +- func: _foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow_ + CUDA: foreach_tensor_clamp_max_list_kernel_cuda_ + autogen: _foreach_minimum.List_out + +- func: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow + CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda + +- func: _foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow_ + CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_ + autogen: _foreach_minimum.ScalarList_out + +- func: _foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_addcdiv_scalar_slow + CUDA: foreach_tensor_addcdiv_scalar_cuda + +- func: _foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_addcdiv_scalarlist_slow + CUDA: foreach_tensor_addcdiv_scalarlist_cuda + +- func: _foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_addcdiv_tensor_slow + CUDA: foreach_tensor_addcdiv_tensor_cuda + +- func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_addcdiv_scalar_slow_ + CUDA: foreach_tensor_addcdiv_scalar_cuda_ + autogen: _foreach_addcdiv.Scalar_out + +- func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_addcdiv_scalarlist_slow_ + CUDA: foreach_tensor_addcdiv_scalarlist_cuda_ + autogen: _foreach_addcdiv.ScalarList_out + +- func: _foreach_addcdiv_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_addcdiv_tensor_slow_ + CUDA: foreach_tensor_addcdiv_tensor_cuda_ + autogen: _foreach_addcdiv.Tensor_out + +- func: _foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow + CUDA: foreach_tensor_addcmul_scalar_cuda + +- func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_addcmul_scalarlist_slow + CUDA: foreach_tensor_addcmul_scalarlist_cuda + +- func: _foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_addcmul_tensor_slow + CUDA: foreach_tensor_addcmul_tensor_cuda + +- func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow_ + CUDA: foreach_tensor_addcmul_scalar_cuda_ + autogen: _foreach_addcmul.Scalar_out + +- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_addcmul_scalarlist_slow_ + CUDA: foreach_tensor_addcmul_scalarlist_cuda_ + autogen: _foreach_addcmul.ScalarList_out + +- func: _foreach_addcmul_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_addcmul_tensor_slow_ + CUDA: foreach_tensor_addcmul_tensor_cuda_ + autogen: _foreach_addcmul.Tensor_out + +- func: _foreach_abs(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_abs_slow + CUDA: foreach_tensor_abs_cuda + +- func: _foreach_abs_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_abs_slow_ + CUDA: foreach_tensor_abs_cuda_ + autogen: _foreach_abs.out + +- func: _foreach_acos(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_acos_slow + CUDA: foreach_tensor_acos_cuda + +- func: _foreach_acos_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_acos_slow_ + CUDA: foreach_tensor_acos_cuda_ + autogen: _foreach_acos.out + +- func: _foreach_asin(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_asin_slow + CUDA: foreach_tensor_asin_cuda + +- func: _foreach_asin_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_asin_slow_ + CUDA: foreach_tensor_asin_cuda_ + autogen: _foreach_asin.out + +- func: _foreach_atan(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_atan_slow + CUDA: foreach_tensor_atan_cuda + +- func: _foreach_atan_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_atan_slow_ + CUDA: foreach_tensor_atan_cuda_ + autogen: _foreach_atan.out + +- func: _foreach_ceil(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_ceil_slow + CUDA: foreach_tensor_ceil_cuda + +- func: _foreach_ceil_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_ceil_slow_ + CUDA: foreach_tensor_ceil_cuda_ + autogen: _foreach_ceil.out + +- func: _foreach_cos(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_cos_slow + CUDA: foreach_tensor_cos_cuda + +- func: _foreach_cos_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_cos_slow_ + CUDA: foreach_tensor_cos_cuda_ + autogen: _foreach_cos.out + +- func: _foreach_cosh(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_cosh_slow + CUDA: foreach_tensor_cosh_cuda + +- func: _foreach_cosh_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_cosh_slow_ + CUDA: foreach_tensor_cosh_cuda_ + autogen: _foreach_cosh.out + +- func: _foreach_erf(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_erf_slow + CUDA: foreach_tensor_erf_cuda + +- func: _foreach_erf_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_erf_slow_ + CUDA: foreach_tensor_erf_cuda_ + autogen: _foreach_erf.out + +- func: _foreach_erfc(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_erfc_slow + CUDA: foreach_tensor_erfc_cuda + +- func: _foreach_erfc_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_erfc_slow_ + CUDA: foreach_tensor_erfc_cuda_ + autogen: _foreach_erfc.out + +- func: _foreach_exp(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_exp_slow + CUDA: foreach_tensor_exp_cuda + +- func: _foreach_exp_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_exp_slow_ + CUDA: foreach_tensor_exp_cuda_ + autogen: _foreach_exp.out + +- func: _foreach_expm1(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_expm1_slow + CUDA: foreach_tensor_expm1_cuda + +- func: _foreach_expm1_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_expm1_slow_ + CUDA: foreach_tensor_expm1_cuda_ + autogen: _foreach_expm1.out + +- func: _foreach_floor(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_floor_slow + CUDA: foreach_tensor_floor_cuda + +- func: _foreach_floor_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_floor_slow_ + CUDA: foreach_tensor_floor_cuda_ + autogen: _foreach_floor.out + +- func: _foreach_frac(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_frac_slow + CUDA: foreach_tensor_frac_cuda + +- func: _foreach_frac_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_frac_slow_ + CUDA: foreach_tensor_frac_cuda_ + autogen: _foreach_frac.out + +- func: _foreach_lerp.List(Tensor[] self, Tensor[] tensors1, Tensor[] weights) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_ternary_lerp_slow + CUDA: foreach_tensor_lerp_ternary_cuda + autogen: _foreach_lerp.List_out + +- func: _foreach_lerp_.List(Tensor(a!)[] self, Tensor[] tensors1, Tensor[] weights) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_ternary_lerp_slow_ + CUDA: foreach_tensor_lerp_ternary_cuda_ + autogen: _foreach_lerp.List_out + +- func: _foreach_lerp.Scalar(Tensor[] self, Tensor[] tensors1, Scalar weight) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_lerp_list_kernel_slow + CUDA: foreach_tensor_lerp_list_cuda + autogen: _foreach_lerp.Scalar_out + +- func: _foreach_lerp_.Scalar(Tensor(a!)[] self, Tensor[] tensors1, Scalar weight) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_lerp_list_kernel_slow_ + CUDA: foreach_tensor_lerp_list_cuda_ + autogen: _foreach_lerp.Scalar_out + +- func: _foreach_lgamma(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_lgamma_slow + CUDA: foreach_tensor_lgamma_cuda + +- func: _foreach_lgamma_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_lgamma_slow_ + CUDA: foreach_tensor_lgamma_cuda_ + autogen: _foreach_lgamma.out + +- func: _foreach_log(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_log_slow + CUDA: foreach_tensor_log_cuda + +- func: _foreach_log_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_log_slow_ + CUDA: foreach_tensor_log_cuda_ + autogen: _foreach_log.out + +- func: _foreach_log10(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_log10_slow + CUDA: foreach_tensor_log10_cuda + +- func: _foreach_log10_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_log10_slow_ + CUDA: foreach_tensor_log10_cuda_ + autogen: _foreach_log10.out + +- func: _foreach_log1p(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_log1p_slow + CUDA: foreach_tensor_log1p_cuda + +- func: _foreach_log1p_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_log1p_slow_ + CUDA: foreach_tensor_log1p_cuda_ + autogen: _foreach_log1p.out + +- func: _foreach_log2(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_log2_slow + CUDA: foreach_tensor_log2_cuda + +- func: _foreach_log2_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_log2_slow_ + CUDA: foreach_tensor_log2_cuda_ + autogen: _foreach_log2.out + +- func: _foreach_max(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_max_slow + CUDA: foreach_tensor_max_cuda + autogen: _foreach_max.out + +- func: _foreach_neg(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_neg_slow + CUDA: foreach_tensor_neg_cuda + +- func: _foreach_neg_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_neg_slow_ + CUDA: foreach_tensor_neg_cuda_ + autogen: _foreach_neg.out + +- func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_norm_slow + CUDA: foreach_tensor_norm_cuda + autogen: _foreach_norm.Scalar_out + +- func: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_pow_list_kernel_slow + CUDA: foreach_tensor_pow_list_kernel_cuda + +- func: _foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_pow_scalar_kernel_slow + CUDA: foreach_tensor_pow_scalar_kernel_cuda + +- func: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_pow_scalarlist_kernel_slow + CUDA: foreach_tensor_pow_scalarlist_kernel_cuda + +- func: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_scalar_pow_list_kernel_slow + CUDA: foreach_scalar_pow_list_kernel_cuda + +- func: _foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> () + device_check: NoCheck + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_pow_list_kernel_slow_ + CUDA: foreach_tensor_pow_list_kernel_cuda_ + autogen: _foreach_pow.List_out + +- func: _foreach_pow_.Scalar(Tensor(a!)[] self, Scalar exponent) -> () + device_check: NoCheck + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_pow_scalar_kernel_slow_ + CUDA: foreach_tensor_pow_scalar_kernel_cuda_ + autogen: _foreach_pow.Scalar_out + +- func: _foreach_pow_.ScalarList(Tensor(a!)[] self, Scalar[] exponent) -> () + device_check: NoCheck + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_pow_scalarlist_kernel_slow_ + CUDA: foreach_tensor_pow_scalarlist_kernel_cuda_ + autogen: _foreach_pow.ScalarList_out + +- func: _foreach_reciprocal(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_reciprocal_slow + CUDA: foreach_tensor_reciprocal_cuda + +- func: _foreach_reciprocal_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_reciprocal_slow_ + CUDA: foreach_tensor_reciprocal_cuda_ + autogen: _foreach_reciprocal.out + +- func: _foreach_round(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_round_slow + CUDA: foreach_tensor_round_cuda + +- func: _foreach_round_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_round_slow_ + CUDA: foreach_tensor_round_cuda_ + autogen: _foreach_round.out + +- func: _foreach_sigmoid(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sigmoid_slow + CUDA: foreach_tensor_sigmoid_cuda + +- func: _foreach_sigmoid_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sigmoid_slow_ + CUDA: foreach_tensor_sigmoid_cuda_ + autogen: _foreach_sigmoid.out + +- func: _foreach_sign(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sign_slow + CUDA: foreach_tensor_sign_cuda + +- func: _foreach_sign_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sign_slow_ + CUDA: foreach_tensor_sign_cuda_ + autogen: _foreach_sign.out + +- func: _foreach_sin(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sin_slow + CUDA: foreach_tensor_sin_cuda + +- func: _foreach_sin_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sin_slow_ + CUDA: foreach_tensor_sin_cuda_ + autogen: _foreach_sin.out + +- func: _foreach_sinh(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sinh_slow + CUDA: foreach_tensor_sinh_cuda + +- func: _foreach_sinh_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sinh_slow_ + CUDA: foreach_tensor_sinh_cuda_ + autogen: _foreach_sinh.out + +- func: _foreach_sqrt(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sqrt_slow + CUDA: foreach_tensor_sqrt_cuda + +- func: _foreach_sqrt_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_sqrt_slow_ + CUDA: foreach_tensor_sqrt_cuda_ + autogen: _foreach_sqrt.out + +- func: _foreach_tan(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_tan_slow + CUDA: foreach_tensor_tan_cuda + +- func: _foreach_tan_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_tan_slow_ + CUDA: foreach_tensor_tan_cuda_ + autogen: _foreach_tan.out + +- func: _foreach_tanh(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_tanh_slow + CUDA: foreach_tensor_tanh_cuda + +- func: _foreach_tanh_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_tanh_slow_ + CUDA: foreach_tensor_tanh_cuda_ + autogen: _foreach_tanh.out + +- func: _foreach_trunc(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_trunc_slow + CUDA: foreach_tensor_trunc_cuda + +- func: _foreach_trunc_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_trunc_slow_ + CUDA: foreach_tensor_trunc_cuda_ + autogen: _foreach_trunc.out + +- func: _foreach_zero_(Tensor(a!)[] self) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_zero_slow_ + CUDA: foreach_tensor_zero_cuda_ + autogen: _foreach_zero, _foreach_zero.out + +- func: _foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_ + CUDA: foreach_tensor_copy_list_kernel_cuda_ + autogen: _foreach_copy.out + +- func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out + device_check: NoCheck + variants: function + dispatch: + CompositeExplicitAutograd: _foreach_copy + +- func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor + dispatch: + CPU: bucketize_cpu + CUDA: bucketize_cuda + MPS: bucketize_mps + +- func: bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: bucketize_out_cpu + CUDA: bucketize_out_cuda + MPS: bucketize_out_mps + +- func: bucketize.Scalar(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor + dispatch: + CPU: bucketize_cpu + CUDA: bucketize_cuda + MPS: bucketize_mps + autogen: bucketize.Scalar_out + +- func: searchsorted.Tensor(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor + dispatch: + CPU: searchsorted_cpu + CUDA: searchsorted_cuda + MPS: searchsorted_mps + +- func: searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: searchsorted_out_cpu + CUDA: searchsorted_out_cuda + MPS: searchsorted_out_mps + +- func: searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor + dispatch: + CPU: searchsorted_cpu + CUDA: searchsorted_cuda + MPS: searchsorted_mps + +- func: searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: searchsorted_out_cpu + CUDA: searchsorted_out_cuda + MPS: searchsorted_out_mps + +- func: _convert_indices_from_coo_to_csr(Tensor self, int size, *, bool out_int32=False) -> Tensor + structured_delegate: _convert_indices_from_coo_to_csr.out + +- func: _convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU: _convert_indices_from_coo_to_csr_structured_cpu + CUDA: _convert_indices_from_coo_to_csr_structured_cuda + +- func: _convert_indices_from_csr_to_coo(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False) -> Tensor + structured_delegate: _convert_indices_from_csr_to_coo.out + +- func: _convert_indices_from_csr_to_coo.out(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU: _convert_indices_from_csr_to_coo_structured_cpu + CUDA: _convert_indices_from_csr_to_coo_structured_cuda + +## NN wrappers + +- func: mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU, CUDA: mse_loss_out + MPS: mse_loss_out_mps + +- func: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: mse_loss.out + python_module: nn + +- func: mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU, CUDA: mse_loss_backward_out + MPS: mse_loss_backward_out_mps + +- func: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + python_module: nn + dispatch: + CPU, CUDA: mse_loss_backward + MPS: mse_loss_backward_mps + +- func: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + python_module: nn + +- func: multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CPU: multi_margin_loss_cpu_out + CUDA: multi_margin_loss_cuda_out + +- func: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor + python_module: nn + dispatch: + CPU: multi_margin_loss_cpu + CUDA: multi_margin_loss_cuda + +- func: multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: multi_margin_loss_cpu_backward_out + CUDA: multi_margin_loss_cuda_backward_out + +- func: multi_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean) -> Tensor + python_module: nn + dispatch: + CPU: multi_margin_loss_cpu_backward + CUDA: multi_margin_loss_cuda_backward + +- func: multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + +- func: multilabel_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + python_module: nn + +- func: multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!)) + python_module: nn + dispatch: + CPU: multilabel_margin_loss_forward_out_cpu + CUDA: multilabel_margin_loss_forward_out_cuda + +- func: multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target) + python_module: nn + dispatch: + CPU: multilabel_margin_loss_forward_cpu + CUDA: multilabel_margin_loss_forward_cuda + +- func: multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: multilabel_margin_loss_backward_cpu_out + CUDA: multilabel_margin_loss_backward_cuda_out + +- func: multilabel_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target) -> Tensor + python_module: nn + dispatch: + CPU: multilabel_margin_loss_backward_cpu + CUDA: multilabel_margin_loss_backward_cuda + +- func: nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + +- func: nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + python_module: nn + dispatch: + CompositeImplicitAutograd: nll_loss_nd_symint + +- func: nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + python_module: nn + dispatch: + CompositeImplicitAutograd: nll_loss_symint + +- func: nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + python_module: nn + structured: True + dispatch: + CPU: nll_loss_forward_out_cpu + CUDA: nll_loss_forward_out_cuda + MPS: nll_loss_forward_out_mps + +- func: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + python_module: nn + structured_delegate: nll_loss_forward.output + +- func: nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: nll_loss_backward_out_cpu + CUDA: nll_loss_backward_out_cuda + MPS: nll_loss_backward_out_mps + +- func: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + python_module: nn + structured_delegate: nll_loss_backward.grad_input + +- func: nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + +- func: nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + python_module: nn + dispatch: + CompositeImplicitAutograd: nll_loss2d_symint + +- func: nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + python_module: nn + dispatch: + CPU: nll_loss2d_forward_out_cpu + CUDA: nll_loss2d_forward_out_cuda + MPS: nll_loss2d_forward_out_mps + +- func: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + python_module: nn + dispatch: + CPU: nll_loss2d_forward_cpu + CUDA: nll_loss2d_forward_cuda + MPS: nll_loss2d_forward_mps + +- func: nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: nll_loss2d_backward_out_cpu + CUDA: nll_loss2d_backward_out_cuda + MPS: nll_loss2d_backward_out_mps + +- func: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + python_module: nn + dispatch: + CPU: nll_loss2d_backward_cpu + CUDA: nll_loss2d_backward_cuda + MPS: nll_loss2d_backward_mps + +- func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU, CUDA: smooth_l1_loss_out + MPS: smooth_l1_loss_out_mps + +- func: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: smooth_l1_loss.out + python_module: nn + +- func: smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: smooth_l1_loss_backward_out + CUDA: smooth_l1_loss_backward_out + MPS: smooth_l1_loss_backward_out_mps + +- func: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor + python_module: nn + dispatch: + CompositeExplicitAutograd: smooth_l1_loss_backward + +- func: huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CPU, CUDA: huber_loss_out + MPS: huber_loss_out_mps + +- func: huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor + python_module: nn + dispatch: + CPU, CUDA: huber_loss + MPS: huber_loss_mps + +- func: huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU, CUDA: huber_loss_backward_out + MPS: huber_loss_backward_out_mps + +- func: huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor + python_module: nn + dispatch: + CompositeExplicitAutograd: huber_loss_backward + +- func: soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CompositeExplicitAutograd: soft_margin_loss_out + +- func: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + python_module: nn + dispatch: + CompositeExplicitAutograd: soft_margin_loss + +- func: soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CompositeExplicitAutograd: soft_margin_loss_backward_out + +- func: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + python_module: nn + dispatch: + CompositeExplicitAutograd: soft_margin_loss_backward + +- func: elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU, CUDA: elu_out + MPS: elu_out_mps + +- func: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor + structured_delegate: elu.out + device_check: NoCheck # TensorIterator + python_module: nn + +- func: elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU, CUDA: elu_backward_out + MPS: elu_backward_out_mps + +- func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor + structured_delegate: elu_backward.grad_input + python_module: nn + +- func: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) + structured_delegate: elu.out + device_check: NoCheck # TensorIterator + python_module: nn + +- func: glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU, CUDA: glu_out + MPS: glu_out_mps + +- func: glu(Tensor self, int dim=-1) -> Tensor + structured_delegate: glu.out + device_check: NoCheck # TensorIterator + python_module: nn + +- func: glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: glu_backward_cpu_out + CUDA: glu_backward_cuda_out + MPS: glu_backward_mps_out + +- func: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor + python_module: nn + dispatch: + CPU: glu_backward_cpu + CUDA: glu_backward_cuda + MPS: glu_backward_mps + +- func: glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor + python_module: nn + dispatch: + CPU, CUDA: glu_jvp + autogen: glu_jvp.out + +- func: glu_backward_jvp(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim) -> Tensor + python_module: nn + dispatch: + CPU, CUDA: glu_backward_jvp + autogen: glu_backward_jvp.out + +- func: hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU, CUDA: hardsigmoid_out + MPS: hardsigmoid_out_mps + QuantizedCPU: hardsigmoid_out_quantized_cpu + +- func: hardsigmoid(Tensor self) -> Tensor + structured_delegate: hardsigmoid.out + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + QuantizedCPU: hardsigmoid_quantized_cpu + +- func: hardsigmoid_(Tensor(a!) self) -> Tensor(a!) + structured_delegate: hardsigmoid.out + device_check: NoCheck # TensorIterator + python_module: nn + +- func: hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU, CUDA: hardsigmoid_backward_out + MPS: hardsigmoid_backward_out_mps + +- func: hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor + structured_delegate: hardsigmoid_backward.grad_input + python_module: nn + +- func: hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU, CUDA, MPS: hardtanh_out + QuantizedCPU: hardtanh_out_quantized_cpu + +- func: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU, CUDA, MPS: hardtanh + QuantizedCPU: hardtanh_quantized_cpu + tags: core + +- func: hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU, CUDA: hardtanh_backward_out + MPS: hardtanh_backward_out_mps + +- func: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor + python_module: nn + dispatch: + CPU, CUDA: hardtanh_backward + MPS: hardtanh_backward_mps + +- func: hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU, CUDA, MPS: hardtanh_ + QuantizedCPU: hardtanh_quantized_cpu_ + +- func: hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU, CUDA: hardswish_out + MPS: hardswish_out_mps + +- func: hardswish(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU, CUDA: hardswish + MPS: hardswish_mps + +- func: hardswish_(Tensor(a!) self) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU, CUDA: hardswish_ + MPS: hardswish_mps_ + +- func: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor + python_module: nn + dispatch: + CPU, CUDA: hardswish_backward + MPS: hardswish_backward_mps + autogen: hardswish_backward.out + +- func: leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU, CUDA: leaky_relu_out + MPS: leaky_relu_out_mps + QuantizedCPU: leaky_relu_out_quantized_cpu + +- func: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor + structured_delegate: leaky_relu.out + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + QuantizedCPU: leaky_relu_quantized_cpu + tags: core + +- func: leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU, CUDA: leaky_relu_backward_out + MPS: leaky_relu_backward_out_mps + +- func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor + structured_delegate: leaky_relu_backward.grad_input + python_module: nn + +- func: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!) + structured_delegate: leaky_relu.out + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + QuantizedCPU: leaky_relu_quantized_cpu_ + +- func: log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: nn + +- func: log_sigmoid(Tensor self) -> Tensor + device_check: NoCheck # TensorIterator + python_module: nn + +- func: log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!)) + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU: log_sigmoid_forward_out_cpu + CUDA: log_sigmoid_forward_out_cuda + MPS: log_sigmoid_forward_out_mps + +- func: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU: log_sigmoid_forward_cpu + CUDA: log_sigmoid_forward_cuda + MPS: log_sigmoid_forward_mps + +- func: log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: log_sigmoid_backward_cpu_out + CUDA: log_sigmoid_backward_cuda_out + MPS: log_sigmoid_backward_mps_out + +- func: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor + python_module: nn + dispatch: + CPU: log_sigmoid_backward_cpu + CUDA: log_sigmoid_backward_cuda + MPS: log_sigmoid_backward_mps + +- func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + tags: nondeterministic_seeded + dispatch: + CPU: rrelu_with_noise_out_cpu + CUDA: rrelu_with_noise_out_cuda + +- func: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor + python_module: nn + dispatch: + CPU: rrelu_with_noise_cpu + CUDA: rrelu_with_noise_cuda + tags: nondeterministic_seeded + +- func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor + python_module: nn + dispatch: + CompositeExplicitAutograd: rrelu_with_noise_backward + autogen: rrelu_with_noise_backward.out + +- func: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) + python_module: nn + tags: nondeterministic_seeded + dispatch: + CPU: rrelu_with_noise_cpu_ + CUDA: rrelu_with_noise_cuda_ + +- func: softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU, CUDA: softplus_out + MPS: softplus_out_mps + +- func: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor + structured_delegate: softplus.out + device_check: NoCheck # TensorIterator + python_module: nn + +- func: softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU, CUDA: softplus_backward_out + MPS: softplus_backward_out_mps + +- func: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor + structured_delegate: softplus_backward.grad_input + python_module: nn + +- func: softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + python_module: nn + dispatch: + CPU, CUDA: softshrink_out + MPS: softshrink_out_mps + +- func: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor + structured_delegate: softshrink.out + device_check: NoCheck # TensorIterator + python_module: nn + +- func: softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: nn + dispatch: + CPU, CUDA: softshrink_backward_out + MPS: softshrink_backward_out_mps + +- func: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor + structured_delegate: softshrink_backward.grad_input + python_module: nn + +- func: adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CPU: adaptive_avg_pool2d_out_cpu + CUDA: adaptive_avg_pool2d_out_cuda + MPS: adaptive_avg_pool2d_out_mps + MkldnnCPU: mkldnn_adaptive_avg_pool2d_out_stub + +- func: adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + python_module: nn + dispatch: + CompositeImplicitAutograd: adaptive_avg_pool2d_symint + +- func: mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor + dispatch: + MkldnnCPU: mkldnn_adaptive_avg_pool2d + +- func: mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + MkldnnCPU: mkldnn_adaptive_avg_pool2d_out + +- func: mkldnn_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor + dispatch: + MkldnnCPU: mkldnn_adaptive_avg_pool2d_backward + autogen: mkldnn_adaptive_avg_pool2d_backward.out + +- func: _adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + dispatch: + CPU: adaptive_avg_pool2d_cpu + CUDA: adaptive_avg_pool2d_cuda + MPS: adaptive_avg_pool2d_mps + QuantizedCPU: adaptive_avg_pool2d_quantized_cpu + QuantizedCUDA: adaptive_avg_pool2d_quantized_cuda + autogen: _adaptive_avg_pool2d.out + tags: core + +- func: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor + python_module: nn + dispatch: + CPU: adaptive_avg_pool2d_backward_cpu + CUDA: adaptive_avg_pool2d_backward_cuda + MPS: adaptive_avg_pool2d_backward_mps + autogen: _adaptive_avg_pool2d_backward.out + tags: core + +- func: adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CPU: adaptive_avg_pool3d_out_cpu + CUDA: adaptive_avg_pool3d_out_cuda + QuantizedCPU: adaptive_avg_pool3d_out_quantized_cpu + +- func: adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + python_module: nn + dispatch: + CompositeImplicitAutograd: adaptive_avg_pool3d_symint + +- func: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + dispatch: + CPU: adaptive_avg_pool3d_cpu + CUDA: adaptive_avg_pool3d_cuda + QuantizedCPU: adaptive_avg_pool3d_quantized_cpu + autogen: _adaptive_avg_pool3d.out + tags: core + +- func: adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: adaptive_avg_pool3d_backward_out_cpu + CUDA: adaptive_avg_pool3d_backward_out_cuda + +- func: _adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor + python_module: nn + dispatch: + CPU: adaptive_avg_pool3d_backward_cpu + CUDA: adaptive_avg_pool3d_backward_cuda + autogen: _adaptive_avg_pool3d_backward.out + +# Return: (Tensor output, Tensor indices) +- func: adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + python_module: nn + structured: True + dispatch: + CPU: adaptive_max_pool2d_out_cpu + CUDA: adaptive_max_pool2d_out_cuda + MPS: adaptive_max_pool2d_out_mps + +# Return: (Tensor output, Tensor indices) +- func: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) + python_module: nn + structured_delegate: adaptive_max_pool2d.out + +- func: adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: adaptive_max_pool2d_backward_out_cpu + CUDA: adaptive_max_pool2d_backward_out_cuda + MPS: adaptive_max_pool2d_backward_out_mps + +- func: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + python_module: nn + structured_delegate: adaptive_max_pool2d_backward.grad_input + +# Return: (Tensor output, Tensor indices) +- func: adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + python_module: nn + structured: True + dispatch: + CPU: adaptive_max_pool3d_out_cpu + CUDA: adaptive_max_pool3d_out_cuda + +# Return: (Tensor output, Tensor indices) +- func: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) + python_module: nn + structured_delegate: adaptive_max_pool3d.out + +- func: adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: adaptive_max_pool3d_backward_out_cpu + CUDA: adaptive_max_pool3d_backward_out_cuda + +- func: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + python_module: nn + structured_delegate: adaptive_max_pool3d_backward.grad_input + +- func: avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + precomputed: + - kernel_size -> int kH, int kW + - stride -> int dH, int dW + - padding -> int padH, int padW + dispatch: + CPU: avg_pool2d_out_cpu + CUDA: avg_pool2d_out_cuda + MPS: avg_pool2d_out_mps + MkldnnCPU: mkldnn_avg_pool2d_out + +- func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + python_module: nn + structured_delegate: avg_pool2d.out + dispatch: + MkldnnCPU: mkldnn_avg_pool2d + QuantizedCPU: avg_pool2d_quantized_cpu + tags: core + +- func: avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: avg_pool2d_backward_out_cpu + CUDA: avg_pool2d_backward_out_cuda + MPS: avg_pool2d_backward_out_mps + MkldnnCPU: mkldnn_avg_pool2d_backward_out + +- func: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + python_module: nn + structured_delegate: avg_pool2d_backward.grad_input + dispatch: + MkldnnCPU: mkldnn_avg_pool2d_backward + tags: core + +- func: avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: avg_pool3d_out_cpu + CUDA: avg_pool3d_out_cuda + MkldnnCPU: mkldnn_avg_pool3d_out + +- func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + python_module: nn + structured_delegate: avg_pool3d.out + dispatch: + MkldnnCPU: mkldnn_avg_pool3d + QuantizedCPU: avg_pool3d_quantized_cpu + tags: core + +- func: avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: avg_pool3d_backward_out_cpu + CUDA: avg_pool3d_backward_out_cuda + MkldnnCPU: mkldnn_avg_pool3d_backward_out + +- func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + python_module: nn + structured_delegate: avg_pool3d_backward.grad_input + dispatch: + MkldnnCPU: mkldnn_avg_pool3d_backward + +# Return: (Tensor output, Tensor indices) +- func: fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + python_module: nn + structured: True + dispatch: + CPU: fractional_max_pool2d_out_cpu + CUDA: fractional_max_pool2d_out_cuda + +# Return: (Tensor output, Tensor indices) +- func: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor) + python_module: nn + structured_delegate: fractional_max_pool2d.output + +- func: fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: fractional_max_pool2d_backward_cpu + CUDA: fractional_max_pool2d_backward_cuda + +- func: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor + python_module: nn + structured_delegate: fractional_max_pool2d_backward.grad_input + +# Return: (Tensor output, Tensor indices) +- func: fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + python_module: nn + structured: True + precomputed: + - kernel_size -> int poolSizeT, int poolSizeH, int poolSizeW + - output_size -> int outputT, int outputH, int outputW + - int numBatch, int numPlanes, int inputT, int inputH, int inputW + dispatch: + CPU: fractional_max_pool3d_out_cpu + CUDA: fractional_max_pool3d_out_cuda + +# Return: (Tensor output, Tensor indices) +- func: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor) + python_module: nn + structured_delegate: fractional_max_pool3d.output + +- func: fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: fractional_max_pool3d_backward_out_cpu + CUDA: fractional_max_pool3d_backward_out_cuda + +- func: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor + python_module: nn + dispatch: + CPU: fractional_max_pool3d_backward_cpu + CUDA: fractional_max_pool3d_backward_cuda + +# Return: (Tensor output, Tensor indices) +- func: max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + python_module: nn + structured: True + dispatch: + CPU: max_pool2d_with_indices_out_cpu + CUDA: max_pool2d_with_indices_out_cuda + MPS: max_pool2d_with_indices_out_mps + +# Return: (Tensor output, Tensor indices) +- func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + python_module: nn + structured_delegate: max_pool2d_with_indices.out + tags: core + +- func: max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: max_pool2d_with_indices_backward_out_cpu + CUDA: max_pool2d_with_indices_backward_out_cuda + MPS: max_pool2d_with_indices_backward_out_mps + +- func: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor + python_module: nn + structured_delegate: max_pool2d_with_indices_backward.grad_input + tags: core + +# Return: (Tensor output, Tensor indices) +- func: max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + python_module: nn + dispatch: + CPU: max_pool3d_with_indices_out_cpu + CUDA: max_pool3d_with_indices_out_cuda + +# Return: (Tensor output, Tensor indices) +- func: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + python_module: nn + dispatch: + CPU: max_pool3d_with_indices_cpu + CUDA: max_pool3d_with_indices_cuda + tags: core + +- func: max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: max_pool3d_with_indices_backward_out_cpu + CUDA: max_pool3d_with_indices_backward_out_cuda + +- func: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor + python_module: nn + dispatch: + CPU: max_pool3d_with_indices_backward_cpu + CUDA: max_pool3d_with_indices_backward_cuda + +- func: max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CPU: max_unpooling2d_forward_out_cpu + CUDA: max_unpooling2d_forward_out_cuda + +- func: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor + python_module: nn + dispatch: + CPU: max_unpooling2d_forward_cpu + CUDA: max_unpooling2d_forward_cuda + +- func: max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CPU: max_unpooling3d_forward_out_cpu + CUDA: max_unpooling3d_forward_out_cuda + +- func: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor + python_module: nn + dispatch: + CPU: max_unpooling3d_forward_cpu + CUDA: max_unpooling3d_forward_cuda + +- func: reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: reflection_pad1d_out_cpu + QuantizedCPU: reflection_pad1d_out_quantized_cpu + CUDA: reflection_pad1d_out_cuda + MPS: reflection_pad1d_out_mps + +- func: reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor + python_module: nn + structured_delegate: reflection_pad1d.out + tags: core + +- func: reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: reflection_pad1d_backward_out_cpu + CUDA: reflection_pad1d_backward_out_cuda + MPS: reflection_pad1d_backward_out_mps + +- func: reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + python_module: nn + structured_delegate: reflection_pad1d_backward.grad_input + +- func: reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CPU, QuantizedCPU: reflection_pad2d_out_cpu + CUDA: reflection_pad2d_out_cuda + MPS: reflection_pad2d_out_mps + +- func: reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor + python_module: nn + dispatch: + CPU: reflection_pad2d_cpu + QuantizedCPU: reflection_pad2d_quantized_cpu + CUDA: reflection_pad2d_cuda + MPS: reflection_pad2d_mps + tags: core + +- func: reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: reflection_pad2d_backward_out_cpu + CUDA: reflection_pad2d_backward_out_cuda + MPS: reflection_pad2d_backward_out_mps + +- func: reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + python_module: nn + dispatch: + CPU: reflection_pad2d_backward_cpu + CUDA: reflection_pad2d_backward_cuda + MPS: reflection_pad2d_backward_mps + +- func: reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: reflection_pad3d_out_cpu + CUDA: reflection_pad3d_out_cuda + MPS: reflection_pad3d_out_mps + +- func: reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor + python_module: nn + structured_delegate: reflection_pad3d.out + tags: core + +- func: reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: reflection_pad3d_backward_out_cpu + CUDA: reflection_pad3d_backward_out_cuda + MPS: reflection_pad3d_backward_out_mps + +- func: reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + python_module: nn + structured_delegate: reflection_pad3d_backward.grad_input + +- func: replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: replication_pad1d_out_cpu + CUDA: replication_pad1d_out_cuda + MPS: replication_pad1d_out_mps + +- func: replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor + python_module: nn + structured_delegate: replication_pad1d.out + +- func: replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: replication_pad1d_backward_out_cpu + CUDA: replication_pad1d_backward_out_cuda + MPS: replication_pad1d_backward_out_mps + +- func: replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + python_module: nn + structured_delegate: replication_pad1d_backward.grad_input + +- func: replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: replication_pad2d_out_cpu + CUDA: replication_pad2d_out_cuda + MPS: replication_pad2d_out_mps + +- func: replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor + python_module: nn + structured_delegate: replication_pad2d.out + tags: core + +- func: replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: replication_pad2d_backward_out_cpu + CUDA: replication_pad2d_backward_out_cuda + MPS: replication_pad2d_backward_out_mps + +- func: replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + python_module: nn + dispatch: + CPU: replication_pad2d_backward_cpu + CUDA: replication_pad2d_backward_cuda + MPS: replication_pad2d_backward_mps + +- func: replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: replication_pad3d_out_cpu + CUDA: replication_pad3d_out_cuda + MPS: replication_pad3d_out_mps + +- func: replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor + python_module: nn + structured_delegate: replication_pad3d.out + tags: core + + +- func: replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: replication_pad3d_backward_out_cpu + CUDA: replication_pad3d_backward_out_cuda + MPS: replication_pad3d_backward_out_mps + +- func: replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + python_module: nn + dispatch: + CPU: replication_pad3d_backward_cpu + CUDA: replication_pad3d_backward_cuda + MPS: replication_pad3d_backward_mps + +- func: _pad_circular(Tensor self, SymInt[] pad) -> Tensor + python_module: nn + dispatch: + CompositeImplicitAutograd: _pad_circular_symint + +- func: _pad_enum(Tensor self, SymInt[] pad, int mode, float? value=None) -> Tensor + python_module: nn + dispatch: + CompositeImplicitAutograd: _pad_enum_symint + +- func: pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor + python_module: nn + dispatch: + CompositeImplicitAutograd: pad_symint + +- func: upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + python_module: nn + autogen: upsample_linear1d.vec_out + +- func: upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + python_module: nn + autogen: upsample_bilinear2d.vec_out + tags: core + +- func: _upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + python_module: nn + autogen: _upsample_bilinear2d_aa.vec_out + +- func: upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + python_module: nn + autogen: upsample_trilinear3d.vec_out + +- func: upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + python_module: nn + autogen: upsample_bicubic2d.vec_out + +- func: _upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + python_module: nn + autogen: _upsample_bicubic2d_aa.vec_out + +- func: upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + python_module: nn + autogen: upsample_nearest1d.vec_out + +- func: _upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + python_module: nn + autogen: _upsample_nearest_exact1d.vec_out + +- func: upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + python_module: nn + autogen: upsample_nearest2d.vec_out + tags: core + +- func: _upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + python_module: nn + autogen: _upsample_nearest_exact2d.vec_out + +- func: upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + python_module: nn + autogen: upsample_nearest3d.vec_out + +- func: _upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + python_module: nn + autogen: _upsample_nearest_exact3d.vec_out + +# NOTE: all of the non-"vec" upsample overloads are only kept for backward compatibility. +- func: upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_linear1d_out_cpu + CUDA: upsample_linear1d_out_cuda + MPS: upsample_linear1d_out_mps + +- func: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor + python_module: nn + structured_delegate: upsample_linear1d.out + +- func: upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_linear1d_backward_out_cpu + CUDA: upsample_linear1d_backward_out_cuda + MPS: upsample_linear1d_backward_out_mps + +- func: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor + python_module: nn + structured_delegate: upsample_linear1d_backward.grad_input + +- func: upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_bilinear2d_out_cpu + CUDA: upsample_bilinear2d_out_cuda + MPS: upsample_bilinear2d_out_mps + +- func: upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: upsample_bilinear2d.out + dispatch: + QuantizedCPU: upsample_bilinear2d_quantized_cpu + +- func: upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_bilinear2d_backward_out_cpu + CUDA: upsample_bilinear2d_backward_out_cuda + MPS: upsample_bilinear2d_backward_out_mps + +- func: upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: upsample_bilinear2d_backward.grad_input + +- func: _upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: _upsample_bilinear2d_aa_out_cpu + CUDA: _upsample_bilinear2d_aa_out_cuda + +- func: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_bilinear2d_aa.out + +- func: _upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: _upsample_bilinear2d_aa_backward_out_cpu + CUDA: _upsample_bilinear2d_aa_backward_out_cuda + +- func: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_bilinear2d_aa_backward.grad_input + +- func: upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_bicubic2d_out_cpu + CUDA: upsample_bicubic2d_out_cuda + +- func: upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: upsample_bicubic2d.out + +- func: upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_bicubic2d_backward_out_cpu + CUDA: upsample_bicubic2d_backward_out_cuda + +- func: upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: upsample_bicubic2d_backward.grad_input + +- func: _upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: _upsample_bicubic2d_aa_out_cpu + CUDA: _upsample_bicubic2d_aa_out_cuda + +- func: _upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_bicubic2d_aa.out + +- func: _upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: _upsample_bicubic2d_aa_backward_out_cpu + CUDA: _upsample_bicubic2d_aa_backward_out_cuda + +- func: _upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_bicubic2d_aa_backward.grad_input + +- func: upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_trilinear3d_out_cpu + CUDA: upsample_trilinear3d_out_cuda + +- func: upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: upsample_trilinear3d.out + +- func: upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_trilinear3d_backward_out_cpu + CUDA: upsample_trilinear3d_backward_out_cuda + +- func: upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: upsample_trilinear3d_backward.grad_input + +- func: upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_nearest1d_out_cpu + CUDA: upsample_nearest1d_out_cuda + MPS: upsample_nearest1d_out_mps + +- func: _upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: _upsample_nearest_exact1d_out_cpu + CUDA: _upsample_nearest_exact1d_out_cuda + MPS: _upsample_nearest_exact1d_out_mps + +- func: upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + python_module: nn + structured_delegate: upsample_nearest1d.out + +- func: _upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + python_module: nn + structured_delegate: _upsample_nearest_exact1d.out + +- func: upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_nearest1d_backward_out_cpu + CUDA: upsample_nearest1d_backward_out_cuda + MPS: upsample_nearest1d_backward_out_mps + +- func: _upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: _upsample_nearest_exact1d_backward_out_cpu + CUDA: _upsample_nearest_exact1d_backward_out_cuda + MPS: _upsample_nearest_exact1d_backward_out_mps + +- func: upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + python_module: nn + structured_delegate: upsample_nearest1d_backward.grad_input + +- func: _upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + python_module: nn + structured_delegate: _upsample_nearest_exact1d_backward.grad_input + +- func: upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_nearest2d_out_cpu + CUDA: upsample_nearest2d_out_cuda + MPS: upsample_nearest2d_out_mps + +- func: _upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: _upsample_nearest_exact2d_out_cpu + CUDA: _upsample_nearest_exact2d_out_cuda + MPS: _upsample_nearest_exact2d_out_mps + +- func: upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: upsample_nearest2d.out + dispatch: + QuantizedCPU: upsample_nearest2d_quantized_cpu + +- func: _upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_nearest_exact2d.out + dispatch: + QuantizedCPU: _upsample_nearest_exact2d_quantized_cpu + +- func: upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_nearest2d_backward_out_cpu + CUDA: upsample_nearest2d_backward_out_cuda + MPS: upsample_nearest2d_backward_out_mps + +- func: _upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: _upsample_nearest_exact2d_backward_out_cpu + CUDA: _upsample_nearest_exact2d_backward_out_cuda + MPS: _upsample_nearest_exact2d_backward_out_mps + +- func: upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: upsample_nearest2d_backward.grad_input + +- func: _upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_nearest_exact2d_backward.grad_input + +- func: upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_nearest3d_out_cpu + CUDA: upsample_nearest3d_out_cuda + +- func: _upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: _upsample_nearest_exact3d_out_cpu + CUDA: _upsample_nearest_exact3d_out_cuda + +- func: upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: upsample_nearest3d.out + dispatch: + QuantizedCPU: upsample_nearest3d_quantized_cpu + +- func: _upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_nearest_exact3d.out + dispatch: + QuantizedCPU: _upsample_nearest_exact3d_quantized_cpu + +- func: upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: upsample_nearest3d_backward_out_cpu + CUDA: upsample_nearest3d_backward_out_cuda + +- func: _upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: _upsample_nearest_exact3d_backward_out_cpu + CUDA: _upsample_nearest_exact3d_backward_out_cuda + +- func: upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: upsample_nearest3d_backward.grad_input + +- func: _upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_nearest_exact3d_backward.grad_input + +- func: sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: sigmoid_backward_out + MPS: sigmoid_backward_out_mps + tags: pointwise + +- func: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor + python_module: nn + structured_delegate: sigmoid_backward.grad_input + tags: pointwise + +- func: logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: logit_backward_out + MPS: logit_backward_out_mps + tags: pointwise + +- func: logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor + python_module: nn + structured_delegate: logit_backward.grad_input + tags: pointwise + +- func: tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: tanh_backward_out + MPS: tanh_backward_out_mps + tags: pointwise + +- func: tanh_backward(Tensor grad_output, Tensor output) -> Tensor + python_module: nn + structured_delegate: tanh_backward.grad_input + +# What's a thnn_conv_ versus a slow_conv_? +# +# Historically, we have inefficient implementations of convolutions +# coming from the THNN/THCUNN library. These convolutions typically +# operated by computing the Toeplitz matrix and then doing a matrix +# multiply with the input; this is very memory inefficient! However, +# occasionally, we really don't have anything better, so it's helpful +# to have these fallbacks when there is no more optimized implementation +# in cudnn or mkldnn, etc. Both thnn_ and slow_ convolutions fall +# into this bucket. +# +# The difference between these two designations, is that thnn_ refers +# to a convolution that is still written in the "legacy" style; that is, +# C code in the THNN/ or THCUNN/ directory. A slow_ convolution is +# one that is written in the native style: modern C++. Algorithmically, +# these are the same thing, but we give them different prefixes to +# make the operational distinction clear. + tags: pointwise + +- func: slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: slow_conv_transpose2d_structured_cpu + CUDA: slow_conv_transpose2d_structured_cuda + +- func: slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor + python_module: nn + structured_delegate: slow_conv_transpose2d.out + +- func: slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CPU: slow_conv_transpose3d_out_cpu + CUDA: slow_conv_transpose3d_out_cuda + +- func: slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor + python_module: nn + dispatch: + CPU: slow_conv_transpose3d_cpu + CUDA: slow_conv_transpose3d_cuda + +- func: thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + +- func: thnn_conv2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0) -> Tensor + python_module: nn + +- func: _slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!) + python_module: nn + dispatch: + CPU: slow_conv2d_forward_out_cpu + CUDA: slow_conv2d_forward_out_cuda + +- func: _slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor + python_module: nn + dispatch: + CPU: slow_conv2d_forward_cpu + CUDA: slow_conv2d_forward_cuda + +- func: _slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + python_module: nn + dispatch: + CPU: slow_conv2d_backward_out_cpu + CUDA: slow_conv2d_backward_out_cuda + +- func: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + python_module: nn + dispatch: + CPU: slow_conv2d_backward_cpu + CUDA: slow_conv2d_backward_cuda + autogen: _slow_conv2d_backward.output_mask_out + +- func: _conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!) + use_const_ref_for_mutable_tensors: True + python_module: nn + dispatch: + CUDA: conv_depthwise2d_cuda_out + +- func: _conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor + python_module: nn + dispatch: + CUDA: conv_depthwise2d_cuda + +- func: conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor + python_module: nn + dispatch: + CUDA: conv_depthwise3d_cuda + autogen: conv_depthwise3d.out + +- func: slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + +- func: slow_conv3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0) -> Tensor + python_module: nn + +- func: slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!) + python_module: nn + dispatch: + CPU: slow_conv3d_forward_out_cpu + +- func: slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor + python_module: nn + dispatch: + CPU: slow_conv3d_forward_cpu + +- func: slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor + python_module: nn + dispatch: + CPU: slow_conv_dilated2d_cpu + CUDA: slow_conv_dilated2d_cuda + autogen: slow_conv_dilated2d.out + +- func: slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor + python_module: nn + dispatch: + CPU: slow_conv_dilated3d_cpu + CUDA: slow_conv_dilated3d_cuda + autogen: slow_conv_dilated3d.out + +- func: col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CPU: col2im_out_cpu + CUDA: col2im_out_cuda + +- func: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + python_module: nn + dispatch: + CPU: col2im_cpu + CUDA: col2im_cuda + tags: core + +- func: column_stack(Tensor[] tensors) -> Tensor + +- func: column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + +- func: im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CPU: im2col_out_cpu + CUDA: im2col_out_cuda + +- func: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + python_module: nn + dispatch: + CPU: im2col_cpu + CUDA: im2col_cuda + +- func: isfinite(Tensor self) -> Tensor + variants: function, method + device_check: NoCheck + device_guard: False + +- func: isinf(Tensor self) -> Tensor + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: isinf + SparseCPU, SparseCUDA: isinf_sparse + SparseMeta: isinf_sparse_meta + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isinf_sparse_csr + autogen: isinf.out + tags: [core, pointwise] + +- func: record_stream(Tensor(a!) self, Stream s) -> () + variants: method + dispatch: + CUDA: record_stream_cuda + +- func: isposinf(Tensor self) -> Tensor + variants: function, method + structured_delegate: isposinf.out + dispatch: + SparseCPU, SparseCUDA: isposinf_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isposinf_sparse_csr + tags: pointwise + +- func: isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: isposinf_out + SparseCPU, SparseCUDA: isposinf_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isposinf_sparse_csr_out + tags: pointwise + +- func: isneginf(Tensor self) -> Tensor + variants: function, method + structured_delegate: isneginf.out + dispatch: + SparseCPU, SparseCUDA: isneginf_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isneginf_sparse_csr + tags: pointwise + +- func: isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: isneginf_out + SparseCPU, SparseCUDA: isneginf_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isneginf_sparse_csr_out + tags: pointwise + +# NOTE [_add_batch_dim and _remove_batch_dim] +# _add_batch_dim and _remove_batch_dim are meant to be used in the implementation +# of the vmap frontend API (see torch/_vmap_internals.py). They are not +# user-facing, hence the leading underscore. Please don't use them them anywhere else. +- func: _add_batch_dim(Tensor self, int batch_dim, int level) -> Tensor + variants: function + +# See NOTE [_add_batch_dim and _remove_batch_dim] +- func: _remove_batch_dim(Tensor self, int level, int batch_size, int out_dim) -> Tensor + variants: function + +## Functions related to the `torch.special` namespace +# Note [special namespace binding] +# Functions in the special python module should have their names start with +# "special_" underscore and be bound to the desired Python name in +# torch/special/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/special.h. +# The "special_" names should be hidden from the user and not documented. + +- func: special_entr(Tensor self) -> Tensor + structured_delegate: special_entr.out + python_module: special + variants: function + tags: pointwise + +- func: special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: special + variants: function + dispatch: + CPU, CUDA: special_entr_out + tags: pointwise + +- func: special_ndtri(Tensor self) -> Tensor + structured_delegate: special_ndtri.out + python_module: special + variants: function + tags: pointwise + +- func: special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: special + variants: function + dispatch: + CPU, CUDA: special_ndtri_out + tags: pointwise + +- func: special_log_ndtr(Tensor self) -> Tensor + structured_delegate: special_log_ndtr.out + python_module: special + variants: function + tags: pointwise + +- func: special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: special + variants: function + dispatch: + CPU, CUDA: special_log_ndtr_out + tags: pointwise + +- func: special_expm1(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_exp2(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_psi(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_digamma(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_gammaln(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_erf(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_erfc(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + +- func: special_erfcx(Tensor self) -> Tensor + python_module: special + variants: function + structured_delegate: special_erfcx.out + tags: pointwise + +- func: special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: special_erfcx_out + tags: pointwise + +- func: special_erfinv(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + +- func: special_ndtr(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_xlog1py(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + python_module: special + variants: function + structured_delegate: special_xlog1py.out + tags: pointwise + +- func: special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + python_module: special + variants: function + dispatch: + CompositeExplicitAutograd: special_xlog1py + tags: pointwise + +- func: special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + python_module: special + variants: function + dispatch: + CompositeExplicitAutograd: special_xlog1py + tags: pointwise + +- func: special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + python_module: special + variants: function + dispatch: + CPU, CUDA: special_xlog1py_out + tags: pointwise + +- func: special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: special + variants: function + dispatch: + CompositeExplicitAutograd: special_xlog1py_out + tags: pointwise + +- func: special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: special + variants: function + dispatch: + CompositeExplicitAutograd: special_xlog1py_out + tags: pointwise + +- func: special_xlogy(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + python_module: special + variants: function + +- func: special_xlogy.self_scalar(Scalar self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + python_module: special + variants: function + +- func: special_xlogy.other_scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + python_module: special + variants: function + +- func: special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: special + variants: function + +- func: special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: special + variants: function + +- func: special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: special + variants: function + +- func: special_zeta(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + python_module: special + variants: function + structured_delegate: special_zeta.out + tags: pointwise + +- func: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + python_module: special + variants: function + dispatch: + CompositeExplicitAutograd: special_zeta + tags: pointwise + +- func: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor + device_check: NoCheck # TensorIterator + python_module: special + variants: function + dispatch: + CompositeExplicitAutograd: special_zeta + tags: pointwise + +- func: special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + python_module: special + variants: function + dispatch: + CPU, CUDA: special_zeta_out + tags: pointwise + +- func: special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: special + variants: function + dispatch: + CompositeExplicitAutograd: special_zeta_out + tags: pointwise + +- func: special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + python_module: special + variants: function + dispatch: + CompositeExplicitAutograd: special_zeta_out + tags: pointwise + +- func: special_i0(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_i0e(Tensor self) -> Tensor + python_module: special + variants: function + structured_delegate: special_i0e.out + tags: pointwise + +- func: special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: special_i0e_out + tags: pointwise + +- func: special_i1(Tensor self) -> Tensor + python_module: special + variants: function + structured_delegate: special_i1.out + tags: pointwise + +- func: special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: special_i1_out + tags: pointwise + +- func: special_i1e(Tensor self) -> Tensor + python_module: special + variants: function + structured_delegate: special_i1e.out + tags: pointwise + +- func: special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: special_i1e_out + tags: pointwise + +- func: special_logit(Tensor self, float? eps=None) -> Tensor + python_module: special + variants: function + +- func: special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + +- func: special_polygamma(int n, Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + +- func: special_logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + python_module: special + variants: function + +- func: special_logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + +- func: special_expit(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_sinc(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_round(Tensor self, *, int decimals=0) -> Tensor + python_module: special + variants: function + +- func: special_round.out(Tensor self, *, int decimals=0, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_log1p(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + python_module: special + variants: function + +- func: special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_gammainc(Tensor self, Tensor other) -> Tensor + python_module: special + variants: function + +- func: special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_gammaincc(Tensor self, Tensor other) -> Tensor + python_module: special + variants: function + +- func: special_multigammaln(Tensor self, int p) -> Tensor + python_module: special + variants: function + +- func: special_multigammaln.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + python_module: special + variants: function + +## Functions related to the fast Fourier transform and the torch.fft namespace +# Note [FFT namespace binding] +# Functions in the fft python module should have their names start with +# "fft_" underscore and be bound to the desired Python name in +# torch/fft/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/fft.h. +# The "fft_" names should be hidden from the user and not documented. +# +# See fft_fft as an example. + +# torch.fft.fft +# NOTE: NOT an alias for torch.fft, which has different semantics +- func: fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_fft_symint + +- func: fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_fft_symint_out + +- func: fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_ifft_symint + +- func: fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_ifft_symint_out + +- func: fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_rfft_symint + +- func: fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_rfft_symint_out + +- func: fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_irfft_symint + +- func: fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_irfft_symint_out + +- func: fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_hfft_symint + +- func: fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_hfft_symint_out + +- func: fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_ihfft_symint + +- func: fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_ihfft_symint_out + +- func: fft_fft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_fft2_symint + +- func: fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_fft2_symint_out + +- func: fft_ifft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_ifft2_symint + +- func: fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_ifft2_symint_out + +- func: fft_rfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_rfft2_symint + +- func: fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_rfft2_symint_out + +- func: fft_irfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_irfft2_symint + +- func: fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_irfft2_symint_out + +- func: fft_hfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + use_const_ref_for_mutable_tensors: True + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_hfft2_symint + +- func: fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_const_ref_for_mutable_tensors: True + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_hfft2_symint_out + +- func: fft_ihfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + use_const_ref_for_mutable_tensors: True + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_ihfft2_symint + +- func: fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_const_ref_for_mutable_tensors: True + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_ihfft2_symint_out + +- func: fft_fftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_fftn_symint + +- func: fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_fftn_symint_out + +- func: fft_ifftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_ifftn_symint + +- func: fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_ifftn_symint_out + +- func: fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_rfftn_symint + +- func: fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_rfftn_symint_out + +- func: fft_irfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_irfftn_symint + +- func: fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_irfftn_symint_out + +- func: fft_hfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + use_const_ref_for_mutable_tensors: True + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_hfftn_symint + +- func: fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_const_ref_for_mutable_tensors: True + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_hfftn_symint_out + +- func: fft_ihfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + use_const_ref_for_mutable_tensors: True + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_ihfftn_symint + +- func: fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_const_ref_for_mutable_tensors: True + python_module: fft + variants: function + dispatch: + CompositeImplicitAutograd: fft_ihfftn_symint_out + +- func: fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeExplicitAutograd: fft_fftfreq + +- func: fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeExplicitAutograd: fft_fftfreq_out + +- func: fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + python_module: fft + variants: function + dispatch: + CompositeExplicitAutograd: fft_rfftfreq + +- func: fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) + python_module: fft + variants: function + dispatch: + CompositeExplicitAutograd: fft_rfftfreq_out + +- func: fft_fftshift(Tensor self, int[1]? dim=None) -> Tensor + python_module: fft + variants: function + +- func: fft_ifftshift(Tensor self, int[1]? dim=None) -> Tensor + python_module: fft + variants: function + +## Functions for linear algebra and the torch.linalg namespace +# Note [linalg namespace binding] +# Functions in the linalg python module should have their names start with +# "linalg_" and be bound to the desired Python name in +# torch/linalg/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/linalg.h. +# The "linalg_" names should be hidden from the user and not documented. +# +# See linalg_det as an example. + +# "_ex" stands for experimental +- func: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) + python_module: linalg + structured_delegate: linalg_cholesky_ex.L + +- func: linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info) + python_module: linalg + structured: True + dispatch: + CPU, CUDA: linalg_cholesky_ex_out + +- func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor + python_module: linalg + +- func: linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +- func: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor + python_module: linalg + variants: function + structured_delegate: linalg_cross.out + dispatch: + ZeroTensor: linalg_cross_zerotensor + +- func: linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + structured: True + dispatch: + CPU, CUDA, MPS: linalg_cross_out + +# linalg.lu_factor +- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) + python_module: linalg + variants: function + dispatch: + CompositeImplicitAutograd: linalg_lu_factor + MPS: linalg_lu_factor_mps + +- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) + python_module: linalg + variants: function + dispatch: + CompositeImplicitAutograd: linalg_lu_factor_out + MPS: linalg_lu_factor_out_mps + +- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) + python_module: linalg + structured_delegate: linalg_lu_factor_ex.out + variants: function + +- func: linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) + python_module: linalg + variants: function + structured: True + dispatch: + CPU, CUDA: linalg_lu_factor_ex_out + +# linalg.lu +- func: linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U) + python_module: linalg + structured_delegate: linalg_lu.out + variants: function + +- func: linalg_lu.out(Tensor A, *, bool pivot=True, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) + python_module: linalg + variants: function + structured: True + dispatch: + CPU, CUDA: linalg_lu_out + +# linalg.lu_solve +- func: linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor + python_module: linalg + structured_delegate: linalg_lu_solve.out + variants: function + +- func: linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + structured: True + dispatch: + CPU, CUDA: linalg_lu_solve_out + +# linalg.det +- func: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots) + structured_delegate: _linalg_det.result + +- func: _linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) + structured: True + dispatch: + CPU, CUDA: _linalg_det_out + +- func: linalg_det(Tensor A) -> Tensor + python_module: linalg + variants: function + +- func: linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +# torch.det, alias for torch.linalg.det +- func: det(Tensor self) -> Tensor + variants: function, method + +- func: linalg_ldl_factor_ex(Tensor self, *, bool hermitian=False, bool check_errors=False) -> (Tensor LD, Tensor pivots, Tensor info) + structured_delegate: linalg_ldl_factor_ex.out + python_module: linalg + variants: function + +- func: linalg_ldl_factor_ex.out(Tensor self, *, bool hermitian=False, bool check_errors=False, Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) + structured: True + python_module: linalg + variants: function + dispatch: + CPU, CUDA: linalg_ldl_factor_ex_out + +- func: linalg_ldl_factor(Tensor self, *, bool hermitian=False) -> (Tensor LD, Tensor pivots) + python_module: linalg + variants: function + +- func: linalg_ldl_factor.out(Tensor self, *, bool hermitian=False, Tensor(a!) LD, Tensor(b!) pivots) -> (Tensor(a!) LD, Tensor(b!) pivots) + python_module: linalg + variants: function + +- func: linalg_ldl_solve(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False) -> Tensor + structured_delegate: linalg_ldl_solve.out + python_module: linalg + variants: function + +- func: linalg_ldl_solve.out(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + structured: True + python_module: linalg + variants: function + dispatch: + CPU, CUDA: linalg_ldl_solve_out + +- func: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values) + python_module: linalg + variants: function + dispatch: + CompositeExplicitAutograd: linalg_lstsq + tags: dynamic_output_shape + +- func: linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) + python_module: linalg + variants: function + dispatch: + CPU, CUDA: linalg_lstsq_out + tags: dynamic_output_shape + +# torch.linalg.matmul, alias for torch.matmul +- func: linalg_matmul(Tensor self, Tensor other) -> Tensor + python_module: linalg + variants: function + +- func: linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +- func: linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor + python_module: linalg + variants: function + +- func: linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +- func: linalg_matrix_exp(Tensor self) -> Tensor + python_module: linalg + variants: function + dispatch: + CPU, CUDA: linalg_matrix_exp + autogen: linalg_matrix_exp.out + +- func: _linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots) + structured_delegate: _linalg_slogdet.sign + +- func: _linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) + structured: True + dispatch: + CPU, CUDA: _linalg_slogdet_out + +- func: linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet) + python_module: linalg + +- func: linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + python_module: linalg + +- func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) + variants: function, method + +- func: slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + variants: function + +- func: logdet(Tensor self) -> Tensor + variants: function, method + +- func: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors) + python_module: linalg + variants: function + dispatch: + CPU, CUDA: linalg_eig + +- func: linalg_eig.out(Tensor self, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + python_module: linalg + dispatch: + CPU, CUDA: linalg_eig_out + +- func: _linalg_eigvals(Tensor self) -> Tensor + python_module: linalg + dispatch: + CPU, CUDA: _linalg_eigvals + +- func: linalg_eigvals(Tensor self) -> Tensor + python_module: linalg + +- func: linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + dispatch: + CPU, CUDA: linalg_eigvals_out + +# This function is exposes the `compute_v` flag, which is then used to implement `linalg.eigh` and +# `linalg.eigvalsh` as composite functions that call this one +- func: _linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors) + structured_delegate: _linalg_eigh.eigenvalues + +- func: _linalg_eigh.eigenvalues(Tensor A, str UPLO="L", bool compute_v=True, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + structured: True + dispatch: + CPU, CUDA: _linalg_eigh_out + +- func: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors) + python_module: linalg + +- func: linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + python_module: linalg + +- func: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor + python_module: linalg + +- func: linalg_eigvalsh.out(Tensor self, str UPLO="L", *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +- func: linalg_householder_product(Tensor input, Tensor tau) -> Tensor + python_module: linalg + variants: function + dispatch: + CPU, CUDA: linalg_householder_product + +- func: linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + dispatch: + CPU, CUDA: linalg_householder_product_out + +- func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info) + python_module: linalg + structured_delegate: linalg_inv_ex.inverse + +- func: linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info) + python_module: linalg + structured: True + dispatch: + CPU, CUDA: linalg_inv_ex_out + MPS: linalg_inv_ex_out_mps + +- func: linalg_inv(Tensor A) -> Tensor + python_module: linalg + +- func: linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +- func: inverse(Tensor self) -> Tensor + variants: function, method + +- func: inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + +- func: inner(Tensor self, Tensor other) -> Tensor + variants: function, method + +- func: inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + +- func: outer(Tensor self, Tensor vec2) -> Tensor + variants: function, method + +- func: outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + +# torch.ger, alias for torch.outer +- func: ger(Tensor self, Tensor vec2) -> Tensor + variants: function, method + +- func: ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + +- func: linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + python_module: linalg + variants: function + +- func: linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + python_module: linalg + variants: function + +- func: linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + python_module: linalg + variants: function + structured_delegate: linalg_vector_norm.out + +- func: linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + structured: True + dispatch: + CPU, CUDA: linalg_vector_norm_out + MPS: linalg_vector_norm_out_mps + +- func: linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + python_module: linalg + +- func: linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +- func: linalg_matrix_norm.str_ord(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + python_module: linalg + +- func: linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +# This function is exposes the `compute_uv` flag, which is then used to implement `linalg.svd` and +# `linalg.svdvals` as composite functions that call this one +- func: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) + variants: function + structured_delegate: _linalg_svd.U + +- func: _linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) + structured: True + dispatch: + CPU, CUDA: _linalg_svd_out + +- func: linalg_svd(Tensor A, bool full_matrices=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) + python_module: linalg + variants: function + +- func: linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) + python_module: linalg + variants: function + +- func: linalg_svdvals(Tensor A, *, str? driver=None) -> Tensor + python_module: linalg + variants: function + +- func: linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_cond(Tensor self, Scalar? p=None) -> Tensor + python_module: linalg + variants: function + +- func: linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_cond.p_str(Tensor self, str p) -> Tensor + python_module: linalg + variants: function + +- func: linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor + python_module: linalg + variants: function + dispatch: + # calls svd, which calls mH() (view op) + # also calls narrow() + CompositeExplicitAutogradNonFunctional: linalg_pinv + +- func: linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + dispatch: + CompositeExplicitAutograd: linalg_pinv_out + +- func: linalg_pinv.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor + cpp_no_default_args: ['atol', 'rtol'] + python_module: linalg + variants: function + +- func: linalg_pinv.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + cpp_no_default_args: ['atol', 'rtol'] + python_module: linalg + variants: function + +- func: linalg_pinv(Tensor self, float rcond, bool hermitian=False) -> Tensor + python_module: linalg + variants: function + +- func: linalg_pinv.rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False) -> Tensor + python_module: linalg + variants: function + +- func: linalg_pinv.out(Tensor self, float rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info) + structured_delegate: _linalg_solve_ex.result + +- func: _linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) + structured: True + dispatch: + CPU, CUDA: _linalg_solve_ex_out + +- func: linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info) + python_module: linalg + +- func: linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info) + python_module: linalg + +- func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor + python_module: linalg + +- func: _spsolve(Tensor A, Tensor B, *, bool left=True) -> Tensor + python_module: sparse + dispatch: + SparseCsrCUDA: _sparse_csr_linear_solve + +- func: linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +- func: linalg_tensorinv(Tensor self, int ind=2) -> Tensor + python_module: linalg + variants: function + +- func: linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor + python_module: linalg + variants: function + +- func: linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R) + python_module: linalg + variants: function + structured_delegate: linalg_qr.out + +- func: linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + python_module: linalg + structured: True + dispatch: + CPU, CUDA: linalg_qr_out + +- func: linalg_matrix_power(Tensor self, int n) -> Tensor + python_module: linalg + +- func: linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +- func: linalg_matrix_rank.atol_rtol_tensor(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor + python_module: linalg + variants: function + +- func: linalg_matrix_rank.atol_rtol_tensor_out(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_matrix_rank.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor + cpp_no_default_args: ['atol', 'rtol'] + python_module: linalg + variants: function + +- func: linalg_matrix_rank.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + cpp_no_default_args: ['atol', 'rtol'] + python_module: linalg + variants: function + +- func: linalg_matrix_rank(Tensor self, float tol, bool hermitian=False) -> Tensor + python_module: linalg + variants: function + +- func: linalg_matrix_rank.out(Tensor self, float tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_matrix_rank.tol_tensor(Tensor input, Tensor tol, bool hermitian=False) -> Tensor + python_module: linalg + variants: function + +- func: linalg_matrix_rank.out_tol_tensor(Tensor input, Tensor tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_multi_dot(Tensor[] tensors) -> Tensor + python_module: linalg + +- func: linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +## Functions related to the `torch.nested` namespace +# Note [nested namespace binding] +# Functions in the nested python module should have their names start with +# "nested_" underscore and be bound to the desired Python name in +# torch/nested/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/nested.h. +# The "nested_" names should be hidden from the user and not documented. + +- func: nested_to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor + python_module: nested + variants: function + +## Functions that are only for testing +# It is undocumented and should not be used outside of tests. +- func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor + +# Note: for testing COW materialization within `at::parallel_for` loop function +- func: _test_parallel_materialize(Tensor self, int num_parallel, bool skip_first=False) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _test_parallel_materialize + +# Note: this function is only for testing. +- func: _test_optional_intlist(Tensor values, int[]? addends) -> Tensor + python_module: nn + dispatch: + CPU: _test_optional_intlist + autogen: _test_optional_intlist.out + +# Note: this function is only for testing. +- func: _test_optional_filled_intlist(Tensor values, int[2]? addends) -> Tensor + python_module: nn + dispatch: + CPU: _test_optional_intlist + autogen: _test_optional_filled_intlist.out + +# Note: this function is only for testing. +- func: _test_optional_floatlist(Tensor values, float[]? addends) -> Tensor + python_module: nn + dispatch: + CPU: _test_optional_floatlist + autogen: _test_optional_floatlist.out + +# Note: this function is only for testing. +- func: _test_string_default(Tensor dummy, str a="\"'\\", str b='"\'\\') -> Tensor + python_module: nn + +# Note: this function is only for testing. +- func: _test_ambiguous_defaults.a(Tensor dummy, int a=1, int b=1) -> Tensor + python_module: nn + +# Note: this function is only for testing. +- func: _test_ambiguous_defaults.b(Tensor dummy, int a=2, str b="2") -> Tensor + cpp_no_default_args: ['a', 'b'] + python_module: nn + +# Note: this function is only for testing. +- func: _test_warn_in_autograd(Tensor self) -> Tensor + python_module: nn + dispatch: + CompositeExplicitAutograd: _test_warn_in_autograd + autogen: _test_warn_in_autograd.out + +# Note: this function is only for testing. +- func: _test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor + dispatch: + # the NestedTensor keys are necessary because NestedTensor has been removed + # from the CompositeExplicitAutograd keyset see Note [NestedTensor Not Included in Backend Keys] + CompositeExplicitAutograd, NestedTensorCPU, NestedTensorCUDA: _test_autograd_multiple_dispatch_fullcoverage + autogen: _test_autograd_multiple_dispatch.fullcoverage_out + +# Note: this function is only for testing. +- func: _test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor + dispatch: + CompositeImplicitAutograd, NestedTensorCPU, NestedTensorCUDA: _test_autograd_multiple_dispatch_ntonly + +# Note: this function is only for testing. +- func: _test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a) + dispatch: + CompositeExplicitAutograd: _test_autograd_multiple_dispatch_view + +# Note: this function is only for testing. +- func: _test_autograd_multiple_dispatch_view_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: _test_autograd_multiple_dispatch_view_copy + tags: view_copy + autogen: _test_autograd_multiple_dispatch_view_copy.out + +- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor + variants: function + dispatch: + CPU, CUDA: segment_reduce_kernel + autogen: segment_reduce.out + +- func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor + variants: function + dispatch: + CPU, CUDA: _segment_reduce_backward_kernel + autogen: _segment_reduce_backward.out + +- func: pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0, str padding_side="right") -> Tensor + python_module: nn + variants: function + +- func: flatten_dense_tensors(Tensor[] tensors) -> Tensor + variants: function + python_module: nn + +- func: unflatten_dense_tensors(Tensor flat, Tensor[] tensors) -> Tensor[] + variants: function + python_module: nn + +- func: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _nested_tensor_from_tensor_list + autogen: _nested_tensor_from_tensor_list.out + +- func: _fw_primal_copy(Tensor self, int level) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: _fw_primal_copy + tags: view_copy + autogen: _fw_primal_copy.out + +- func: _make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: _make_dual_copy + tags: view_copy + autogen: _make_dual_copy.out + +- func: view_as_real_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: view_as_real_copy + tags: view_copy + autogen: view_as_real_copy.out + +- func: view_as_complex_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: view_as_complex_copy + tags: view_copy + autogen: view_as_complex_copy.out + +- func: _conj_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: _conj_copy + tags: view_copy + autogen: _conj_copy.out + +- func: _neg_view_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: _neg_view_copy + tags: view_copy + autogen: _neg_view_copy.out + +- func: as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: as_strided_copy_symint + tags: view_copy + autogen: as_strided_copy.out + +- func: _sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: _sparse_broadcast_to_copy + tags: view_copy + autogen: _sparse_broadcast_to_copy.out + +- func: diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: diagonal_copy + tags: view_copy + autogen: diagonal_copy.out + +- func: expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: expand_copy_symint + tags: view_copy + autogen: expand_copy.out + +- func: permute_copy(Tensor self, int[] dims) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: permute_copy + tags: view_copy + autogen: permute_copy.out + +- func: _reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: _reshape_alias_copy_symint + tags: view_copy + autogen: _reshape_alias_copy.out + +- func: select_copy.int(Tensor self, int dim, SymInt index) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: select_copy_symint + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: select_copy_sparse_csr + tags: view_copy + autogen: select_copy.int_out + +- func: detach_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: detach_copy + tags: view_copy + autogen: detach_copy.out + +- func: slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: slice_copy_Tensor_symint + tags: view_copy + autogen: slice_copy.Tensor_out + +- func: split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: split_copy_Tensor_symint + tags: view_copy + +- func: split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: split_with_sizes_copy_symint + tags: view_copy + +- func: squeeze_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: squeeze_copy + tags: view_copy + autogen: squeeze_copy.out + +- func: squeeze_copy.dim(Tensor self, int dim) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: squeeze_copy_dim + tags: view_copy + autogen: squeeze_copy.dim_out + +- func: squeeze_copy.dims(Tensor self, int[] dim) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: squeeze_copy_dims + tags: view_copy + autogen: squeeze_copy.dims_out + +- func: t_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: t_copy + tags: view_copy + autogen: t_copy.out + +- func: transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: transpose_copy_int + tags: view_copy + autogen: transpose_copy.int_out + +- func: unsqueeze_copy(Tensor self, int dim) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: unsqueeze_copy + tags: view_copy + autogen: unsqueeze_copy.out + +- func: _indices_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: _indices_copy + tags: view_copy + autogen: _indices_copy.out + +- func: _values_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: _values_copy + tags: view_copy + autogen: _values_copy.out + +- func: indices_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: indices_copy + tags: view_copy + autogen: indices_copy.out + +- func: values_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: values_copy + tags: view_copy + autogen: values_copy.out + +- func: crow_indices_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: crow_indices_copy + tags: view_copy + autogen: crow_indices_copy.out + +- func: col_indices_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: col_indices_copy + tags: view_copy + autogen: col_indices_copy.out + +- func: ccol_indices_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: ccol_indices_copy + tags: view_copy + autogen: ccol_indices_copy.out + +- func: row_indices_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: row_indices_copy + tags: view_copy + autogen: row_indices_copy.out + +- func: unbind_copy.int(Tensor self, int dim=0) -> Tensor[] + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: unbind_copy_int + tags: view_copy + +- func: unbind_copy.int_out(Tensor self, int dim=0, *, Tensor(a!)[] out) -> () + variants: function + dispatch: + CompositeExplicitAutograd: unbind_copy_int_out + +- func: split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + variants: function + dispatch: + CompositeExplicitAutograd: split_copy_Tensor_out + + +- func: split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + variants: function + dispatch: + CompositeExplicitAutograd: split_with_sizes_copy_out + CUDA: split_with_sizes_copy_out_cuda + +- func: view_copy(Tensor self, SymInt[] size) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: view_copy_symint + tags: view_copy + autogen: view_copy.out + +- func: view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: view_copy_dtype + tags: view_copy + autogen: view_copy.dtype_out + +- func: unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: unfold_copy + tags: view_copy + autogen: unfold_copy.out + +- func: alias_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: alias_copy + tags: view_copy + autogen: alias_copy.out + +- func: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor + variants: method + dispatch: + NestedTensorCPU: NestedTensor_to_padded_tensor_generic + NestedTensorCUDA: NestedTensor_to_padded_tensor_cuda + autogen: to_padded_tensor.out + +- func: _jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value=0.0) -> Tensor + variants: function + dispatch: + CUDA: _fbgemm_jagged_to_padded_dense_forward + CPU: _jagged_to_padded_dense_forward_cpu + +- func: _padded_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor + variants: function + dispatch: + CUDA: _fbgemm_dense_to_jagged_forward_symint + CPU: _padded_dense_to_jagged_forward_cpu + +- func: _nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor + dispatch: + NestedTensorCPU: NestedTensor_softmax_dropout + NestedTensorCUDA: NestedTensor_softmax_dropout_cuda + tags: nondeterministic_seeded + +- func: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + dispatch: + CompositeExplicitAutograd: _safe_softmax + NestedTensorCPU, NestedTensorCUDA: _safe_softmax + +# Apparently, putting "forward" in the name will cause Python bindings to be skipped, so "fwd" it is. +- func: _transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor + variants: function + dispatch: + CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: transformer_encoder_layer_forward + autogen: _transformer_encoder_layer_fwd.out + +- func: _native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor) + variants: function + dispatch: + CPU, NestedTensorCPU: native_multi_head_attention_cpu + CUDA, NestedTensorCUDA: native_multi_head_attention_cuda + autogen: _native_multi_head_attention.out + +- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor + python_module: nn + variants: function + autogen: scaled_dot_product_attention.out + tags: nondeterministic_seeded + +# This aten function is kept so that we can test the choice function from Python +- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int + dispatch: + Meta: _fused_sdp_choice_meta + CPU, NestedTensorCPU: _fused_sdp_choice_cpp + CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda + tags: nondeterministic_seeded + +- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor) + variants: function + tags: nondeterministic_seeded + +- func: _scaled_dot_product_attention_math_for_mps(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor) + dispatch: + MPS: _scaled_dot_product_attention_math_mps + tags: nondeterministic_seeded + +- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + dispatch: + CUDA: _scaled_dot_product_flash_attention_cuda + NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda + tags: nondeterministic_seeded + +- func: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp) + dispatch: + CPU: _scaled_dot_product_flash_attention_cpu + tags: nondeterministic_seeded + +- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + dispatch: + CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable + tags: nondeterministic_seeded + +- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) + device_check: NoCheck + variants: function + dispatch: + CUDA: _scaled_dot_product_flash_attention_backward_cuda + NestedTensorCUDA: _scaled_dot_product_flash_attention_backward_nested + +- func: _scaled_dot_product_flash_attention_for_cpu_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, float dropout_p, bool is_causal, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) + device_check: NoCheck + variants: function + dispatch: + CPU: _scaled_dot_product_flash_attention_cpu_backward + +- func: _scaled_dot_product_fused_attention_overrideable_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor attn_bias, bool[4] grad_input_mask, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value, Tensor grad_attn_bias) + device_check: NoCheck + variants: function + dispatch: + CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable_backward + +- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) + dispatch: + CUDA: _scaled_dot_product_efficient_attention_cuda + NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda + tags: nondeterministic_seeded + +- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor) + device_check: NoCheck + dispatch: + CUDA: _scaled_dot_product_efficient_attention_backward_cuda + tags: nondeterministic_seeded + +- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + dispatch: + CUDA: _scaled_dot_product_cudnn_attention_cuda + tags: nondeterministic_seeded + +- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: _scaled_dot_product_cudnn_attention_backward_cuda + tags: nondeterministic_seeded + +- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + variants: function + dispatch: + CUDA: _flash_attention_forward + tags: nondeterministic_seeded + +- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor) + device_check: NoCheck + variants: function + dispatch: + CUDA: _flash_attention_backward + +# Returns output, logsumexp if compute_logsumexp +- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) + variants: function + dispatch: + CUDA: _efficient_attention_forward + tags: nondeterministic_seeded + +- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor) + device_check: NoCheck + variants: function + dispatch: + CUDA: _efficient_attention_backward + +- func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor + variants: function + dispatch: + CUDA: triton_scaled_dot_attention + tags: nondeterministic_seeded + autogen: _triton_scaled_dot_attention.out + +- func: _fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!) + variants: function + dispatch: + CUDA: _fill_mem_eff_dropout_mask_ + tags: nondeterministic_seeded + +- func: _triton_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None) -> Tensor + variants: function + dispatch: + CUDA: triton_multi_head_attention + autogen: _triton_multi_head_attention.out + +- func: special_airy_ai(Tensor x) -> Tensor + python_module: special + structured_delegate: special_airy_ai.out + variants: function + tags: pointwise + +- func: special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_airy_ai_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_bessel_j0(Tensor self) -> Tensor + python_module: special + structured_delegate: special_bessel_j0.out + variants: function + tags: pointwise + +- func: special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_bessel_j0_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_bessel_j1(Tensor self) -> Tensor + python_module: special + structured_delegate: special_bessel_j1.out + variants: function + tags: pointwise + +- func: special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_bessel_j1_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_bessel_y0(Tensor self) -> Tensor + python_module: special + structured_delegate: special_bessel_y0.out + variants: function + tags: pointwise + +- func: special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_bessel_y0_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_bessel_y1(Tensor self) -> Tensor + python_module: special + structured_delegate: special_bessel_y1.out + variants: function + tags: pointwise + +- func: special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_bessel_y1_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor + device_check: NoCheck + python_module: special + structured_delegate: special_chebyshev_polynomial_t.out + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_t + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_t + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + dispatch: + CPU, CUDA: special_chebyshev_polynomial_t_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_t_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_t_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor + device_check: NoCheck + python_module: special + structured_delegate: special_chebyshev_polynomial_u.out + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_u + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_u + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + dispatch: + CPU, CUDA: special_chebyshev_polynomial_u_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_u_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_u_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor + device_check: NoCheck + python_module: special + structured_delegate: special_chebyshev_polynomial_v.out + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_v + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_v + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + dispatch: + CPU, CUDA: special_chebyshev_polynomial_v_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_v_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_v_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor + device_check: NoCheck + python_module: special + structured_delegate: special_chebyshev_polynomial_w.out + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_w + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_w + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + dispatch: + CPU, CUDA: special_chebyshev_polynomial_w_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_w_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_chebyshev_polynomial_w_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor + device_check: NoCheck + python_module: special + structured_delegate: special_hermite_polynomial_h.out + variants: function + tags: pointwise + +- func: special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_hermite_polynomial_h + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_hermite_polynomial_h + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + dispatch: + CPU, CUDA: special_hermite_polynomial_h_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_hermite_polynomial_h_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_hermite_polynomial_h_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor + device_check: NoCheck + python_module: special + structured_delegate: special_hermite_polynomial_he.out + variants: function + tags: pointwise + +- func: special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_hermite_polynomial_he + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_hermite_polynomial_he + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + dispatch: + CPU, CUDA: special_hermite_polynomial_he_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_hermite_polynomial_he_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_hermite_polynomial_he_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor + device_check: NoCheck + python_module: special + structured_delegate: special_laguerre_polynomial_l.out + variants: function + tags: pointwise + +- func: special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_laguerre_polynomial_l + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_laguerre_polynomial_l + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + dispatch: + CPU, CUDA: special_laguerre_polynomial_l_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_laguerre_polynomial_l_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_laguerre_polynomial_l_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor + device_check: NoCheck + python_module: special + structured_delegate: special_legendre_polynomial_p.out + variants: function + tags: pointwise + +- func: special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_legendre_polynomial_p + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_legendre_polynomial_p + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + dispatch: + CPU, CUDA: special_legendre_polynomial_p_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_legendre_polynomial_p_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_legendre_polynomial_p_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_modified_bessel_i0(Tensor self) -> Tensor + python_module: special + structured_delegate: special_modified_bessel_i0.out + variants: function + tags: pointwise + +- func: special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_modified_bessel_i0_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_modified_bessel_i1(Tensor self) -> Tensor + python_module: special + structured_delegate: special_modified_bessel_i1.out + variants: function + tags: pointwise + +- func: special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_modified_bessel_i1_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_modified_bessel_k0(Tensor self) -> Tensor + python_module: special + structured_delegate: special_modified_bessel_k0.out + variants: function + tags: pointwise + +- func: special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_modified_bessel_k0_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_modified_bessel_k1(Tensor self) -> Tensor + python_module: special + structured_delegate: special_modified_bessel_k1.out + variants: function + tags: pointwise + +- func: special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_modified_bessel_k1_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_scaled_modified_bessel_k0(Tensor x) -> Tensor + python_module: special + structured_delegate: special_scaled_modified_bessel_k0.out + variants: function + tags: pointwise + +- func: special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_scaled_modified_bessel_k0_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_scaled_modified_bessel_k1(Tensor x) -> Tensor + python_module: special + structured_delegate: special_scaled_modified_bessel_k1.out + variants: function + tags: pointwise + +- func: special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_scaled_modified_bessel_k1_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor + device_check: NoCheck + python_module: special + structured_delegate: special_shifted_chebyshev_polynomial_t.out + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_t + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_t + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + dispatch: + CPU, CUDA: special_shifted_chebyshev_polynomial_t_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_t_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_t_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor + device_check: NoCheck + python_module: special + structured_delegate: special_shifted_chebyshev_polynomial_u.out + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_u + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_u + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + dispatch: + CPU, CUDA: special_shifted_chebyshev_polynomial_u_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_u_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_u_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor + device_check: NoCheck + python_module: special + structured_delegate: special_shifted_chebyshev_polynomial_v.out + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_v + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_v + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + dispatch: + CPU, CUDA: special_shifted_chebyshev_polynomial_v_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_v_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_v_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor + device_check: NoCheck + python_module: special + structured_delegate: special_shifted_chebyshev_polynomial_w.out + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_w + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_w + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + dispatch: + CPU, CUDA: special_shifted_chebyshev_polynomial_w_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_w_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_w_out + device_check: NoCheck + python_module: special + variants: function + tags: pointwise + +- func: special_spherical_bessel_j0(Tensor x) -> Tensor + python_module: special + structured_delegate: special_spherical_bessel_j0.out + variants: function + tags: pointwise + +- func: special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_spherical_bessel_j0_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +# Aux function used in the test TestPythonDispatch.test_kwarg_only_and_positional_default +# within test/test_python_dispatch.py +- func: _foobar(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True) -> Tensor + dispatch: + CPU: foobar + autogen: _foobar.out + +- func: _fused_adam_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now). + variants: function + dispatch: + CPU: _fused_adam_kernel_cpu_ + CUDA: _fused_adam_kernel_cuda_ + MPS: _fused_adam_kernel_mps_ + autogen: _fused_adam, _fused_adam.out + +- func: _fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now), + # but still skip the device check as the Tensor LR can be on CPU + device_check: NoCheck + variants: function + dispatch: + CPU: _fused_adam_kernel_cpu_ + CUDA: _fused_adam_kernel_cuda_ + MPS: _fused_adam_kernel_mps_ + autogen: _fused_adam.tensor_lr, _fused_adam.tensor_lr_out + +- func: _fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now). + variants: function + dispatch: + CPU: _fused_adamw_kernel_cpu_ + CUDA: _fused_adamw_kernel_cuda_ + MPS: _fused_adamw_kernel_mps_ + autogen: _fused_adamw, _fused_adamw.out + +- func: _fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now), + # but still skip the device check as the Tensor LR can be on CPU + device_check: NoCheck + variants: function + dispatch: + CPU: _fused_adamw_kernel_cpu_ + CUDA: _fused_adamw_kernel_cuda_ + MPS: _fused_adamw_kernel_mps_ + autogen: _fused_adamw.tensor_lr, _fused_adamw.tensor_lr_out + +- func: _fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now). + variants: function + dispatch: + CPU: _fused_sgd_kernel_cpu_ + CUDA: _fused_sgd_kernel_cuda_ + MPS: _fused_sgd_kernel_mps_ + autogen: _fused_sgd, _fused_sgd.out + +- func: _fused_sgd_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now). + # but still skip the device check as the Tensor LR can be on CPU + device_check: NoCheck + variants: function + dispatch: + CPU: _fused_sgd_kernel_cpu_ + CUDA: _fused_sgd_kernel_cuda_ + MPS: _fused_sgd_kernel_mps_ + autogen: _fused_sgd.tensor_lr, _fused_sgd.tensor_lr_out + +- func: _fused_adagrad_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + variants: function + dispatch: + CPU: _fused_adagrad_kernel_cpu_ + autogen: _fused_adagrad, _fused_adagrad.out + +# This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts. +- func: _propagate_xla_data(Tensor input, Tensor output) -> () + variants: function diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/native/tags.yaml b/lib/python3.10/site-packages/torchgen/packaged/ATen/native/tags.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3544a3cf0b16c671406d0aacf4d3359783245d2f --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/native/tags.yaml @@ -0,0 +1,74 @@ +# This yaml file contains all the possible tags that can be defined in `tags` in `native_functions.yaml` + +- tag: inplace_view + desc: | + This tag indicates if an operator *only* modifies the tensor metadata +- tag: pt2_compliant_tag + desc: | + This tag indicates if the operator is guaranteed to + work with the PT2 compilation APIs (torch.compile, + torch.export, etc). If you add this tag to an + operator, please use + `torch.testing._internal.optest.opcheck` to test that + the operator has been registered correctly and + works with torch.compile +- tag: view_copy + desc: | + This tag indicates operators that are *_copy* variants + of view/aliasing operators. If an operator has a view_copy tag, + then it should have the name {op}_copy, where {op} is a view operator. +- tag: dynamic_output_shape + desc: | + This tag indicates if an operator's output's shape depends on input Tensor + data. +- tag: data_dependent_output + desc: | + Operator has a non-Tensor output whose value is dependent on the data + of Tensor inputs. Among other things, this implies that this operator + cannot be run with meta tensor (since data is not available), nor + can it be symbolically traced. +- tag: generated + desc: | + This tag indicates that the operator doesn't have an explicit entry in + native_functions.yaml, and instead was generated automatically by the codegen. +- tag: nondeterministic_seeded + desc: | + This tag indicates if an operator is nondeterministically seeded + (i.e., is random) such that the operator intentionally produces + different results when run twice on the same inputs, but this randomness + is controlled by a Generator which, if reseeded would give you the + same result. +- tag: nondeterministic_bitwise + desc: | + This tag indicates if an operator doesn't guarantee bitwise equivalence + across different runs of an operator with identical inputs. +- tag: needs_fixed_stride_order + desc: | + This tag indicates that the operator should be passed Tensors following + the same stride permutation as observed in eager when compiled in inductor. + Only one of {needs_fixed_stride_order, flexible_layout} can apply; if + multiple are assigned then we assume the most restrictive one. +- tag: flexible_layout + desc: | + This tag indicates that the custom operator can accept inputs with varying + strides/storage_offset and that when compiled, Inductor is allowed to change + the strides/storage_offset of inputs to the custom operator. + Only one of {needs_fixed_stride_order, flexible_layout} can apply; if + multiple are assigned then we assume the most restrictive one. + +# NOTE [Core ATen Ops] +- tag: core + desc: | + Core aten ops is a subset of aten ops that remains after aten-to-aten decomposition and + functionalization pass. Core aten ops are fully functional and adhere to single static + assignment (SSA): this implies there will be no `inplace` or `_out` variants in this opset. + This opset is designed to serve as the functional IR to interface with compiler backends. + In contrast to primTorch, core aten opset doesn't decompose ops into explicit + type promotion and broadcasting ops. + Core aten ops is also effectively the opset produced by torchdynamo.export(aten_graph=True), + and thus can be used as an opset for export purpose. +- tag: pointwise + desc: | + Pointwise operators are operators where each element of the output is computed only by accessing + the corresponding element of all the broadcasted inputs. The output shape will be the broadcasted + shape of the inputs. diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/ATenOpList.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/ATenOpList.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5de3424857e236917eb68940e7904446de59f586 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/ATenOpList.cpp @@ -0,0 +1,36 @@ +#include + +#include +#include +#include +#include +#include + +// ${generated_comment} + +namespace at { + +namespace { +struct OpNameEquals final { + bool operator()(const std::pair& lhs, const std::pair& rhs) const { + return 0 == strcmp(lhs.first, rhs.first) && 0 == strcmp(lhs.second, rhs.second); + } +}; + +struct OpNameHash final { + size_t operator()(const std::pair& p) const { + // use std::hash because std::hash would hash pointers and not pointed-to strings + return std::hash()(p.first) ^ (~ std::hash()(p.second)); + } +}; +} + +bool is_custom_op(const c10::OperatorName& opName) { + static std::unordered_set, OpNameHash, OpNameEquals> ops { + ${aten_ops} + {"", ""} + }; + return ops.count(std::make_pair( + opName.name.c_str(), opName.overload_name.c_str())) == 0; +} +} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/CompositeViewCopyKernels.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/CompositeViewCopyKernels.cpp new file mode 100644 index 0000000000000000000000000000000000000000..47097d7aa4320674bec4bddbb5ac861309334f0c --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/CompositeViewCopyKernels.cpp @@ -0,0 +1,73 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +$ops_headers +#endif + +namespace at { +namespace native { + +// This file contains a number of kernels for aten functions that are fully code-generated. +// TODO: rename this file to something more generic. + +namespace { +at::Tensor clone_arg(const at::Tensor& t) { + return t.clone(); +} + +std::vector clone_arg(const at::TensorList& t_list) { + std::vector out(t_list.size()); + for (const auto& i : c10::irange(t_list.size())) { + out[i] = t_list[i].clone(); + } + return out; +} + +// duped with gen_resize_out_helper from structured kernels +void copy_arg(const at::Tensor& dst, const at::Tensor& src) { + TORCH_CHECK(src.dtype() == dst.dtype(), + "Expected out tensor to have dtype ", src.dtype(), ", but got ", dst.dtype(), " instead"); + TORCH_CHECK(src.device() == dst.device(), + "Expected out tensor to have device ", src.device(), ", but got ", dst.device(), " instead"); + dst.copy_(src); +} + +void copy_arg(const at::TensorList& dst, const at::TensorList& src) { + TORCH_INTERNAL_ASSERT(dst.size() == src.size()); + for (const auto& i : c10::irange(dst.size())) { + copy_arg(dst[i], src[i]); + } +} + +// TODO: this doesn't handle restriding empty tensors correctly; see +// gen_resize_out_helper for the correct algorithm + +void resize_out_helper(const at::Tensor& dst, const at::Tensor& src) { + at::native::resize_output(dst, src.sizes()); +} + +void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) { + TORCH_INTERNAL_ASSERT(dst.size() == src.size()); + for (const auto& i : c10::irange(dst.size())) { + at::native::resize_output(dst[i], src[i].sizes()); + } +} +} + + +${CompositeViewCopyKernel_Definitions} + +${GeneratedCompositeFunctional_Definitions} + +${GeneratedCompositeOut_Definitions} + +} // namespace native +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunction.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunction.h new file mode 100644 index 0000000000000000000000000000000000000000..c92d5eb3898ecea0fb9e1f79c2725d1bc6dfa7fb --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunction.h @@ -0,0 +1,23 @@ +#pragma once +// ${generated_comment} + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace ${dispatch_namespace} { + +${dispatch_namespaced_declarations} + +} // namespace ${dispatch_namespace} +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..35f43297fdd9ca9f932c8c53b5b773f1b9b8a427 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h @@ -0,0 +1,29 @@ +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +${inline_headers} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..fbb71c2cb123cb21fb57ec32341d86bff06f6a17 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h @@ -0,0 +1,22 @@ +#pragma once +// ${generated_comment} + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +${DispatchKeyFunctions_inl_includes} + + +${dispatch_namespaced_declarations} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7647f459a744b2eacfac6aaea4f49b86babbb234 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp @@ -0,0 +1,13 @@ +// ${generated_comment} +${includes} +${native_functions_include} + +namespace { +${helper_fns} +} // namespace + +${namespace_prologue} + +${native_function_definitions} + +${namespace_epilogue} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..b45a17b5922f8a0b76e0237616914ce9969efca5 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h @@ -0,0 +1,19 @@ +#pragma once + +// an external backend might generate file within its code tree +// and check all the source files within the tree with clang-format. +// so, disable it since the backend might have a different config. +// clang-format off + +// ${generated_comment} + +#include + +${namespace_prologue} + +struct ${class_name} { + +${dispatch_declarations} + +}; +${namespace_epilogue} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Function.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Function.h new file mode 100644 index 0000000000000000000000000000000000000000..db430a3ffc4977ca3037a6698849e55669216e22 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Function.h @@ -0,0 +1,26 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +${static_dispatch_ops_headers} + +${operator_includes} + +namespace at { + +${function_definitions} + +} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h new file mode 100644 index 0000000000000000000000000000000000000000..3217e097d7adf37ab7041c45f5dd413024fc33ef --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h @@ -0,0 +1,33 @@ +#pragma once + +// ${generated_comment} + +#include + +namespace at { +namespace functionalization { + +enum class InverseReturnMode { + /// Specifies that functional inverses should always return a view. + AlwaysView, + /// Specifies that functional inverses should always return a non-view / copy. + NeverView, + /// Specifies that functional inverses should return a view unless a (copying) scatter + /// inverse exists, in which case that will be used instead. + /// This avoids as_strided() calls that can be difficult for subclasses to handle. + ViewOrScatterInverse, +}; + +struct FunctionalInverses { + +${view_inverse_declarations} + +// NB: These are not generated! They're manually implemented in the template. +// TODO: Change codegen to generate these. See the following link: +// https://github.com/pytorch/pytorch/blob/main/torchgen/model.py#L2583-L2585 +static at::Tensor chunk_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int chunks, int dim); +static at::Tensor narrow_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int dim, c10::SymInt start, c10::SymInt length); + +}; +} +} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Functions.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1a2ffaa238163831a1e177670c20c7fed06d600a --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Functions.cpp @@ -0,0 +1,103 @@ +#include + +#include +#include +#include + +namespace at { + +Tensor TensorMaker::make_tensor() { + AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove. + tracer::impl::NoTracerDispatchMode tracer_guard{}; + + check_size_nonnegative(sizes_); + + TORCH_CHECK_VALUE( + !deleter_ || !ctx_, + "The deleter and context arguments are mutually exclusive."); + + if (device_ == std::nullopt) { + device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type()); + } + + if (opts_.device().has_index()) { + // clang-format off + TORCH_CHECK_VALUE( + opts_.device() == *device_, + "Specified device ", opts_.device(), " does not match device of data ", *device_); + // clang-format on + } + + std::size_t size_bytes = computeStorageSize(); + + DataPtr data_ptr{}; + if (deleter_) { + data_ptr = makeDataPtrFromDeleter(); + } else { + data_ptr = makeDataPtrFromContext(); + } + + TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()"); + Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizable=*/resizeable_}; + + Tensor tensor = detail::make_tensor( + std::move(storage), opts_.computeDispatchKey(), opts_.dtype()); + + TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); + if (strides_) { + tensor_impl->set_sizes_and_strides(sizes_, *strides_); + } else { + tensor_impl->set_sizes_contiguous(sizes_); + } + if (storage_offset_) { + tensor_impl->set_storage_offset(*storage_offset_); + } + + return tensor; + } + + std::size_t TensorMaker::computeStorageSize() const noexcept { + std::size_t itemsize = opts_.dtype().itemsize(); + + if (strides_) { + auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize); + if (storage_offset_) { + storage_size += storage_offset_.value(); + } + return storage_size; + } + + std::size_t size = 1; + for (std::int64_t s : sizes_) { + size *= static_cast(s); + } + auto storage_size = size * itemsize; + if (storage_offset_) { + storage_size += storage_offset_.value(); + } + return storage_size; + } + + inline DataPtr TensorMaker::makeDataPtrFromDeleter() noexcept { + return InefficientStdFunctionContext::makeDataPtr(data_, std::move(deleter_), *device_); + } + + inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept { + return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_}; + } + + IntArrayRef TensorMaker::makeTempSizes() const noexcept { + static std::int64_t zeros[5] = {0, 0, 0, 0, 0}; + if (opts_.has_memory_format()) { + MemoryFormat format = *opts_.memory_format_opt(); + if (format == MemoryFormat::ChannelsLast) { + return IntArrayRef(zeros, 4); + } + if (format == MemoryFormat::ChannelsLast3d) { + return IntArrayRef(zeros, 5); + } + } + return IntArrayRef(zeros, 1); + } + +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Functions.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Functions.h new file mode 100644 index 0000000000000000000000000000000000000000..1f010ccec48b18116db2327f13cebd57c9effe23 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Functions.h @@ -0,0 +1,143 @@ +#pragma once + +// ${generated_comment} + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from and \ + see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +// NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS] +// +// In ATen, certain generated headers files include the definitions of +// every single operator in PyTorch. Unfortunately this means every +// time an operator signature is updated or changed in +// native_functions.yaml, you (and every other PyTorch developer) need +// to recompile every source file that includes any of these headers. +// +// To break up these header dependencies, and improve incremental +// build times for all PyTorch developers. These headers are split +// into per-operator headers in the `ATen/ops` folder. This limits +// incremental builds to only changes to methods of `Tensor`, or files +// that use the specific operator being changed. With `at::sum` as an +// example, you should include +// +// // instead of ATen/Functions.h +// // instead of ATen/NativeFunctions.h +// // instead of ATen/Operators.h +// // instead of ATen/CPUFunctions.h +// +// However, even if you're careful to use this in your own code. +// `Functions.h` might be included indirectly through another header +// without you realising. To avoid this, you can add +// +// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// +// to the top of your source file. This way any time the non-specific +// headers are included, the compiler will error out. +// +// Also, be aware that `ops` are not available in all build +// configurations (namely fb-internal) so you must guard these +// includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g. +// +// #ifndef AT_PER_OPERATOR_HEADERS +// #include +// #else +// #include +// #endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +${Functions_includes} + +namespace at { + +${Functions_declarations} + +// Special C++ only overloads for std()-like functions (See gh-40287) +// These are needed because int -> bool conversion takes precedence over int -> IntArrayRef +// So, for example std(0) would select the std(unbiased=False) overload +TORCH_API inline Tensor var(const Tensor& self, int dim) { + return at::var(self, IntArrayRef{dim}); +} +TORCH_API inline std::tuple var_mean(const Tensor& self, int dim) { + return at::var_mean(self, IntArrayRef{dim}); +} +TORCH_API inline Tensor std(const Tensor& self, int dim) { + return at::std(self, IntArrayRef{dim}); +} +TORCH_API inline std::tuple std_mean(const Tensor& self, int dim) { + return at::std_mean(self, IntArrayRef{dim}); +} + +inline int64_t numel(const Tensor& tensor) { + return tensor.numel(); +} + +inline int64_t size(const Tensor& tensor, int64_t dim) { + return tensor.size(dim); +} + +inline int64_t stride(const Tensor& tensor, int64_t dim) { + return tensor.stride(dim); +} + +inline bool is_complex(const Tensor& tensor) { + return tensor.is_complex(); +} + +inline bool is_floating_point(const Tensor& tensor) { + return tensor.is_floating_point(); +} + +inline bool is_signed(const Tensor& tensor) { + return tensor.is_signed(); +} + +inline bool is_inference(const Tensor& tensor) { + return tensor.is_inference(); +} + +inline bool _is_zerotensor(const Tensor& tensor) { + return tensor._is_zerotensor(); +} + +inline bool is_conj(const Tensor& tensor) { + return tensor.is_conj(); +} + +inline Tensor conj(const Tensor& tensor) { + return tensor.conj(); +} + +inline bool is_neg(const Tensor& tensor) { + return tensor.is_neg(); +} + +} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/LazyIr.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/LazyIr.h new file mode 100644 index 0000000000000000000000000000000000000000..9190ff8243d316fd2bd472bb3f0603701761bdb7 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/LazyIr.h @@ -0,0 +1,19 @@ +#pragma once + +// This file contains autogenerated LazyTensor IR nodes +${lazy_ir_sysinc} +${lazy_ir_inc} + +${namespace_prologue} +using at::operator<<; + +// kNullValue is used to contribute a static hash value any time +// a node has an Optional input that is nullopt. It is important +// to differentiate between HASH(std::nullopt, something) and HASH(something, std::nullopt), +// and using kNullValue in the hash function in the order of arguments +// serves this purpose. +static const torch::lazy::Value kNullValue = torch::lazy::Value(); + +${ir_declarations} + +${namespace_epilogue} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h new file mode 100644 index 0000000000000000000000000000000000000000..18eaf6da52e4b3654becac6cc89849bc0806ae09 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h @@ -0,0 +1,11 @@ +#pragma once + +${lazy_non_native_ir_inc} + +// This file contains autogenerated LazyTensor Non Native IR nodes + +${namespace_prologue} + +${non_native_ir_nodes} + +${namespace_epilogue} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/MethodOperators.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/MethodOperators.h new file mode 100644 index 0000000000000000000000000000000000000000..0e192cd05ef3c78fa74848c93de32150c1e3fd8b --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/MethodOperators.h @@ -0,0 +1,24 @@ +#pragma once + +// ${generated_comment} + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +${MethodOperators_includes} + +namespace at { +namespace _ops { +${MethodOperators_declarations} +} // namespace _ops +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h new file mode 100644 index 0000000000000000000000000000000000000000..a5441ad85d1d5e28c4e31dd3f0dc7f66dfbff9e7 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h @@ -0,0 +1,17 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +${extra_includes} + +${native_function_declarations} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..9dc972495ca038bddb7b887c39c2e0507e487213 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h @@ -0,0 +1,33 @@ +#pragma once + +// ${generated_comment} + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +${NativeFunctions_includes} + +${NativeFunctions_declarations} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h new file mode 100644 index 0000000000000000000000000000000000000000..6522c97546d0498e4b3825fb4eafefbb34c71911 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h @@ -0,0 +1,23 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +${meta_function_declarations} + +} // namespace native +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunctions.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..89989e2121c9aa34a4583205c3541a04edd36700 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunctions.h @@ -0,0 +1,19 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include +#include + +${NativeMetaFunctions_includes} + +namespace at { + +namespace meta { + +${NativeMetaFunctions_declarations} + +} // namespace meta +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Operator.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Operator.h new file mode 100644 index 0000000000000000000000000000000000000000..8b3989b66debc86e3782169c29a6f83fea222ac6 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Operator.h @@ -0,0 +1,18 @@ +#pragma once + +// ${generated_comment} + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + +${declarations} + +}} // namespace at::_ops diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Operators.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Operators.cpp new file mode 100644 index 0000000000000000000000000000000000000000..082bb67c3e2043f2c36b29345f57048ec2e9eea7 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Operators.cpp @@ -0,0 +1,19 @@ +#include +#include + +// ${generated_comment} +// NOTE See [Sharded File] comment in VariableType + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +${operator_headers} +#endif + +${static_dispatch_extra_headers} + +namespace at { namespace _ops { + +${definitions} + +}} // namespace at::_ops diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Operators.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Operators.h new file mode 100644 index 0000000000000000000000000000000000000000..e74b96ef3d5c6b6d50fe63eac4dca51f0655daa5 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/Operators.h @@ -0,0 +1,74 @@ +#pragma once + +// ${generated_comment} + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +${Operators_includes} + +// Extension writers: do you write wrapper functions? Are you frustrated with +// resolving overloads of operators? Are you frustrated with dealing with +// pointer-to-methods and resolving overloads of pointer-to-methods?? Look no +// further, this is the utility for you. +// +// Given an operator schema: aten::op.overload(... +// +// Use ATEN_FN2(op, overload) to get a *function* version of the operator +// that is guaranteed to not be overloaded. This means that you can safely +// decltype(&ATEN_FN2(op, overload)) it. NB: the 2 means this macro takes 2 args. +// +// Given an operator schema without an overload name: aten::op(... +// +// Use ATEN_FN(op) to get an unambiguous *function* version of the operator. +// +// There is some interesting behavior for out= operations. +// ATEN_FN2(sin, out) gives a function that is *faithful* to the schema; +// that is, the order of arguments is exactly what it looks like in the schema. + +#define ATEN_FN2(op_name, overload) at::_ops::op_name##_##overload::call +#define ATEN_FN(op_name) at::_ops::op_name::call + +// Separately, ATEN_OP(op) and ATEN_OP2(op, overload) define a class containing compile-time +// metadata about a given aten operator. +// Notable data on the class includes: +// - ATEN_OP2(add, Tensor)::name // returns the string name: "add" +// - ATEN_OP2(add, Tensor)::overload_name // returns the string overload name: "Tensor" +// - ATEN_OP2(add, Tensor)::schema // returns the C++ schema type: at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &) +// - ATEN_OP2(add, Tensor)::schema_str // returns the string jit type: "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" + +#define ATEN_OP2(op_name, overload) at::_ops::op_name##_##overload +#define ATEN_OP(op_name) at::_ops::op_name + +// WARNING: Please do not call any of the ops in the _ops namespace directly. +// Use the ATEN_FN macros. We do not guarantee stability of the naming +// scheme for the functions in at::_ops + +// See Note [The ATen Operators API] for details of the at::_ops namespace + +namespace at { +namespace _ops { +${Operators_declarations} +} // namespace _ops +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RedispatchFunctions.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RedispatchFunctions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..58102bd97fca4eaef477818b0b0a92b7995e38b1 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RedispatchFunctions.cpp @@ -0,0 +1,15 @@ +// ${generated_comment} + +#include +#include + +#include +#include + +namespace at { + +namespace redispatch { + ${function_redispatch_definitions} +} // namespace redispatch + +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RedispatchFunctions.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RedispatchFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..2422cdd409cfdd59c2a05df27d28bb25ee610463 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RedispatchFunctions.h @@ -0,0 +1,32 @@ +#pragma once + +// ${generated_comment} + +#ifdef TORCH_ASSERT_ONLY_METHOD_OPERATORS +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider using the at::_ops::{name}::redispatch() interface by including \ + the specific operator from +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +namespace redispatch { + ${function_redispatch_definitions} +} // namespace redispatch + +} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterBackendSelect.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterBackendSelect.cpp new file mode 100644 index 0000000000000000000000000000000000000000..018cf358f11237d5bdc9bca01aa8d09d1462f574 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterBackendSelect.cpp @@ -0,0 +1,29 @@ +// We register ops with a higher priority dispatch key (BackendSelect) than the usual backend-specific keys (e.g. CPU) +// which makes calls to the factory functions dispatch to here. +// We then 'manually' compute a lower-priority to re-dispatch to (e.g. CPU) to get to the eventually correct backend. +// ${generated_comment} + +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else + +${ops_headers} +#endif + +namespace at { + +namespace { + +${backend_select_method_definitions} + +TORCH_LIBRARY_IMPL(aten, BackendSelect, m) { + ${backend_select_function_registrations}; +} + +} // namespace +} // at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterCodegenUnboxedKernels.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterCodegenUnboxedKernels.cpp new file mode 100644 index 0000000000000000000000000000000000000000..279f987c66a26c2eb5d11c664c85b3604b67684b --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterCodegenUnboxedKernels.cpp @@ -0,0 +1,41 @@ +#include +#include +#include + +#include + +// ${generated_comment} + +// NOTE [Sharded File]: This file is generated in a sharded fashion to speed up +// incremental rebuilds. See the comment at the top of +// templates/VariableType.cpp for an analogous, in-depth discussion. +// +// Generated by tools/jit/gen_unboxing.py. This file registers all ATen ops into JIT op registry instead of c10 +// dispatcher. JIT op registry only takes boxed kernels, so we are calling unboxing functions in UnboxingFunctions.h +// to cast arguments into C++ types (instead of IValue) and delegate to unboxed kernels. + +namespace torch { namespace jit { + +using autograd::Variable; +using autograd::variable_list; +using at::Scalar; +using at::ScalarType; +using at::Tensor; +using at::TensorOptions; +using at::DeviceGuard; + +using ::c10::fmap; +using ::c10::filter; + +namespace { + +RegisterOperators reg({ + + // Generated operators + ${unboxed_ops} +}); + +} // anon namespace + + +}} // namespace torch::jit diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini new file mode 100644 index 0000000000000000000000000000000000000000..3bf7f9b1bb32112a126e88a2e23e47c91e58dd9c --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini @@ -0,0 +1,24 @@ +${ns_prologue} + +// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid +// ambiguity with conflicting identifiers that may have been defined in +// at namespace already. +namespace { + +${dispatch_helpers} + +${dispatch_anonymous_definitions} + +${static_init_dispatch_registrations} + +} // anonymous namespace + +${deferred_dispatch_registrations} + +namespace ${dispatch_namespace} { + +${dispatch_namespaced_definitions} + +} // namespace ${dispatch_namespace} + +${ns_epilogue} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ec34841034bad15f22cc520514e420c8838725bb --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp @@ -0,0 +1,54 @@ +// required for old g++ to compile PRId64 macros, see +// https://github.com/pytorch/pytorch/issues/3571 +// for context +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS +#endif + +// an external backend might generate file within its code tree +// and check all the source files within the tree with clang-format. +// so, disable it since the backend might have a different config. +// clang-format off + +// NOTE: This condition is true for all PyTorch internal libraries, it +// just excludes external projects such as torch_xla which +// re-use some of the PyTorch codegen machinery. +#if defined(CAFFE2_BUILD_MAIN_LIB) || \ + defined(TORCH_CUDA_BUILD_MAIN_LIB) || \ + defined(TORCH_HIP_BUILD_MAIN_LIB) || \ + defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \ + defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB) +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#endif + +// ${generated_comment} + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +$extra_cuda_headers +$external_backend_headers +$dispatch_headers +$ops_headers + +// See template file RegisterDispatchDefinitions.ini +$dispatch_definitions diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6a4dafa4c049667aa74be6eb91cbd7e82d7efc92 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp @@ -0,0 +1,110 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +// needed for the meta tensor calls to get stride info in functionalization +#include +// needed for special handling of copy_(). +// See Note [functionalizating copy_() and not preserving strides] +#include +#include + +$ops_headers +#endif + +namespace at { +namespace functionalization { + +// This keyset is used by functionalization when it calls into meta kernels +// to accurately propagate stride metadata. +// Exclude any modes: the purpose of calling into meta kernels is only as an implementation +// detail to perform shape inference, and we don't want any modal keys to run. +// Specifically, we want to prevent functionalization and Python modes from running. +constexpr auto exclude_keys_for_meta_dispatch = + c10::functorch_transforms_ks | + c10::DispatchKeySet({ + c10::DispatchKey::FuncTorchDynamicLayerBackMode, + c10::DispatchKey::FuncTorchDynamicLayerFrontMode, + c10::DispatchKey::Python, + c10::DispatchKey::PreDispatch, + + }); + +// Helper around at::has_internal_overlap. +// The ATen util is used in hot-path eager mode: it's always fast, +// but might return TOO_HARD sometimes. +// During functionalization, we're ok taking a bit longer +// to detect memory overlap. +inline bool has_internal_overlap_helper(const at::Tensor t) { + auto has_overlap = at::has_internal_overlap(t); + if (has_overlap == at::MemOverlap::Yes) return true; + if (has_overlap == at::MemOverlap::No) return false; + return false; +} + + +inline Tensor to_meta(const Tensor& t) { + if (!t.defined()) return t; + return at::native::empty_strided_meta_symint(t.sym_sizes(), t.sym_strides(), +/*dtype=*/std::make_optional(t.scalar_type()), /*layout=*/std::make_optional(t.layout()), +/*device=*/std::make_optional(c10::Device(kMeta)), /*pin_memory=*/std::nullopt); +} + +inline std::optional to_meta(const std::optional& t) { + if (t.has_value()) { + return std::make_optional(to_meta(*t)); + } + return std::nullopt; +} + +inline std::vector to_meta(at::ITensorListRef t_list) { + std::vector outputs; + outputs.reserve(t_list.size()); + for (const auto& tensor : t_list) { + outputs.push_back(to_meta(tensor)); + } + return outputs; +} + +inline c10::List to_meta(const c10::List& t_list) { + c10::List outputs; + outputs.reserve(t_list.size()); + for (const auto i : c10::irange(t_list.size())) { + outputs.push_back(to_meta(t_list[i])); + } + return outputs; +} + +inline c10::List<::std::optional> to_meta(const c10::List<::std::optional>& t_list) { + c10::List<::std::optional> outputs; + outputs.reserve(t_list.size()); + for (const auto i : c10::irange(t_list.size())) { + outputs.push_back(to_meta(t_list[i])); + } + return outputs; +} + + +${func_definitions} + +} // namespace functionalization + +namespace { + +TORCH_LIBRARY_IMPL(aten, Functionalize, m) { + ${func_registrations}; +} + +} // namespace + +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterSchema.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterSchema.cpp new file mode 100644 index 0000000000000000000000000000000000000000..029796d3e575b2bde85cfd44af9e6fcbb56466cd --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegisterSchema.cpp @@ -0,0 +1,13 @@ +// ${generated_comment} +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +namespace at { +TORCH_LIBRARY(aten, m) { + ${aten_schema_registrations}; + // Distributed Ops + // Implementations located in torch/csrc/jit/runtime/register_distributed_ops.cpp + m.def("get_gradients(int context_id) -> Dict(Tensor, Tensor)"); +} +${schema_registrations} +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegistrationDeclarations.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegistrationDeclarations.h new file mode 100644 index 0000000000000000000000000000000000000000..5a0f0d0c7b44dabb60061d32ced243fe607069d8 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/RegistrationDeclarations.h @@ -0,0 +1,4 @@ +// This file contains all native_functions that can be registered to +// and the schema string that they should be registered with + +${registration_declarations} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/TensorBody.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/TensorBody.h new file mode 100644 index 0000000000000000000000000000000000000000..2e1520392ef927719c7fe8cbfdb20d9fbeb921d0 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/TensorBody.h @@ -0,0 +1,753 @@ +#pragma once + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#include + +namespace c10{ +template class List; +template class IListRef; +} +namespace at { +struct Generator; +struct Type; +class DeprecatedTypeProperties; +class Tensor; +} // namespace at +namespace at { +namespace indexing { +struct TensorIndex; +} // namespace indexing +} // namespace at + +namespace torch { namespace autograd { + +struct Node; + +}} // namespace torch::autograd + +namespace at { + +class OptionalTensorRef; +class TensorRef; +class Tensor; +using TensorList = ArrayRef; +using ITensorList = c10::IListRef; + +using Stream = c10::Stream; + +// Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which +// has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr. +// +// For example: +// +// void func(Tensor a) { +// Tensor b = a; +// ... +// } +// +// In this example, when we say Tensor b = a, we are creating a new object that points to the +// same underlying TensorImpl, and bumps its reference count. When b goes out of scope, the +// destructor decrements the reference count by calling release() on the TensorImpl it points to. +// The existing constructors, operator overloads, etc. take care to implement the correct semantics. +// +// Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and +// special care must be taken to handle this. +class TORCH_API Tensor: public TensorBase { + protected: + // Create a Tensor with a +0 reference count. Special care must be + // taken to avoid decrementing this reference count at destruction + // time. Intended to support MaybeOwnedTraits. + explicit Tensor(unsafe_borrow_t, const TensorBase& rhs): TensorBase(unsafe_borrow_t{}, rhs) {} + friend MaybeOwnedTraits; + friend OptionalTensorRef; + friend TensorRef; + + public: + Tensor() = default; + // This constructor should not be used by end users and is an implementation + // detail invoked by autogenerated code. + explicit Tensor( + c10::intrusive_ptr tensor_impl) + : TensorBase(std::move(tensor_impl)) {} + Tensor(const Tensor &tensor) = default; + Tensor(Tensor &&tensor) = default; + + // Implicitly move-constructible from TensorBase, but must be explicit to increase refcount + explicit Tensor(const TensorBase &base): TensorBase(base) {} + /*implicit*/ Tensor(TensorBase &&base): TensorBase(std::move(base)) {} + + // Creates a new wrapper from TensorImpl. Intentionally a free method because + // it should be used with care. Checks necessary invariants + static Tensor wrap_tensor_impl( + c10::intrusive_ptr tensor_impl) { + return TensorBase::wrap_tensor_impl(std::move(tensor_impl)); + } + + Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { + return TensorBase::contiguous(memory_format); + } + + Tensor conj() const { + if (!this->is_complex()) { + return *this; + } + + switch (this->layout()) { + case at::kSparse: + case at::kSparseCsr: + case at::kSparseCsc: + case at::kSparseBsr: + case at::kSparseBsc: + return this->conj_physical(); + default: + return this->_conj(); + } + } + + // Aliased by Dimname overloads, so need explicit using + using TensorBase::size; + using TensorBase::sym_size; + using TensorBase::stride; + + /// Should be used if *this can reasonably be expected to be contiguous and + /// performance is important. + /// Compared to contiguous, it saves a reference count + /// increment/decrement if *this is already contiguous, at the cost + /// in all cases of an extra pointer of stack usage, an extra branch + /// to access, and an extra branch at destruction time. + c10::MaybeOwned expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const &; + + // Use .contiguous() instead. Trying to borrow from a prvalue Tensor + // will only lead to trouble and dangling references. + c10::MaybeOwned expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete; + + // The following overloads are very intruiging. Consider the following + // program: + // + // x[1] = 3; + // + // We would expect that the first entry of x is written to 3. But how can we + // actually achieve this? x[1] evaluates to a tensor... + // + // The answer is, using a ref-qualifier. x[1] is an rvalue, which cannot be + // (profitably) assigned to in the traditional sense, so we overload + // assignment to mean, "Actually, copy 3 into the tensor data." This is done + // with an rvalue-reference ref-qualified overload (the methods with && at the + // end of their type.) + // + // There's one more fly in the ointment: We also want + // + // Tensor x = y; + // + // to work, and we want it NOT to copy. So we need a traditional operator= + // overload. But we MUST specify a mutable lvalue ref-qualifier, to + // disambiguate the traditional overload from the rvalue-reference + // ref-qualified overload. Otherwise, it will be ambiguous, because + // a non ref-qualified method is eligible for all situations. + + // Unfortunately, we have to write these constructors out manually + // to work around an MSVC bug: + // error C2580: 'at::Tensor &at::Tensor::operator =(const at::Tensor &) &': + // multiple versions of a defaulted special member functions are not allowed + // Tensor& operator=(const Tensor&) & = default; + // Tensor& operator=(Tensor&&) & = default; + + // Also MSVC will wrongly issue the following warning with the aforementioned fix + // warning C4522: 'at::Tensor': multiple assignment operators specified + // Let's just skip the warning. + // + // TODO: temporarily disabled + + Tensor& operator=(const TensorBase& x) & { + impl_ = x.getIntrusivePtr(); + return *this; + } + Tensor& operator=(TensorBase&& x) & noexcept { + impl_ = x.unsafeReleaseIntrusivePtr(); + return *this; + } + + Tensor& operator=(const Tensor &x) & { + return operator=(static_cast(x)); + } + Tensor& operator=(Tensor &&x) & noexcept { + return operator=(static_cast(x)); + } + + Tensor& operator=(const Scalar &v) && { + return fill_(v); + } + Tensor& operator=(const Tensor &rhs) && { + return copy_(rhs); + } + Tensor& operator=(Tensor&& rhs) && { + return copy_(rhs); + } + + C10_DEPRECATED_MESSAGE("Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device().") + DeprecatedTypeProperties & type() const { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + dispatchKeyToBackend(legacyExtractDispatchKey(key_set())), + scalar_type()); + } + + Tensor toType(ScalarType t) const { + return to(options().dtype(t), /*non_blocking*/ false, /*copy*/ false); + } + + // TODO: Deprecate me + Tensor toBackend(Backend b) const { + return to(options().device(backendToDeviceType(b)).layout(layout_from_backend(b)), /*non_blocking*/ false, /*copy*/ false); + } + + C10_DEPRECATED_MESSAGE("Tensor.is_variable() is deprecated; everything is a variable now. (If you want to assert that variable has been appropriately handled already, use at::impl::variable_excluded_from_dispatch())") + bool is_variable() const noexcept { + return !at::impl::variable_excluded_from_dispatch(); + } + + template + C10_DEPRECATED_MESSAGE("Tensor.data() is deprecated. Please use Tensor.data_ptr() instead.") + T * data() const { + return data_ptr(); + } + + template + T item() const; + + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() const & { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() && = delete; + + Tensor operator~() const { + return bitwise_not(); + } + Tensor operator-() const { + return neg(); + } + Tensor& operator+=(const Tensor & other) { + return add_(other); + } + Tensor& operator+=(const Scalar & other) { + return add_(other); + } + Tensor& operator-=(const Tensor & other) { + return sub_(other); + } + Tensor& operator-=(const Scalar & other) { + return sub_(other); + } + Tensor& operator*=(const Tensor & other) { + return mul_(other); + } + Tensor& operator*=(const Scalar & other) { + return mul_(other); + } + Tensor& operator/=(const Tensor & other) { + return div_(other); + } + Tensor& operator/=(const Scalar & other) { + return div_(other); + } + Tensor& operator&=(const Tensor & other) { + return bitwise_and_(other); + } + Tensor& operator|=(const Tensor & other) { + return bitwise_or_(other); + } + Tensor& operator^=(const Tensor & other) { + return bitwise_xor_(other); + } + Tensor operator[](const Scalar & index) const { + if (!index.isIntegral(false)) { + TORCH_CHECK_INDEX(false, "Can only index tensors with integral scalars"); + } + return this->operator[](index.toLong()); + } + Tensor operator[](const Tensor & index) const { + // These properties are checked in the Scalar constructor, but we already + // check them here to provide more useful diagnostics for the user. + if (!index.defined()) { + TORCH_CHECK_INDEX(false, "Can only index with tensors that are defined"); + } + if (index.dim() != 0) { + TORCH_CHECK_INDEX(false, + "Can only index with tensors that are scalars (zero-dim)"); + } + // The Scalar(Tensor) constructor is explicit, so we need to call it. + return this->operator[](index.item()); + } + Tensor operator[](int64_t index) const { + return select(0, index); + } + + Tensor index(ArrayRef indices) const; + Tensor index(std::initializer_list indices) const; + + Tensor & index_put_(ArrayRef indices, Tensor const & rhs); + Tensor & index_put_(ArrayRef indices, const Scalar& v); + Tensor & index_put_(std::initializer_list indices, Tensor const & rhs); + Tensor & index_put_(std::initializer_list indices, const Scalar& v); + + Tensor cpu() const { + return to(options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); + } + + // TODO: The Python version also accepts arguments + Tensor cuda() const { + return to(options().device(c10::DeviceType::CUDA), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor hip() const { + return to(options().device(c10::DeviceType::HIP), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor ve() const { + return to(options().device(c10::DeviceType::VE), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor vulkan() const { + return to(options().device(c10::DeviceType::Vulkan), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor metal() const { + return to(options().device(c10::DeviceType::Metal), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor meta() const { + return to(options().device(c10::DeviceType::Meta), /*non_blocking*/ false, /*copy*/ false); + } + + // ~~~~~ Autograd API ~~~~~ + + /// \fn bool is_leaf() const; + /// + /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention. + /// + /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were + /// created by the user. This means that they are not the result of an operation and so + /// `grad_fn()` is `nullptr`. + /// + /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`. + /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`. + /// + /// Example: + /// @code + /// auto a = torch::rand(10, torch::requires_grad()); + /// std::cout << a.is_leaf() << std::endl; // prints `true` + /// + /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA); + /// std::cout << b.is_leaf() << std::endl; // prints `false` + /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor + /// + /// auto c = torch::rand(10, torch::requires_grad()) + 2; + /// std::cout << c.is_leaf() << std::endl; // prints `false` + /// // c was created by the addition operation + /// + /// auto d = torch::rand(10).cuda(); + /// std::cout << d.is_leaf() << std::endl; // prints `true` + /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine) + /// + /// auto e = torch::rand(10).cuda().requires_grad_(); + /// std::cout << e.is_leaf() << std::endl; // prints `true` + /// // e requires gradients and has no operations creating it + /// + /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true)); + /// std::cout << f.is_leaf() << std::endl; // prints `true` + /// // f requires grad, has no operation creating it + /// @endcode + + /// \fn void backward(const Tensor & gradient={}, std::optional retain_graph=std::nullopt, bool create_graph=false, std::optional inputs=std::nullopt) const; + /// + /// Computes the gradient of current tensor with respect to graph leaves. + /// + /// The graph is differentiated using the chain rule. If the tensor is + /// non-scalar (i.e. its data has more than one element) and requires + /// gradient, the function additionally requires specifying ``gradient``. + /// It should be a tensor of matching type and location, that contains + /// the gradient of the differentiated function w.r.t. this Tensor. + /// + /// This function accumulates gradients in the leaves - you might need to + /// zero them before calling it. + /// + /// \param gradient Gradient w.r.t. the + /// tensor. If it is a tensor, it will be automatically converted + /// to a Tensor that does not require grad unless ``create_graph`` is True. + /// None values can be specified for scalar Tensors or ones that + /// don't require grad. If a None value would be acceptable then + /// this argument is optional. + /// \param retain_graph If ``false``, the graph used to compute + /// the grads will be freed. Note that in nearly all cases setting + /// this option to True is not needed and often can be worked around + /// in a much more efficient way. Defaults to the value of + /// ``create_graph``. + /// \param create_graph If ``true``, graph of the derivative will + /// be constructed, allowing to compute higher order derivative + /// products. Defaults to ``false``. + /// \param inputs Inputs w.r.t. which the gradient will be accumulated into + /// ``at::Tensor::grad``. All other Tensors will be ignored. If not + /// provided, the gradient is accumulated into all the leaf Tensors + /// that were used to compute the current tensor. + /// When inputs are provided and a given input is not a leaf, + /// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients). + /// It is an implementation detail on which the user should not rely. + /// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. + void backward(const Tensor & gradient={}, std::optional retain_graph=std::nullopt, bool create_graph=false, std::optional inputs=std::nullopt) const { + // NB: Adding this wrapper to _backward here because we'd like our + // 'backwards' api to accept the 'inputs' argument optionally. Since code gen + // currently does not support optional of TensorList our approach is to replace + // backward in native_functions.yaml with _backward and call it here instead. + if (inputs.has_value()) { + TORCH_CHECK(inputs.value().size() > 0, "'inputs' argument to backward cannot be empty") + this->_backward(inputs.value(), gradient, retain_graph, create_graph); + } else { + this->_backward({}, gradient, retain_graph, create_graph); + } + } + + /// \fn Tensor detach() const; + /// + /// Returns a new Tensor, detached from the current graph. + /// The result will never require gradient. + + /// \fn Tensor & detach_() const; + /// + /// Detaches the Tensor from the graph that created it, making it a leaf. + /// Views cannot be detached in-place. + + /// \fn void retain_grad() const; + /// + /// Enables this Tensor to have their :attr:`grad` populated during + /// :func:`backward`. This is a no-op for leaf tensors. + + /// \fn bool retains_grad() const; + /// + /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be + /// populated during :func:`backward`, ``false`` otherwise. + + const Tensor& set_requires_grad(bool requires_grad) const { + TensorBase::set_requires_grad(requires_grad); + return *this; + } + + /// Return a mutable reference to the gradient. This is conventionally + /// used as `t.grad() = x` to set a gradient to a completely new tensor. + /// Note that this function work with a non-const Tensor and is not + /// thread safe. + Tensor& mutable_grad() const { + return impl_->mutable_grad(); + } + + /// This function returns an undefined tensor by default and returns a defined tensor + /// the first time a call to `backward()` computes gradients for this Tensor. + /// The attribute will then contain the gradients computed and future calls + /// to `backward()` will accumulate (add) gradients into it. + const Tensor& grad() const { + const Tensor& maybe_grad = impl_->grad(); + if (!is_leaf() && !retains_grad() && !maybe_grad.defined()) { + TORCH_WARN( + "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad " + "attribute won't be populated during autograd.backward(). If you indeed want the .grad " + "field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. " + "If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor " + "instead. See github.com/pytorch/pytorch/pull/30531 for more informations."); + } + return maybe_grad; + } + + // The Forward AD API functions below are low level and are not to be used by end + // users who should use the API provided in torch/csrc/autograd.h + + /// This function returns the forward gradient for this Tensor at the given level. + const Tensor& _fw_grad(uint64_t level) const { + return impl_->_fw_grad(level, *this); + } + + /// This function can be used to set the value of the forward grad. + /// Note that the given new_grad might not be used directly if it has different + /// metadata (size/stride/storage offset) compared to this Tensor. In that case, + /// new_grad content will be copied into a new Tensor + void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const { + impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op); + } + + + // STOP. Thinking of adding a method here, which only makes use + // of other ATen methods? Define it in native_functions.yaml. + + //example + //Tensor * add(Tensor & b); + ${tensor_method_declarations} + + // Special C++ only overloads for std()-like functions (See gh-40287) + // These are needed because int -> bool conversion takes precedence over int -> IntArrayRef + // So, for example std(0) would select the std(unbiased=False) overload + + Tensor var(int dim) const { + return var(IntArrayRef{dim}); + } + + Tensor std(int dim) const { + return std(IntArrayRef{dim}); + } + + // We changed .dtype() to return a TypeMeta in #12766. Ideally, we want the + // at::kDouble and its friends to be TypeMeta's, but that hasn't happened yet. + // Before that change, we make this method to maintain BC for C++ usage like + // `x.to(y.dtype)`. + // TODO: remove following two after at::kDouble and its friends are TypeMeta's. + inline Tensor to(caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const { + return this->to(/*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy); + } + inline Tensor to(Device device, caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const { + return this->to(device, /*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy); + } + + template + decltype(auto) m(F func, Args&&... params) const { + return func(*this, std::forward(params)...); + } + + /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended + /// to be used from functions that need to access the `Variable`'s equivalent `Tensor` + /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`). + /// + /// One notable difference with the legacy `.data()` function is that changes to the + /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset) + /// will not update the original `Variable`, due to the fact that this function + /// shallow-copies the `Variable`'s underlying TensorImpl. + at::Tensor tensor_data() const { + return TensorBase::tensor_data(); + } + + /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data` + /// in Python, which create a new `Variable` that shares the same storage and + /// tensor metadata with the original `Variable`, but with a completely new + /// autograd history. + /// + /// NOTE: If we change the tensor metadata (e.g. sizes / strides / + /// storage / storage_offset) of a variable created from `var.variable_data()`, those + /// changes will not update the original variable `var`. In `.variable_data()`, we set + /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal, + /// in order to prevent users from changing metadata of `var.variable_data()` + /// and expecting the original variable `var` to also be updated. + at::Tensor variable_data() const { + return TensorBase::variable_data(); + } + + // Hooks + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + template + using hook_return_void_t = std::enable_if_t>::value, unsigned>; + template + using hook_return_var_t = std::enable_if_t, Tensor>::value, unsigned>; + + /// Registers a backward hook. + /// + /// The hook will be called every time a gradient with respect to the Tensor is computed. + /// The hook should have one of the following signature: + /// ``` + /// hook(Tensor grad) -> Tensor + /// ``` + /// ``` + /// hook(Tensor grad) -> void + /// ``` + /// The hook should not modify its argument, but it can optionally return a new gradient + /// which will be used in place of `grad`. + /// + /// This function returns the index of the hook in the list which can be used to remove hook. + /// + /// Example: + /// @code + /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad()); + /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient + /// v.backward(torch::tensor({1., 2., 3.})); + /// // This prints: + /// // ``` + /// // 2 + /// // 4 + /// // 6 + /// // [ CPUFloatType{3} ] + /// // ``` + /// std::cout << v.grad() << std::endl; + /// v.remove_hook(h); // removes the hook + /// @endcode + template + hook_return_void_t register_hook(T&& hook) const; + template + hook_return_var_t register_hook(T&& hook) const; + + // Variable methods + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Tensor data() const { + return TensorBase::data(); + } + + void _backward(TensorList inputs, const std::optional& gradient, std::optional keep_graph, bool create_graph) const; + + const Tensor& requires_grad_(bool _requires_grad=true) const { + TensorBase::requires_grad_(_requires_grad); + return *this; + } +}; + +namespace detail { +// Helper creator for Tensor class which doesn't requires the users to pass +// in an intrusive_ptr instead it just converts the argument passed to +// requested intrusive_ptr type. +template +Tensor make_tensor(Args&&... args) { + return Tensor(c10::make_intrusive(std::forward(args)...)); +} + +} // namespace detail + +} // namespace at + + +namespace at { +${tensor_method_definitions} +} // namespace at + + +namespace c10 { +template <> +struct MaybeOwnedTraits { + using owned_type = at::Tensor; + using borrow_type = at::Tensor; + + static borrow_type createBorrow(const owned_type& from) { + // NOTE: this can be implemented without the special + // unsafe_borrow_t Tensor constructor as + // + // return borrow_type(c10::intrusive_ptr::reclaim(from.unsafeGetTensorImpl())); + // + // but that hurts inlining due to the nullptr check in the + // Tensor(c10::intrusive_ptr<...>) constructor. We already know + // that from.impl_ isn't null because from is a valid Tensor, so + // we needn't do the check again. (using __builtin_assume can + // avoid this, but wouldn't be portable to MSVC.) + return borrow_type(borrow_type::unsafe_borrow_t{}, from); + } + + static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { + lhs.unsafeReleaseTensorImpl(); + // See above note: this can be implemented with public API + // similarly to createBorrow(), but that would hurt inlining. + lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs); + } + + static void destroyBorrow(borrow_type& toDestroy) { + toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0. + } + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return &borrow; + } + + static bool debugBorrowIsValid(const borrow_type& /*borrow*/) { + return true; + } +}; + +template <> +struct ExclusivelyOwnedTraits { + using repr_type = at::Tensor; + using pointer_type = at::Tensor*; + using const_pointer_type = const at::Tensor*; + + static repr_type nullRepr() { + return at::Tensor(); + } + + template + static repr_type createInPlace(Args&&... args) { + return at::Tensor(std::forward(args)...); + } + + static repr_type moveToRepr(at::Tensor&& x) { + return std::move(x); + } + + static void destroyOwned(at::Tensor& x) { + return ExclusivelyOwnedTraits::destroyOwned(x); + } + + static at::Tensor take(at::Tensor& x) { + return std::move(x); + } + + static pointer_type getImpl(repr_type& x) { + return &x; + } + + static const_pointer_type getImpl(const repr_type& x) { + return &x; + } +}; +} // namespace c10 + +namespace at { + +inline c10::MaybeOwned borrow_from_optional_tensor( + const std::optional& opt) { + return opt.has_value() + ? c10::MaybeOwned::borrowed(*opt) + : c10::MaybeOwned::owned(std::in_place); +} + +inline c10::MaybeOwned Tensor::expect_contiguous(MemoryFormat memory_format) const & { + if (is_contiguous(memory_format)) { + return c10::MaybeOwned::borrowed(*this); + } else { + return c10::MaybeOwned::owned(__dispatch_contiguous(memory_format)); + } +} +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/TensorMethods.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/TensorMethods.cpp new file mode 100644 index 0000000000000000000000000000000000000000..76439040eda45ec34f627298260e7bf081fd728c --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/TensorMethods.cpp @@ -0,0 +1,61 @@ +#include +#include + +#include + +namespace at { + +namespace { + +// Verifies the requested type is the same as the Tensor's type. +void check_type(const TensorBase& tensor, ScalarType type, c10::string_view type_name) { + TORCH_CHECK( + tensor.scalar_type() == type + || (isQIntType(tensor.scalar_type()) + && toUnderlying(tensor.scalar_type()) == type), + "expected scalar type ", type_name, " but found ", tensor.scalar_type()); +} + +} // namespace + +#define DEFINE_CAST(T, name) \ + template <> \ + TORCH_API const T* TensorBase::const_data_ptr() const { \ + check_type(*this, ScalarType::name, #name); \ + return this->unsafeGetTensorImpl()->data_ptr_impl(); \ + } \ + \ + template <> \ + TORCH_API const T* TensorBase::const_data_ptr() const { \ + check_type(*this, ScalarType::name, #name); \ + return this->unsafeGetTensorImpl()->data_ptr_impl>(); \ + } \ + \ + template <> \ + TORCH_API T* TensorBase::mutable_data_ptr() const { \ + check_type(*this, ScalarType::name, #name); \ + return this->unsafeGetTensorImpl()->mutable_data_ptr_impl(); \ + } \ + \ + template <> \ + TORCH_API T* TensorBase::data_ptr() const { \ + return mutable_data_ptr(); \ + } \ + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST) + AT_FORALL_QINT_TYPES(DEFINE_CAST) + DEFINE_CAST(uint16_t, UInt16) + DEFINE_CAST(uint32_t, UInt32) + DEFINE_CAST(uint64_t, UInt64) + #undef DEFINE_CAST + + #define DEFINE_ITEM(T, name) \ + template <> \ + TORCH_API T Tensor::item() const { \ + return item().to##name(); \ + } + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ITEM) + #undef DEFINE_ITEM + + } //namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UfuncCPU.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UfuncCPU.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b363a508907cc064e41794720657541fc28c301 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UfuncCPU.cpp @@ -0,0 +1,19 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include +#include +#include + +namespace at { + +// NB: this is explicitly copied here (via codegen) rather than +// included via NativeFunctions.h to avoid recompiling this file when +// NativeFunctions.h changes +namespace meta { +${meta_declaration} +} + +namespace native { +${native_declaration} +${native_definitions} +}} // namespace at::native diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UfuncCPUKernel.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UfuncCPUKernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0cac55664d6125287bdee0bd94c150462b81d5b9 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UfuncCPUKernel.cpp @@ -0,0 +1,14 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +${native_definitions} +}} // namespace at::native diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UfuncCUDA.cu b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UfuncCUDA.cu new file mode 100644 index 0000000000000000000000000000000000000000..e75d82d9cc84bd8fddfd303f610412e5d0a98729 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UfuncCUDA.cu @@ -0,0 +1,21 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include +#include +#include +#include +${cuda_headers} + +namespace at { + +// NB: this is explicitly copied here (via codegen) rather than +// included via NativeFunctions.h to avoid recompiling this file when +// NativeFunctions.h changes +namespace meta { +${meta_declaration} +} + +namespace native { +${native_declaration} +${native_definitions} +}} // namespace at::native diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.cpp b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..86c13235d8623964d734e743f5f15cf68a8df63c --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.cpp @@ -0,0 +1,35 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace at { +namespace unboxing { + +using ::c10::fmap; +using ::c10::filter; +using torch::jit::peek; +using torch::jit::drop; +using torch::jit::pack; +using torch::jit::pop; + +// Generated function declaration +${definitions} + +} // namespace unboxing +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..a65469a9b0123cbfd4075ff3c263276aa47f137f --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.h @@ -0,0 +1,32 @@ +// ${generated_comment} + +// Generated by tools/jit/gen_unboxing.py. This file declares code generated boxed C++ functions for operators, +// base off of native_functions.yaml (or similar yaml file with the same syntax). The definition of such a boxed +// function will pop out IValues from the stack then convert them into the correct C++ types based on given schema. This +// unboxing logic is an alternative to template-based metaprogramming unboxing. + +#pragma once + +#include +namespace at { +namespace unboxing { +namespace { + +template +std::array as_array(const c10::List& list) { + std::array res; + AT_ASSERT(list.size() == N); + std::vector vec; + for (c10::IValue elem : list) { + vec.push_back(elem.to()); + } + std::copy(vec.begin(), vec.end(), res.begin()); + return res; +} +} // namespace +using Stack = std::vector; +// Generated function declaration +${declarations} + +} // namespace unboxing +} // namespace at diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/aten_interned_strings.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/aten_interned_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..326d4622334a776f4f1f94fb49a70f2c53c7e6eb --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/aten_interned_strings.h @@ -0,0 +1,22 @@ +#pragma once + +// ${generated_comment} + +#if defined(TORCH_ASSERT_NO_OPERATORS) || defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if including for \ + the c10::Symbol class would be sufficient, or if your change would be \ + better placed in another file. +#endif + +// ATen symbols correspond exactly to operators defined in ATen. Every +// symbol here corresponds exactly to an ATen operation defined in +// native_functions.yaml; attributes are in one-to-one correspondence +// with their ATen name. + +#define FORALL_ATEN_BASE_SYMBOLS(_) \ +${aten_symbols} + +#define FORALL_ATTR_BASE_SYMBOLS(_) \ +${attr_symbols} diff --git a/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/enum_tag.h b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/enum_tag.h new file mode 100644 index 0000000000000000000000000000000000000000..1320fbc28ab8f7d72655816292f49a4c9a9b727d --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/ATen/templates/enum_tag.h @@ -0,0 +1,10 @@ +#pragma once + +// ${generated_comment} + +namespace at { + // Enum of valid tags obtained from the entries in tags.yaml + enum class Tag { + ${enum_of_valid_tags} + }; +} diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/BUILD.bazel b/lib/python3.10/site-packages/torchgen/packaged/autograd/BUILD.bazel new file mode 100644 index 0000000000000000000000000000000000000000..d1a0db360d230fe0f027c19869c6307f17010503 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/BUILD.bazel @@ -0,0 +1,4 @@ +load("//:tools/bazel.bzl", "rules") +load(":build.bzl", "define_targets") + +define_targets(rules = rules) diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/README.md b/lib/python3.10/site-packages/torchgen/packaged/autograd/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bfa43899cc590959c2bfd74e38662ec03aaee3d6 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/README.md @@ -0,0 +1,3 @@ +If you add a file to this directory, you **MUST** update +`torch/CMakeLists.txt` and add the file as a dependency to +the `add_custom_command` call. diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__init__.py b/lib/python3.10/site-packages/torchgen/packaged/autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30873b8501a20e4f5a857dbfaf62819a8e1db910 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/context.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/context.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24eaf8e3a39414ef26951b66dabc1f75ae6a4019 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/context.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_annotated_fn_args.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_annotated_fn_args.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e81fafb9cdf44df56529d9d430766027dd27db20 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_annotated_fn_args.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e10edea52ee4ad5c36a158696ac291611e0c4e26 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd_functions.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd_functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..247e39ba2b6326d34b228a6340f1193d0fc27866 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd_functions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_inplace_or_view_type.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_inplace_or_view_type.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..214d1f39cd36f35b786077a71a7fce453e2ec301 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_inplace_or_view_type.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_python_functions.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_python_functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01b86c28d22a820ffeb62467e16994423affd014 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_python_functions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_trace_type.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_trace_type.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d55ad6265f01956887d13e1a29b742fa0dd75ea Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_trace_type.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_factories.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_factories.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d41c0654c5081720710f4f4cbd120237a72456f5 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_factories.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_type.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_type.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6136c8af6cec608e36a0a3722b48be12bf27c445 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_type.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_view_funcs.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_view_funcs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e1e001ad640077f30668b951e10b0005bcce02a Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/gen_view_funcs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/load_derivatives.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/load_derivatives.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1306e42dcbff5e6252dca7f889bcf6139baf61d6 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/packaged/autograd/__pycache__/load_derivatives.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/build.bzl b/lib/python3.10/site-packages/torchgen/packaged/autograd/build.bzl new file mode 100644 index 0000000000000000000000000000000000000000..588bd5944e29477119782591b231fd80a7a57cf4 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/build.bzl @@ -0,0 +1,14 @@ +def define_targets(rules): + rules.py_library( + name = "autograd", + srcs = rules.glob(["*.py"]), + data = rules.glob([ + "*.yaml", + "templates/*", + ]), + visibility = ["//:__subpackages__"], + deps = [ + rules.requirement("PyYAML"), + "//torchgen", + ], + ) diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/context.py b/lib/python3.10/site-packages/torchgen/packaged/autograd/context.py new file mode 100644 index 0000000000000000000000000000000000000000..d838aa3c77bbbc0f37cd7fa6e005d85c9e9dd624 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/context.py @@ -0,0 +1,31 @@ +import functools +from typing import Callable + +from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI +from torchgen.context import native_function_manager +from torchgen.utils import T + + +# Like tools.api.context.with_native_function, but for +# NativeFunctionWithDifferentiabilityInfo. +def with_native_function_with_differentiability_info( + func: Callable[[NFWDI], T] +) -> Callable[[NFWDI], T]: + @functools.wraps(func) + def wrapper(f: NFWDI) -> T: + with native_function_manager(f.func): + return func(f) + + return wrapper + + +# Like the above but with an additional dispatch key string argument +def with_native_function_with_differentiability_info_and_key( + func: Callable[[NFWDI, str], T] +) -> Callable[[NFWDI, str], T]: + @functools.wraps(func) + def wrapper(f: NFWDI, key: str) -> T: + with native_function_manager(f.func): + return func(f, key) + + return wrapper diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/deprecated.yaml b/lib/python3.10/site-packages/torchgen/packaged/autograd/deprecated.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52f7ec50b6ea15dae1c3308358997950d295c924 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/deprecated.yaml @@ -0,0 +1,134 @@ +# Deprecated function signatures. These are exposed in Python, but not included +# in the error message suggestions. + +- name: add(Tensor self, Scalar alpha, Tensor other) -> Tensor + aten: add(self, other, alpha) + +- name: add_(Tensor(a!) self, Scalar alpha, Tensor other) -> Tensor(a!) + aten: add_(self, other, alpha) + +- name: add(Tensor self, Scalar alpha, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + aten: add_out(out, self, other, alpha) + +- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor + aten: addbmm(self, batch1, batch2, beta, alpha) + +- name: addbmm_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor(a!) + aten: addbmm_(self, batch1, batch2, beta, alpha) + +- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!) + aten: addbmm_out(out, self, batch1, batch2, beta, alpha) + +- name: addbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2) -> Tensor + aten: addbmm(self, batch1, batch2, beta, 1) + +- name: addbmm_(Scalar beta, Tensor(a!) self, Tensor batch1, Tensor batch2) -> Tensor(a!) + aten: addbmm_(self, batch1, batch2, beta, 1) + +- name: addbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!) + aten: addbmm_out(out, self, batch1, batch2, beta, 1) + +- name: addcdiv(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor + aten: addcdiv(self, tensor1, tensor2, value) + +- name: addcdiv_(Tensor(a!) self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor(a!) + aten: addcdiv_(self, tensor1, tensor2, value) + +- name: addcdiv(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2, *, Tensor(a!) out) -> Tensor(a!) + aten: addcdiv_out(out, self, tensor1, tensor2, value) + +- name: addcmul(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor + aten: addcmul(self, tensor1, tensor2, value) + +- name: addcmul_(Tensor(a!) self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor(a!) + aten: addcmul_(self, tensor1, tensor2, value) + +- name: addcmul(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2, *, Tensor(a!) out) -> Tensor(a!) + aten: addcmul_out(out, self, tensor1, tensor2, value) + +- name: addmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2) -> Tensor + aten: addmm(self, mat1, mat2, beta, alpha) + +- name: addmm_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor mat1, Tensor mat2) -> Tensor(a!) + aten: addmm_(self, mat1, mat2, beta, alpha) + +- name: addmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + aten: addmm_out(out, self, mat1, mat2, beta, alpha) + +- name: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) -> Tensor + aten: addmm(self, mat1, mat2, beta, 1) + +- name: addmm_(Scalar beta, Tensor(a!) self, Tensor mat1, Tensor mat2) -> Tensor(a!) + aten: addmm_(self, mat1, mat2, beta, 1) + +- name: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + aten: addmm_out(out, self, mat1, mat2, beta, 1) + +- name: sspaddmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2) -> Tensor + aten: sspaddmm(self, mat1, mat2, beta, alpha) + +- name: sspaddmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) -> Tensor + aten: sspaddmm(self, mat1, mat2, beta, 1) + +- name: addmv(Scalar beta, Tensor self, Scalar alpha, Tensor mat, Tensor vec) -> Tensor + aten: addmv(self, mat, vec, beta, alpha) + +- name: addmv_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor mat, Tensor vec) -> Tensor(a!) + aten: addmv_(self, mat, vec, beta, alpha) + +- name: addmv(Scalar beta, Tensor self, Scalar alpha, Tensor mat, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) + aten: addmv_out(out, self, mat, vec, beta, alpha) + +- name: addmv(Scalar beta, Tensor self, Tensor mat, Tensor vec) -> Tensor + aten: addmv(self, mat, vec, beta, 1) + +- name: addmv_(Scalar beta, Tensor(a!) self, Tensor mat, Tensor vec) -> Tensor(a!) + aten: addmv_(self, mat, vec, beta, 1) + +- name: addmv(Scalar beta, Tensor self, Tensor mat, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) + aten: addmv_out(out, self, mat, vec, beta, 1) + +- name: addr(Scalar beta, Tensor self, Scalar alpha, Tensor vec1, Tensor vec2) -> Tensor + aten: addr(self, vec1, vec2, beta, alpha) + +- name: addr_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor vec1, Tensor vec2) -> Tensor(a!) + aten: addr_(self, vec1, vec2, beta, alpha) + +- name: addr(Scalar beta, Tensor self, Scalar alpha, Tensor vec1, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + aten: addr_out(out, self, vec1, vec2, beta, alpha) + +- name: addr(Scalar beta, Tensor self, Tensor vec1, Tensor vec2) -> Tensor + aten: addr(self, vec1, vec2, beta, 1) + +- name: addr_(Scalar beta, Tensor(a!) self, Tensor vec1, Tensor vec2) -> Tensor(a!) + aten: addr_(self, vec1, vec2, beta, 1) + +- name: addr(Scalar beta, Tensor self, Tensor vec1, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + aten: addr_out(out, self, vec1, vec2, beta, 1) + +- name: baddbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor + aten: baddbmm(self, batch1, batch2, beta, alpha) + +- name: baddbmm_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor(a!) + aten: baddbmm_(self, batch1, batch2, beta, alpha) + +- name: baddbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!) + aten: baddbmm_out(out, self, batch1, batch2, beta, alpha) + +- name: baddbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2) -> Tensor + aten: baddbmm(self, batch1, batch2, beta, 1) + +- name: baddbmm_(Scalar beta, Tensor(a!) self, Tensor batch1, Tensor batch2) -> Tensor(a!) + aten: baddbmm_(self, batch1, batch2, beta, 1) + +- name: baddbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!) + aten: baddbmm_out(out, self, batch1, batch2, beta, 1) + +- name: sub(Tensor self, Scalar alpha, Tensor other) -> Tensor + aten: sub(self, other, alpha) + +- name: sub_(Tensor(a!) self, Scalar alpha, Tensor other) -> Tensor(a!) + aten: sub_(self, other, alpha) + +- name: sub(Tensor self, Scalar alpha, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + aten: sub_out(out, self, other, alpha) diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/derivatives.yaml b/lib/python3.10/site-packages/torchgen/packaged/autograd/derivatives.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9f7ea3fbeb4ff4a5ec04578f3c3751b2870fc3b0 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/derivatives.yaml @@ -0,0 +1,3206 @@ +# Defines derivative formulas and Python signatures of methods on Variable +# +# Note about possibly confusing nomenclature: An 'output gradient' is the +# gradient of an output of a forward function. Output gradients are used as +# the inputs to backward functions. `grads` is a vector of output gradients, +# and `grad == grads[0]`, in all the derivative formulas in this file. +# An 'input gradient' is the gradient of an input to a forward function. +# Input gradients are the outputs of backward functions, corresponding to the +# input names included in the derivative formulas defined in this file. +# Also, every time we talk computing "gradient" we actually mean computing +# the vector jacobian product using the given 'output gradient' as the vector. +# +# Each entry consists of: +# - A 'name', which specifies the ATen name of the function you +# are defining derivatives for, and an argument specification. +# - An optional 'dispatch' entry which can be used to specify +# per-autograd dispatch key derivatives. If this entry is not +# specified, then the gradient entries will be taken as the +# default gradients (i.e. registered for every backward dispatch +# key). (see _test_autograd_multiple_dispatch for an example +# of how to register separate derivates for different dispatch keys). +# The list of allowed dispatch keys (in addition to 'Default' which +# represents the Autograd alias key) is torchgen/model.py:AUTOGRAD_KEYS. +# - One or more gradients entries, mapping differentiable input +# names to a formula specifying how to compute its gradient. +# Note that a single gradient entry can specify the gradient +# formula for multiple input names, by specifying a key +# "input1, input2" (see atan2 for an example). +# - An argument can be flagged as 'non_differentiable'. +# - Optional entry with key 'output_differentiability' and value a list of the +# same length as the number of outputs from the forward function. The list +# should contain only booleans, specifying whether each of the output Tensor +# is differentiable. +# If it is not specified for a function that returns multiple elements but +# uses `grad` instead of `grads[idx]`, then all but the first output will +# be marked as non-differentiable. +# If None of the output is differentiable, you can also add the function +# name to `gen_variable_type.py`'s `DONT_REQUIRE_DERIVATIVE` list. +# +# There are two cases for Tensor and TensorList arguments here: +# - If that argument is differentiable, in the sense that a gradient with respect +# to that argument could exist. You should either: +# - Specify the formula for that gradient +# - Specify not_implemented("function_name") as a formula to say that this is not +# implemented yet (but might be in the future and the user can request that on an issue) +# - If that argument is not differentiable, because it is not a floating point dtype or the +# function is not differentiable with respect to that argument for +# example. You should either: +# - Do not specify any formula for this argument +# - Specify explicitly that this argument is "non_differentiable". Note that in this case, +# we trust you that this argument will never have requires_grad=True and it will be silently +# ignored if it does. +# +# If a function has out-of-place and in-place variants, then the derivative +# definition for the in-place variant is optional. It will default to the +# definition for the out-of-place variant. Note that _out variants are never +# differentiable. +# +# Gradient expressions are standard C++ expressions operating on ATen +# variables. In a gradient expression, the following variables/functions +# are in scope: +# +# - 'grad', the gradient of the output (often spelled grad_output +# in Python) which we are going to left-multiply. +# +# When a function returns multiple *differentiable* outputs, +# you can refer to the gradients of each outputs using 'grads', +# e.g., 'grads[0]', 'grads[1]'. +# +# When a function returns multiple *differentiable* outputs that +# are named, you can refer to the gradients of each outputs using +# 'grad_{name}', e.g., 'grad_x', 'grad_y'. +# +# When a function returns *one* differentiable output (the +# first output) and some more nondifferentiable outputs, +# you MUST refer to the gradient of the differentiable output with +# 'grad' (this case is special-cased in our code generation). +# +# Note that the number of differentiable outputs can be modified by the +# 'output_differentiability' entry (see above). +# +# Across a differentiable function's derivatives set, it is not +# permitted to mix the use of "grad", "grads", and +# "grad_{name}". You must be consistent for that differentiable +# function. +# +# - Any of the input arguments, tensor or non-tensor, including +# argument names that only appear in Declarations.yaml, e.g. 'output'. +# +# - 'result', representing the result of evaluating the forward +# expression for ATen native function declarations. If the forward +# expression outputs a tuple, use 'resultX' instead to access the +# X-th entry +# +# - 'grad_input_mask', a std::array, specifies which input +# gradients are actually needed. For example, in the entry +# `input0, input1: foo(grad_input_mask)`, `grad_input_mask` is a size +# two array, where `grad_input_mask[0]` is true if `input0` requires +# grad, and `grad_input_mask[1]` is true if `input1` requires grad. +# +# (NB: if your function computes gradient for a list of tensors, +# the `grad_input_mask` will only have a single entry for the list +# specifying if either zero or at least one tensor from the list requires +# grad. If we want to support more fine-grained signalling, +# we'll need some alternate variable which is not a std::array) +# +# - 'retain_variables', a bool which is true if a user has specified +# that saved variables should be retained in case the backwards is +# run again later. This allows an optimization where we can +# destroy saved buffers if we know variables are not going to be retained, +# e.g., it is used by _cudnn_rnn +# +# - `wrap_opt_if`, is a 2-argument function that accepts a tensor +# variable and a boolean condition that dictates whether to save that +# variable in a graph. The result of this function is `c10::optional`, +# and it is `::std::nullopt` when the condition evalutes to `false`, +# otherwise it is the variable wrapped in `c10::optional`. +# For example, wrap_opt_if(var_0, grad_input_mask[1] || grad_input_mask[2]) +# would mean that `var_0` is saved as long as the second (grad_input_mask[1]) +# or the third (grad_input_mask[2]) argument requires gradients. +# Another interpretation of this expression would read as `var_0` is needed +# in the backward computation of the second or the third argument. +# NOTE: the usage of `var_i.requires_grad()` in the conditional expression +# is not supported, use `grad_input_mask[i]` instead. +# NOTE: `wrap_opt_if` could be used to prevent saving redundant variables +# with multi-output backward formulas. +# See https://github.com/pytorch/pytorch/issues/97575 for more details +# on the issue. +# +# If you need a complex expression, e.g., with local variables, +# write a _backward function in torch/csrc/autograd/FunctionsManual.cpp +# and invoke it from here. By the way, go read +# https://github.com/zdevito/ATen/issues/163; this describes an +# important hazard that occurs when porting backwards from Python to C++ +# +# Double backwards gradient expressions can be somewhat confusing; +# the most important thing to remember is: (1) you need to define a +# derivative formula for every input, including inputs named things +# like 'grad_output', and (2) the gradient to multiply with is always +# called 'grad' (even though it really is a grad-grad). +# +# You can also add forward derivative definition by defining a formula for +# a returned value (in general "result" if the name is not specified). This +# formula works the same way as the backward one and advanced implementations +# should also be placed in the FunctionsManual file. +# This formula should compute a single Jacobian vector product using the (primal) +# value of the argument "foo_p", its forward grad "foo_t" and the result of the +# function as "result". +# Note that the forward derivative can be automatically generated in two cases: +# - if your function is linear (NOT affine or multi-linear), then you can +# specify so by just using the string "auto_linear" for the formula. +# - if your function is applied element wise (and has a single input), you +# can specify so by just using the string "auto_element_wise" for the formula. +# +# Note that to avoid unpacking overhead, functions taking TensorList as inputs +# will always have their forward grad formula called. This function is responsible +# to check if any computation is needed and should return an undefined Tensor when +# there is nothing to do. You can check "cat_forward" for a full example. +# +# NB: There are a number of gradient definitions in here which are bogus +# (implemented using zeros_like). These gradients are (hopefully) not +# used by our frontend. You MUST check the frontend code; search for +# OpName.apply to see if it's still using a legacy Python style API. +# +# Note: Returning views. +# The following cases exist: +# - If a function returns no view, it can have arbitrary outputs. +# - If a function return at least one Tensor that is a differentiable view +# of one of its input: +# - If there is only one differentiable output, this Tensor is marked as a +# differentiable view. (alias or transpose for example) +# - If there are more than one differentiable output, by default all the views are +# marked as differentiable views and created with allow_rebase_history=false. +# Meaning that any inplace operation on it will raise an error. (unbind for example) +# +# Notes about undefined output gradients: +# All backward functions must support all combinations of undefined output +# gradient Tensors, where `grad[i].defined() == false`. Depending on the +# number of input and output grads your derivative formula uses, code +# generation may automatically add some level of undefined grad support, +# according to these three cases: +# +# * 1 input grad and 1 output grad: +# Complete undefined grad support is automatically added, so you +# shouldn't have to think about it, unless there is a bug in the code +# generation. +# +# * 1 input grad and multiple output grads: +# Undefined grad support is automatically added ONLY in the case where +# all output grads are undefined. You will have to add explicit support +# for cases where a subset of output grads is undefined. +# +# * multiple input grads: +# No automatic support, so you will need to add it. +# +# If your derivative formula uses more than one output grad, it is usually +# preferable to add undefined grad support in the backward function itself +# (if you're using one), rather than in the derivative formula in this file. +# +# Undefined Tensors are created with the default constructor `at::Tensor()`. +# It is an efficient way to represent a Tensor filled with zeros because +# the Tensor holds no sizing information and no Storage data is allocated. +# But consequentially, Tensor operations cannot be performed on them. +# Therefore, your backward function should treat an undefined output grad as +# a zero, and it needs to be a special case. +# +# If all output grads are undefined, then it should be correct for the +# backward function to return undefined input grads. Since we use the chain +# rule, output grads equal to zero should result in input grads equal to zero, +# unless there is some rare special case. +# +# If a subset of output grads is undefined, then it may be acceptable for +# the backward function to return undefined input grads--it depends on the +# specific function, so you'll have to determine that yourself. If returning +# an undefined Tensor is correct for a given input grad, it is also logically +# correct to return a defined grad full of zeros, but that would not be +# preferable since it would be less efficient. +# +# NB: The parameter names here MUST be consistent with the parameter names +# in native_functions.yaml +- name: abs(Tensor self) -> Tensor + self: grad * self.sgn() + result: handle_r_to_c(result.scalar_type(), self_t.conj() * self_p.sgn()) + +- name: acos(Tensor self) -> Tensor + self: grad * -((-self * self + 1).rsqrt()).conj() + result: auto_element_wise + +- name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + other: handle_r_to_c(other.scalar_type(), maybe_multiply(grad, alpha.conj())) + result: self_t + maybe_multiply(other_t, alpha) + +- name: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + result: self_t.clone() + +- name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + batch1: maybe_multiply(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) }).bmm(batch2.transpose(1, 2).conj()), alpha.conj()) + batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) })), alpha.conj()) + result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p).sum(0), alpha) + maybe_multiply(batch1_p.bmm(batch2_t).sum(0), alpha) + +- name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (value / tensor2).conj()) + tensor2: handle_r_to_c(tensor2.scalar_type(), -grad * (value * tensor1 / (tensor2 * tensor2)).conj()) + result: self_t + maybe_multiply(tensor1_t / tensor2_p, value) - maybe_multiply(tensor2_t * (tensor1_p / tensor2_p) / tensor2_p, value) + +- name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (tensor2 * value).conj()) + tensor2: handle_r_to_c(tensor2.scalar_type(), grad * (tensor1 * value).conj()) + result: self_t + maybe_multiply(tensor1_t * tensor2_p, value) + maybe_multiply(tensor2_t * tensor1_p, value) + +- name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + mat1: mm_mat1_backward(grad, mat2, mat1.sym_sizes(), mat1.sym_strides(), mat1.layout(), alpha) + mat2: mm_mat2_backward(grad, mat1, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), alpha) + result: maybe_multiply(self_t, beta) + maybe_multiply(mat1_t.mm(mat2_p), alpha) + maybe_multiply(mat1_p.mm(mat2_t), alpha) + +- name: _sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta) + mat1: mm_mat1_sparse_backward(grad, mat1, mat2, alpha) + mat2: mm_mat2_backward(grad, mat1, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), alpha) + +- name: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + mat: maybe_multiply(grad.ger(vec.conj()), alpha.conj()) + vec: maybe_multiply(mat.t().conj().mv(grad), alpha.conj()) + result: maybe_multiply(self_t, beta) + maybe_multiply(mat_t.mv(vec_p), alpha) + maybe_multiply(mat_p.mv(vec_t), alpha) + +- name: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + vec1: maybe_multiply(grad.mv(vec2.conj()), alpha.conj()) + vec2: maybe_multiply(grad.t().mv(vec1.conj()), alpha.conj()) + result: maybe_multiply(self_t, beta) + maybe_multiply(vec1_t.outer(vec2_p), alpha) + maybe_multiply(vec1_p.outer(vec2_t), alpha) + +- name: affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor + theta: affine_grid_generator_backward_symint(grad, size, align_corners) + +- name: alias(Tensor(a) self) -> Tensor(a) + self: grad + result: self_t + +- name: angle(Tensor self) -> Tensor + self: angle_backward(grad, self) + result: handle_r_to_c(result.scalar_type(), angle_backward(self_t.conj(), self_p).conj()) + +# The four items below are necessary because TensorIterator doesn't work on +# Variables (codegen does not unwrap the input Tensor for all() and any() ). +- name: any(Tensor self) -> Tensor + output_differentiability: [False] + +- name: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + output_differentiability: [False] + +- name: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor + output_differentiability: [False] + +- name: _is_all_true(Tensor self) -> Tensor + self: non_differentiable + +- name: _is_any_true(Tensor self) -> Tensor + self: non_differentiable + +- name: all(Tensor self) -> Tensor + output_differentiability: [False] + +- name: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + output_differentiability: [False] + +- name: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor + output_differentiability: [False] + +- name: acosh(Tensor self) -> Tensor +# Save one rsqrt in the real case by using that for x real and positive sqrt(x*y) = sqrt(x)*sqrt(y) (not true in the complex case) + self: "self.is_complex() ? grad * ((self + 1).rsqrt() * (self - 1).rsqrt()).conj() : grad * (self * self - 1).rsqrt()" + result: auto_element_wise + +- name: acosh_(Tensor(a!) self) -> Tensor(a!) + self: not_implemented("inplace version of acosh") + +- name: asinh(Tensor self) -> Tensor + self: grad * (self.pow(2) + 1).rsqrt().conj() + result: auto_element_wise + +- name: asinh_(Tensor(a!) self) -> Tensor(a!) + self: not_implemented("inplace version of asinh") + +- name: atanh(Tensor self) -> Tensor + self: grad * 1 / (1 - self.pow(2)).conj() + result: auto_element_wise + +- name: atanh_(Tensor(a!) self) -> Tensor(a!) + self: not_implemented("inplace version of atanh") + +- name: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) + self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) + result: auto_linear + +- name: as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) + self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) + result: auto_linear + +- name: asin(Tensor self) -> Tensor + self: grad * (-self * self + 1).rsqrt().conj() + result: auto_element_wise + +- name: atan(Tensor self) -> Tensor + self: grad / (self * self + 1).conj() + result: auto_element_wise + +- name: atan2(Tensor self, Tensor other) -> Tensor + self, other: atan2_backward(grad, self, other, grad_input_mask) + result: (-self_p * other_t + other_p * self_t) / (self_p.pow(2) + other_p.pow(2)) + +- name: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + batch1: maybe_multiply(grad.bmm(batch2.transpose(1, 2).conj()), alpha.conj()) + batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad), alpha.conj()) + result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p), alpha) + maybe_multiply(batch1_p.bmm(batch2_t), alpha) + +- name: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + p: zeros_like(p) + result: self_t.zero_() + +- name: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: bmm(Tensor self, Tensor mat2) -> Tensor + self: grad.bmm(mat2.transpose(1, 2).conj()) + mat2: self.transpose(1, 2).conj().bmm(grad) + result: self_t.bmm(mat2_p) + self_p.bmm(mat2_t) + +- name: matmul(Tensor self, Tensor other) -> Tensor + self, other: matmul_backward(grad, self, other, grad_input_mask) + +- name: cat(Tensor[] tensors, int dim=0) -> Tensor + tensors: cat_tensors_backward(grad, to_args_sizes_symint(tensors), to_args_scalartypes(tensors), dim) + result: cat_jvp(tensors, dim) + +- name: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: ceil(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: cholesky(Tensor self, bool upper=False) -> Tensor + self: cholesky_backward(grad, upper, result) + +- name: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[] + dispatch: + Default: + # the default case will use the CompositeImplicitAutograd + self: not_implemented("chunk") + AutogradNestedTensor: + self: chunk_backward_nested(grads, self, chunks, dim) + +- name: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) + self: cholesky_backward(grad, upper, L) + L: cholesky_jvp(self_t, L, upper) + +- name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor + self, input2: cholesky_solve_backward(grad, self, input2, result, upper, grad_input_mask) + result: cholesky_solve_jvp(result, input2_p, input2_t, self_t, upper) + +- name: cholesky_inverse(Tensor self, bool upper=False) -> Tensor + self: cholesky_inverse_backward(grad, self, upper, result) + result: cholesky_inverse_jvp(self_p, self_t, result, upper) + +# For clamp, gradient is not defined at the boundaries. But empirically it's helpful +# to be able to get gradient on min and max, so we return the subgradient 1 for these cases. +- name: clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor + self: clamp_backward(grad, self, min, max) + min, max: clamp_backward_min_max(grad, self, min, max, grad_input_mask) + result: clamp_jvp(self_p, self_t, min_p, min_t, max_p, max_t) + +- name: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor + self: clamp_backward(grad, self, min, max) + result: auto_element_wise + +- name: clamp_min(Tensor self, Scalar min) -> Tensor + self: where(self >= min, grad, at::scalar_tensor(0., grad.options())) + result: auto_element_wise + +- name: clamp_min.Tensor(Tensor self, Tensor min) -> Tensor + self: where(self >= min, grad, at::scalar_tensor(0., grad.options())) + min: where(self < min, grad, at::scalar_tensor(0., grad.options())) + result: where(self_p >= min_p, self_t, min_t) + +- name: clamp_max(Tensor self, Scalar max) -> Tensor + self: where(self <= max, grad, at::scalar_tensor(0., grad.options())) + result: auto_element_wise + +- name: clamp_max.Tensor(Tensor self, Tensor max) -> Tensor + self: where(self <= max, grad, at::scalar_tensor(0., grad.options())) + max: where(self > max, grad, at::scalar_tensor(0., grad.options())) + result: where(self_p <= max_p, self_t, max_t) + +- name: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor + self: grad + result: auto_linear + +- name: _lazy_clone(Tensor self) -> Tensor + self: grad + result: auto_linear + +- name: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor + self: _to_copy_backward(grad, self.options()) + result: _to_copy(self_t, dtype, layout, device, pin_memory, non_blocking, memory_format) + # The condition is: if dtype is not nullopt, then isDifferentiableType(*dtype) + # (If dtype IS nullopt, we rely on the regular check that any input requires grad). + output_differentiability: ["!dtype || isDifferentiableType(*dtype)"] + +- name: _coalesce(Tensor self) -> Tensor + self: grad + +- name: complex(Tensor real, Tensor imag) -> Tensor + real: at::real(grad) + imag: at::imag(grad) + result: at::complex(real_t, imag_t) + +- name: polar(Tensor abs, Tensor angle) -> Tensor + abs, angle: polar_backward(grad, result) + result: at::complex(abs_t*angle_p.cos() - angle_t*abs_p*angle_p.sin(), abs_t*angle_p.sin() + angle_t*abs_p*angle_p.cos()) + +- name: _conj(Tensor(a) self) -> Tensor(a) + self: grad.conj() + result: self_t.conj() + +- name: _neg_view(Tensor(a) self) -> Tensor(a) + self: grad.neg() + result: self_t._neg_view() + +- name: _conj_physical(Tensor self) -> Tensor + self: grad.conj_physical() + result: self_t.conj_physical() + +- name: conj_physical_(Tensor(a!) self) -> Tensor(a!) + self: grad.conj_physical() + result: self_t.conj_physical_() + +- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor + self: copysign_tensor_self_backward(grad, self, result) + other: zeros_like(other) + result: copysign_tensor_self_backward(self_t, self_p, result) + +- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor + self: copysign_tensor_self_backward(grad, self, result) + result: auto_element_wise + +- name: cos(Tensor self) -> Tensor + self: grad * -self.sin().conj() + result: auto_element_wise + +- name: cosh(Tensor self) -> Tensor + self: grad * self.sinh().conj() + result: auto_element_wise + +- name: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor + output_differentiability: [False] + +- name: count_nonzero(Tensor self, int? dim=None) -> Tensor + output_differentiability: [False] + +- name: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor + self: at::linalg_cross(other.conj(), grad, dim) + other: at::linalg_cross(grad, self.conj(), dim) + result: "at::linalg_cross(self_t, other_p, dim) + at::linalg_cross(self_p, other_t, dim)" + +- name: logcumsumexp(Tensor self, int dim) -> Tensor + self: logcumsumexp_backward(grad, self, result, dim) + result: logcumsumexp_jvp(self_p, self_t, dim) + +- name: cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + self: cumprod_backward(grad.to(self.scalar_type()), self, dim, result) + result: "cumprod_jvp(self_t, self_p, result, dim).to(dtype.has_value() ? *dtype : self_p.scalar_type())" + +- name: cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + self: cumsum_backward(grad.to(self.scalar_type()), dim) + result: auto_linear + +- name: cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) + self: cummaxmin_backward(grad, self, indices, dim) + values: self_t.gather(dim, indices) + +- name: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) + self: cummaxmin_backward(grad, self, indices, dim) + values: self_t.gather(dim, indices) + +- name: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor + self, weight, bias: "grad.defined() ? conv_tbc_backward(grad, self, weight, bias, pad) : std::tuple()" + +- name: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity) + +- name: _ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity) + +- name: deg2rad(Tensor self) -> Tensor + self: deg2rad_backward(grad) + result: auto_element_wise + +- name: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots) + A: linalg_det_backward(grad, result, A, LU, pivots) + result: linalg_det_jvp(A_t, result, LU, pivots, A_p.is_contiguous() && !A_p.is_complex()) + output_differentiability: [True, False, False] + +- name: _linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots) + A: slogdet_backward(grad_sign, grad_logabsdet, A, sign, LU, pivots) + sign, logabsdet: slogdet_jvp(LU, pivots, A_t, sign, A_p.is_contiguous() && !A_p.is_complex()) + output_differentiability: [True, True, False, False] + +- name: block_diag(Tensor[] tensors) -> Tensor + tensors: block_diag_backward(grad, to_args_sizes(tensors), to_args_scalartypes(tensors)) + result: block_diag_jvp(tensors) + +- name: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor + self: grad.diagonal(offset, dim1, dim2) + result: auto_linear + +- name: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) + self: diagonal_backward_symint(grad, self.sym_sizes(), offset, dim1, dim2) + result: auto_linear + +- name: diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor + grad_output: grad.diagonal(offset, dim1, dim2) + result: auto_linear + +- name: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor + self: norm_backward(grad, self - other, p, result) + other: -norm_backward(grad, self - other, p, result) + result: norm_jvp(self_p - other_p, self_t - other_t, p, result, {}, false) + +# The backward formula is done in this order to improve numerical stability +# of the higher order derivatives, see https://github.com/pytorch/pytorch/issues/43414 +# Note that we don't use "result" because saving it would be BC-breaking when it is used in an inplace operation later +- name: div.Tensor(Tensor self, Tensor other) -> Tensor + self: div_tensor_self_backward(grad, other, self.scalar_type()) + other: div_tensor_other_backward(grad, self, other) + result: (self_t - other_t * result) / other_p + +- name: div.Scalar(Tensor self, Scalar other) -> Tensor + self: div_tensor_self_backward(grad, other, self.scalar_type()) + result: self_t / other + +- name: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + self: div_tensor_self_backward(grad, other, self.scalar_type(), rounding_mode) + other: div_tensor_other_backward(grad, self, other, rounding_mode) + result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other_p - other_t * (self_p / other_p) / other_p" + +- name: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor + self: div_tensor_self_backward(grad, other, self.scalar_type(), rounding_mode) + result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other" + +- name: dot(Tensor self, Tensor tensor) -> Tensor + self: grad * tensor.conj() + tensor: grad * self.conj() + result: at::dot(self_t, tensor_p) + at::dot(self_p, tensor_t) + +- name: vdot(Tensor self, Tensor other) -> Tensor + self: grad.conj() * other + other: grad * self + result: at::vdot(self_t, other_p) + at::vdot(self_p, other_t) + +- name: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor) + self: _fused_dropout_backward(grad, result1, p) + +- name: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) + input: "GradMode::is_enabled() ? infinitely_differentiable_native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p)))) : native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p))))" + result0: "(!train.has_value() || train.value()) ? (p == 1 ? 0.0 : 1.0 / (1.0 - p)) * input_t * result1 : input_t" + +- name: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor + grad_output: "native_dropout_double_backward(grad, grad_output, mask, scale)" + mask: 'not_implemented("native_dropout_backward: mask")' + +- name: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: erf(Tensor self) -> Tensor + self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad + result: auto_element_wise + +- name: erfc(Tensor self) -> Tensor + self: -2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad + result: auto_element_wise + +- name: special_erfcx(Tensor self) -> Tensor + self: (2.0 * self * result - 2.0 / sqrt(M_PI)) * grad + result: auto_element_wise + +- name: erfinv(Tensor self) -> Tensor + self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad + result: auto_element_wise + +- name: exp(Tensor self) -> Tensor + self: grad * result.conj() + result: auto_element_wise + +- name: exp2(Tensor self) -> Tensor + self: grad * result.conj() * M_LN2 + result: auto_element_wise + +- name: expm1(Tensor self) -> Tensor + self: grad * (result.conj() + 1) + result: auto_element_wise + +# TODO: this derivative is not SymInt safe, need sum_to support +- name: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) + self: at::sum_to(grad, self.sym_sizes()) + result: auto_linear + +- name: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) + +- name: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) + +- name: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor + self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max, grad_factor) : std::tuple()" + +- name: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + self: fake_quantize_per_channel_affine_cachemask_backward(grad, mask) + +- name: _fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor + self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_channel_affine_backward(grad, self, scale, zero_point, axis, quant_min, quant_max, grad_factor) : std::tuple()" + +- name: _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) + self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) + +- name: fill.Scalar(Tensor self, Scalar value) -> Tensor + self: zeros_like(grad) + result: at::fill(self_t, 0) + +- name: fill.Tensor(Tensor self, Tensor value) -> Tensor + self: zeros_like(grad) + value: grad.sum() + result: at::fill(self_t, value_t) + +- name: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.fill_(0) + +- name: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) + self: zeros_like(grad) + value: grad.sum() + result: self_t.fill_(value_t) + +- name: floor(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: fmod.Scalar(Tensor self, Scalar other) -> Tensor + self: grad + result: auto_element_wise + +- name: fmod.Tensor(Tensor self, Tensor other) -> Tensor + self: grad + other: -grad * self.div(other, /*rounding_mode=*/"trunc") + result: self_t - other_t * self_p.div(other_p, /*rounding_mode=*/"trunc") + +- name: frac(Tensor self) -> Tensor + self: grad + result: self_t + +- name: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) + self: grad / exponent.exp2() + mantissa: self_t / exponent.exp2() + +- name: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor + self: gather_backward(grad, self, dim, index, sparse_grad) + index: non_differentiable + result: auto_linear + +- name: ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: geqrf(Tensor self) -> (Tensor a, Tensor tau) + self: not_implemented("geqrf") + +- name: indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: _indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: crow_indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: col_indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: ccol_indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: row_indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + input, grid: "grad.defined() ? grid_sampler_2d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple()" + +- name: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + input, grid: "grad.defined() ? grid_sampler_3d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple()" + +# See NOTE [ grid_sample CPU fallback ] +- name: _grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + input, grid: "grad.defined() ? _grid_sampler_2d_cpu_fallback_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners) : std::tuple()" + +- name: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: hardsigmoid(Tensor self) -> Tensor + self: hardsigmoid_backward(grad, self) + result: auto_element_wise + +- name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor + output_differentiability: [False] + +- name: hardswish(Tensor self) -> Tensor + self: hardswish_backward(grad, self) + result: auto_element_wise + +- name: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor + grad_output: hardswish_backward(grad, self) + self: at::where(at::logical_and(-3.0 < self, self < 3.0), grad * grad_output / 3.0, at::zeros({}, self.options())) + result: "hardswish_backward(grad_output_t, self_p) + + at::where(at::logical_and(-3.0 < self_p, self_p < 3.0), self_t * grad_output_p / 3.0, at::zeros({}, self_p.options()))" + +- name: hypot(Tensor self, Tensor other) -> Tensor + self: grad * self / result + other: grad * other / result + result: self_t * self_p / result + other_t * other_p / result + +- name: i0(Tensor self) -> Tensor + self: grad * at::special_i1(self) + result: auto_element_wise + +- name: special_i0e(Tensor self) -> Tensor + self: grad * (at::special_i1e(self) - self.sgn() * result) + result: auto_element_wise + +- name: special_i1(Tensor self) -> Tensor + self: i1_backward(grad, self, result) + result: auto_element_wise + +- name: special_i1e(Tensor self) -> Tensor + self: i1e_backward(grad, self, result) + result: auto_element_wise + +- name: igamma(Tensor self, Tensor other) -> Tensor + self: 'not_implemented("igamma: input")' + other: grad * exp((self - 1) * log(other) - other - lgamma(self)) + +- name: igammac(Tensor self, Tensor other) -> Tensor + self: 'not_implemented("igammac: input")' + other: -grad * exp((self - 1) * log(other) - other - lgamma(self)) + +- name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + self: index_backward(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad) + result: auto_linear + +- name: _unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + self: at::_unsafe_index_put(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad, true) + result: auto_linear + +- name: _unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor + self: at::_unsafe_masked_index_put_accumulate(grad.new_zeros_symint(self.sym_sizes(), self.options()), mask, indices, grad) + mask: non_differentiable + result: _unsafe_masked_index(self_t, mask, indices, 0) + +- name: _unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor + self: grad + mask: non_differentiable + values: at::_unsafe_masked_index(grad, mask, indices, 0) + result: at::_unsafe_masked_index_put_accumulate(self_t, mask, indices, values_t) + +- name: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor + self: grad + # The case source.dim() == 0 is necessary to support scalar tensors of the form + # source.dim() == 0 and index.dim() == 1 and index.size() == (1,), + # This is because source is not broadcastable to index, as source.dim() < index.dim() + source: "maybe_multiply(source.dim() > 0 ? grad.index_select(dim, index).expand_as(source) : grad.index_select(dim, index.squeeze(0)), alpha)" + index: non_differentiable + result: at::index_add(self_t, dim, index, maybe_multiply(source_t, alpha)) + +- name: index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor + self, source: index_reduce_backward(grad, self, dim, index, source, reduce, include_self, result) + index: non_differentiable + +- name: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor + self: grad.index_fill(dim, index, 0) + # The case source.dim() == 0 is necessary to support scalar tensors of the form + # source.dim() == 0 and index.dim() == 1 and index.size() == (1,), + # This is because source is not broadcastable to index, as source.dim() < index.dim() + source: "source.dim() > 0 ? grad.index_select(dim, index).expand_as(source) : grad.index_select(dim, index.squeeze(0))" + index: non_differentiable + result: self_t.index_copy(dim, index, source_t) + +- name: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + self: grad.index_fill(dim, index, 0) + index: non_differentiable + result: self_t.index_fill(dim, index, 0) + +- name: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor + self: grad.index_fill(dim, index, 0) + value: grad.index_select(dim, std::get<0>(at::_unique(index, /*sorted=*/false))).sum() + index: non_differentiable + result: self_t.index_fill(dim, index, value_t) + +- name: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + self: "accumulate ? grad : grad.index_put(indices, zeros_like(values), false)" + values: grad.index(indices) + result: self_t.index_put(indices, values_t, accumulate) + +- name: _unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + self: "accumulate ? grad : at::_unsafe_index_put(grad, indices, zeros_like(values), false)" + values: at::_unsafe_index(grad, indices) + result: at::_unsafe_index_put(self_t, indices, values_t, accumulate) + +- name: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!) + self: "accumulate ? grad : grad.index_put(indices, zeros_like(values), false)" + values: grad.index(indices) + result: at::_index_put_impl_(self_t, indices, values_t, accumulate, unsafe) + +- name: index_select(Tensor self, int dim, Tensor index) -> Tensor + self: index_select_backward_symint(grad, self.sym_sizes(), dim, index) + index: non_differentiable + result: auto_linear + +- name: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info) + A: -at::matmul(inverse.mH(), at::matmul(grad, inverse.mH())) + inverse: -at::matmul(at::matmul(inverse, A_t), inverse) + output_differentiability: [True, False] + +- name: linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor + self: pinv_backward(grad, result, self) + result: pinv_jvp(self_p, result, self_t) + +- name: isnan(Tensor self) -> Tensor + self: non_differentiable + +- name: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor + self: "weight.isComplex() ? grad * (1 - weight.conj().toComplexDouble()) : grad * (1 - weight.toDouble())" + end: grad * weight.conj() + result: at::lerp(self_t, end_t, weight) + +- name: lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor + self: grad * (1 - weight).conj() + end: grad * weight.conj() + weight: grad * (end - self).conj() + result: at::lerp(self_t, end_t, weight_p) + weight_t * (end_p - self_p) + +- name: lgamma(Tensor self) -> Tensor + self: grad * digamma(self) + result: auto_element_wise + +- name: digamma(Tensor self) -> Tensor + self: grad * polygamma(1, self) + result: auto_element_wise + +- name: polygamma(int n, Tensor self) -> Tensor + self: grad * polygamma(n + 1, self) + result: auto_element_wise + +- name: polygamma_(Tensor(a!) self, int n) -> Tensor(a!) + self: grad * polygamma(n + 1, self) + result: self_t.mul_(polygamma(n + 1, original_self_p)) + +- name: log(Tensor self) -> Tensor + self: grad.div(self.conj()) + result: auto_element_wise + +- name: log10(Tensor self) -> Tensor + self: grad / (self.conj() * 2.3025850929940456) + result: auto_element_wise + +- name: log1p(Tensor self) -> Tensor + self: log1p_backward(grad, self) + result: auto_element_wise + +- name: log2(Tensor self) -> Tensor + self: grad / (self.conj() * 0.6931471805599453) + result: auto_element_wise + +- name: logaddexp(Tensor self, Tensor other) -> Tensor + self: grad / (1 + exp(other - self)).conj() + other: grad / (1 + exp(self - other)).conj() + result: self_t / (1 + exp(other_p - self_p)) + other_t / (1 + exp(self_p - other_p)) + +- name: logaddexp2(Tensor self, Tensor other) -> Tensor + self: grad / (1 + pow(2, other - self)) + other: grad / (1 + pow(2, self - other)) + result: self_t / (1 + pow(2, other_p - self_p)) + other_t / (1 + pow(2, self_p - other_p)) + +# Note [Gradient formula for xlogy at x = 0, y <= 0] +# x * log(y) is not defined at y <= 0, so we cannot even talk about differentiability +# Now, xlogy(0, y) = 0 by definition. +# This does not make it differentiable as it's not defined in a neighbourhood of a point +# (0, y) when y <= 0. +# Now, when a function is non-differentiable, sometimes we return "a relatively sensible value" +# In this case, as per the discussion in https://github.com/pytorch/pytorch/issues/80770, we choose +# this value to be zero, which is the directional derivative along the line {x = 0}. +- name: xlogy.Tensor(Tensor self, Tensor other) -> Tensor + self: at::xlogy(grad, other).masked_fill((self == 0.) & (other <= 0.), 0.) + other: grad * self / other + result: at::xlogy(self_t, other_p).masked_fill((self_p == 0.) & (other_p <= 0.), 0.) + other_t * self_p / other_p + +- name: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor + other: grad * self / other + result: auto_element_wise + +- name: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor + self: "other.toDouble() > 0. + ? at::xlogy(grad, other) + : at::xlogy(grad, other).masked_fill(self == 0., 0.)" + result: auto_element_wise + +# See Note [Gradient formula for xlogy at x = 0, y <= 0] +# Same here but with y <= -1 +- name: special_xlog1py(Tensor self, Tensor other) -> Tensor + self: at::special_xlog1py(grad, other).masked_fill((self == 0.) & (other <= -1.), 0.) + other: grad * self / (other + 1) + result: at::special_xlog1py(self_t, other_p).masked_fill((self_p == 0.) & (other_p <= -1.), 0.) + other_t * self_p / (other_p + 1) + +- name: special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor + other: grad * self / (other + 1) + result: auto_element_wise + +- name: special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor + self: "other.toDouble() > -1. + ? at::special_xlog1py(grad, other) + : at::special_xlog1py(grad, other).masked_fill(self == 0., 0.)" + result: auto_element_wise + +- name: special_zeta(Tensor self, Tensor other) -> Tensor + self: not_implemented("zeta") + other: grad * -self * special_zeta(self + 1., other) + +- name: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor + other: grad * -self * special_zeta(self.toDouble() + 1., other) + +- name: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor + self: not_implemented("zeta") + +- name: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + self: logsumexp_backward(grad, self, result, dim, keepdim) + result: logsumexp_jvp(self_p, self_t, dim, keepdim) + +- name: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values) + self, b: linalg_lstsq_backward(grad, self, b, grad_input_mask) + solution: linalg_lstsq_jvp(self_p, b_p, self_t, b_t) + output_differentiability: [True, False, False, False] + +- name: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) + A: lu_factor_ex_backward(grad, LU, pivots, pivot) + LU: lu_factor_ex_jvp(A_t, LU, pivots, pivot) + output_differentiability: [True, False, False] + +- name: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) + A: lu_factor_ex_backward(grad, LU, pivots, pivot) + LU: lu_factor_ex_jvp(A_t, LU, pivots, pivot) + output_differentiability: [True, False] + +- name: linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U) + A: linalg_lu_backward(grad_L, grad_U, P, L, U, pivot) + L: std::get<0>(linalg_lu_jvp(A_t, P, L, U, pivot)) + U: std::get<1>(linalg_lu_jvp(A_t, P, L, U, pivot)) + output_differentiability: [False, True, True] + +- name: linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor + LU: linalg_lu_solve_LU(grad, LU, pivots, result, left, adjoint) + B: "at::linalg_lu_solve(LU, pivots, grad, left, !adjoint)" + result: linalg_lu_solve_jvp(result, LU_p, pivots, LU_t, B_t, left, adjoint) + +- name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) + LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1)) + LU_pivots: non_differentiable + L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril(-1)" + U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu()" + output_differentiability: [False, True, True] + +- name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor + self: grad.masked_fill(mask, 0) + mask: non_differentiable + result: self_t.masked_fill(mask, 0) + +- name: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor + self: grad.masked_fill(mask, 0) + value: masked_fill_backward(grad, mask) + mask: non_differentiable + result: self_t.masked_fill(mask, value_t) + +- name: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor + self: grad.masked_fill(mask, 0) + source: masked_scatter_backward_symint(grad, mask, source.sym_sizes()) + mask: non_differentiable + result: self_t.masked_scatter(mask, source_t) + +- name: masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor + grad_output: zeros_like(grad_output).masked_scatter(mask, grad) + mask: non_differentiable + result: masked_scatter_backward(grad_output_t, mask, grad_output_t.sizes()) + +- name: masked_select(Tensor self, Tensor mask) -> Tensor + self: masked_select_backward(grad, self, mask) + mask: non_differentiable + result: auto_linear + +- name: linalg_matrix_exp(Tensor self) -> Tensor + self: linalg_matrix_exp_differential(self, grad, /*adjoint*/ true) + result: linalg_matrix_exp_differential(self_p, self_t, /*adjoint*/ false) + +- name: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: max(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) + +- name: maximum(Tensor self, Tensor other) -> Tensor + self: at::where(self == other, grad / 2, grad).masked_fill_(self < other, 0) + other: at::where(self == other, grad / 2, grad).masked_fill_(self > other, 0) + result: other_t + at::where(self_p == other_p, at::scalar_tensor(0.5, result.options()), (self_p > other_p).to(result.scalar_type())) * (self_t - other_t) + +- name: fmax(Tensor self, Tensor other) -> Tensor + self: grad.masked_fill((self >= other).logical_or_(other.isnan()).logical_not_(), 0) + other: grad.masked_fill((self >= other).logical_or_(other.isnan()), 0) + result: other_t + (self_p > other_p).logical_or_(other_p.isnan()) * (self_t - other_t) + +- name: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor + self: grad.expand_symint(self.sym_sizes()) / self.sym_numel() + result: auto_linear + +- name: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: mean_backward(grad, self.sym_sizes(), dim, self.sym_numel(), keepdim) + result: auto_linear + +- name: median(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) + +- name: nanmedian(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) + +# This is in theory incorrect in the following case: +# sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value +# | at middle position of the +# | list between two `b`s. E.g., +# | +# ^the middle position +# The gradient exists and is essentially 0 in this case. +# +# In case where the middle position is at the boundary of `b` range, e.g., +# sorted list: [..., a, b, b, ..., b, b, c, ...] +# | +# ^the middle position +# The backward implementation is correct in the sense that it returns the +# subgradient on one side. +- name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: min(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) + +- name: minimum(Tensor self, Tensor other) -> Tensor + self: at::where(self == other, grad / 2, grad).masked_fill_(self > other, 0) + other: at::where(self == other, grad / 2, grad).masked_fill_(self < other, 0) + result: other_t + at::where(self_p == other_p, at::scalar_tensor(0.5, result.options()), (self_p < other_p).to(result.scalar_type())) * (self_t - other_t) + +- name: fmin(Tensor self, Tensor other) -> Tensor + self: grad.masked_fill((self <= other).logical_or_(other.isnan()).logical_not_(), 0) + other: grad.masked_fill((self <= other).logical_or_(other.isnan()), 0) + result: other_t + (self_p <= other_p).logical_or_(other_p.isnan()) * (self_t - other_t) + +- name: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) + result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) + +- name: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) + result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) + +- name: mm(Tensor self, Tensor mat2) -> Tensor + self: mm_mat1_backward(grad, mat2, self.sym_sizes(), self.sym_strides(), self.layout(), 1) + mat2: mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), 1) + result: at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t) + +- name: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: mul.Tensor(Tensor self, Tensor other) -> Tensor + self: mul_tensor_backward(grad, other, self.scalar_type()) + other: mul_tensor_backward(grad, self, other.scalar_type()) + result: other_t * self_p + self_t * other_p + +- name: mul.Scalar(Tensor self, Scalar other) -> Tensor + self: mul_tensor_backward(grad, other, self.scalar_type()) + result: self_t * other + +- name: mv(Tensor self, Tensor vec) -> Tensor + self: grad.ger(vec.conj()) + vec: self.conj().t().mv(grad) + result: mv(self_t, vec_p) + mv(self_p, vec_t) + +- name: mvlgamma(Tensor self, int p) -> Tensor + self: mvlgamma_backward(grad, self, p) + result: auto_element_wise + +- name: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor + self: grad * at::isfinite(self) + result: auto_element_wise + +- name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps) + +- name: _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps) + +- name: _native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*training=*/false, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, /*training=*/false, eps) + +- name: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, Tensor(), Tensor(), result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, Tensor(), Tensor(), result1, result2, training, eps) + +- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask) + save_mean: not_implemented("native_batch_norm_backward save_mean") + save_invstd: not_implemented("native_batch_norm_backward save_invstd") + +- name: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_layer_norm_backward_symint(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple()" + result0: layer_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, normalized_shape) + +- name: native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask) + bias: Tensor() + mean: not_implemented("native_layer_norm_backward mean") + rstd: not_implemented("native_layer_norm_backward rstd") + +- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" + result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) + result1: group_norm_mean_jvp(input_t, result1, group) + result2: group_norm_invstd_jvp(input_p, input_t, result1, result2, group) + +- name: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: neg(Tensor self) -> Tensor + self: grad.neg() + result: auto_element_wise + +- name: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/true, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, true, eps) + +- name: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/false, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, false, eps) + +- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor) + input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, update, eps, save_mean, save_var, grad_input_mask) + save_mean: not_implemented("batch_norm_backward save_mean") + save_var: not_implemented("batch_norm_backward save_var") + reserve: not_implemented("batch_norm_backward reserve") + +- name: nextafter(Tensor self, Tensor other) -> Tensor + self: not_implemented("nextafter") + other: not_implemented("nextafter") + +- name: norm.Scalar(Tensor self, Scalar p=2) -> Tensor + self: norm_backward(grad, self, p, result) + result: norm_jvp(self_p, self_t, p, result) + +- name: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor + self: norm_backward(grad, self, p, result, dim, keepdim) + result: norm_jvp(self_p, self_t, p, result, dim, keepdim) + +- name: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor + self: norm_backward(grad, self.to(grad.scalar_type()), p, result) + result: norm_jvp(self_p, self_t, p, result) + +- name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor + self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim) + result: norm_jvp(self_p, self_t, p, result, dim, keepdim) + +- name: linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: linalg_vector_norm_backward(grad, self, ord, result, dim, keepdim) + result: linalg_vector_norm_jvp(self_p, self_t, ord, result, dim, keepdim) + +- name: _pdist_forward(Tensor self, float p=2) -> Tensor + self: _pdist_backward(grad, self, p, result) + +- name: _pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor + grad: not_implemented("_pdist_backward") + self: not_implemented("_pdist_backward") + pdist: not_implemented("_pdist_backward") + +- name: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor + x1, x2: _euclidean_dist_backward(grad, x1, x2, result) + +- name: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor + x1: _cdist_backward(grad.contiguous(), x1, x2, p, result) + x2: _cdist_backward(grad.mT().contiguous(), x2, x1, p, result.mT().contiguous()) + +- name: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor + grad: not_implemented("_cdist_backward") + x1: not_implemented("_cdist_backward") + x2: not_implemented("_cdist_backward") + cdist: not_implemented("_cdist_backward") + +- name: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor + mean: at::zeros_symint(mean.sym_sizes(), grad.options()) + result: auto_element_wise + +- name: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor + std: at::zeros_symint(std.sym_sizes(), grad.options()) + result: auto_element_wise + +- name: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor + mean: at::zeros_symint(mean.sym_sizes(), grad.options()) + std: at::zeros_symint(std.sym_sizes(), grad.options()) + result: zeros_like(mean_t) + +- name: linalg_householder_product(Tensor input, Tensor tau) -> Tensor + input, tau: householder_product_backward(grad, result, input, tau) + result: householder_product_jvp(input_t, tau_t, result, input_p, tau_p) + +- name: ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor + self, input2, input3: ormqr_backward(grad, result, self, input2, input3, left, transpose, grad_input_mask) + +- name: permute(Tensor(a) self, int[] dims) -> Tensor(a) + self: permute_backwards(grad, dims) + result: auto_linear + +- name: poisson(Tensor self, Generator? generator=None) -> Tensor + self: zeros_like(self) + result: auto_element_wise + +- name: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + self: pow_backward(grad, self, exponent) + result: auto_element_wise + +- name: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + self: pow_backward_self(grad, self, exponent) + exponent: pow_backward_exponent(grad, self, exponent, result) + result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result)).conj() + +- name: pow.Scalar(Scalar self, Tensor exponent) -> Tensor + exponent: pow_backward_exponent(grad, self, exponent, result) + result: auto_element_wise + +- name: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor + self: prod_backward(grad, self.to(grad.scalar_type()), result) + result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result) * self_t.conj()).sum().conj() + +- name: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim) + result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj() + +- name: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor + self: "accumulate ? grad : grad.put(index, zeros_like(source), false)" + index: non_differentiable + source: grad.take(index).reshape_as(source) + result: self_t.put(index, source_t, accumulate) + +- name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R) + A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode) + Q, R: linalg_qr_jvp(A_t, Q, R, mode) + +- name: rad2deg(Tensor self) -> Tensor + self: rad2deg_backward(grad) + result: auto_element_wise + +- name: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: reciprocal(Tensor self) -> Tensor + self: -grad * (result * result).conj() + result: auto_element_wise + +- name: remainder.Scalar(Tensor self, Scalar other) -> Tensor + self: grad + result: auto_element_wise + +- name: remainder.Tensor(Tensor self, Tensor other) -> Tensor + self: grad + other: -grad * self.div(other, /*rounding_mode=*/"floor") + result: self_t - other_t * self_p.div(other_p, /*rounding_mode=*/"floor") + +- name: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor + self: renorm_backward(grad, self, p, dim, maxnorm) + result: renorm_jvp(self_p, self_t, p, dim, maxnorm) + +- name: repeat(Tensor self, SymInt[] repeats) -> Tensor + self: repeat_backward(grad, repeats, self.sym_sizes()) + result: auto_linear + +- name: special_entr(Tensor self) -> Tensor + self: grad * (-(1 + self.log())) + result: auto_element_wise + +- name: special_ndtri(Tensor self) -> Tensor + self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp() + result: auto_element_wise + +- name: special_log_ndtr(Tensor self) -> Tensor + self: grad / std::sqrt(2 * M_PI) * (result + self.pow(2) / 2).neg().exp() + result: auto_element_wise + +# [Note: Sometimes view derivatives] +# The following situation applies to other operations as well. +# TODO: This note is only referenced by to_dense and to_sparse*. Make +# this more generic if it's been referenced more than once. +# +# DO NOT define a backward for reshape! +# reshape is special in that it sometimes returns a view, and sometimes not. +# Defining a backward will make codegen spit out the forward call as +# as_variable(baseType->reshape(self)), +# making it impossible (hard) to detect when it is actually a view. +# - name: reshape(Tensor self, IntArrayRef shape) + +- name: _reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear + +- name: round(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: round.decimals(Tensor self, *, int decimals) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: rsqrt(Tensor self) -> Tensor + self: -0.5 * grad * result.pow(3).conj() + result: auto_element_wise + +- name: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + self: grad.scatter(dim, index, 0) + index: non_differentiable + src: grad.gather(dim, index) + result: self_t.scatter(dim, index, src_t) + +- name: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + self: grad.scatter(dim, index, 0) + index: non_differentiable + result: self_t.scatter(dim, index, 0) + +- name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + self: grad + index: non_differentiable + src: grad.gather(dim, index) + result: scatter_add(self_t, dim, index, src_t) + +- name: select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) + dispatch: + Default: + self: select_backward_symint(grad, self.sym_sizes(), dim, index) + result: auto_linear + AutogradNestedTensor: + self: _nested_select_backward_symint(grad, self, dim, index) + +- name: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor + grad_output: grad.select_symint(dim, index) + result: auto_linear + +- name: sigmoid(Tensor self) -> Tensor + self: sigmoid_backward(grad, result) + result: auto_element_wise + +- name: logit(Tensor self, float? eps=None) -> Tensor + self: "GradMode::is_enabled() ? infinitely_differentiable_logit_backward(grad, self, eps) : logit_backward(grad, self, eps)" + result: auto_element_wise + +- name: sign(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: sgn(Tensor self) -> Tensor + self: sgn_backward(self, grad, result) + # Cannot use auto_element_wise here because the Jacobian is *not* Hermitian (in fact, it is symmetric) + # The function is not holomorphic, so there's no reason for its Jacobian to be Hermitian + # auto_element_wise has a name that's a bit deceiving in the complex case + result: sgn_backward(self_p, self_t, result) + +- name: sin(Tensor self) -> Tensor + self: grad * self.cos().conj() + result: auto_element_wise + +- name: sinc(Tensor self) -> Tensor + self: sinc_backward(grad, self) + result: auto_element_wise + +- name: sinh(Tensor self) -> Tensor + self: grad * self.cosh().conj() + result: auto_element_wise + +- name: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + self: slice_backward_wrapper(grad, self.sym_sizes(), dim, start, end, step) + result: auto_linear + +- name: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor + grad_output: grad.slice_symint(dim, start, end, step) + result: auto_linear + +- name: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + self: grad.slice_symint(dim, start, end, step) + src: slice_scatter_symint(grad, zeros_like(self), dim, start, end, step) + result: auto_linear + +- name: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + self: slice_scatter_symint(grad, zeros_like(src), dim, start, end, step) + src: grad.slice_symint(dim, start, end, step) + result: auto_linear + +- name: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor + self: select_scatter_symint(grad, zeros_like(src), dim, index) + src: grad.select_symint(dim, index) + result: auto_linear + +- name: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor + self: diagonal_scatter(grad, zeros_like(src), offset, dim1, dim2) + src: grad.diagonal(offset, dim1, dim2) + result: auto_linear + +- name: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + self: as_strided_scatter_backward(grad, TensorGeometry(self), TensorGeometry(src), size, stride, storage_offset) + # See Note [as_strided_scatter backward support] + src: grad.contiguous().as_strided_symint(size, stride, storage_offset) + result: auto_linear + +- name: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info) + A, B: linalg_solve_backward(grad, result, A, LU, pivots, left, grad_input_mask[1]) + result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left, A_p.is_contiguous() && !A_p.is_complex())" + output_differentiability: [True, False, False, False] # LU is an auxiliary tensor not exposed to the user + +- name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) + output_differentiability: [True, False] + values: gather_with_keepdimed_indices(self_t, dim, indices, true) + +- name: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) + output_differentiability: [True, False] + values: gather_with_keepdimed_indices(self_t, dim, indices, true) + +- name: split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] + self: split_backward(grads, split_size, dim, self.sym_sizes(), self.options()) + result: auto_linear + +- name: unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + self: split_backward(grads, split_size, dim, self.sym_sizes(), self.options()) + result: auto_linear + +- name: split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] + dispatch: + Default: + self: split_with_sizes_backward(grads, split_sizes, dim, self.sym_sizes(), self.options()) + result: auto_linear + AutogradNestedTensor: + self: _nested_split_with_sizes_backward(grads, split_sizes, dim, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), self.options()) + +- name: unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + self: split_with_sizes_backward(grads, split_sizes, dim, self.sym_sizes(), self.options()) + result: auto_linear + +- name: sqrt(Tensor self) -> Tensor + self: grad / (2 * result.conj()) + result: auto_element_wise + +- name: squeeze(Tensor(a) self) -> Tensor(a) + self: unsqueeze_to(grad, self.sym_sizes()) + result: auto_linear + +- name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) + dispatch: + Default: + self: unsqueeze_to(grad, dim, self.sym_sizes()) + result: auto_linear + AutogradNestedTensor: + self: grad.unsqueeze(dim) + +- name: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a) + dispatch: + Default: + self: unsqueeze_to(grad, dim, self.sym_sizes()) + result: auto_linear + AutogradNestedTensor: + self: unsqueeze_multiple(grad, dim, self.dim()) + +- name: squeeze_(Tensor(a!) self) -> Tensor(a!) + self: unsqueeze_to(grad, self.sym_sizes()) + result: auto_linear + +- name: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) + self: unsqueeze_to(grad, dim, self.sym_sizes()) + result: auto_linear + +- name: squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!) + self: unsqueeze_to(grad, dim, self.sym_sizes()) + result: auto_linear + +- name: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor + self: std_backward(result, grad, self, dim, correction, keepdim) + # pointwise (variance) + sum + sqrt + result: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result)).masked_fill_(result == 0, 0) + +- name: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + self: std_mean_backward(grads[0], grads[1], self, result0, dim, correction, keepdim) + result0: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result0)).masked_fill_(result0 == 0, 0) + # linear + result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim) + +- name: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + other: handle_r_to_c(other.scalar_type(), maybe_multiply(-grad, alpha.conj())) + result: self_t - maybe_multiply(other_t, alpha) + +- name: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + result: auto_element_wise + +- name: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), maybe_multiply(-grad, alpha.conj())) + other: handle_r_to_c(other.scalar_type(), grad) + result: -maybe_multiply(self_t, alpha) + other_t + +- name: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), maybe_multiply(-grad, alpha.conj())) + result: auto_element_wise + +- name: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor + self: grad.expand_symint(self.sym_sizes()) + result: auto_linear + +- name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + dispatch: + Default: + self: sum_backward(grad, self.sym_sizes(), dim, keepdim) + result: auto_linear + AutogradNestedTensor: + # TODO: replace this function once semantics for nested tensor expand have been settled on + self: _nested_sum_backward(grad, self, dim, keepdim) + +- name: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) + result: at::where(self_p.isnan(), 0, self_t).sum(dim, keepdim, dtype) + +# We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here +- name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) + A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow_symint(-1, 0, S.sym_size(-1)) : grad_U, + grad_S, + full_matrices && grad_Vh.defined() ? grad_Vh.narrow_symint(-2, 0, S.sym_size(-1)) : grad_Vh, + full_matrices ? U.narrow_symint(-1, 0, S.sym_size(-1)) : U, + S, + full_matrices ? Vh.narrow_symint(-2, 0, S.sym_size(-1)) : Vh)" + U, S, Vh: linalg_svd_jvp(A_t, U, S, Vh, full_matrices) + +- name: _linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors) + A: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/true) + eigenvalues, eigenvectors: linalg_eig_jvp(A_t, eigenvalues, eigenvectors, /*is_hermitian=*/true) + +- name: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors) + self: handle_r_to_c(self.scalar_type(), linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/false)) + eigenvalues, eigenvectors: linalg_eig_jvp(self_t, eigenvalues, eigenvectors, /*is_hermitian=*/false) + +- name: t(Tensor(a) self) -> Tensor(a) + self: grad.t() + result: auto_linear + +- name: t_(Tensor(a!) self) -> Tensor(a!) + self: grad.t() + result: auto_linear + +- name: one_hot(Tensor self, int num_classes=-1) -> Tensor + self: non_differentiable + +- name: flip(Tensor self, int[] dims) -> Tensor + self: grad.flip(dims) + result: auto_linear + +- name: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor + self: grad.roll_symint(fmap(reverse_list_symint(shifts), [](c10::SymInt i){return -i;}), reverse_list(dims)) + result: auto_linear + +- name: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor + self: grad.rot90(-k, dims) + result: auto_linear + +- name: take(Tensor self, Tensor index) -> Tensor + self: take_backward(grad, self, index) + index: non_differentiable + result: auto_linear + +- name: tan(Tensor self) -> Tensor + self: grad * (1 + result.pow(2)).conj() + result: auto_element_wise + +- name: tanh(Tensor self) -> Tensor + self: tanh_backward(grad, result) + result: auto_element_wise + +- name: topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) + output_differentiability: [True, False] + values: gather(self_t, dim, indices) + +- name: trace(Tensor self) -> Tensor + self: trace_backward_symint(grad, self.sym_sizes()) + result: auto_linear + +- name: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + self: grad.transpose(dim0, dim1) + result: auto_linear + +- name: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + self: grad.transpose(dim0, dim1) + result: auto_linear + +- name: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient) + self, A: triangular_solve_backward(grad_solution, grad_cloned_coefficient, self, A, solution, upper, transpose, unitriangular, grad_input_mask) + solution: triangular_solve_jvp(solution, A_p, A_t, self_t, upper, transpose, unitriangular) + cloned_coefficient: A_t + +- name: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor + self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask) + result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular) + +- name: tril(Tensor self, int diagonal=0) -> Tensor + self: grad.tril(diagonal) + result: auto_linear + +- name: triu(Tensor self, int diagonal=0) -> Tensor + self: grad.triu(diagonal) + result: auto_linear + +- name: trunc(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +# DO NOT define a backward for to_dense +# See [Note: Sometimes view derivatives] +# - name: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor +# +- name: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor + self: to_dense_backward(grad, self, masked_grad) + +# DO NOT define a backward for to_sparse.sparse_dim +# See [Note: Sometimes view derivatives] +# - name: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor +# +- name: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse +# See [Note: Sometimes view derivatives] +# - name: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor +# +- name: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse_csr +# See [Note: Sometimes view derivatives] +# - name: to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor +# +- name: _to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse_csc +# See [Note: Sometimes view derivatives] +# - name: to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor +# +- name: _to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse_bsr +# See [Note: Sometimes view derivatives] +# - name: to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +# +- name: _to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse_bsc +# See [Note: Sometimes view derivatives] +# - name: to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +# +- name: _to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +- name: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor + self: to_mkldnn_backward(grad, self) + +- name: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) + self: unfold_backward_symint(grad, self.sym_sizes(), dimension, size, step) + result: auto_linear + +- name: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor + grad_in: grad.unfold(dim, size, step) + result: auto_linear + +- name: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) + output_differentiability: [True, False] + self: not_implemented("_unique") + +- name: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("unique_dim") + +- name: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("unique_consecutive") + +- name: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("unique_dim_consecutive") + +- name: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("_unique2") + +- name: _unsafe_view(Tensor self, SymInt[] size) -> Tensor + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear + +- name: lift(Tensor self) -> Tensor + self: grad + result: auto_linear + +- name: lift_fresh(Tensor(a) self) -> Tensor(a) + self: grad + result: auto_linear + +- name: unsqueeze(Tensor(a) self, int dim) -> Tensor(a) + self: grad.squeeze(dim) + result: auto_linear + +- name: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) + self: grad.squeeze(dim) + result: auto_linear + +- name: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor + self: var_backward(grad, self, dim, correction, keepdim) + # pointwise + sum + result: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) + +- name: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + self: var_mean_backward(grads[0], grads[1], self, dim, correction, keepdim) + result0: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) + # linear + result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim) + +- name: view(Tensor(a) self, SymInt[] size) -> Tensor(a) + dispatch: + Default: + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear + AutogradNestedTensor: + self: grad.reshape_as(self) + result: auto_linear + +- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) + output_differentiability: [False] + +- name: view_as_real(Tensor(a) self) -> Tensor(a) + self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1 + result: at::view_as_real(self_t) + +- name: view_as_complex(Tensor(a) self) -> Tensor(a) + self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy] + result: at::view_as_complex(self_t) + +- name: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor + condition: non_differentiable + self: where(condition, grad, 0) + other: where(condition, 0, grad) + result: where(condition, self_t, other_t) + +# weight_norm_cuda_interface_backward does not have an explicitly defined derivative, so if we do happen +# to be running backward with create_graph=True, fall back to a backward function that uses +# differentiable ops. +- name: _weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor) + v, g: "grad.defined() ? (GradMode::is_enabled() ? _weight_norm_differentiable_backward(grad.contiguous(), v, g, result1, dim) : _weight_norm_interface_backward(grad.contiguous(), v, g, result1, dim)) : std::tuple()" + +- name: zero_(Tensor(a!) self) -> Tensor(a!) + self: zeros_like(grad) + result: auto_linear + +- name: sparse_mask(Tensor self, Tensor mask) -> Tensor + self: sparse_mask_backward(grad, mask, self.layout()) + mask: non_differentiable + +- name: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor + indices: non_differentiable + values: grad.sparse_mask(result)._values() + +- name: sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + compressed_indices: non_differentiable + plain_indices: non_differentiable + # TODO: remove to_dense after gh-107381 is fixed + values: grad.to_dense().sparse_mask(result).values() + +- name: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor + self: at::_sparse_sum_backward(grad, self, dim) + +- name: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor + self: grad * _standard_gamma_grad(self, result) + +- name: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor + self: not_implemented("_standard_gamma_grad") + +- name: values(Tensor(a) self) -> Tensor(a) + dispatch: + Default: + self: values_backward(grad, self) + AutogradNestedTensor: + self: at::_nested_view_from_buffer(grad.contiguous(), self._nested_tensor_size(), self._nested_tensor_strides(), self._nested_tensor_storage_offsets()) + +# Why is _values() not differentiable? +# See NOTE [ Sparse: autograd and API ] +- name: _values(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +# NN +- name: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor + i1, i2, i3: "_trilinear_backward(grad, + wrap_opt_if(i1, grad_input_mask[1] || grad_input_mask[2]), + wrap_opt_if(i2, grad_input_mask[0] || grad_input_mask[2]), + wrap_opt_if(i3, grad_input_mask[0] || grad_input_mask[1]), + expand1, expand2, expand3, sumdim, grad_input_mask)" + result: "_trilinear(i1_t, i2_p, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) + + _trilinear(i1_p, i2_t, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) + + _trilinear(i1_p, i2_p, i3_t, expand1, expand2, expand3, sumdim, unroll_dim)" + +- name: constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor + self: constant_pad_nd_backward(grad, pad) + result: constant_pad_nd_symint(self_t, pad, 0) + +- name: binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + self: binary_cross_entropy_backward(grad, self, target, weight, reduction) + target: binary_cross_entropy_target_backward(grad, self, target, weight, reduction) + result: "apply_loss_reduction( + binary_cross_entropy_backward(self_t, self_p, target_p, weight, at::Reduction::None) + + binary_cross_entropy_target_backward(target_t, self_p, target_p, weight, at::Reduction::None), + reduction)" + +- name: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + self: binary_cross_entropy_double_backward(grad_output, grad, self, target, weight, reduction) + target: binary_cross_entropy_double_backward_target(grad, grad_output, self, target, weight, reduction) + grad_output: binary_cross_entropy_double_backward_grad_output(grad, self, target, weight, reduction) + result: " binary_cross_entropy_double_backward(grad_output_p, self_t, self_p, target_p, weight, reduction) + + binary_cross_entropy_double_backward_target(target_t, grad_output_p, self_p, target_p, weight, reduction) + + binary_cross_entropy_double_backward_grad_output(grad_output_t, self_p, target_p, weight, reduction)" + +- name: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor + self: binary_cross_entropy_with_logits_backward(grad, self, target, weight, pos_weight, reduction) + target: binary_cross_entropy_with_logits_target_backward(grad, self, target, weight, pos_weight, reduction) + result: "apply_loss_reduction( + binary_cross_entropy_with_logits_backward(self_t, self_p, target_p, weight, pos_weight, at::Reduction::None) + + binary_cross_entropy_with_logits_target_backward(target_t, self_p, target_p, weight, pos_weight, at::Reduction::None), + reduction)" + +- name: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor + indices: non_differentiable + weight: embedding_backward_symint(grad, indices, weight.sym_size(0), padding_idx, scale_grad_by_freq, sparse) + result: auto_linear + +- name: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor + grad_output: embedding_dense_double_backward_symint(grad, indices, padding_idx) + indices: non_differentiable + result: auto_linear + +- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) + indices: non_differentiable + offsets: non_differentiable + weight: _embedding_bag_backward_symint(grad, indices, offsets, result1, result2, result3, weight.sym_size(0), scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx) + per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, offsets, result1, mode, padding_idx) + +- name: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + grad: not_implemented("_embedding_bag_backward") + indices: non_differentiable + offsets: non_differentiable + offset2bag: non_differentiable + bag_size: non_differentiable + maximum_indices: non_differentiable + per_sample_weights: not_implemented("_embedding_bag_backward") + +- name: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + grad: not_implemented("_embedding_bag_dense_backward") + indices: non_differentiable + offset2bag: non_differentiable + bag_size: non_differentiable + maximum_indices: non_differentiable + per_sample_weights: not_implemented("_embedding_bag_dense_backward") + +- name: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!) + indices: non_differentiable + self: not_implemented("embedding_renorm") + +- name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + self: mse_loss_backward(grad, self, target, reduction) + target: mse_loss_backward(grad, target, self, reduction) + result: apply_loss_reduction(mse_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None).conj() + mse_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None).conj(), reduction) + +- name: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor + self: multi_margin_loss_backward(grad, self, target, p, margin, weight, reduction) + target: non_differentiable + +- name: multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target) + self: multilabel_margin_loss_backward(grad, self, target, reduction, is_target) + target: non_differentiable + +- name: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + self: nll_loss_backward_symint(grad, self, target, weight, reduction, ignore_index, total_weight) + target: non_differentiable + output: std::get<0>(nll_loss_forward_symint(self_t, target, weight, reduction, ignore_index)) + +- name: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + self: nll_loss2d_backward_symint(grad, self, target, weight, reduction, ignore_index, total_weight) + target: non_differentiable + output: std::get<0>(nll_loss2d_forward_symint(self_t, target, weight, reduction, ignore_index)) + +- name: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor + self: smooth_l1_loss_backward(grad, self, target, reduction, beta) + target: smooth_l1_loss_backward(grad, target, self, reduction, beta) + result: apply_loss_reduction(smooth_l1_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None, beta).conj() + smooth_l1_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None, beta).conj(), reduction) + +- name: huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor + self: huber_loss_backward(grad, self, target, reduction, delta) + target: huber_loss_backward(grad, target, self, reduction, delta) + result: apply_loss_reduction(huber_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None, delta).conj() + huber_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None, delta).conj(), reduction) + +- name: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + self: soft_margin_loss_backward(grad, self, target, reduction) + result: apply_loss_reduction(soft_margin_loss_backward(self_t.conj(), self_p, target, at::Reduction::None).conj(), reduction) + +- name: relu(Tensor self) -> Tensor + self: threshold_backward(grad, result, 0) + result: auto_element_wise + +- name: silu(Tensor self) -> Tensor + self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)" + result: auto_element_wise + +- name: mish(Tensor self) -> Tensor + self: "GradMode::is_enabled() ? infinitely_differentiable_mish_backward(grad, self) : mish_backward(grad, self)" + result: auto_element_wise + +- name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor + self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ false, self) + result: auto_element_wise + +- name: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) + self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ true, result) + result: self_t.copy_(elu_backward(original_self_t, alpha, scale, input_scale, /* is_result */ true, result)) + +- name: celu(Tensor self, Scalar alpha=1.0) -> Tensor + self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ false, self) + result: auto_element_wise + +- name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) + self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result) + result: self_t.copy_(elu_backward(original_self_t, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result)) + +- name: gelu(Tensor self, *, str approximate='none') -> Tensor + self: gelu_backward(grad, self, approximate) + result: auto_element_wise + +- name: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor + grad_output: gelu_backward(grad, self, approximate) + self: gelu_double_backward(grad, grad_output, self, approximate) + result: gelu_backward(grad_output_t, self_p, approximate) + gelu_double_backward(self_t, grad_output_p, self_p, approximate) + +- name: glu(Tensor self, int dim=-1) -> Tensor + # TODO: glu_backward can benefit from forward result, + # and forward ad/forward over reverse ad for that matter + self: glu_backward(grad, self, dim) + result: glu_jvp(result, self_p, self_t, dim) + +- name: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor + self: hardshrink_backward(grad, self, lambd) + result: auto_element_wise + +- name: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor + grad_out: hardshrink_backward(grad, self, lambd) + self: zeros_like(grad) + result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_out_t, at::zeros({}, result.options()).expand_as(result)) + +- name: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor + self: hardtanh_backward(grad, self, min_val, max_val) + result: auto_element_wise + +- name: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor + self: leaky_relu_backward(grad, self, negative_slope, false) + result: auto_element_wise + +- name: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!) + self: leaky_relu_backward(grad, result, negative_slope, true) + result: self_t.copy_(leaky_relu_backward(original_self_t.conj(), result, negative_slope, true).conj()) + +- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) + self: log_sigmoid_backward(grad, self, buffer) + output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj() + output_differentiability: [True, False] + +- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + self: _log_softmax_backward_data(grad, result, dim, self.scalar_type()) + result: self_t - logsumexp_jvp(self_p, self_t, {dim}, true) + +- name: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + self: _sparse_log_softmax_backward_data(grad, result, dim, self) + +- name: _masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor + self: _masked_softmax_backward(grad, result, mask, dim) + mask: non_differentiable + +- name: _prelu_kernel(Tensor self, Tensor weight) -> Tensor + self, weight: "grad.defined() ? _prelu_kernel_backward(grad, self, weight) : std::tuple()" + result: at::where(self_p >= 0, self_t, weight_p * self_t + weight_t * self_p) + +- name: _prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) + grad_output: "grads[0].defined() ? + (grads[1].defined() ? at::where(self >= 0, grads[0], grads[0] * weight + grads[1] * self) + : at::where(self >= 0, grads[0], grads[0] * weight)) + : at::where(self >= 0, at::zeros({}, grad_output.options()), grads[1] * self)" + self: "grads[1].defined() ? at::where(self >= 0, at::zeros({}, self.options()), grad_output * grads[1]) : zeros_like(self)" + weight: "grads[0].defined() ? at::where(self >= 0, at::zeros({}, weight.options()), grad_output * grads[0]) : zeros_like(self)" + result0: at::where(self_p >= 0, grad_output_t, grad_output_t * weight_p + grad_output_p * weight_t) + result1: at::where(self_p >= 0, at::zeros({}, self_p.options()), grad_output_p * self_t + grad_output_t * self_p) + +- name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor + self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) + result: auto_element_wise + +- name: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) + self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true) + +- name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor + self: _softmax_backward_data(grad, result, dim, self.scalar_type()) + result: result * (self_t - logsumexp_jvp(self_p, self_t, {dim}, true)) + +- name: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + self: _sparse_softmax_backward_data(grad, result, dim, self) + +- name: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor + self: sparse_sparse_matmul_backward(grad, self, other, 0) + other: sparse_sparse_matmul_backward(grad, self, other, 1) + +- name: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor + self: softplus_backward(grad, self, beta, threshold) + result: auto_element_wise + +- name: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor + self: softshrink_backward(grad, self, lambd) + result: auto_element_wise + +- name: threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor + self: threshold_backward(grad, self, threshold) + result: auto_element_wise + +- name: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) + self: threshold_backward(grad, self, threshold) + result: self_t.copy_(threshold_backward(self_t.conj(), original_self_p, threshold).conj()) + +- name: reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor + self: reflection_pad1d_backward_symint(grad, self, padding) + result: auto_linear + +- name: reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor + self: reflection_pad2d_backward_symint(grad, self, padding) + result: auto_linear + +- name: reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor + self: reflection_pad3d_backward_symint(grad, self, padding) + result: auto_linear + +- name: replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor + self: replication_pad1d_backward_symint(grad, self, padding) + result: auto_linear + +- name: replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor + self: replication_pad2d_backward_symint(grad, self, padding) + result: auto_linear + +- name: replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor + self: replication_pad3d_backward_symint(grad, self, padding) + result: auto_linear + +- name: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor + self: upsample_linear1d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales) + result: auto_linear + +- name: upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_bilinear2d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + +- name: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_bilinear2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + +- name: upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_bicubic2d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + +- name: _upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_bicubic2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + +- name: upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_trilinear3d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_d, scales_h, scales_w) + result: auto_linear + +- name: upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + self: upsample_nearest1d_backward_symint(grad, output_size, self.sym_sizes(), scales) + result: auto_linear + +- name: _upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + self: _upsample_nearest_exact1d_backward_symint(grad, output_size, self.sym_sizes(), scales) + result: auto_linear + +- name: upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_nearest2d_backward_symint(grad, output_size, self.sym_sizes(), scales_h, scales_w) + result: auto_linear + +- name: _upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_nearest_exact2d_backward_symint(grad, output_size, self.sym_sizes(), scales_h, scales_w) + result: auto_linear + +- name: upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_nearest3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w) + result: auto_linear + +- name: _upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_nearest_exact3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w) + result: auto_linear + +- name: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor + self: pixel_unshuffle(grad, upscale_factor) + result: auto_linear + +- name: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor + self: pixel_shuffle(grad, downscale_factor) + result: auto_linear + +- name: channel_shuffle(Tensor self, SymInt groups) -> Tensor + self: channel_shuffle_symint(grad, grad.sym_size(1) / groups) + result: auto_linear + +- name: _adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + self: _adaptive_avg_pool2d_backward(grad, self) + result: auto_linear + +- name: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + self: _adaptive_avg_pool3d_backward(grad, self) + result: auto_linear + +- name: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) + self: adaptive_max_pool2d_backward(grad, self, result1) + result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1) + output_differentiability: [True, False] + +- name: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) + self: adaptive_max_pool3d_backward(grad, self, result1) + result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) + output_differentiability: [True, False] + +- name: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + self: avg_pool2d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + result: auto_linear + +- name: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + self: avg_pool3d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + result: auto_linear + +- name: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor) + self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, result1) + result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1) + output_differentiability: [True, False] + +- name: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor) + self: fractional_max_pool3d_backward(grad, self, kernel_size, output_size, result1) + result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) + output_differentiability: [True, False] + +- name: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor + input, weight, bias: "grad.defined() ? linear_backward(input, grad, weight, grad_input_mask) : std::tuple()" + +- name: linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + self, grad_output, weight: linear_double_backward(grads, self, grad_output, weight) + +#mps +- name: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + self: max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode) + +- name: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + self, weight, bias: "grad.defined() ? mps_convolution_backward_symint(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple()" + +- name: mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + grad_output, self, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) + +- name: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + self: max_pool2d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1) + result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1) + output_differentiability: [True, False] + +- name: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + self: max_pool3d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1) + result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) + output_differentiability: [True, False] + +- name: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor + self: max_pool_double_backward(grad, indices, 2) + indices: non_differentiable + result: auto_linear + +- name: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor + self: max_pool_double_backward(grad, indices, 3) + indices: non_differentiable + result: auto_linear + +- name: convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" + result: convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups) + +# TorchScript serializes calls to _convolution so this entry is present until that is changed to use convolution. +# Note that the benchmark, deterministic, cudnn_enabled, and allow_tf32 flags are queried from the global context +# by convolution_backward instead of being passed along from the forward pass. +- name: _convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor + input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" + result: _convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32) + +- name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + grad_output, input, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) + result0: std::get<0>(convolution_backward_symint(grad_output_p, input_p, weight_t, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) + std::get<0>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) + result1: std::get<1>(convolution_backward_symint(grad_output_p, input_t, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) + std::get<1>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) + result2: convolution_backward_jvp_grad_bias(grad_output_t, result2) + +- name: convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + input, weight, bias: "grad.defined() ? convolution_backward_overrideable_symint(grad, input, weight, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" + +- name: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + grad_output, input, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) + +- name: slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" + +- name: slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" + +- name: _slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor + self, weight, bias: "grad.defined() ? _slow_conv2d_backward_symint(grad, self, weight, kernel_size, stride, padding, grad_input_mask) : std::tuple()" + +- name: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + grad_output, self, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, grad_input_mask) + +- name: _conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()" + +- name: conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()" + +- name: slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, /*dilation=*/ {{1, 1, 1}}, false, /*output_padding=*/ {{0, 0, 0}}, 1, grad_input_mask) : std::tuple()" + +- name: slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + +- name: slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + +- name: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + self: im2col(grad, kernel_size, dilation, padding, stride) + result: auto_linear + +- name: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + self: col2im_symint(grad, {self.sym_size(-2), self.sym_size(-1)}, kernel_size, dilation, padding, stride) + result: auto_linear + +- name: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor + grad_output: _adaptive_avg_pool2d_symint(grad, {grad_output.sym_size(-2), grad_output.sym_size(-1)}) + self: zeros_like(self) + result: _adaptive_avg_pool2d_backward(grad_output_t, self_p) + +- name: _adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor + grad_output: _adaptive_avg_pool3d_symint(grad, { grad_output.sym_size(-3), grad_output.sym_size(-2), grad_output.sym_size(-1) }) + self: zeros_like(self) + result: _adaptive_avg_pool3d_backward(grad_output_t, self_p) + +- name: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 2) + self: zeros_like(self) + result: auto_linear + +- name: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 3) + self: zeros_like(self) + result: auto_linear + +- name: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + grad_output: avg_pool2d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + self: zeros_like(self) + result: avg_pool2d_backward(grad_output_t, self_p, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + +- name: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + grad_output: avg_pool3d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + self: zeros_like(self) + result: avg_pool3d_backward(grad_output_t, self_p, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + +- name: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor + grad_output: elu_backward(grad, alpha, scale, input_scale, is_result, self_or_result) + self_or_result: elu_double_backward(grad, grad_output, alpha, scale, input_scale, is_result, self_or_result) + result: elu_backward(grad_output_t, alpha, scale, input_scale, is_result, self_or_result_p) + elu_double_backward(self_or_result_t, grad_output_p, alpha, scale, input_scale, is_result, self_or_result_p) + +- name: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 2) + self: zeros_like(self) + result: auto_linear + +- name: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 3) + self: zeros_like(self) + result: auto_linear + +- name: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor + grad_output: glu_double_backward_grad_output(grad, self, dim) + self: glu_double_backward(grad, grad_output, self, dim) + result: glu_backward_jvp(result, grad_output_p, self_p, grad_output_t, self_t, dim) + +- name: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor + grad_output: hardtanh_backward(grad, self, min_val, max_val) + self: zeros_like(grad) + result: at::where((self_p > min_val).logical_and(self_p < max_val), grad_output_t, at::zeros({}, result.options()).expand_as(result)) + +- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor + grad_output: log_sigmoid_backward(grad, self, buffer) + self: log_sigmoid_double_backward(grad * grad_output, self) + result: log_sigmoid_backward(grad_output_t, self_p, buffer) + log_sigmoid_double_backward(self_t * grad_output_p, self_p) + +- name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor + grad_output: grad.to(output.dtype()) - (grad.to(output.dtype()) * output.exp()).sum(dim, true) + output: (-grad_output.sum(dim, true) * output.exp() * grad.to(output.dtype())).to(output.dtype()) + +- name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor + # self_is_result is always false here since double backward call is an out-of-place call, self is input itself + grad_output: leaky_relu_backward(grad, self, negative_slope, false) + self: zeros_like(grad) + # leaky_relu_backward(grad_output, self, negative_slope, false) + # computes grad_output * at::where(self_p > 0, 1, negative_slope) + # so the jvp formula is the following: + # grad_output_t * at::where(self_p > 0, self_p.new_ones([]), negative_slope); + # + # leaky_relu_backward(grad_output, result, negative_slope, true) + # computes grad_output * at::where(result > 0, 1, negative_slope) + # under the assumption that `negative_slope` is positive (otherwise, + # it is not possible to compute the gradient). + # + # so the jvp formula is the following: + # grad_output_t * at::where(result_p > 0, result_p.new_ones([]), negative_slope); + # with the assumption that negative_slope is positive. + # + # Combined together that results in the following optimized kernel which + # also checks the assumption that negative_slope is positive when self_is_result + # is True: + result: leaky_relu_backward(grad_output_t, self_p, negative_slope, self_is_result) + +# This derivative is mps-only, and `error_for_max_pool2d_double_backward` just raises an error. +- name: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + grad_output: error_for_max_pool2d_double_backward() + self: zeros_like(self) + result: auto_linear + +- name: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 2) + self: zeros_like(self) + indices: non_differentiable + result: auto_linear + +- name: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 3) + self: zeros_like(self) + indices: non_differentiable + result: auto_linear + +- name: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + grad_output: mse_loss_backward(grad, self, target, reduction) + self: mse_loss_double_backward(grad * grad_output, self, reduction) + target: -mse_loss_double_backward(grad * grad_output, target, reduction) + result: " mse_loss_double_backward(self_t * grad_output_p, self_p, reduction) + - mse_loss_double_backward(target_t * grad_output_p, target_p, reduction) + + mse_loss_backward(grad_output_t, self_p, target_p, reduction) + " + +- name: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + grad_output: nll_loss_symint(grad, target, weight, reduction, ignore_index) + self: zeros_like(grad) + target: non_differentiable + +- name: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + grad_output: nll_loss2d_symint(grad, target, weight, reduction, ignore_index) + self: zeros_like(grad) + target: non_differentiable + +- name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor + # self_is_result is always false here since double backward call is an out-of-place call, self is input itself + grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) + self: zeros_like(grad) + result: rrelu_with_noise_backward(grad_output_t, self_p, noise, lower, upper, training, false) + +- name: reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + grad_output: reflection_pad1d_symint(grad, padding) + self: zeros_like(self) + result: reflection_pad1d_backward_symint(grad_output_t, self_p, padding) + +- name: reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + grad_output: reflection_pad2d_symint(grad, padding) + self: zeros_like(self) + result: reflection_pad2d_backward_symint(grad_output_t, self_p, padding) + +- name: reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + grad_output: reflection_pad3d_symint(grad, padding) + self: zeros_like(self) + result: reflection_pad3d_backward_symint(grad_output_t, self_p, padding) + +- name: replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + grad_output: replication_pad1d_symint(grad, padding) + self: zeros_like(self) + result: replication_pad1d_backward_symint(grad_output_t, self_p, padding) + +- name: replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + grad_output: replication_pad2d_symint(grad, padding) + self: zeros_like(self) + result: replication_pad2d_backward_symint(grad_output_t, self_p, padding) + +- name: replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + grad_output: replication_pad3d_symint(grad, padding) + self: zeros_like(self) + result: replication_pad3d_backward_symint(grad_output_t, self_p, padding) + +- name: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self, mat1, mat2: "sparse_sampled_addmm_backward(grad, + self, + wrap_opt_if(mat1, grad_input_mask[2]), + wrap_opt_if(mat2, grad_input_mask[1]), + alpha, beta, grad_input_mask)" + +- name: _sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor) + output_differentiability: [True, False] + self, other: "grad.defined() ? _sparse_mm_reduce_impl_backward(self, grad, other, reduce, result1, grad_input_mask) : std::tuple()" + +- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor + grad_output: smooth_l1_loss_backward(grad, self, target, reduction, beta) + self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta) + target: -smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta) + result: " smooth_l1_loss_double_backward(self_t * grad_output_p, self_p, target_p, reduction, beta) + - smooth_l1_loss_double_backward(target_t * grad_output_p, self_p, target_p, reduction, beta) + + smooth_l1_loss_backward(grad_output_t, self_p, target_p, reduction, beta) + " + +- name: huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor + grad_output: huber_loss_double_backward_grad_output(grad, grad_output, self, target, reduction, delta) + self: huber_loss_double_backward(grad * grad_output, self, target, reduction, delta) + target: -huber_loss_double_backward(grad * grad_output, self, target, reduction, delta) + +- name: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor + grad_output: softplus_backward(grad, self, beta, threshold) + self: softplus_double_backward(grad * grad_output, self, beta, threshold) + result: "softplus_backward(grad_output_t, self_p, beta, threshold) + + softplus_double_backward(self_t * grad_output_p, self_p, beta, threshold)" + +- name: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor + grad_output: _softmax_backward_data(grad.to(output.dtype()), output, dim, input_dtype) + output: softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(output.dtype()) + +- name: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + grad_output: soft_margin_loss_double_backward_grad_output(grad, grad_output, self, target, reduction) + self: soft_margin_loss_double_backward(grad * grad_output, self, target, reduction) + +- name: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor + grad_output: softshrink_backward(grad, self, lambd) + self: zeros_like(grad) + result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_output_t, at::zeros({}, result.options()).expand_as(result)) + +- name: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor + grad_output: threshold_backward(grad, self, threshold) + self: zeros_like(grad) + result: zeros_like(self_t) + threshold_backward(grad_output_t, self_p, threshold) + +- name: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor + grad_output: upsample_linear1d_symint(grad, output_size, align_corners, scales) + result: auto_linear + +- name: upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_bilinear2d_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + +- name: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_bilinear2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + +- name: upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_bicubic2d_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + +- name: _upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_bicubic2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + +- name: upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_trilinear3d_symint(grad, output_size, align_corners, scales_d, scales_h, scales_w) + result: auto_linear + +- name: upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + grad_output: upsample_nearest1d_symint(grad, output_size, scales) + result: auto_linear + +- name: _upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + grad_output: _upsample_nearest_exact1d_symint(grad, output_size, scales) + result: auto_linear + +- name: upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_nearest2d_symint(grad, output_size, scales_h, scales_w) + result: auto_linear + +- name: _upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_nearest_exact2d_symint(grad, output_size, scales_h, scales_w) + result: auto_linear + +- name: upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_nearest3d_symint(grad, output_size, scales_d, scales_h, scales_w) + result: auto_linear + +- name: _upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_nearest_exact3d_symint(grad, output_size, scales_d, scales_h, scales_w) + result: auto_linear + +- name: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor + grad_output: sigmoid_backward(grad, output.conj()) + output: grad.conj() * grad_output * (-2 * output.conj() + 1) + result: sigmoid_backward(grad_output_t, output_p) + output_t.conj() * grad_output_p * (-2 * output_p.conj() + 1) + +- name: tanh_backward(Tensor grad_output, Tensor output) -> Tensor + grad_output: tanh_backward(grad, output.conj()) + output: grad.conj() * (-2 * output.conj() * grad_output) + result: tanh_backward(grad_output_t, output_p) + output_t.conj() * (-2 * output_p.conj() * grad_output_p) + +# cudnn +- name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity) + +- name: _cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity) + +- name: cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, output_padding, stride, dilation, true, groups, {grad_input_mask[0], grad_input_mask[1]})" + +- name: _mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + self, weight: "grad.defined() ? mps_convolution_transpose_backward_symint(self, grad, weight, padding, output_padding, stride, dilation, groups, grad_input_mask) : std::tuple()" + +- name: cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, std::vector(padding.size(), 0), stride, dilation, false, groups, {grad_input_mask[0], grad_input_mask[1]})" + +- name: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output + self, grid: "grad.defined() ? cudnn_grid_sampler_backward(self, grid, grad) : std::tuple()" + +- name: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid + theta: cudnn_affine_grid_generator_backward(grad, N, C, H, W) + +# NB: Why is the backwards here so complicated? CuDNN cannot be used to compute +# backward in evaluation mode, because the math for backward in evaluation mode +# is different (since the forward math is different), and CuDNN does not support +# it. And in any case, you shouldn't be using this bn in evaluation mode, +# because it should be merged into the previous convolution (left for future +# work.) +# NB2: The quotes around the gradient are needed to appease YAML parsing rules. +- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? (training ? cudnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon, retain_variables ? result3.clone() : result3) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) + +# HACK: save_mean and save_var are going to be passed in as +# requires_grad variables (even though we'll never backprop through +# them) so we need to prevent the unpacking from triggering an error. +- name: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) + save_mean: not_implemented("cudnn_batch_norm_backward save_mean") + save_var: not_implemented("cudnn_batch_norm_backward save_var") + reserveSpace: not_implemented("cudnn_batch_norm_backward reserveSpace") + input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) + +# nnpack + +- name: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor + # NNPACK does not support strided convolutions in the backwards path, which is the reason why we are using the closest available function that does here. + input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector(padding.size(), 1), false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + +#LSTM MPS +- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) + output_differentiability: [True, True, True, False, False, False] + input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, result5, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)" + +- name: lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[]) + + + +# Only frst three of _cudnn_rnn outputs can have gradients. +# _cudnn_rnn outputs: (output, hy, cy, reserve, weight_buf) +- name: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + dropout_state: non_differentiable + output_differentiability: [True, True, True, False, False] + input, hx, cx, weight: "_cudnn_rnn_backward_symint(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" + +- name: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + dropout_state: non_differentiable + input: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + weight: not_implemented_list("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + hx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + cx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + grad_output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + grad_hy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + grad_cy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + +# miopen + +- name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, groups, grad_input_mask) : std::tuple()" + +- name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" + +- name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" + +- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) + +- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) + save_mean: not_implemented("miopen_batch_norm_backward save_mean") + save_var: not_implemented("miopen_batch_norm_backward save_var") + input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) + +- name: miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + dropout_state: non_differentiable + output_differentiability: [True, True, True, False, False] + input, hx, cx, weight: "miopen_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" + +- name: miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + dropout_state: non_differentiable + +- name: mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor) + output_differentiability: [True, True, True, False] + input, weight0, weight1, weight2, weight3, hx_, cx_: "GradMode::is_enabled() ? mkldnn_rnn_layer_differentiable_backward(input, weight0, weight1, weight2, weight3, hx_, cx_, result0, result1, result2, grads[0], grads[1], grads[2], reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, result3) : mkldnn_rnn_layer_backward(input, weight0, weight1, weight2, weight3, hx_, cx_, result0, result1, result2, grads[0], grads[1], grads[2], reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, result3)" + +- name: mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) + +# mkldnn +- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" + +- name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor + self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask) + +- name: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + self: mkldnn_max_pool2d_backward(grad, result, self, kernel_size, stride, padding, dilation, ceil_mode) + +- name: mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + self: mkldnn_max_pool3d_backward(grad, result, self, kernel_size, stride, padding, dilation, ceil_mode) + +- name: mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor + self: mkldnn_adaptive_avg_pool2d_backward(grad, self) + +- name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor + self: grad.reshape_symint(self.sym_sizes()) + +# NestedTensor +- name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + list: "grad.defined()? at::unbind(grad) : std::vector(list.size())" + +- name: _nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor + t: grad.to_padded_tensor_symint(0, t.sym_sizes()) + mask: non_differentiable + +- name: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor + padded: _nested_from_padded_backward(grad, padded, fuse_transform_0213) + cpu_nested_shape_example: non_differentiable + +- name: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor + self: at::_nested_from_padded(grad, self._nested_tensor_size()) + padding: non_differentiable + +- name: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a) + self: grad.values() + nested_size: non_differentiable + nested_strides: non_differentiable + +- name: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) + self: grad.values() + offsets: non_differentiable + lengths: non_differentiable + dummy: non_differentiable + +- name: _nested_get_values(Tensor(a) self) -> Tensor(a) + self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? std::optional(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? std::optional(at::_nested_get_max_seqlen(self)) : ::std::nullopt)" + +# Transformer +- name: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + self: _softmax_backward_data(grad, result, dim, self.scalar_type()) + result: result * (self_t - safe_logsumexp_jvp(self_p, self_t, {dim}, true)) + +- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) + output_differentiability: [True, False, False, False] + query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale) + +- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) + +- name: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp) + output_differentiability: [True, False] + query, key, value: _scaled_dot_product_flash_attention_for_cpu_backward(grad, query, key, value, output, logsumexp, dropout_p, is_causal, attn_mask, scale) + +- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False] + query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale, window_size_left, window_size_right) + +- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) + output_differentiability: [True, False, False, False, False, False] + query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) + +- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) + +- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value, attn_bias: _scaled_dot_product_fused_attention_overrideable_backward_symint(grad, query, key, value, attn_bias, grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) + +# fft +- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor + self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back())) + result: auto_linear + +- name: _fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor + self: fft_c2r_backward(grad, dim, normalization) + result: auto_linear + +- name: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor + self: _fft_c2c_symint(grad, dim, normalization, !forward) + result: auto_linear + +- name: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] + dispatch: + Default: + self: unbind_backward(grads, dim) + result: auto_linear + AutogradNestedTensor: + self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())" + result: auto_linear + +- name: stack(Tensor[] tensors, int dim=0) -> Tensor + tensors: stack_tensors_backward(grad, dim, to_args_scalartypes(tensors)) + result: stack_jvp(tensors, dim) + +# fused RNN kernels + +# Only frst two of _thnn_fused_lstm_cell outputs can have gradients. +# _thnn_fused_lstm_cell outputs: (hy, cy, workspace) +- name: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, True, False] + input_gates, hidden_gates, cx, input_bias, hidden_bias: "GradMode::is_enabled() ? _thnn_differentiable_lstm_cell_backward(grads[0], grads[1], input_gates, hidden_gates, input_bias, hidden_bias, cx, result1) : _thnn_fused_lstm_cell_backward(grads[0], grads[1], cx, result1, result2, input_bias.defined())" + +- name: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor) + input_gates, hidden_gates, hx, input_bias, hidden_bias: "grad.defined() ? (GradMode::is_enabled() ? _thnn_differentiable_gru_cell_backward(grad, input_gates, hidden_gates, hx, input_bias, hidden_bias) : _thnn_fused_gru_cell_backward(grad, result1, input_bias.defined())) : std::tuple()" + +# PackedSequence helpers +- name: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor) + input: _pack_padded_sequence_backward_symint(grad, input.sym_sizes(), result1, batch_first) + +# TH wrappers +- name: eq.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: eq.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: ge.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: ge.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: gt.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: gt.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: le.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: le.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: lt.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: lt.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: ne.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: ne.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor + output_differentiability: [False] + +- name: nonzero(Tensor self) -> Tensor + output_differentiability: [False] + +- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor + data: _segment_reduce_backward(grad, result, data, reduce, lengths, offsets, axis, initial) + +- name: _pin_memory(Tensor self, Device? device=None) -> Tensor + self: grad + +- name: _new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor + self: non_differentiable + other: non_differentiable + output_differentiability: [False] + +- name: _test_warn_in_autograd(Tensor self) -> Tensor + self: warn_backwards(grad) + +- name: _test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor + dispatch: + Default: + self: grad.expand_symint(self.sym_sizes()) + 1 + result: auto_linear + AutogradNestedTensor: + self: grad.mul(grad) + AutogradCUDA: + self: grad.expand_symint(self.sym_sizes()) * 2 + +- name: _test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor + dispatch: + AutogradNestedTensor: + self: grad.mul(grad).add(grad) + +- name: _test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a) + dispatch: + Default: + self: grad.reshape_as(self) + AutogradCUDA: + self: grad.reshape_as(self) + 1 + +- name: _efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + output_differentiability: [False] + +- name: scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor + self, src: scatter_reduce_backward(grad, self, dim, index, src, reduce, include_self, result) + index: non_differentiable + result: scatter_reduce_jvp(self_p, self_t, dim, index, src_p, src_t, reduce, include_self, result) + +- name: special_airy_ai(Tensor x) -> Tensor + x: non_differentiable + +- name: special_bessel_j0(Tensor self) -> Tensor + self: non_differentiable + +- name: special_bessel_j1(Tensor self) -> Tensor + self: non_differentiable + +- name: special_bessel_y0(Tensor self) -> Tensor + self: non_differentiable + +- name: special_bessel_y1(Tensor self) -> Tensor + self: non_differentiable + +- name: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_modified_bessel_i0(Tensor self) -> Tensor + self: non_differentiable + +- name: special_modified_bessel_i1(Tensor self) -> Tensor + self: non_differentiable + +- name: special_modified_bessel_k0(Tensor self) -> Tensor + self: non_differentiable + +- name: special_modified_bessel_k1(Tensor self) -> Tensor + self: non_differentiable + +- name: special_scaled_modified_bessel_k0(Tensor x) -> Tensor + x: non_differentiable + +- name: special_scaled_modified_bessel_k1(Tensor x) -> Tensor + x: non_differentiable + +- name: special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_spherical_bessel_j0(Tensor x) -> Tensor + x: non_differentiable + +- name: _reshape_copy(Tensor self, SymInt[] size) -> Tensor + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear + +# note(crcrpar): `torchgen/api/autograd` logic would unwantedly replace substrings of `self` and `other` of function names. +- name: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[] + self: div_tensor_self_backward(grads[i], other[i], self[i].scalar_type()) + other: div_tensor_other_backward(grads[i], self[i], other[i]) + result: (self_t - other_t * result[i]) / other_p + +- name: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[] + self: pow_backward_self(grads[i], self[i], exponent[i]) + exponent: pow_backward_exponent(grads[i], self[i], exponent[i], result[i]) + result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result[i])).conj() + +- name: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[] + self: pow_backward(grads[i], self[i], exponent[i]) + result: pow_backward(self_t.conj(), self_p, exponent[i]).conj() + +- name: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] + exponent: pow_backward_exponent(grads[i], self, exponent[i], result[i]) + +# note(crcrpar): following definitions seem necessary because the reference native functions +# of `maximum` and `minimum` don't have the overload def with Scalar as their second argument. +- name: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] > scalar, 0) + result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p < scalar).to(result[i].scalar_type())) * (self_t - scalar) + +- name: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] > scalars[i], 0) + result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p < scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i]) + +- name: _foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] < scalar, 0) + result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p > scalar).to(result[i].scalar_type())) * (self_t - scalar) + +- name: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < scalars[i], 0) + result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p > scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i]) + +# note(crcrpar): forward-mode AD is tricky for a simple string replace to handle: +# formula.replace("p", "ord") produces `norm_jvord(self_ord, self_t, ord, result)` +- name: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[] + self: norm_backward(grads[i], self[i], ord, result[i]) + result: norm_jvp(self_p, self_t, ord, result[i]) diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py new file mode 100644 index 0000000000000000000000000000000000000000..c32779b3a2825e82d18a57bdeea76c47707e4284 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py @@ -0,0 +1,132 @@ +""" +For procedural tests needed for __torch_function__, we use this function +to export method names and signatures as needed by the tests in +test/test_overrides.py. + +python -m tools.autograd.gen_annotated_fn_args \ + aten/src/ATen/native/native_functions.yaml \ + aten/src/ATen/native/tags.yaml \ + $OUTPUT_DIR \ + tools/autograd + +Where $OUTPUT_DIR is where you would like the files to be +generated. In the full build system, OUTPUT_DIR is +torch/testing/_internal/generated +""" + +from __future__ import annotations + +import argparse +import os +import textwrap +from collections import defaultdict +from typing import Any, Sequence, TYPE_CHECKING + +import torchgen.api.python as python +from torchgen.context import with_native_function +from torchgen.gen import parse_native_yaml +from torchgen.utils import FileManager + +from .gen_python_functions import ( + is_py_fft_function, + is_py_linalg_function, + is_py_nn_function, + is_py_special_function, + is_py_torch_function, + is_py_variable_method, + should_generate_py_binding, +) + + +if TYPE_CHECKING: + from torchgen.model import Argument, BaseOperatorName, NativeFunction + + +def gen_annotated( + native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str +) -> None: + native_functions = parse_native_yaml( + native_yaml_path, tags_yaml_path + ).native_functions + mappings = ( + (is_py_torch_function, "torch._C._VariableFunctions"), + (is_py_nn_function, "torch._C._nn"), + (is_py_linalg_function, "torch._C._linalg"), + (is_py_special_function, "torch._C._special"), + (is_py_fft_function, "torch._C._fft"), + (is_py_variable_method, "torch.Tensor"), + ) + annotated_args: list[str] = [] + for pred, namespace in mappings: + groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list) + for f in native_functions: + if not should_generate_py_binding(f) or not pred(f): + continue + groups[f.func.name.name].append(f) + for group in groups.values(): + for f in group: + annotated_args.append(f"{namespace}.{gen_annotated_args(f)}") + + template_path = os.path.join(autograd_dir, "templates") + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_with_template( + "annotated_fn_args.py", + "annotated_fn_args.py.in", + lambda: { + "annotated_args": textwrap.indent("\n".join(annotated_args), " "), + }, + ) + + +@with_native_function +def gen_annotated_args(f: NativeFunction) -> str: + def _get_kwargs_func_exclusion_list() -> list[str]: + # functions that currently don't work with kwargs in test_overrides.py + return [ + "diagonal", + "round_", + "round", + "scatter_", + ] + + def _add_out_arg( + out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool + ) -> None: + for arg in args: + if arg.default is not None: + continue + out_arg: dict[str, Any] = {} + out_arg["is_kwarg_only"] = str(is_kwarg_only) + out_arg["name"] = arg.name + out_arg["simple_type"] = python.argument_type_str( + arg.type, simple_type=True + ) + size_t = python.argument_type_size(arg.type) + if size_t: + out_arg["size"] = size_t + out_args.append(out_arg) + + out_args: list[dict[str, Any]] = [] + _add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False) + if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list(): + _add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True) + + return f"{f.func.name.name}: {repr(out_args)}," + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate annotated_fn_args script") + parser.add_argument( + "native_functions", metavar="NATIVE", help="path to native_functions.yaml" + ) + parser.add_argument("tags", metavar="TAGS", help="path to tags.yaml") + parser.add_argument("out", metavar="OUT", help="path to output directory") + parser.add_argument( + "autograd", metavar="AUTOGRAD", help="path to template directory" + ) + args = parser.parse_args() + gen_annotated(args.native_functions, args.tags, args.out, args.autograd) + + +if __name__ == "__main__": + main() diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_autograd.py b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e7be149ad6d8043126a9420217eb3bfe4d42e6 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_autograd.py @@ -0,0 +1,147 @@ +""" +To run this file by hand from the root of the PyTorch +repository, run: + +python -m tools.autograd.gen_autograd \ + aten/src/ATen/native/native_functions.yaml \ + aten/src/ATen/native/tags.yaml \ + $OUTPUT_DIR \ + tools/autograd + +Where $OUTPUT_DIR is where you would like the files to be +generated. In the full build system, OUTPUT_DIR is +torch/csrc/autograd/generated/ +""" + +# gen_autograd.py generates C++ autograd functions and Python bindings. +# +# It delegates to the following scripts: +# +# gen_autograd_functions.py: generates subclasses of torch::autograd::Node +# gen_variable_type.py: generates VariableType.h which contains all tensor methods +# gen_python_functions.py: generates Python bindings to THPVariable +# + +from __future__ import annotations + +import argparse +import os + +from torchgen.api import cpp +from torchgen.api.autograd import ( + match_differentiability_info, + NativeFunctionWithDifferentiabilityInfo, +) +from torchgen.gen import parse_native_yaml +from torchgen.selective_build.selector import SelectiveBuilder + +from . import gen_python_functions +from .gen_autograd_functions import ( + gen_autograd_functions_lib, + gen_autograd_functions_python, +) +from .gen_inplace_or_view_type import gen_inplace_or_view_type +from .gen_trace_type import gen_trace_type +from .gen_variable_factories import gen_variable_factories +from .gen_variable_type import gen_variable_type +from .gen_view_funcs import gen_view_funcs +from .load_derivatives import load_derivatives + + +def gen_autograd( + native_functions_path: str, + tags_path: str, + out: str, + autograd_dir: str, + operator_selector: SelectiveBuilder, + disable_autograd: bool = False, +) -> None: + # Parse and load derivatives.yaml + differentiability_infos, used_dispatch_keys = load_derivatives( + os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path + ) + + template_path = os.path.join(autograd_dir, "templates") + + native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions + fns = sorted( + filter( + operator_selector.is_native_function_selected_for_training, native_funcs + ), + key=lambda f: cpp.name(f.func), + ) + fns_with_diff_infos: list[ + NativeFunctionWithDifferentiabilityInfo + ] = match_differentiability_info(fns, differentiability_infos) + + # Generate VariableType.h/cpp + if not disable_autograd: + gen_variable_type( + out, + native_functions_path, + tags_path, + fns_with_diff_infos, + template_path, + used_dispatch_keys, + ) + + gen_inplace_or_view_type( + out, native_functions_path, tags_path, fns_with_diff_infos, template_path + ) + + # operator filter not applied as tracing sources are excluded in selective build + gen_trace_type(out, native_funcs, template_path) + # Generate Functions.h/cpp + gen_autograd_functions_lib(out, differentiability_infos, template_path) + + # Generate variable_factories.h + gen_variable_factories(out, native_functions_path, tags_path, template_path) + + # Generate ViewFuncs.h/cpp + gen_view_funcs(out, fns_with_diff_infos, template_path) + + +def gen_autograd_python( + native_functions_path: str, + tags_path: str, + out: str, + autograd_dir: str, +) -> None: + differentiability_infos, _ = load_derivatives( + os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path + ) + + template_path = os.path.join(autograd_dir, "templates") + + # Generate Functions.h/cpp + gen_autograd_functions_python(out, differentiability_infos, template_path) + + # Generate Python bindings + deprecated_path = os.path.join(autograd_dir, "deprecated.yaml") + gen_python_functions.gen( + out, native_functions_path, tags_path, deprecated_path, template_path + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate autograd C++ files script") + parser.add_argument( + "native_functions", metavar="NATIVE", help="path to native_functions.yaml" + ) + parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml") + parser.add_argument("out", metavar="OUT", help="path to output directory") + parser.add_argument( + "autograd", metavar="AUTOGRAD", help="path to autograd directory" + ) + args = parser.parse_args() + gen_autograd( + args.native_functions, + args.tags, + args.out, + args.autograd, + SelectiveBuilder.get_nop_selector(), + ) + + +if __name__ == "__main__": + main() diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_autograd_functions.py b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_autograd_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..785ea68315b7621528ecfc20a204ef660a083880 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_autograd_functions.py @@ -0,0 +1,925 @@ +# Generates C++ autograd functions for the derivatives of ATen operations +# +# This writes two files: +# Functions.h/cpp: subclasses of autograd::Node +# python_functions.h/cpp: Python bindings for the above classes +# + +from __future__ import annotations + +from typing import Sequence + +from torchgen.api.autograd import ( + Derivative, + DifferentiabilityInfo, + SavedAttribute, + uses_retain_variables, + uses_single_grad, +) +from torchgen.api.types import ( + ArrayRefCType, + BaseCppType, + BaseCType, + Binding, + boolT, + doubleT, + intArrayRefT, + iTensorListRefT, + ListCType, + longT, + MutRefCType, + OptionalCType, + optionalIntArrayRefT, + optionalSymIntArrayRefT, + scalarT, + stringT, + symIntArrayRefT, + SymIntT, + TENSOR_LIST_LIKE_CTYPES, + tensorListT, + tensorT, + VectorCType, +) +from torchgen.code_template import CodeTemplate +from torchgen.model import Argument, FunctionSchema +from torchgen.utils import FileManager + +from .gen_inplace_or_view_type import VIEW_FUNCTIONS + + +FUNCTION_DECLARATION = CodeTemplate( + """\ +#ifdef _WIN32 +struct ${op} : public ${superclass} { + TORCH_API ${op}() = default; +#else +struct TORCH_API ${op} : public ${superclass} { +#endif + using ${superclass}::${superclass}; + variable_list apply(variable_list&& grads) override; + std::string name() const override { return "${op}"; } + void release_variables() override { + ${thread_lock} + ${release_variables} + } + ${will_release_variables} + void compiled_args(CompiledNodeArgs& args) override; + variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override; + ${saved_variables} + ${saved_list_sizes} +}; +""" +) + +WILL_RELEASE_VARIABLES = CodeTemplate( + """\ +bool retain_variables = true; +void will_release_variables() override { + retain_variables = false; +} +""" +) + +FUNCTION_DEFINITION = CodeTemplate( + """\ +variable_list ${op}::apply(variable_list&& grads) { + ${thread_lock} + ${asserts} + IndexRangeGenerator gen; + ${compute_index_ranges} + variable_list grad_inputs(gen.size()); + ${body} + return grad_inputs; +} +void ${op}::compiled_args(CompiledNodeArgs& args) { + ${compiled_args} +} +variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) { + ${apply_with_saved_before} + variable_list result = apply(variable_list(grads)); + ${apply_with_saved_after} + return result; +} +""" +) + +GRAD_INPUT_MASK = CodeTemplate( + """\ + auto grad_input_mask = std::array{ + ${masks} + };\ +""" +) + +DERIVATIVE_SINGLE = CodeTemplate( + """\ +if (task_should_compute_output({ ${name}_ix })) { + auto grad_result = ${derivative}; + copy_range(grad_inputs, ${name}_ix, grad_result); +} +""" +) + +# note(crcrpar): `self` argument and other optional positional argument +# of foreach functions are basically a list of n `Tensor`s thus iterating over +# `grads` in order to utilize and apply the existing derivative definitions +# to each `Tensor`(s) of `self`, and the others. +DERIVATIVE_SINGLE_FOREACH = CodeTemplate( + """\ +if (task_should_compute_output({ ${name}_ix })) { + std::vector grad_result; + grad_result.reserve(grads.size()); + for (const auto & i : c10::irange(grads.size())) { + if (grads[i].defined()) { + grad_result.emplace_back(${derivative}); + } else { + grad_result.emplace_back(Tensor()); + } + } + copy_range(grad_inputs, ${name}_ix, grad_result); +} +""" +) + +DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate( + """\ + if (task_should_compute_output({ ${name}_ix })) { + copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result)); + } +""" +) + +DERIVATIVE_MULTI = CodeTemplate( + """\ +if (task_should_compute_output({ ${idx_ranges} })) { + ${grad_input_mask} + auto grad_result = ${derivative}; + ${copy_ranges} +} +""" +) + +# Generates python bindings +# +# This generates the definitions for: +# (1) The PyTypeObject for each backward grad_fn subclassing Node +# (2) The entry for PyTypeObject's tp_getset slot (an array of PyGetSetDef structs) +# We generate one PyGetSetDef struct for each of grad_fn's saved inputs and outputs +# Each PyGetSetDef has a function ptr to a getter, also defined here (3). +# (3) Getters for each of grad_fn's saved inputs and outputs. +# +PY_FUNCTION_DEFINITION = CodeTemplate( + """\ +static PyTypeObject ${op}Class; +addClass<${op}>(module, ${op}Class, "${op}", ${op}_properties); +""" +) + +PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate( + """\ +${all_getter_definitions} + +static struct PyGetSetDef ${op}_properties[] = { + THP_FUNCTION_DEFAULT_PROPERTIES, + ${all_getsetdef_structs} + {nullptr} /* sentinel */ +}; + +""" +) + +PY_GETSETDEF_STRUCT = CodeTemplate( + """\ +{(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}""" +) + +PY_RAW_GETSETDEF_STRUCT = CodeTemplate( + """\ +{(char*)"_raw_saved_${name}", (getter)THP${op}_${name}_raw_getter, nullptr, nullptr, nullptr}""" +) + +# Getter templates +GETTER_DEFINITION = CodeTemplate( + """\ +PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + auto prop = static_cast<${op}*>(self->cdata.get())->${name}; + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +GETTER_DEFINITION_SAVEDVAR = CodeTemplate( + """\ +PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_; + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +GETTER_DEFINITION_RAW_SAVEDVAR = CodeTemplate( + """\ +PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_; + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate( + """\ +PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + const auto *node = static_cast<${op}*>(self->cdata.get()); + const auto& prop = node->${name}_; + if (node->${name}_released_) { + PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); + return nullptr; + } + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +GETTER_DEFINITION_RAW_VEC_SAVEDVAR = CodeTemplate( + """\ +PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + const auto *node = static_cast<${op}*>(self->cdata.get()); + const auto& prop = node->${name}_; + if (node->${name}_released_) { + PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); + return nullptr; + } + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +GETTER_DEFINITION_OPT = CodeTemplate( + """\ +PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name}; + if (!opt_prop.has_value()) { + Py_RETURN_NONE; + } + auto prop = opt_prop.value(); + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate( + """\ +PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name}; + if (!opt_prop.list.has_value()) { + Py_RETURN_NONE; + } + auto prop = opt_prop.list.value(); + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +# Getter body +GETTER_BODY_SAVEDVAR = """\ +return THPVariable_Wrap(prop.unpack(self->cdata)); +""" + +GETTER_BODY_RAW_SAVEDVAR = """\ +pybind11::object obj = pybind11::cast(prop, pybind11::return_value_policy::reference); +return obj.release().ptr(); +""" + +GETTER_BODY_VEC_SAVEDVAR = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i: c10::irange(prop.size())) { + PyTuple_SetItem(tup, (Py_ssize_t) i, THPVariable_Wrap(prop[i].unpack(self->cdata))); +} +return tup; +""" + +GETTER_BODY_RAW_VEC_SAVEDVAR = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i : c10::irange(prop.size())) { + pybind11::object obj = pybind11::cast(prop[i], pybind11::return_value_policy::reference); + PyTuple_SetItem(tup, (Py_ssize_t) i, obj.release().ptr()); +} +return tup; +""" + +GETTER_BODY_ARRAYREF_LONG = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i : c10::irange(prop.size())) { + PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong((uint64_t) prop[i])); +} +return tup; +""" + +GETTER_BODY_ARRAYREF_SYMINT = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i : c10::irange(prop.size())) { + auto si = prop[i]; + if (auto m = si.maybe_as_int()) { + PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(*m)); + } else { + auto py_symint = py::cast(si).release().ptr(); + PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint); + } +} +return tup; +""" + +GETTER_BODY_ARRAYREF_DOUBLE = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i : c10::irange(prop.size())) { + PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble((double) prop[i])); +} +return tup; +""" + +GETTER_BODY_INT64_T = """\ +return PyLong_FromUnsignedLong((int64_t) prop); +""" + +GETTER_BODY_SYMINT = """\ +if (auto m = prop.maybe_as_int()) { + return PyLong_FromUnsignedLong(*m); +} else { + return py::cast(prop).release().ptr(); +} +""" + +GETTER_BODY_DOUBLE = """\ +return PyFloat_FromDouble((double) prop); +""" + +GETTER_BODY_BOOL = """\ +if (prop) { + Py_RETURN_TRUE; +} else { + Py_RETURN_FALSE; +} +""" + +GETTER_BODY_STRING = """\ +return PyUnicode_FromStringAndSize(prop.data(), prop.size()); +""" + +GETTER_BODY_SCALAR = """\ +if (prop.isComplex()) { + auto cprop = prop.to>(); + return PyComplex_FromDoubles(cprop.real(), cprop.imag()); +} else if (prop.isFloatingPoint()) { + return PyFloat_FromDouble(prop.to()); +} else if (prop.isIntegral(/*includeBool=*/false)) { + return PyLong_FromLong(prop.to()); +} else if (prop.isBoolean()) { + if (prop.to()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} else { + PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type"); + return nullptr; +} +""" + + +GETTER_BODY_VEC_SCALAR = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i: c10::irange(prop.size())) { + if (prop[i].isComplex()) { + auto cprop = prop[i].to>(); + PyTuple_SetItem(tup, (Py_ssize_t) i, PyComplex_FromDoubles(cprop.real(), cprop.imag())); + } else if (prop[i].isFloatingPoint()) { + auto double_prop = prop[i].to(); + PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble(double_prop)); + } else if (prop[i].isIntegral(/*includeBool=*/false)) { + auto long_prop = prop[i].to(); + PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromLong(long_prop)); + } else if (prop[i].isBoolean()) { + if (prop[i].to()) { + PyTuple_SetItem(tup, (Py_ssize_t) i, Py_True); + } else { + PyTuple_SetItem(tup, (Py_ssize_t) i, Py_False); + } + } else { + PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type"); + return nullptr; + } +} +return tup; +""" + + +MISC_GETTER_DEFS = { + OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T), + OptionalCType(BaseCType(SymIntT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SYMINT), + BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE), + OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE), + BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL), + BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR), + OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR), +} + +# These functions have backwards which cannot be traced, and so must have +# their backward functions traced opaquely. +# VIEW_FUNCTIONS are not traceable because they use as_strided, which +# has an untraceable backwards, see +# https://github.com/pytorch/pytorch/issues/4250 +# TODO: This is probably not exhaustive, but it's a start +UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS + + +def get_infos_with_derivatives_list( + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] +) -> list[DifferentiabilityInfo]: + diff_info_list = [ + info + for diffinfo_dict in differentiability_infos.values() + for info in diffinfo_dict.values() + ] + + return list(filter(lambda info: info.args_with_derivatives, diff_info_list)) + + +def gen_autograd_functions_lib( + out: str, + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], + template_path: str, +) -> None: + """Functions.h and Functions.cpp body + + These contain the auto-generated subclasses of torch::autograd::Node + for each every differentiable torch function. + """ + + # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here + # infos with the diff dispatchkeys but the same name will still be in the same shard. + infos = get_infos_with_derivatives_list(differentiability_infos) + declarations = [process_function(f, FUNCTION_DECLARATION) for f in infos] + definitions = [process_function(f, FUNCTION_DEFINITION) for f in infos] + + file_basename = "Functions" + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + for suffix in [".h", ".cpp"]: + fname = file_basename + suffix + fm.write_with_template( + fname, + fname, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/" + + fname, + "autograd_function_declarations": declarations, + "autograd_function_definitions": definitions, + }, + ) + + +def gen_autograd_functions_python( + out: str, + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], + template_path: str, +) -> None: + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + num_shards = 5 + fm.write( + "python_functions.h", + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/python_functions.h", + "shard_forward_declare": [ + f"void initialize_autogenerated_functions_{i}(PyObject* module);" + for i in range(num_shards) + ], + "shard_call": [ + f"initialize_autogenerated_functions_{i}(module);" + for i in range(num_shards) + ], + }, + ) + + # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here + # infos with the diff dispatchkeys but the same name will still be in the same shard. + infos = get_infos_with_derivatives_list(differentiability_infos) + fm.write_sharded( + "python_functions.cpp", + infos, + key_fn=lambda info: info.name, + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/python_functions.cpp", + }, + env_callable=lambda info: { + "py_function_initializers": [ + process_function(info, PY_FUNCTION_DEFINITION) + ], + "py_function_props_and_getters": [ + process_function(info, PY_FUNCTION_PROPS_AND_GETTERS) + ], + }, + num_shards=num_shards, + sharded_keys={"py_function_initializers", "py_function_props_and_getters"}, + ) + + +def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str: + saved_variables: list[str] = [] + release_variables: list[str] = [] + saved_list_sizes: list[str] = [] + unpack: list[str] = [] + asserts: list[str] = [] + compute_index_ranges: list[str] = [] + getter_definitions: list[str] = [] + py_getsetdef_structs: list[str] = [] + compiled_args: list[str] = [] + apply_with_saved_before: list[str] = [] + apply_with_saved_after: list[str] = [] + + for arg in info.args_with_derivatives: + if arg.type in TENSOR_LIST_LIKE_CTYPES: + size = f"{arg.name}_size_" + saved_list_sizes.append(f"size_t {arg.name}_size_;") + else: + size = "1" + compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});") + + def save_var(var: SavedAttribute, is_output: bool) -> None: + name = var.nctype.name + type = var.nctype.type + should_append_getsetdef = True + should_append_raw_getsetdef = False + visit_name = name + uses_cpp_saved_variable_cls = False + + if ( + type == BaseCType(tensorT) + or type == OptionalCType(BaseCType(tensorT)) + or type == MutRefCType(OptionalCType(BaseCType(tensorT))) + or (type == BaseCType(scalarT) and is_output) + ): + uses_cpp_saved_variable_cls = True + saved_variables.append(f"SavedVariable {name}_;") + release_variables.append(f"{name}_.reset_data();") + ptr = "shared_from_this()" if is_output else "" + unpack.append(f"auto {name} = {name}_.unpack({ptr});") + getter_definitions.append( + GETTER_DEFINITION_SAVEDVAR.substitute( + op=info.op, name=name, body=GETTER_BODY_SAVEDVAR + ) + ) + getter_definitions.append( + GETTER_DEFINITION_RAW_SAVEDVAR.substitute( + op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR + ) + ) + should_append_raw_getsetdef = True + visit_name = f"{name}_" + elif ( + type == BaseCType(tensorListT) + or type == BaseCType(iTensorListRefT) + or type == VectorCType(BaseCType(tensorT)) + ): + # note(crcrpar): [nuanced return type of out-of-place foreach functions] + # When an out-of-place foreach function whose return signature is `Tensor[]` + # spells out its backward definitions in `derivatives.yaml`, and some of them depend on + # `result`, `result`'s type is interpreted and treated as `std::vector`. + # An out-of-place foreach whose backwards rely on their output doesn't suffer from this + # difference if the definitions are codegen'ed. + # This special case is needed for `_foreach_pow.List` and `_foreach_pow.ScalarAndTensor` + # as of https://github.com/pytorch/pytorch/pull/105504. + if type == VectorCType(BaseCType(tensorT)): + assert ( + info.func.func.name.name.base.startswith("_foreach") and is_output + ) + uses_cpp_saved_variable_cls = True + saved_variables.append(f"std::vector {name}_;") + saved_variables.append(f"bool {name}_released_ = false;") + # Just clear() is sufficient, we don't need to loop and clear each variable. + # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. + release_variables.append(f"{name}_.clear();") + release_variables.append(f"{name}_released_ = true;") + ptr = "shared_from_this()" if is_output else "nullptr" + unpack.append(f"auto {name} = unpack_list({name}_, {ptr});") + asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") + getter_definitions.append( + GETTER_DEFINITION_VEC_SAVEDVAR.substitute( + op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR + ) + ) + getter_definitions.append( + GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( + op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR + ) + ) + should_append_raw_getsetdef = True + visit_name = f"{name}_" + elif type == ListCType(OptionalCType(BaseCType(tensorT))): + uses_cpp_saved_variable_cls = True + saved_variables.append(f"std::vector {name}_;") + saved_variables.append(f"bool {name}_released_ = false;") + # Just clear() is sufficient, we don't need to loop and clear each variable. + # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. + release_variables.append(f"{name}_.clear();") + release_variables.append(f"{name}_released_ = true;") + unpack.append(f"auto {name} = unpack_opt_list({name}_);") + asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") + getter_definitions.append( + GETTER_DEFINITION_VEC_SAVEDVAR.substitute( + op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR + ) + ) + getter_definitions.append( + GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( + op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR + ) + ) + should_append_raw_getsetdef = True + visit_name = f"{name}_" + elif type == BaseCType(intArrayRefT): + saved_variables.append(f"std::vector {name};") + getter_definitions.append( + GETTER_DEFINITION.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG + ) + ) + elif type == BaseCType(symIntArrayRefT): + saved_variables.append(f"std::vector {name};") + getter_definitions.append( + GETTER_DEFINITION.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT + ) + ) + elif type == BaseCType(optionalIntArrayRefT): + saved_variables.append(f"c10::OptionalArray {name};") + getter_definitions.append( + GETTER_DEFINITION_OPT_ARRAYREF.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG + ) + ) + elif type == BaseCType(optionalSymIntArrayRefT): + saved_variables.append(f"c10::OptionalArray {name};") + getter_definitions.append( + GETTER_DEFINITION_OPT_ARRAYREF.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT + ) + ) + elif type == OptionalCType(BaseCType(intArrayRefT)): + saved_variables.append(f"c10::OptionalArray {name};") + getter_definitions.append( + GETTER_DEFINITION_OPT_ARRAYREF.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG + ) + ) + elif type == OptionalCType(BaseCType(symIntArrayRefT)): + saved_variables.append(f"c10::OptionalArray {name};") + getter_definitions.append( + GETTER_DEFINITION_OPT_ARRAYREF.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT + ) + ) + elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))): + saved_variables.append(f"c10::OptionalArray {name};") + getter_definitions.append( + GETTER_DEFINITION_OPT_ARRAYREF.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE + ) + ) + elif type == BaseCType(longT): + saved_variables.append(f"{type.cpp_type()} {name} = 0;") + getter_definitions.append( + GETTER_DEFINITION.substitute( + op=info.op, name=name, body=GETTER_BODY_INT64_T + ) + ) + elif type == BaseCType(SymIntT): + saved_variables.append(f"c10::SymInt {name};") + getter_definitions.append( + GETTER_DEFINITION.substitute( + op=info.op, name=name, body=GETTER_BODY_SYMINT + ) + ) + elif type == BaseCType(stringT): + saved_variables.append(f"std::string {name};") + getter_definitions.append( + GETTER_DEFINITION.substitute( + op=info.op, name=name, body=GETTER_BODY_STRING + ) + ) + elif type == OptionalCType(BaseCType(stringT)): + saved_variables.append(f"std::optional {name};") + getter_definitions.append( + GETTER_DEFINITION_OPT.substitute( + op=info.op, name=name, body=GETTER_BODY_STRING + ) + ) + elif type == ArrayRefCType( + elem=BaseCType(type=BaseCppType(ns="at", name="Scalar")) + ): + saved_variables.append(f"std::vector {name};") + saved_variables.append(f"bool {name}_released_ = false;") + # Just clear() is sufficient, we don't need to loop and clear each variable. + # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. + release_variables.append(f"{name}.clear();") + # release_variables.append(f"{name}_released_ = true;") + # unpack.append(f"auto {name} = unpack_list({name}_);") + # asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") + getter_definitions.append( + CodeTemplate( + """\ +PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + const auto *node = static_cast<${op}*>(self->cdata.get()); + const auto& prop = node->${name}; + if (node->${name}_released_) { + PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); + return nullptr; + } + ${body} + END_HANDLE_TH_ERRORS +} + """ + ).substitute( + op=info.op, + name=name, + body=GETTER_BODY_VEC_SCALAR, + ) + ) + else: + # Check for indicators that you're putting a non-owning reference + # into the saved variable field. If this is spuriously firing, + # edit this field. Otherwise, you probably need to add a case + # above. + assert ( + "ref" not in type.cpp_type().lower() + and "view" not in type.cpp_type().lower() + and "*" not in type.cpp_type() + and "&" not in type.cpp_type() + ), f"{type.cpp_type()} looks like it contains a non-owning reference" + saved_variables.append(f"{type.cpp_type()} {name};") + + if type in MISC_GETTER_DEFS: + getter_def, body = MISC_GETTER_DEFS[type] + getter_definitions.append( + getter_def.substitute(op=info.op, name=name, body=body) + ) + else: + # Types we don't expose python bindings to yet: + # TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry, + # std::vector>, std::vector + should_append_getsetdef = False + + if should_append_getsetdef: + py_getsetdef_structs.append( + PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name) + ) + if should_append_raw_getsetdef: + py_getsetdef_structs.append( + PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name) + ) + + if uses_cpp_saved_variable_cls: + compiled_args.append( + f"args.collect({visit_name}, {'true' if is_output else 'false'});" + ) + else: + compiled_args.append(f"args.collect({visit_name});") + apply_with_saved_before.append(f"saved.before({visit_name});") + apply_with_saved_after.append(f"saved.after({visit_name});") + + for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)): + save_var(var, is_output=False) + for var in sorted(info.all_saved_outputs, key=lambda sa: str(sa.nctype.name)): + save_var(var, is_output=True) + + # lock the mutex when we release variables and in Node::apply to protect thread safety + # see Note [Thread Safety on Autograd Node] + if len(release_variables) > 0: + thread_lock = "std::lock_guard lock(mutex_);" + else: + thread_lock = "" + + if uses_retain_variables(info): + will_release_variables = WILL_RELEASE_VARIABLES.substitute() + else: + will_release_variables = "" + + body: list[str] = [] + + if uses_single_grad(info): + body.append("const auto& grad = grads[0];") + else: + # Generate aliases for gradients named for returned values. + body.extend( + f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];" + for name in sorted(info.used_named_gradients) + ) + + def emit_derivative( + derivative: Derivative, + args_with_derivatives: Sequence[Binding], + ) -> tuple[bool, str]: + formula = derivative.formula + var_names = derivative.var_names + if len(var_names) == 1: + checks_any_grad_defined = False + if "not_implemented" not in formula: + matching_args = [ + arg for arg in args_with_derivatives if arg.name == var_names[0] + ] + if len(matching_args) == 1: + # We can add undefined grad support if the input variable is a Tensor + arg = matching_args[0] + if isinstance(arg.argument, Argument) and str( + arg.argument.type + ) in ("Tensor", "Tensor?"): + formula = "any_grad_defined ? (" + formula + ") : Tensor()" + checks_any_grad_defined = True + if info.name.startswith("_foreach_"): + derivative_template = DERIVATIVE_SINGLE_FOREACH + else: + derivative_template = DERIVATIVE_SINGLE + return ( + checks_any_grad_defined, + derivative_template.substitute(name=var_names[0], derivative=formula), + ) + else: + if "grad_input_mask" in formula: + masks = [ + f"task_should_compute_output({{ {n}_ix }})," for n in var_names + ] + grad_input_mask = GRAD_INPUT_MASK.substitute( + masks=masks, n=len(var_names) + ) + else: + grad_input_mask = "" + idx_ranges = ", ".join(f"{n}_ix" for n in var_names) + copy_ranges: list[str] = [] + for i, n in enumerate(var_names): + copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i)) + return False, DERIVATIVE_MULTI.substitute( + idx_ranges=idx_ranges, + copy_ranges=copy_ranges, + derivative=formula, + grad_input_mask=grad_input_mask, + ) + + body.extend(unpack) + need_any_grad_defined_var = False + for derivative in info.derivatives: + checks_any_grad_defined, derivative_text = emit_derivative( + derivative, info.args_with_derivatives + ) + body.append(derivative_text) + need_any_grad_defined_var |= checks_any_grad_defined + # Since single-output derivative formulas need to check if grads are + # defined, only perform the check once, before all the formulas + if need_any_grad_defined_var: + body.insert( + -len(info.derivatives), + "bool any_grad_defined = any_variable_defined(grads);", + ) + + if info.name in UNTRACEABLE_FUNCTIONS: + superclass = "Node" + else: + superclass = "TraceableFunction" + + all_getsetdef_structs = ( + ",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else "" + ) + all_getter_definitions = "\n".join(getter_definitions) + + return template.substitute( + op=info.op, + compute_index_ranges=compute_index_ranges, + saved_variables=saved_variables, + release_variables=release_variables, + saved_list_sizes=saved_list_sizes, + asserts=asserts, + thread_lock=thread_lock, + will_release_variables=will_release_variables, + body=body, + superclass=superclass, + all_getter_definitions=all_getter_definitions, + all_getsetdef_structs=all_getsetdef_structs, + compiled_args=compiled_args, + apply_with_saved_before=apply_with_saved_before, + apply_with_saved_after=apply_with_saved_after, + ) diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py new file mode 100644 index 0000000000000000000000000000000000000000..e8141658b0335cb7272a4bb885b49fdb934d1bbd --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py @@ -0,0 +1,675 @@ +# Generates ADInplaceOrViewType.h/cpp +# +# NOTE: If any changes are being made to the ADInplaceOrView codegen please also check +# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp +# The fallback is expected to mimick this codegen, so we should keep the two in sync. + +from __future__ import annotations + +from torchgen.api import cpp +from torchgen.api.autograd import ( + dispatch_strategy, + gen_differentiable_outputs, + NativeFunctionWithDifferentiabilityInfo, +) +from torchgen.api.types import ( + BaseCType, + Binding, + boolT, + ConstRefCType, + CType, + DispatcherSignature, + intArrayRefT, + longT, + OptionalCType, + symIntArrayRefT, + SymIntT, + tensorT, +) +from torchgen.code_template import CodeTemplate +from torchgen.context import with_native_function +from torchgen.model import ( + NativeFunction, + SchemaKind, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import FileManager + +from .context import with_native_function_with_differentiability_info +from .gen_trace_type import ( + get_return_value, + MANUAL_AUTOGRAD, + tie_return_values, + type_wrapper_name, +) + + +# See NOTE [ Autograd View Variables ] in variable.h for details. +# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT, +# you **MUST** also update the public list of view ops accordingly in +# docs/source/tensor_view.rst. Note not all ATen functions are exposed to public, +# e.g alias & sparse_coo_tensor_with_dims_and_tensors. +# +# A map: function name => name of the argument that all outputs are view of + +VIEW_FUNCTIONS_WITH_METADATA_CHANGE = [ + "view_as_complex", + "view_as_real", + "_conj", + "_neg_view", + "_nested_get_values", + "_nested_view_from_buffer", + "_nested_view_from_jagged", +] + +VIEW_FUNCTIONS = { + "numpy_T": "self", + "alias": "self", + "as_strided": "self", + "diagonal": "self", + "expand": "self", + "permute": "self", + "select": "self", + "slice": "self", + "slice_inverse": "self", + "split": "self", + "split_with_sizes": "self", + "squeeze": "self", + "t": "self", + "transpose": "self", + "unfold": "self", + "unsqueeze": "self", + "flatten": "self", + "view": "self", + "unbind": "self", + "_indices": "self", + "_values": "self", + "indices": "self", + "values": "self", + "crow_indices": "self", + "col_indices": "self", + "ccol_indices": "self", + "row_indices": "self", + # sparse_coo ctor output should really be views of both indices and values, + # but we only supports making as view of a single variable, and indices is + # discrete anyways. + # FIXME: clone indices on construction. + "sparse_coo_tensor_with_dims_and_tensors": "values", + "_reshape_alias": "self", + "_test_autograd_multiple_dispatch_view": "self", +} + +for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE: + VIEW_FUNCTIONS[key] = "self" + +# note: some VIEW_FUNCTIONS are just compositions of the view functions above +# this list contains both the root view functions and any that are purely composed +# of viewing functions, and is used by the JIT to determine when an operator +# may return a view of its inputs; however they may sometimes return a copy. +# (e.g. `contiguous`) +RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union( + { + "chunk", + "detach", + "contiguous", + "reshape", + "reshape_as", + "expand_as", + "view_as", + "real", + "imag", + "narrow", + "movedim", + "tensor_split", + "swapdims", + "swapaxes", + "mT", + "mH", + "adjoint", + "matrix_H", + } +) + +# These are the functions we consider views for the purposes of validating +# StorageImpl and TensorImpl in gen_variable_type. +# `_unsafe_view` is not included in VIEW_FUNCTIONS above because it is not a +# view for the purposes of ADInplaceOrView kernel, we do not want to call as_view +# See NOTE [Unsafe View] for more info. +ALL_VIEW_FUNCTIONS = { + **VIEW_FUNCTIONS, + "_unsafe_view": "self", +} + +ARRAYREF_TO_VEC = CodeTemplate( + """\ +auto ${vec} = ${arg}.vec(); +""" +) + +OPTIONAL_TO_VAL = CodeTemplate( + """\ +auto ${val} = ${arg}.value_or(${default}); +""" +) + +CALL_DISPATCH = CodeTemplate( + """\ +at::_ops::${unambiguous_name}::call(${unpacked_args})""" +) + +REVERSE_VIEW_DISPATCH = CodeTemplate( + """\ +${reverse_name}(${unpacked_args})""" +) + +MULTI_OUTPUT_VIEW_ITERATION = CodeTemplate( + """\ +for (auto ${view_idx} : c10::irange(${var}.size())) { + ${body} +} +""" +) + +SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate( + """\ +std::unique_ptr func(nullptr); +std::function rev_func=nullptr; +if (${is_view_with_metadata_change} || + !self.unsafeGetTensorImpl()->support_as_strided() || + self.unsafeGetTensorImpl()->is_python_dispatch() || + c10::AutogradState::get_tls_state().get_view_replay_enabled()) { + ${replay_view_func} + ${reverse_replay_view_func} +} +""" +) + +REPLAY_VIEW_FUNC = CodeTemplate( + """\ +func = std::make_unique<${view_func_name}>(${view_func_args}); +""" +) + +REVERSE_REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate( + """\ +rev_func = [=](const at::Tensor& ${input_view}) { + return ${reverse_replay_view_call}; +}; +""" +) + +METHOD_DEFINITION = CodeTemplate( + """\ +${return_type} ${type_wrapper_name}(${formals}) { + ${type_definition_body} +} +""" +) + +WRAPPER_REGISTRATION = CodeTemplate( + """\ +m.impl("${unqual_operator_name_with_overload}", + TORCH_FN(${class_type}::${type_wrapper_name}) +); +""" +) + +AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION = CodeTemplate( + """\ +m.impl("${unqual_operator_name_with_overload}", torch::autograd::autogradNotImplementedFallback()); +""" +) + +INPLACE_REDISPATCH = CodeTemplate( + """\ +{ + at::AutoDispatchBelowADInplaceOrView guard; + at::_ops::${unambiguous_name}::redispatch(${unpacked_args}); +} +""" +) + +ASSIGN_RETURN_VALUE = CodeTemplate( + """\ +${return_values} = ${rhs_value}; +""" +) + +VIEW_REDISPATCH = CodeTemplate( + """\ +${assign_return_values} ([&]() { + at::AutoDispatchBelowADInplaceOrView guard; + return at::_ops::${unambiguous_name}::redispatch(${unpacked_args}); +})(); +""" +) + +TMP_VAR = "_tmp" + + +# FIXME: Ideally these functions should be methods on Type class, but we have a +# comment in codegen/model.py there saying these concepts are not well defined. +# Thus we put a version that commonly used by autograd codegen here. +def is_tensor_type(t: Type) -> bool: + # TODO: Should handle optional here? + return t.is_tensor_like() and t.is_list_like() is None + + +def is_tensor_list_type(t: Type) -> bool: + # TODO: Should handle optional here? + return t.is_tensor_like() and t.is_list_like() is not None + + +UNPACK_TENSOR = CodeTemplate( + """\ +auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""" +) + + +def unpacked_name(arg_name: str) -> str: + return arg_name + "_" + + +# e.g. select.int -> select_copy_int_inverse() +def inverse_view_name(f: NativeFunction) -> str: + copy_variant = f"{f.root_name}_copy" + overload = f"{f.func.name.overload_name}" + if overload != "": + overload = "_" + overload + return f"{copy_variant}{overload}_inverse" + + +def extract_bindings(f: NativeFunction) -> list[Binding]: + return [ + r + for a in f.func.schema_order_arguments() + for r in cpp.argument( + a, + method=False, + symint=True, + cpp_no_default_args=set(), + faithful=False, + has_tensor_options=False, + ) + ] + + +@with_native_function +def unpack_args(f: NativeFunction) -> tuple[list[str], list[Binding]]: + body: list[str] = [] + unpacked_bindings: list[Binding] = [] + + for i, binding in enumerate(extract_bindings(f)): + assert not isinstance(binding.argument, SelfArgument) + if isinstance(binding.argument, TensorOptionsArguments): + raise RuntimeError("VariableKernel shouldn't take TensorOptions") + + is_nullable = binding.argument.type.is_nullable() + if not binding.argument.type.is_tensor_like() or is_nullable: + unpacked_bindings.append(binding) + continue + + is_tensor_list = is_tensor_list_type(binding.argument.type) + ref = (not is_nullable) and not is_tensor_list + suffix = "_opt" if is_nullable and not is_tensor_list else "" + body.append( + UNPACK_TENSOR.substitute( + arg_name=binding.name, + arg_pos=i, + suffix=suffix, + ref="&" if ref else "", + ) + ) + unpacked_bindings.append( + Binding( + name=unpacked_name(binding.name), + nctype=binding.nctype, + argument=binding.argument, + default=binding.default, + ) + ) + + return body, unpacked_bindings + + +def get_base_name(f: NativeFunction) -> str: + return f.func.name.name.base # TODO: should be str(f.func.name.name)? + + +def get_view_info(f: NativeFunction) -> str | None: + base_name = get_base_name(f) + view_info = VIEW_FUNCTIONS.get(base_name, None) + if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT: + view_info = "self" + return view_info + + +def emit_view_func( + f: NativeFunction, bindings: list[Binding], view_idx: str | None = None +) -> str: + """Generate an additional lambda function to recover views in backward when as_strided is not supported. + See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details. + """ + # TODO: Clean this logic up if we get rid of reverse view funcs or reify them. + input_base = "input_base" + replay_view_func = "" + updated_args: list[str] = [] + known_view_arg_simple_types: list[CType] = [ + BaseCType(longT), + OptionalCType(BaseCType(longT)), + BaseCType(SymIntT), + OptionalCType(BaseCType(SymIntT)), + BaseCType(boolT), + BaseCType(intArrayRefT), + BaseCType(symIntArrayRefT), + ConstRefCType(BaseCType(tensorT)), + ConstRefCType(OptionalCType(BaseCType(tensorT))), + ] + for binding in bindings: + arg, arg_type = binding.name, binding.nctype.type + if arg == "self": + updated_args.append(input_base) + continue + if arg_type not in known_view_arg_simple_types: + known_types_str = ", ".join([str(t) for t in known_view_arg_simple_types]) + raise TypeError( + f"You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: " + f"{known_types_str}. Please update the list or materialize it so that it can be closed " + "over by value, also add a test in pytorch/xla/test/test_operations.py where this code " + "is exercised." + ) + if arg_type == BaseCType(intArrayRefT) or arg_type == BaseCType( + symIntArrayRefT + ): + # It's not safe to close over IntArrayRef by value, since this is a + # reference type, so materialize a vector to close over by value + arg_vec = arg + "_vec" + replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec) + updated_args.append(arg_vec) + elif arg_type == OptionalCType(BaseCType(longT)): + # Materialize int64_t? to int64_t + arg_value = arg + "_val" + replay_view_func += OPTIONAL_TO_VAL.substitute( + arg=arg, val=arg_value, default="0" + ) + updated_args.append(arg_value) + elif arg_type == ConstRefCType(BaseCType(tensorT)) or arg_type == ConstRefCType( + OptionalCType(BaseCType(tensorT)) + ): + # NB: Closing over a tensor. If a user modifies this tensor, this will be silently + # incorrect. The proper thing to do is to store the version counter and copy on write. + updated_args.append(arg) + else: + updated_args.append(arg) + + from .gen_view_funcs import view_func_name + + view_func_args = [b.name for b in bindings if b.name != "self"] + if view_idx is not None: + view_func_args.append(f"{view_idx}") + replay_view_func += REPLAY_VIEW_FUNC.substitute( + view_func_name=view_func_name(f, include_namespace=True), + view_func_args=view_func_args, + ) + + input_view = "input_view" + reverse_unpacked_args = [ + "self", + f"{input_view}", + # inverse_return_mode= + "at::functionalization::InverseReturnMode::AlwaysView", + *(() if view_idx is None else (f"{view_idx}",)), + # skip input_base arg + *updated_args[1:], + ] + + from torchgen.api.functionalization import reverse_name + + reverse_replay_view_call = REVERSE_VIEW_DISPATCH.substitute( + reverse_name=reverse_name(f, include_namespace=True), + unpacked_args=reverse_unpacked_args, + ) + reverse_replay_view_func = REVERSE_REPLAY_VIEW_LAMBDA_FUNC.substitute( + input_view=input_view, reverse_replay_view_call=reverse_replay_view_call + ) + + is_view_with_metadata_change = ( + "true" if cpp.name(f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else "false" + ) + + return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute( + is_view_with_metadata_change=is_view_with_metadata_change, + replay_view_func=replay_view_func, + reverse_replay_view_func=reverse_replay_view_func, + ) + + +def emit_view_body( + fn: NativeFunctionWithDifferentiabilityInfo, var: str +) -> tuple[str, str]: + # See NOTE [ Autograd View Variables ] in variable.h for details. + f = fn.func + base_name = get_base_name(f) + view_info = get_view_info(f) + call = "" + differentiable_outputs = gen_differentiable_outputs(fn) + differentiable_output_vars = {r.name for r in differentiable_outputs} + if not isinstance(view_info, str): + raise TypeError( + f"The view info should be a string for {base_name}, but it is: {view_info}" + ) + if len(differentiable_output_vars) == 0: + # no output is differentiable (.indices() for SparseTensors for example) + rhs_value = ( + f"as_view({view_info}, {var}, " + f"/* is_bw_differentiable */ false, /* is_fw_differentiable */ false)" + ) + elif len(differentiable_output_vars) == 1: + # Single differentiable output (Tensor or Tensor[]) + return_info = differentiable_outputs[0] + # We only support simple Tensor or a TensorList for functions that return views + if not is_tensor_type(return_info.type) and not is_tensor_list_type( + return_info.type + ): + raise RuntimeError( + f"{base_name} that return differentiable views can only return Tensor or Tensor[]" + ) + + # See Note [ View + Inplace detection] + def get_creation_meta_in_mode(original: str) -> str: + creation_meta_with_grad_mode = f"(at::GradMode::is_enabled() ? {original} : CreationMeta::NO_GRAD_MODE)" + return f"InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : {creation_meta_with_grad_mode}" + + # Only allow rebasing of the history if we return a single Tensor + # If we are in a no grad block, raise a warning + # See NOTE [ View + Inplace detection ] for more details about this logic + if is_tensor_list_type(return_info.type): + creation_meta = get_creation_meta_in_mode("CreationMeta::MULTI_OUTPUT_NODE") + view_idx = "view_idx" + view_func = emit_view_func( + f, extract_bindings(f), view_idx=view_idx + ).strip() + as_view_call = ( + f"as_view(/* base */ {view_info}, /* output */ {var}[{view_idx}], " + "/* is_bw_differentiable */ true, /* is_fw_differentiable */ true, " + "/* view_func */ std::move(func), /* rev_view_func */ rev_func, " + f"/* creation_meta */ {creation_meta});" + ) + call += MULTI_OUTPUT_VIEW_ITERATION.substitute( + var=var, view_idx=view_idx, body=f"{view_func}\n{as_view_call}" + ) + rhs_value = f"std::move({var})" + else: + call += emit_view_func(f, extract_bindings(f), view_idx=None) + creation_meta = get_creation_meta_in_mode("CreationMeta::DEFAULT") + rhs_value = ( + f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, " + "/* is_fw_differentiable */ true, " + f"/* view_func */ std::move(func), /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})" + ) + else: + # This could be supported but we don't need it at the moment, so keeping things simple. + raise RuntimeError( + "Function that return multiple differentiable output " + "when at least one of them is view is not supported." + ) + return call, rhs_value + + +def modifies_arguments(f: NativeFunction) -> bool: + return f.func.kind() in [SchemaKind.inplace, SchemaKind.out] + + +@with_native_function_with_differentiability_info +def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> list[str]: + f = fn.func + inplace_view_body: list[str] = [] + + dispatcher_sig = DispatcherSignature.from_schema(f.func) + dispatcher_exprs = dispatcher_sig.exprs() + + # code-generated ADInplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + dispatch_key_set = "ks & c10::after_ADInplaceOrView_keyset" + redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) + + # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. + # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. + if modifies_arguments(f): # inplace op + inplace_view_body.append( + INPLACE_REDISPATCH.substitute( + unambiguous_name=f.func.name.unambiguous_name(), + unpacked_args=redispatch_args, + ) + ) + for r in cpp.return_names(f): + inplace_view_body.append(f"increment_version({r});") + else: + assert get_view_info(f) is not None + inplace_view_body.append( + VIEW_REDISPATCH.substitute( + assign_return_values="auto " + TMP_VAR + " = ", + unambiguous_name=f.func.name.unambiguous_name(), + unpacked_args=redispatch_args, + ) + ) + call, rhs_value = emit_view_body(fn, TMP_VAR) + inplace_view_body.append(call) + assert rhs_value is not None + inplace_view_body.append( + ASSIGN_RETURN_VALUE.substitute( + return_values=tie_return_values(f), rhs_value=rhs_value + ) + ) + if f.func.returns: + inplace_view_body.append(f"return {get_return_value(f)};") + return inplace_view_body + + +@with_native_function +def gen_formals(f: NativeFunction) -> str: + return ", ".join( + # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + ["c10::DispatchKeySet ks"] + + [ + f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}' + for a in f.func.schema_order_arguments() + ] + ) + + +@with_native_function_with_differentiability_info +def inplace_or_view_method_definition( + fn: NativeFunctionWithDifferentiabilityInfo, +) -> str | None: + f = fn.func + if get_view_info(f) is None and ( + # For functions that modify their inputs but don't return them, + # we can't give them autograd support. + # See https://github.com/pytorch/pytorch/issues/53796 + not modifies_arguments(f) + or len(f.func.returns) == 0 + ): + return None + return METHOD_DEFINITION.substitute( + return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(), + type_wrapper_name=type_wrapper_name(f), + formals=gen_formals(f), + type_definition_body=emit_inplace_or_view_body(fn), + ) + + +@with_native_function_with_differentiability_info +def inplace_or_view_method_registration( + fn: NativeFunctionWithDifferentiabilityInfo, +) -> str | None: + f = fn.func + if get_view_info(f) is None and ( + not modifies_arguments(f) or len(f.func.returns) == 0 + ): + return None + return WRAPPER_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name, + type_wrapper_name=type_wrapper_name(f), + class_type="ADInplaceOrView", + ) + + +def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool: + f = fn.func + name = cpp.name(f.func) + return name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == "use_derived" + + +def gen_inplace_or_view_type_env( + fn: NativeFunctionWithDifferentiabilityInfo, +) -> dict[str, list[str]]: + definition = inplace_or_view_method_definition(fn) + registration = inplace_or_view_method_registration(fn) + + return { + "ops_headers": ( + [f"#include "] + if definition is not None + else [] + ), + "inplace_or_view_method_definitions": [definition] + if definition is not None + else [], + "inplace_or_view_wrapper_registrations": [registration] + if registration is not None + else [], + } + + +def gen_inplace_or_view_type( + out: str, + native_yaml_path: str, + tags_yaml_path: str, + fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo], + template_path: str, +) -> None: + # NOTE: see Note [Sharded File] at the top of the VariableType.cpp + # template regarding sharding of the generated files. + num_shards = 2 + + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_sharded( + "ADInplaceOrViewType.cpp", + [fn for fn in fns_with_infos if use_derived(fn)], + key_fn=lambda fn: fn.func.root_name, + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/ADInplaceOrViewType.cpp", + }, + env_callable=gen_inplace_or_view_type_env, + num_shards=2, + sharded_keys={ + "ops_headers", + "inplace_or_view_method_definitions", + "inplace_or_view_wrapper_registrations", + }, + ) diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_python_functions.py b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_python_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..44453306a0ecbf65452c0287a8c903b9d11f0600 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_python_functions.py @@ -0,0 +1,1402 @@ +# Generates Python bindings for ATen functions +# +# The bindings are generated as methods on python_variable or functions on the +# torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse +# or torch._C._special objects. +# + +# Code tries to stick to the following rules: +# +# - templates should be colocated with the functions that use them. +# no templates are currently shared between functions, but if that +# happens, maybe put the template with the first one +# +# - don't use environment dictionaries when calling template.substitute(). +# pass named arguments directly for everything, otherwise it's much too +# hard to track what's actually being used and by who +# +# - colocate any new hacks/adjustments with existing ones of the same kind. +# ideally in a data structure rather than code if possible. See e.g. +# SCHEMA_DEFAULT_CONVERSION_HACKS, etc. +# +# - similarly, conversions from one format to another should ideally happen +# all at once in a single place. +# +# - no nontrivial nested functions. couple-liners are ok but please no more. +# especially avoid functions that read/write outer variables defined far away. +# +# - raise RuntimeError instead of asserting, and put as much +# information as is available into the message. I.e. no need to +# plumb in new params whose only purpose is to fill out an error +# message, but use what's there +# + +from __future__ import annotations + +import itertools +import re +from collections import defaultdict +from typing import Callable, Iterable, Sequence + +import yaml + +from torchgen.api import cpp +from torchgen.api.python import ( + arg_parser_output_exprs, + cpp_dispatch_exprs, + cpp_dispatch_target, + dispatch_lambda_args, + dispatch_lambda_exprs, + dispatch_lambda_return_str, + has_tensor_options, + PythonSignature, + PythonSignatureDeprecated, + PythonSignatureGroup, + PythonSignatureNativeFunctionPair, + signature, + signature_from_schema, + structseq_fieldnames, +) +from torchgen.code_template import CodeTemplate +from torchgen.context import with_native_function +from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml +from torchgen.model import ( + Argument, + BaseOperatorName, + FunctionSchema, + NativeFunction, + SchemaKind, + Type, + Variant, +) +from torchgen.utils import FileManager, split_name_params +from torchgen.yaml_utils import YamlLoader + +from .gen_inplace_or_view_type import is_tensor_list_type +from .gen_trace_type import should_trace + + +# +# declarations blocklist +# We skip codegen for these functions, for various reasons. +# Future PRs will categorize this list and eliminate or hoist +# them out of eager-only codegen. +# See https://github.com/pytorch/pytorch/issues/30788 +# + +# These functions require manual Python bindings or are not exposed to Python +_SKIP_PYTHON_BINDINGS = [ + "alias", + "contiguous", + "is_cuda", + "is_sparse", + "is_sparse_csr", + "size", + "stride", + "sym_size", + "sym_stride", + "sym_storage_offset", + "sym_numel", + ".*_backward", + ".*_backward_(out|input|weight|bias)", + ".*_forward", + ".*_forward_out", + ".*_jvp", + "_unsafe_view", + "tensor", + "_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*", + "_range.*", + "_sparse_add_out", + "_sparse_div.*", + "_sparse_mul.*", + "_sparse_sub.*", + "_sparse_dense_add_out", + "index", + "index_out", + "unique_dim_consecutive", + "_cumsum.*", + "_cumprod.*", + "_sum.*", + "_prod.*", + "_th_.*", + "_thnn_.*", + "range.*", + "_solve.*", + "_inverse.*", + "_cholesky.*", + "_triangular_solve.*", + "_qr.*", + "_svd.*", + "slice", + "item", + "_local_scalar_dense", + "to", + "_to_copy", + "_to_copy_out", + "_reshape_copy", + "_reshape_copy_out", + "copy_sparse_to_sparse_", + "copy_", + "_foreach_copy", + "numpy_T", + "matrix_H", + "mT", + "mH", # these need to be an attributes in Python, not functions + "nonzero(_(out|numpy))?", + "set_data", + ".*_overrideable", # overrideable functions for backend extension + "data", + "is_leaf", + "output_nr", + "_version", + "requires_grad_", + "retains_grad", + "set_", + "_fw_primal", + "fake_quantize_per_tensor_affine_cachemask", + "fake_quantize_per_channel_affine_cachemask", + "_new_zeros_with_same_feature_meta", + "_has_same_storage_numel", # used for forward AD internals + "_reshape_alias", + "replace_", # only used by the functionalization pass, doesn't need to be exposed to python + "copy", # only used by the functionalization pass + "fill.Tensor", # only used by the functionalization pass + "fill.Scalar", # only used by the functionalization pass + "lift.*", + "normal_functional", # only used by the functionalization pass + "nbytes", + "itemsize", + "_batch_norm_with_update", + "_batch_norm_with_update_out", + "_batch_norm_no_update", +] + +SKIP_PYTHON_BINDINGS = [ + re.compile(rf"^{pattern}$") for pattern in _SKIP_PYTHON_BINDINGS +] + +# These function signatures are not exposed to Python. Note that this signature +# list does not support regex. +SKIP_PYTHON_BINDINGS_SIGNATURES = [ + "add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", + "add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", + "sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", + "sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", + "mul.Scalar(Tensor self, Scalar other) -> Tensor", + "mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", + "div.Scalar(Tensor self, Scalar other) -> Tensor", + "div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", +] + + +@with_native_function +def should_generate_py_binding(f: NativeFunction) -> bool: + # NativeFunctions that are entirely code-generated should not get python bindings + # because these codegen implementations are often inefficient. A handful of + # view_copy style ops were exposed accidentally when they were handwritten and now + # that we are moving them to codegen for bc reasons we need to keep them exposed in + # python. + if "generated" in f.tags and "view_copy" not in f.tags: + return False + + name = cpp.name(f.func) + for skip_regex in SKIP_PYTHON_BINDINGS: + if skip_regex.match(name): + return False + + signature = str(f.func) + for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: + if pattern == signature: + return False + return True + + +def get_pycname(name: BaseOperatorName) -> str: + return f"THPVariable_{name}" + + +def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool: + return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0 + + +def is_py_variable_method(f: NativeFunction) -> bool: + return f.python_module is None and Variant.method in f.variants + + +def is_py_torch_function(f: NativeFunction) -> bool: + return f.python_module is None and Variant.function in f.variants + + +def is_py_nn_function(f: NativeFunction) -> bool: + return f.python_module == "nn" + + +def is_py_fft_function(f: NativeFunction) -> bool: + return f.python_module == "fft" + + +def is_py_linalg_function(f: NativeFunction) -> bool: + return f.python_module == "linalg" + + +def is_py_nested_function(f: NativeFunction) -> bool: + return f.python_module == "nested" + + +def is_py_sparse_function(f: NativeFunction) -> bool: + return f.python_module == "sparse" + + +def is_py_special_function(f: NativeFunction) -> bool: + return f.python_module == "special" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Main Function +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def gen( + out: str, + native_yaml_path: str, + tags_yaml_path: str, + deprecated_yaml_path: str, + template_path: str, + *, + symint: bool = True, +) -> None: + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + native_functions = parse_native_yaml( + native_yaml_path, tags_yaml_path + ).native_functions + native_functions = list(filter(should_generate_py_binding, native_functions)) + + methods = load_signatures(native_functions, deprecated_yaml_path, method=True) + create_python_bindings( + fm, + methods, + is_py_variable_method, + None, + "python_variable_methods.cpp", + method=True, + symint=symint, + ) + + # NOTE: num_shards here must be synced with gatherTorchFunctions in + # torch/csrc/autograd/python_torch_functions_manual.cpp + functions = load_signatures(native_functions, deprecated_yaml_path, method=False) + create_python_bindings_sharded( + fm, + functions, + is_py_torch_function, + "torch", + "python_torch_functions.cpp", + method=False, + num_shards=3, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_nn_function, + "torch.nn", + "python_nn_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_fft_function, + "torch.fft", + "python_fft_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_linalg_function, + "torch.linalg", + "python_linalg_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_nested_function, + "torch.nested", + "python_nested_functions.cpp", + method=False, + ) + + create_python_bindings( + fm, + functions, + is_py_sparse_function, + "torch.sparse", + "python_sparse_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_special_function, + "torch.special", + "python_special_functions.cpp", + method=False, + symint=symint, + ) + + # Currently, we only use `functions` to generate `return_types` bindings. + # All methods which return structseq have function variant at this point. + # If any method only operator with structseq is added in the future, + # we will have to address that. + create_python_return_type_bindings( + fm, functions, lambda fn: True, "python_return_types.cpp" + ) + create_python_return_type_bindings_header( + fm, functions, lambda fn: True, "python_return_types.h" + ) + + valid_tags = parse_tags_yaml(tags_yaml_path) + + def gen_tags_enum() -> dict[str, str]: + return { + "enum_of_valid_tags": ( + "".join( + [f'\n.value("{tag}", at::Tag::{tag})' for tag in sorted(valid_tags)] + ) + ) + } + + fm.write("python_enum_tag.cpp", gen_tags_enum) + + +def group_filter_overloads( + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], +) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]: + grouped: dict[ + BaseOperatorName, list[PythonSignatureNativeFunctionPair] + ] = defaultdict(list) + for pair in pairs: + if pred(pair.function): + grouped[pair.function.func.name.name].append(pair) + return grouped + + +def create_python_bindings( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + module: str | None, + filename: str, + *, + method: bool, + symint: bool = True, +) -> None: + """Generates Python bindings to ATen functions""" + py_methods: list[str] = [] + ops_headers: list[str] = [] + py_method_defs: list[str] = [] + py_forwards: list[str] = [] + + grouped = group_filter_overloads(pairs, pred) + + for name in sorted(grouped.keys(), key=str): + overloads = grouped[name] + py_methods.append( + method_impl(name, module, overloads, method=method, symint=symint) + ) + py_method_defs.append(method_def(name, module, overloads, method=method)) + py_forwards.extend(forward_decls(name, overloads, method=method)) + ops_headers.append(f"#include ") + + fm.write_with_template( + filename, + filename, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + "ops_headers": ops_headers, + "py_forwards": py_forwards, + "py_methods": py_methods, + "py_method_defs": py_method_defs, + }, + ) + + +def create_python_return_type_bindings( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + filename: str, +) -> None: + """ + Generate function to initialize and return named tuple for native functions + which returns named tuple and registration invocations in `python_return_types.cpp`. + """ + py_return_types_definition: list[str] = [] + py_return_types_registrations: list[str] = [] + + grouped = group_filter_overloads(pairs, pred) + + for name in sorted(grouped.keys(), key=str): + overloads = grouped[name] + definitions, registrations = generate_return_type_definition_and_registrations( + overloads + ) + py_return_types_definition.append( + "" if not definitions else "\n".join(definitions) + ) + py_return_types_registrations.append( + "" if not registrations else "\n".join(registrations) + ) + + fm.write_with_template( + filename, + filename, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + "py_return_types": py_return_types_definition, + "py_return_types_registrations": py_return_types_registrations, + }, + ) + + +def create_python_return_type_bindings_header( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + filename: str, +) -> None: + """ + Generate function to initialize and return named tuple for native functions + which returns named tuple and relevant entry for the map in `python_return_types.cpp`. + """ + py_return_types_declarations: list[str] = [] + + grouped = group_filter_overloads(pairs, pred) + + for name in sorted(grouped.keys(), key=str): + overloads = grouped[name] + declarations = generate_return_type_declarations(overloads) + py_return_types_declarations.append( + "" if not declarations else "\n".join(declarations) + ) + + fm.write_with_template( + filename, + filename, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + "py_return_types_declarations": py_return_types_declarations, + }, + ) + + +def create_python_bindings_sharded( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + module: str | None, + filename: str, + *, + method: bool, + num_shards: int, + symint: bool = True, +) -> None: + """Generates Python bindings to ATen functions""" + grouped = group_filter_overloads(pairs, pred) + + def key_func( + kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] + ) -> str: + return kv[0].base + + def env_func( + kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] + ) -> dict[str, list[str]]: + name, fn_pairs = kv + return { + "ops_headers": [f"#include "], + "py_forwards": list(forward_decls(name, fn_pairs, method=method)), + "py_methods": [ + method_impl(name, module, fn_pairs, method=method, symint=symint) + ], + "py_method_defs": [method_def(name, module, fn_pairs, method=method)], + } + + fm.write_sharded( + filename, + grouped.items(), + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + }, + key_fn=key_func, + env_callable=env_func, + num_shards=num_shards, + sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"}, + ) + + +def load_signatures( + native_functions: list[NativeFunction], + deprecated_yaml_path: str, + *, + method: bool, + skip_deprecated: bool = False, + pyi: bool = False, +) -> Sequence[PythonSignatureNativeFunctionPair]: + @with_native_function + def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair: + return PythonSignatureNativeFunctionPair( + signature=signature(f, method=method, pyi=pyi), + function=f, + ) + + pairs = list(map(gen_signature_pairs, native_functions)) + deprecated = load_deprecated_signatures( + pairs, deprecated_yaml_path, method=method, pyi=pyi + ) + return pairs if skip_deprecated else pairs + deprecated + + +def load_deprecated_signatures( + pairs: Sequence[PythonSignatureNativeFunctionPair], + deprecated_yaml_path: str, + *, + method: bool, + pyi: bool, +) -> list[PythonSignatureNativeFunctionPair]: + # The deprecated.yaml doesn't have complete type information, we need + # find and leverage the original ATen signature (to which it delegates + # the call) to generate the full python signature. + # We join the deprecated and the original signatures using type-only form. + + # group the original ATen signatures by name + grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list) + for pair in pairs: + grouped[pair.signature.name].append(pair) + + # find matching original signatures for each deprecated signature + results: list[PythonSignatureNativeFunctionPair] = [] + + with open(deprecated_yaml_path) as f: + deprecated_defs = yaml.load(f, Loader=YamlLoader) + + for deprecated in deprecated_defs: + schema = FunctionSchema.parse(deprecated["name"]) + aten_name, call_args = split_name_params(deprecated["aten"]) + is_out = aten_name.endswith("_out") + if is_out: + aten_name = aten_name.replace("_out", "") + + # HACK: these are fixed constants used to pass the aten function. + # The type must be known ahead of time + known_constants = { + "1": Type.parse("Scalar"), + } + schema_args_by_name = {a.name: a for a in schema.arguments.flat_all} + for name in call_args: + assert ( + name in schema_args_by_name or name in known_constants + ), f"deprecation definiton: Unrecognized value {name}" + + # Map deprecated signature arguments to their aten signature and test + # if the types and alias annotation match. + def is_schema_compatible( + aten_schema: FunctionSchema, + ) -> bool: + arguments: Iterable[Argument] + if is_out: + arguments = itertools.chain( + aten_schema.arguments.out, aten_schema.arguments.flat_non_out + ) + else: + arguments = aten_schema.arguments.flat_all + + for i, arg in enumerate(arguments): + if i < len(call_args): + arg_name = call_args[i] + if arg_name in known_constants: + schema_type = known_constants[arg_name] + schema_annotation = None + else: + schema_arg = schema_args_by_name[arg_name] + schema_type = schema_arg.type + schema_annotation = schema_arg.annotation + + if schema_type != arg.type or schema_annotation != arg.annotation: + return False + else: + if arg.default is None: + return False + + return len(schema.returns) == len(aten_schema.returns) and all( + a == b for a, b in zip(schema.returns, aten_schema.returns) + ) + + any_schema_found = False + for pair in grouped[aten_name]: + if not is_schema_compatible(pair.function.func): + continue + any_schema_found = True + + python_sig = signature_from_schema( + schema, + category_override=pair.function.category_override, + method=method, + pyi=pyi, + ) + + results.append( + PythonSignatureNativeFunctionPair( + signature=PythonSignatureDeprecated( + name=python_sig.name, + input_args=python_sig.input_args, + input_kwargs=python_sig.input_kwargs, + output_args=python_sig.output_args, + tensor_options_args=python_sig.tensor_options_args, + method=python_sig.method, + deprecated_schema=schema, + deprecated_args_exprs=tuple(call_args), + returns=python_sig.returns, + ), + function=pair.function, + ) + ) + assert ( + any_schema_found + ), f"No native function with name {aten_name} matched signature:\n {str(schema)}" + + return results + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Named Tuple Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@with_native_function +def gen_structseq_typename_key(f: NativeFunction) -> str: + name = cpp.name(f.func) + fieldnames = structseq_fieldnames(f.func.returns) + return "_".join([name] + fieldnames) + + +def emit_structseq_call( + overloads: Sequence[PythonSignatureNativeFunctionPair], +) -> tuple[list[str], dict[str, str]]: + """ + Generate block of named tuple type def inits, and add typeref snippets + to declarations that use them + """ + typenames: dict[ + str, str + ] = {} # map from unique name + field name lists to typedef name + typedefs: list[str] = [] # typedef declarations and init code + + for overload in overloads: + fieldnames = structseq_fieldnames(overload.function.func.returns) + if not fieldnames: + continue + + name = cpp.name(overload.function.func) # use @with_native_function? + tn_key = gen_structseq_typename_key(overload.function) + typename = typenames.get(tn_key) + if typename is None: + typename = f'NamedTuple{"" if not typedefs else len(typedefs)}' + typenames[tn_key] = typename + typedefs.append( + f"""\ +static PyTypeObject* {typename} = generated::get_{name}_structseq();""" + ) + + return typedefs, typenames + + +def generate_return_type_definition_and_registrations( + overloads: Sequence[PythonSignatureNativeFunctionPair], +) -> tuple[list[str], list[str]]: + """ + Generate block of function in `python_return_types.cpp` to initialize + and return named tuple for a native function which returns named tuple + and registration invocations in same file. + """ + typenames: dict[ + str, str + ] = {} # map from unique name + field name lists to typedef name + definitions: list[str] = [] # function definition to register the typedef + registrations: list[str] = [] # register call for the typedef + + for overload in overloads: + fieldnames = structseq_fieldnames(overload.function.func.returns) + if not fieldnames: + continue + + fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames) + + name = cpp.name(overload.function.func) # use @with_native_function? + tn_key = gen_structseq_typename_key(overload.function) + typename = typenames.get(tn_key) + + if typename is None: + typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}' + typenames[tn_key] = typename + definitions.append( + f"""\ +PyTypeObject* get_{name}_structseq() {{ + static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }}; + static PyTypeObject {typename}; + static bool is_initialized = false; + static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }}; + if (!is_initialized) {{ + PyStructSequence_InitType(&{typename}, &desc); + {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; + is_initialized = true; + }} + return &{typename}; +}} +""" + ) + registrations.append( + f'addReturnType(return_types_module, "{name}", generated::get_{name}_structseq());' + ) + + return definitions, registrations + + +def generate_return_type_declarations( + overloads: Sequence[PythonSignatureNativeFunctionPair], +) -> list[str]: + """ + Generate block of function declarations in `python_return_types.h` to initialize + and return named tuple for a native function. + """ + typenames: dict[ + str, str + ] = {} # map from unique name + field name lists to typedef name + declarations: list[str] = [] # function declaration to register the typedef + + for overload in overloads: + fieldnames = structseq_fieldnames(overload.function.func.returns) + if not fieldnames: + continue + + name = cpp.name(overload.function.func) # use @with_native_function? + tn_key = gen_structseq_typename_key(overload.function) + typename = typenames.get(tn_key) + + if typename is None: + typename = ( + f'{name}NamedTuple{"" if not declarations else len(declarations)}' + ) + typenames[tn_key] = typename + declarations.append(f"PyTypeObject* get_{name}_structseq();") + + return declarations + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Method Impl Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# python binding for all overloads of a particular function/method +PY_VARIABLE_METHOD_VARARGS = CodeTemplate( + r"""\ +// ${name} +static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + ${method_header} + static PythonArgParser parser({ + ${signatures} + }, /*traceable=*/${traceable}); + + ParsedArgs<${max_args}> parsed_args; + auto _r = parser.parse(${self_}, args, kwargs, parsed_args); + ${check_has_torch_function} + switch (_r.idx) { + ${dispatch} + } + ${method_footer} +} + +""" +) + +# handler for a single parsed signature - may be a single overload or +# a pair of overloads that whose signatures only differ in output params +# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch}) +PY_VARIABLE_CASE = CodeTemplate( + """\ +case ${overload_index}: { + ${body} +} +""" +) + +# python binding for single-overload function/method +PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate( + """\ +// ${name} +static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + ${method_header} + static PythonArgParser parser({ + ${signatures} + }, /*traceable=*/${traceable}); + + ParsedArgs<${max_args}> parsed_args; + auto _r = parser.parse(${self_}, args, kwargs, parsed_args); + ${check_has_torch_function} + ${dispatch} + ${method_footer} +} + +""" +) + +# python binding for a method with no args, shortcuts parsing +PY_VARIABLE_METHOD_NOARGS = CodeTemplate( + """\ +// ${name} +static PyObject * ${pycname}(PyObject* self_, PyObject* args) +{ + ${method_header} + ${check_has_torch_function} + ${dispatch} + ${method_footer} +} + +""" +) + + +def method_impl( + name: BaseOperatorName, + module: str | None, + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool, + symint: bool = True, +) -> str: + """ + Generate a python binding for all overloads of an op. + """ + pycname = get_pycname(name) + noarg = is_noarg(overloads) + structseq_inits, structseq_typenames = emit_structseq_call(overloads) + + method_header = ["HANDLE_TH_ERRORS"] + method_header += structseq_inits + method_header += ( + ["const Tensor& self = THPVariable_Unpack(self_);"] if method else [] + ) + + method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"] + + traceable = "true" if all(should_trace(o.function) for o in overloads) else "false" + + grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads( + overloads, symint=symint + ) + is_singleton = len(grouped_overloads) == 1 + signatures: list[str] = [] + dispatch: list[str] = [] + for overload_index, overload in enumerate(grouped_overloads): + signature = overload.signature.signature_str(symint=symint) + signatures.append(f"{cpp_string(str(signature))},") + dispatch_body = emit_dispatch_case(overload, structseq_typenames, symint=symint) + dispatch.append( + PY_VARIABLE_CASE.substitute( + overload_index=overload_index, body=dispatch_body + ) + if not is_singleton + else dispatch_body + ) + + if noarg: + template = PY_VARIABLE_METHOD_NOARGS + elif is_singleton: + template = PY_VARIABLE_METHOD_VARARGS_SINGLETON + else: + template = PY_VARIABLE_METHOD_VARARGS + + return template.substitute( + name=name, + pycname=pycname, + method_header=method_header, + max_args=max(o.signature.arguments_count() for o in overloads), + signatures=signatures, + traceable=traceable, + check_has_torch_function=gen_has_torch_function_check( + name=name, + module=module, + noarg=noarg, + method=method, + ), + dispatch=dispatch, + method_footer=method_footer, + self_="self_" if method else "nullptr", + ) + + +def gen_has_torch_function_check( + name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool +) -> str: + if noarg: + if method: + return f"""\ +if(check_has_torch_function(self_)) {{ + return handle_torch_function(self_, "{name}"); +}} +""" + else: + return "" + + self_ = "self_" if method else "nullptr" + namespace = ( + { + "torch": "THPVariableFunctionsModule", + "torch.nn": "THPNNVariableFunctionsModule", + "torch.fft": "THPFFTVariableFunctionsModule", + "torch.linalg": "THPLinalgVariableFunctionsModule", + "torch.nested": "THPNestedVariableFunctionsModule", + "torch.sparse": "THPSparseVariableFunctionsModule", + "torch.special": "THPSpecialVariableFunctionsModule", + }[module] + if module + else "THPVariableClass" + ) + + return f"""\ +if(_r.has_torch_function()) {{ + return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}"); +}} +""" + + +# handler for output/no-output overload pair +PY_VARIABLE_OUT = CodeTemplate( + """\ +if (_r.isNone(${out_idx})) { + ${call_dispatch} +} else { + ${call_dispatch_out} +} +""" +) + + +def emit_dispatch_case( + overload: PythonSignatureGroup, + structseq_typenames: dict[str, str], + *, + symint: bool = True, +) -> str: + """ + Emit dispatch code for a single parsed signature. This corresponds to either + a single native function, or a pair that differ only in output params. In the + latter case, a single python signature is used for both and dispatching + switches on the presence/absence of passed output args. + """ + if overload.outplace is not None: + # dispatch output and no-output variants, branch on _r.isNone() + return PY_VARIABLE_OUT.substitute( + out_idx=overload.signature.output_idx(), + call_dispatch=emit_single_dispatch( + overload.signature, overload.base, structseq_typenames, symint=symint + ), + call_dispatch_out=emit_single_dispatch( + overload.signature, + overload.outplace, + structseq_typenames, + symint=symint, + ), + ) + else: + # no-output version only + return emit_single_dispatch( + overload.signature, overload.base, structseq_typenames, symint=symint + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Forward Declarations Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def forward_decls( + name: BaseOperatorName, + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool, +) -> tuple[str, ...]: + if method: + return () + + pycname = get_pycname(name) + if is_noarg(overloads): + return ( + f"""\ +static PyObject * {pycname}(PyObject* self_, PyObject* args); +""", + ) + else: + return ( + f"""\ +static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs); +""", + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Method Def (Binding Table Entry) Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def method_def( + name: BaseOperatorName, + module: str | None, + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool, +) -> str: + """ + Generate method def entry. + """ + pycname = get_pycname(name) + + if name.dunder_method: + # PyMethodDef entry for binary op, throws not implemented error + pycname = f"TypeError_to_NotImplemented_<{pycname}>" + + if is_noarg(overloads): + flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS" + else: + pycname = f"castPyCFunctionWithKeywords({pycname})" + flags = "METH_VARARGS | METH_KEYWORDS" + + if module == "torch": + flags += " | METH_STATIC" + + return f'{{"{name}", {pycname}, {flags}, NULL}},' + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Overload Sorting and Grouping +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def group_overloads( + overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True +) -> Sequence[PythonSignatureGroup]: + bases: dict[str, PythonSignatureNativeFunctionPair] = {} + outplaces: dict[str, PythonSignatureNativeFunctionPair] = {} + + # first group by signature ignoring out arguments + for overload in overloads: + sig = overload.signature.signature_str(skip_outputs=True, symint=symint) + if overload.function.func.is_out_fn(): + if sig in outplaces: + raise RuntimeError( + f"Found duplicated function definition:\n- {overload.function.func}.\n" + f"Existing definition:\n- {outplaces[sig].function.func}." + ) + outplaces[sig] = overload + else: + if sig in bases: + raise RuntimeError( + f"Found duplicated function definition:\n- {overload.function.func}.\n" + f"Existing definition:\n- {bases[sig].function.func}." + ) + bases[sig] = overload + + for sig, out in outplaces.items(): + if sig not in bases: + candidates: list[str] = [] + for overload in overloads: + if ( + str(overload.function.func.name.name) + == str(out.function.func.name.name) + and not overload.function.func.is_out_fn() + and not overload.signature.deprecated + ): + candidates.append( + overload.signature.signature_str( + skip_outputs=True, symint=symint + ) + ) + out_sig = out.signature.signature_str(symint=symint) + raise RuntimeError( + f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. " + f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema " + "correctly in native_functions.yaml. We discovered the following candidate(s): \n" + + "\n".join(f"- {candidate}" for candidate in candidates) + ) + + grouped = [ + PythonSignatureGroup.from_pairs( + functional=base, + out=outplaces.get(sig), + ) + for sig, base in bases.items() + ] + return sort_overloads(grouped, symint=symint) + + +# This function declares a partial order on declarations, and sorts them according +# to its linear extension. This is necessary, because there's some ambiguity in the +# choice of overload, and we want a different order. +# +# See Note[Order of overloads matters] +# +# A few examples of ambiguous python signature pairs. +# +# All parameters have the same type, except one taking Tensor the other taking +# Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor +# object can be accepted as Scalar type parameter (see python_arg_parser.cpp). +# Therefore, same input arguments might be accepted by either python signature. +# We want to always parse the one taking Tensor first. +# +# bitwise_and(Tensor input, Tensor other, *, Tensor out=None) +# bitwise_and(Tensor input, Scalar other, *, Tensor out=None) +# +# If they have different number of parameters then they are not ambiguous - but +# the difference on output param can be ignored as it's optional. +# +# multiply(Tensor input, Tensor other, *, Tensor out=None) +# multiply(Tensor input, Scalar other) +# +# Both positional args and keyword-only args are considered together. +# +# subtract(Tensor other, *, Scalar alpha=1) +# subtract(Scalar other, Scalar alpha=1) +# +# A few ambiguous cases which it does NOT handle yet. +# +# If there is any difference in other parameters besides the Tensor/Scalar +# difference, then they are not considered ambiguous by this method anymore. +# However, the difference could be too trivial to disambiguate. +# +# foo(Tensor input, Scalar other, Scalar bar) +# foo(Tensor input, Tensor other, double bar) +# +# If they are taking different number of parameters then they are not considered +# ambiguous anymore, even if the difference is only on optional kwargs. +# +# foo(Scalar other, Scalar alpha=1) +# foo(Tensor other, *, Scalar alpha=1, Scalar beta=1) +# + + +def sort_overloads( + grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True +) -> Sequence[PythonSignatureGroup]: + # NB: Smaller here means lower priority + + def is_arg_smaller(t1: Type, t2: Type) -> bool: + return ( + str(t1) == "Scalar" + and str(t2) == "Tensor" + or str(t1) == "Scalar?" + and str(t2) == "Tensor?" + or "Dimname" in str(t1) + and "Dimname" not in str(t2) + or + # In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been + # discussed why it is important to prioritize int/int? over int[] + str(t1) == "int[]" + and (str(t2) == "int" or str(t2) == "int?") + or + # TensorList currently throws an error during argument parsing, that's why it needs to be + # last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087 + str(t1) == "Tensor[]" + and str(t2).find("[]") != -1 + or + # Prioritize IntArrayRef overload over SymIntArrayRef + str(t1) == "SymInt[]" + and str(t2) == "int[]" + or + # Make sure both in, SymInt are sorted consistently w.r.t. Tensor since Tensor can be implicitly + # converted to either int or SymInt. Prioritize the Tensor overload since it otherwise gets shadowed. + (str(t1) == "SymInt" or str(t1) == "int") + and str(t2) == "Tensor" + ) + + def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool: + """Returns True if s1 < s2 in the partial order.""" + args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True) + if len(args1) != len(args2): + return False + # TODO: should use some canonical form instead of 'str(arg.type)' - see comments + # above. The old codegen used the deprecated 'dynamic_type(arg.type)', which + # ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'. + equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2)) + smaller_or_equal = all( + str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type) + for arg1, arg2 in zip(args1, args2) + ) + return smaller_or_equal and not equal + + # First sort by signature + grouped_overloads = sorted( + grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint) + ) + + # Construct the relation graph + larger_than: dict[int, set[int]] = defaultdict(set) + for i1, overload1 in enumerate(grouped_overloads): + for i2, overload2 in enumerate(grouped_overloads): + if is_smaller(overload1.signature, overload2.signature): + larger_than[i1].add(i2) + + if not larger_than: + return list(grouped_overloads) + + # Use a topological sort to sort overloads according to the partial order. + N = len(grouped_overloads) + sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N))) + + for idx in range(N): + # The size of sorted_ids will grow to N eventually. + i = sorted_ids[idx] + for j in sorted(larger_than.keys()): + larger = larger_than[j] + larger.discard(i) + if not larger: + del larger_than[j] + sorted_ids.append(j) + + return [grouped_overloads[x] for x in sorted_ids] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Codegen API Integration +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def emit_single_dispatch( + ps: PythonSignature, + f: NativeFunction, + structseq_typenames: dict[str, str], + *, + symint: bool = True, +) -> str: + """ + Emit dispatch code for a single native function. + """ + + @with_native_function + def go(f: NativeFunction) -> str: + # header comments + if isinstance(ps, PythonSignatureDeprecated): + schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}" + else: + schema_comment = f"// aten::{f.func}" + + deprecated = "[deprecated] " if ps.deprecated else "" + + # dispatch lambda signature + name = cpp.name(f.func) + lambda_formals = ", ".join( + f"{a.type_str} {a.name}" for a in dispatch_lambda_args(ps, f, symint=symint) + ) + lambda_return = dispatch_lambda_return_str(f) + + # dispatch lambda body + dispatch_callee = cpp_dispatch_target(f) + dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps)) + + # from arg parser outputs to dispatch lambda arguments + parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) + lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint) + inits = "\n".join(lambda_arg_exprs.inits) + lambda_args = ", ".join(lambda_arg_exprs.exprs) + + # scatter fields + # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky + # solution for enabling the 'requires_grad' argument for tensor methods + # new_full, new_empty, and new_zeros. A much better but more difficult to + # implement solution involves refactoring according to Ed's description here: + # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589 + need_set_requires_grad = ps.tensor_options_args and ( + not has_tensor_options(f) + or (ps.method and ("requires_grad" in parser_outputs)) + ) + set_requires_grad = ( + f'.set_requires_grad({parser_outputs["requires_grad"].expr})' + if need_set_requires_grad + else "" + ) + + if lambda_return == "void": + # Make in-place foreach return `self` at python-binding level. + # ref: https://github.com/pytorch/pytorch/pull/118622#pullrequestreview-1904804954 + self_arg = f.func.arguments.self_arg + return_stmt: str + if ( + str(f.func.name).startswith("_foreach_") + and f.func.kind() == SchemaKind.inplace + ): + # note(crcrpar): `_foreach_pow.ScalarAndTensor` does NOT have its in-place + # variant and it unlikely to have it in the future. Thus it's safe to have the following assert. + assert self_arg is not None and is_tensor_list_type( + self_arg.argument.type + ) + return_stmt = """PyObject* self_tensorlist = _r.args[0]; +Py_INCREF(self_tensorlist); +return self_tensorlist; +""" + else: + return_stmt = "Py_RETURN_NONE;" + return f"""\ +{schema_comment} +{inits} +auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ + pybind11::gil_scoped_release no_gil; + {dispatch_callee}({dispatch_args}); +}}; +dispatch_{name}({lambda_args}){set_requires_grad}; +{return_stmt} +""" + else: + typename = structseq_typenames.get(gen_structseq_typename_key(f)) + structseq_typeref = f"{typename}, " if typename is not None else "" + return f"""\ +{schema_comment} +{inits} +auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ + pybind11::gil_scoped_release no_gil; + return {dispatch_callee}({dispatch_args}); +}}; +return wrap({structseq_typeref}dispatch_{name}({lambda_args}){set_requires_grad}); +""" + + return go(f) diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_trace_type.py b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_trace_type.py new file mode 100644 index 0000000000000000000000000000000000000000..3b462655010417a655efec0114b118ba2fa0bd6a --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_trace_type.py @@ -0,0 +1,536 @@ +from __future__ import annotations + +import itertools +from typing import Sequence + +from torchgen.api import cpp +from torchgen.api.types import DispatcherSignature +from torchgen.code_template import CodeTemplate +from torchgen.context import with_native_function +from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments +from torchgen.utils import FileManager + + +# Note [Manual Backend kernels] +# For these ops, we want to manually register to dispatch key Backend and +# skip codegen-ed registeration to all keys before Backend. +# For codegen this means: +# - op set below must match ops with manual_kernel_registration=True in native_functions.yaml +# where we skip codegen backend kernels +# - all ops below are part of MANUAL_AUTOGRAD to skip codegen Autograd kernel registration +# - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration +# Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now. +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_BACKEND = { + "options", + "data", + "set_data", + "is_leaf", + "output_nr", + "_version", + "retain_grad", + "_backward", + "requires_grad_", +} + +# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys. +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_AUTOGRAD_AND_TRACER = { + "resize_", + "resize_as_", + "detach", + "detach_", + "copy_", + "_fw_primal", + "_make_dual", +} + +# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops: +# union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER) +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_AUTOGRAD = MANUAL_TRACER = MANUAL_BACKEND | MANUAL_AUTOGRAD_AND_TRACER + +# These functions we don't want to record for tracing, because we always want +# to trace their constituent parts. This is a temporary hack in lieue +# of proper scopes, where subsequent compilation passes can ask for the unfolding +# on demand. Only concrete ATen methods can be disabled this way; it will have +# NO EFFECT otherwise. +DONT_RECORD_TRACE = { + "convolution", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "lstm_cell", + "gru_cell", + "rnn_tanh_cell", + "rnn_relu_cell", + # FIXME: figure out a better way when we support sparse tensors in jit + "_coalesced", +} + + +def should_trace(f: NativeFunction) -> bool: + # Operations involving Storage or Type are not traceable at the moment + if any( + str(arg.type) in {"Storage", "Type", "ConstQuantizerPtr"} + for arg in f.func.schema_order_arguments() + ): + return False + # We can't trace functions which don't have any Tensor or TensorList returns + if not any(r.type.is_tensor_like() for r in f.func.returns): + return False + return f.func.name.name.base not in DONT_RECORD_TRACE + + +SELECT = CodeTemplate( + """\ + +if (${cond}) { + ${true} +} else { + ${false} +} +""" +) + +OP_NAME = CodeTemplate( + """\ +op_name = c10::Symbol::fromQualString("aten::${trace_name}"); +""" +) + +# These functions have their names recorded under trace renamed, +RENAME_TRACE = { + "zero": "zeros_like", # replacing aten::zero_ with aten::zeros_like + "fill": "full_like", # replacing aten::fill_ with aten::full_like +} + + +def format_trace_op_name(f: NativeFunction) -> str: + # TODO: byte-for-byte compatible with old codegen behavior - should clean up + if ( + f.func.kind() in (SchemaKind.functional, SchemaKind.out) + or f.func.name.name.dunder_method + ): + # special case for *_out functions: the in-place and out-of-place ops + # are overloaded with the same name in the JIT + trace_name = str(f.func.name.name) + trace_name = RENAME_TRACE.get(trace_name, trace_name) + return OP_NAME.substitute(trace_name=trace_name) + + # otherwise, this is an in-place op and we need to emit both in- and + # out-of-place versions + outplace_trace_name = f.func.name.name.base + inplace_trace_name = cpp.name(f.func) + outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name) + inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name) + + return SELECT.substitute( + cond="tracer_state->force_outplace", + true=OP_NAME.substitute(trace_name=outplace_trace_name), + false=OP_NAME.substitute(trace_name=inplace_trace_name), + ) + + +ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""") + + +def format_trace_inputs(f: NativeFunction) -> str: + def dispatch_trace_input(arg: Argument | TensorOptionsArguments) -> Sequence[str]: + if isinstance(arg, TensorOptionsArguments): + name = "options" + return [ + ADD_TRACE_INPUT.substitute( + name=name, input="c10::optTypeMetaToScalarType(options.dtype_opt())" + ), + ADD_TRACE_INPUT.substitute(name=name, input="options.layout()"), + ADD_TRACE_INPUT.substitute(name=name, input="options.device()"), + ADD_TRACE_INPUT.substitute(name=name, input="options.pinned_memory()"), + ] + else: + name = arg.name + if str(arg.type) == "Tensor?[]": + return [f'jit::tracer::addInputs(node, "{name}", {name});'] + else: + return [ADD_TRACE_INPUT.substitute(name=name, input=name)] + + args: list[Argument | TensorOptionsArguments] = list( + f.func.schema_order_arguments() + ) + + if f.func.is_out_fn(): + # *_out functions take the result as a separate argument, but we don't want to + # trace that argument directly. Instead, we trace its TensorOptions. + # So first, we need to remove the out argument from the list of arguments to trace. + num_out_args = len(f.func.arguments.out) + args = args[:-num_out_args] + + trace_inputs = itertools.chain.from_iterable( + dispatch_trace_input(arg) for arg in args + ) + + if f.func.is_out_fn(): + # for *_out functions, handle the result argument differently for inplace/outplace. + # For inplace: just add the input to the end to confirm with the JIT schema + inplace = [ + ADD_TRACE_INPUT.substitute( + name=f.func.arguments.out[i].name, input=f.func.arguments.out[i].name + ) + for i in range(num_out_args) + ] + + # for outplace: do nothing, except if the function is a factory. + # Factories are a bit special because their out-of-place overloads + # take an extra TensorOptions argument, which is missing in the _out function + has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns) + has_tensor_input_arg = any( + a.type.is_tensor_like() for a in f.func.arguments.flat_non_out + ) + is_factory_method = f.category_override == "factory" or ( + has_tensor_return and not has_tensor_input_arg + ) + + # HACK: preserve old codegen behavior - the old codegen set the `is_factory_method` + # flag for the whole family of ops with the same basename if any of them is a + # factory method. For most cases the whole family of ops are indeed all factory + # method - 'normal' is the only exception. So we handle it specially here to avoid + # cloning the old logic. + if f.func.name.name.base == "normal": + is_factory_method = True + + if is_factory_method: + outplace = [ + ADD_TRACE_INPUT.substitute( + name="out", + input="c10::optTypeMetaToScalarType(out.options().dtype_opt())", + ), + ADD_TRACE_INPUT.substitute(name="out", input="out.options().layout()"), + ADD_TRACE_INPUT.substitute(name="out", input="out.options().device()"), + ADD_TRACE_INPUT.substitute( + name="out", input="out.options().pinned_memory()" + ), + ] + else: + outplace = [] + + trace_inputs = itertools.chain( + trace_inputs, + [ + SELECT.substitute( + cond="tracer_state->force_outplace", + true="\n".join(outplace), + false="\n".join(inplace), + ) + ], + ) + + return "\n".join(trace_inputs) + + +# `torch.jit.trace` have undocumented keyword argument `_force_outplace`, +# which force jit to replace functions with outplace variants (for +# example `aten::add_` becomes `aten::add`). +# +# This replacement implemented in-place with minimum modifications of +# arguments stack (as it assumes that outplace call has the same arguments +# as inplace version). +# +# However there are no such substitutions available for `aten::fill_` +# and `aten::zero_` operators, as we never implemented `aten::fill` +# and `aten::zero`. So jit tracing hack replacing `aten::zero_` with +# `aten::zeros_like` and replacing `aten::fill_` with `aten::full_like`. +# +# But as they potentially can have different arguments, we also have +# to hack into the stack and add missing ones. +# +# A possible alternative would be: +# +# - Add `aten::fill` and `aten::zero` +# +# - Or keep `aten::zeros_like` arguments aligned with `aten::zero_` +# arguments (inside of the `native_functions.yaml`) +RENAME_TRACE_ADD_ARGS = { + "fill": """\ + jit::tracer::addInputs(node, "options", ::std::optional()); + jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt)); + jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt)); + jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt)); + ::std::optional memory_format = c10::MemoryFormat::Preserve; + jit::tracer::addInputs(node, "memory_format", memory_format); +""", + "zero": """\ + jit::tracer::addInputs(node, "options", ::std::optional()); + jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt)); + jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt)); + jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt)); + ::std::optional memory_format = c10::MemoryFormat::Preserve; + jit::tracer::addInputs(node, "memory_format", memory_format); +""", +} + +INPLACE_GUARD = CodeTemplate( + """\ +jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input}); +""" +) + +PRE_RECORD_TRACE = CodeTemplate( + """\ +torch::jit::Node* node = nullptr; +std::shared_ptr tracer_state; +if (jit::tracer::isTracing()) { + tracer_state = jit::tracer::getTracingState(); + at::Symbol op_name; + ${set_op_name} + node = tracer_state->createNode(op_name, /*num_outputs=*/0); + jit::tracer::recordSourceLocation(node); + ${add_trace_inputs} + tracer_state->insertNode(node); + ${inplace_guard} + jit::tracer::setTracingState(nullptr); +} +""" +) + + +def format_prerecord_trace(f: NativeFunction) -> str: + if not should_trace(f): + return "" + + # TODO: clean up old codegen behavior + is_inplace = ( + f.func.kind() in (SchemaKind.inplace, SchemaKind.out) + and not f.func.name.name.dunder_method + ) + add_args = ( + RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, "") if is_inplace else "" + ) + additional_inputs = ( + SELECT.substitute( + cond="tracer_state->force_outplace", + true=add_args, + false="", + ) + if add_args + else "" + ) + + return PRE_RECORD_TRACE.substitute( + set_op_name=format_trace_op_name(f), + add_trace_inputs=format_trace_inputs(f) + additional_inputs, + inplace_guard=INPLACE_GUARD.substitute( + name=cpp.name(f.func), + mutable_input=f.func.arguments.out[0].name + if f.func.arguments.out + else "self", + ) + if is_inplace + else "", + ) + + +POST_RECORD_TRACE = CodeTemplate( + """\ +if (tracer_state) { + jit::tracer::setTracingState(std::move(tracer_state)); + ${add_trace_outputs} +} +""" +) + + +def format_postrecord_trace(f: NativeFunction) -> str: + if not should_trace(f): + return "" + + # For outplacing ops, *_out overloads require special handling to move the + # output *argument* to a return value + if f.func.is_out_fn(): + output_names_outplace = [arg.name for arg in f.func.arguments.out] + output_names_inplace = cpp.return_names(f) + + # Code size optimization: the common case is that the return value is + # the same for both variants + if output_names_outplace == output_names_inplace: + outputs = [ + f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace + ] + return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) + + selection = SELECT.substitute( + cond="force_outplace", + true="\n".join( + f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace + ), + false="\n".join( + f"jit::tracer::addOutput(node, {n});" for n in output_names_inplace + ), + ) + return POST_RECORD_TRACE.substitute(add_trace_outputs=selection) + else: + output_names = cpp.return_names(f) + outputs = [f"jit::tracer::addOutput(node, {n});" for n in output_names] + return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) + + +def tie_return_values(f: NativeFunction) -> str: + if len(f.func.returns) == 1: + return f'auto {f.func.returns[0].name or "result"}' + names = cpp.return_names(f) + return f'auto [{", ".join(names)}]' + + +def get_return_value(f: NativeFunction) -> str: + names = cpp.return_names(f) + if len(f.func.returns) == 1: + return names[0] + if f.func.kind() == SchemaKind.out: + return f'std::forward_as_tuple({", ".join(names)})' + else: + moved = ", ".join(f"std::move({name})" for name in names) + return f"std::make_tuple({moved})" + + +TRACE_DISPATCH = CodeTemplate( + """\ +${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args});""" +) + + +def emit_trace_body(f: NativeFunction) -> list[str]: + trace_body: list[str] = [] + + trace_body.append(format_prerecord_trace(f)) + + dispatcher_sig = DispatcherSignature.from_schema(f.func) + dispatcher_exprs = dispatcher_sig.exprs() + + # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + dispatch_key_set = "ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)" + redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) + + assign_return_values = ( + f"{tie_return_values(f)} = " + if f.func.kind() in [SchemaKind.functional, SchemaKind.mutable] + and f.func.returns + else "" + ) + + # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. + # We could probably work harder to ensure that the fast variants are + # called instead, but the perf benefit would be minimal. + trace_body.append( + TRACE_DISPATCH.substitute( + assign_return_values=assign_return_values, + unambiguous_name=f.func.name.unambiguous_name(), + unpacked_args=redispatch_args, + ) + ) + + trace_body.append(format_postrecord_trace(f)) + if f.func.returns: + trace_body.append(f"return {get_return_value(f)};") + return trace_body + + +METHOD_DEFINITION = CodeTemplate( + """\ +${return_type} ${type_wrapper_name}(${formals}) { + ${type_definition_body} +} +""" +) + + +def type_wrapper_name(f: NativeFunction, key: str = "Default") -> str: + if f.func.name.overload_name: + name = f"{cpp.name(f.func)}_{f.func.name.overload_name}" + else: + name = cpp.name(f.func) + + # The key argument is only used in gen_variable_type where we need fns per autograd dispatch key. + # In gen_trace_type and gen_inplace_view_type where only one fn per native_fn must be generated, + # the key argument should not be passed. + # We do not append key if it is Default so that generated functions from + # before per-dispatch-key derivatives were added retain the same names. + if key != "Default": + name = name + f"_{key}" + return name + + +@with_native_function +def method_definition(f: NativeFunction) -> str: + assert cpp.name(f.func) not in MANUAL_TRACER + + formals = ", ".join( + # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + ["c10::DispatchKeySet ks"] + + [ + f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}' + for a in f.func.schema_order_arguments() + ] + ) + + return METHOD_DEFINITION.substitute( + return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(), + type_wrapper_name=type_wrapper_name(f), + formals=formals, + type_definition_body=emit_trace_body(f), + ) + + +WRAPPER_REGISTRATION = CodeTemplate( + """\ +m.impl("${name}", + TORCH_FN(${class_type}::${type_wrapper_name}) +); +""" +) + + +@with_native_function +def method_registration(f: NativeFunction) -> str: + assert cpp.name(f.func) not in MANUAL_TRACER + + return WRAPPER_REGISTRATION.substitute( + name=f.func.name, + type_wrapper_name=type_wrapper_name(f), + class_type="TraceType", + ) + + +def gen_trace_type_func(fn: NativeFunction) -> dict[str, list[str]]: + return { + "ops_headers": [f"#include "], + "trace_method_definitions": [method_definition(fn)], + "trace_wrapper_registrations": [method_registration(fn)], + } + + +def gen_trace_type( + out: str, native_functions: list[NativeFunction], template_path: str +) -> None: + # NOTE: see Note [Sharded File] at the top of the VariableType.cpp + # template regarding sharding of the generated files. + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_sharded( + "TraceType.cpp", + [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER], + key_fn=lambda fn: fn.root_name, + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/TraceType.cpp", + }, + env_callable=gen_trace_type_func, + num_shards=5, + sharded_keys={ + "ops_headers", + "trace_method_definitions", + "trace_wrapper_registrations", + }, + ) diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_variable_factories.py b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_variable_factories.py new file mode 100644 index 0000000000000000000000000000000000000000..f206939bd535a887827a8f8170e99e6d37a71aef --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_variable_factories.py @@ -0,0 +1,116 @@ +# Generates C++ functions that wrap ATen tensor factory methods to turn them into Variables. +# +# This writes one file: variable_factories.h + +from __future__ import annotations + +import re + +import torchgen.api.python as python +from torchgen.api import cpp +from torchgen.api.types import CppSignatureGroup +from torchgen.context import with_native_function +from torchgen.gen import parse_native_yaml +from torchgen.model import NativeFunction, TensorOptionsArguments, Variant +from torchgen.utils import FileManager, mapMaybe + + +OPTIONAL_TYPE_PATTERN = re.compile(r"std::optional<(.+)>") +TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)") + + +# Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc. +# TODO: maybe update the cpp argument API to take optional namespace argument? +def fully_qualified_type(argument_type: str) -> str: + def maybe_optional_type(type: str, is_opt: bool) -> str: + return f"std::optional<{type}>" if is_opt else type + + opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type) + is_opt = opt_match is not None + if opt_match: + argument_type = argument_type[opt_match.start(1) : opt_match.end(1)] + match = TYPE_PATTERN.match(argument_type) + if match is None: + return maybe_optional_type(argument_type, is_opt) + index = match.start(1) + qualified_type = f"{argument_type[:index]}at::{argument_type[index:]}" + return maybe_optional_type(qualified_type, is_opt) + + +def gen_variable_factories( + out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str +) -> None: + native_functions = parse_native_yaml( + native_yaml_path, tags_yaml_path + ).native_functions + factory_functions = [fn for fn in native_functions if is_factory_function(fn)] + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_with_template( + "variable_factories.h", + "variable_factories.h", + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/variable_factories.h", + "ops_headers": [ + f"#include " for fn in factory_functions + ], + "function_definitions": list(mapMaybe(process_function, factory_functions)), + }, + ) + + +@with_native_function +def is_factory_function(f: NativeFunction) -> bool: + if Variant.function not in f.variants: + return False + + name = cpp.name(f.func) + has_tensor_options = python.has_tensor_options(f) + return has_tensor_options or name.endswith("_like") + + +@with_native_function +def process_function(f: NativeFunction) -> str | None: + name = cpp.name(f.func) + has_tensor_options = python.has_tensor_options(f) + is_factory = has_tensor_options or name.endswith("_like") + + if Variant.function not in f.variants or not is_factory: + return None + + cpp_sigs = CppSignatureGroup.from_native_function(f, method=False) + sigs = [cpp_sigs.signature] + if cpp_sigs.symint_signature is not None: + sigs.append(cpp_sigs.symint_signature) + r = "" + for sig in sigs: + formals: list[str] = [] + exprs: list[str] = [] + requires_grad = "false" + for arg in sig.arguments(): + qualified_type = fully_qualified_type(arg.type) + if arg.default: + formals.append(f"{qualified_type} {arg.name} = {arg.default}") + else: + formals.append(f"{qualified_type} {arg.name}") + + if isinstance(arg.argument, TensorOptionsArguments): + # note: we remove the requires_grad setting from the TensorOptions because + # it is ignored anyways (and we actually have an assertion that it isn't set + # which would fail otherwise). We handle requires_grad explicitly here + # instead of passing it through to the kernel. + exprs.append( + f"at::TensorOptions({arg.name}).requires_grad(::std::nullopt)" + ) + # Manually set the requires_grad bit on the result tensor. + requires_grad = f"{arg.name}.requires_grad()" + else: + exprs.append(arg.name) + + r += f"""\ +inline at::Tensor {sig.name()}({', '.join(formals)}) {{ + at::AutoDispatchBelowADInplaceOrView guard; + return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad}); +}} +""" + return r diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_variable_type.py b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_variable_type.py new file mode 100644 index 0000000000000000000000000000000000000000..4bec1871ae483ea7b12f7c3ff9ecc6198ea8c383 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_variable_type.py @@ -0,0 +1,2180 @@ +# Generates VariableType.h/cpp +# +# **If any changes are being made to the VariableType codegen please also check +# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp +# +# VariableType is a subclass of at::Type that provides the binding code +# necessary to provide a differentiable version of ATen operators. There are a +# number of different things we could mean: +# +# - Given a non-differentiable forward implementation, we might +# directly associate it with a backward implementation to make +# it differentiable. This is the common case. +# +# - Some functions don't need a backwards implementation, because +# backpropagation will never propagate beyond them. There are a +# number of different reasons why this may be the case: +# +# - The function has no differentiable inputs +# - The function's output is not differentiable +# - The function has no data dependency on its input +# +# - Some function don't need a backwards implementation because they +# are implemented as a composition of other (differentiable) ATen +# functions. These are dispatched directly to the Type superclass, +# which will in turn dispatch back to VariableType for its +# differentiable subcomponents. +# + +from __future__ import annotations + +import re +from typing import Callable, Sequence + +from torchgen.api import cpp +from torchgen.api.autograd import ( + DifferentiableInput, + dispatch_strategy, + ForwardDerivative, + gen_differentiable_outputs, + is_differentiable, + NativeFunctionWithDifferentiabilityInfo, + SavedAttribute, +) +from torchgen.api.types import ( + ArrayRefCType, + BaseCppType, + BaseCType, + Binding, + DispatcherSignature, + intArrayRefT, + iTensorListRefT, + ListCType, + MutRefCType, + OptionalCType, + scalarT, + SpecialArgName, + stringT, + symIntArrayRefT, + TENSOR_LIST_LIKE_CTYPES, + tensorListT, + tensorT, + TupleCType, + VectorCType, +) +from torchgen.code_template import CodeTemplate +from torchgen.context import ( + native_function_manager, + with_native_function, + with_native_function_and, +) +from torchgen.model import ( + Argument, + BaseType, + ListType, + NativeFunction, + SchemaKind, + SelfArgument, + TensorOptionsArguments, +) +from torchgen.utils import FileManager, mapMaybe + +from .context import with_native_function_with_differentiability_info_and_key +from .gen_inplace_or_view_type import ( + ALL_VIEW_FUNCTIONS, + ASSIGN_RETURN_VALUE, + AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION, + gen_formals, + get_base_name, + get_view_info, + is_tensor_list_type, + is_tensor_type, + METHOD_DEFINITION, + modifies_arguments, + TMP_VAR, + unpack_args, + unpacked_name, + use_derived, + WRAPPER_REGISTRATION, +) +from .gen_trace_type import ( + get_return_value, + MANUAL_AUTOGRAD_AND_TRACER, + MANUAL_BACKEND, + tie_return_values, + type_wrapper_name, +) + + +# We don't set or modify grad_fn on these methods. Generally, they return +# tensors that have requires_grad=False. In-place functions listed here will +# not examine or modify requires_grad or grad_fn. +# NB: this does NOT include overload name +DONT_REQUIRE_DERIVATIVE = { + # These only depend on the input Tensor's shape and device, not the data + "empty_like", + "ones_like", + "full_like", + "zeros_like", + "rand_like", + "randn_like", + "new_empty", + "new_empty_strided", + "new_full", + "new_zeros", + "new_ones", + # These are only implemented on integral types + "__and__", + "__iand__", + "__ilshift__", + "__ior__", + "__irshift__", + "__ixor__", + "__lshift__", + "__or__", + "__rshift__", + "__xor__", + # These work on integral data types, and hence don't require derivative + "_sobol_engine_draw", + "_sobol_engine_ff", + "_sobol_engine_scramble_", + "_sobol_engine_initialize_state_", + # This is an unsafe method that is meant to be out of reach of autograd. + "_coalesced_", + # Quantize functions should not record gradients + "quantize_per_tensor", + "quantize_per_channel", + # Functions that return integers should not have output that require gradients + "argmax", + "argmin", + "argsort", + "searchsorted", + "bucketize", + # Functions that return booleans are not differentiable + "isnan", + "isposinf", + "isneginf", + "isinf", + "signbit", + "isin", + "allclose", + # Functions return none are not differentiable + "record_stream", + # These functions are not differentiable + "logical_and", + "logical_xor", + "logical_not", + "logical_or", + # This function returns nested_tensor shape as a tensor that is non-differentiable + "_nested_tensor_size", + "_nested_tensor_strides", + "_nested_tensor_storage_offsets", +} + +# The C -> R functions at the time of adding this are still being audited and tested +# but will not error out. +# C -> C, R -> C functions for which backward is correctly implemented and tested +GRADIENT_IMPLEMENTED_FOR_COMPLEX = { + "fill", + "t", + "t_copy", + "view", + "reshape", + "reshape_as", + "view_as", + "view_copy", + "roll", + "clone", + "block_diag", + "diag_embed", + "repeat", + "expand", + "expand_copy", + "flip", + "fliplr", + "flipud", + "rot90", + "nanmean", + "nansum", + "transpose", + "permute", + "squeeze", + "unsqueeze", + "unsqueeze_copy", + "resize", + "resize_as", + "tril", + "triu", + "chunk", + "zero_", + "eq_", + "ne_", + "add", + "__radd__", + "sum", + "_conj", + "sin", + "cos", + "mul", + "sinc", + "sinh", + "cosh", + "__rmul__", + "sgn", + "asin", + "acos", + "sub", + "div", + "cat", + "view_as_complex", + "index_put", + "neg", + "complex", + "select", + "where", + "as_strided", + "as_strided_copy", + "as_strided_scatter", + "slice", + "constant_pad_nd", + "unbind", + "split", + "split_with_sizes", + "unsafe_split", + "split_with_sizes_backward", + "dot", + "vdot", + "cholesky", + "triangular_solve", + "mm", + "_unsafe_view", + "mv", + "outer", + "bmm", + "diagonal", + "alias", + "atan", + "log", + "log10", + "log1p", + "log2", + "logaddexp", + "logsumexp", + "logcumsumexp", + "reciprocal", + "tan", + "pow", + "rsqrt", + "tanh", + "tanh_backward", + "asinh", + "acosh", + "atanh", + "take", + "fill_", + "exp", + "exp2", + "expm1", + "nonzero", + "mean", + "std_mean", + "var_mean", + "inverse", + "solve", + "linalg_cholesky", + "addcmul", + "addcdiv", + "matrix_exp", + "linalg_matrix_exp", + "_linalg_eigh", + "cholesky_solve", + "linalg_qr", + "_linalg_svd", + "_fft_c2c", + "_fft_r2c", + "linalg_solve", + "sqrt", + "stack", + "gather", + "index_select", + "index_add_", + "linalg_inv", + "linalg_inv_ex", + "baddbmm", + "addbmm", + "addmm", + "addmv", + "addr", + "linalg_householder_product", + "ormqr", + "reflection_pad1d", + "reflection_pad2d", + "reflection_pad3d", + "linalg_cholesky_ex", + "linalg_eig", + "diagonal_copy", + "diagonal_scatter", + "alias_copy", + "select_backward", + "diagonal_backward", + "slice_backward", + "reflection_pad1d_backward", + "reflection_pad2d_backward", + "reflection_pad3d_backward", + "_sparse_sparse_matmul", + "replication_pad1d", + "replication_pad2d", + "replication_pad3d", + "put", + "put_", + "_to_copy", + "replication_pad1d_backward", + "replication_pad2d_backward", + "replication_pad3d_backward", + "diag", + "masked_scatter", + "masked_select", + "index_add", + "index_fill", + "trace", + "polar", + "cumsum", + "rsub", + "eig", + "lerp", + "linalg_vector_norm", + "cumprod", + "prod", + "index_copy", + "lu", + "unfold", + "unfold_backward", + "index", + "masked_fill", + "masked_scatter_backward", + "linalg_cross", + "lu_unpack", + "renorm", + "_conj_physical", + "linalg_lu_factor_ex", + "scatter", + "scatter_add", + "sigmoid", + "sigmoid_backward", + "sparse_mask", + "trapezoid", + "cumulative_trapezoid", + "conj_physical_", + "_neg_view", + "_reshape_alias", + "_reshape_copy", + "_linalg_det", + "lu_solve", + "linalg_solve_triangular", + "linalg_pinv", + "linalg_lstsq", + "unfold_copy", + "col2im", + "im2col", + "cholesky_inverse", + "to_sparse", + "sparse_sampled_addmm", + "linalg_lu", + "pixel_shuffle", + "pixel_unshuffle", + "channel_shuffle", + "linalg_lu_solve", + "_linalg_slogdet", + "_linalg_solve_ex", + "_unsafe_index", + "_unsafe_index_put", + "_unsafe_masked_index", + "_unsafe_masked_index_put_accumulate", +} + +GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = { + "_to_dense", + "_coalesce", + "coalesce", + "values", + "_sparse_coo_tensor_with_dims_and_tensors", + "_sparse_addmm", +} + +GRADIENT_IMPLEMENTED_FOR_COMPLEX.update(GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX) + +# Some operators invalidate the grad_accumulator. Let's reset it. +RESET_GRAD_ACCUMULATOR = {"set_", "resize_"} + +# NOTE [ TensorImpl and Storage Pointer Sanity Checks ] +# +# We check the following properties: +# 1) A function should never change the input tensors' underlying c10::TensorImpl +# pointers or c10::Storage pointers, even if it modifies its input tensors (via +# inplace or out-variants) +# If the function does not modify its arguments, we also check the following properties +# pertaining to its output: +# 2) Its TensorImpl has use_count of 1 +# 3) If the function is a view function, it has the same StorageImpl as that of +# the input it is aliased with. Otherwise, its StorageImpl has use_count of 1 +# +# The following code templates implement the checks for this invariant: +SAVE_TENSOR_STORAGE = CodeTemplate( + """\ +auto ${tensor_name}_storage_saved = + ${tensor_name}.has_storage() ? ::std::optional(${tensor_name}.storage()) : ::std::nullopt; +""" +) + + +# If tensor_name == out_tensor_name, used to enforce (1), otherwise used for (2) +ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate( + """\ +if (${tensor_name}_storage_saved.has_value() && + !at::impl::dispatch_mode_enabled() && + !at::impl::tensor_has_dispatch(${tensor_name}) && + !at::impl::tensor_has_dispatch(${out_tensor_name})) + TORCH_INTERNAL_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${out_tensor_name}.storage())); +""" +) + +SAVE_TENSORLIST_STORAGE = CodeTemplate( + """\ +std::vector<::std::optional> ${tensorlist_name}_storage_saved(${tensorlist_name}.size()); +for (const Tensor& tensor : ${tensorlist_name}) + ${tensorlist_name}_storage_saved.push_back( + tensor.has_storage() ? ::std::optional(tensor.storage()) : ::std::nullopt); +""" +) + +ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate( + """\ +for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) { + if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name})) + TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(${tensorlist_name}[i].storage())); +} +""" +) + +SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate( + """\ +std::vector<::std::optional> ${tensorlist_name}_storage_saved(${tensorlist_name}.size()); +for (const ::std::optional& tensor : ${tensorlist_name}) + ${tensorlist_name}_storage_saved.push_back( + tensor.has_value() && tensor->has_storage() ? ::std::optional(tensor->storage()) : ::std::nullopt); +""" +) + +ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate( + """\ +for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) { + if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name})) + TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of( + static_cast<::std::optional>(${tensorlist_name}[i])->storage())); +} +""" +) + +SAVE_TENSOR_IMPL = CodeTemplate( + """\ +c10::intrusive_ptr ${tensor_name}_impl_saved; +if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr(); +""" +) + +ENFORCE_SAME_TENSOR_IMPL = CodeTemplate( + """\ +if (${tensor_name}_impl_saved && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) + TORCH_INTERNAL_ASSERT(${tensor_name}_impl_saved == ${tensor_name}.getIntrusivePtr()); +""" +) + +ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE = CodeTemplate( + """\ +if (!at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) + TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() <= 1, "function: ${fn_name}"); +""" +) + +ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE = CodeTemplate( + """\ +if (${tensor_name}.has_storage() && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) { + TORCH_INTERNAL_ASSERT(${tensor_name}.storage().use_count() == 1, "function: ${fn_name}"); +} +""" +) + +SAVE_TENSORLIST_IMPL = CodeTemplate( + """\ +std::vector> ${tensorlist_name}_impl_saved(${tensorlist_name}.size()); +for (size_t i=0; i<${tensorlist_name}.size(); i++) + if (${tensorlist_name}[i].defined()) ${tensorlist_name}_impl_saved[i] = ${tensorlist_name}[i].getIntrusivePtr(); +""" +) + +ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate( + """\ +for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) { + if (${tensorlist_name}_impl_saved[i] && !at::impl::tensorlist_has_dispatch(${tensorlist_name})) + TORCH_INTERNAL_ASSERT(${tensorlist_name}_impl_saved[i] == ${tensorlist_name}[i].getIntrusivePtr()); +} +""" +) + +SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate( + """\ +std::vector> ${tensorlist_name}_impl_saved(${tensorlist_name}.size()); +for (size_t i=0; i<${tensorlist_name}.size(); i++) { + ::std::optional t = ${tensorlist_name}[i]; + if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr(); +} +""" +) + +ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate( + """\ +for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) { + if (${tensorlist_name}_impl_saved[i]) + TORCH_INTERNAL_ASSERT( + ${tensorlist_name}_impl_saved[i] == static_cast<::std::optional>(${tensorlist_name}[i])->getIntrusivePtr()); +} +""" +) + +# The following list contains functions that we don't enforce the invariant on. +DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = { + # These functions are expected to change impl or storage of input tensors + "set_", + "_cudnn_rnn_flatten_weight", + "_unsafe_masked_index", + "_unsafe_masked_index_put_accumulate", +} +DONT_ENFORCE_TENSOR_IMPL_USE_COUNT = { + # These non-inplace, non-out functions return tensors with use_count > 1 + # Therefore, they MAY (but not necessarily) return one of its inputs as-is + # See https://github.com/pytorch/pytorch/issues/60426 for more information + "_embedding_bag", + "_embedding_bag_forward_only", + "q_per_channel_scales", + "q_per_channel_zero_points", + "lu_unpack", + "_cudnn_rnn_backward", + # The below failed StorageImpl use_count check but we skip tensor_impl check + # just in case + "_cudnn_rnn", + "dequantize_self", + # lift() should never actually be called with a requires_grad=True tensor, + "lift", + "lift_fresh", + "lift_fresh_copy", + # Nested Tensors related functions + # _nested_tensor_size() should never actually be called with requires_grad=True tensor + "_nested_tensor_size", + "_nested_tensor_strides", + "_nested_tensor_storage_offsets", +} + +DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = { + # These non-view functions return tensors with storage use_count != 1 + "_slow_conv2d_forward", + "slow_conv3d_forward", + "channel_shuffle", + # If an input is returned as-is in output, we cannot guarantee its storage_impl + # use count to be 1 either. + *DONT_ENFORCE_TENSOR_IMPL_USE_COUNT, +} +# END CHECKS FOR [ TensorImpl and Storage Pointer Sanity Checks ] + +DECLARE_GRAD_FN = CodeTemplate( + """\ +std::shared_ptr<${op}> grad_fn; +""" +) + +DECLARE_VECTOR_OF_GRAD_FN = CodeTemplate( + """\ +std::vector> grad_fns; +""" +) + +SETUP_ANY_REQUIRES_GRAD = CodeTemplate( + """\ +[[maybe_unused]] auto _any_requires_grad = compute_requires_grad( ${args_with_derivatives} ); +${extra_differentiability_conditions} +""" +) + +SETUP_DERIVATIVE = CodeTemplate( + """\ +if (_any_requires_grad) { + ${setup} +} +""" +) + +SETUP_NONE_REQUIRES_GRAD = CodeTemplate( + """\ +if (compute_requires_grad( ${args_to_check} )) { + throw_error_out_requires_grad("${base_name}"); +} +""" +) + +ASSIGN_GRAD_FN = CodeTemplate( + """\ +grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode); +grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} )); +""" +) + +# note(crcrpar): `compute_requires_grad` in the template below is supplied with arguments indexed with `i` +# while the `SETUP_ANY_REQUIRES_GRAD` above takes whole tensors and scalars. +ASSIGN_VECTOR_OF_GRAD_FN = CodeTemplate( + """\ +for (const auto& i : c10::irange( ${irange} )) { + const auto ith_requires_grad = compute_requires_grad(${args_with_derivatives}); + check_inplace(self[i], ith_requires_grad); + grad_fns.push_back([&]() -> std::shared_ptr<${op}> { + if (!ith_requires_grad) { + return nullptr; + } else { + auto grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode); + grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} )); + return grad_fn; + } + }()); +} +""" +) + +CALL_REDISPATCH = CodeTemplate( + """\ +at::redispatch::${api_name}(${unpacked_args})""" +) +# If the non-variable operation has return values, we use the `tmp` variable to hold the +# values temporarily and pass the values to the return variables outside of the +# `at::AutoDispatchBelowAutograd` guard block. +DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP = CodeTemplate( + """\ +auto ${tmp_var} = ([&]() { + if (${any_has_forward_grad}) { + static c10::OperatorName full_name("aten::${op_name}", "${op_overload}"); + static ::std::optional opt_op = c10::Dispatcher::singleton().findSchema(full_name); + return impl::run_jit_decomposition_with_args_for_jvp<${return_types}>("${op_name}", *opt_op, ks, ${arg_names}); + } else { + ${guard} + return ${base_type_call}; + } +})(); +""" +) + +DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate( + """\ +auto ${tmp_var} = ([&]() { + ${guard} + return ${base_type_call}; +})(); +""" +) + +DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate( + """\ +{ + ${guard} + ${base_type_call}; +} +""" +) + +SET_HISTORY = CodeTemplate( + """\ +if (grad_fn) { + ${fn}_history(${differentiable_outputs}, grad_fn); +} +""" +) + +LOOP_OVER_VECTOR_OF_GRAD_FNS = CodeTemplate( + """\ +if (!grad_fns.empty()) { + ${preamble} + for (const auto& i : c10::irange(grad_fns.size())) { + auto grad_fn = grad_fns[i]; + if (grad_fn != nullptr) { + ${statements} + } + } +} +""" +) + +CONDITIONAL = CodeTemplate( + """\ +if (${cond}) { + ${statements} +} +""" +) + +RUN_ONLY_IN_DEBUG_MODE = CodeTemplate( + """\ +#ifndef NDEBUG +${statements} +#endif +""" +) + +FW_DERIVATIVE_CHECK_TEMPLATE = CodeTemplate( + """\ +isFwGradDefined(${req_inp})\ +""" +) +FW_DERIVATIVE_SIZE_CHECK_TEMPLATE = CodeTemplate( + """\ +TORCH_CHECK( + self.size() == ${inp_name}.size(), + "Tensor lists must have the same number of tensors, got ", + self.size(), + " and ", + ${inp_name}.size()); +""" +) + +FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE = CodeTemplate( + """\ +isFwGradDefinedTensorList(${req_inp})\ +""" +) + +FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate( + """\ +auto ${inp_name}_t_raw = toNonOptFwGrad(${inp}); +auto ${inp_name}_tensor = toNonOptTensor(${inp}); +auto ${inp_name}_t = (${inp_name}_t_raw.defined() || !${inp_name}_tensor.defined()) + ? ${inp_name}_t_raw : at::${zeros_fn}(${inp_name}_tensor.sym_sizes(), ${inp_name}_tensor.options()); +""" +) + +FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate( + """\ +auto ${inp_name}_p = toNonOptPrimal(${inp}); +""" +) + +FW_DERIVATIVE_SETTER_TENSOR = CodeTemplate( + """\ +if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}.defined()) { + // The hardcoded 0 here will need to be updated once we support multiple levels. + ${out_arg}._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace}); +} +""" +) + +FW_DERIVATIVE_SETTER_TENSOR_FOREACH = CodeTemplate( + """\ +for (const auto& i : c10::irange(${out_arg}_new_fw_grad_opts.size())) { + auto& ${out_arg}_new_fw_grad_opt = ${out_arg}_new_fw_grad_opts[i]; + if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}[i].defined()) { + // The hardcoded 0 here will need to be updated once we support multiple levels. + ${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace}); + } +} +""" +) + +FW_DERIVATIVE_SETTER_MULTI_OUTPUT = CodeTemplate( + """\ +if (${all_res}_new_fw_grad_opt.has_value() && std::get<${idx}>(${all_res}_new_fw_grad_opt.value()).defined() + && ${out_arg}.defined()) { + ${out_arg}._set_fw_grad(std::get<${idx}>(${all_res}_new_fw_grad_opt.value()), /* level */ 0, /* is_inplace_op */ false); +} +""" +) + +FW_DERIVATIVE_SETTER_TENSOR_LIST = CodeTemplate( + """\ +if (${out_arg}_new_fw_grad_opt.has_value()) { + auto ${out_arg}_new_fw_grad = ${out_arg}_new_fw_grad_opt.value(); + TORCH_INTERNAL_ASSERT(${out_arg}.size() == ${out_arg}_new_fw_grad.size()); + for (const auto i : c10::irange(${out_arg}.size())) { + if (${out_arg}_new_fw_grad[i].defined() && ${out_arg}[i].defined()) { + // The hardcoded 0 here will need to be updated once we support multiple levels. + ${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad[i], /* level */ 0, /* is_inplace_op */ ${is_inplace}); + } + } +} +""" +) + +FW_DERIVATIVE_TEMPLATE = CodeTemplate( + """\ +${fw_grad_opt_definition} +if (${requires_fw_grad}) { + ${unpacked_arguments} + ${out_arg}_new_fw_grad_opt = ${formula}; +} +""" +) + +FW_DERIVATIVE_FOREACH_TEMPLATE = CodeTemplate( + """\ +${fw_grad_opt_definition} +for (const auto& i : c10::irange(${vector_of_optional_tensor}.size())) { + if (${any_has_forward_grad_for_current_index}) { + ${unpacked_arguments} + ${vector_of_optional_tensor}[i] = ${formula}; + } +} +""" +) + +FW_DERIVATIVE_FORBID_TEMPLATE = CodeTemplate( + """\ +TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}"); +""" +) + +FW_DERIVATIVE_FORBID_LIST_TEMPLATE = CodeTemplate( + """\ +for (const auto& _t: ${arg}) { + TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}"); +} +""" +) + + +def gen_variable_type( + out: str, + native_yaml_path: str, + tags_yaml_path: str, + fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo], + template_path: str, + used_keys: set[str], +) -> None: + """VariableType.h and VariableType.cpp body + + This is the at::Type subclass for differentiable tensors. The + implementation of each function dispatches to the base tensor type to + compute the output. The grad_fn is attached to differentiable functions. + """ + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write( + "VariableType.h", + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/VariableType.h" + }, + ) + + # helper that generates a TORCH_LIBRARY_IMPL macro for each + # dispatch key that appears in derivatives.yaml + def wrapper_registrations(used_keys: set[str]) -> str: + library_impl_macro_list: list[str] = [] + for key in sorted(used_keys): + dispatch_key = key + if key == "Default": + dispatch_key = "Autograd" + library_impl_macro = ( + f"TORCH_LIBRARY_IMPL(aten, {dispatch_key}, m) " + + "{\n" + + "${" + + f"wrapper_registrations_{key}" + + "}\n}" + ) + library_impl_macro_list += [library_impl_macro] + return "\n\n".join(library_impl_macro_list) + + # Generate a new template from VariableType.cpp which replaces ${wrapper_registrations} + # with per key TORCH_LIBRARY_IMPL macros for each key that appears in derivatives.yaml + fm1 = FileManager( + install_dir=out + "/templates", template_dir=template_path, dry_run=False + ) + fm1.write( + "VariableType.cpp", + lambda: { + "type_derived_method_definitions": "\n\n".join( + [ + "${" + f"type_derived_method_definitions_{key}" + "}" + for key in sorted(used_keys) + ] + ), + "wrapper_registrations": wrapper_registrations(used_keys), + }, + ) + + # Generate final VariableType_*.cpp files from the generated template + fm2 = FileManager(install_dir=out, template_dir=out + "/templates", dry_run=False) + + sharded_keys = set( + [f"type_derived_method_definitions_{key}" for key in sorted(used_keys)] + + [f"wrapper_registrations_{key}" for key in sorted(used_keys)] + ) + # NOTE: see Note [Sharded File] at the top of the VariableType.cpp + # template regarding sharding of the generated files. + fm2.write_sharded( + "VariableType.cpp", + [fn for fn in fns_with_diff_infos if use_derived(fn)], + key_fn=lambda fn: cpp.name(fn.func.func), + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/VariableType.cpp", + }, + env_callable=gen_variable_type_func, + num_shards=5, + sharded_keys=sharded_keys, + ) + + +@with_native_function_and +def gen_wrapper_registration(f: NativeFunction, key: str = "Default") -> str: + return WRAPPER_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name, + type_wrapper_name=type_wrapper_name(f, key), + class_type="VariableType", + ) + + +def gen_variable_type_func( + fn: NativeFunctionWithDifferentiabilityInfo, +) -> dict[str, list[str]]: + f = fn.func + result = {} + with native_function_manager(f): + name = cpp.name(f.func) + formals = gen_formals(f) + + if ( + fn.info is None + and str(f.func.name.name) not in RESET_GRAD_ACCUMULATOR + and get_base_name(f) not in DONT_REQUIRE_DERIVATIVE + and len(gen_differentiable_outputs(fn)) > 0 + and cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE + and type_wrapper_name(f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT + and type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT + ): + # NOTE: [ Registering AutogradNotImplemented boxed kernel ] + # + # When there is no derivatives.yaml entry, we register a generic boxed + # NotImplemented kernel to set grad_fn to be NotImplemented, so that forward + # proceeds as usual but an error is properly produced on backward. + # TODO: it would be nice to not have these special cases + # + # There are several cases where still let codegen handle it: + # 1) ops that need to reset grad accumulator (we let codegen handle this case + # because) the list is (currently) only accessible in Python. + # 2) User explicitly specifies DONT_REQUIRE_DERIVATIVE. This basically makes + # autograd a fallthrough with NDEBUG checks. This can be useful for when all + # outputs are integral. + # 3) When there are no differentiable outputs. This is similar to (2). + # 4) There are certain ops where we skip certain NDEBUG checks. this is similar + # to (1). + type_definition = "" + wrapper_registration = AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name + ) + result["type_derived_method_definitions_Default"] = [type_definition] + result["wrapper_registrations_Default"] = [wrapper_registration] + else: + if not fn.info: + key = "Default" + type_definition = METHOD_DEFINITION.substitute( + return_type=cpp.returns_type( + f.func.returns, symint=True + ).cpp_type(), + type_wrapper_name=type_wrapper_name(f, key), + type_definition_body=emit_body(fn, key), + formals=formals, + ) + wrapper_registration = gen_wrapper_registration(f, key) + result[f"type_derived_method_definitions_{key}"] = [type_definition] + result[f"wrapper_registrations_{key}"] = [wrapper_registration] + else: + for key in fn.info.keys(): + type_definition = METHOD_DEFINITION.substitute( + return_type=cpp.returns_type( + f.func.returns, symint=True + ).cpp_type(), + type_wrapper_name=type_wrapper_name(f, key), + type_definition_body=emit_body(fn, key), + formals=formals, + ) + wrapper_registration = gen_wrapper_registration(f, key) + result[f"type_derived_method_definitions_{key}"] = [type_definition] + result[f"wrapper_registrations_{key}"] = [wrapper_registration] + # See Note [Manual Backend kernels] + assert (name in MANUAL_BACKEND) == f.manual_kernel_registration + # If you want to register a kernel to Autograd, you must make the op abstract. + # In other words, this op must have dispatch section in native_functions.yaml. + if name in MANUAL_AUTOGRAD_AND_TRACER or ( + fn.info and any(info.has_derivatives for info in fn.info.values()) + ): + msg = ( + f"There's a formula for {name}(or its functional variant) in derivatives.yaml. " + f"It's required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA " + f"or CompositeExplicitAutograd in native_functions.yaml. Please see " + f"https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword " + f"for instructions to choose the right dispatch keyword." + ) + assert f.is_abstract, msg + + return result + + +_foreach_ops_without_differentiability_info = { + # No reference backward available due to the lack of `{maximum, minimum}(tensor, scalar)`. + ("_foreach_maximum", "Scalar"), + ("_foreach_maximum", "ScalarList"), + ("_foreach_minimum", "Scalar"), + ("_foreach_minimum", "ScalarList"), + # No reference backward available as addcdiv/addcmul don't support Tensor as scaling factor. + ("_foreach_addcdiv", "Tensor"), + ("_foreach_addcmul", "Tensor"), + ("_foreach_copy", ""), +} + +_foreach_ops_with_different_arity = { + # These ops lack `alpha` of scaling factor to applied to the right hand side argument. + ("_foreach_add", "Scalar"), + ("_foreach_add", "ScalarList"), + ("_foreach_sub", "Scalar"), + ("_foreach_sub", "ScalarList"), +} + + +@with_native_function_with_differentiability_info_and_key +def emit_body( + fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" +) -> list[str]: + assert dispatch_strategy(fn) == "use_derived" + f = fn.func + info = fn.info[key] if fn.info else None + fw_derivatives = fn.fw_derivatives.get(key, []) if fn.fw_derivatives else [] + + name = cpp.name(f.func) + inplace = f.func.kind() == SchemaKind.inplace + is_out_fn = f.func.kind() == SchemaKind.out + returns_void = len(f.func.returns) == 0 + base_name = get_base_name(f) + view_info = get_view_info(f) + + is_foreach = name.startswith("_foreach") + is_inplace_foreach = is_foreach and inplace + if is_inplace_foreach: + inplace_foreacharg2refarg: dict[Argument, Argument] = {} + refargname2inplace_foreacharg: dict[str, Argument] = {} + base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name) + if info is None: + assert ( + base_name_and_overload_name + in _foreach_ops_without_differentiability_info + ), f"{'.'.join(base_name_and_overload_name)} should have a differentiability info" + else: + assert ( + len(f.func.arguments.flat_non_out) + == len(info.func.func.arguments.flat_non_out) + ) or (base_name_and_overload_name in _foreach_ops_with_different_arity), ( + f"{'.'.join(base_name_and_overload_name)} has {len(f.func.arguments.flat_non_out)} args " + f"but the reference has {len(info.func.func.arguments.flat_non_out)}" + ) + for foreach_arg, ref_arg in zip( + f.func.arguments.flat_non_out, info.func.func.arguments.flat_non_out + ): + foreach_arg_type = foreach_arg.type + if isinstance(foreach_arg_type, ListType): + foreach_arg_type = foreach_arg_type.elem + assert foreach_arg_type == ref_arg.type + inplace_foreacharg2refarg[foreach_arg] = ref_arg + refargname2inplace_foreacharg[ref_arg.name] = foreach_arg + + def gen_differentiable_input( + arg: Argument | SelfArgument | TensorOptionsArguments, + ) -> DifferentiableInput | None: + if isinstance(arg, TensorOptionsArguments): + return None + a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg + + # TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove. + # NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are + # not handled properly as they are irrelevant for this codegen. + cpp_type = cpp.argument_type(a, binds=a.name, symint=True).cpp_type() + + if not is_differentiable(a.name, a.type, info): + return None + return DifferentiableInput( + name=a.name, + type=a.type, + cpp_type=cpp_type, + ) + + @with_native_function + def gen_differentiable_inputs(f: NativeFunction) -> list[DifferentiableInput]: + arguments = list(f.func.arguments.non_out) + if is_inplace_foreach and info is not None: + for i, arg in enumerate(f.func.arguments.flat_non_out): + if arg in inplace_foreacharg2refarg: + # note(crcrpar): From what I understand, what matters is only the name. + # Thus originally I only replace argument only when the names are different. + # TODO(crcrpar): Make it simpler. + mapped_arg = inplace_foreacharg2refarg[arg] + arguments[i] = Argument( + mapped_arg.name, + mapped_arg.type, + mapped_arg.default, + mapped_arg.annotation, + ) + return list(mapMaybe(gen_differentiable_input, arguments)) + + def find_args_with_derivatives( + differentiable_inputs: list[DifferentiableInput], + ) -> list[DifferentiableInput]: + """Find arguments that have derivative definitions""" + if info is None or not info.has_derivatives: + return differentiable_inputs + names = {name for d in info.derivatives for name in d.var_names} + differentiable = [arg for arg in differentiable_inputs if arg.name in names] + if len(differentiable) != len(names): + missing = names - {arg.name for arg in differentiable} + raise RuntimeError( + f"Missing arguments for derivatives: {missing} in {info.name}" + ) + return differentiable + + differentiable_inputs = gen_differentiable_inputs(f) + args_with_derivatives = find_args_with_derivatives(differentiable_inputs) + differentiable_outputs = gen_differentiable_outputs(fn, key) + + undifferentiable = (base_name in DONT_REQUIRE_DERIVATIVE) or ( + name in DONT_REQUIRE_DERIVATIVE + ) + + requires_derivative = ( + (not undifferentiable) + and (len(differentiable_inputs) > 0) + and ( + (len(differentiable_outputs) > 0) + # note(crcrpar): In-place foreach functions are a void function. + or is_inplace_foreach + ) + ) + + if ( + info is not None + and info.has_derivatives + and not requires_derivative + # out= ops are allowed to have zero returns which cause requires_derivative to be False + # we shouldn't error out though (out= ops for autograd just redispatch) + and len(f.func.returns) > 0 + ): + raise RuntimeError( + f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative" + ) + + # note(crcrpar): In-place foreach functions do not support forward AD + if requires_derivative and len(fw_derivatives) > 0 and not is_inplace_foreach: + assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len( + differentiable_outputs + ), ( + "Expected the number of forward derivatives implemented to match the " + "number of differentiable outputs. NB: This only applies when at least " + "one forward derivative is implemented. Not implementing any forward " + "derivatives is also okay, and we would require inputs to the op to " + "not have associated tangents in that case." + ) + + try_jit_decomposition = ( + requires_derivative + and len(fw_derivatives) == 0 + and (not modifies_arguments(f)) + and (not returns_void) + ) + + def emit_save_inputs() -> list[str]: + setup: list[str] = [] + if info is None or not info.has_derivatives: + return setup + + has_tensorlist_arg = any( + is_tensor_list_type(arg.type) for arg in args_with_derivatives + ) + + # We don't want to save tensors if we know that they will never be used + # when computing the derivative, so we add guards to those statements + def guard_for(arg: SavedAttribute) -> str | None: + assert info is not None + + # It's hard to determine the edge offset if we have TensorLists + # NOTE(crcrpar): in-place foreach functions' arguments include tensorlist + # but their derivatives don't use it, so let them bypass this check. + if has_tensorlist_arg and (not is_inplace_foreach): + return None + + # Empirical evaluation of the cases where we insert those guards in + # backward show that they are somewhat useless. E.g. there's no need + # to guard on some values captured from forward, because they had to + # require_grad if the backward function even gets executed. I don't + # have any good ideas for detecting those cases, so I simply disabled the + # checks. + if "backward" in info.name: + return None + + # If there's a single derivative we could compute, we already have + # a requires_grad check that is sufficient + if len(args_with_derivatives) <= 1: + return None + + # We really only care about trimming down the amount of tensors we save + if arg.nctype.type != BaseCType(tensorT): + return None + + # We want to emit simple guards, so we only allow that if checking one + # input is enough to determine whether we need that value + used_in = [d for d in info.derivatives if arg in d.saved_inputs] + assert len(used_in) > 0 + if len(used_in) != 1: + return None + derivative = used_in[0] + + # Case with multioutput formulas + # TODO: process all derivative formulas!!! + if len(derivative.var_names) != 1: + wrap_opt_if_start = derivative.formula.find( + f"wrap_opt_if({arg.nctype.name}" + ) + if wrap_opt_if_start == -1: + return None + + wrap_opt_if_match = re.match( + rf"wrap_opt_if\({arg.nctype.name},(.*?)\)", + derivative.formula[wrap_opt_if_start:], + ) + assert wrap_opt_if_match is not None + + # Condition is between 'wrap_opt_if(var_name,' and ')'. + condition_slice = slice(len(rf"wrap_opt_if\({arg.nctype.name},"), -1) + wrap_opt_if_condition = wrap_opt_if_match.group(0)[ + condition_slice + ].strip() + # replace 'grad_input_mask[num]' with 'grad_fn->should_compute_output(num)' + wrap_opt_if_condition = re.sub( + r"grad_input_mask\[(\d+)\]", + r"grad_fn->should_compute_output(\1)", + wrap_opt_if_condition, + ) + return f"{wrap_opt_if_condition}" + + # Figure out the offset of the edge that uses this variable + derivative_var_name = derivative.var_names[0] + for edge_off, a in enumerate(args_with_derivatives): + if a.name == derivative_var_name: + break + else: + raise AssertionError + return f"grad_fn->should_compute_output({edge_off})" + + if is_inplace_foreach: + save_input_stmts = save_variables(info.all_saved_inputs, False, guard_for) + if save_input_stmts: + setup.append( + LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute( + preamble="", statements=save_input_stmts + ) + ) + else: + setup.extend(save_variables(info.all_saved_inputs, False, guard_for)) + for arg in args_with_derivatives: + if is_tensor_list_type(arg.type): + setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();") + return setup + + def setup_derivative(differentiable_inputs: list[DifferentiableInput]) -> list[str]: + body: list[str] = [] + if is_out_fn: + # For out functions, ensure that no input or output requires grad + body.append(DECLARE_GRAD_FN.substitute(op="Node")) + body.append( + SETUP_NONE_REQUIRES_GRAD.substitute( + base_name=base_name, + args_to_check=[arg.name for arg in differentiable_inputs], + ) + ) + body.append( + SETUP_NONE_REQUIRES_GRAD.substitute( + base_name=base_name, + args_to_check=[arg.name for arg in differentiable_outputs], + ) + ) + return body + + op = info.op if info is not None and info.has_derivatives else "NotImplemented" + setup = [] + if not is_inplace_foreach: + setup.extend( + ASSIGN_GRAD_FN.substitute( + op=op, + op_ctor="" + if info is not None and info.has_derivatives + else f'"{cpp.name(f.func)}"', + args_with_derivatives=[arg.name for arg in args_with_derivatives], + ).split("\n") + ) + else: + # note(crcrpar): Assuming in-place foreach function's self_arg is always TensorList. + list_like_arg = "self" + args = [arg.name for arg in args_with_derivatives] + for i, arg in enumerate(args): + if is_inplace_foreach and info is not None: + if arg in refargname2inplace_foreacharg: + foreach_arg = refargname2inplace_foreacharg[arg] + args[i] = foreach_arg.name + ( + "[i]" if isinstance(foreach_arg.type, ListType) else "" + ) + else: + if arg == list_like_arg: + args[i] = arg + "[i]" + setup.extend( + ASSIGN_VECTOR_OF_GRAD_FN.substitute( + op=op, + op_ctor="" + if info is not None and info.has_derivatives + else f'"{cpp.name(f.func)}"', + args_with_derivatives=args, + irange=f"{list_like_arg}.size()", + ).split("\n") + ) + setup.extend(emit_save_inputs()) + + body.extend( + emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives) + ) + declare_grad_fn_template = ( + DECLARE_GRAD_FN if not is_inplace_foreach else DECLARE_VECTOR_OF_GRAD_FN + ) + body.append(declare_grad_fn_template.substitute(op=op)) + body.append(SETUP_DERIVATIVE.substitute(setup=setup)) + return body + + def emit_check_if_in_complex_autograd_allowlist() -> list[str]: + body: list[str] = [] + if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX: + return body + for arg in differentiable_outputs: + name = arg.name + # TODO: should be `arg.type.is_tensor_like()`? + if arg.cpp_type == "at::Tensor" or arg.cpp_type in TENSOR_LIST_LIKE_CTYPES: + body.append(f'throw_error_for_complex_autograd({name}, "{base_name}");') + return body + + def emit_check_no_requires_grad( + tensor_args: list[DifferentiableInput], + args_with_derivatives: list[DifferentiableInput], + ) -> list[str]: + """Checks that arguments without derivatives don't require grad""" + body: list[str] = [] + for arg in tensor_args: + if arg in args_with_derivatives: + continue + arg_name = arg.name + if info and arg_name in info.non_differentiable_arg_names: + continue + if arg_name == "output": + # Double-backwards definitions sometimes take in 'input' and + # 'output', but only define the derivative for input. + continue + body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");') + return body + + def emit_original_self_definition() -> list[str]: + body: list[str] = [] + if inplace: + if is_inplace_foreach: + body.append( + "std::vector<::std::optional> original_selfs(self.size());" + ) + else: + body.append("::std::optional original_self;") + + all_forward_grad_cond = [] + for derivative in fw_derivatives: + if derivative.required_original_self_value: + all_forward_grad_cond.append( + get_any_has_forward_grad_name(derivative.var_names) + ) + + if all_forward_grad_cond: + if not is_inplace_foreach: + body.append(f'if ({" || ".join(all_forward_grad_cond)}) {{') + body.append(" original_self = self.clone();") + body.append("}") + else: + current_all_forward_grad_cond = [ + f"{cond}[i]" for cond in all_forward_grad_cond + ] + body.append("for (const auto& i : c10::irange(self.size())) {") + body.append( + f" if ({' || '.join(current_all_forward_grad_cond)}) {{" + ) + body.append(" original_selfs[i] = self[i].clone();") + body.append(" }") + body.append("}") + + return body + + def save_variables( + saved_variables: Sequence[SavedAttribute], + is_output: bool, + guard_for: Callable[[SavedAttribute], str | None] = lambda name: None, + ) -> Sequence[str]: + # assign the saved variables to the generated grad_fn + stmts: list[str] = [] + for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)): + name = ( + arg.nctype.name.name + if isinstance(arg.nctype.name, SpecialArgName) + else arg.nctype.name + ) + foreacharg: Argument | None = None + is_foreacharg_list_type: bool = False + type = arg.nctype.type + expr = arg.expr + stmts_prepend = None + if is_inplace_foreach and info is not None: + # todo(crcrpar): See if we can add some check e.g. `assert foreacharg is not None`. + # for now the example assert would fail. + name_to_query = name.split("_scalar_type")[0] + if name_to_query in refargname2inplace_foreacharg: + foreacharg = refargname2inplace_foreacharg[name_to_query] + is_foreacharg_list_type = isinstance(foreacharg.type, ListType) + if foreacharg is not None: + name_in_expr = ( + f"{foreacharg.name}{'[i]' if is_foreacharg_list_type else ''}" + ) + src_name = name + if "_scalar_type" in src_name: + split_src_name = src_name.split("_scalar_type") + assert len(split_src_name) == 2 + src_name = split_src_name[0] + expr = expr.replace(src_name, name_in_expr) + if ( + type == BaseCType(tensorT) + or type == OptionalCType(BaseCType(tensorT)) + or type == MutRefCType(OptionalCType(BaseCType(tensorT))) + or (is_output and type == BaseCType(scalarT)) + ): + # note(crcrpar): Here `expr` is generated from scratch, `arg.expr` is ignored. + var = name + name += "_" + if var == "self" and inplace: + original_self_var = ( + "original_self" + if not is_inplace_foreach + else "original_selfs[i]" + ) + self_var = var if not is_inplace_foreach else var + "[i]" + stmts_prepend = f"if (!{original_self_var}.has_value()) {original_self_var} = {self_var}.clone()" + var = f"{original_self_var}.value()" + assert not is_output + if inplace and is_output: + assert name == "result_" + var = ( + "self[i]" + if is_inplace_foreach or is_foreacharg_list_type + else "self" + ) + is_inplace_view = f"{var}.is_view()" + expr = f"SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})" + else: + expr = f"SavedVariable({var}, {str(is_output).lower()})" + if foreacharg is not None and "original_selfs" not in expr: + expr = expr.replace(src_name, name_in_expr) + elif ( + type == BaseCType(tensorListT) + or type == ListCType(OptionalCType(BaseCType(tensorT))) + or type == BaseCType(iTensorListRefT) + or type == VectorCType(BaseCType(tensorT)) + ): + # See Note [nuanced return type of out-of-place foreach functions] + if type == VectorCType(BaseCType(tensorT)): + assert is_foreach and is_output + expr = f"make_saved_variable_list({name}, {str(is_foreach and is_output).lower()})" + name += "_" + elif type == BaseCType(intArrayRefT): + expr = expr + ".vec()" + elif type == BaseCType(symIntArrayRefT): + expr = expr + ".vec()" + elif type == BaseCType(stringT): + expr = f"std::string({expr})" + elif type == OptionalCType(BaseCType(stringT)): + expr = f"{expr}.has_value() ? ::std::optional(std::string({expr}.value())) : ::std::nullopt" + elif type == ArrayRefCType( + elem=BaseCType(type=BaseCppType(ns="at", name="Scalar")) + ): + expr = expr + ".vec()" + + guard = guard_for(arg) + if guard is None: + if stmts_prepend: + stmts.append(f"{stmts_prepend};") + stmts.append(f"grad_fn->{name} = {expr};") + else: + stmts.append(f"if ({guard}) {{") + if stmts_prepend: + stmts.append(f" {stmts_prepend};") + stmts.append(f" grad_fn->{name} = {expr};") + stmts.append("}") + return stmts + + # Generates a Dispatcher::redispatch() call into the dispatcher. We do this mainly for performance reasons: + # - Pre-compute the full DispatchKeySet. This saves the dispatcher from having to read from TLS. + # - redispatch() avoids a redundant call to RecordFunction, which was already called right before + # we entered this autograd kernel. + def emit_dispatch_call( + f: NativeFunction, input_base: str, unpacked_args: Sequence[str] + ) -> str: + """Dispatch call via function in a namespace or method on Tensor.""" + dispatcher_sig = DispatcherSignature.from_schema(f.func) + dispatcher_exprs = dispatcher_sig.exprs() + + # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance. + # Ops also always have a function variant of the redispatch API. + # See Note [Plumbing Keys Through The Dispatcher] for details. + dispatch_key_set = "ks & c10::after_autograd_keyset" + call = CALL_REDISPATCH.substitute( + api_name=cpp.name( + f.func, + faithful_name_for_out_overloads=True, + symint_overload=f.func.has_symint(), + ), + unpacked_args=[dispatch_key_set] + list(unpacked_args), + ) + return call + + def wrap_output( + f: NativeFunction, unpacked_bindings: list[Binding], var: str + ) -> str: + call = "" + rhs_value: str | None = None + if not any(r.type.is_tensor_like() for r in f.func.returns): + rhs_value = var + else: + rhs_value = f"std::move({var})" + assert rhs_value is not None + call += ASSIGN_RETURN_VALUE.substitute( + return_values=tie_return_values(f), rhs_value=rhs_value + ) + return call + + def check_tensorimpl_and_storage( + call: str, unpacked_bindings: list[Binding] + ) -> str: + # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] + stmts_before_call: list[str] = [] + stmts_after_call: list[str] = [] + + if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: + return call + + # Check properties of inputs (enforce (1)) + for unpacked_binding in unpacked_bindings: + arg = unpacked_binding.name + noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref() + if noref_cpp_type == BaseCType(tensorListT) or noref_cpp_type == BaseCType( + iTensorListRefT + ): + stmts_before_call += [ + SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), + SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg), + ] + stmts_after_call += [ + ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), + ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg), + ] + elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))): + stmts_before_call += [ + SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), + SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg), + ] + stmts_after_call += [ + ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute( + tensorlist_name=arg + ), + ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute( + tensorlist_name=arg + ), + ] + elif noref_cpp_type == BaseCType(tensorT): + stmts_before_call += [ + SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), + SAVE_TENSOR_IMPL.substitute(tensor_name=arg), + ] + stmts_after_call += [ + ENFORCE_SAME_TENSOR_STORAGE.substitute( + tensor_name=arg, out_tensor_name=arg + ), + ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg), + ] + + assert (stmts_before_call and stmts_after_call) or ( + not stmts_before_call and not stmts_after_call + ) + + # Check properties of outputs (enforce (2), (3)) + if f.func.kind() not in (SchemaKind.inplace, SchemaKind.out): + base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)? + aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None) + if aliased_arg_name is not None: + aliased_arg_name = unpacked_name(aliased_arg_name) + for i, (ret, ret_name) in enumerate( + zip(f.func.returns, cpp.return_names(f)) + ): + noref_cpp_type = cpp.return_type(ret, symint=True).remove_const_ref() + if noref_cpp_type == BaseCType(tensorT): + if aliased_arg_name is not None: + assert ( + i == 0 + ), "Expect non-CompositeImplicitAutograd view function {base} to return single output" + stmts_after_call += [ + ENFORCE_SAME_TENSOR_STORAGE.substitute( + tensor_name=aliased_arg_name, out_tensor_name=ret_name + ) + ] + else: + if ( + type_wrapper_name(f) + not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT + ): + stmts_after_call += [ + ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.substitute( + tensor_name=ret_name, fn_name=type_wrapper_name(f) + ) + ] + + if type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT: + stmts_after_call += [ + ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.substitute( + tensor_name=ret_name, fn_name=type_wrapper_name(f) + ) + ] + + # Currently we don't have any functions that return the following types, but + # we should update the checks once we do + elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))): + raise AssertionError( + f"Please add use_count checks for {noref_cpp_type}" + ) + elif noref_cpp_type == BaseCType(tensorListT): + raise AssertionError( + f"Please add use_count checks for {noref_cpp_type}" + ) + + if stmts_before_call and stmts_after_call: + call = ( + RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call) + + call + + RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call) + ) + return call + + def emit_call( + f: NativeFunction, unpacked_bindings: list[Binding], try_jit_decomposition: bool + ) -> str: + # We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch + # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure + # the baseType operations still dispatch to non-Variable type, even if the arguments passed + # in are now Variables. + # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details. + unpacked_args = [b.name for b in unpacked_bindings] + base_type_call = emit_dispatch_call(f, "self_", unpacked_args) + + if get_view_info(f) is not None or modifies_arguments(f): + guard = "at::AutoDispatchBelowAutograd guard;" + else: + guard = "at::AutoDispatchBelowADInplaceOrView guard;" + + any_has_forward_grad = ( + get_any_has_fw_grad_cond(derivative=None) + if requires_derivative + else "false" + ) + return_types = ", ".join( + [cpp.return_type(a, symint=True).cpp_type() for a in f.func.returns] + ) + if len(f.func.returns) > 1: + return_types = f"std::tuple<{return_types}>" + + arg_names = [ + a.name + for a in cpp.arguments( + f.func.arguments, + faithful=True, + symint=True, + method=False, + cpp_no_default_args=set(), + ) + ] + + if not modifies_arguments(f) and not returns_void: + if try_jit_decomposition: + call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP.substitute( + base_type_call=base_type_call, + tmp_var=TMP_VAR, + guard=guard, + any_has_forward_grad=any_has_forward_grad, + op_name=cpp.name(f.func), + op_overload=f.func.name.overload_name, + return_types=return_types, + arg_names=arg_names, + ) + else: + call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute( + base_type_call=base_type_call, + tmp_var=TMP_VAR, + guard=guard, + ) + + call += wrap_output(f, unpacked_bindings, TMP_VAR) + else: + assert not try_jit_decomposition + call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute( + base_type_call=base_type_call, guard=guard + ) + call = check_tensorimpl_and_storage(call, unpacked_bindings) + return call + + def emit_history() -> str: + fn = "rebase" if modifies_arguments(f) and view_info is None else "set" + output_names = [r.name for r in differentiable_outputs] + # TODO: flatten allocates a std::vector, which could be expensive + outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute( + outs=output_names if not is_inplace_foreach else "self" + ) + if not is_inplace_foreach: + return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs) + else: + return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute( + preamble=( + f"auto differentiable_outputs = {outs};\n" + f"TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size());" + ), + statements=f"{fn}_history(differentiable_outputs[i], grad_fns[i]);", + ) + + def emit_save_outputs() -> str: + if is_out_fn: + # out functions don't currently support differentiation + return "" + if info is not None and info.has_derivatives: + stmts = save_variables(info.all_saved_outputs, True) + if len(stmts) == 0: + return "" + if not is_inplace_foreach: + return CONDITIONAL.substitute(cond="grad_fn", statements=stmts) + else: + return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute( + preamble="", statements=stmts + ) + return "" + + def emit_any_requires_grad() -> list[str]: + extra_condition = "" + if info and info.output_differentiability_conditions: + assert len(info.output_differentiability_conditions) == 1 + extra_condition = f"_any_requires_grad &= ({info.output_differentiability_conditions[0]});" + names_of_args_with_derivatives = [arg.name for arg in args_with_derivatives] + if is_inplace_foreach and info is not None: + for i, arg in enumerate(names_of_args_with_derivatives): + for f_arg, r_arg in inplace_foreacharg2refarg.items(): + if arg == r_arg.name: + names_of_args_with_derivatives[i] = f_arg.name + return [ + SETUP_ANY_REQUIRES_GRAD.substitute( + args_with_derivatives=names_of_args_with_derivatives, + extra_differentiability_conditions=extra_condition, + ) + ] + + def get_any_has_forward_grad_name(var_names: tuple[str, ...]) -> str: + if len(var_names) == 1: + return f"_any_has_forward_grad_{var_names[0]}" + else: + return f'_any_has_forward_grad_{"_".join(var_names)}' + + def emit_any_has_forward_grad() -> list[str]: + content: list[str] = [] + if not is_foreach: + for derivative in fw_derivatives: + requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative) + if info and info.output_differentiability_conditions: + assert len(info.output_differentiability_conditions) == 1 + requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && {requires_fw_grad}" + content.append( + f"[[maybe_unused]] auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};" + ) + else: + for derivative in fw_derivatives: + bool_vector_name = get_any_has_forward_grad_name(derivative.var_names) + cur_derivative_conditions = [] + for inp in differentiable_inputs: + if derivative.required_inputs_fw_grad is None: + continue + if inp.name not in derivative.required_inputs_fw_grad: + continue + inp_name = ( + inp.name + if not inplace + else refargname2inplace_foreacharg[inp.name].name + ) + inp_type = ( + inp.type + if not inplace + else refargname2inplace_foreacharg[inp.name].type + ) + is_list_type = is_tensor_list_type(inp_type) + if is_list_type: + if inp_name != "self": + content.append( + FW_DERIVATIVE_SIZE_CHECK_TEMPLATE.substitute( + inp_name=inp_name + ) + ) + cur_derivative_conditions.append( + FW_DERIVATIVE_CHECK_TEMPLATE.substitute( + req_inp=inp_name + "[i]" + ) + ) + else: + cur_derivative_conditions.append( + FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp_name) + ) + + content.append(f"std::vector {bool_vector_name}(self.size());") + content.append("for (const auto& i : c10::irange(self.size())) {") + content.append( + f" {bool_vector_name}[i] = {' || '.join(cur_derivative_conditions)};" + ) + content.append("}") + return content + + def emit_check_inplace() -> list[str]: + if not inplace: + return [] + return [ + f"check_inplace({arg.name}, _any_requires_grad);" + for arg in differentiable_outputs + ] + + def emit_fw_derivatives() -> list[str]: + content: list[str] = [] + fw_grad_setters: list[str] = [] + for derivative in fw_derivatives: + res = derivative.var_names + if f.func.name.name.inplace: + assert ( + len(res) == 1 + ), "Expected number of outputs to be 1 if function is inplace" + # TODO update this when inplace namings are unified + res = ("self",) + + assert derivative.required_inputs_fw_grad is not None + + unpacked_arguments = "" + for inp in differentiable_inputs: + inp_name = inp.name + is_input_tensorlist = is_foreach and is_tensor_list_type( + inp.type + if not inplace + else refargname2inplace_foreacharg[inp.name].type + ) + input_suffix = "[i]" if is_input_tensorlist else "" + if is_inplace_foreach: + if inp.name in refargname2inplace_foreacharg: + inp_name = refargname2inplace_foreacharg[inp.name].name + zeros_fn = ( + "zeros_symint" + if inplace and inp.name == "self" + else "_efficientzerotensor_symint" + ) + if inp.name in derivative.required_inputs_fw_grad: + unpacked_arguments += ( + FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute( + inp_name=inp.name, + inp=inp_name + input_suffix, + zeros_fn=zeros_fn, + ) + ) + if inp.name in (derivative.required_inputs_primal or []): + unpacked_arguments += ( + FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute( + inp_name=inp.name, + inp=inp_name + input_suffix, + ) + ) + if derivative.required_original_self_value: + input_suffix = "s[i]" if is_inplace_foreach else "" + unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute( + inp_name="original_self", + inp="original_self" + input_suffix, + zeros_fn=zeros_fn, + ) + unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute( + inp_name="original_self", + inp="original_self" + input_suffix, + ) + elif inplace and derivative.is_reusing_outplace_formula: + # The gradient wasn't already cloned, do it if grad mode is enabled + unpacked_arguments += ( + "self_t = GradMode::is_enabled() ? self_t.clone() : self_t;" + ) + + if inplace: + is_inplace_str = "true" + else: + is_inplace_str = "false" + + requires_fw_grad = get_any_has_forward_grad_name(derivative.var_names) + + if all( + (isinstance(var_type, BaseType) and var_type.is_tensor_like()) + for var_type in derivative.var_types + ): + # Is there a way to get from BaseType to BaseCType + if len(derivative.var_types) == 1: + opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type() + if not is_foreach: + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR.substitute( + out_arg=res[0], is_inplace=is_inplace_str + ) + ) + else: + assert res[0] == ("result" if not inplace else "self") + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute( + out_arg=res[0], is_inplace=is_inplace_str + ) + ) + requires_fw_grad += f" && ({derivative.var_names[0]}.defined())" + else: + tuple_type = TupleCType( + [BaseCType(tensorT)] * len(derivative.var_types) + ) + opt_res_grad_type = OptionalCType(tuple_type).cpp_type() + for idx, single_res in enumerate(res): + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_MULTI_OUTPUT.substitute( + idx=idx, all_res="_".join(res), out_arg=single_res + ) + ) + elif ( + isinstance(derivative.var_types[0], ListType) + and derivative.var_types[0].is_tensor_like() + ): + assert ( + len(derivative.var_types) == 1 + ), "Expected number of outputs to be 1 if function returns ListType" + if not is_foreach: + opt_res_grad_type = OptionalCType( + VectorCType(BaseCType(tensorT)) + ).cpp_type() + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute( + out_arg=res[0], is_inplace=is_inplace_str + ) + ) + else: + # TODO(crcrpar): Should this (= the foreach specific logic) be refactored somehow? + # Only out-place foreach functions that have entries in `tools/autograd/derivatives.yaml` + # can reach here. + opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type() + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute( + out_arg=res[0], is_inplace=is_inplace_str + ) + ) + else: + raise RuntimeError("Unsupported output type for forward derivative") + + if not is_foreach: + fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = ::std::nullopt;" + # View ops create fw_grad that already is a view of the base's fw_grad so just use that + content.append( + FW_DERIVATIVE_TEMPLATE.substitute( + fw_grad_opt_definition=fw_grad_opt_definition, + requires_fw_grad=requires_fw_grad, + formula=derivative.formula, + out_arg="_".join(res), + unpacked_arguments=unpacked_arguments, + ) + ) + else: + # note(crcrpar): Assuming `self` is TensorList. + fw_grad_opt_definition = ( + f"std::vector<{opt_res_grad_type}> {'_'.join(res)}_new_fw_grad_opts" + "(self.size(), ::std::nullopt);" + ) + foreach_forward_grad_formula = derivative.formula + _foreach_arg: Argument | DifferentiableInput + if inplace: + for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items(): + # note(crcrpar): Massage only Scalar and ArrayRef here. + if not ( + is_tensor_type(_foreach_arg.type) + or is_tensor_list_type(_foreach_arg.type) + ): + pattern = _foreach_arg.name + if isinstance(_foreach_arg.type, ListType): + pattern += "[i]" + foreach_forward_grad_formula = ( + foreach_forward_grad_formula.replace( + _ref_arg.name, pattern + ) + ) + else: + if ( + "result" in foreach_forward_grad_formula + and "result[i]" not in foreach_forward_grad_formula + ): + foreach_forward_grad_formula = ( + foreach_forward_grad_formula.replace("result", "result[i]") + ) + + content.append( + FW_DERIVATIVE_FOREACH_TEMPLATE.substitute( + fw_grad_opt_definition=fw_grad_opt_definition, + vector_of_optional_tensor=f"{'_'.join(res)}_new_fw_grad_opts", + any_has_forward_grad_for_current_index=" || ".join( + get_any_has_forward_grad_name(derivative.var_names) + "[i]" + for derivative in fw_derivatives + ), + formula=foreach_forward_grad_formula, + unpacked_arguments=unpacked_arguments, + ) + ) + + # Set all the grads at the end to avoid: https://github.com/pytorch/pytorch/issues/67367 + content.append("\n".join(fw_grad_setters)) + return content + + def get_any_has_fw_grad_cond(derivative: ForwardDerivative | None) -> str: + # + # Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)") + # + if derivative is None: + # (1) If a derivative is NOT provided, cond will check fw_grad of ALL differentiable inputs + # - Used in the out_fn case when we want to forbid fw derivatives + # - Used in the case where the fw_derivative is not defined, but we want + # To check if there is a decomposition registered for jvp + to_check: list[str] = [] + for inp in list( + mapMaybe( + gen_differentiable_input, + f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator] + ) + ): + if is_tensor_type(inp.type): + to_check.append( + FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name) + ) + elif is_tensor_list_type(inp.type): + to_check.append( + FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE.substitute( + req_inp=inp.name + ) + ) + else: + raise RuntimeError( + f'Unsupported input type for "{name}" when forbidding forward AD usage.' + ) + return f'({" || ".join(to_check)})' + else: + # (2) If derivative is provided, use that information to determine which inputs + # to check fw_grad for + assert derivative.required_inputs_fw_grad is not None + + if len(derivative.required_inputs_fw_grad) == 0: + # Handle functions like stack + # For these, we don't unpack anything and always call the user function + if not ( + len(differentiable_inputs) == 1 + and is_tensor_list_type(differentiable_inputs[0].type) + ): + raise RuntimeError( + f'No differentiable input to "{name}" is a differentiable Tensor (as the provided ' + "forward AD formula does not use any input tangent) even though a forward gradient " + "formula has been defined for it. This case should only happen for function that " + "take a single TensorList as input. All other cases are not supported right now." + ) + any_has_fw_grad = "true" + else: + any_has_fw_grad = " || ".join( + [ + ( + FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE + if is_tensor_list_type(inp.type) + else FW_DERIVATIVE_CHECK_TEMPLATE + ).substitute(req_inp=inp.name) + for inp in differentiable_inputs + if inp.name in derivative.required_inputs_fw_grad + ] + ) + any_has_fw_grad = f"({any_has_fw_grad})" + + return any_has_fw_grad + + def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str: + if is_out_fn: + msg = "because it is an out= function" + else: + msg = ( + "because it has not been implemented yet.\\nPlease file an issue " + "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml " + "so that we can prioritize its implementation." + ) + cond = get_any_has_fw_grad_cond(derivative=None) + return ( + FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=msg) + if cond != "" + else "" + ) + + body: list[str] = [] + unpack_args_stats, unpacked_bindings = unpack_args(f) + + body.extend(unpack_args_stats) + if requires_derivative: + body.extend(emit_any_requires_grad()) + body.extend(emit_any_has_forward_grad()) + body.extend(emit_check_inplace()) + body.extend(emit_original_self_definition()) + body.extend(setup_derivative(differentiable_inputs)) + + body.append(emit_call(f, unpacked_bindings, try_jit_decomposition)) + if requires_derivative: + # set_flags has to appear after version_counter, because rebase_history + # requires that the counter is incremented before it is called + body.append(emit_history()) + body.extend(emit_check_if_in_complex_autograd_allowlist()) + + if is_out_fn: + body.append(emit_forbid_fw_derivatives(is_out_fn=True)) + else: + if requires_derivative and not try_jit_decomposition: + if len(fw_derivatives) > 0: + body.extend(emit_fw_derivatives()) + else: + body.append(emit_forbid_fw_derivatives()) + + if requires_derivative: + # Save only after the forward AD has been set up + body.append(emit_save_outputs()) + + if str(f.func.name.name) in RESET_GRAD_ACCUMULATOR: + # `inplace` implies that there is exactly one output named `self`, + # so we can keep the generated code easy. If you need to + # `reset_grad_accumulator` in an operator that's not `inplace`, you can + # remove this assert but the code generation will get more elaborate + assert inplace + body.append("reset_grad_accumulator(self);") + if not returns_void: + body.append(f"return {get_return_value(f)};") + return body diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_view_funcs.py b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_view_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..245a77106dc65a2b9ab89c9006ff317eabf1ed1c --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_view_funcs.py @@ -0,0 +1,340 @@ +# Generates ViewFuncs.h/cpp +# +# NOTE: If any changes are being made to the ViewFunc codegen please also check +# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp +# The fallback is expected to mimic this codegen, so we should keep the two in sync. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torchgen.api.dispatcher as dispatcher +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + NamedCType, + SymIntT, + tensorT, + VectorCType, +) +from torchgen.code_template import CodeTemplate +from torchgen.model import Argument, NativeFunction, OptionalType +from torchgen.utils import FileManager + +from .gen_inplace_or_view_type import ( + CALL_DISPATCH, + extract_bindings, + get_view_info, + modifies_arguments, + use_derived, +) + + +if TYPE_CHECKING: + from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo + + +FUNCTION_DECLARATION = CodeTemplate( + """\ +#define ${uppercase_op}_AVAILABLE +struct ${op} : public ${superclass} { + ${op}(${constructor_args}) ${initializer_list} + {}; + virtual ~${op}() override {}; + virtual std::vector get_symints() const override; + virtual size_t num_symints() const override; + virtual std::vector get_tensors() const override; + virtual size_t num_tensors() const override; + virtual at::Tensor operator()(const at::Tensor&) const override; + virtual std::unique_ptr clone_and_set( + std::optional> = ::std::nullopt, + std::optional> = ::std::nullopt) const override; + +protected: + virtual void set_symints(std::vector) override; + virtual void set_tensors(std::vector) override; + +private: + ${state} +}; + +""" +) + +FUNCTION_DEFINITION = CodeTemplate( + """\ +std::vector ${op}::get_symints() const { + ${get_symints} +} + +size_t ${op}::num_symints() const { + return static_cast(${num_symints}); +} + +void ${op}::set_symints(std::vector ${symints_vec}) { + TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints()); + ${set_symints} +} + +std::vector ${op}::get_tensors() const { + ${get_tensors} +} + +size_t ${op}::num_tensors() const { + return static_cast(${num_tensors}); +} + +void ${op}::set_tensors(std::vector ${tensors_vec}) { + TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors()); + ${set_tensors} +} + +at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const { + return ${op_call}; +} + +std::unique_ptr ${op}::clone_and_set( + std::optional> ${symints_vec}, + std::optional> ${tensors_vec}) const { + auto output = std::make_unique<${op}>(${clone_args}); + if (${symints_vec}.has_value()) { + output->set_symints(std::move(*(${symints_vec}))); + } + if (${tensors_vec}.has_value()) { + output->set_tensors(std::move(*(${tensors_vec}))); + } + return output; +} + +""" +) + + +# e.g. as_strided -> AsStridedViewFunc for camel case or +# as_strided_view_func otherwise +def view_func_name( + f: NativeFunction, include_namespace: bool = False, camel_case: bool = True +) -> str: + name = f.func.name.unambiguous_name() + view_func_name = f"{name.replace('.', '_')}_view_func" + if camel_case: + is_private = view_func_name.startswith("_") + view_func_name = "".join( + [p.title() for p in view_func_name.replace(".", "_").split("_")] + ) + if is_private: + # put the leading underscore back in + view_func_name = f"_{view_func_name}" + namespace = "torch::autograd::generated::" if include_namespace else "" + return f"{namespace}{view_func_name}" + + +def is_symint_or_tensor(arg: Argument) -> bool: + return arg.type.is_tensor_like() or arg.type.is_symint_like() + + +def remove_const_ref(binding: Binding) -> Binding: + return Binding( + name=binding.name, + nctype=binding.nctype.remove_const_ref(), + argument=binding.argument, + default=binding.default, + ) + + +def returns_multi_tensor(fn: NativeFunction) -> bool: + returns = fn.func.returns + assert len(returns) == 1 + returns_list_like = returns[0].type.is_list_like() is not None + returns_tensor_like = returns[0].type.is_tensor_like() + return returns_list_like and returns_tensor_like + + +# Generates strings with logic for getting / setting state of a particular type. +# +# Args: +# bindings (list): List of state bindings of interest (may be empty) +# state_vec_type (NamedCType): Type of vector to either return or copy from +# +# Returns: +# tuple: (list of getter logic strings, list of setter logic strings, string +# with num items expression) +def generate_state_getter_setter( + bindings: list[Binding], + state_vec_type: NamedCType, +) -> tuple[list[str], list[str], str]: + getter_logic = [] + setter_logic = [] + + state_vec = state_vec_type.name + getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};") + if len(bindings) > 0: + setter_logic.append("auto i = 0;") + + num_exprs = [] + for i, b in enumerate(bindings): + assert isinstance(b.argument, Argument) + if b.argument.type.is_list_like(): + # Handle list-likes. + num_expr = f"{b.name}.size()" + num_exprs.append(num_expr) + getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());" + setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());" + elif isinstance(b.argument.type, OptionalType): + # Handle optionals. + num_expr = f"({b.name}.has_value() ? 1 : 0)" + num_exprs.append(num_expr) + conditional = f"if({b.name}.has_value())" + getter = ( + f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));" + ) + setter = f"{conditional} {b.name} = {state_vec}[i];" + else: + num_expr = "1" + num_exprs.append(num_expr) + getter = f"{state_vec}.push_back({b.name});" + setter = f"{b.name} = {state_vec}[i];" + + getter_logic.append(getter) + setter_logic.append(setter) + if i < len(bindings) - 1: + setter_logic.append(f"i += {num_expr};") + + # Reserve / assert based on the total number of items expression. + num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs) + if len(bindings) > 0: + getter_logic.insert(1, f"{state_vec}.reserve({num_items});") + + getter_logic.append(f"return {state_vec};") + + return getter_logic, setter_logic, num_items + + +def process_function(fn: NativeFunction, template: CodeTemplate) -> str: + bindings = extract_bindings(fn) + non_self_bindings = [b for b in bindings if b.name != "self"] + + non_self_args = fn.func.arguments.flat_all[1:] + non_self_value_bindings = [ + dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args + ] + + # Generate constructor / clone args for the generated struct. + constructor_args = [b.defn() for b in non_self_bindings] + clone_args = [b.name for b in non_self_bindings] + + # Generate state variable declarations for the generated struct. + state_variables = [ + f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings + ] + + # Generate initializer list expressions for the generated struct. + # allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as + # vectors. + init_exprs = translate( + non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True + ) + initializers = [] + for b, init_expr in zip(non_self_bindings, init_exprs): + name = b.nctype.name + assert isinstance(name, str) + initializers.append(f"{name}({init_expr.expr})") + + # Generate call to underlying view op + call_input_name = "input_base" + op_call_args = [call_input_name, *(b.name for b in non_self_bindings)] + op_call = CALL_DISPATCH.substitute( + unambiguous_name=fn.func.name.unambiguous_name(), + unpacked_args=op_call_args, + ) + + # Multi-output views additionally require a view_idx for disambiguation. + if returns_multi_tensor(fn): + view_idx_name = "view_idx" + view_idx_typename = "int64_t" + view_idx_decl = f"{view_idx_typename} {view_idx_name}" + constructor_args.append(view_idx_decl) + clone_args.append(view_idx_name) + state_variables.append(f"{view_idx_decl};") + initializers.append(f"{view_idx_name}({view_idx_name})") + op_call += f"[{view_idx_name}]" + + # Generate initializer list for the generated struct. + initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else "" + + # Generate getter / setter logic for any symints. + symint_bindings = [ + b + for b in non_self_bindings + if isinstance(b.argument, Argument) and b.argument.type.is_symint_like() + ] + symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT))) + get_symints, set_symints, num_symints = generate_state_getter_setter( + symint_bindings, symints_vec_type + ) + + # Generate getter / setter logic for any tensors. + tensor_bindings = [ + b + for b in non_self_bindings + if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like() + ] + tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT))) + get_tensors, set_tensors, num_tensors = generate_state_getter_setter( + tensor_bindings, tensors_vec_type + ) + + return template.substitute( + op=view_func_name(fn), + uppercase_op=view_func_name(fn, camel_case=False).upper(), + superclass="torch::autograd::ViewFunc", + initializer_list=initializer_list, + state=state_variables, + constructor_args=constructor_args, + clone_args=clone_args, + symints_vec=symints_vec_type.name, + get_symints=get_symints, + set_symints=set_symints, + num_symints=num_symints, + tensors_vec=tensors_vec_type.name, + get_tensors=get_tensors, + set_tensors=set_tensors, + num_tensors=num_tensors, + call_input_name=call_input_name, + op_call=op_call, + ) + + +def gen_view_funcs( + out: str, + fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo], + template_path: str, +) -> None: + # don't need the info parts, just the function + fns = [fn.func for fn in fns_with_infos if use_derived(fn)] + # only want out-of-place views + view_fns = [ + fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn) + ] + + declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns] + definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns] + ops_headers = [f"#include " for fn in view_fns] + + file_basename = "ViewFuncs" + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + for suffix in [".h", ".cpp"]: + fname = file_basename + suffix + fm.write_with_template( + fname, + fname, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/" + + fname, + "view_func_declarations": declarations, + "view_func_definitions": definitions, + "ops_headers": ops_headers, + }, + ) diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/load_derivatives.py b/lib/python3.10/site-packages/torchgen/packaged/autograd/load_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..645a569c45e3dc9877f61b4329d2434fe987cf76 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/load_derivatives.py @@ -0,0 +1,1014 @@ +# Parses derivatives.yaml into autograd functions +# +# Each autograd function is represented by `DifferentiabilityInfo` containing +# a list of `Derivative`. See `torchgen.api.autograd` for the data models. + +from __future__ import annotations + +import re +from collections import defaultdict +from typing import Any, Counter, Dict, Sequence, Set, Tuple + +import yaml + +from torchgen.api import cpp +from torchgen.api.autograd import ( + Derivative, + DifferentiabilityInfo, + ForwardDerivative, + SavedAttribute, +) +from torchgen.api.types import ( + BaseCType, + Binding, + boolT, + CppSignatureGroup, + layoutT, + longT, + NamedCType, + OptionalCType, + scalarTypeT, + SpecialArgName, + stringT, + symIntArrayRefT, + SymIntT, + tensorGeometryT, + tensorOptionsT, + typeAndSizeT, + VectorCType, +) +from torchgen.context import with_native_function +from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml +from torchgen.model import ( + AUTOGRAD_KEYS, + FunctionSchema, + NativeFunction, + NativeFunctionsViewGroup, + OperatorName, + SchemaKind, + Type, + Variant, +) +from torchgen.utils import concatMap, IDENT_REGEX, split_name_params +from torchgen.yaml_utils import YamlLoader + + +DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]] + +_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {} + +_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS) + + +# This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op. +# Since every {view} and {view}_copy op shares the same derivative formula, +# we generate them here instead of duplicating them in the yaml. +# See Note [Codegen'd {view}_copy Operators] +def add_view_copy_derivatives( + infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], + view_groups: list[NativeFunctionsViewGroup], +) -> None: + # Get the map from each view op's name to its corresponding view group + view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = { + g.view.func.name: g for g in view_groups + } + + view_infos = {} + + for info_dispatch_dict in infos.values(): + # maybe_view_group only needs to be calculated once per info_dispatch_dict + maybe_view_group = None + view_copy_differentiability_infos = {} + for dispatch_key, info in info_dispatch_dict.items(): + maybe_view_group = view_name_to_group.get(info.func.func.name, None) + if maybe_view_group is not None and maybe_view_group.view_copy is not None: + view_copy_info = info.create_view_copy_from_view_derivative( + maybe_view_group + ) + if view_copy_info is not None: + fn_schema = view_copy_info.func.func + view_copy_differentiability_infos[dispatch_key] = view_copy_info + else: + break + # prefer manually-defined derivatives if any + if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos: + assert fn_schema is not None + view_infos[fn_schema] = view_copy_differentiability_infos + + infos.update(view_infos) + + +def load_derivatives( + derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str +) -> DerivativeRet: + # Do some caching as this is a deterministic function + global _GLOBAL_LOAD_DERIVATIVE_CACHE + key = (derivatives_yaml_path, native_yaml_path) + if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE: + with open(derivatives_yaml_path) as f: + definitions = yaml.load(f, Loader=YamlLoader) + + funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions + # From the parsed native functions, separate out the (generated) view_copy functions, + # so we can generate derivatives for them separately. + native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs) + native_functions = concatMap( + lambda g: [g] + if isinstance(g, NativeFunction) + else list(g.functions(include_copy=True)), + native_functions_with_view_groups, + ) + view_groups = [ + g + for g in native_functions_with_view_groups + if isinstance(g, NativeFunctionsViewGroup) + ] + + # What's the difference between function schema v.s. signature? + # function schema is the complete declaration including mutability annotation / default value and etc. + # signature is the canonical schema for a group of functions (in-place/out/functional variants) + # that are semantically related. + functions_by_signature: dict[ + FunctionSchema, list[NativeFunction] + ] = defaultdict(list) + functions_by_schema: dict[str, NativeFunction] = {} + for function in native_functions: + functions_by_signature[function.func.signature()].append(function) + assert str(function.func) not in functions_by_schema + functions_by_schema[str(function.func)] = function + + # Keep track of how many of which ops we've seen so we can + # disambiguate them with a numeric suffix. + op_counter = Counter[str]() + + # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos + # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info + # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema + infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {} + used_dispatch_keys: set[str] = set() + for defn_dict in definitions: + # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded. + if "dispatch" not in defn_dict: + specification = defn_dict.pop("name") + output_differentiability = defn_dict.pop( + "output_differentiability", None + ) + defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}} + if output_differentiability: + defn_dict["output_differentiability"] = output_differentiability + name, per_dispatch_diffinfos = create_differentiability_info( + defn_dict, + functions_by_signature, + functions_by_schema, + op_counter, + used_dispatch_keys, + ) + infos[name] = per_dispatch_diffinfos + + add_view_copy_derivatives(infos, view_groups) + + # cache both loaded infos as well a a set of all the dispatch_keys/aliases + # that appear in derivatives.yaml. used_dispatch_keys is useful for generating + # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used + _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys + + return _GLOBAL_LOAD_DERIVATIVE_CACHE[key] + + +# TODO: Why is this going through CppSignatureGroup, that doesn't make sense... +@with_native_function +def cpp_arguments(f: NativeFunction) -> Sequence[Binding]: + sigs = CppSignatureGroup.from_native_function(f, method=False) + if sigs.symint_signature is not None: + return sigs.symint_signature.arguments() + else: + return sigs.signature.arguments() + + +def create_derivative( + f: NativeFunction, + formula: str, + var_names: tuple[str, ...], + available_named_gradients: Sequence[str], +) -> Derivative: + original_formula = formula + arguments: list[NamedCType] = [ + a.nctype.remove_const_ref() for a in cpp_arguments(f) + ] + + return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f)) + return_types = tuple( + cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns + ) + + named_returns = [ + NamedCType(name, type) for name, type in zip(return_names, return_types) + ] + + formula, saved_inputs = saved_variables(formula, arguments, var_names) + formula, saved_outputs = saved_variables(formula, named_returns, var_names) + + used_named_gradients = { + name + for name in available_named_gradients + if re.search(IDENT_REGEX.format(name), formula) + } + + # Check that the referenced derivatives in the formula are in bounds + for i in used_gradient_indices(formula): + if i >= len(f.func.returns): + raise RuntimeError( + f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} " + f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs." + ) + + return Derivative( + formula=formula, + original_formula=original_formula, + var_names=var_names, + saved_inputs=saved_inputs, + saved_outputs=saved_outputs, + named_gradients=used_named_gradients, + ) + + +def create_forward_derivative( + f: NativeFunction, formula: str, names: tuple[str, ...] +) -> ForwardDerivative: + var_names = names + var_types: tuple[Type, ...] | None = None + for r in f.func.returns: + if r.name in var_names: + if var_types is None: + var_types = () + var_types = var_types + (r.type,) + + # Handle default return names + if var_types is None: + if var_names == ("result",): + assert len(f.func.returns) == 1 + var_types = (f.func.returns[0].type,) + else: + for var_name in var_names: + res = re.findall(r"^result(\d+)$", var_name) + if len(res) == 1: + if var_types is None: + var_types = () + arg_idx = int(res[0]) + var_types = var_types + (f.func.returns[arg_idx].type,) + + assert var_types is not None, "No matching output for forward derivative definition" + return ForwardDerivative( + formula=formula, + var_names=var_names, + var_types=var_types, + required_inputs_fw_grad=None, + required_inputs_primal=None, + required_original_self_value=False, + is_reusing_outplace_formula=False, + ) + + +def postprocess_forward_derivatives( + f: NativeFunction, + defn_name: str, + all_arg_names: list[str], + derivatives: list[Derivative], + forward_derivatives: list[ForwardDerivative], + args_with_derivatives: Sequence[Binding], +) -> list[ForwardDerivative]: + def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]: + is_foreach = f.func.name.name.base.startswith("_foreach_") + required_inputs = set() + for arg in args_with_derivatives: + if ( + arg.type in ("at::TensorList", "const at::ITensorListRef &") + and not is_foreach + ): + # The functions taking TensorList handle everything internally + continue + arg_name = arg.name + + found = re.search(IDENT_REGEX.format(arg_name), formula) + if found: + raise RuntimeError( + f"The forward formula for {defn_name} is using the base name of the {arg_name} " + f"argument which is ambiguous. You should use {arg_name}_p to access the primal " + f"value and {arg_name}_t to access the tangent." + ) + + found = re.search(IDENT_REGEX.format(arg_name + postfix), formula) + if found: + required_inputs.add(arg_name) + + return tuple(required_inputs) + + updated_derivatives: list[ForwardDerivative] = [] + + for defn in forward_derivatives: + formula = defn.formula + required_inputs_tangent = find_required_inputs(formula, "_t") + if formula == "auto_element_wise": + assert ( + f.func.kind() != SchemaKind.inplace + ), f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant" + if ( + (not len(args_with_derivatives) == 1) + or len(forward_derivatives) > 1 + or len(forward_derivatives[0].var_names) > 1 + ): + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml defines the " + "forward definition of gradient as element_wise but this only " + "works for functions with a single differentiable input and a " + "single differentiable output." + ) + if not len(derivatives) == 1: + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml defines the " + "forward definition of gradient as element_wise but it does not " + "defines the gradient formula for its argument which is required." + ) + # This transformation is based on the observation that for element-wise functions, the Jacobian + # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions) + # For the complex case, we use hermitian transpose and get (v.conj() J).conj() + # So here we are going to re-use the backward formula and replace two things: + # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input. + # 2) all usage of an original input "foo" with its primal value "foo_p". + # 3) conjugate the final result + # For example, for abs, the backward formula is: + # grad * self.sgn() + # And this function generates a forward formula that is: + # (self_t.conj() * self_p.sgn()).conj() + + backward_formula = derivatives[0].original_formula + input_name = args_with_derivatives[0].name + + # Do replacement 1) of the grad + def repl(m: Any) -> str: + return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}" + + fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula) + + # Do replacement 2) of the input variables + for arg in args_with_derivatives: + arg_name = arg.name + + def repl(m: Any) -> str: + return f"{m.group(1)}{arg_name}_p{m.group(2)}" + + fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula) + + # Do the final conjugate 3) + fw_formula = f"({fw_formula}).conj()" + + # Since there is a single differentiable inputs and we necessarily need its tangent we can + # simply require all differentiable input's tangent. + required_inputs_tangent = tuple(all_arg_names) + formula = fw_formula + elif formula == "auto_linear": + if ( + len(forward_derivatives) > 1 + or len(forward_derivatives[0].var_names) > 1 + ): + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml defines the " + "forward definition of gradient as linear but this only works " + "for functions with a single differentiable output." + ) + # This transformation is based on the observation that linear functions can be written as: + # y = f(x) = A * x + # For some matrix A and the Jacobian of the function f is also A. + # So doing J * v = A * v = f(v). + # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x. + # We do this by calling the forward again by replacing any occurrence of the differentiable + # input "foo" by it's tangent "foo_t". + # Note that multiple inputs are not a problem as long as the function is truly linear wrt to + # the vector where all the differentiable inputs are stacked. + + diff_arg_names = [arg.name for arg in args_with_derivatives] + assert len(diff_arg_names) > 0 + + # Do replacement of input variables + new_args = [] + for arg_name in all_arg_names: + if arg_name in diff_arg_names: + arg_name = arg_name + "_t" + new_args.append(arg_name) + + # TODO we are trolling + if f.func.has_symint(): + defn_name += "_symint" + + # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions. + if Variant.function in f.variants: + fw_formula = f"at::{defn_name}({', '.join(new_args)})" + else: + assert Variant.method in f.variants + fw_formula = f"{new_args[0]}.{defn_name}({', '.join(new_args[1:])})" + + # All of the input tangents are always used so all of them are required here. + required_inputs_tangent = tuple(diff_arg_names) + formula = fw_formula + + # At this point, the formula is final and is not modified anymore. + + # During forward formula, we use the primal instead of the input Tensors. + # This call inspects the formula to find for which input's primal are used. + required_inputs_primal = find_required_inputs(formula, "_p") + + updated_derivatives.append( + ForwardDerivative( + formula=formula, + var_names=defn.var_names, + var_types=defn.var_types, + required_inputs_fw_grad=required_inputs_tangent, + required_inputs_primal=required_inputs_primal, + required_original_self_value=False, + is_reusing_outplace_formula=False, + ) + ) + + return updated_derivatives + + +def is_forward_derivative_definition( + all_arg_names: list[str], names: tuple[str, ...] +) -> bool: + for name in names: + return name not in all_arg_names + raise RuntimeError("Expected `names` to be non-empty") + + +def create_differentiability_info( + defn_dict: dict[Any, Any], + functions_by_signature: dict[FunctionSchema, list[NativeFunction]], + functions_by_schema: dict[str, NativeFunction], + op_counter: Counter[str], + used_dispatch_keys: set[str], +) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]: + """Processes a single entry `defn` in derivatives.yaml""" + + def canonical_function( + functions: Sequence[NativeFunction], name: str + ) -> NativeFunction: + for f in functions: + if ( + not f.func.is_functional_fn() + and not f.func.is_out_fn() + and name == str(f.func.name.name) + ): + return f + # some functions only have in-place variants + assert name + "_" == cpp.name(functions[0].func) + return functions[0] + + def split_names(raw_names: str) -> tuple[str, ...]: + """Given "foo, bar", return ["foo", "bar"].""" + return tuple(x.strip() for x in raw_names.split(",")) + + def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None: + """ + Check for some subtle mistakes one might make when writing derivatives. + These mistakes will compile, but will be latent until a function is + used with double backwards. + """ + + uses_grad = False # true if any derivative uses "grad" + num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]" + uses_named_grads = False # true if any derivative uses "grad_{name}" + used_grads_indices: list[int] = [] # which indices of grads are used + for d in derivatives: + formula = d.formula + uses_grad = uses_grad or bool( + re.findall(IDENT_REGEX.format("grad"), formula) + ) + num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula)) + uses_named_grads = uses_named_grads or bool(d.named_gradients) + used_grads_indices.extend(used_gradient_indices(formula)) + # This is a basic sanity check: the number of places we see + # "grads" should be no fewer than the number of indices we see + # inside "grads". They may not be equal because we may use + # "grads" without an index. + assert num_grads_uses >= len(used_grads_indices) + # Thus if the number is equal, every use of grads is also + # indexed. + only_used_grads_indices = num_grads_uses == len(used_grads_indices) + + if uses_grad and num_grads_uses > 0: + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml illegally " + "mixes use of 'grad' and 'grads'. Consider replacing " + "occurrences of 'grad' with 'grads[0]'" + ) + + if only_used_grads_indices and set(used_grads_indices) == {0}: + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml solely " + "refers to 'grads[0]'. If the first output is indeed the " + "only differentiable output, replace 'grads[0]' with 'grad'; " + "otherwise, there is a likely error in your derivatives " + "declaration." + ) + + if uses_named_grads and (uses_grad or num_grads_uses > 0): + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml illegally " + 'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use ' + "only one method for identifying gradients." + ) + + @with_native_function + def set_up_derivatives( + f: NativeFunction, + ) -> tuple[ + Sequence[Derivative], + Sequence[ForwardDerivative], + Sequence[Binding], + Sequence[str], + Sequence[str], + ]: + # Set up the derivative information + derivatives: list[Derivative] = [] + forward_derivatives: list[ForwardDerivative] = [] + non_differentiable_arg_names: list[str] = [] + args_with_derivatives_set: set[str] = set() + + all_arg_names = [a.name for a in cpp_arguments(f)] + all_ret_names = [ + r.name for r in f.func.returns + ] # only used for the assert below + # output_differentiability is captured from the enclosed + # scope. Don't modify it. + # + # If it is not present, then no output is explicitly + # undifferentiable. + # + # It may be present and shorter than the length of return + # values. If that's the case, any return value that does not + # have a corresponding entry is considered not differentiable. + differentiability = output_differentiability or [True] * len(f.func.returns) + # A return is available as a named gradient ... + available_named_gradients = [ + f"grad_{ret.name}" + for ret, differentiable in zip(f.func.returns, differentiability) + # if it has not been explicitly made undifferentiable + if differentiable + # and if it has a name + and ret.name is not None + # and if its type is differentiable + and ret.type.is_tensor_like() + ] + + for raw_names in sorted(defn.keys()): + formula = defn[raw_names] + names = split_names(raw_names) + + for name in names: + assert not (name in all_arg_names and name in all_ret_names), ( + f"While processing the derivative formula for '{f.func.name}' wrt '{name}', " + f"expected '{name}' to not be both an input arg and named return. " + ) + + if is_forward_derivative_definition(all_arg_names, names): + forward_derivatives.append(create_forward_derivative(f, formula, names)) + else: + if formula.lower().strip() == "non_differentiable": + non_differentiable_arg_names += names + else: + derivative = create_derivative( + f, formula, names, available_named_gradients + ) + derivatives.append(derivative) + args_with_derivatives_set |= set(names) + + overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names) + if overlap: + raise RuntimeError( + f"derivatives definition for {defn} have overlapped non_differentiable " + f"and differentiable variables: {overlap}" + ) + + # Next, let us determine the list of inputs in order. + # TODO: do we need eagerly calculate and save it here? Can it be derived + # from NativeFunction and `derivatives` on callsites instead? + args_with_derivatives = [ + a for a in cpp_arguments(f) if a.name in args_with_derivatives_set + ] + + # Postprocess forward derivatives definitions now that we know the differentiable arguments + forward_derivatives = postprocess_forward_derivatives( + f, + defn_name, + all_arg_names, + derivatives, + forward_derivatives, + args_with_derivatives, + ) + + # Test to see if the use of 'grads' makes sense. + check_grad_usage(defn_name, derivatives) + + return ( + derivatives, + forward_derivatives, + args_with_derivatives, + non_differentiable_arg_names, + available_named_gradients, + ) + + # NB: Removes 'name' from defn dictionary + specification = defn_dict.pop("name") + defn_name, _ = split_name_params(specification) + # NB: Removes 'output_differentiability' from defn dictionary + # `None` means all differentiable. + output_differentiability = defn_dict.pop("output_differentiability", None) + output_differentiability_conditions = None + if output_differentiability and any( + isinstance(diff, str) for diff in output_differentiability + ): + if len(output_differentiability) != 1: + raise RuntimeError( + f"Not supported: for {specification}," + f"output_differentiability must either be " + f"List[bool] or a List[str] where each str is a " + f"condition. In the case where it is a condition, " + f"we only support single-output functions. " + f"Please file us an issue. " + ) + output_differentiability_conditions = output_differentiability + output_differentiability = [True] + + schema_function = functions_by_schema.get(specification) + if not schema_function: + avail = "\n".join( + k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name + ) + raise RuntimeError( + f"could not find ATen function for schema: {specification} " + f". Available signatures:\n{avail}" + ) + + # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here + # to map in-place schemas to the out-of-place variants. + # TODO: maybe the logic to handle the legacy schema is no longer necessary? + signature = schema_function.func.signature() + functions = functions_by_signature[signature] + if len(functions) == 0: + avail = "\n".join( + str(k) + for k, v in functions_by_signature.items() + if cpp.name(k) == defn_name + ) + raise RuntimeError( + f"could not find ATen function for legacy signature: {signature} " + f"corresponding to schema {specification}. Please report a bug to PyTorch. " + f"Available signatures:\n{avail}" + ) + + canonical = canonical_function(functions, defn_name) + if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)): + raise RuntimeError( + f"Schema for {defn_name} has an argument named grad_input_mask, " + "but this name would be shadowed by our codegen. " + "Please use a different name in native_functions.yaml." + ) + + if "result" in (a.name for a in cpp_arguments(canonical)): + raise RuntimeError( + f"Schema for {defn_name} has an argument named result, " + "but this is only allowed for outputs." + "Please use a different name in native_functions.yaml." + ) + + diffinfo_dict = {} + for key, defn in defn_dict["dispatch"].items(): + if key != "Default" and key not in _VALID_AUTOGRAD_KEYS: + raise RuntimeError( + f"Invalid dispatch key {key} in derivatives.yaml for {specification}," + f" expected key to be one of {_VALID_AUTOGRAD_KEYS}" + ) + if key not in used_dispatch_keys: + used_dispatch_keys.add(key) + + ( + derivatives, + forward_derivatives, + args_with_derivatives, + non_differentiable_arg_names, + available_named_gradients, + ) = set_up_derivatives(canonical) + + used_named_gradients: set[str] = set() + for d in derivatives: + used_named_gradients |= d.named_gradients + + # only assign an op name if we are actually going to calculate a derivative + op = None + if args_with_derivatives: + op_prefix = _create_op_prefix(defn_name) + if key != "Default": + op_prefix = op_prefix + key + op = f"{op_prefix}{op_counter[op_prefix]}" + op_counter[op_prefix] += 1 + + diffinfo_dict[key] = DifferentiabilityInfo( + name=defn_name, + func=canonical, + op=op, + derivatives=derivatives, + forward_derivatives=forward_derivatives, + all_saved_inputs=dedup_vars( + [v for d in derivatives for v in d.saved_inputs] + ), + all_saved_outputs=dedup_vars( + [v for d in derivatives for v in d.saved_outputs] + ), + available_named_gradients=available_named_gradients, + used_named_gradients=used_named_gradients, + args_with_derivatives=args_with_derivatives, + non_differentiable_arg_names=non_differentiable_arg_names, + output_differentiability=output_differentiability, + output_differentiability_conditions=output_differentiability_conditions, + ) + + return canonical.func, diffinfo_dict + + +GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]" + + +def used_gradient_indices(formula: str) -> list[int]: + """Determine a list of gradient indices (the i in grads[i]) that + are used by the formula. + + >>> used_gradient_indices("foo(grads[0], grads[1])") + [0, 1] + """ + return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)] + + +def saved_variables( + formula: str, + nctypes: list[NamedCType], + var_names: tuple[str, ...], +) -> tuple[str, tuple[SavedAttribute, ...]]: + def stride_expr(name: str) -> str: + assert var_names == (name,), ( + 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor ' + 'that ".strides()" is being called on.' + ) + return f'strides_or_error({name}, "{name}")' + + REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [ + # replace self.sym_sizes() with self_sym_sizes + ( + r"{}.sym_sizes\(\)", + { + "suffix": "_sym_sizes", + "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), + }, + ), + # replace self->sym_sizes() with self_sym_sizes_opt + ( + r"{}->sym_sizes\(\)", + { + "suffix": "_sym_sizes_opt", + "nctype": lambda name: NamedCType( + name, OptionalCType(BaseCType(symIntArrayRefT)) + ), + "expr": lambda name: f"{name}.has_value() ? std::optional({name}->sym_sizes()) : std::nullopt", + }, + ), + # replace self.sym_blocksize() with self_sym_blocksize_opt + ( + r"{}.sym_blocksize\(\)", + { + "suffix": "_self_sym_blocksize_opt", + "nctype": lambda name: NamedCType( + name, OptionalCType(BaseCType(symIntArrayRefT)) + ), + "expr": lambda name: f"at::sparse_csr::getSymIntBlockSize({name})", + }, + ), + # replace self.options() with self_options + ( + r"{}.options\(\)", + { + "suffix": "_options", + "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)), + }, + ), + # replace zeros_like(self) with self_info + ( + r"zeros_like\({}\)", + { + "suffix": "_info", + "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)), + "expr": lambda name: name, # at save-time + "res": lambda name: name + "_info.zeros()", # at eval-time + }, + ), + # replace self.sym_size(2) with self_sym_size_2 + ( + r"{}.sym_size\((-?\w+)\)", + { + "suffix": lambda m: f"_sym_argsize_{m.groups()[0].replace('-', 'minus_')}", + "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)), + }, + ), + # replace self.numel() with self_numel + ( + r"{}.numel\(\)", + { + "suffix": "_numel", + "nctype": lambda name: NamedCType(name, BaseCType(longT)), + }, + ), + # replace self.sym_numel() with self_sym_numel + ( + r"{}.sym_numel\(\)", + { + "suffix": "_sym_numel", + "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)), + }, + ), + # replace to_args_sizes(self) with self_args_sizes + ( + r"to_args_sizes\({}\)", + { + "suffix": "_args_sizes", + "nctype": lambda name: NamedCType( + name, VectorCType(VectorCType(BaseCType(longT))) + ), + }, + ), + # replace to_args_sizes_symint(self) with self_args_sizes + ( + r"to_args_sizes_symint\({}\)", + { + "suffix": "_args_sizes_symint", + "nctype": lambda name: NamedCType( + name, VectorCType(VectorCType(BaseCType(SymIntT))) + ), + }, + ), + # replace to_args_scalartypes(self) with self_args_scalartypes + ( + r"to_args_scalartypes\({}\)", + { + "suffix": "_args_scalartypes", + "nctype": lambda name: NamedCType( + name, VectorCType(BaseCType(scalarTypeT)) + ), + }, + ), + # replace TensorGeometry(self) with self_geometry + ( + r"TensorGeometry\({}\)", + { + "suffix": "_geometry", + "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)), + }, + ), + ( + r"{}.scalar_type\(\)", + { + "suffix": "_scalar_type", + "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)), + }, + ), + # replace self.dim() with self_dim + ( + r"{}.dim\(\)", + { + "suffix": "_dim", + "nctype": lambda name: NamedCType(name, BaseCType(longT)), + }, + ), + # replace self.sym_strides() with self_sym_strides + ( + r"{}.sym_strides\(\)", + { + "suffix": "_sym_strides", + "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), + "expr": stride_expr, + }, + ), + # replace self.layout() with self_layout + ( + r"{}.layout\(\)", + { + "suffix": "_layout", + "nctype": lambda name: NamedCType(name, BaseCType(layoutT)), + }, + ), + # replace self.is_conj() with self_conjugate + ( + r"{}.is_conj\(\)", + { + "suffix": "_conjugate", + "nctype": lambda name: NamedCType(name, BaseCType(boolT)), + }, + ), + ] + + # find which arguments need to be saved + saved: list[SavedAttribute] = [] + + if ".sizes()" in formula or "->sizes()" in formula: + raise RuntimeError( + ".sizes() is not supported in derivative formulas. Instead, please use the SymInt version," + + f".sym_sizes(), which returned a c10::SymIntArrayRef. formula={formula}" + ) + if re.search(r"\.size\([-]?\d+\)", formula) or re.search( + r"->size\([-]?\d+\)", formula + ): + raise RuntimeError( + ".size(int) is not supported in derivative formulas. Instead, please use the SymInt version," + + f".sym_size(int), which returned a c10::SymIntArrayRef. formula={formula}" + ) + if ".strides()" in formula or "->strides()" in formula: + raise RuntimeError( + ".strides() is not supported in derivative formulas. Instead, please use the SymInt version," + + f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}" + ) + for nctype in nctypes: + name = ( + nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name + ) + # First search the formula for expressions which can be evaluated + # when the autograd Function is created to avoid saving variables + for regex, info in REPLACEMENTS: + + def repl(m: re.Match[str]) -> str: + suffix: str = ( + info["suffix"](m) if callable(info["suffix"]) else info["suffix"] + ) + expr: str = info["expr"](name) if "expr" in info else m.group(0) + saved.append( + SavedAttribute( + nctype=info["nctype"](name + suffix), + expr=expr, + ) + ) + if "res" in info: + replacement: str = info["res"](name) + return replacement + return name + suffix + + formula = re.sub(regex.format(name), repl, formula) + + # std::optional types stored in Backward nodes must be + # converted to std::optional before being passed into + # the backward function + if nctype.type == OptionalCType(BaseCType(stringT)): + formula = re.sub( + rf"\b{name}\b", + f"{name}.has_value() ? std::optional({name}.value()) : std::nullopt", + formula, + ) + + # Find any variables which remain in the formula and save them + if re.search(IDENT_REGEX.format(name), formula): + saved.append( + SavedAttribute( + nctype=nctype, + expr=name, + ) + ) + + return formula, tuple(saved) + + +def _create_op_prefix(name: str) -> str: + """Takes a native function name converts to a op prefix name. + + Note that the "name" parameter must be the native function name + without the optional variant suffix, so "add" instead of + "add.out". + + OP names correspond to classes, hence the change to title case. + + Example:: + >>> _create_op_prefix('add') + 'AddBackward' + """ + camel_case = "".join([p.title() for p in name.split("_")]) + return (camel_case + "Backward").replace("ForwardBackward", "Backward") + + +def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]: + seen: set[str] = set() + saved: list[SavedAttribute] = [] + for var in vars: + name = ( + var.nctype.name.name + if isinstance(var.nctype.name, SpecialArgName) + else var.nctype.name + ) + if name in seen: + continue + seen.add(name) + saved.append(var) + return saved diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/ADInplaceOrViewType.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/ADInplaceOrViewType.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e8276697eee065a36d1b16e583a5f011f92541c2 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/ADInplaceOrViewType.cpp @@ -0,0 +1,38 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include "torch/csrc/autograd/VariableTypeUtils.h" +#include "torch/csrc/autograd/generated/ViewFuncs.h" + +#include +#include +#include + +// ${generated_comment} + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using namespace at; +using torch::autograd::CreationMeta; +using torch::autograd::as_view; +using torch::autograd::increment_version; + +namespace torch { + +namespace ADInplaceOrView { + +namespace { +${inplace_or_view_method_definitions} +} // namespace +} // namespace ADInplaceOrView + +namespace { + +TORCH_LIBRARY_IMPL(aten, ADInplaceOrView, m) { + ${inplace_or_view_wrapper_registrations}; +} + +} // namespace +} // namespace torch diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/Functions.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/Functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5bc089f67df74b300bc8de6568b702d48e0cb6c2 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/Functions.cpp @@ -0,0 +1,20 @@ +#include "torch/csrc/autograd/FunctionsManual.h" +#include "torch/csrc/dynamo/compiled_autograd.h" + +// ${generated_comment} + +// The manual function definitions that used to be here are now in torch/csrc/autograd/FunctionsManual.cpp +// This speeds up re-compilation and allow to share these implementations so that they can be +// used for forward mode AD formulas as well. + +using namespace torch::autograd::generated::details; +using at::Tensor; +using at::Scalar; +using at::IntArrayRef; +using at::TensorList; + +namespace torch::autograd::generated { + +${autograd_function_definitions} + +} // namespace torch::autograd::generated diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/Functions.h b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/Functions.h new file mode 100644 index 0000000000000000000000000000000000000000..911d7d905c002b29941167ccff112a8079d48266 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/Functions.h @@ -0,0 +1,51 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include + +#include "torch/csrc/autograd/function.h" +#include "torch/csrc/autograd/variable.h" +#include "torch/csrc/autograd/saved_variable.h" +#include + +#include + +namespace torch { namespace autograd { namespace generated { + +using at::Scalar; +using at::Tensor; +using at::IntArrayRef; +using at::ArrayRef; +using at::Type; +using at::TensorGeometry; +using at::ScalarType; +using std::optional; +using c10::fmap; + +inline std::vector unpack_list(at::ArrayRef xs, std::shared_ptr saved_for = nullptr) { + // NB: we must explicitly do the conversion in the lambda, otherwise template + // deduction will give a Tensor of Variable which is not convertible + return fmap(xs, [&saved_for](const SavedVariable& x) { + // TODO(crcrpar): Use `std::move(saved_for)` to avoid incrementing refcount, which would need refactoring. + return static_cast(x.unpack(saved_for)); + }); +} + +inline c10::List> unpack_opt_list(at::ArrayRef xs, std::shared_ptr saved_for = nullptr) { + torch::List> result; + result.reserve(xs.size()); + for (const SavedVariable& v : xs) { + auto var = v.unpack(saved_for); + result.push_back(var.defined() ? std::optional(var) : ::std::nullopt); + } + return result; +} + +using torch::autograd::TypeAndSize; + +${autograd_function_declarations} + +}}} // namespace torch::autograd::generated diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/TraceType.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/TraceType.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb5e7ae44a5353a3cc2a90858fe33b7fc0ef8bfd --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/TraceType.cpp @@ -0,0 +1,40 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include "torch/csrc/jit/frontend/tracer.h" + +#include + +#include "torch/csrc/autograd/function.h" + +#include "ATen/quantized/Quantizer.h" + +// ${generated_comment} + +// See the `Tracer` section in `torch/csrc/jit/OVERVIEW.md`. +// NOTE See [Sharded File] comment in VariableType + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using namespace at; + +namespace torch { + +namespace TraceType { + +namespace { +${trace_method_definitions} +} // namespace +} // namespace TraceType + +namespace { + +TORCH_LIBRARY_IMPL(aten, Tracer, m) { + ${trace_wrapper_registrations}; +} + +} // namespace + +} // namespace torch diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/VariableType.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/VariableType.cpp new file mode 100644 index 0000000000000000000000000000000000000000..08f1f8b698e528ca382ead2fb64ee0a45a708b08 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/VariableType.cpp @@ -0,0 +1,65 @@ +#include "torch/csrc/autograd/VariableTypeUtils.h" +#include "torch/csrc/autograd/generated/VariableType.h" +#include "torch/csrc/autograd/FunctionsManual.h" + +#include +#include +#include +#include + +#include + + +// ${generated_comment} + +// NOTE [Sharded File]: on this file's split-into-shards state +// +// Back in the good old days, VariableType.cpp was generated as one +// file with every function in it, and everything was great and +// simple. +// +// However, this file was also very large (over 36,000 lines), and +// compiling it was very slow, and in fact was a significant +// bottleneck for incremental rebuilds. To address this, we now +// generate the file split across multiple shards, named +// VariableType_0.cpp and so on, which can be compiled in parallel. +// +// For ease of inspection and debugging, so that it's not necessary to +// go rooting around in multiple files, we also generate all the +// functions together in VariableTypeEverything.cpp. This generated +// file is only for convenience; it's not actually used in the +// build. If the file you're looking at now is one of the shards, you +// may want to switch over to the Everything variant to make you +// grepping smoother. + +using namespace at; +using namespace torch::autograd::generated; +using namespace torch::autograd::generated::details; + + +namespace torch::autograd { + +namespace VariableType { +namespace{ + C10_UNUSED void reset_grad_accumulator(Variable & self) { + AutogradMeta* meta = torch::autograd::impl::get_autograd_meta(self); + if (meta != nullptr) { + meta->grad_accumulator_.reset(); + } + } +} + +namespace { + + +${type_derived_method_definitions} +} +} + +namespace { + +${wrapper_registrations} + +} + +} // namespace torch::autograd diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/VariableType.h b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/VariableType.h new file mode 100644 index 0000000000000000000000000000000000000000..08da173f94bf868517ed6a52fd449e6f144904ce --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/VariableType.h @@ -0,0 +1,59 @@ +#pragma once + +// ${generated_comment} + +#include +#include + +#include + +#include +#include + +#include // for size_t +#include // for function +#include // for unique_ptr +#include +#include + +namespace at { + struct Quantizer; +}; + +namespace torch { namespace autograd { + +using Variable = at::Tensor; +using at::Context; +using at::Device; +using at::Dimname; +using at::DimnameList; +using at::Generator; +using at::IntArrayRef; +using at::MemoryFormat; +using at::QScheme; +using at::Scalar; +using at::ScalarType; +using at::Storage; +using at::Tensor; +using at::TensorList; +using at::TensorOptions; +using at::Quantizer; +// This is temporary typedef to enable Quantizer in aten native function API +// we'll remove them when we are actually exposing Quantizer class +// to frontend +using ConstQuantizerPtr = const c10::intrusive_ptr&; +using std::optional; + +namespace VariableType { + TORCH_API std::vector allCUDATypes(); + TORCH_API std::vector allXPUTypes(); + TORCH_API std::vector allCPUTypes(); + TORCH_API std::vector allPrivateUser1Types(); + + at::Tensor & unpack(Tensor & t, const char * name, int pos); + const at::Tensor & unpack(const Tensor & t, const char * name, int pos); + at::Tensor unpack_opt(const Tensor & t, const char * name, int pos); + std::vector unpack(const at::ITensorListRef& tl, const char *name, int pos); +}; + +}} // namespace torch::autograd diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..11b9b194fb46f924e863c4c1dab5cbb8dbb0601b --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.cpp @@ -0,0 +1,14 @@ +#include + +// ${generated_comment} + +using at::Tensor; +using at::Scalar; +using at::IntArrayRef; +using at::TensorList; + +namespace torch::autograd::generated { + +${view_func_definitions} + +} // namespace torch::autograd::generated diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.h b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.h new file mode 100644 index 0000000000000000000000000000000000000000..1f69c062d344e4cd5f98cf5f34fd4278019fdf8a --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.h @@ -0,0 +1,28 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +namespace torch::autograd::generated { + +using at::Scalar; +using at::Tensor; +using at::IntArrayRef; +using at::ArrayRef; +using at::Type; +using at::ScalarType; +using std::optional; +using c10::fmap; + +${view_func_declarations} + +} // namespace torch::autograd::generated diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/annotated_fn_args.py.in b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/annotated_fn_args.py.in new file mode 100644 index 0000000000000000000000000000000000000000..1012c008451745b8f1ed1454a864f666caf2618a --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/annotated_fn_args.py.in @@ -0,0 +1,11 @@ +""" +This file is needed for generating procedural tests required for +testing __torch_function__. See tests/test_overrides.py. +""" + +# flake8: noqa +import torch + +annotated_args = { +${annotated_args} +} diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_enum_tag.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_enum_tag.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83cfad1d7ba4d6fc3529caf78e036c5883e7bc23 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_enum_tag.cpp @@ -0,0 +1,15 @@ +#include +#include +#include +#include + +namespace py = pybind11; +namespace torch { + namespace autograd { + void initEnumTag(PyObject* module) { + auto m = py::handle(module).cast(); + py::enum_(m, "Tag") + ${enum_of_valid_tags}; + m.doc() = "An Enum that contains tags that can be assigned to an operator registered in C++."; + } +}} diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_fft_functions.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_fft_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..71ac4e2226d2db418eba5690995424d3f007e620 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_fft_functions.cpp @@ -0,0 +1,81 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include "torch/csrc/Device.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_fft_functions.h" +#include "torch/csrc/autograd/generated/python_return_types.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/autograd/generated/variable_factories.h" +#include "torch/csrc/utils/out_types.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/utils/device_lazy_init.h" + +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using at::Tensor; +using at::Device; +using at::Layout; +using at::Scalar; +using at::ScalarType; +using at::Backend; +using at::OptionalDeviceGuard; +using at::DeviceGuard; +using at::TensorOptions; +using at::IntArrayRef; +using at::Generator; +using at::TensorList; +using at::Dimname; +using at::DimnameList; + +using torch::utils::check_out_type_matches; +using namespace torch::autograd::utils; + +namespace torch::autograd { + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef fft_functions[] = { + ${py_method_defs} + {NULL} +}; + +static PyObject* THPFFTVariableFunctionsModule = NULL; + +void initFFTFunctions(PyObject* module) { + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._fft", + NULL, + -1, + fft_functions + }; + PyObject* fft = PyModule_Create(&def); + THPFFTVariableFunctionsModule = fft; + if (!fft) { + throw python_error(); + } + // steals a reference to fft + if (PyModule_AddObject(module, "_fft", fft) != 0) { + throw python_error(); + } +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_functions.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1522d6cd0f5a2a1fc0188bf9d6d0d59fe1b27d85 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_functions.cpp @@ -0,0 +1,37 @@ +#include + +// ${generated_comment} + +#include +#include + +#include +#include "torch/csrc/autograd/generated/Functions.h" +#include "torch/csrc/autograd/python_cpp_function.h" +#include +#include +#include +#include +#include + +// NOTE: See [Sharded File] comment in VariableType + +namespace torch::autograd::generated { + +template +static void addClass(PyObject* module, PyTypeObject& type, const char* name, + PyGetSetDef* function_properties=NULL, PyMethodDef* function_methods=NULL) +{ + _initFunctionPyTypeObject(type, name, function_properties, function_methods); + Py_INCREF(&type); + PyModule_AddObject(module, name, (PyObject*)&type); + registerCppFunction(typeid(C), &type); +} + +${py_function_props_and_getters} + +void initialize_autogenerated_functions${shard_id}(PyObject* module) { + ${py_function_initializers} +} + +} // namespace torch::autograd::generated diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_functions.h b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..22e37207e219431100fefaf21b02e3ed0f63d956 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_functions.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +// ${generated_comment} + +// Python bindings for automatically generated autograd functions + +namespace torch { namespace autograd { namespace generated { + +${shard_forward_declare} + +inline void initialize_autogenerated_functions(PyObject* module) { + ${shard_call} +} + +}}} // namespace torch::autograd::generated diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_linalg_functions.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_linalg_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c93752a3ddbfcf111426f98c3ea68fc625e94def --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_linalg_functions.cpp @@ -0,0 +1,68 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include "torch/csrc/Device.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_linalg_functions.h" +#include "torch/csrc/autograd/generated/python_return_types.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using at::Tensor; +using at::Scalar; +using at::ScalarType; +using at::MemoryFormat; +using at::Generator; +using at::IntArrayRef; +using at::TensorList; + +using namespace torch::autograd::utils; + +namespace torch::autograd { + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef linalg_functions[] = { + ${py_method_defs} + {NULL} +}; + +static PyObject* THPLinalgVariableFunctionsModule = NULL; + +void initLinalgFunctions(PyObject* module) { + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._linalg", + NULL, + -1, + linalg_functions + }; + PyObject* linalg = PyModule_Create(&def); + THPLinalgVariableFunctionsModule = linalg; + if (!linalg) { + throw python_error(); + } + // steals a reference to linalg + if (PyModule_AddObject(module, "_linalg", linalg) != 0) { + throw python_error(); + } +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_nested_functions.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_nested_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3acb5128cee1e180de887080106e7cf5559f15ee --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_nested_functions.cpp @@ -0,0 +1,81 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include "torch/csrc/Device.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_nested_functions.h" +#include "torch/csrc/autograd/generated/python_return_types.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/autograd/generated/variable_factories.h" +#include "torch/csrc/utils/out_types.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/utils/device_lazy_init.h" + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using at::Tensor; +using at::Device; +using at::Layout; +using at::Scalar; +using at::ScalarType; +using at::Backend; +using at::OptionalDeviceGuard; +using at::DeviceGuard; +using at::TensorOptions; +using at::IntArrayRef; +using at::OptionalIntArrayRef; +using at::Generator; +using at::TensorList; +using at::Dimname; +using at::DimnameList; + +using namespace torch::autograd::utils; + +namespace torch::autograd { + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef nested_functions[] = { + {NULL, NULL, 0, NULL}, + ${py_method_defs} + {NULL} +}; + +static PyObject* THPNestedVariableFunctionsModule = NULL; + +void initNestedFunctions(PyObject* module) { + nested_functions[0] = get_nested_functions_manual()[0]; + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._nested", + NULL, + -1, + nested_functions + }; + PyObject* nested = PyModule_Create(&def); + THPNestedVariableFunctionsModule = nested; + if (!nested) { + throw python_error(); + } + // steals a reference to nested + if (PyModule_AddObject(module, "_nested", nested) != 0) { + throw python_error(); + } +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_nn_functions.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_nn_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4877df6584bd6702f259f0797e2ff45d3c719bd3 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_nn_functions.cpp @@ -0,0 +1,113 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include "torch/csrc/Device.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_nn_functions.h" +#include "torch/csrc/autograd/generated/python_return_types.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/utils/tensor_memoryformats.h" + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using at::Tensor; +using at::Scalar; +using at::MemoryFormat; +using at::Generator; +using at::IntArrayRef; +using at::ArrayRef; + +using namespace torch::autograd::utils; + +namespace torch::autograd { + +static PyObject* THPNNVariableFunctionsModule = NULL; + +static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + "to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + "to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + }); + ParsedArgs<5> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + if (r.has_torch_function()) { + return handle_torch_function(r, args, kwargs, THPNNVariableFunctionsModule, "torch.nn", "_parse_to"); + } + auto parsed = parse_to_conversion(r, /*allow_copy*/ false); // we don't want copy for nn.Module.to + auto& device = std::get<0>(parsed); + auto& scalarType = std::get<1>(parsed); + auto non_blocking = std::get<2>(parsed); + auto opt_memory_format = std::get<4>(parsed); + auto tuple = THPObjectPtr{PyTuple_New(4)}; + if (!tuple) throw python_error(); + if (device) { + PyTuple_SET_ITEM(tuple.get(), 0, THPDevice_New(*device)); + } else { + Py_INCREF(Py_None); + PyTuple_SET_ITEM(tuple.get(), 0, Py_None); + } + if (scalarType) { + PyTuple_SET_ITEM(tuple.get(), 1, Py_NewRef(torch::getTHPDtype(*scalarType))); + } else { + Py_INCREF(Py_None); + PyTuple_SET_ITEM(tuple.get(), 1, Py_None); + } + PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking)); + if (opt_memory_format.has_value()) { + PyTuple_SET_ITEM(tuple.get(), 3, Py_NewRef(torch::utils::getTHPMemoryFormat(opt_memory_format.value()))); + } else { + Py_INCREF(Py_None); + PyTuple_SET_ITEM(tuple.get(), 3, Py_None); + } + return tuple.release(); + END_HANDLE_TH_ERRORS +} + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef nn_functions[] = { + {"_parse_to", castPyCFunctionWithKeywords(THPVariable__parse_to), + METH_VARARGS | METH_KEYWORDS, nullptr}, + ${py_method_defs} + {NULL} +}; + +void initNNFunctions(PyObject* module) { + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._nn", + NULL, + -1, + nn_functions + }; + PyObject* nn = PyModule_Create(&def); + THPNNVariableFunctionsModule = nn; + if (!nn) { + throw python_error(); + } + // steals a reference to nn + if (PyModule_AddObject(module, "_nn", nn) != 0) { + throw python_error(); + } +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_return_types.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_return_types.cpp new file mode 100644 index 0000000000000000000000000000000000000000..139e6b8958336cfcc8328fa33581e9f1ab6d5532 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_return_types.cpp @@ -0,0 +1,52 @@ +#include + +#include +#include +#include + +#include "torch/csrc/autograd/generated/python_return_types.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/Exceptions.h" + +namespace torch { namespace autograd { namespace generated { + +${py_return_types} + +}}} + +namespace torch::autograd { + +static void addReturnType( + PyObject* module, + const char* name, + PyTypeObject* type) { + // hold onto the TypeObject for the unlikely case of user + // deleting or overriding it. + Py_INCREF(type); + if (PyModule_AddObject( + module, + name, + (PyObject*)type) != 0) { + Py_DECREF(type); + throw python_error(); + } +} + +void initReturnTypes(PyObject* module) { + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, "torch._C._return_types", nullptr, -1, {}}; + PyObject* return_types_module = PyModule_Create(&def); + if (!return_types_module) { + throw python_error(); + } + + ${py_return_types_registrations} + + // steals a reference to return_types on success + if (PyModule_AddObject(module, "_return_types", return_types_module) != 0) { + Py_DECREF(return_types_module); + throw python_error(); + } +} + +} // namespace torch::autograd diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_return_types.h b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_return_types.h new file mode 100644 index 0000000000000000000000000000000000000000..ce6c355ea146a272709255b898603764112168b9 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_return_types.h @@ -0,0 +1,14 @@ +#pragma once + +namespace torch { +namespace autograd { +namespace generated { + +${py_return_types_declarations} + +} + +void initReturnTypes(PyObject* module); + +} // namespace autograd +} // namespace torch diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_sparse_functions.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_sparse_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..648d91442102e9b950cb2ddb8db545c4b4e1100e --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_sparse_functions.cpp @@ -0,0 +1,67 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include "torch/csrc/Device.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_sparse_functions.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using at::Tensor; +using at::Scalar; +using at::ScalarType; +using at::MemoryFormat; +using at::Generator; +using at::IntArrayRef; +using at::TensorList; + +using namespace torch::autograd::utils; + +namespace torch::autograd { + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef sparse_functions[] = { + ${py_method_defs} + {NULL} +}; + +static PyObject* THPSparseVariableFunctionsModule = NULL; + +void initSparseFunctions(PyObject* module) { + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._sparse", + NULL, + -1, + sparse_functions + }; + PyObject* sparse = PyModule_Create(&def); + THPSparseVariableFunctionsModule = sparse; + if (!sparse) { + throw python_error(); + } + // steals a reference to sparse + if (PyModule_AddObject(module, "_sparse", sparse) != 0) { + throw python_error(); + } +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_special_functions.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_special_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bf9e109b4a77352cd85ba828b97d67d329543867 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_special_functions.cpp @@ -0,0 +1,79 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include "torch/csrc/Device.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_special_functions.h" +#include "torch/csrc/autograd/generated/python_return_types.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/autograd/generated/variable_factories.h" +#include "torch/csrc/utils/out_types.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/utils/device_lazy_init.h" + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using at::Tensor; +using at::Device; +using at::Layout; +using at::Scalar; +using at::ScalarType; +using at::Backend; +using at::OptionalDeviceGuard; +using at::DeviceGuard; +using at::TensorOptions; +using at::IntArrayRef; +using at::Generator; +using at::TensorList; +using at::Dimname; +using at::DimnameList; + +using torch::utils::check_out_type_matches; +using namespace torch::autograd::utils; + +namespace torch::autograd { + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef special_functions[] = { + ${py_method_defs} + {NULL} +}; + +static PyObject* THPSpecialVariableFunctionsModule = NULL; + +void initSpecialFunctions(PyObject* module) { + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._special", + NULL, + -1, + special_functions + }; + PyObject* special = PyModule_Create(&def); + THPSpecialVariableFunctionsModule = special; + if (!special) { + throw python_error(); + } + // steals a reference to special + if (PyModule_AddObject(module, "_special", special) != 0) { + throw python_error(); + } +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_torch_functions.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_torch_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c17d1040e1892b6a215a8c4264fe5a5345265bc7 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_torch_functions.cpp @@ -0,0 +1,93 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +// Python bindings for torch.* functions implemented through ATen. +// +// The functions are bound as static methods on a class +// torch._C._VariableFunctions which is also aliased as Variable._torch +// and also copied into 'torch' module. + +#include + +// Undefine the copysign macro so that at::copysign works as intended with MSVC +// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196 +#ifdef _MSC_VER +#undef copysign +#endif // _MSC_VER + +#include "torch/csrc/autograd/python_torch_functions.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/Dtype.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/utils/out_types.h" +#include "torch/csrc/utils/pybind.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/tensor_layouts.h" +#include "torch/csrc/utils/tensor_new.h" +#include "torch/csrc/utils/tensor_numpy.h" +#include "torch/csrc/jit/frontend/tracer.h" +#include "torch/csrc/autograd/generated/variable_factories.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/utils/device_lazy_init.h" +#include "torch/csrc/autograd/generated/python_return_types.h" + +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +#include +#include +#include +#include + +using at::Tensor; +using at::Device; +using at::Layout; +using at::Scalar; +using at::ScalarType; +using at::Backend; +using at::OptionalDeviceGuard; +using at::DeviceGuard; +using at::TensorOptions; +using at::IntArrayRef; +using at::Generator; +using at::TensorList; +using at::Dimname; +using at::DimnameList; +using at::ArrayRef; + +using torch::utils::check_out_type_matches; +using namespace torch::autograd::utils; + +// NOTE: See [Sharded File] comment in VariableType + +namespace torch::autograd { + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef torch_functions_shard[] = { + ${py_method_defs} +}; + +void gatherTorchFunctions${shard_id}(std::vector &torch_functions) { + constexpr size_t num_functions = sizeof(torch_functions_shard) / sizeof(torch_functions_shard[0]); + torch_functions.insert( + torch_functions.end(), + torch_functions_shard, + torch_functions_shard + num_functions); +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_variable_methods.cpp b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_variable_methods.cpp new file mode 100644 index 0000000000000000000000000000000000000000..16c3b9e5efd6a6eab58f4d29557386be6f893a2c --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/python_variable_methods.cpp @@ -0,0 +1,1333 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include + +// Undefine the copysign macro so that at::copysign works as intended with MSVC +// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196 +#ifdef _MSC_VER +#undef copysign +#endif // _MSC_VER + +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/Size.h" +#include "torch/csrc/autograd/generated/VariableType.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/autograd/utils/error_messages.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/jit/frontend/tracer.h" +#ifdef USE_CUDA +#include "torch/csrc/cuda/Event.h" +#endif +#include "torch/csrc/utils/device_lazy_init.h" +#include +#include "torch/csrc/utils/object_ptr.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/python_numbers.h" +#include "torch/csrc/utils/python_strings.h" +#include "torch/csrc/utils/python_tuples.h" +#include "torch/csrc/utils/tensor_apply.h" +#include "torch/csrc/utils/tensor_list.h" +#include "torch/csrc/utils/tensor_new.h" +#include "torch/csrc/utils/tensor_numpy.h" +#include "torch/csrc/utils/tensor_types.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/autograd/generated/python_return_types.h" + +#include +#include +#include "c10/util/Optional.h" +#include "c10/core/Stream.h" + +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#include +#endif + +using at::DeviceGuard; +using at::device_of; +using at::OptionalDeviceGuard; +using at::Backend; +using at::Scalar; +using at::ScalarType; +using at::Tensor; +using c10::Stream; +using namespace torch::autograd::utils; + +namespace torch::autograd { + +static PyObject * THPVariable__is_view(PyObject *self, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "_is_view", args); + } + auto& self_ = THPVariable_Unpack(self); + if (self_.is_view()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +// implemented on the python object bc no support for first-class functions in native_functions.yaml +// See: ATen/native/README.md for more context +static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + auto args = py::make_tuple(py::handle(arg)); + return handle_torch_function(self, "apply_", args.ptr()); + } + auto& self_ = THPVariable_Unpack(self); + if (self_.requires_grad()) { + throw std::runtime_error( + "Can't call apply_() on Variable that requires grad. Use " + "var.detach().apply_() instead."); + } + return THPVariable_Wrap(torch::utils::apply_(self_, arg)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "size(int64_t? dim=None)", + "size(Dimname dim)", + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + if (r.idx == 0) { + if (!r.toInt64Optional(0).has_value()) { + return THPSize_NewFromSymSizes(self_); + } + if (jit::tracer::isTracing()) { + // will error out if a tensor has symints + return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0))); + } else { + return torch::toPyObject(self_.sym_size(r.toInt64(0))); + } + } else if (r.idx == 1) { + if (jit::tracer::isTracing()) { + TORCH_INTERNAL_ASSERT(false, "NYI: Named tensors w/ JIT"); + } + return wrap(self_.size(r.dimname(0))); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_stride(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "stride(int64_t? dim=None)", + "stride(Dimname dim)", + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + if (r.idx == 0) { + if (r.toInt64Optional(0).has_value()) { + return torch::toPyObject(self_.sym_stride(r.toInt64(0))); + } + // yes, this is called strides in ATen. + at::SymIntArrayRef strides = self_.sym_strides(); + // we can't do the normal wrapping here because IntArrayRef maps to both + // torch.Size and tuple in python + // TODO: consider factoring this out + THPObjectPtr tuple(PyTuple_New(strides.size())); + if (!tuple) throw python_error(); + for (size_t i = 0; i != strides.size(); i++) { + PyObject* s = torch::toPyObject(strides[i]); + if (!s) throw python_error(); + PyTuple_SET_ITEM(tuple.get(), i, s); + } + return tuple.release(); + } else if (r.idx == 1) { + return wrap(self_.stride(r.dimname(0))); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "get_device", args, nullptr); + } + auto& self = THPVariable_Unpack(self_); + return wrap(self.get_device()); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_has_names(PyObject* self_, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "has_names", args); + } + auto& self = THPVariable_Unpack(self_); + return wrap(self.has_names()); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_data_ptr(PyObject* self_, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "data_ptr", args); + } + auto& self = THPVariable_Unpack(self_); + return wrap(self.data_ptr()); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_storage_offset(PyObject* self_, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "storage_offset"); + } + auto& self = THPVariable_Unpack(self_); + return py::cast(self.sym_storage_offset()).release().ptr(); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_dim(PyObject* self, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "dim", args); + } + auto& self_ = THPVariable_Unpack(self); + return THPUtils_packInt64(self_.dim()); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_numel(PyObject* self, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "numel", args); + } + auto& self_ = THPVariable_Unpack(self); + if (jit::tracer::isTracing()) { + return wrap(jit::tracer::getNumelOf(self_)); + } else { + return py::cast(self_.sym_numel()).release().ptr(); + } + END_HANDLE_TH_ERRORS +} + +static Tensor dispatch_contiguous(const Tensor & self, at::MemoryFormat memory_format) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return self.contiguous(memory_format); +} + +static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "contiguous(*, MemoryFormat memory_format=contiguous_format)", + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto& self_ = THPVariable_Unpack(self); + auto memory_format = r.memoryformat(0); + // avoids touching the GIL or current device if self is already contiguous + if (self_.is_contiguous(memory_format)) { + // NOTE: this logic is duplicated from VariableType.cpp. Since we need to + // record this call to contiguous() in the trace regardless of whether + // we actually call contiguous here, we need to record this information + // manually. + if (jit::tracer::isTracing()) { + auto tracer_state = jit::tracer::getTracingState(); + auto op_name = c10::Symbol::fromQualString("aten::contiguous"); + auto node = tracer_state->createNode(op_name, /*num_outputs=*/0); + jit::tracer::recordSourceLocation(node); + jit::tracer::addInputs(node, "self", self_); + jit::tracer::addInputs(node, "memory_format", memory_format); + tracer_state->insertNode(node); + jit::tracer::addOutput(node, self_); + } + Py_INCREF(self); + return self; + } + return THPVariable_Wrap(dispatch_contiguous(self_, memory_format)); + END_HANDLE_TH_ERRORS +} + +static Tensor dispatch_copy_(const Tensor & self, const Tensor & other, bool non_blocking) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return self.copy_(other, non_blocking); +} + + static PyObject * THPVariable_copy_(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "copy_(Tensor other, bool non_blocking=False)", + "copy_(Tensor other, bool async=False)|deprecated" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<2> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + return THPVariable_Wrap(dispatch_copy_(self_, r.tensor(0), r.toBool(1))); + END_HANDLE_TH_ERRORS +} + +template +static T dispatch_to(const Tensor & self) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + TORCH_CHECK_VALUE(self.sym_numel() == 1, "only one element tensors can be converted to Python scalars"); + return self.template item(); +} + +static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__float__", args); + } + jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW); + auto& self_ = THPVariable_Unpack(self); + return wrap(dispatch_to(self_)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__complex__", args); + } + jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW); + auto& self_ = THPVariable_Unpack(self); + return wrap(dispatch_to>(self_)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__int__", args); + } + jit::tracer::warn("Converting a tensor to a Python integer", jit::tracer::WARN_PYTHON_DATAFLOW); + auto& self_ = THPVariable_Unpack(self); + if (isFloatingType(self_.scalar_type())) { + // we can't dispatch to item here because we want to avoid ATen overflow checks; + // the python integral type (long in python2) can't overflow. + return THPUtils_packDoubleAsInt(dispatch_to(self_)); + } else { + return wrap(dispatch_to(self_)); + } + END_HANDLE_TH_ERRORS +} + +// This is the __index__ function in Python which is similar to __int__, but +// called when used as a slice. +static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__index__", args); + } + auto& self_ = THPVariable_Unpack(self); + // TODO: change the condition to `self_.dim() != 0` once we expose scalars + // in PyTorch. + if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true) || self_.sym_numel() != 1) { + throw TypeError("only integer tensors of a single element can be converted to an index"); + } + return wrap(dispatch_to(self_)); + END_HANDLE_TH_ERRORS +} + +static Tensor dispatch_invert(const Tensor & self) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return self.bitwise_not(); +} + +static PyObject * THPVariable_invert(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__invert__", args); + } + auto& self_ = THPVariable_Unpack(self); + if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true)) { + throw TypeError("~ (operator.invert) is only implemented on integer and Boolean-type tensors"); + } + return THPVariable_Wrap(dispatch_invert(self_)); + END_HANDLE_TH_ERRORS +} + +static Tensor dispatch_to(const Tensor & self, Device device, bool non_blocking, bool copy, std::optional optional_memory_format) { + pybind11::gil_scoped_release no_gil; + // NOTE: this is where we record aten::to in the graph during tracing. However, the behavior of aten::to + // is different with respect to TensorOptions fields that are not present: aten::to inherits fields that + // are missing from the self argument while the tracer assumes that they should be populated with the + // default values (eg. float for scalar type). By explicitly copying over the tensor options here we fully + // specify all tensor options and thus record the proper trace + return self.to(self.options().device(device).memory_format(optional_memory_format), non_blocking, copy); +} + +static Tensor dispatch_to(const Tensor & self, bool non_blocking, bool copy, std::optional optional_memory_format) { + pybind11::gil_scoped_release no_gil; + return self.to(self.options().memory_format(optional_memory_format), non_blocking, copy); +} + +static Tensor dispatch_to(const Tensor & self, ScalarType dtype, bool non_blocking, bool copy, std::optional optional_memory_format) { + pybind11::gil_scoped_release no_gil; + // TODO: Make this call the TensorOptions version, maybe? + return self.to(dtype, non_blocking, copy, optional_memory_format); +} + +static Tensor dispatch_to(const Tensor & self, Device device, ScalarType dtype, bool non_blocking, bool copy, std::optional optional_memory_format) { + pybind11::gil_scoped_release no_gil; + // TODO: Make this call the TensorOptions version, maybe? + return self.to(device, dtype, non_blocking, copy, optional_memory_format); +} + +static PyObject * THPVariable_cpu(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "cpu(*, MemoryFormat? memory_format=None)" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_Wrap(dispatch_to(self_, at::Device(at::DeviceType::CPU), false, false, opt_memory_format)); + END_HANDLE_TH_ERRORS +} + +static Tensor dispatch_nonzero(const Tensor & self) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return self.nonzero(); +} + +static std::vector dispatch_nonzero_numpy(const Tensor & self) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return self.nonzero_numpy(); +} + +static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "nonzero()", + "nonzero(*, bool as_tuple)", + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<2> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + if (r.idx == 0 || (r.idx == 1 && !r.toBool(0))) { + return wrap(dispatch_nonzero(self_)); + } else { + return wrap(dispatch_nonzero_numpy(self_)); + } + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_cuda(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "cuda(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)", + "cuda(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto device = r.isNone(0) ? at::Device(at::DeviceType::CUDA) : r.device(0); + auto opt_memory_format = r.memoryformatOptional(2); + TORCH_CHECK(device.is_cuda(), "Invalid device, must be cuda device"); + torch::utils::device_lazy_init(at::kCUDA); + return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_mtia(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "mtia(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)", + "mtia(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if (r.has_torch_function()) { + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto device = r.isNone(0) ? at::Device(at::DeviceType::MTIA) : r.device(0); + auto opt_memory_format = r.memoryformatOptional(2); + TORCH_CHECK(device.is_mtia(), "Invalid device, must be MTIA device"); + torch::utils::device_lazy_init(at::kMTIA); + return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_xpu(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "xpu(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)", + "xpu(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if (r.has_torch_function()) { + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto device = r.isNone(0) ? at::Device(at::DeviceType::XPU) : r.device(0); + auto opt_memory_format = r.memoryformatOptional(2); + TORCH_CHECK(device.is_xpu(), "Invalid device, must be xpu device"); + torch::utils::device_lazy_init(at::kXPU); + return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_ipu(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "ipu(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)", + "ipu(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if (r.has_torch_function()) { + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto device = r.isNone(0) ? at::Device(at::DeviceType::IPU) : r.device(0); + auto opt_memory_format = r.memoryformatOptional(2); + TORCH_CHECK(device.is_ipu(), "Invalid device, must be ipu device"); + return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_to_type(PyObject* self, ScalarType scalarType, std::optional optional_memory_format) { + HANDLE_TH_ERRORS + auto& self_ = THPVariable_Unpack(self); + return THPVariable_Wrap(dispatch_to(self_, scalarType, false, false, optional_memory_format)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_byte(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "byte(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Byte, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_char(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "char(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Char, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_double(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "double(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Double, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_float(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "float(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Float, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_cdouble(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "cdouble(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::ComplexDouble, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_cfloat(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "cfloat(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::ComplexFloat, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_half(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "half(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Half, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_int(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "int(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Int, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_long(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "long(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Long, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_short(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "short(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Short, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_bool(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "bool(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Bool, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_bfloat16(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "bfloat16(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::BFloat16, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_element_size(PyObject* self, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "element_size", args); + } + auto& self_ = THPVariable_Unpack(self); + return THPUtils_packInt64(self_.element_size()); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object bc PyObjects not declarable in native_functions.yaml +// See: ATen/native/README.md for more context +static PyObject * THPVariable_numpy(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "numpy(*, bool force=False)" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if (r.has_torch_function()) { + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + jit::tracer::warn("Converting a tensor to a NumPy array", jit::tracer::WARN_PYTHON_DATAFLOW); + return torch::utils::tensor_to_numpy(self_, r.toBool(0)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_requires_grad_(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "requires_grad_(bool requires_grad=True)", + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + // temporary hack to improve functorch UX. + const auto& functorch_tls = at::functorch::functorchTLSAccessor(); + if (functorch_tls) { + functorch_tls->checkSupportsInplaceRequiresGrad(); + } + + auto requires_grad = r.toBool(0); + // should we throw if requires_grad is true? var.requires_grad = True throws here + // but it's nice to let this be a no-op. + if (!self_.is_leaf() && !requires_grad) { + throw std::runtime_error(autograd::utils::requires_grad_leaf_error(requires_grad)); + } + if (requires_grad && ! isDifferentiableType(at::typeMetaToScalarType(self_.dtype()))) { + throw std::runtime_error("only Tensors of floating point dtype can require gradients"); + } + self_.set_requires_grad(requires_grad); + return THPVariable_Wrap(self_); + END_HANDLE_TH_ERRORS +} + +inline bool dispatch_is_contiguous(const Tensor & self, MemoryFormat memory_format) { + return self.is_contiguous(memory_format); +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_is_contiguous(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "is_contiguous(*, MemoryFormat memory_format=contiguous_format)", + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self_, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self_, args, kwargs, PyObject_Type(self_), "torch.Tensor"); + } + + auto memory_format = r.memoryformat(0); + auto& self = THPVariable_Unpack(self_); + return wrap(dispatch_is_contiguous(self, memory_format)); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_item(PyObject* self, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "item", args); + } + jit::tracer::warn("Converting a tensor to a Python number", jit::tracer::WARN_PYTHON_DATAFLOW); + auto& self_ = THPVariable_Unpack(self); + auto dispatch_item_ = [](const Tensor& self) -> at::Scalar { + pybind11::gil_scoped_release no_gil; + return self.item(); + }; + return py::cast(dispatch_item_(self_)).release().ptr(); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object bc no support for first class functions in native_functions.yaml +// See: ATen/native/README.md for more context +static PyObject * THPVariable_map_(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ "map_(Tensor other, PyObject* callable)" }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<2> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + Variable other = r.tensor(0); + if (self_.requires_grad() || other.requires_grad()) { + throw std::runtime_error( + "Can't call map_() on Variable that requires grad. Use " + "var.detach().map_() instead."); + } + TORCH_CHECK( + !self_.unsafeGetTensorImpl()->is_python_dispatch() && !other.unsafeGetTensorImpl()->is_python_dispatch(), + ".map_ is not supported for tensor subclasses."); + + return THPVariable_Wrap(torch::utils::map_(self_, other, r.pyobject(1))); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object bc no support for first class functions in native_functions.yaml +// See: ATen/native/README.md for more context +static PyObject * THPVariable_map2_(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ "map2_(Tensor x, Tensor y, PyObject* callable)" }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + Variable x = r.tensor(0); + Variable y = r.tensor(1); + if (self_.requires_grad() || x.requires_grad() || y.requires_grad()) { + throw std::runtime_error( + "Can't call map2_() on Variable that requires grad. Use " + "var.detach().map2_() instead."); + } + TORCH_CHECK( + !x.unsafeGetTensorImpl()->is_python_dispatch() && !y.unsafeGetTensorImpl()->is_python_dispatch(), + ".map2_ is not supported for tensor subclasses."); + return THPVariable_Wrap(torch::utils::map2_(self_, x, y, r.pyobject(2))); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "new", args, kwargs); + } + auto& self_ = THPVariable_Unpack(self); + OptionalDeviceGuard device_guard(device_of(self_)); + return THPVariable_Wrap(torch::utils::legacy_tensor_new(legacyExtractDispatchKey(self_), self_.scalar_type(), args, kwargs)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "new_tensor", args, kwargs); + } + auto& self_ = THPVariable_Unpack(self); + OptionalDeviceGuard device_guard(device_of(self_)); + return THPVariable_Wrap(torch::utils::new_tensor(legacyExtractDispatchKey(self_), self_.scalar_type(), args, kwargs)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_storage(PyObject* self, PyObject* arg) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "untyped_storage"); + } + auto& self_ = THPVariable_Unpack(self); + return createPyObject(self_.storage()); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + "to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + "to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + }); + ParsedArgs<5> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + if (r.has_torch_function()) { + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto parsed = parse_to_conversion(r, /*allow_copy*/ true); + auto& device = std::get<0>(parsed); + auto& scalarType = std::get<1>(parsed); + auto non_blocking = std::get<2>(parsed); + auto copy = std::get<3>(parsed); + auto opt_memory_format = std::get<4>(parsed); + auto& self_ = THPVariable_Unpack(self); + torch::utils::maybe_initialize_device(device); + if (!device && !scalarType && !copy && !opt_memory_format.has_value()) { + Py_INCREF(self); + return self; + } else if (!device && !scalarType) { + return THPVariable_Wrap( + dispatch_to(self_, non_blocking, copy, opt_memory_format)); + } else if (!device) { + return THPVariable_Wrap(dispatch_to(self_, *scalarType, non_blocking, copy, opt_memory_format)); + } else if (!scalarType) { + return THPVariable_Wrap(dispatch_to(self_, *device, non_blocking, copy, opt_memory_format)); + } else { + return THPVariable_Wrap(dispatch_to(self_, *device, *scalarType, non_blocking, copy, opt_memory_format)); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// implemented on the python object b/c arbitrarily nested list not declarable in native_functions.yaml +// See: ATen/native/README.md for more context +static PyObject * THPVariable_tolist(PyObject* self, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "tolist", args); + } + jit::tracer::warn("Converting a tensor to a Python list", jit::tracer::WARN_PYTHON_DATAFLOW); + auto self_ = THPVariable_Unpack(self); + return torch::utils::tensor_to_list(self_); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "type(PyObject* dtype=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)", + "type(PyObject* dtype=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + if (r.isNone(0)) { + return THPUtils_packString(torch::utils::options_to_string(self_.options())); + } + auto obj = r.pyobject(0); + auto opt_memory_format = r.memoryformatOptional(2); + std::string type_name; + bool is_dtype = false; + if (PyType_Check(obj)) { + if (obj == THPVariableClass) { + type_name = "torch.Tensor"; + } else { + type_name = ((PyTypeObject*)obj)->tp_name; + } + } else if (THPUtils_checkString(obj)) { + type_name = THPUtils_unpackString(obj); + } else if (THPDtype_Check(obj)) { + is_dtype = true; + } else { + throw TypeError("dtype must be a type, str, or dtype object"); + } + ScalarType scalar_type; + Device device = self_.device(); + if (is_dtype) { + scalar_type = r.scalartype(0); + return THPVariable_Wrap(dispatch_to(self_, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format)); + } + at::TensorOptions options = torch::utils::options_from_string(type_name); + scalar_type = at::typeMetaToScalarType(options.dtype()); + auto device_type = options.device().type(); + if (device_type != device.type()) { + device = at::Device(device_type); + } + torch::utils::maybe_initialize_device(device); + return THPVariable_Wrap(dispatch_to(self_, device, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format)); + END_HANDLE_TH_ERRORS +} + +// generated methods start here + +${py_methods} + +static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) { + if (check_has_torch_function(self)) { + HANDLE_TH_ERRORS + return handle_torch_function(self, "__bool__", args); + END_HANDLE_TH_ERRORS + } + jit::tracer::warn("Converting a tensor to a Python boolean", jit::tracer::WARN_PYTHON_DATAFLOW); + return THPVariable_is_nonzero(self, args); +} + +static PyObject * THPVariable___eq__(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS +#ifdef USE_NUMPY + if (torch::utils::is_numpy_available()) { + static PythonArgParser parser({ + "__eq__(PyObject* other)", + }, /*traceable=*/true); + + ParsedArgs<1> parsed_args; + auto _r = parser.parse(self_, args, kwargs, parsed_args); + if(_r.has_torch_function()) { + return handle_torch_function(_r, self_, args, kwargs, THPVariableClass, "torch.Tensor"); + } + switch (_r.idx) { + case 0: { + auto other = _r.pyobject(0); + if (PyArray_Check(other)) { + auto other_tensor = torch::utils::tensor_from_numpy(other); + auto dispatch_eq = [](const at::Tensor & self, const at::Tensor & other) -> at::Tensor { + pybind11::gil_scoped_release no_gil; + return self.eq(other); + }; + const Tensor& self = THPVariable_Unpack(self_); + return wrap(dispatch_eq(self, other_tensor)); + } + } + } + } +#endif + return THPVariable_eq(self_, args, kwargs); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// Wrapper converts a raised TypeError into returning NotImplemented +// Used to implement binary arithmetic operators +template +static PyObject * TypeError_to_NotImplemented_(PyObject* self, PyObject* args, PyObject* kwargs) { + + PyObject* ret = Func(self, args, kwargs); + if (!ret && PyErr_ExceptionMatches(PyExc_TypeError)) { + PyErr_Clear(); + Py_INCREF(Py_NotImplemented); + ret = Py_NotImplemented; + } + return ret; +} + +// set_ has to be defined in the template because the c10::Storage object +// does not have a type, and we need to make sure the Python storage object's +// type matches the tensor's type +static PyObject* THPVariable_set_( + PyObject* self_, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + const Tensor& self = THPVariable_Unpack(self_); + static PythonArgParser parser( + { + "set_()", + "set_(Storage source)", + "set_(Storage source, SymInt storage_offset, SymIntArrayRef size, SymIntArrayRef stride=None)", + "set_(Tensor source)", + "set_(Tensor source, SymInt storage_offset, SymIntArrayRef size, SymIntArrayRef stride=None)", + }, + /*traceable=*/false); + + ParsedArgs<4> parsed_args; + auto _r = parser.parse(args, kwargs, parsed_args); + + switch (_r.idx) { + case 0: { + // aten::set_(Tensor(a!) self) -> Tensor(a!) + auto dispatch_set_ = [](const Tensor& self) -> Tensor { + pybind11::gil_scoped_release no_gil; + return self.set_(); + }; + return wrap(dispatch_set_(self)); + } + case 1: { + // aten::set_.source_Storage(Tensor(a!) self, Storage source) -> + // Tensor(a!) + at::ScalarType storage_scalar_type; + bool is_typed_storage = true; + at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage); + TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage, + "Expected a Storage of type ", self.dtype(), + " or an UntypedStorage, but got type ", storage_scalar_type, + " for argument 1 'storage'"); + auto dispatch_set_ = [](const Tensor& self, Storage source) -> Tensor { + pybind11::gil_scoped_release no_gil; + return self.set_(source); + }; + return wrap(dispatch_set_(self, storage)); + } + case 2: { + // aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage + // source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!) + at::ScalarType storage_scalar_type; + bool is_typed_storage = true; + at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage); + TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage, + "Expected a Storage of type ", self.dtype(), + " or an UntypedStorage, but got type ", storage_scalar_type, + " for argument 1 'storage'"); + auto dispatch_set_ = [](const Tensor& self, + Storage source, + c10::SymInt storage_offset, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) -> Tensor { + pybind11::gil_scoped_release no_gil; + return self.set__symint(source, storage_offset, size, stride); + }; + return wrap(dispatch_set_( + self, storage, _r.toSymInt(1), _r.symintlist(2), _r.symintlist(3))); + } + case 3: { + // aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) + auto dispatch_set_ = [](const Tensor& self, const Tensor& source) -> Tensor { + TORCH_CHECK(source.dtype() == self.dtype(), "Could not set tensor of type ", source.dtype(), " to a tensor of type ", self.dtype()); + pybind11::gil_scoped_release no_gil; + return self.set_(source); + }; + return wrap(dispatch_set_(self, _r.tensor(0))); + } + case 4: { + // aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor + // source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!) + at::Tensor storage = _r.tensor(0); + auto dispatch_set_ = [](const Tensor& self, + const Tensor& source, + c10::SymInt storage_offset, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) -> Tensor { + pybind11::gil_scoped_release no_gil; + return self.set__symint(source, storage_offset, size, stride); + }; + return wrap(dispatch_set_( + self, storage, _r.toSymInt(1), _r.symintlist(2), _r.symintlist(3))); + } + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// XXX: ops that are bound here are not exposed to the C++ api nor the JIT. +// Any new ops added here should be accompanied with a comment why they are not +// being registered through native_functions.yaml, and be tagged cpp / JIT +PyMethodDef variable_methods[] = { + // These magic methods are all implemented on python object to wrap NotImplementedError + {"__add__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__radd__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__iadd__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__rmul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__mul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__imul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__sub__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__isub__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__div__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__truediv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__floordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__idiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__ifloordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__mod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__imod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__eq__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__ne__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__lt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__le__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__gt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__ge__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__rand__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__ror__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__rxor__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__bool__", THPVariable_bool_scalar, METH_NOARGS, NULL}, + {"__float__", THPVariable_float_scalar, METH_NOARGS, NULL}, + {"__complex__", THPVariable_complex_scalar, METH_NOARGS, NULL}, + {"__int__", THPVariable_integral_scalar, METH_NOARGS, NULL}, + {"__long__", THPVariable_integral_scalar, METH_NOARGS, NULL}, + {"__index__", THPVariable_index_scalar, METH_NOARGS, NULL}, + {"__nonzero__", THPVariable_bool_scalar, METH_NOARGS, NULL}, + {"__invert__", THPVariable_invert, METH_NOARGS, NULL}, + {"__matmul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"_is_view", THPVariable__is_view, METH_NOARGS, NULL}, + {"apply_", THPVariable_apply_, METH_O, NULL}, + {"bfloat16", castPyCFunctionWithKeywords(THPVariable_bfloat16), METH_VARARGS | METH_KEYWORDS, NULL}, + {"byte", castPyCFunctionWithKeywords(THPVariable_byte), METH_VARARGS | METH_KEYWORDS, NULL}, + {"char", castPyCFunctionWithKeywords(THPVariable_char), METH_VARARGS | METH_KEYWORDS, NULL}, + {"contiguous", castPyCFunctionWithKeywords(THPVariable_contiguous), METH_VARARGS | METH_KEYWORDS, NULL}, + {"copy_", castPyCFunctionWithKeywords(THPVariable_copy_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"cpu", castPyCFunctionWithKeywords(THPVariable_cpu), METH_VARARGS | METH_KEYWORDS, NULL}, + {"cuda", castPyCFunctionWithKeywords(THPVariable_cuda), METH_VARARGS | METH_KEYWORDS, NULL}, + {"mtia", castPyCFunctionWithKeywords(THPVariable_mtia), METH_VARARGS | METH_KEYWORDS, NULL}, + {"xpu", castPyCFunctionWithKeywords(THPVariable_xpu), METH_VARARGS | METH_KEYWORDS, NULL}, + {"ipu", castPyCFunctionWithKeywords(THPVariable_ipu), METH_VARARGS | METH_KEYWORDS, NULL}, + {"data_ptr", THPVariable_data_ptr, METH_NOARGS, NULL}, + {"dim", THPVariable_dim, METH_NOARGS, NULL}, + {"has_names", THPVariable_has_names, METH_NOARGS, NULL}, + {"double", castPyCFunctionWithKeywords(THPVariable_double), METH_VARARGS | METH_KEYWORDS, NULL}, + {"cdouble", castPyCFunctionWithKeywords(THPVariable_cdouble), METH_VARARGS | METH_KEYWORDS, NULL}, + {"element_size", THPVariable_element_size, METH_NOARGS, NULL}, + {"float", castPyCFunctionWithKeywords(THPVariable_float), METH_VARARGS | METH_KEYWORDS, NULL}, + {"cfloat", castPyCFunctionWithKeywords(THPVariable_cfloat), METH_VARARGS | METH_KEYWORDS, NULL}, + {"get_device", THPVariable_get_device, METH_NOARGS, NULL}, + {"bool", castPyCFunctionWithKeywords(THPVariable_bool), METH_VARARGS | METH_KEYWORDS, NULL}, + {"half", castPyCFunctionWithKeywords(THPVariable_half), METH_VARARGS | METH_KEYWORDS, NULL}, + {"int", castPyCFunctionWithKeywords(THPVariable_int), METH_VARARGS | METH_KEYWORDS, NULL}, + {"is_contiguous", castPyCFunctionWithKeywords(THPVariable_is_contiguous), METH_VARARGS | METH_KEYWORDS, NULL}, + {"item", THPVariable_item, METH_NOARGS, NULL}, + {"long", castPyCFunctionWithKeywords(THPVariable_long), METH_VARARGS | METH_KEYWORDS, NULL}, + {"map_", castPyCFunctionWithKeywords(THPVariable_map_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"map2_", castPyCFunctionWithKeywords(THPVariable_map2_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"ndimension", THPVariable_dim, METH_NOARGS, NULL}, + {"nelement", THPVariable_numel, METH_NOARGS, NULL}, + {"new", castPyCFunctionWithKeywords(THPVariable_new), METH_VARARGS | METH_KEYWORDS, NULL}, + {"new_tensor", castPyCFunctionWithKeywords(THPVariable_new_tensor), METH_VARARGS | METH_KEYWORDS, NULL}, + {"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS, NULL}, + {"numel", THPVariable_numel, METH_NOARGS, NULL}, + {"numpy", castPyCFunctionWithKeywords(THPVariable_numpy), METH_VARARGS | METH_KEYWORDS, NULL}, + {"requires_grad_", castPyCFunctionWithKeywords(THPVariable_requires_grad_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL}, + {"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, NULL}, + {"untyped_storage", THPVariable_storage, METH_NOARGS, NULL}, + {"storage_offset", THPVariable_storage_offset, METH_NOARGS, NULL}, + {"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, NULL}, + {"to", castPyCFunctionWithKeywords(THPVariable_to), METH_VARARGS | METH_KEYWORDS, NULL}, + {"tolist", THPVariable_tolist, METH_NOARGS, NULL}, + {"type", castPyCFunctionWithKeywords(THPVariable_type), METH_VARARGS | METH_KEYWORDS, NULL}, + ${py_method_defs} + {NULL} +}; + +} // namespace torch::autograd diff --git a/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/variable_factories.h b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/variable_factories.h new file mode 100644 index 0000000000000000000000000000000000000000..2b55f441ab6249cb7963c5e4a15070f626f775b7 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/packaged/autograd/templates/variable_factories.h @@ -0,0 +1,135 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +$ops_headers +#endif + +#include +#include +#include + +namespace torch { + +/// NOTE: Currently `torch::tensor(...)` doesn't support mixed data types +/// (i.e. `torch::tensor({{bool, 2.0}})` doesn't work). We might be able to +/// support it in the future by iterating over all sub-lists to find +/// the largest data type that can represent all of the elements, or by using +/// variadic templates. +/// +/// NOTE: C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` / `std::vector` / +/// (nested) braced-init-list of floating-point types always produces a tensor of dtype +/// `torch::get_default_dtype()`, matching Python `torch.tensor` behavior. +/// +/// NOTE: C++ `torch::tensor` with an integer type or an `at::ArrayRef` / `std::vector` / +/// (nested) braced-init-list of integer types always produces a tensor of dtype `at::kLong` +/// (aka. int64_t), matching Python `torch.tensor` behavior. +/// +/// NOTE: The following dtypes are not supported by `torch::tensor` currently: +/// - `unsigned int` +/// - `unsigned long int` +/// - `unsigned long long int` +/// - `long long int` +inline at::Tensor tensor(detail::TensorDataContainer tensor_data_container, const at::TensorOptions& options = {}) { + return autograd::make_variable( + // note: we remove the requires_grad setting from the TensorOptions because + // it is ignored anyways (and we actually have an assertion that it isn't set + // which would fail otherwise). We handle requires_grad explicitly here + // instead of passing it through to the kernel. + tensor_data_container.convert_to_tensor(options.requires_grad(::std::nullopt)), + options.requires_grad()); +} + +/// A generic deleter function. +using Deleter = std::function; +using at::MemoryFormat; + +/// Exposes the given `data` as a `Tensor` without taking ownership of the +/// original data. `sizes` should specify the shape of the tensor, `strides` the +/// stride in each dimension. The `deleter` function (a +/// `std::function`) will be called on the `data` when the Tensor +/// data would normally be deallocated. The `TensorOptions` specify additional +/// configuration options for the returned tensor, such as what type to +/// interpret the `data` as. +inline at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + at::IntArrayRef strides, + const Deleter& deleter, + const at::TensorOptions& options = at::TensorOptions()) { + at::Tensor tensor = ([&]() { + at::AutoDispatchBelowAutograd guard; // TODO: remove + at::tracer::impl::NoTracerDispatchMode tracer_guard; + return at::from_blob(data, sizes, strides, deleter, options.requires_grad(::std::nullopt)); + })(); + return autograd::make_variable(tensor, options.requires_grad()); +} + +/// Exposes the given `data` as a `Tensor` without taking ownership of the +/// original data. `sizes` should specify the shape of the tensor, `strides` the +/// stride in each dimension. The `TensorOptions` +/// specify additional configuration options for the returned tensor, such as +/// what type to interpret the `data` as. +inline at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + at::IntArrayRef strides, + const at::TensorOptions& options = at::TensorOptions()) { + at::Tensor tensor = ([&]() { + at::AutoDispatchBelowAutograd guard; // TODO: remove + at::tracer::impl::NoTracerDispatchMode tracer_guard; + return at::from_blob(data, sizes, strides, options.requires_grad(::std::nullopt)); + })(); + return autograd::make_variable(tensor, options.requires_grad()); +} + +/// Exposes the given `data` as a `Tensor` without taking ownership of the +/// original data. `sizes` should specify the shape of the tensor. The `deleter` +/// (a `std::function`) function will be called on the `data` when +/// the Tensor data would normally be deallocated. The `TensorOptions` specify +/// additional configuration options for the returned tensor, such as what type +/// to interpret the `data` as. +inline at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + const Deleter& deleter, + const at::TensorOptions& options = at::TensorOptions()) { + at::Tensor tensor = ([&]() { + at::AutoDispatchBelowAutograd guard; // TODO: remove + at::tracer::impl::NoTracerDispatchMode tracer_guard; + return at::from_blob(data, sizes, deleter, options.requires_grad(::std::nullopt)); + })(); + return autograd::make_variable(tensor, options.requires_grad()); +} + +/// Exposes the given `data` as a `Tensor` without taking ownership of the +/// original data. `sizes` should specify the shape of the tensor. The +/// `TensorOptions` specify additional configuration options for the returned +/// tensor, such as what type to interpret the `data` as. +inline at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + const at::TensorOptions& options = at::TensorOptions()) { + at::Tensor tensor = ([&]() { + at::AutoDispatchBelowAutograd guard; // TODO: remove + at::tracer::impl::NoTracerDispatchMode tracer_guard; + return at::from_blob(data, sizes, options.requires_grad(::std::nullopt)); + })(); + return autograd::make_variable(tensor, options.requires_grad()); +} + +${function_definitions} + +} // namespace torch diff --git a/lib/python3.10/site-packages/torchgen/selective_build/__init__.py b/lib/python3.10/site-packages/torchgen/selective_build/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05348bf4d7c8f2d33cc7b9f4d72d337990a5884e Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/selective_build/__pycache__/operator.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/selective_build/__pycache__/operator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3501f789b99b709c8a42e9292340fdd88423a34e Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/selective_build/__pycache__/operator.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/selective_build/__pycache__/selector.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/selective_build/__pycache__/selector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f1819f0884805fcf7e00b1ecd024687af99b99d Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/selective_build/__pycache__/selector.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/selective_build/operator.py b/lib/python3.10/site-packages/torchgen/selective_build/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb92dfc09e28c7c98ab7230af362c363a30d621 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/selective_build/operator.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +# This class holds information about a single operator used to determine +# the outcome of a selective/custom PyTorch build that doesn't include +# registration code for all the supported operators. This is done to +# reduce the size of the generated binary so that it can be deployed in +# situations where binary size comes at a premium. +# +@dataclass(frozen=True) +class SelectiveBuildOperator: + # The name of the operator. This includes the aten::, etc... prefix + # The operator name may or may not have the overload name. If this + # operator name does not specify an overload name, the way to determine + # if this entry refers to the family of operators with this base name + # or just the operator with this name is to look at the value of the + # 'include_all_overloads' flag in this class. + name: str + + # True if this is a root operator (i.e. called directly from a + # TorchScript model, etc...). An operator is considered to be a + # root operator if it is called directly from any one of the models + # that this instance of the pytorch library was built for. Hence, it + # may not be a root operator in all of the models that are used in + # this instance of the pytorch library. + is_root_operator: bool + + # Is this operator used for on-device training? If True, then we need to + # use the information to generate code in VariableType_N.cpp for registration + # of training related operators. Again, this is True if this operator + # is used for training in one or more models used by this instance of the + # pytorch library. + is_used_for_training: bool + + # If True, it indicates that this operator instance (object) refers to an + # operator without the overload name and should apply to all overloads + # which have this operator name as the base name. This flag is applicable + # only for objects that have operator names without a DOT (period) character + # in them. + # + # Note: This flag is a temporary workaround to grandfather in the current + # static selective (custom) build mechanism, which largely ignores overload + # names when determining whether to select operators for registration + # purposes. + include_all_overloads: bool + + # Debug Information at the operator level + _debug_info: tuple[str, ...] | None + + @staticmethod + def from_yaml_dict( + op_name: str, op_info: dict[str, object] + ) -> SelectiveBuildOperator: + allowed_keys = { + "name", + "is_root_operator", + "is_used_for_training", + "include_all_overloads", + "debug_info", + } + + if len(set(op_info.keys()) - allowed_keys) > 0: + raise Exception( # noqa: TRY002 + "Got unexpected top level keys: {}".format( + ",".join(set(op_info.keys()) - allowed_keys), + ) + ) + + if "name" in op_info: + assert op_name == op_info["name"] + + is_root_operator = op_info.get("is_root_operator", True) + assert isinstance(is_root_operator, bool) + + is_used_for_training = op_info.get("is_used_for_training", True) + assert isinstance(is_used_for_training, bool) + + include_all_overloads = op_info.get("include_all_overloads", True) + assert isinstance(include_all_overloads, bool) + + debug_info: tuple[str, ...] | None = None + if "debug_info" in op_info: + di_list = op_info["debug_info"] + assert isinstance(di_list, list) + debug_info = tuple(str(x) for x in di_list) + + return SelectiveBuildOperator( + name=op_name, + is_root_operator=is_root_operator, + is_used_for_training=is_used_for_training, + include_all_overloads=include_all_overloads, + _debug_info=debug_info, + ) + + @staticmethod + def from_legacy_operator_name_without_overload( + name: str, + ) -> SelectiveBuildOperator: + return SelectiveBuildOperator( + name=name, + is_root_operator=True, + is_used_for_training=True, + include_all_overloads=True, + _debug_info=None, + ) + + def to_dict(self) -> dict[str, object]: + ret: dict[str, object] = { + "is_root_operator": self.is_root_operator, + "is_used_for_training": self.is_used_for_training, + "include_all_overloads": self.include_all_overloads, + } + if self._debug_info is not None: + ret["debug_info"] = self._debug_info + + return ret + + +def merge_debug_info( + lhs: tuple[str, ...] | None, + rhs: tuple[str, ...] | None, +) -> tuple[str, ...] | None: + # Ensure that when merging, each entry shows up just once. + if lhs is None and rhs is None: + return None + + return tuple(set((lhs or ()) + (rhs or ()))) + + +def combine_operators( + lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator +) -> SelectiveBuildOperator: + if str(lhs.name) != str(rhs.name): + raise Exception( # noqa: TRY002 + f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead" + ) + + return SelectiveBuildOperator( + name=lhs.name, + # Consider this operator to be a root operator if it is a + # root operator in any of the models used in this instance of + # the pytorch library. + is_root_operator=lhs.is_root_operator or rhs.is_root_operator, + # Consider this operator to be a training operator if it is + # an operator used for training in any of the models used + # in this instance of the pytorch library. + is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training, + include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads, + _debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info), + ) + + +def merge_operator_dicts( + lhs: dict[str, SelectiveBuildOperator], + rhs: dict[str, SelectiveBuildOperator], +) -> dict[str, SelectiveBuildOperator]: + operators: dict[str, SelectiveBuildOperator] = {} + for op_name, op in list(lhs.items()) + list(rhs.items()): + new_op = op + if op_name in operators: + new_op = combine_operators(operators[op_name], op) + + operators[op_name] = new_op + + return operators + + +def strip_operator_overload_name(op_name: str) -> str: + return op_name.split(".")[0] diff --git a/lib/python3.10/site-packages/torchgen/selective_build/selector.py b/lib/python3.10/site-packages/torchgen/selective_build/selector.py new file mode 100644 index 0000000000000000000000000000000000000000..04acc354203ade2f48dcef56fd9d9ef70c82ad1d --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/selective_build/selector.py @@ -0,0 +1,352 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import yaml + +from torchgen.selective_build.operator import ( + merge_debug_info, + merge_operator_dicts, + SelectiveBuildOperator, + strip_operator_overload_name, +) + + +if TYPE_CHECKING: + from torchgen.model import NativeFunction + + +# A SelectiveBuilder holds information extracted from the selective build +# YAML specification. +# +# It includes information about the build's selectivity, the debug_info +# associated with this selective build (opaque string), and the set of +# operators that should be included in the build. +# +@dataclass(frozen=True) +class SelectiveBuilder: + # If true, then the build is not selective, and includes all + # operators. + include_all_operators: bool + + # Debug Information at the selective/custom build level. + _debug_info: tuple[str, ...] | None + + # A dictionary of operator -> operator metadata. + operators: dict[str, SelectiveBuildOperator] + + # A dictionary of selected kernel tags and dtypes. Typically a + # PyTorch Operator Kernel (function) may have many code paths + # that are specialized for many many Tensor dtypes, so it's not + # one per kernel function, but there could be many per kernel + # function. The tag isn't a kernel function name, but some fragment + # of the kernel function implementation itself. + kernel_metadata: dict[str, list[str]] + + # ExecuTorch only. A dictionary of kernel tag -> list of (list of input + # dtypes for tensor-like input args). + # This is from selective.yaml + et_kernel_metadata: dict[str, list[str]] + + # A set of all the custom torch bind classes used by the selected models + # Stored as a set internally to remove duplicates proactively, but written + # as a list to yamls + custom_classes: set[str] + + # A set of all the build features used by the selected models + # Stored as a set internally to remove duplicates proactively, but written + # as a list to yamls + build_features: set[str] + + # If true, then fragments for all dtypes for all kernel functions + # are included as well as all custom classes. This is typically set when any one of the + # operator lists is generated from a mechanism other than + # tracing based selective build. + include_all_non_op_selectives: bool + + @staticmethod + def get_nop_selector() -> SelectiveBuilder: + return SelectiveBuilder.from_yaml_dict({"include_all_operators": True}) + + @staticmethod + def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder: + valid_top_level_keys = { + "include_all_non_op_selectives", + "include_all_operators", + "debug_info", + "operators", + "kernel_metadata", + "et_kernel_metadata", + "custom_classes", + "build_features", + } + top_level_keys = set(data.keys()) + if len(top_level_keys - valid_top_level_keys) > 0: + raise Exception( # noqa: TRY002 + "Got unexpected top level keys: {}".format( + ",".join(top_level_keys - valid_top_level_keys), + ) + ) + include_all_operators = data.get("include_all_operators", False) + assert isinstance(include_all_operators, bool) + + debug_info = None + if "debug_info" in data: + di_list = data["debug_info"] + assert isinstance(di_list, list) + + debug_info = tuple(str(x) for x in di_list) + + operators = {} + operators_dict = data.get("operators", {}) + assert isinstance(operators_dict, dict) + + for k, v in operators_dict.items(): + operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v) + + kernel_metadata = {} + kernel_metadata_dict = data.get("kernel_metadata", {}) + assert isinstance(kernel_metadata_dict, dict) + + for k, v in kernel_metadata_dict.items(): + kernel_metadata[str(k)] = [str(dtype) for dtype in v] + + et_kernel_metadata = data.get("et_kernel_metadata", {}) + assert isinstance(et_kernel_metadata, dict) + + custom_classes = data.get("custom_classes", []) + assert isinstance(custom_classes, Iterable) + custom_classes = set(custom_classes) + + build_features = data.get("build_features", []) + assert isinstance(build_features, Iterable) + build_features = set(build_features) + + include_all_non_op_selectives = data.get("include_all_non_op_selectives", False) + assert isinstance(include_all_non_op_selectives, bool) + + return SelectiveBuilder( + include_all_operators, + debug_info, + operators, + kernel_metadata, + et_kernel_metadata, + custom_classes, # type: ignore[arg-type] + build_features, # type: ignore[arg-type] + include_all_non_op_selectives, + ) + + @staticmethod + def from_yaml_str(config_contents: str) -> SelectiveBuilder: + contents = yaml.safe_load(config_contents) + return SelectiveBuilder.from_yaml_dict(contents) + + @staticmethod + def from_yaml_path(config_path: str) -> SelectiveBuilder: + with open(config_path) as f: + contents = yaml.safe_load(f) + return SelectiveBuilder.from_yaml_dict(contents) + + @staticmethod + def from_legacy_op_registration_allow_list( + allow_list: set[str], is_root_operator: bool, is_used_for_training: bool + ) -> SelectiveBuilder: + operators = {} + for op in allow_list: + operators[op] = { + "name": op, + "is_root_operator": is_root_operator, + "is_used_for_training": is_used_for_training, + "include_all_overloads": True, + } + return SelectiveBuilder.from_yaml_dict( + { + "operators": operators, + "include_all_non_op_selectives": True, + } + ) + + def is_operator_selected(self, name: str) -> bool: + if self.include_all_operators: + return True + + if name in self.operators: + return True + name = strip_operator_overload_name(name) + return name in self.operators and self.operators[name].include_all_overloads + + def is_native_function_selected(self, func: NativeFunction) -> bool: + op_name = op_name_from_native_function(func) + return self.is_operator_selected(op_name) + + def is_operator_selected_for_training(self, name: str) -> bool: + if not self.is_operator_selected(name): + return False + if self.include_all_operators: + return True + + not_training_op = SelectiveBuildOperator( + name="", + is_root_operator=False, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + op = not_training_op + if name in self.operators: + op = self.operators[name] + + name = strip_operator_overload_name(name) + base_op = not_training_op + if name in self.operators: + base_op = self.operators[name] + + return op.is_used_for_training or ( + base_op.include_all_overloads and base_op.is_used_for_training + ) + + def is_native_function_selected_for_training(self, func: NativeFunction) -> bool: + op_name = op_name_from_native_function(func) + return self.is_operator_selected_for_training(op_name) + + def is_root_operator(self, name: str) -> bool: + if not self.is_operator_selected(name): + return False + if self.include_all_operators: + return True + + if name in self.operators: + op: SelectiveBuildOperator = self.operators[name] + return op.is_root_operator + name = strip_operator_overload_name(name) + if name not in self.operators: + return False + base_op: SelectiveBuildOperator = self.operators[name] + return base_op.include_all_overloads and base_op.is_root_operator + + def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool: + if self.include_all_operators or self.include_all_non_op_selectives: + return True + + return ( + kernel_tag in self.kernel_metadata + and dtype in self.kernel_metadata[kernel_tag] + ) + + def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]: + """ + Return a list of kernel keys that cover the used ops + """ + # If no kernel metadata, either it's implied by include_all_operators=True or the op is not used. + if op_name not in self.et_kernel_metadata: + return kernel_key if self.include_all_operators else [] + # Otherwise, only return the specific kernel keys. + + result_set = set() + + for model_kernel_keys in self.et_kernel_metadata[op_name]: + key_found = False + for key in kernel_key: + # Don't compare the version for now + if ( + key != "default" + and key.split("/")[1] == model_kernel_keys.split("/")[1] + ): + result_set.add(key) + key_found = True + break + if not key_found: + if "default" not in kernel_key: + raise Exception("Missing kernel for the model") # noqa: TRY002 + else: + result_set.add("default") + + return list(result_set) + + def to_dict(self) -> dict[str, object]: + ret: dict[str, object] = { + "include_all_non_op_selectives": self.include_all_non_op_selectives, + "include_all_operators": self.include_all_operators, + } + operators = {} + for op_name, op in self.operators.items(): + operators[op_name] = op.to_dict() + ret["operators"] = operators + + if self._debug_info is not None: + ret["debug_info"] = sorted(self._debug_info) + + ret["kernel_metadata"] = { + k: sorted(v) for (k, v) in self.kernel_metadata.items() + } + + ret["et_kernel_metadata"] = self.et_kernel_metadata + + ret["custom_classes"] = sorted(self.custom_classes) + + ret["build_features"] = sorted(self.build_features) + + return ret + + +def merge_kernel_metadata( + lhs: dict[str, list[str]], + rhs: dict[str, list[str]], +) -> dict[str, list[str]]: + kernel_metadata: dict[str, list[str]] = {} + for tag_name, dtypes in list(lhs.items()) + list(rhs.items()): + dtypes_copy = set(dtypes) + if tag_name in kernel_metadata: + dtypes_copy |= set(kernel_metadata[tag_name]) + + kernel_metadata[tag_name] = list(dtypes_copy) + + return kernel_metadata + + +def merge_et_kernel_metadata( + lhs: dict[str, list[str]], + rhs: dict[str, list[str]], +) -> dict[str, list[str]]: + merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set) + for op in list(lhs.keys()) + list(rhs.keys()): + merge_et_kernel_metadata[op].update(lhs.get(op, [])) + merge_et_kernel_metadata[op].update(rhs.get(op, [])) + + return {op: sorted(val) for op, val in merge_et_kernel_metadata.items()} + + +def combine_selective_builders( + lhs: SelectiveBuilder, rhs: SelectiveBuilder +) -> SelectiveBuilder: + include_all_operators = lhs.include_all_operators or rhs.include_all_operators + debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info) + operators = merge_operator_dicts(lhs.operators, rhs.operators) + kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata) + et_kernel_metadata = merge_et_kernel_metadata( + lhs.et_kernel_metadata, rhs.et_kernel_metadata + ) + include_all_non_op_selectives = ( + lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives + ) + custom_classes = lhs.custom_classes.union(rhs.custom_classes) + build_features = lhs.build_features.union(rhs.build_features) + return SelectiveBuilder( + include_all_operators, + debug_info, + operators, + kernel_metadata, + et_kernel_metadata, + custom_classes, + build_features, + include_all_non_op_selectives, + ) + + +def op_name_from_native_function(f: NativeFunction) -> str: + # This was originally read from the 'operator_name_with_overload' field in the + # declaration dict, which was the part before the first '(' in 'schema_string'. + return f"{f.namespace}::{f.func.name}" diff --git a/lib/python3.10/site-packages/torchgen/static_runtime/__init__.py b/lib/python3.10/site-packages/torchgen/static_runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/torchgen/static_runtime/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/static_runtime/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..386aff89f8b380e8695bbd21fda65844a667b0ab Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/static_runtime/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/static_runtime/__pycache__/config.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/static_runtime/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f8b24872682f69daa5c2a2e42c4dab6f4d29b10 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/static_runtime/__pycache__/config.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbd3d831b0b4e8cf87a4ffb91bf145a4f8c7fb35 Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/static_runtime/__pycache__/generator.cpython-310.pyc b/lib/python3.10/site-packages/torchgen/static_runtime/__pycache__/generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ee663b1c7e710502d299aa11e7ca010e70b694d Binary files /dev/null and b/lib/python3.10/site-packages/torchgen/static_runtime/__pycache__/generator.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchgen/static_runtime/config.py b/lib/python3.10/site-packages/torchgen/static_runtime/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7b541fa2c1287921613384aec2fee2cd7d4e97 --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/static_runtime/config.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup + + +def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str: + if isinstance(g, NativeFunctionsGroup): + return str(g.functional.func.name.name.base) + else: + return str(g.view.root_name) + + +is_hand_written_ops_ = frozenset( + ( + "abs", + "add", + "addmm", + "all", + "any", + "argmin", + "bmm", + "clamp", + "clamp_min", + "cumsum", + "div", + "fmod", + "index_select", + "leaky_relu", + "linear", + "log", + "matmul", + "mul", + "narrow_copy", + "nonzero", + "pow", + "remainder", + "sigmoid", + "sign", + "sub", + "tanh", + "detach", + "expand_as", + "flatten", + "narrow", + "reshape_as", + "select", + "slice", + "softmax", + "split", + "squeeze", + "transpose", + "view", + "where", + ) +) + + +def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: + name_base = func_name_base_str(g) + return name_base in is_hand_written_ops_ + + +def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None: + assert index == 0 or index == 1 + if op_name == "addr": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["vec1"] = "at::rand({6})" + arg_map["vec2"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["vec1"] = "at::rand({22})" + arg_map["vec2"] = "at::rand({22})" + return + if op_name == "mv": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["vec"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["vec"] = "at::rand({22})" + return + if op_name == "addbmm": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + else: + arg_map["self"] = "at::rand({22, 22})" + return + if op_name == "cross": + if index == 0: + arg_map["self"] = "at::rand({3, 3, 3})" + arg_map["other"] = "at::rand({3, 3, 3})" + else: + arg_map["self"] = "at::rand({22, 3, 22})" + arg_map["other"] = "at::rand({22, 3, 22})" + return + if op_name == "take": + if index == 0: + arg_map["index"] = "at::randint(0, 216, {20}, torch::kInt64)" + else: + arg_map["index"] = "at::randint(0, 1000, {100}, torch::kInt64)" + return + if op_name == "take_along_dim": + if index == 0: + arg_map["indices"] = "at::argsort(self0, 1, true)" + else: + arg_map["indices"] = "at::argsort(self1, 1, true)" + return + if op_name == "masked_select": + if index == 0: + arg_map["mask"] = "at::randn({6, 6, 6}) > 0.5" + else: + arg_map["mask"] = "at::rand({22, 22, 22}) > 0.5" + return + if op_name == "orgqr": + if index == 0: + arg_map["input2"] = "at::rand({6, 6})" + else: + arg_map["input2"] = "at::rand({22, 22})" + return + if op_name == "ormqr": + if index == 0: + arg_map["input2"] = "at::rand({6, 6})" + else: + arg_map["input2"] = "at::rand({22, 22})" + return + if op_name == "quantile": + if index == 0: + arg_map["q"] = "at::rand({6})" + arg_map["interpolation"] = '"linear"' + else: + arg_map["q"] = "at::rand({22})" + arg_map["interpolation"] = '"linear"' + return + if op_name == "nanquantile": + if index == 0: + arg_map["q"] = "at::rand({6})" + arg_map["interpolation"] = '"linear"' + else: + arg_map["q"] = "at::rand({22})" + arg_map["interpolation"] = '"linear"' + return + if op_name == "multi_margin_loss": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["target"] = "at::randint(6, {6}, torch::kInt64)" + arg_map["weight"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["target"] = "at::randint(22, {22}, torch::kInt64)" + arg_map["weight"] = "at::rand({22})" + return + if op_name == "multilabel_margin_loss": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["target"] = "at::randint(6, {6, 6}, torch::kInt64)" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["target"] = "at::randint(22, {22, 22}, torch::kInt64)" + return + if op_name == "nll_loss": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["target"] = "at::randint(6, {6}, torch::kInt64)" + arg_map["weight"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["target"] = "at::randint(22, {22}, torch::kInt64)" + arg_map["weight"] = "at::rand({22})" + return + if op_name == "nll_loss2d": + if index == 0: + arg_map["self"] = "at::rand({6, 6, 6, 6})" + arg_map["target"] = "at::randint(6, {6, 6, 6}, torch::kInt64)" + arg_map["weight"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22, 22, 22})" + arg_map["target"] = "at::randint(22, {22, 22, 22}, torch::kInt64)" + arg_map["weight"] = "at::rand({22})" + return + if op_name in ( + "fft_fft", + "fft_ifft", + "fft_rfft", + "fft_irfft", + "fft_hfft", + "fft_ihfft", + ): + arg_map["norm"] = '"forward"' + return + if op_name == "linalg_tensorinv": + if index == 0: + arg_map["self"] = "at::rand({6, 6, 6, 6})" + arg_map["ind"] = "2" + else: + arg_map["self"] = "at::rand({22, 22, 22, 22})" + arg_map["ind"] = "2" + return + if op_name == "addmv": + if index == 0: + arg_map["self"] = "at::rand({2})" + arg_map["mat"] = "at::rand({2, 2})" + arg_map["vec"] = "at::rand({2})" + else: + arg_map["self"] = "at::rand({35})" + arg_map["mat"] = "at::rand({35, 35})" + arg_map["vec"] = "at::rand({35})" + return + if op_name == "acosh": + if index == 0: + arg_map["self"] = "at::rand({2, 2, 2}) + at::ones({2, 2, 2})" + else: + arg_map["self"] = "at::rand({5, 5, 5}) + at::ones({5, 5, 5})" + return + if op_name == "adaptive_max_pool2d_backward": + if index == 0: + arg_map["grad_output"] = "at::rand({2, 2, 2}, at::kFloat)" + arg_map["self"] = "at::rand({2, 2, 2}, at::kFloat)" + arg_map["indices"] = "at::randint(0, 1, {2, 2, 2}, at::kLong)" + else: + arg_map["grad_output"] = "at::rand({3, 3, 3}, at::kFloat)" + arg_map["self"] = "at::rand({3, 3, 3}, at::kFloat)" + arg_map["indices"] = "at::randint(0, 1, {3, 3, 3}, at::kLong)" + return + if op_name == "adaptive_max_pool3d_backward": + if index == 0: + arg_map["grad_output"] = "at::rand({2, 2, 2, 2}, at::kFloat)" + arg_map["self"] = "at::rand({2, 2, 2, 2}, at::kFloat)" + arg_map["indices"] = "at::randint(0, 1, {2, 2, 2, 2}, at::kLong)" + else: + arg_map["grad_output"] = "at::rand({3, 3, 3, 3}, at::kFloat)" + arg_map["self"] = "at::rand({3, 3, 3, 3}, at::kFloat)" + arg_map["indices"] = "at::randint(0, 1, {3, 3, 3, 3}, at::kLong)" + return + if op_name == "bitwise_left_shift": + if index == 0: + arg_map["self"] = "at::randint(1, 1 << 4, {6, 6, 6}, at::kInt)" + arg_map["other"] = "at::randint(1, 26, {6, 6, 6}, at::kInt)" + else: + arg_map["self"] = "at::randint(1, 1 << 4, {22, 22, 22}, at::kInt)" + arg_map["other"] = "at::randint(1, 26, {22, 22, 22}, at::kInt)" + return + if op_name == "bitwise_right_shift": + if index == 0: + arg_map["self"] = "at::randint(1 << 21, 1 << 30, {6, 6, 6}, at::kInt)" + arg_map["other"] = "at::randint(1, 22, {6, 6, 6}, at::kInt)" + else: + arg_map["self"] = "at::randint(1 << 21, 1 << 30, {22, 22, 22}, at::kInt)" + arg_map["other"] = "at::randint(1, 22, {22, 22, 22}, at::kInt)" + return + if op_name == "gather": + if index == 0: + arg_map["self"] = "at::randint(1, 100, {2,2,2}, at::kInt)" + arg_map["dim"] = "1" + arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)" + arg_map["sparse_grad"] = "false" + else: + arg_map["self"] = "at::randint(1, 100, {5,5,5}, at::kInt)" + arg_map["dim"] = "1" + arg_map["index"] = "at::randint(0, 4, {5,5,5}, torch::kInt64)" + arg_map["sparse_grad"] = "false" + return + if op_name == "gelu": + if index == 0: + arg_map["self"] = "at::rand({6, 6, 6})" + arg_map["approximate"] = '"tanh"' + else: + arg_map["self"] = "at::rand({22, 22, 22})" + arg_map["approximate"] = '"tanh"' + return + if op_name == "gelu_backward": + if index == 0: + arg_map["grad_output"] = "at::rand({6, 6, 6})" + arg_map["self"] = "at::rand({6, 6, 6})" + arg_map["approximate"] = '"tanh"' + else: + arg_map["grad_output"] = "at::rand({22, 22, 22})" + arg_map["self"] = "at::rand({22, 22, 22})" + arg_map["approximate"] = '"tanh"' + return + if op_name == "index_add": + if index == 0: + arg_map["self"] = "at::rand({2})" + arg_map["dim"] = "0" + arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)" + arg_map["source"] = "at::rand({2})" + arg_map["alpha"] = "2" + else: + arg_map["self"] = "at::rand({16})" + arg_map["dim"] = "0" + arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)" + arg_map["source"] = "at::rand({16})" + arg_map["alpha"] = "2" + return + if op_name == "index_copy": + if index == 0: + arg_map["self"] = "at::rand({2})" + arg_map["dim"] = "0" + arg_map["index"] = "at::randint(0, 1, {2}, at::kLong)" + arg_map["source"] = "at::rand({2})" + else: + arg_map["self"] = "at::rand({32})" + arg_map["dim"] = "0" + arg_map["index"] = "at::randint(0, 10, {32}, at::kLong)" + arg_map["source"] = "at::rand({32})" + return + if op_name == "linalg_cross": + if index == 0: + arg_map["self"] = "at::rand({6, 3, 6})" + arg_map["other"] = "at::rand({6, 3, 6})" + arg_map["dim"] = "1" + else: + arg_map["self"] = "at::rand({22, 3, 22})" + arg_map["other"] = "at::rand({22, 3, 22})" + arg_map["dim"] = "1" + return + if op_name == "nll_loss_backward": + if index == 0: + arg_map["grad_output"] = "at::rand({})" + arg_map["self"] = "at::rand({6})" + arg_map["target"] = "at::randint(0, 5, {6}, torch::kInt64)" + arg_map["weight"] = "at::rand({6})" + arg_map["reduction"] = "1" + arg_map["ignore_index"] = "1" + arg_map["total_weight"] = "at::rand({})" + else: + arg_map["grad_output"] = "at::rand({})" + arg_map["self"] = "at::rand({36})" + arg_map["target"] = "at::randint(0, 11, {36}, torch::kInt64)" + arg_map["weight"] = "at::rand({36})" + arg_map["reduction"] = "1" + arg_map["ignore_index"] = "1" + arg_map["total_weight"] = "at::rand({})" + return + if op_name in ["scatter", "scatter_add", "_scatter_reduce"]: + if index == 0: + arg_map["self"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)" + arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)" + arg_map["src"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)" + else: + arg_map["self"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)" + arg_map["index"] = "at::randint(0, 1, {5,5,5}, torch::kInt64)" + arg_map["src"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)" + if "reduce" in arg_map: + arg_map["reduce"] = '"sum"' if op_name == "_scatter_reduce" else '"add"' + return + if op_name == "scatter_reduce": + arg_map["reduce"] = '"mean"' + if index == 0: + arg_map["index"] = "at::randint(6, {6, 6, 6}, torch::kInt64)" + else: + arg_map["index"] = "at::randint(22, {22, 22, 22}, torch::kInt64)" + return + if op_name == "special_zeta": + if index == 0: + arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})" + arg_map["other"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})" + else: + arg_map["self"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})" + arg_map["other"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})" + return + if op_name == "_convert_indices_from_csr_to_coo": + if index == 0: + arg_map["crow_indices"] = "torch::tensor({1}, torch::kInt32)" + arg_map["col_indices"] = "torch::tensor({0, 1, 0}, torch::kInt32)" + arg_map["out_int32"] = "false" + else: + arg_map["crow_indices"] = "torch::tensor({0}, torch::kInt32)" + arg_map[ + "col_indices" + ] = "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)" + arg_map["out_int32"] = "false" + return + if op_name == "_convert_indices_from_coo_to_csr": + if index == 0: + arg_map["self"] = "at::randint(0, 3, {2}, at::kInt)" + arg_map["size"] = "10" + arg_map["out_int32"] = "false" + else: + arg_map["self"] = "at::randint(0, 3, {12}, at::kInt)" + arg_map["size"] = "24" + arg_map["out_int32"] = "false" + return + if op_name in ("diagonal", "linalg_diagonal"): + arg_map["offset"] = "0" + arg_map["dim1"] = "2" + arg_map["dim2"] = "1" + return diff --git a/lib/python3.10/site-packages/torchgen/static_runtime/gen_static_runtime_ops.py b/lib/python3.10/site-packages/torchgen/static_runtime/gen_static_runtime_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7357173746748bacfc3e540ebcf37426b5455e --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/static_runtime/gen_static_runtime_ops.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +import argparse +import itertools +import os +from typing import Sequence, TypeVar, Union + +from libfb.py.log import set_simple_logging # type: ignore[import] + +from torchgen import gen +from torchgen.context import native_function_manager +from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup +from torchgen.static_runtime import config, generator + + +# Given a list of `grouped_native_functions` sorted by their op names, return a list of +# lists each of which groups ops that share the base name. For example, `mean` and +# `mean.dim` are grouped together by this function. + +NativeGroupT = TypeVar( + "NativeGroupT", + bound=Union[NativeFunctionsGroup, NativeFunctionsViewGroup], +) + + +def group_functions_by_op_name( + grouped_native_functions: Sequence[NativeGroupT], +) -> Sequence[Sequence[NativeGroupT]]: + if not grouped_native_functions: + return [] + groups = [] + + def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: + with native_function_manager(g): + return generator.is_supported(g) + + eligible_ops = (g for g in grouped_native_functions if is_supported(g)) + groups = [ + list(group) + for k, group in ( + itertools.groupby( + eligible_ops, + key=config.func_name_base_str, + ) + ) + ] + + return groups + + +def clang_format(cpp_file_path: str) -> None: + import subprocess + + subprocess.check_call(["clang-format", "-i", cpp_file_path]) + + +def write_cpp(cpp_ops: Sequence[str], file_path: str) -> None: + code = "\n".join(cpp_ops) + generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN +// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch {{ +namespace jit {{ + +{code} + +}} // namespace jit +}} // namespace torch +""" + with open(file_path, "w") as f: + f.write(generated) + clang_format(file_path) + + +def write_test_cpp(cpp_ops: Sequence[str], file_path: str) -> None: + code = "\n".join(cpp_ops) + generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN +// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py +#include +#include +#include + +#include "test_utils.h" + +using namespace caffe2; +using namespace torch; +using namespace torch::jit; +using namespace torch::jit::test; +using c10::IValue; + +{code} + +""" + with open(file_path, "w") as f: + f.write(generated) + clang_format(file_path) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate ATen source files") + parser.add_argument( + "-s", + "--source-path", + help="path to source directory for ATen", + default="caffe2/aten/src/ATen", + ) + parser.add_argument( + "-p", + "--generated-ops-cpp-path", + help="path to directory to generate op dispatcher .cpp file", + default="caffe2/torch/csrc/jit/runtime/static/generated_ops.cpp", + ) + parser.add_argument( + "-t", + "--generated-ops-test-cpp-path", + help="path to directory to generate op dispatcher .cpp file", + default="caffe2/benchmarks/static_runtime/test_generated_ops.cc", + ) + options = parser.parse_args() + native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml") + tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml") + parsed_yaml = gen.parse_native_yaml(native_yaml_path, tags_yaml_path) + native_functions, backend_indices = ( + parsed_yaml.native_functions, + parsed_yaml.backend_indices, + ) + + op_generator = generator.GenOpDispatcher() + test_case_generator = generator.GenOpTestCase() + + native_functions_groups = [ + g + for g in gen.get_grouped_native_functions(native_functions) + if isinstance(g, NativeFunctionsGroup) + ] + + supported_functions_groups = group_functions_by_op_name(native_functions_groups) + + out_variant_op_result = [ + op_generator.out_variant(groups, backend_indices[DispatchKey.CPU]) + for groups in supported_functions_groups + ] + out_variant_test_result = [ + test_case_generator.out_variant(groups) for groups in supported_functions_groups + ] + + native_functions_view_groups = [ + g + for g in gen.get_grouped_by_view_native_functions(native_functions) + if isinstance(g, NativeFunctionsViewGroup) + ] + + supported_functions_view_groups = group_functions_by_op_name( + native_functions_view_groups + ) + + view_op_result = [ + op_generator.view(groups, backend_indices[DispatchKey.CPU]) + for groups in supported_functions_view_groups + ] + view_test_result = [ + test_case_generator.view(groups) for groups in supported_functions_view_groups + ] + + op_result = out_variant_op_result + ["\n\n"] + view_op_result + test_result = out_variant_test_result + ["\n\n"] + view_test_result + + write_cpp(op_result, options.generated_ops_cpp_path) + write_test_cpp(test_result, options.generated_ops_test_cpp_path) + + print( + "\ntotal grouped native ops: %d" + % len(gen.get_grouped_native_functions(native_functions)) + ) + + print("grouped native ops with out variant: %d" % len(native_functions_groups)) + supported_functions_num = sum(len(groups) for groups in supported_functions_groups) + print("generated functions groups with out variant: %d" % supported_functions_num) + + print("\nview grouped native ops: %d" % len(native_functions_view_groups)) + supported_view_functions_num = sum( + len(groups) for groups in supported_functions_view_groups + ) + print("generated functions view groups: %d" % supported_view_functions_num) + + print( + "\noverall generated : %d" + % (supported_functions_num + supported_view_functions_num) + ) + + +if __name__ == "__main__": + set_simple_logging(escape_newlines=False) + main() diff --git a/lib/python3.10/site-packages/torchgen/static_runtime/generator.py b/lib/python3.10/site-packages/torchgen/static_runtime/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..7bbb7f64d8644252cd6a92492c0c36b40d623b2f --- /dev/null +++ b/lib/python3.10/site-packages/torchgen/static_runtime/generator.py @@ -0,0 +1,809 @@ +from __future__ import annotations + +import json +import logging +import math +from typing import Sequence + +import torchgen.api.cpp as cpp +from torchgen.context import native_function_manager +from torchgen.model import ( + Argument, + BackendIndex, + BaseTy, + BaseType, + FunctionSchema, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + OptionalType, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.static_runtime import config + + +logger: logging.Logger = logging.getLogger() + + +def has_alias( + arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments], +) -> bool: + for arg in arguments: + annotation = getattr(arg, "annotation", None) + if not annotation: + continue + alias_set = getattr(annotation, "alias_set", ()) + if alias_set: + return True + return False + + +BLOCKED_OPS = frozenset( + ( + # non cpu ops + "sparse_sampled_addmm", + "hspmm", + "linalg_svdvals", + # sparse ops + "sspaddmm", + "coalesce", + "_indices", + "indices", + "_values", + "values", + "crow_indices", + "col_indices", + # deprecated ops + "floor_divide", + "ger", + # buggy ops + "conj_physical", # P495807361 + "binary_cross_entropy", # P496394764 + "arccosh", + # uncommon ops + "cholesky", + "lu_solve", + "linalg_cholesky", + "linalg_householder_product", + "linalg_ldl_solve", + "_compute_linear_combination", + # training related ops + "_make_dual", + # cannot call directly + "_fw_primal", + # no documentation + "_index_reduce", + # TODO: these ones got added recently and need manual inspection + "_new_zeros_with_same_feature_meta", + "_conj_physical", + "binary_cross_entropy_with_logits", + "bincount", + "conv_tbc", + "copy", + "_copy_from", + "_copy_from_and_resize", + "count_nonzero", + "cudnn_affine_grid_generator", + "cudnn_affine_grid_generator_backward", + "cudnn_grid_sampler", + "diag_embed", + "embedding", + "embedding_dense_backward", + "_embedding_bag_dense_backward", + "_embedding_bag_per_sample_weights_backward", + "grid_sampler_2d", + "_grid_sampler_2d_cpu_fallback", + "grid_sampler_3d", + "isnan", + "mkldnn_linear", + "median", + "nanmedian", + "_sparse_sparse_matmul", + "batch_norm_backward_elemt", + "_euclidean_dist", + "pixel_shuffle", + "pixel_unshuffle", + "channel_shuffle", + "_reshape_nested_backward", + "relu", + "prelu", + "celu", + "slice_scatter", + "select_scatter", + "diagonal_scatter", + "sum", + "_mkldnn_transpose", + "_nested_tensor_from_mask", + "_nested_from_padded", + "_nested_tensor_size", + "_nested_from_padded_and_nested_example", + "_standard_gamma_grad", + "_dirichlet_grad", + "native_norm", + "_sparse_softmax", + "_sparse_softmax_backward_data", + "_sparse_log_softmax", + "_sparse_log_softmax_backward_data", + "zero", + "_sparse_addmm", + "sparse_mask", + "_sparse_mask_projection", + "_to_dense", + "_coalesce", + "_coalesced", + "copy_sparse_to_sparse", + "to_sparse", + "to_sparse_csr", + "to_sparse_csc", + "to_mkldnn", + "quantize_per_tensor_dynamic", + "quantize_per_channel", + "q_per_channel_scales", + "q_per_channel_zero_points", + "int_repr", + "_make_per_channel_quantized_tensor", + "set", + "lift", + "lift_fresh", + "lift_fresh_copy", + "masked_scatter", + "_masked_softmax", + "_masked_softmax_backward", + "put", + "index_reduce", + "trace", + "_cholesky_solve_helper", + "dist", + "max", + "_torch_cuda_cu_linker_symbol_op", + "glu_jvp", + "glu_backward_jvp", + "hardswish_backward", + "rrelu_with_noise_backward", + "mkldnn_adaptive_avg_pool2d_backward", + "_adaptive_avg_pool2d_backward", + "_adaptive_avg_pool3d_backward", + "isinf", + "linalg_lu_solve", + "linalg_vecdot", + "linalg_matrix_exp", + "linalg_eigvalsh", + "_test_warn_in_autograd", + "_test_autograd_multiple_dispatch_view", + "_test_autograd_multiple_dispatch_view_copy", + "_segment_reduce", + "_segment_reduce_backward", + "_fw_primal_copy", + "_make_dual_copy", + "view_as_real_copy", + "view_as_complex_copy", + "_conj_copy", + "_neg_view_copy", + "diagonal_copy", + "detach_copy", + "squeeze_copy", + "t_copy", + "unsqueeze_copy", + "_indices_copy", + "_values_copy", + "indices_copy", + "values_copy", + "crow_indices_copy", + "col_indices_copy", + "ccol_indices", + "ccol_indices_copy", + "row_indices", + "row_indices_copy", + "unfold_copy", + "alias_copy", + "_triton_multi_head_attention", + "special_airy_ai", + "special_bessel_j0", + "special_bessel_j1", + "special_bessel_y0", + "special_bessel_y1", + "special_chebyshev_polynomial_t", + "special_chebyshev_polynomial_u", + "special_chebyshev_polynomial_v", + "special_chebyshev_polynomial_w", + "special_hermite_polynomial_h", + "special_hermite_polynomial_he", + "special_laguerre_polynomial_l", + "special_legendre_polynomial_p", + "special_modified_bessel_i0", + "special_modified_bessel_i1", + "special_modified_bessel_k0", + "special_modified_bessel_k1", + "special_scaled_modified_bessel_k0", + "special_scaled_modified_bessel_k1", + "special_shifted_chebyshev_polynomial_t", + "special_shifted_chebyshev_polynomial_u", + "special_shifted_chebyshev_polynomial_v", + "special_shifted_chebyshev_polynomial_w", + "special_spherical_bessel_j0", + "_foobar", + "_nested_tensor_strides", + "_nested_tensor_storage_offsets", + "_nested_get_values", # no CPU backend + "_nested_get_values_copy", # no CPU backend + "_nested_view_from_jagged", # testing needs to be patched + "_nested_view_from_jagged_copy", # testing needs to be patched + "_nested_view_from_buffer", # testing needs to be patched + "_nested_view_from_buffer_copy", # testing needs to be patched + "_int_mm", # testing needs to be patched + "_to_sparse_csc", # testing needs to be patched + "_to_sparse_csr", # testing needs to be patched + "segment_reduce", # testing needs to be patched + ) +) + + +def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: + base_op_name = "" + func = None + if isinstance(g, NativeFunctionsViewGroup): + base_op_name = g.view.root_name + func = g.view.func + else: + base_op_name = g.out.func.name.name.base + func = g.out.func + if config.is_hand_written(g): + logger.info("HAND WRITTEN: %s", base_op_name) + return False + if base_op_name in BLOCKED_OPS: + logger.info("BLOCKED: %s", base_op_name) + return False + for arg in func.schema_order_arguments(): + maybe_method = ivalue_type_conversion_method(arg.type) + if not maybe_method: + # Type converting is unsupported yet. + logger.info("NOT SUPPORTED TYPE CONVERTING: %s", func) + return False + + if isinstance(g, NativeFunctionsViewGroup): + # TODO: stop doing type tests by converting to C++ and then testing + # the string, just test the dang thing directly + if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type(): + # Returns a non-Tensor value. + logger.info("NON-TENSOR RET TYPE: %s", str(func)) + return False + return True + + # For out variant ops, we need to check the arguments of its functional func. + for arg in g.functional.func.schema_order_arguments(): + maybe_method = ivalue_type_conversion_method(arg.type) + if not maybe_method: + # Type converting is unsupported yet. + logger.info("NOT SUPPORTED TYPE CONVERTING: %s", g.functional.func) + return False + + if not g.structured: + # In case of unstructured op, we check if it has out variant implementation. + # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last + # parameter. + if ( + not hasattr(g, "out") + or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)") + or not str(func.name).endswith(".out") + ): + return False + # TODO: stop type testing by converting to C++ + if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type(): + logger.info("NON_TENSOR RET TYPE: %s", func) + return False + if has_alias(func.arguments.non_out): + # This op may create an alias of inputs. + logger.info("INPUTS ALIAS: %s", base_op_name) + return False + return True + + +def ivalue_type_conversion_method( + arg_type: BaseType | OptionalType | Type, +) -> tuple[bool, str] | None: + """ + Return the method call expression of `c10::ivalue' to convert its contained value to + the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor, + this function returns ".toTensor()", so that it can be appended to the ivalue's + variable name to get the value of the expected type. + """ + type_conversion_methods = { + BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional()")), + BaseTy.int: ((False, "toInt()"), (False, "toOptional()")), + BaseTy.bool: ((False, "toBool()"), (False, "toOptional()")), + BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional()")), + BaseTy.ScalarType: ( + (False, "toScalarType()"), + (False, "toOptional()"), + ), + BaseTy.str: ( + (False, "toStringView()"), + (False, "toOptional()"), + ), + } + + base_ty_object = None + if isinstance(arg_type, BaseType): + base_ty_object = arg_type.name + elif isinstance(arg_type, OptionalType): + if not isinstance(arg_type.elem, BaseType): + # ListType is currently unsupported. + return None + base_ty_object = arg_type.elem.name + else: + return None + + if base_ty_object not in type_conversion_methods: + return None + methods = type_conversion_methods[base_ty_object] + if isinstance(arg_type, BaseType): + return methods[0] + return methods[1] + + +should_use_int_tensor_ops_ = frozenset( + ( + "bitwise_not", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + "bitwise_left_shift", + "bitwise_right_shift", + "gcd", + "lcm", + "scatter", + "gather", + "_convert_indices_from_coo_to_csr", + "_convert_indices_from_csr_to_coo", + ) +) +should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj")) + + +def should_use_int_tensor(op_name: str) -> bool: + return op_name in should_use_int_tensor_ops_ + + +def should_use_complex_tensor(op_name: str) -> bool: + return op_name in should_use_complex_tensor_ops_ + + +test_tensor_dim_ops_1_ = frozenset( + ( + "addmv", + "index_add", + "_convert_indices_from_coo_to_csr", + "_convert_indices_from_csr_to_coo", + "nll_loss_backward", + "dot", + "vdot", + "outer", + "ger", + ) +) +test_tensor_dim_ops_2_ = frozenset( + ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t") +) + + +def test_tensor_dim(op_name: str) -> int: + if op_name in test_tensor_dim_ops_1_: + return 1 + if op_name in test_tensor_dim_ops_2_: + return 2 + return 3 + + +test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}' +test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string) + + +def test_tensor_shape(op_name: str) -> str: + if op_name in test_tensor_shape_json: + return test_tensor_shape_json[op_name] + else: + return "" + + +def test_value_expression( + arg_type: BaseType | OptionalType | Type, index: int, op_name: str +) -> str: + tensor_size_ex = test_tensor_shape(op_name) + if tensor_size_ex == "": + num_tensors = 16 if index == 0 else 64 + num_dim = test_tensor_dim(op_name) + size_per_dim = math.ceil(num_tensors / float(num_dim)) + size_per_dim += size_per_dim % 2 + tensor_size_ex = "{{{}}}".format(",".join([f"{size_per_dim}"] * num_dim)) + if should_use_int_tensor(op_name): + tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)" + elif should_use_complex_tensor(op_name): + tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)" + else: + tensor_expression = f"at::rand({tensor_size_ex})" + + value_expressions = { + BaseTy.Tensor: tensor_expression, + BaseTy.int: "1", + BaseTy.bool: "false", + BaseTy.Scalar: "2", + BaseTy.ScalarType: "at::ScalarType::Float", + BaseTy.str: '"floor"', + } + + base_ty_object = None + if isinstance(arg_type, BaseType): + base_ty_object = arg_type.name + else: + assert isinstance(arg_type, OptionalType) and isinstance( + arg_type.elem, BaseType + ) + base_ty_object = arg_type.elem.name + assert base_ty_object in value_expressions, "not expected type" + value_expression = value_expressions[base_ty_object] + return value_expression + + +def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str: + assert not schema.is_out_fn() + schema_name = schema.name.name.base + arg_map = {} + for arg in schema.schema_order_arguments(): + test_value_exp = test_value_expression(arg.type, index, schema_name) + arg_map[arg.name] = test_value_exp + config.override_test_values(arg_map, schema_name, index) + arg_populations = [] + for arg_name, arg_value in arg_map.items(): + arg_populations.append(f"auto {arg_name}{index} = {arg_value}") + return ";\n ".join(arg_populations) + ";" + + +def generate_test_value_names(schema: FunctionSchema, index: int) -> str: + assert not schema.is_out_fn() + return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments()) + + +generate_test_ir_arguments_base_ty_to_type_str_ = { + BaseTy.Tensor: "Tensor", + BaseTy.int: "int", + BaseTy.float: "float", + BaseTy.str: "str", + BaseTy.Scalar: "int", + BaseTy.ScalarType: "int", + BaseTy.bool: "bool", +} + + +def generate_test_ir_arguments( + schema: FunctionSchema, +) -> list[tuple[str, str | None]]: + def ir_argument(arg: Argument) -> tuple[str, str | None]: + t = arg.type + add_optional = False + if isinstance(t, OptionalType): + t = t.elem + add_optional = True + assert isinstance(t, BaseType) + type_str = None + if t.name in generate_test_ir_arguments_base_ty_to_type_str_: + type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name] + if type_str and add_optional: + type_str = f"{type_str}?" + return ("%" + arg.name, type_str) + + return [ir_argument(arg) for arg in schema.schema_order_arguments()] + + +def generate_arg_extraction(schema: FunctionSchema) -> str: + arg_populations = [] + for i, arg in enumerate(schema.schema_order_arguments()): + maybe_method = ivalue_type_conversion_method(arg.type) + assert maybe_method + is_reference, type_conversion_method = maybe_method + reference = "&" if is_reference else "" + arg_populations.append( + f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}" + ) + return ";\n ".join(arg_populations) + ";" + + +def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: + kernel = backend_index.get_kernel(g.functional) + if g.structured or kernel is None: + return cpp.name(g.functional.func) + return kernel.kernel + + +def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: + kernel = backend_index.get_kernel(g.out) + if g.structured or kernel is None: + return cpp.name(g.out.func) + return kernel.kernel + + +def generate_non_out_variant_call( + g: NativeFunctionsGroup, backend_index: BackendIndex +) -> str: + schema = g.functional.func + assert not schema.is_out_fn() + kernel_name = get_kernel_name(g, backend_index) + arg_names = (arg.name for arg in schema.schema_order_arguments()) + namespace_name = "cpu" if g.structured else "native" + return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})' + + +def generate_call_to_view_ops( + g: NativeFunctionsViewGroup, backend_index: BackendIndex +) -> str: + schema = g.view.func + kernel_name = cpp.name(schema) + kernel = backend_index.get_kernel(g.view) + if kernel: + kernel_name = kernel.kernel + arg_names = (arg.name for arg in schema.schema_order_arguments()) + namespace_name = "native" + return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})' + + +def generate_out_variant_call( + g: NativeFunctionsGroup, backend_index: BackendIndex +) -> str: + schema = g.out.func + assert schema.is_out_fn() + arg_names = [] + kernel_name = get_out_kernel_name(g, backend_index) + if g.structured: + # structured op starts with the output tensor argument. + arg_names = [out_arg.name for out_arg in schema.arguments.out] + else: + arg_names = [] + for arg in schema.arguments.non_out: + if isinstance(arg, SelfArgument): + arg_names.append(arg.argument.name) + else: + assert isinstance(arg, Argument) + arg_names.append(arg.name) + if not g.structured: + assert len(schema.arguments.out) == 1 + arg_names.append(schema.arguments.out[0].name) + cpp_arg_names = ",".join(arg_names) + namespace_name = "cpu" if g.structured else "native" + return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})" + + +no_memory_resize_ops = frozenset( + ( + "isin.Scalar_Tensor", + "index_add", + "dot", + "vdot", + "nuclear_norm", + "histc", + "l1_loss", + "multi_margin_loss", + "multilabel_margin_loss", + "nll_loss", + "nll_loss2d", + "prod", + ) +) + + +def should_check_resize(schema: FunctionSchema) -> bool: + schema_str = str(schema) + type_variant_op_name = schema_str[: schema_str.find("(")] + return type_variant_op_name not in no_memory_resize_ops + + +def op_name_from_group(g: NativeFunctionsGroup) -> str: + return g.functional.func.name.name.base + + +class GenOpDispatcher: + def out_variant( + self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex + ) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsGroup) + generated_type_variant = self.out_variant_op_generator(g, backend_index) + generated_type_variants.append(generated_type_variant) + op_name = op_name_from_group(groups[0]) + body = "\n".join(generated_type_variants) + generated = f""" +REGISTER_OPERATOR_FUNCTOR( + aten::{op_name}, + aten_{op_name}, + [](Node* n) -> SROperator {{ + {body} + LogAndDumpSchema(n); + return nullptr; + }}); +""" + return generated + + def view( + self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex + ) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsViewGroup) + generated_type_variant = self.view_op_generator(g, backend_index) + generated_type_variants.append(generated_type_variant) + op_name = config.func_name_base_str(groups[0]) + body = "\n".join(generated_type_variants) + generated = f""" +REGISTER_NATIVE_OPERATOR_FUNCTOR( + aten::{op_name}, + aten_{op_name}, + [](Node* n) -> SROperator {{ + {body} + LogAndDumpSchema(n); + return nullptr; + }}); +""" + return generated + + def out_variant_op_generator( + self, g: NativeFunctionsGroup, backend_index: BackendIndex + ) -> str: + functional = g.functional + schema = str(functional.func) + populated_argument = generate_arg_extraction(g.functional.func) + functional_variant_call = generate_non_out_variant_call(g, backend_index) + assert len(g.out.func.arguments.out) == 1 + out_variable_name = str(g.out.func.arguments.out[0].name) + out_variant_call = generate_out_variant_call(g, backend_index) + generated = f""" + if (n->matches(torch::schema("aten::{schema}"))) {{ + return [](ProcessedNode* p_node) {{ + {populated_argument} + if (p_node->Output(0).isNone()) {{ + p_node->Output(0) = {functional_variant_call}; + return; + }} + auto& {out_variable_name} = p_node->Output(0).toTensor(); + fastResizeToZero({out_variable_name}); + {out_variant_call}; + }}; + }}""" + return generated + + def view_op_generator( + self, g: NativeFunctionsViewGroup, backend_index: BackendIndex + ) -> str: + schema = str(g.view.func) + populated_argument = generate_arg_extraction(g.view.func) + functional_variant_call = generate_call_to_view_ops(g, backend_index) + generated = f""" + if (n->matches(torch::schema("aten::{schema}"))) {{ + return [](ProcessedNode* p_node) {{ + {populated_argument} + p_node->Output(0) = {functional_variant_call}; + }}; + }}""" + return generated + + +class GenOpTestCase: + def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsGroup) + generated_type_variant = self.out_variant_op_test_case_generator(g) + generated_type_variants.append(generated_type_variant) + return "\n".join(generated_type_variants) + + def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsViewGroup) + generated_type_variant = self.view_op_test_case_generator(g) + generated_type_variants.append(generated_type_variant) + return "\n".join(generated_type_variants) + + def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str: + schema = g.functional.func + schema_str = str(schema) + assert schema_str.find("(") > 0 + type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_") + op_name = op_name_from_group(g) + assert type_variant_op_name.startswith(op_name) + + arg_types = generate_test_ir_arguments(schema) + arg_declarations = ", ".join( + ( + arg_name if arg_type is None else f"{arg_name}: {arg_type}" + for arg_name, arg_type in arg_types + ) + ) + arg_names = ", ".join((arg_name for arg_name, _ in arg_types)) + assert ( + len(schema.returns) == 1 + and isinstance(schema.returns[0].type, BaseType) + and schema.returns[0].type.name is BaseTy.Tensor + ) + test_value_definitions = generate_test_value_definitions(schema, 0) + test_value_names = generate_test_value_names(schema, 0) + test_value_definitions2 = generate_test_value_definitions(schema, 1) + test_value_names2 = generate_test_value_names(schema, 1) + check_resize = "true" if should_check_resize(schema) else "false" + generated = f""" +TEST(StaticRuntime, autogen_{type_variant_op_name}) {{ + const std::string script = R"IR( + graph({arg_declarations}): + %bias: None = prim::Constant() + %ret = aten::{op_name}({arg_names}) + %cloned = aten::clone(%ret, %bias) + return (%cloned) + )IR"; + + {test_value_definitions} + std::vector args{{{test_value_names}}}; + testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize}); + + {test_value_definitions2} + std::vector args2{{{test_value_names2}}}; + testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize}); + +}} +""" + return generated + + def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str: + schema = g.view.func + schema_str = str(schema) + assert schema_str.find("(") > 0 + type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_") + op_name = g.view.root_name + assert type_variant_op_name.startswith(op_name) + + arg_types = generate_test_ir_arguments(schema) + arg_declarations = ", ".join( + ( + arg_name if arg_type is None else f"{arg_name}: {arg_type}" + for arg_name, arg_type in arg_types + ) + ) + arg_names = ", ".join((arg_name for arg_name, _ in arg_types)) + assert ( + len(schema.returns) == 1 + and isinstance(schema.returns[0].type, BaseType) + and schema.returns[0].type.name is BaseTy.Tensor + ) + test_value_definitions = generate_test_value_definitions(schema, 0) + test_value_names = generate_test_value_names(schema, 0) + generated = f""" +TEST(StaticRuntime, autogen_{type_variant_op_name}) {{ + const std::string script = R"IR( + graph({arg_declarations}): + %bias: None = prim::Constant() + %ret = aten::{op_name}({arg_names}) + %cloned = aten::clone(%ret, %bias) + return (%cloned) + )IR"; + + {test_value_definitions} + std::vector args{{{test_value_names}}}; + testStaticRuntime(script, args); +}} +""" + + return generated diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f0f88a14473da9f1b409ed089fdb62ea37f85b3 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/_optical_flow.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/_optical_flow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38365f061c615a85d54f0a6d60d94dcb25b195eb Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/_optical_flow.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/_stereo_matching.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/_stereo_matching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ac9adadd03084f9147e513996380fc8aaea58ce Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/_stereo_matching.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/caltech.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/caltech.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5394924444e34d1787a051891ea1ddc641b31e13 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/caltech.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/celeba.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/celeba.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2077b0cce4ebd803acfe409972d29c6d4cc23eb7 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/celeba.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/cifar.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/cifar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6015b25d59936350751c2f84c1bf6169a422ea7 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/cifar.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/cityscapes.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/cityscapes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca17a8cf6f47336717d9e5fd64ccb68aa3e0e2ea Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/cityscapes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/clevr.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/clevr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34d30e37cdc1305077f9043cb2a5989f572a84d6 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/clevr.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/coco.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/coco.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e417705b0bcc630b62b7b270a481c466a1e8f8a Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/coco.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/country211.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/country211.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..192ddcfcf133e15a2b65ea3ff33a4c7069137799 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/country211.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/dtd.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/dtd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df66b64a2ae18fcca949d27bb2e513070e1d2450 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/dtd.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/eurosat.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/eurosat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b2a5e2638bdabd93cd36d43863c8411c925c952 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/eurosat.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/fakedata.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/fakedata.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32d6e912c9515a15e938430176db968a972861eb Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/fakedata.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/fer2013.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/fer2013.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e755e48d5e9e1047b96f90ca461b48d81db5d56c Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/fer2013.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/fgvc_aircraft.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/fgvc_aircraft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02b32d065ea1d331fbeafc0811ec3bc8bdbf8891 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/fgvc_aircraft.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/flickr.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/flickr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35783cfba976c4a09e1ae1c23b43d2cd3c8f029d Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/flickr.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/flowers102.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/flowers102.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebf57d027be68c983cab3b79127742be897ceb5e Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/flowers102.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/folder.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/folder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad31732e215a8ce944fa1aeb63e34699bf4e9315 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/folder.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/food101.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/food101.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6cbda4e48ee6bc9038f9253f78e5ded0e356eee Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/food101.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/gtsrb.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/gtsrb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79bcf7dec2061a0df5bfc91e254a46dbc6aebbf8 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/gtsrb.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/hmdb51.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/hmdb51.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..618061706fb8abf337e525a5db528d0b0877acf5 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/hmdb51.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/imagenet.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/imagenet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6f285d43efaacc038228864ba6bf2f452370e5c Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/imagenet.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/rendered_sst2.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/rendered_sst2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b21b549c47f049353fd92df5ec478be4a8f1b02 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/rendered_sst2.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/sbd.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/sbd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c323fbf128f26ca64224beed56b56203277b8471 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/sbd.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/sbu.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/sbu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..639f437a59c7c9fec932d6455a07fafb88934c81 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/sbu.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/semeion.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/semeion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0333d7859cfab8af6ca6a69b22c05d67cfa4f3b2 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/semeion.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/stanford_cars.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/stanford_cars.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca2c09ed878fe9de67aacea55fba10ead9cd9519 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/stanford_cars.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/stl10.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/stl10.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef457083034003377a6a2ac3b88a2066a9f7615a Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/stl10.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/sun397.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/sun397.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8c703cdf75947fd1f9a1434a28f3dcd60fdb7f4 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/sun397.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/svhn.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/svhn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8cb02d32e7ea8f745fc9e7864689283e526105f Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/svhn.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/ucf101.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/ucf101.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02d8a26a0d353d47d572a5385ab72845b3ce7d64 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/ucf101.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/usps.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/usps.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..919fca8f23b4ad5df78b62a6328f5efa5c46a9cf Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/usps.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/utils.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8816cb7c2e92656337562b0b9f012d7bd6bb78cb Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/video_utils.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/video_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e191de3ba1d716e258b6e5c6ef6544868ef7f7fb Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/video_utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/vision.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/vision.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1fb105811c541ed25aa8fbf603c63db40d34eb9 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/vision.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/voc.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/voc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc28f1e7e9bde961ed6f903408a2f6f6dd0a6e7c Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/voc.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/__pycache__/widerface.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/widerface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb719a3d1c44034b1b8a4e71892eb3f513d38d2a Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/__pycache__/widerface.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/samplers/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/samplers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7149a67900e65cf1f386ce3357c87e1a7a590fb4 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/samplers/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/datasets/samplers/__pycache__/clip_sampler.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/datasets/samplers/__pycache__/clip_sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a06ec5f5905624afbb7a787f70255d7251b5ff6 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/datasets/samplers/__pycache__/clip_sampler.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b31fa241a6febddcc9f1c2bb90830bdd8723739 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/_api.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40dce79299f57c7fa56a601d71f1845a3f9aecd7 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/_api.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/_meta.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/_meta.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7830990c5697f3c37d5fbc16d168e0fe95adb69a Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/_meta.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/_utils.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc52aaef9b27bd9281158f4ede45f08a91d3ff21 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/_utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/alexnet.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/alexnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..996cab3498d08e9f3520a27764b8d1873dce427f Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/alexnet.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/convnext.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/convnext.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1023d456caa66928cc2ff771b6f3287480de1cde Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/convnext.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/densenet.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/densenet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..914f888f40db2304f1a0c2b30f65c3d53f5d277b Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/densenet.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/efficientnet.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/efficientnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6830e937ebf835c0d313bd0184c48876f7142024 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/efficientnet.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/feature_extraction.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/feature_extraction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7554fea50c82791933c3b77f19f5b739ecbd907 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/feature_extraction.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/googlenet.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/googlenet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1853d735e4f7cb14bc2686b383082b2181eb056b Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/googlenet.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/inception.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/inception.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a48c8cd859c7177888f20e98bb8876c878b78c8 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/inception.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/maxvit.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/maxvit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..743823cc0cc414a00cf1c9523b0163f7dc2c6597 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/maxvit.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/mnasnet.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/mnasnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c78cefa48ce349b77a081e039cde77578488c873 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/mnasnet.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/mobilenet.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/mobilenet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..110086f6f6d4c2db7aecf0d295a061bb7ab0bd55 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/mobilenet.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/mobilenetv2.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/mobilenetv2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec3fed2ce2921c76abc5fda7df3b60efe4cbfc67 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/mobilenetv2.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/mobilenetv3.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/mobilenetv3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3b822e0b0b60f33fb330aaf4483f59103647ed5 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/mobilenetv3.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/regnet.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/regnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57b2469a66872c6314d2b144aad4e9d98686e0b3 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/regnet.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/resnet.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..170f9eaabd30f5cecdb271401487b93d19b8aafa Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/resnet.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/shufflenetv2.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/shufflenetv2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b544b5b491b0f73ab7064c0cd8176c154673f9e Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/shufflenetv2.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/squeezenet.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/squeezenet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ae5cdcaf88894022ed80ea854d709876e00d4de Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/squeezenet.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/swin_transformer.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/swin_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec59c35e34c00fa161c9f433a128ea6f124f31c3 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/swin_transformer.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/vgg.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/vgg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a5a29da30f42677dcbb01ffdd722600855322aa Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/vgg.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/__pycache__/vision_transformer.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/__pycache__/vision_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..401abe1c99cfa8e31d7478faf28ade25a348d10a Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/__pycache__/vision_transformer.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/detection/__init__.py b/lib/python3.10/site-packages/torchvision/models/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4146651c737971cc5a883b6750f2ded3051bc8ea --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/__init__.py @@ -0,0 +1,7 @@ +from .faster_rcnn import * +from .fcos import * +from .keypoint_rcnn import * +from .mask_rcnn import * +from .retinanet import * +from .ssd import * +from .ssdlite import * diff --git a/lib/python3.10/site-packages/torchvision/models/detection/_utils.py b/lib/python3.10/site-packages/torchvision/models/detection/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..559db858ac32f3b9f157aff3c22da83abece2a73 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/_utils.py @@ -0,0 +1,540 @@ +import math +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn, Tensor +from torch.nn import functional as F +from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss + + +class BalancedPositiveNegativeSampler: + """ + This class samples batches, ensuring that they contain a fixed proportion of positives + """ + + def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None: + """ + Args: + batch_size_per_image (int): number of elements to be selected per image + positive_fraction (float): percentage of positive elements per batch + """ + self.batch_size_per_image = batch_size_per_image + self.positive_fraction = positive_fraction + + def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: + """ + Args: + matched_idxs: list of tensors containing -1, 0 or positive values. + Each tensor corresponds to a specific image. + -1 values are ignored, 0 are considered as negatives and > 0 as + positives. + + Returns: + pos_idx (list[tensor]) + neg_idx (list[tensor]) + + Returns two lists of binary masks for each image. + The first list contains the positive elements that were selected, + and the second list the negative example. + """ + pos_idx = [] + neg_idx = [] + for matched_idxs_per_image in matched_idxs: + positive = torch.where(matched_idxs_per_image >= 1)[0] + negative = torch.where(matched_idxs_per_image == 0)[0] + + num_pos = int(self.batch_size_per_image * self.positive_fraction) + # protect against not enough positive examples + num_pos = min(positive.numel(), num_pos) + num_neg = self.batch_size_per_image - num_pos + # protect against not enough negative examples + num_neg = min(negative.numel(), num_neg) + + # randomly select positive and negative examples + perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] + perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] + + pos_idx_per_image = positive[perm1] + neg_idx_per_image = negative[perm2] + + # create binary mask from indices + pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8) + neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8) + + pos_idx_per_image_mask[pos_idx_per_image] = 1 + neg_idx_per_image_mask[neg_idx_per_image] = 1 + + pos_idx.append(pos_idx_per_image_mask) + neg_idx.append(neg_idx_per_image_mask) + + return pos_idx, neg_idx + + +@torch.jit._script_if_tracing +def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor: + """ + Encode a set of proposals with respect to some + reference boxes + + Args: + reference_boxes (Tensor): reference boxes + proposals (Tensor): boxes to be encoded + weights (Tensor[4]): the weights for ``(x, y, w, h)`` + """ + + # perform some unpacking to make it JIT-fusion friendly + wx = weights[0] + wy = weights[1] + ww = weights[2] + wh = weights[3] + + proposals_x1 = proposals[:, 0].unsqueeze(1) + proposals_y1 = proposals[:, 1].unsqueeze(1) + proposals_x2 = proposals[:, 2].unsqueeze(1) + proposals_y2 = proposals[:, 3].unsqueeze(1) + + reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1) + reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1) + reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1) + reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1) + + # implementation starts here + ex_widths = proposals_x2 - proposals_x1 + ex_heights = proposals_y2 - proposals_y1 + ex_ctr_x = proposals_x1 + 0.5 * ex_widths + ex_ctr_y = proposals_y1 + 0.5 * ex_heights + + gt_widths = reference_boxes_x2 - reference_boxes_x1 + gt_heights = reference_boxes_y2 - reference_boxes_y1 + gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths + gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights + + targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths + targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights + targets_dw = ww * torch.log(gt_widths / ex_widths) + targets_dh = wh * torch.log(gt_heights / ex_heights) + + targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) + return targets + + +class BoxCoder: + """ + This class encodes and decodes a set of bounding boxes into + the representation used for training the regressors. + """ + + def __init__( + self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16) + ) -> None: + """ + Args: + weights (4-element tuple) + bbox_xform_clip (float) + """ + self.weights = weights + self.bbox_xform_clip = bbox_xform_clip + + def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]: + boxes_per_image = [len(b) for b in reference_boxes] + reference_boxes = torch.cat(reference_boxes, dim=0) + proposals = torch.cat(proposals, dim=0) + targets = self.encode_single(reference_boxes, proposals) + return targets.split(boxes_per_image, 0) + + def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: + """ + Encode a set of proposals with respect to some + reference boxes + + Args: + reference_boxes (Tensor): reference boxes + proposals (Tensor): boxes to be encoded + """ + dtype = reference_boxes.dtype + device = reference_boxes.device + weights = torch.as_tensor(self.weights, dtype=dtype, device=device) + targets = encode_boxes(reference_boxes, proposals, weights) + + return targets + + def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: + torch._assert( + isinstance(boxes, (list, tuple)), + "This function expects boxes of type list or tuple.", + ) + torch._assert( + isinstance(rel_codes, torch.Tensor), + "This function expects rel_codes of type torch.Tensor.", + ) + boxes_per_image = [b.size(0) for b in boxes] + concat_boxes = torch.cat(boxes, dim=0) + box_sum = 0 + for val in boxes_per_image: + box_sum += val + if box_sum > 0: + rel_codes = rel_codes.reshape(box_sum, -1) + pred_boxes = self.decode_single(rel_codes, concat_boxes) + if box_sum > 0: + pred_boxes = pred_boxes.reshape(box_sum, -1, 4) + return pred_boxes + + def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: + """ + From a set of original boxes and encoded relative box offsets, + get the decoded boxes. + + Args: + rel_codes (Tensor): encoded boxes + boxes (Tensor): reference boxes. + """ + + boxes = boxes.to(rel_codes.dtype) + + widths = boxes[:, 2] - boxes[:, 0] + heights = boxes[:, 3] - boxes[:, 1] + ctr_x = boxes[:, 0] + 0.5 * widths + ctr_y = boxes[:, 1] + 0.5 * heights + + wx, wy, ww, wh = self.weights + dx = rel_codes[:, 0::4] / wx + dy = rel_codes[:, 1::4] / wy + dw = rel_codes[:, 2::4] / ww + dh = rel_codes[:, 3::4] / wh + + # Prevent sending too large values into torch.exp() + dw = torch.clamp(dw, max=self.bbox_xform_clip) + dh = torch.clamp(dh, max=self.bbox_xform_clip) + + pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] + pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] + pred_w = torch.exp(dw) * widths[:, None] + pred_h = torch.exp(dh) * heights[:, None] + + # Distance from center to box's corner. + c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h + c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w + + pred_boxes1 = pred_ctr_x - c_to_c_w + pred_boxes2 = pred_ctr_y - c_to_c_h + pred_boxes3 = pred_ctr_x + c_to_c_w + pred_boxes4 = pred_ctr_y + c_to_c_h + pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1) + return pred_boxes + + +class BoxLinearCoder: + """ + The linear box-to-box transform defined in FCOS. The transformation is parameterized + by the distance from the center of (square) src box to 4 edges of the target box. + """ + + def __init__(self, normalize_by_size: bool = True) -> None: + """ + Args: + normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes. + """ + self.normalize_by_size = normalize_by_size + + def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: + """ + Encode a set of proposals with respect to some reference boxes + + Args: + reference_boxes (Tensor): reference boxes + proposals (Tensor): boxes to be encoded + + Returns: + Tensor: the encoded relative box offsets that can be used to + decode the boxes. + + """ + + # get the center of reference_boxes + reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2]) + reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3]) + + # get box regression transformation deltas + target_l = reference_boxes_ctr_x - proposals[..., 0] + target_t = reference_boxes_ctr_y - proposals[..., 1] + target_r = proposals[..., 2] - reference_boxes_ctr_x + target_b = proposals[..., 3] - reference_boxes_ctr_y + + targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1) + + if self.normalize_by_size: + reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0] + reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1] + reference_boxes_size = torch.stack( + (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1 + ) + targets = targets / reference_boxes_size + return targets + + def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: + + """ + From a set of original boxes and encoded relative box offsets, + get the decoded boxes. + + Args: + rel_codes (Tensor): encoded boxes + boxes (Tensor): reference boxes. + + Returns: + Tensor: the predicted boxes with the encoded relative box offsets. + + .. note:: + This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``. + + """ + + boxes = boxes.to(dtype=rel_codes.dtype) + + ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2]) + ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3]) + + if self.normalize_by_size: + boxes_w = boxes[..., 2] - boxes[..., 0] + boxes_h = boxes[..., 3] - boxes[..., 1] + + list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1) + rel_codes = rel_codes * list_box_size + + pred_boxes1 = ctr_x - rel_codes[..., 0] + pred_boxes2 = ctr_y - rel_codes[..., 1] + pred_boxes3 = ctr_x + rel_codes[..., 2] + pred_boxes4 = ctr_y + rel_codes[..., 3] + + pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1) + return pred_boxes + + +class Matcher: + """ + This class assigns to each predicted "element" (e.g., a box) a ground-truth + element. Each predicted element will have exactly zero or one matches; each + ground-truth element may be assigned to zero or more predicted elements. + + Matching is based on the MxN match_quality_matrix, that characterizes how well + each (ground-truth, predicted)-pair match. For example, if the elements are + boxes, the matrix may contain box IoU overlap values. + + The matcher returns a tensor of size N containing the index of the ground-truth + element m that matches to prediction n. If there is no match, a negative value + is returned. + """ + + BELOW_LOW_THRESHOLD = -1 + BETWEEN_THRESHOLDS = -2 + + __annotations__ = { + "BELOW_LOW_THRESHOLD": int, + "BETWEEN_THRESHOLDS": int, + } + + def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None: + """ + Args: + high_threshold (float): quality values greater than or equal to + this value are candidate matches. + low_threshold (float): a lower quality threshold used to stratify + matches into three levels: + 1) matches >= high_threshold + 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold) + 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold) + allow_low_quality_matches (bool): if True, produce additional matches + for predictions that have only low-quality match candidates. See + set_low_quality_matches_ for more details. + """ + self.BELOW_LOW_THRESHOLD = -1 + self.BETWEEN_THRESHOLDS = -2 + torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold") + self.high_threshold = high_threshold + self.low_threshold = low_threshold + self.allow_low_quality_matches = allow_low_quality_matches + + def __call__(self, match_quality_matrix: Tensor) -> Tensor: + """ + Args: + match_quality_matrix (Tensor[float]): an MxN tensor, containing the + pairwise quality between M ground-truth elements and N predicted elements. + + Returns: + matches (Tensor[int64]): an N tensor where N[i] is a matched gt in + [0, M - 1] or a negative value indicating that prediction i could not + be matched. + """ + if match_quality_matrix.numel() == 0: + # empty targets or proposals not supported during training + if match_quality_matrix.shape[0] == 0: + raise ValueError("No ground-truth boxes available for one of the images during training") + else: + raise ValueError("No proposal boxes available for one of the images during training") + + # match_quality_matrix is M (gt) x N (predicted) + # Max over gt elements (dim 0) to find best gt candidate for each prediction + matched_vals, matches = match_quality_matrix.max(dim=0) + if self.allow_low_quality_matches: + all_matches = matches.clone() + else: + all_matches = None # type: ignore[assignment] + + # Assign candidate matches with low quality to negative (unassigned) values + below_low_threshold = matched_vals < self.low_threshold + between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold) + matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD + matches[between_thresholds] = self.BETWEEN_THRESHOLDS + + if self.allow_low_quality_matches: + if all_matches is None: + torch._assert(False, "all_matches should not be None") + else: + self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) + + return matches + + def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None: + """ + Produce additional matches for predictions that have only low-quality matches. + Specifically, for each ground-truth find the set of predictions that have + maximum overlap with it (including ties); for each prediction in that set, if + it is unmatched, then match it to the ground-truth with which it has the highest + quality value. + """ + # For each gt, find the prediction with which it has the highest quality + highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) + # Find the highest quality match available, even if it is low, including ties + gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None]) + # Example gt_pred_pairs_of_highest_quality: + # (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]), + # tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390])) + # Each element in the first tensor is a gt index, and each element in second tensor is a prediction index + # Note how gt items 1, 2, 3, and 5 each have two ties + + pred_inds_to_update = gt_pred_pairs_of_highest_quality[1] + matches[pred_inds_to_update] = all_matches[pred_inds_to_update] + + +class SSDMatcher(Matcher): + def __init__(self, threshold: float) -> None: + super().__init__(threshold, threshold, allow_low_quality_matches=False) + + def __call__(self, match_quality_matrix: Tensor) -> Tensor: + matches = super().__call__(match_quality_matrix) + + # For each gt, find the prediction with which it has the highest quality + _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1) + matches[highest_quality_pred_foreach_gt] = torch.arange( + highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device + ) + + return matches + + +def overwrite_eps(model: nn.Module, eps: float) -> None: + """ + This method overwrites the default eps values of all the + FrozenBatchNorm2d layers of the model with the provided value. + This is necessary to address the BC-breaking change introduced + by the bug-fix at pytorch/vision#2933. The overwrite is applied + only when the pretrained weights are loaded to maintain compatibility + with previous versions. + + Args: + model (nn.Module): The model on which we perform the overwrite. + eps (float): The new value of eps. + """ + for module in model.modules(): + if isinstance(module, FrozenBatchNorm2d): + module.eps = eps + + +def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]: + """ + This method retrieves the number of output channels of a specific model. + + Args: + model (nn.Module): The model for which we estimate the out_channels. + It should return a single Tensor or an OrderedDict[Tensor]. + size (Tuple[int, int]): The size (wxh) of the input. + + Returns: + out_channels (List[int]): A list of the output channels of the model. + """ + in_training = model.training + model.eval() + + with torch.no_grad(): + # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values + device = next(model.parameters()).device + tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device) + features = model(tmp_img) + if isinstance(features, torch.Tensor): + features = OrderedDict([("0", features)]) + out_channels = [x.size(1) for x in features.values()] + + if in_training: + model.train() + + return out_channels + + +@torch.jit.unused +def _fake_cast_onnx(v: Tensor) -> int: + return v # type: ignore[return-value] + + +def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int: + """ + ONNX spec requires the k-value to be less than or equal to the number of inputs along + provided dim. Certain models use the number of elements along a particular axis instead of K + if K exceeds the number of elements along that axis. Previously, python's min() function was + used to determine whether to use the provided k-value or the specified dim axis value. + + However, in cases where the model is being exported in tracing mode, python min() is + static causing the model to be traced incorrectly and eventually fail at the topk node. + In order to avoid this situation, in tracing mode, torch.min() is used instead. + + Args: + input (Tensor): The original input tensor. + orig_kval (int): The provided k-value. + axis(int): Axis along which we retrieve the input size. + + Returns: + min_kval (int): Appropriately selected k-value. + """ + if not torch.jit.is_tracing(): + return min(orig_kval, input.size(axis)) + axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0) + min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0)) + return _fake_cast_onnx(min_kval) + + +def _box_loss( + type: str, + box_coder: BoxCoder, + anchors_per_image: Tensor, + matched_gt_boxes_per_image: Tensor, + bbox_regression_per_image: Tensor, + cnf: Optional[Dict[str, float]] = None, +) -> Tensor: + torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}") + + if type == "l1": + target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) + return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum") + elif type == "smooth_l1": + target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) + beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0 + return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta) + else: + bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image) + eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7 + if type == "ciou": + return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) + if type == "diou": + return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) + # otherwise giou + return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) diff --git a/lib/python3.10/site-packages/torchvision/models/detection/anchor_utils.py b/lib/python3.10/site-packages/torchvision/models/detection/anchor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..253f6502a9b6344f5a3da239f2394179a256424e --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/anchor_utils.py @@ -0,0 +1,268 @@ +import math +from typing import List, Optional + +import torch +from torch import nn, Tensor + +from .image_list import ImageList + + +class AnchorGenerator(nn.Module): + """ + Module that generates anchors for a set of feature maps and + image sizes. + + The module support computing anchors at multiple sizes and aspect ratios + per feature map. This module assumes aspect ratio = height / width for + each anchor. + + sizes and aspect_ratios should have the same number of elements, and it should + correspond to the number of feature maps. + + sizes[i] and aspect_ratios[i] can have an arbitrary number of elements, + and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors + per spatial location for feature map i. + + Args: + sizes (Tuple[Tuple[int]]): + aspect_ratios (Tuple[Tuple[float]]): + """ + + __annotations__ = { + "cell_anchors": List[torch.Tensor], + } + + def __init__( + self, + sizes=((128, 256, 512),), + aspect_ratios=((0.5, 1.0, 2.0),), + ): + super().__init__() + + if not isinstance(sizes[0], (list, tuple)): + # TODO change this + sizes = tuple((s,) for s in sizes) + if not isinstance(aspect_ratios[0], (list, tuple)): + aspect_ratios = (aspect_ratios,) * len(sizes) + + self.sizes = sizes + self.aspect_ratios = aspect_ratios + self.cell_anchors = [ + self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios) + ] + + # TODO: https://github.com/pytorch/pytorch/issues/26792 + # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. + # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) + # This method assumes aspect ratio = height / width for an anchor. + def generate_anchors( + self, + scales: List[int], + aspect_ratios: List[float], + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ) -> Tensor: + scales = torch.as_tensor(scales, dtype=dtype, device=device) + aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) + h_ratios = torch.sqrt(aspect_ratios) + w_ratios = 1 / h_ratios + + ws = (w_ratios[:, None] * scales[None, :]).view(-1) + hs = (h_ratios[:, None] * scales[None, :]).view(-1) + + base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 + return base_anchors.round() + + def set_cell_anchors(self, dtype: torch.dtype, device: torch.device): + self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors] + + def num_anchors_per_location(self) -> List[int]: + return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] + + # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), + # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. + def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]: + anchors = [] + cell_anchors = self.cell_anchors + torch._assert(cell_anchors is not None, "cell_anchors should not be None") + torch._assert( + len(grid_sizes) == len(strides) == len(cell_anchors), + "Anchors should be Tuple[Tuple[int]] because each feature " + "map could potentially have different sizes and aspect ratios. " + "There needs to be a match between the number of " + "feature maps passed and the number of sizes / aspect ratios specified.", + ) + + for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors): + grid_height, grid_width = size + stride_height, stride_width = stride + device = base_anchors.device + + # For output anchor, compute [x_center, y_center, x_center, y_center] + shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width + shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) + + # For every (base anchor, output anchor) pair, + # offset each zero-centered base anchor by the center of the output anchor. + anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)) + + return anchors + + def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]: + grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] + image_size = image_list.tensors.shape[-2:] + dtype, device = feature_maps[0].dtype, feature_maps[0].device + strides = [ + [ + torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]), + torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]), + ] + for g in grid_sizes + ] + self.set_cell_anchors(dtype, device) + anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides) + anchors: List[List[torch.Tensor]] = [] + for _ in range(len(image_list.image_sizes)): + anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps] + anchors.append(anchors_in_image) + anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] + return anchors + + +class DefaultBoxGenerator(nn.Module): + """ + This module generates the default boxes of SSD for a set of feature maps and image sizes. + + Args: + aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map. + min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation + of the scales of each feature map. It is used only if the ``scales`` parameter is not provided. + max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}` of the default boxes used in the estimation + of the scales of each feature map. It is used only if the ``scales`` parameter is not provided. + scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using + the ``min_ratio`` and ``max_ratio`` parameters. + steps (List[int]], optional): It's a hyper-parameter that affects the tiling of default boxes. If not provided + it will be estimated from the data. + clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping + is applied while the boxes are encoded in format ``(cx, cy, w, h)``. + """ + + def __init__( + self, + aspect_ratios: List[List[int]], + min_ratio: float = 0.15, + max_ratio: float = 0.9, + scales: Optional[List[float]] = None, + steps: Optional[List[int]] = None, + clip: bool = True, + ): + super().__init__() + if steps is not None and len(aspect_ratios) != len(steps): + raise ValueError("aspect_ratios and steps should have the same length") + self.aspect_ratios = aspect_ratios + self.steps = steps + self.clip = clip + num_outputs = len(aspect_ratios) + + # Estimation of default boxes scales + if scales is None: + if num_outputs > 1: + range_ratio = max_ratio - min_ratio + self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)] + self.scales.append(1.0) + else: + self.scales = [min_ratio, max_ratio] + else: + self.scales = scales + + self._wh_pairs = self._generate_wh_pairs(num_outputs) + + def _generate_wh_pairs( + self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu") + ) -> List[Tensor]: + _wh_pairs: List[Tensor] = [] + for k in range(num_outputs): + # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k + s_k = self.scales[k] + s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1]) + wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]] + + # Adding 2 pairs for each aspect ratio of the feature map k + for ar in self.aspect_ratios[k]: + sq_ar = math.sqrt(ar) + w = self.scales[k] * sq_ar + h = self.scales[k] / sq_ar + wh_pairs.extend([[w, h], [h, w]]) + + _wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device)) + return _wh_pairs + + def num_anchors_per_location(self) -> List[int]: + # Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map. + return [2 + 2 * len(r) for r in self.aspect_ratios] + + # Default Boxes calculation based on page 6 of SSD paper + def _grid_default_boxes( + self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32 + ) -> Tensor: + default_boxes = [] + for k, f_k in enumerate(grid_sizes): + # Now add the default boxes for each width-height pair + if self.steps is not None: + x_f_k = image_size[1] / self.steps[k] + y_f_k = image_size[0] / self.steps[k] + else: + y_f_k, x_f_k = f_k + + shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype) + shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + + shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2) + # Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h) + _wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k] + wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1) + + default_box = torch.cat((shifts, wh_pairs), dim=1) + + default_boxes.append(default_box) + + return torch.cat(default_boxes, dim=0) + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"aspect_ratios={self.aspect_ratios}" + f", clip={self.clip}" + f", scales={self.scales}" + f", steps={self.steps}" + ")" + ) + return s + + def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]: + grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] + image_size = image_list.tensors.shape[-2:] + dtype, device = feature_maps[0].dtype, feature_maps[0].device + default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype) + default_boxes = default_boxes.to(device) + + dboxes = [] + x_y_size = torch.tensor([image_size[1], image_size[0]], device=default_boxes.device) + for _ in image_list.image_sizes: + dboxes_in_image = default_boxes + dboxes_in_image = torch.cat( + [ + (dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:]) * x_y_size, + (dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]) * x_y_size, + ], + -1, + ) + dboxes.append(dboxes_in_image) + return dboxes diff --git a/lib/python3.10/site-packages/torchvision/models/detection/backbone_utils.py b/lib/python3.10/site-packages/torchvision/models/detection/backbone_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..668e6b31696eb949513d07878eada9d468dc99cd --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/backbone_utils.py @@ -0,0 +1,244 @@ +import warnings +from typing import Callable, Dict, List, Optional, Union + +from torch import nn, Tensor +from torchvision.ops import misc as misc_nn_ops +from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool + +from .. import mobilenet, resnet +from .._api import _get_enum_from_fn, WeightsEnum +from .._utils import handle_legacy_interface, IntermediateLayerGetter + + +class BackboneWithFPN(nn.Module): + """ + Adds a FPN on top of a model. + Internally, it uses torchvision.models._utils.IntermediateLayerGetter to + extract a submodel that returns the feature maps specified in return_layers. + The same limitations of IntermediateLayerGetter apply here. + Args: + backbone (nn.Module) + return_layers (Dict[name, new_name]): a dict containing the names + of the modules for which the activations will be returned as + the key of the dict, and the value of the dict is the name + of the returned activation (which the user can specify). + in_channels_list (List[int]): number of channels for each feature map + that is returned, in the order they are present in the OrderedDict + out_channels (int): number of channels in the FPN. + norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None + Attributes: + out_channels (int): the number of channels in the FPN + """ + + def __init__( + self, + backbone: nn.Module, + return_layers: Dict[str, str], + in_channels_list: List[int], + out_channels: int, + extra_blocks: Optional[ExtraFPNBlock] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.fpn = FeaturePyramidNetwork( + in_channels_list=in_channels_list, + out_channels=out_channels, + extra_blocks=extra_blocks, + norm_layer=norm_layer, + ) + self.out_channels = out_channels + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + x = self.body(x) + x = self.fpn(x) + return x + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"], + ), +) +def resnet_fpn_backbone( + *, + backbone_name: str, + weights: Optional[WeightsEnum], + norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, + trainable_layers: int = 3, + returned_layers: Optional[List[int]] = None, + extra_blocks: Optional[ExtraFPNBlock] = None, +) -> BackboneWithFPN: + """ + Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone. + + Examples:: + + >>> import torch + >>> from torchvision.models import ResNet50_Weights + >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone + >>> backbone = resnet_fpn_backbone(backbone_name='resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3) + >>> # get some dummy image + >>> x = torch.rand(1,3,64,64) + >>> # compute the output + >>> output = backbone(x) + >>> print([(k, v.shape) for k, v in output.items()]) + >>> # returns + >>> [('0', torch.Size([1, 256, 16, 16])), + >>> ('1', torch.Size([1, 256, 8, 8])), + >>> ('2', torch.Size([1, 256, 4, 4])), + >>> ('3', torch.Size([1, 256, 2, 2])), + >>> ('pool', torch.Size([1, 256, 1, 1]))] + + Args: + backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50', + 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2' + weights (WeightsEnum, optional): The pretrained weights for the model + norm_layer (callable): it is recommended to use the default value. For details visit: + (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) + trainable_layers (int): number of trainable (not frozen) layers starting from final block. + Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. + returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``. + By default, all layers are returned. + extra_blocks (ExtraFPNBlock or None): if provided, extra operations will + be performed. It is expected to take the fpn features, the original + features and the names of the original features as input, and returns + a new list of feature maps and their corresponding names. By + default, a ``LastLevelMaxPool`` is used. + """ + backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) + return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks) + + +def _resnet_fpn_extractor( + backbone: resnet.ResNet, + trainable_layers: int, + returned_layers: Optional[List[int]] = None, + extra_blocks: Optional[ExtraFPNBlock] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, +) -> BackboneWithFPN: + + # select layers that won't be frozen + if trainable_layers < 0 or trainable_layers > 5: + raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}") + layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] + if trainable_layers == 5: + layers_to_train.append("bn1") + for name, parameter in backbone.named_parameters(): + if all([not name.startswith(layer) for layer in layers_to_train]): + parameter.requires_grad_(False) + + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + + if returned_layers is None: + returned_layers = [1, 2, 3, 4] + if min(returned_layers) <= 0 or max(returned_layers) >= 5: + raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}") + return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)} + + in_channels_stage2 = backbone.inplanes // 8 + in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] + out_channels = 256 + return BackboneWithFPN( + backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer + ) + + +def _validate_trainable_layers( + is_trained: bool, + trainable_backbone_layers: Optional[int], + max_value: int, + default_value: int, +) -> int: + # don't freeze any layers if pretrained model or backbone is not used + if not is_trained: + if trainable_backbone_layers is not None: + warnings.warn( + "Changing trainable_backbone_layers has no effect if " + "neither pretrained nor pretrained_backbone have been set to True, " + f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable" + ) + trainable_backbone_layers = max_value + + # by default freeze first blocks + if trainable_backbone_layers is None: + trainable_backbone_layers = default_value + if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value: + raise ValueError( + f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} " + ) + return trainable_backbone_layers + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"], + ), +) +def mobilenet_backbone( + *, + backbone_name: str, + weights: Optional[WeightsEnum], + fpn: bool, + norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, + trainable_layers: int = 2, + returned_layers: Optional[List[int]] = None, + extra_blocks: Optional[ExtraFPNBlock] = None, +) -> nn.Module: + backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) + return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks) + + +def _mobilenet_extractor( + backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3], + fpn: bool, + trainable_layers: int, + returned_layers: Optional[List[int]] = None, + extra_blocks: Optional[ExtraFPNBlock] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, +) -> nn.Module: + backbone = backbone.features + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] + num_stages = len(stage_indices) + + # find the index of the layer from which we won't freeze + if trainable_layers < 0 or trainable_layers > num_stages: + raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ") + freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] + + for b in backbone[:freeze_before]: + for parameter in b.parameters(): + parameter.requires_grad_(False) + + out_channels = 256 + if fpn: + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + + if returned_layers is None: + returned_layers = [num_stages - 2, num_stages - 1] + if min(returned_layers) < 0 or max(returned_layers) >= num_stages: + raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ") + return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)} + + in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers] + return BackboneWithFPN( + backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer + ) + else: + m = nn.Sequential( + backbone, + # depthwise linear combination of channels to reduce their size + nn.Conv2d(backbone[-1].out_channels, out_channels, 1), + ) + m.out_channels = out_channels # type: ignore[assignment] + return m diff --git a/lib/python3.10/site-packages/torchvision/models/detection/faster_rcnn.py b/lib/python3.10/site-packages/torchvision/models/detection/faster_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..18474ee84f4539cfec99d24534acb1e1e74a14b3 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/faster_rcnn.py @@ -0,0 +1,846 @@ +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torchvision.ops import MultiScaleRoIAlign + +from ...ops import misc as misc_nn_ops +from ...transforms._presets import ObjectDetection +from .._api import register_model, Weights, WeightsEnum +from .._meta import _COCO_CATEGORIES +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights +from ..resnet import resnet50, ResNet50_Weights +from ._utils import overwrite_eps +from .anchor_utils import AnchorGenerator +from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers +from .generalized_rcnn import GeneralizedRCNN +from .roi_heads import RoIHeads +from .rpn import RegionProposalNetwork, RPNHead +from .transform import GeneralizedRCNNTransform + + +__all__ = [ + "FasterRCNN", + "FasterRCNN_ResNet50_FPN_Weights", + "FasterRCNN_ResNet50_FPN_V2_Weights", + "FasterRCNN_MobileNet_V3_Large_FPN_Weights", + "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights", + "fasterrcnn_resnet50_fpn", + "fasterrcnn_resnet50_fpn_v2", + "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", +] + + +def _default_anchorgen(): + anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + return AnchorGenerator(anchor_sizes, aspect_ratios) + + +class FasterRCNN(GeneralizedRCNN): + """ + Implements Faster R-CNN. + + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes. + + The behavior of the model changes depending on if it is in training or evaluation mode. + + During training, the model expects both the input tensors and targets (list of dictionary), + containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the class label for each ground-truth box + + The model returns a Dict[Tensor] during training, containing the classification and regression + losses for both the RPN and the R-CNN. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows: + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the predicted labels for each image + - scores (Tensor[N]): the scores or each prediction + + Args: + backbone (nn.Module): the network used to compute the features for the model. + It should contain an out_channels attribute, which indicates the number of output + channels that each feature map has (and it should be the same for all feature maps). + The backbone should return a single Tensor or and OrderedDict[Tensor]. + num_classes (int): number of output classes of the model (including the background). + If box_predictor is specified, num_classes should be None. + min_size (int): Images are rescaled before feeding them to the backbone: + we attempt to preserve the aspect ratio and scale the shorter edge + to ``min_size``. If the resulting longer edge exceeds ``max_size``, + then downscale so that the longer edge does not exceed ``max_size``. + This may result in the shorter edge beeing lower than ``min_size``. + max_size (int): See ``min_size``. + image_mean (Tuple[float, float, float]): mean values used for input normalization. + They are generally the mean values of the dataset on which the backbone has been trained + on + image_std (Tuple[float, float, float]): std values used for input normalization. + They are generally the std values of the dataset on which the backbone has been trained on + rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature + maps. + rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN + rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training + rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing + rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training + rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing + rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals + rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be + considered as positive during training of the RPN. + rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be + considered as negative during training of the RPN. + rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN + for computing the loss + rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training + of the RPN + rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh + box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in + the locations indicated by the bounding boxes + box_head (nn.Module): module that takes the cropped feature maps as input + box_predictor (nn.Module): module that takes the output of box_head and returns the + classification logits and box regression deltas. + box_score_thresh (float): during inference, only return proposals with a classification score + greater than box_score_thresh + box_nms_thresh (float): NMS threshold for the prediction head. Used during inference + box_detections_per_img (int): maximum number of detections per image, for all classes. + box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be + considered as positive during training of the classification head + box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be + considered as negative during training of the classification head + box_batch_size_per_image (int): number of proposals that are sampled during training of the + classification head + box_positive_fraction (float): proportion of positive proposals in a mini-batch during training + of the classification head + bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the + bounding boxes + + Example:: + + >>> import torch + >>> import torchvision + >>> from torchvision.models.detection import FasterRCNN + >>> from torchvision.models.detection.rpn import AnchorGenerator + >>> # load a pre-trained model for classification and return + >>> # only the features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features + >>> # FasterRCNN needs to know the number of + >>> # output channels in a backbone. For mobilenet_v2, it's 1280, + >>> # so we need to add it here + >>> backbone.out_channels = 1280 + >>> + >>> # let's make the RPN generate 5 x 3 anchors per spatial + >>> # location, with 5 different sizes and 3 different aspect + >>> # ratios. We have a Tuple[Tuple[int]] because each feature + >>> # map could potentially have different sizes and + >>> # aspect ratios + >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), + >>> aspect_ratios=((0.5, 1.0, 2.0),)) + >>> + >>> # let's define what are the feature maps that we will + >>> # use to perform the region of interest cropping, as well as + >>> # the size of the crop after rescaling. + >>> # if your backbone returns a Tensor, featmap_names is expected to + >>> # be ['0']. More generally, the backbone should return an + >>> # OrderedDict[Tensor], and in featmap_names you can choose which + >>> # feature maps to use. + >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], + >>> output_size=7, + >>> sampling_ratio=2) + >>> + >>> # put the pieces together inside a FasterRCNN model + >>> model = FasterRCNN(backbone, + >>> num_classes=2, + >>> rpn_anchor_generator=anchor_generator, + >>> box_roi_pool=roi_pooler) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + """ + + def __init__( + self, + backbone, + num_classes=None, + # transform parameters + min_size=800, + max_size=1333, + image_mean=None, + image_std=None, + # RPN parameters + rpn_anchor_generator=None, + rpn_head=None, + rpn_pre_nms_top_n_train=2000, + rpn_pre_nms_top_n_test=1000, + rpn_post_nms_top_n_train=2000, + rpn_post_nms_top_n_test=1000, + rpn_nms_thresh=0.7, + rpn_fg_iou_thresh=0.7, + rpn_bg_iou_thresh=0.3, + rpn_batch_size_per_image=256, + rpn_positive_fraction=0.5, + rpn_score_thresh=0.0, + # Box parameters + box_roi_pool=None, + box_head=None, + box_predictor=None, + box_score_thresh=0.05, + box_nms_thresh=0.5, + box_detections_per_img=100, + box_fg_iou_thresh=0.5, + box_bg_iou_thresh=0.5, + box_batch_size_per_image=512, + box_positive_fraction=0.25, + bbox_reg_weights=None, + **kwargs, + ): + + if not hasattr(backbone, "out_channels"): + raise ValueError( + "backbone should contain an attribute out_channels " + "specifying the number of output channels (assumed to be the " + "same for all the levels)" + ) + + if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))): + raise TypeError( + f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}" + ) + if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))): + raise TypeError( + f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}" + ) + + if num_classes is not None: + if box_predictor is not None: + raise ValueError("num_classes should be None when box_predictor is specified") + else: + if box_predictor is None: + raise ValueError("num_classes should not be None when box_predictor is not specified") + + out_channels = backbone.out_channels + + if rpn_anchor_generator is None: + rpn_anchor_generator = _default_anchorgen() + if rpn_head is None: + rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0]) + + rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) + rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) + + rpn = RegionProposalNetwork( + rpn_anchor_generator, + rpn_head, + rpn_fg_iou_thresh, + rpn_bg_iou_thresh, + rpn_batch_size_per_image, + rpn_positive_fraction, + rpn_pre_nms_top_n, + rpn_post_nms_top_n, + rpn_nms_thresh, + score_thresh=rpn_score_thresh, + ) + + if box_roi_pool is None: + box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2) + + if box_head is None: + resolution = box_roi_pool.output_size[0] + representation_size = 1024 + box_head = TwoMLPHead(out_channels * resolution**2, representation_size) + + if box_predictor is None: + representation_size = 1024 + box_predictor = FastRCNNPredictor(representation_size, num_classes) + + roi_heads = RoIHeads( + # Box + box_roi_pool, + box_head, + box_predictor, + box_fg_iou_thresh, + box_bg_iou_thresh, + box_batch_size_per_image, + box_positive_fraction, + bbox_reg_weights, + box_score_thresh, + box_nms_thresh, + box_detections_per_img, + ) + + if image_mean is None: + image_mean = [0.485, 0.456, 0.406] + if image_std is None: + image_std = [0.229, 0.224, 0.225] + transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs) + + super().__init__(backbone, rpn, roi_heads, transform) + + +class TwoMLPHead(nn.Module): + """ + Standard heads for FPN-based models + + Args: + in_channels (int): number of input channels + representation_size (int): size of the intermediate representation + """ + + def __init__(self, in_channels, representation_size): + super().__init__() + + self.fc6 = nn.Linear(in_channels, representation_size) + self.fc7 = nn.Linear(representation_size, representation_size) + + def forward(self, x): + x = x.flatten(start_dim=1) + + x = F.relu(self.fc6(x)) + x = F.relu(self.fc7(x)) + + return x + + +class FastRCNNConvFCHead(nn.Sequential): + def __init__( + self, + input_size: Tuple[int, int, int], + conv_layers: List[int], + fc_layers: List[int], + norm_layer: Optional[Callable[..., nn.Module]] = None, + ): + """ + Args: + input_size (Tuple[int, int, int]): the input size in CHW format. + conv_layers (list): feature dimensions of each Convolution layer + fc_layers (list): feature dimensions of each FCN layer + norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None + """ + in_channels, in_height, in_width = input_size + + blocks = [] + previous_channels = in_channels + for current_channels in conv_layers: + blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer)) + previous_channels = current_channels + blocks.append(nn.Flatten()) + previous_channels = previous_channels * in_height * in_width + for current_channels in fc_layers: + blocks.append(nn.Linear(previous_channels, current_channels)) + blocks.append(nn.ReLU(inplace=True)) + previous_channels = current_channels + + super().__init__(*blocks) + for layer in self.modules(): + if isinstance(layer, nn.Conv2d): + nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu") + if layer.bias is not None: + nn.init.zeros_(layer.bias) + + +class FastRCNNPredictor(nn.Module): + """ + Standard classification + bounding box regression layers + for Fast R-CNN. + + Args: + in_channels (int): number of input channels + num_classes (int): number of output classes (including background) + """ + + def __init__(self, in_channels, num_classes): + super().__init__() + self.cls_score = nn.Linear(in_channels, num_classes) + self.bbox_pred = nn.Linear(in_channels, num_classes * 4) + + def forward(self, x): + if x.dim() == 4: + torch._assert( + list(x.shape[2:]) == [1, 1], + f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}", + ) + x = x.flatten(start_dim=1) + scores = self.cls_score(x) + bbox_deltas = self.bbox_pred(x) + + return scores, bbox_deltas + + +_COMMON_META = { + "categories": _COCO_CATEGORIES, + "min_size": (1, 1), +} + + +class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 41755286, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", + "_metrics": { + "COCO-val2017": { + "box_map": 37.0, + } + }, + "_ops": 134.38, + "_file_size": 159.743, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", + }, + ) + DEFAULT = COCO_V1 + + +class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 43712278, + "recipe": "https://github.com/pytorch/vision/pull/5763", + "_metrics": { + "COCO-val2017": { + "box_map": 46.7, + } + }, + "_ops": 280.371, + "_file_size": 167.104, + "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""", + }, + ) + DEFAULT = COCO_V1 + + +class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 19386354, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", + "_metrics": { + "COCO-val2017": { + "box_map": 32.8, + } + }, + "_ops": 4.494, + "_file_size": 74.239, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", + }, + ) + DEFAULT = COCO_V1 + + +class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 19386354, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", + "_metrics": { + "COCO-val2017": { + "box_map": 22.8, + } + }, + "_ops": 0.719, + "_file_size": 74.239, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", + }, + ) + DEFAULT = COCO_V1 + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) +def fasterrcnn_resnet50_fpn( + *, + weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: + """ + Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object + Detection with Region Proposal Networks `__ + paper. + + .. betastatus:: detection module + + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each + image, and should be in ``0-1`` range. Different images can have different sizes. + + The behavior of the model changes depending on if it is in training or evaluation mode. + + During training, the model expects both the input tensors and a targets (list of dictionary), + containing: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression + losses for both the RPN and the R-CNN. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as + follows, where ``N`` is the number of detections: + + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the predicted labels for each detection + - scores (``Tensor[N]``): the scores of each detection + + For more details on the output, you may refer to :ref:`instance_seg_output`. + + Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size. + + Example:: + + >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT) + >>> # For training + >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4) + >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4] + >>> labels = torch.randint(1, 91, (4, 11)) + >>> images = list(image for image in images) + >>> targets = [] + >>> for i in range(len(images)): + >>> d = {} + >>> d['boxes'] = boxes[i] + >>> d['labels'] = labels[i] + >>> targets.append(d) + >>> output = model(images, targets) + >>> # For inference + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + >>> + >>> # optionally, if you want to export the model to ONNX: + >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11) + + Args: + weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The + pretrained weights for the backbone. + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from + final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are + trainable. If ``None`` is passed (the default) this value is set to 3. + **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights + :members: + """ + weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d + + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) + model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + + return model + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) +def fasterrcnn_resnet50_fpn_v2( + *, + weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: + """ + Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection + Transfer Learning with Vision Transformers `__ paper. + + .. betastatus:: detection module + + It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See + :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more + details. + + Args: + weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The + pretrained weights for the backbone. + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from + final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are + trainable. If ``None`` is passed (the default) this value is set to 3. + **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights + :members: + """ + weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + + backbone = resnet50(weights=weights_backbone, progress=progress) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d) + rpn_anchor_generator = _default_anchorgen() + rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2) + box_head = FastRCNNConvFCHead( + (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d + ) + model = FasterRCNN( + backbone, + num_classes=num_classes, + rpn_anchor_generator=rpn_anchor_generator, + rpn_head=rpn_head, + box_head=box_head, + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +def _fasterrcnn_mobilenet_v3_large_fpn( + *, + weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], + progress: bool, + num_classes: Optional[int], + weights_backbone: Optional[MobileNet_V3_Large_Weights], + trainable_backbone_layers: Optional[int], + **kwargs: Any, +) -> FasterRCNN: + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d + + backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer) + backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) + anchor_sizes = ( + ( + 32, + 64, + 128, + 256, + 512, + ), + ) * 3 + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + model = FasterRCNN( + backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) +def fasterrcnn_mobilenet_v3_large_320_fpn( + *, + weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: + """ + Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases. + + .. betastatus:: detection module + + It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See + :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more + details. + + Example:: + + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The + pretrained weights for the backbone. + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from + final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are + trainable. If ``None`` is passed (the default) this value is set to 3. + **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights + :members: + """ + weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + + defaults = { + "min_size": 320, + "max_size": 640, + "rpn_pre_nms_top_n_test": 150, + "rpn_post_nms_top_n_test": 150, + "rpn_score_thresh": 0.05, + } + + kwargs = {**defaults, **kwargs} + return _fasterrcnn_mobilenet_v3_large_fpn( + weights=weights, + progress=progress, + num_classes=num_classes, + weights_backbone=weights_backbone, + trainable_backbone_layers=trainable_backbone_layers, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) +def fasterrcnn_mobilenet_v3_large_fpn( + *, + weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: + """ + Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone. + + .. betastatus:: detection module + + It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See + :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more + details. + + Example:: + + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The + pretrained weights for the backbone. + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from + final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are + trainable. If ``None`` is passed (the default) this value is set to 3. + **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights + :members: + """ + weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + + defaults = { + "rpn_score_thresh": 0.05, + } + + kwargs = {**defaults, **kwargs} + return _fasterrcnn_mobilenet_v3_large_fpn( + weights=weights, + progress=progress, + num_classes=num_classes, + weights_backbone=weights_backbone, + trainable_backbone_layers=trainable_backbone_layers, + **kwargs, + ) diff --git a/lib/python3.10/site-packages/torchvision/models/detection/fcos.py b/lib/python3.10/site-packages/torchvision/models/detection/fcos.py new file mode 100644 index 0000000000000000000000000000000000000000..a86ad2f424c32bd1cf951d474d3ef14bd1bddbb7 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/fcos.py @@ -0,0 +1,775 @@ +import math +import warnings +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import nn, Tensor + +from ...ops import boxes as box_ops, generalized_box_iou_loss, misc as misc_nn_ops, sigmoid_focal_loss +from ...ops.feature_pyramid_network import LastLevelP6P7 +from ...transforms._presets import ObjectDetection +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._meta import _COCO_CATEGORIES +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..resnet import resnet50, ResNet50_Weights +from . import _utils as det_utils +from .anchor_utils import AnchorGenerator +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers +from .transform import GeneralizedRCNNTransform + + +__all__ = [ + "FCOS", + "FCOS_ResNet50_FPN_Weights", + "fcos_resnet50_fpn", +] + + +class FCOSHead(nn.Module): + """ + A regression and classification head for use in FCOS. + + Args: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_classes (int): number of classes to be predicted + num_convs (Optional[int]): number of conv layer of head. Default: 4. + """ + + __annotations__ = { + "box_coder": det_utils.BoxLinearCoder, + } + + def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4) -> None: + super().__init__() + self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True) + self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs) + self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs) + + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + anchors: List[Tensor], + matched_idxs: List[Tensor], + ) -> Dict[str, Tensor]: + + cls_logits = head_outputs["cls_logits"] # [N, HWA, C] + bbox_regression = head_outputs["bbox_regression"] # [N, HWA, 4] + bbox_ctrness = head_outputs["bbox_ctrness"] # [N, HWA, 1] + + all_gt_classes_targets = [] + all_gt_boxes_targets = [] + for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs): + if len(targets_per_image["labels"]) == 0: + gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),)) + gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4)) + else: + gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)] + gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)] + gt_classes_targets[matched_idxs_per_image < 0] = -1 # background + all_gt_classes_targets.append(gt_classes_targets) + all_gt_boxes_targets.append(gt_boxes_targets) + + # List[Tensor] to Tensor conversion of `all_gt_boxes_target`, `all_gt_classes_targets` and `anchors` + all_gt_boxes_targets, all_gt_classes_targets, anchors = ( + torch.stack(all_gt_boxes_targets), + torch.stack(all_gt_classes_targets), + torch.stack(anchors), + ) + + # compute foregroud + foregroud_mask = all_gt_classes_targets >= 0 + num_foreground = foregroud_mask.sum().item() + + # classification loss + gt_classes_targets = torch.zeros_like(cls_logits) + gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0 + loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum") + + # amp issue: pred_boxes need to convert float + pred_boxes = self.box_coder.decode(bbox_regression, anchors) + + # regression loss: GIoU loss + loss_bbox_reg = generalized_box_iou_loss( + pred_boxes[foregroud_mask], + all_gt_boxes_targets[foregroud_mask], + reduction="sum", + ) + + # ctrness loss + + bbox_reg_targets = self.box_coder.encode(anchors, all_gt_boxes_targets) + + if len(bbox_reg_targets) == 0: + gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1]) + else: + left_right = bbox_reg_targets[:, :, [0, 2]] + top_bottom = bbox_reg_targets[:, :, [1, 3]] + gt_ctrness_targets = torch.sqrt( + (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) + * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]) + ) + pred_centerness = bbox_ctrness.squeeze(dim=2) + loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits( + pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum" + ) + + return { + "classification": loss_cls / max(1, num_foreground), + "bbox_regression": loss_bbox_reg / max(1, num_foreground), + "bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground), + } + + def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: + cls_logits = self.classification_head(x) + bbox_regression, bbox_ctrness = self.regression_head(x) + return { + "cls_logits": cls_logits, + "bbox_regression": bbox_regression, + "bbox_ctrness": bbox_ctrness, + } + + +class FCOSClassificationHead(nn.Module): + """ + A classification head for use in FCOS. + + Args: + in_channels (int): number of channels of the input feature. + num_anchors (int): number of anchors to be predicted. + num_classes (int): number of classes to be predicted. + num_convs (Optional[int]): number of conv layer. Default: 4. + prior_probability (Optional[float]): probability of prior. Default: 0.01. + norm_layer: Module specifying the normalization layer to use. + """ + + def __init__( + self, + in_channels: int, + num_anchors: int, + num_classes: int, + num_convs: int = 4, + prior_probability: float = 0.01, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + + self.num_classes = num_classes + self.num_anchors = num_anchors + + if norm_layer is None: + norm_layer = partial(nn.GroupNorm, 32) + + conv = [] + for _ in range(num_convs): + conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) + conv.append(norm_layer(in_channels)) + conv.append(nn.ReLU()) + self.conv = nn.Sequential(*conv) + + for layer in self.conv.children(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.constant_(layer.bias, 0) + + self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) + torch.nn.init.normal_(self.cls_logits.weight, std=0.01) + torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability)) + + def forward(self, x: List[Tensor]) -> Tensor: + all_cls_logits = [] + + for features in x: + cls_logits = self.conv(features) + cls_logits = self.cls_logits(cls_logits) + + # Permute classification output from (N, A * K, H, W) to (N, HWA, K). + N, _, H, W = cls_logits.shape + cls_logits = cls_logits.view(N, -1, self.num_classes, H, W) + cls_logits = cls_logits.permute(0, 3, 4, 1, 2) + cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4) + + all_cls_logits.append(cls_logits) + + return torch.cat(all_cls_logits, dim=1) + + +class FCOSRegressionHead(nn.Module): + """ + A regression head for use in FCOS, which combines regression branch and center-ness branch. + This can obtain better performance. + + Reference: `FCOS: A simple and strong anchor-free object detector `_. + + Args: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_convs (Optional[int]): number of conv layer. Default: 4. + norm_layer: Module specifying the normalization layer to use. + """ + + def __init__( + self, + in_channels: int, + num_anchors: int, + num_convs: int = 4, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ): + super().__init__() + + if norm_layer is None: + norm_layer = partial(nn.GroupNorm, 32) + + conv = [] + for _ in range(num_convs): + conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) + conv.append(norm_layer(in_channels)) + conv.append(nn.ReLU()) + self.conv = nn.Sequential(*conv) + + self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) + self.bbox_ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1) + for layer in [self.bbox_reg, self.bbox_ctrness]: + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.zeros_(layer.bias) + + for layer in self.conv.children(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.zeros_(layer.bias) + + def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]: + all_bbox_regression = [] + all_bbox_ctrness = [] + + for features in x: + bbox_feature = self.conv(features) + bbox_regression = nn.functional.relu(self.bbox_reg(bbox_feature)) + bbox_ctrness = self.bbox_ctrness(bbox_feature) + + # permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). + N, _, H, W = bbox_regression.shape + bbox_regression = bbox_regression.view(N, -1, 4, H, W) + bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2) + bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4) + all_bbox_regression.append(bbox_regression) + + # permute bbox ctrness output from (N, 1 * A, H, W) to (N, HWA, 1). + bbox_ctrness = bbox_ctrness.view(N, -1, 1, H, W) + bbox_ctrness = bbox_ctrness.permute(0, 3, 4, 1, 2) + bbox_ctrness = bbox_ctrness.reshape(N, -1, 1) + all_bbox_ctrness.append(bbox_ctrness) + + return torch.cat(all_bbox_regression, dim=1), torch.cat(all_bbox_ctrness, dim=1) + + +class FCOS(nn.Module): + """ + Implements FCOS. + + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes. + + The behavior of the model changes depending on if it is in training or evaluation mode. + + During training, the model expects both the input tensors and targets (list of dictionary), + containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the class label for each ground-truth box + + The model returns a Dict[Tensor] during training, containing the classification, regression + and centerness losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows: + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the predicted labels for each image + - scores (Tensor[N]): the scores for each prediction + + Args: + backbone (nn.Module): the network used to compute the features for the model. + It should contain an out_channels attribute, which indicates the number of output + channels that each feature map has (and it should be the same for all feature maps). + The backbone should return a single Tensor or an OrderedDict[Tensor]. + num_classes (int): number of output classes of the model (including the background). + min_size (int): Images are rescaled before feeding them to the backbone: + we attempt to preserve the aspect ratio and scale the shorter edge + to ``min_size``. If the resulting longer edge exceeds ``max_size``, + then downscale so that the longer edge does not exceed ``max_size``. + This may result in the shorter edge beeing lower than ``min_size``. + max_size (int): See ``min_size``. + image_mean (Tuple[float, float, float]): mean values used for input normalization. + They are generally the mean values of the dataset on which the backbone has been trained + on + image_std (Tuple[float, float, float]): std values used for input normalization. + They are generally the std values of the dataset on which the backbone has been trained on + anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature + maps. For FCOS, only set one anchor for per position of each level, the width and height equal to + the stride of feature map, and set aspect ratio = 1.0, so the center of anchor is equivalent to the point + in FCOS paper. + head (nn.Module): Module run on top of the feature pyramid. + Defaults to a module containing a classification and regression module. + center_sampling_radius (int): radius of the "center" of a groundtruth box, + within which all anchor points are labeled positive. + score_thresh (float): Score threshold used for postprocessing the detections. + nms_thresh (float): NMS threshold used for postprocessing the detections. + detections_per_img (int): Number of best detections to keep after NMS. + topk_candidates (int): Number of best detections to keep before NMS. + + Example: + + >>> import torch + >>> import torchvision + >>> from torchvision.models.detection import FCOS + >>> from torchvision.models.detection.anchor_utils import AnchorGenerator + >>> # load a pre-trained model for classification and return + >>> # only the features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features + >>> # FCOS needs to know the number of + >>> # output channels in a backbone. For mobilenet_v2, it's 1280, + >>> # so we need to add it here + >>> backbone.out_channels = 1280 + >>> + >>> # let's make the network generate 5 x 3 anchors per spatial + >>> # location, with 5 different sizes and 3 different aspect + >>> # ratios. We have a Tuple[Tuple[int]] because each feature + >>> # map could potentially have different sizes and + >>> # aspect ratios + >>> anchor_generator = AnchorGenerator( + >>> sizes=((8,), (16,), (32,), (64,), (128,)), + >>> aspect_ratios=((1.0,),) + >>> ) + >>> + >>> # put the pieces together inside a FCOS model + >>> model = FCOS( + >>> backbone, + >>> num_classes=80, + >>> anchor_generator=anchor_generator, + >>> ) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + """ + + __annotations__ = { + "box_coder": det_utils.BoxLinearCoder, + } + + def __init__( + self, + backbone: nn.Module, + num_classes: int, + # transform parameters + min_size: int = 800, + max_size: int = 1333, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + # Anchor parameters + anchor_generator: Optional[AnchorGenerator] = None, + head: Optional[nn.Module] = None, + center_sampling_radius: float = 1.5, + score_thresh: float = 0.2, + nms_thresh: float = 0.6, + detections_per_img: int = 100, + topk_candidates: int = 1000, + **kwargs, + ): + super().__init__() + _log_api_usage_once(self) + + if not hasattr(backbone, "out_channels"): + raise ValueError( + "backbone should contain an attribute out_channels " + "specifying the number of output channels (assumed to be the " + "same for all the levels)" + ) + self.backbone = backbone + + if not isinstance(anchor_generator, (AnchorGenerator, type(None))): + raise TypeError( + f"anchor_generator should be of type AnchorGenerator or None, instead got {type(anchor_generator)}" + ) + + if anchor_generator is None: + anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map + aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor + anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) + self.anchor_generator = anchor_generator + if self.anchor_generator.num_anchors_per_location()[0] != 1: + raise ValueError( + f"anchor_generator.num_anchors_per_location()[0] should be 1 instead of {anchor_generator.num_anchors_per_location()[0]}" + ) + + if head is None: + head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes) + self.head = head + + self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True) + + if image_mean is None: + image_mean = [0.485, 0.456, 0.406] + if image_std is None: + image_std = [0.229, 0.224, 0.225] + self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs) + + self.center_sampling_radius = center_sampling_radius + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.detections_per_img = detections_per_img + self.topk_candidates = topk_candidates + + # used only on torchscript mode + self._has_warned = False + + @torch.jit.unused + def eager_outputs( + self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]] + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + if self.training: + return losses + + return detections + + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + anchors: List[Tensor], + num_anchors_per_level: List[int], + ) -> Dict[str, Tensor]: + matched_idxs = [] + for anchors_per_image, targets_per_image in zip(anchors, targets): + if targets_per_image["boxes"].numel() == 0: + matched_idxs.append( + torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device) + ) + continue + + gt_boxes = targets_per_image["boxes"] + gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2 # Nx2 + anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N + anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0] + # center sampling: anchor point must be close enough to gt center. + pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max( + dim=2 + ).values < self.center_sampling_radius * anchor_sizes[:, None] + # compute pairwise distance between N points and M boxes + x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1) + x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M) + pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M) + + # anchor point must be inside gt + pairwise_match &= pairwise_dist.min(dim=2).values > 0 + + # each anchor is only responsible for certain scale range. + lower_bound = anchor_sizes * 4 + lower_bound[: num_anchors_per_level[0]] = 0 + upper_bound = anchor_sizes * 8 + upper_bound[-num_anchors_per_level[-1] :] = float("inf") + pairwise_dist = pairwise_dist.max(dim=2).values + pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (pairwise_dist < upper_bound[:, None]) + + # match the GT box with minimum area, if there are multiple GT matches + gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N + pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :]) + min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match + matched_idx[min_values < 1e-5] = -1 # unmatched anchors are assigned -1 + + matched_idxs.append(matched_idx) + + return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) + + def postprocess_detections( + self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]] + ) -> List[Dict[str, Tensor]]: + class_logits = head_outputs["cls_logits"] + box_regression = head_outputs["bbox_regression"] + box_ctrness = head_outputs["bbox_ctrness"] + + num_images = len(image_shapes) + + detections: List[Dict[str, Tensor]] = [] + + for index in range(num_images): + box_regression_per_image = [br[index] for br in box_regression] + logits_per_image = [cl[index] for cl in class_logits] + box_ctrness_per_image = [bc[index] for bc in box_ctrness] + anchors_per_image, image_shape = anchors[index], image_shapes[index] + + image_boxes = [] + image_scores = [] + image_labels = [] + + for box_regression_per_level, logits_per_level, box_ctrness_per_level, anchors_per_level in zip( + box_regression_per_image, logits_per_image, box_ctrness_per_image, anchors_per_image + ): + num_classes = logits_per_level.shape[-1] + + # remove low scoring boxes + scores_per_level = torch.sqrt( + torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level) + ).flatten() + keep_idxs = scores_per_level > self.score_thresh + scores_per_level = scores_per_level[keep_idxs] + topk_idxs = torch.where(keep_idxs)[0] + + # keep only topk scoring predictions + num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0) + scores_per_level, idxs = scores_per_level.topk(num_topk) + topk_idxs = topk_idxs[idxs] + + anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor") + labels_per_level = topk_idxs % num_classes + + boxes_per_level = self.box_coder.decode( + box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs] + ) + boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape) + + image_boxes.append(boxes_per_level) + image_scores.append(scores_per_level) + image_labels.append(labels_per_level) + + image_boxes = torch.cat(image_boxes, dim=0) + image_scores = torch.cat(image_scores, dim=0) + image_labels = torch.cat(image_labels, dim=0) + + # non-maximum suppression + keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) + keep = keep[: self.detections_per_img] + + detections.append( + { + "boxes": image_boxes[keep], + "scores": image_scores[keep], + "labels": image_labels[keep], + } + ) + + return detections + + def forward( + self, + images: List[Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + """ + Args: + images (list[Tensor]): images to be processed + targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) + + Returns: + result (list[BoxList] or dict[Tensor]): the output from the model. + During training, it returns a dict[Tensor] which contains the losses. + During testing, it returns list[BoxList] contains additional fields + like `scores`, `labels` and `mask` (for Mask R-CNN models). + """ + if self.training: + + if targets is None: + torch._assert(False, "targets should not be none when in training mode") + else: + for target in targets: + boxes = target["boxes"] + torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.") + torch._assert( + len(boxes.shape) == 2 and boxes.shape[-1] == 4, + f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.", + ) + + original_image_sizes: List[Tuple[int, int]] = [] + for img in images: + val = img.shape[-2:] + torch._assert( + len(val) == 2, + f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", + ) + original_image_sizes.append((val[0], val[1])) + + # transform the input + images, targets = self.transform(images, targets) + + # Check for degenerate boxes + if targets is not None: + for target_idx, target in enumerate(targets): + boxes = target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + # print the first degenerate box + bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] + degen_bb: List[float] = boxes[bb_idx].tolist() + torch._assert( + False, + f"All bounding boxes should have positive height and width. Found invalid box {degen_bb} for target at index {target_idx}.", + ) + + # get the features from the backbone + features = self.backbone(images.tensors) + if isinstance(features, torch.Tensor): + features = OrderedDict([("0", features)]) + + features = list(features.values()) + + # compute the fcos heads outputs using the features + head_outputs = self.head(features) + + # create the set of anchors + anchors = self.anchor_generator(images, features) + # recover level sizes + num_anchors_per_level = [x.size(2) * x.size(3) for x in features] + + losses = {} + detections: List[Dict[str, Tensor]] = [] + if self.training: + if targets is None: + torch._assert(False, "targets should not be none when in training mode") + else: + # compute the losses + losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level) + else: + # split outputs per level + split_head_outputs: Dict[str, List[Tensor]] = {} + for k in head_outputs: + split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1)) + split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors] + + # compute the detections + detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes) + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) + + if torch.jit.is_scripting(): + if not self._has_warned: + warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting") + self._has_warned = True + return losses, detections + return self.eager_outputs(losses, detections) + + +class FCOS_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", + transforms=ObjectDetection, + meta={ + "num_params": 32269600, + "categories": _COCO_CATEGORIES, + "min_size": (1, 1), + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn", + "_metrics": { + "COCO-val2017": { + "box_map": 39.2, + } + }, + "_ops": 128.207, + "_file_size": 123.608, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", + }, + ) + DEFAULT = COCO_V1 + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) +def fcos_resnet50_fpn( + *, + weights: Optional[FCOS_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FCOS: + """ + Constructs a FCOS model with a ResNet-50-FPN backbone. + + .. betastatus:: detection module + + Reference: `FCOS: Fully Convolutional One-Stage Object Detection `_. + `FCOS: A simple and strong anchor-free object detector `_. + + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each + image, and should be in ``0-1`` range. Different images can have different sizes. + + The behavior of the model changes depending on if it is in training or evaluation mode. + + During training, the model expects both the input tensors and targets (list of dictionary), + containing: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as + follows, where ``N`` is the number of detections: + + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the predicted labels for each detection + - scores (``Tensor[N]``): the scores of each detection + + For more details on the output, you may refer to :ref:`instance_seg_output`. + + Example: + + >>> model = torchvision.models.detection.fcos_resnet50_fpn(weights=FCOS_ResNet50_FPN_Weights.DEFAULT) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + weights (:class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for + the backbone. + trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting + from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are + trainable. If ``None`` is passed (the default) this value is set to 3. Default: None + **kwargs: parameters passed to the ``torchvision.models.detection.FCOS`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.FCOS_ResNet50_FPN_Weights + :members: + """ + weights = FCOS_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d + + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) + backbone = _resnet_fpn_extractor( + backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) + ) + model = FCOS(backbone, num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/detection/generalized_rcnn.py b/lib/python3.10/site-packages/torchvision/models/detection/generalized_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..b481265077fb5a582402d81aeb3516ffca063653 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/generalized_rcnn.py @@ -0,0 +1,118 @@ +""" +Implements the Generalized R-CNN framework +""" + +import warnings +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import nn, Tensor + +from ...utils import _log_api_usage_once + + +class GeneralizedRCNN(nn.Module): + """ + Main class for Generalized R-CNN. + + Args: + backbone (nn.Module): + rpn (nn.Module): + roi_heads (nn.Module): takes the features + the proposals from the RPN and computes + detections / masks from it. + transform (nn.Module): performs the data transformation from the inputs to feed into + the model + """ + + def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None: + super().__init__() + _log_api_usage_once(self) + self.transform = transform + self.backbone = backbone + self.rpn = rpn + self.roi_heads = roi_heads + # used only on torchscript mode + self._has_warned = False + + @torch.jit.unused + def eager_outputs(self, losses, detections): + # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]] + if self.training: + return losses + + return detections + + def forward(self, images, targets=None): + # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + """ + Args: + images (list[Tensor]): images to be processed + targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional) + + Returns: + result (list[BoxList] or dict[Tensor]): the output from the model. + During training, it returns a dict[Tensor] which contains the losses. + During testing, it returns list[BoxList] contains additional fields + like `scores`, `labels` and `mask` (for Mask R-CNN models). + + """ + if self.training: + if targets is None: + torch._assert(False, "targets should not be none when in training mode") + else: + for target in targets: + boxes = target["boxes"] + if isinstance(boxes, torch.Tensor): + torch._assert( + len(boxes.shape) == 2 and boxes.shape[-1] == 4, + f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.", + ) + else: + torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.") + + original_image_sizes: List[Tuple[int, int]] = [] + for img in images: + val = img.shape[-2:] + torch._assert( + len(val) == 2, + f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", + ) + original_image_sizes.append((val[0], val[1])) + + images, targets = self.transform(images, targets) + + # Check for degenerate boxes + # TODO: Move this to a function + if targets is not None: + for target_idx, target in enumerate(targets): + boxes = target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + # print the first degenerate box + bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] + degen_bb: List[float] = boxes[bb_idx].tolist() + torch._assert( + False, + "All bounding boxes should have positive height and width." + f" Found invalid box {degen_bb} for target at index {target_idx}.", + ) + + features = self.backbone(images.tensors) + if isinstance(features, torch.Tensor): + features = OrderedDict([("0", features)]) + proposals, proposal_losses = self.rpn(images, features, targets) + detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator] + + losses = {} + losses.update(detector_losses) + losses.update(proposal_losses) + + if torch.jit.is_scripting(): + if not self._has_warned: + warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting") + self._has_warned = True + return losses, detections + else: + return self.eager_outputs(losses, detections) diff --git a/lib/python3.10/site-packages/torchvision/models/detection/image_list.py b/lib/python3.10/site-packages/torchvision/models/detection/image_list.py new file mode 100644 index 0000000000000000000000000000000000000000..583866557e4c9ec178e7cc268272db3de1698e41 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/image_list.py @@ -0,0 +1,25 @@ +from typing import List, Tuple + +import torch +from torch import Tensor + + +class ImageList: + """ + Structure that holds a list of images (of possibly + varying sizes) as a single tensor. + This works by padding the images to the same size, + and storing in a field the original sizes of each image + + Args: + tensors (tensor): Tensor containing images. + image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images. + """ + + def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None: + self.tensors = tensors + self.image_sizes = image_sizes + + def to(self, device: torch.device) -> "ImageList": + cast_tensor = self.tensors.to(device) + return ImageList(cast_tensor, self.image_sizes) diff --git a/lib/python3.10/site-packages/torchvision/models/detection/keypoint_rcnn.py b/lib/python3.10/site-packages/torchvision/models/detection/keypoint_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5d7ff0ea433a681064a11a22c3e276e253997772 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/keypoint_rcnn.py @@ -0,0 +1,474 @@ +from typing import Any, Optional + +import torch +from torch import nn +from torchvision.ops import MultiScaleRoIAlign + +from ...ops import misc as misc_nn_ops +from ...transforms._presets import ObjectDetection +from .._api import register_model, Weights, WeightsEnum +from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..resnet import resnet50, ResNet50_Weights +from ._utils import overwrite_eps +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers +from .faster_rcnn import FasterRCNN + + +__all__ = [ + "KeypointRCNN", + "KeypointRCNN_ResNet50_FPN_Weights", + "keypointrcnn_resnet50_fpn", +] + + +class KeypointRCNN(FasterRCNN): + """ + Implements Keypoint R-CNN. + + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes. + + The behavior of the model changes depending on if it is in training or evaluation mode. + + During training, the model expects both the input tensors and targets (list of dictionary), + containing: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the class label for each ground-truth box + - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the + format [x, y, visibility], where visibility=0 means that the keypoint is not visible. + + The model returns a Dict[Tensor] during training, containing the classification and regression + losses for both the RPN and the R-CNN, and the keypoint loss. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows: + + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the predicted labels for each image + - scores (Tensor[N]): the scores or each prediction + - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format. + + Args: + backbone (nn.Module): the network used to compute the features for the model. + It should contain an out_channels attribute, which indicates the number of output + channels that each feature map has (and it should be the same for all feature maps). + The backbone should return a single Tensor or and OrderedDict[Tensor]. + num_classes (int): number of output classes of the model (including the background). + If box_predictor is specified, num_classes should be None. + min_size (int): Images are rescaled before feeding them to the backbone: + we attempt to preserve the aspect ratio and scale the shorter edge + to ``min_size``. If the resulting longer edge exceeds ``max_size``, + then downscale so that the longer edge does not exceed ``max_size``. + This may result in the shorter edge beeing lower than ``min_size``. + max_size (int): See ``min_size``. + image_mean (Tuple[float, float, float]): mean values used for input normalization. + They are generally the mean values of the dataset on which the backbone has been trained + on + image_std (Tuple[float, float, float]): std values used for input normalization. + They are generally the std values of the dataset on which the backbone has been trained on + rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature + maps. + rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN + rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training + rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing + rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training + rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing + rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals + rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be + considered as positive during training of the RPN. + rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be + considered as negative during training of the RPN. + rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN + for computing the loss + rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training + of the RPN + rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh + box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in + the locations indicated by the bounding boxes + box_head (nn.Module): module that takes the cropped feature maps as input + box_predictor (nn.Module): module that takes the output of box_head and returns the + classification logits and box regression deltas. + box_score_thresh (float): during inference, only return proposals with a classification score + greater than box_score_thresh + box_nms_thresh (float): NMS threshold for the prediction head. Used during inference + box_detections_per_img (int): maximum number of detections per image, for all classes. + box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be + considered as positive during training of the classification head + box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be + considered as negative during training of the classification head + box_batch_size_per_image (int): number of proposals that are sampled during training of the + classification head + box_positive_fraction (float): proportion of positive proposals in a mini-batch during training + of the classification head + bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the + bounding boxes + keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in + the locations indicated by the bounding boxes, which will be used for the keypoint head. + keypoint_head (nn.Module): module that takes the cropped feature maps as input + keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the + heatmap logits + + Example:: + + >>> import torch + >>> import torchvision + >>> from torchvision.models.detection import KeypointRCNN + >>> from torchvision.models.detection.anchor_utils import AnchorGenerator + >>> + >>> # load a pre-trained model for classification and return + >>> # only the features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features + >>> # KeypointRCNN needs to know the number of + >>> # output channels in a backbone. For mobilenet_v2, it's 1280, + >>> # so we need to add it here + >>> backbone.out_channels = 1280 + >>> + >>> # let's make the RPN generate 5 x 3 anchors per spatial + >>> # location, with 5 different sizes and 3 different aspect + >>> # ratios. We have a Tuple[Tuple[int]] because each feature + >>> # map could potentially have different sizes and + >>> # aspect ratios + >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), + >>> aspect_ratios=((0.5, 1.0, 2.0),)) + >>> + >>> # let's define what are the feature maps that we will + >>> # use to perform the region of interest cropping, as well as + >>> # the size of the crop after rescaling. + >>> # if your backbone returns a Tensor, featmap_names is expected to + >>> # be ['0']. More generally, the backbone should return an + >>> # OrderedDict[Tensor], and in featmap_names you can choose which + >>> # feature maps to use. + >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], + >>> output_size=7, + >>> sampling_ratio=2) + >>> + >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], + >>> output_size=14, + >>> sampling_ratio=2) + >>> # put the pieces together inside a KeypointRCNN model + >>> model = KeypointRCNN(backbone, + >>> num_classes=2, + >>> rpn_anchor_generator=anchor_generator, + >>> box_roi_pool=roi_pooler, + >>> keypoint_roi_pool=keypoint_roi_pooler) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + """ + + def __init__( + self, + backbone, + num_classes=None, + # transform parameters + min_size=None, + max_size=1333, + image_mean=None, + image_std=None, + # RPN parameters + rpn_anchor_generator=None, + rpn_head=None, + rpn_pre_nms_top_n_train=2000, + rpn_pre_nms_top_n_test=1000, + rpn_post_nms_top_n_train=2000, + rpn_post_nms_top_n_test=1000, + rpn_nms_thresh=0.7, + rpn_fg_iou_thresh=0.7, + rpn_bg_iou_thresh=0.3, + rpn_batch_size_per_image=256, + rpn_positive_fraction=0.5, + rpn_score_thresh=0.0, + # Box parameters + box_roi_pool=None, + box_head=None, + box_predictor=None, + box_score_thresh=0.05, + box_nms_thresh=0.5, + box_detections_per_img=100, + box_fg_iou_thresh=0.5, + box_bg_iou_thresh=0.5, + box_batch_size_per_image=512, + box_positive_fraction=0.25, + bbox_reg_weights=None, + # keypoint parameters + keypoint_roi_pool=None, + keypoint_head=None, + keypoint_predictor=None, + num_keypoints=None, + **kwargs, + ): + + if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))): + raise TypeError( + "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}" + ) + if min_size is None: + min_size = (640, 672, 704, 736, 768, 800) + + if num_keypoints is not None: + if keypoint_predictor is not None: + raise ValueError("num_keypoints should be None when keypoint_predictor is specified") + else: + num_keypoints = 17 + + out_channels = backbone.out_channels + + if keypoint_roi_pool is None: + keypoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2) + + if keypoint_head is None: + keypoint_layers = tuple(512 for _ in range(8)) + keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers) + + if keypoint_predictor is None: + keypoint_dim_reduced = 512 # == keypoint_layers[-1] + keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints) + + super().__init__( + backbone, + num_classes, + # transform parameters + min_size, + max_size, + image_mean, + image_std, + # RPN-specific parameters + rpn_anchor_generator, + rpn_head, + rpn_pre_nms_top_n_train, + rpn_pre_nms_top_n_test, + rpn_post_nms_top_n_train, + rpn_post_nms_top_n_test, + rpn_nms_thresh, + rpn_fg_iou_thresh, + rpn_bg_iou_thresh, + rpn_batch_size_per_image, + rpn_positive_fraction, + rpn_score_thresh, + # Box parameters + box_roi_pool, + box_head, + box_predictor, + box_score_thresh, + box_nms_thresh, + box_detections_per_img, + box_fg_iou_thresh, + box_bg_iou_thresh, + box_batch_size_per_image, + box_positive_fraction, + bbox_reg_weights, + **kwargs, + ) + + self.roi_heads.keypoint_roi_pool = keypoint_roi_pool + self.roi_heads.keypoint_head = keypoint_head + self.roi_heads.keypoint_predictor = keypoint_predictor + + +class KeypointRCNNHeads(nn.Sequential): + def __init__(self, in_channels, layers): + d = [] + next_feature = in_channels + for out_channels in layers: + d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1)) + d.append(nn.ReLU(inplace=True)) + next_feature = out_channels + super().__init__(*d) + for m in self.children(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + nn.init.constant_(m.bias, 0) + + +class KeypointRCNNPredictor(nn.Module): + def __init__(self, in_channels, num_keypoints): + super().__init__() + input_features = in_channels + deconv_kernel = 4 + self.kps_score_lowres = nn.ConvTranspose2d( + input_features, + num_keypoints, + deconv_kernel, + stride=2, + padding=deconv_kernel // 2 - 1, + ) + nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu") + nn.init.constant_(self.kps_score_lowres.bias, 0) + self.up_scale = 2 + self.out_channels = num_keypoints + + def forward(self, x): + x = self.kps_score_lowres(x) + return torch.nn.functional.interpolate( + x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False + ) + + +_COMMON_META = { + "categories": _COCO_PERSON_CATEGORIES, + "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES, + "min_size": (1, 1), +} + + +class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): + COCO_LEGACY = Weights( + url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 59137258, + "recipe": "https://github.com/pytorch/vision/issues/1606", + "_metrics": { + "COCO-val2017": { + "box_map": 50.6, + "kp_map": 61.1, + } + }, + "_ops": 133.924, + "_file_size": 226.054, + "_docs": """ + These weights were produced by following a similar training recipe as on the paper but use a checkpoint + from an early epoch. + """, + }, + ) + COCO_V1 = Weights( + url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 59137258, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn", + "_metrics": { + "COCO-val2017": { + "box_map": 54.6, + "kp_map": 65.0, + } + }, + "_ops": 137.42, + "_file_size": 226.054, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", + }, + ) + DEFAULT = COCO_V1 + + +@register_model() +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY + if kwargs["pretrained"] == "legacy" + else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1, + ), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) +def keypointrcnn_resnet50_fpn( + *, + weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + num_keypoints: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> KeypointRCNN: + """ + Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. + + .. betastatus:: detection module + + Reference: `Mask R-CNN `__. + + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each + image, and should be in ``0-1`` range. Different images can have different sizes. + + The behavior of the model changes depending on if it is in training or evaluation mode. + + During training, the model expects both the input tensors and targets (list of dictionary), + containing: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the + format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible. + + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression + losses for both the RPN and the R-CNN, and the keypoint loss. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as + follows, where ``N`` is the number of detected instances: + + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the predicted labels for each instance + - scores (``Tensor[N]``): the scores or each instance + - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format. + + For more details on the output, you may refer to :ref:`instance_seg_output`. + + Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size. + + Example:: + + >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + >>> + >>> # optionally, if you want to export the model to ONNX: + >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11) + + Args: + weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int, optional): number of output classes of the model (including the background) + num_keypoints (int, optional): number of keypoints + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The + pretrained weights for the backbone. + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. + Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is + passed (the default) this value is set to 3. + + .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights + :members: + """ + weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"])) + else: + if num_classes is None: + num_classes = 2 + if num_keypoints is None: + num_keypoints = 17 + + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d + + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) + model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/detection/mask_rcnn.py b/lib/python3.10/site-packages/torchvision/models/detection/mask_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..cdabbfd26ca8bbefaefdb6fb8b098afac217b595 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/mask_rcnn.py @@ -0,0 +1,590 @@ +from collections import OrderedDict +from typing import Any, Callable, Optional + +from torch import nn +from torchvision.ops import MultiScaleRoIAlign + +from ...ops import misc as misc_nn_ops +from ...transforms._presets import ObjectDetection +from .._api import register_model, Weights, WeightsEnum +from .._meta import _COCO_CATEGORIES +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..resnet import resnet50, ResNet50_Weights +from ._utils import overwrite_eps +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers +from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead + + +__all__ = [ + "MaskRCNN", + "MaskRCNN_ResNet50_FPN_Weights", + "MaskRCNN_ResNet50_FPN_V2_Weights", + "maskrcnn_resnet50_fpn", + "maskrcnn_resnet50_fpn_v2", +] + + +class MaskRCNN(FasterRCNN): + """ + Implements Mask R-CNN. + + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes. + + The behavior of the model changes depending on if it is in training or evaluation mode. + + During training, the model expects both the input tensors and targets (list of dictionary), + containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the class label for each ground-truth box + - masks (UInt8Tensor[N, H, W]): the segmentation binary masks for each instance + + The model returns a Dict[Tensor] during training, containing the classification and regression + losses for both the RPN and the R-CNN, and the mask loss. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows: + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the predicted labels for each image + - scores (Tensor[N]): the scores or each prediction + - masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range. In order to + obtain the final segmentation masks, the soft masks can be thresholded, generally + with a value of 0.5 (mask >= 0.5) + + Args: + backbone (nn.Module): the network used to compute the features for the model. + It should contain an out_channels attribute, which indicates the number of output + channels that each feature map has (and it should be the same for all feature maps). + The backbone should return a single Tensor or and OrderedDict[Tensor]. + num_classes (int): number of output classes of the model (including the background). + If box_predictor is specified, num_classes should be None. + min_size (int): Images are rescaled before feeding them to the backbone: + we attempt to preserve the aspect ratio and scale the shorter edge + to ``min_size``. If the resulting longer edge exceeds ``max_size``, + then downscale so that the longer edge does not exceed ``max_size``. + This may result in the shorter edge beeing lower than ``min_size``. + max_size (int): See ``min_size``. + image_mean (Tuple[float, float, float]): mean values used for input normalization. + They are generally the mean values of the dataset on which the backbone has been trained + on + image_std (Tuple[float, float, float]): std values used for input normalization. + They are generally the std values of the dataset on which the backbone has been trained on + rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature + maps. + rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN + rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training + rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing + rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training + rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing + rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals + rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be + considered as positive during training of the RPN. + rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be + considered as negative during training of the RPN. + rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN + for computing the loss + rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training + of the RPN + rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh + box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in + the locations indicated by the bounding boxes + box_head (nn.Module): module that takes the cropped feature maps as input + box_predictor (nn.Module): module that takes the output of box_head and returns the + classification logits and box regression deltas. + box_score_thresh (float): during inference, only return proposals with a classification score + greater than box_score_thresh + box_nms_thresh (float): NMS threshold for the prediction head. Used during inference + box_detections_per_img (int): maximum number of detections per image, for all classes. + box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be + considered as positive during training of the classification head + box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be + considered as negative during training of the classification head + box_batch_size_per_image (int): number of proposals that are sampled during training of the + classification head + box_positive_fraction (float): proportion of positive proposals in a mini-batch during training + of the classification head + bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the + bounding boxes + mask_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in + the locations indicated by the bounding boxes, which will be used for the mask head. + mask_head (nn.Module): module that takes the cropped feature maps as input + mask_predictor (nn.Module): module that takes the output of the mask_head and returns the + segmentation mask logits + + Example:: + + >>> import torch + >>> import torchvision + >>> from torchvision.models.detection import MaskRCNN + >>> from torchvision.models.detection.anchor_utils import AnchorGenerator + >>> + >>> # load a pre-trained model for classification and return + >>> # only the features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features + >>> # MaskRCNN needs to know the number of + >>> # output channels in a backbone. For mobilenet_v2, it's 1280 + >>> # so we need to add it here, + >>> backbone.out_channels = 1280 + >>> + >>> # let's make the RPN generate 5 x 3 anchors per spatial + >>> # location, with 5 different sizes and 3 different aspect + >>> # ratios. We have a Tuple[Tuple[int]] because each feature + >>> # map could potentially have different sizes and + >>> # aspect ratios + >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), + >>> aspect_ratios=((0.5, 1.0, 2.0),)) + >>> + >>> # let's define what are the feature maps that we will + >>> # use to perform the region of interest cropping, as well as + >>> # the size of the crop after rescaling. + >>> # if your backbone returns a Tensor, featmap_names is expected to + >>> # be ['0']. More generally, the backbone should return an + >>> # OrderedDict[Tensor], and in featmap_names you can choose which + >>> # feature maps to use. + >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], + >>> output_size=7, + >>> sampling_ratio=2) + >>> + >>> mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], + >>> output_size=14, + >>> sampling_ratio=2) + >>> # put the pieces together inside a MaskRCNN model + >>> model = MaskRCNN(backbone, + >>> num_classes=2, + >>> rpn_anchor_generator=anchor_generator, + >>> box_roi_pool=roi_pooler, + >>> mask_roi_pool=mask_roi_pooler) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + """ + + def __init__( + self, + backbone, + num_classes=None, + # transform parameters + min_size=800, + max_size=1333, + image_mean=None, + image_std=None, + # RPN parameters + rpn_anchor_generator=None, + rpn_head=None, + rpn_pre_nms_top_n_train=2000, + rpn_pre_nms_top_n_test=1000, + rpn_post_nms_top_n_train=2000, + rpn_post_nms_top_n_test=1000, + rpn_nms_thresh=0.7, + rpn_fg_iou_thresh=0.7, + rpn_bg_iou_thresh=0.3, + rpn_batch_size_per_image=256, + rpn_positive_fraction=0.5, + rpn_score_thresh=0.0, + # Box parameters + box_roi_pool=None, + box_head=None, + box_predictor=None, + box_score_thresh=0.05, + box_nms_thresh=0.5, + box_detections_per_img=100, + box_fg_iou_thresh=0.5, + box_bg_iou_thresh=0.5, + box_batch_size_per_image=512, + box_positive_fraction=0.25, + bbox_reg_weights=None, + # Mask parameters + mask_roi_pool=None, + mask_head=None, + mask_predictor=None, + **kwargs, + ): + + if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))): + raise TypeError( + f"mask_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(mask_roi_pool)}" + ) + + if num_classes is not None: + if mask_predictor is not None: + raise ValueError("num_classes should be None when mask_predictor is specified") + + out_channels = backbone.out_channels + + if mask_roi_pool is None: + mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2) + + if mask_head is None: + mask_layers = (256, 256, 256, 256) + mask_dilation = 1 + mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation) + + if mask_predictor is None: + mask_predictor_in_channels = 256 # == mask_layers[-1] + mask_dim_reduced = 256 + mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes) + + super().__init__( + backbone, + num_classes, + # transform parameters + min_size, + max_size, + image_mean, + image_std, + # RPN-specific parameters + rpn_anchor_generator, + rpn_head, + rpn_pre_nms_top_n_train, + rpn_pre_nms_top_n_test, + rpn_post_nms_top_n_train, + rpn_post_nms_top_n_test, + rpn_nms_thresh, + rpn_fg_iou_thresh, + rpn_bg_iou_thresh, + rpn_batch_size_per_image, + rpn_positive_fraction, + rpn_score_thresh, + # Box parameters + box_roi_pool, + box_head, + box_predictor, + box_score_thresh, + box_nms_thresh, + box_detections_per_img, + box_fg_iou_thresh, + box_bg_iou_thresh, + box_batch_size_per_image, + box_positive_fraction, + bbox_reg_weights, + **kwargs, + ) + + self.roi_heads.mask_roi_pool = mask_roi_pool + self.roi_heads.mask_head = mask_head + self.roi_heads.mask_predictor = mask_predictor + + +class MaskRCNNHeads(nn.Sequential): + _version = 2 + + def __init__(self, in_channels, layers, dilation, norm_layer: Optional[Callable[..., nn.Module]] = None): + """ + Args: + in_channels (int): number of input channels + layers (list): feature dimensions of each FCN layer + dilation (int): dilation rate of kernel + norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None + """ + blocks = [] + next_feature = in_channels + for layer_features in layers: + blocks.append( + misc_nn_ops.Conv2dNormActivation( + next_feature, + layer_features, + kernel_size=3, + stride=1, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + ) + ) + next_feature = layer_features + + super().__init__(*blocks) + for layer in self.modules(): + if isinstance(layer, nn.Conv2d): + nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu") + if layer.bias is not None: + nn.init.zeros_(layer.bias) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + num_blocks = len(self) + for i in range(num_blocks): + for type in ["weight", "bias"]: + old_key = f"{prefix}mask_fcn{i+1}.{type}" + new_key = f"{prefix}{i}.0.{type}" + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class MaskRCNNPredictor(nn.Sequential): + def __init__(self, in_channels, dim_reduced, num_classes): + super().__init__( + OrderedDict( + [ + ("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)), + ("relu", nn.ReLU(inplace=True)), + ("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)), + ] + ) + ) + + for name, param in self.named_parameters(): + if "weight" in name: + nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + # elif "bias" in name: + # nn.init.constant_(param, 0) + + +_COMMON_META = { + "categories": _COCO_CATEGORIES, + "min_size": (1, 1), +} + + +class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 44401393, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", + "_metrics": { + "COCO-val2017": { + "box_map": 37.9, + "mask_map": 34.6, + } + }, + "_ops": 134.38, + "_file_size": 169.84, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", + }, + ) + DEFAULT = COCO_V1 + + +class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 46359409, + "recipe": "https://github.com/pytorch/vision/pull/5773", + "_metrics": { + "COCO-val2017": { + "box_map": 47.4, + "mask_map": 41.8, + } + }, + "_ops": 333.577, + "_file_size": 177.219, + "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""", + }, + ) + DEFAULT = COCO_V1 + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) +def maskrcnn_resnet50_fpn( + *, + weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> MaskRCNN: + """Mask R-CNN model with a ResNet-50-FPN backbone from the `Mask R-CNN + `_ paper. + + .. betastatus:: detection module + + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each + image, and should be in ``0-1`` range. Different images can have different sizes. + + The behavior of the model changes depending on if it is in training or evaluation mode. + + During training, the model expects both the input tensors and targets (list of dictionary), + containing: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + - masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance + + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression + losses for both the RPN and the R-CNN, and the mask loss. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as + follows, where ``N`` is the number of detected instances: + + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the predicted labels for each instance + - scores (``Tensor[N]``): the scores or each instance + - masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to + obtain the final segmentation masks, the soft masks can be thresholded, generally + with a value of 0.5 (``mask >= 0.5``) + + For more details on the output and on how to plot the masks, you may refer to :ref:`instance_seg_output`. + + Mask R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size. + + Example:: + + >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + >>> + >>> # optionally, if you want to export the model to ONNX: + >>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11) + + Args: + weights (:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The + pretrained weights for the backbone. + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from + final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are + trainable. If ``None`` is passed (the default) this value is set to 3. + **kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights + :members: + """ + weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d + + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) + model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + + return model + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) +def maskrcnn_resnet50_fpn_v2( + *, + weights: Optional[MaskRCNN_ResNet50_FPN_V2_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> MaskRCNN: + """Improved Mask R-CNN model with a ResNet-50-FPN backbone from the `Benchmarking Detection Transfer + Learning with Vision Transformers `_ paper. + + .. betastatus:: detection module + + :func:`~torchvision.models.detection.maskrcnn_resnet50_fpn` for more details. + + Args: + weights (:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The + pretrained weights for the backbone. + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from + final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are + trainable. If ``None`` is passed (the default) this value is set to 3. + **kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights + :members: + """ + weights = MaskRCNN_ResNet50_FPN_V2_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + + backbone = resnet50(weights=weights_backbone, progress=progress) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d) + rpn_anchor_generator = _default_anchorgen() + rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2) + box_head = FastRCNNConvFCHead( + (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d + ) + mask_head = MaskRCNNHeads(backbone.out_channels, [256, 256, 256, 256], 1, norm_layer=nn.BatchNorm2d) + model = MaskRCNN( + backbone, + num_classes=num_classes, + rpn_anchor_generator=rpn_anchor_generator, + rpn_head=rpn_head, + box_head=box_head, + mask_head=mask_head, + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/detection/retinanet.py b/lib/python3.10/site-packages/torchvision/models/detection/retinanet.py new file mode 100644 index 0000000000000000000000000000000000000000..a8cc7755014b6010965108a46c080f71b2d609db --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/retinanet.py @@ -0,0 +1,903 @@ +import math +import warnings +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import nn, Tensor + +from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss +from ...ops.feature_pyramid_network import LastLevelP6P7 +from ...transforms._presets import ObjectDetection +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._meta import _COCO_CATEGORIES +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..resnet import resnet50, ResNet50_Weights +from . import _utils as det_utils +from ._utils import _box_loss, overwrite_eps +from .anchor_utils import AnchorGenerator +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers +from .transform import GeneralizedRCNNTransform + + +__all__ = [ + "RetinaNet", + "RetinaNet_ResNet50_FPN_Weights", + "RetinaNet_ResNet50_FPN_V2_Weights", + "retinanet_resnet50_fpn", + "retinanet_resnet50_fpn_v2", +] + + +def _sum(x: List[Tensor]) -> Tensor: + res = x[0] + for i in x[1:]: + res = res + i + return res + + +def _v1_to_v2_weights(state_dict, prefix): + for i in range(4): + for type in ["weight", "bias"]: + old_key = f"{prefix}conv.{2*i}.{type}" + new_key = f"{prefix}conv.{i}.0.{type}" + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) + + +def _default_anchorgen(): + anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]) + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) + return anchor_generator + + +class RetinaNetHead(nn.Module): + """ + A regression and classification head for use in RetinaNet. + + Args: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_classes (int): number of classes to be predicted + norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None + """ + + def __init__(self, in_channels, num_anchors, num_classes, norm_layer: Optional[Callable[..., nn.Module]] = None): + super().__init__() + self.classification_head = RetinaNetClassificationHead( + in_channels, num_anchors, num_classes, norm_layer=norm_layer + ) + self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors, norm_layer=norm_layer) + + def compute_loss(self, targets, head_outputs, anchors, matched_idxs): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor] + return { + "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs), + "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), + } + + def forward(self, x): + # type: (List[Tensor]) -> Dict[str, Tensor] + return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)} + + +class RetinaNetClassificationHead(nn.Module): + """ + A classification head for use in RetinaNet. + + Args: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_classes (int): number of classes to be predicted + norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None + """ + + _version = 2 + + def __init__( + self, + in_channels, + num_anchors, + num_classes, + prior_probability=0.01, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ): + super().__init__() + + conv = [] + for _ in range(4): + conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer)) + self.conv = nn.Sequential(*conv) + + for layer in self.conv.modules(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) + if layer.bias is not None: + torch.nn.init.constant_(layer.bias, 0) + + self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) + torch.nn.init.normal_(self.cls_logits.weight, std=0.01) + torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability)) + + self.num_classes = num_classes + self.num_anchors = num_anchors + + # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript. + # TorchScript doesn't support class attributes. + # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584 + self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + _v1_to_v2_weights(state_dict, prefix) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def compute_loss(self, targets, head_outputs, matched_idxs): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor + losses = [] + + cls_logits = head_outputs["cls_logits"] + + for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs): + # determine only the foreground + foreground_idxs_per_image = matched_idxs_per_image >= 0 + num_foreground = foreground_idxs_per_image.sum() + + # create the target classification + gt_classes_target = torch.zeros_like(cls_logits_per_image) + gt_classes_target[ + foreground_idxs_per_image, + targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]], + ] = 1.0 + + # find indices for which anchors should be ignored + valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS + + # compute the classification loss + losses.append( + sigmoid_focal_loss( + cls_logits_per_image[valid_idxs_per_image], + gt_classes_target[valid_idxs_per_image], + reduction="sum", + ) + / max(1, num_foreground) + ) + + return _sum(losses) / len(targets) + + def forward(self, x): + # type: (List[Tensor]) -> Tensor + all_cls_logits = [] + + for features in x: + cls_logits = self.conv(features) + cls_logits = self.cls_logits(cls_logits) + + # Permute classification output from (N, A * K, H, W) to (N, HWA, K). + N, _, H, W = cls_logits.shape + cls_logits = cls_logits.view(N, -1, self.num_classes, H, W) + cls_logits = cls_logits.permute(0, 3, 4, 1, 2) + cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4) + + all_cls_logits.append(cls_logits) + + return torch.cat(all_cls_logits, dim=1) + + +class RetinaNetRegressionHead(nn.Module): + """ + A regression head for use in RetinaNet. + + Args: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None + """ + + _version = 2 + + __annotations__ = { + "box_coder": det_utils.BoxCoder, + } + + def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None): + super().__init__() + + conv = [] + for _ in range(4): + conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer)) + self.conv = nn.Sequential(*conv) + + self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) + torch.nn.init.normal_(self.bbox_reg.weight, std=0.01) + torch.nn.init.zeros_(self.bbox_reg.bias) + + for layer in self.conv.modules(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) + if layer.bias is not None: + torch.nn.init.zeros_(layer.bias) + + self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + self._loss_type = "l1" + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + _v1_to_v2_weights(state_dict, prefix) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def compute_loss(self, targets, head_outputs, anchors, matched_idxs): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor + losses = [] + + bbox_regression = head_outputs["bbox_regression"] + + for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip( + targets, bbox_regression, anchors, matched_idxs + ): + # determine only the foreground indices, ignore the rest + foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] + num_foreground = foreground_idxs_per_image.numel() + + # select only the foreground boxes + matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]] + bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] + anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] + + # compute the loss + losses.append( + _box_loss( + self._loss_type, + self.box_coder, + anchors_per_image, + matched_gt_boxes_per_image, + bbox_regression_per_image, + ) + / max(1, num_foreground) + ) + + return _sum(losses) / max(1, len(targets)) + + def forward(self, x): + # type: (List[Tensor]) -> Tensor + all_bbox_regression = [] + + for features in x: + bbox_regression = self.conv(features) + bbox_regression = self.bbox_reg(bbox_regression) + + # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). + N, _, H, W = bbox_regression.shape + bbox_regression = bbox_regression.view(N, -1, 4, H, W) + bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2) + bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4) + + all_bbox_regression.append(bbox_regression) + + return torch.cat(all_bbox_regression, dim=1) + + +class RetinaNet(nn.Module): + """ + Implements RetinaNet. + + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes. + + The behavior of the model changes depending on if it is in training or evaluation mode. + + During training, the model expects both the input tensors and targets (list of dictionary), + containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the class label for each ground-truth box + + The model returns a Dict[Tensor] during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows: + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the predicted labels for each image + - scores (Tensor[N]): the scores for each prediction + + Args: + backbone (nn.Module): the network used to compute the features for the model. + It should contain an out_channels attribute, which indicates the number of output + channels that each feature map has (and it should be the same for all feature maps). + The backbone should return a single Tensor or an OrderedDict[Tensor]. + num_classes (int): number of output classes of the model (including the background). + min_size (int): Images are rescaled before feeding them to the backbone: + we attempt to preserve the aspect ratio and scale the shorter edge + to ``min_size``. If the resulting longer edge exceeds ``max_size``, + then downscale so that the longer edge does not exceed ``max_size``. + This may result in the shorter edge beeing lower than ``min_size``. + max_size (int): See ``min_size``. + image_mean (Tuple[float, float, float]): mean values used for input normalization. + They are generally the mean values of the dataset on which the backbone has been trained + on + image_std (Tuple[float, float, float]): std values used for input normalization. + They are generally the std values of the dataset on which the backbone has been trained on + anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature + maps. + head (nn.Module): Module run on top of the feature pyramid. + Defaults to a module containing a classification and regression module. + score_thresh (float): Score threshold used for postprocessing the detections. + nms_thresh (float): NMS threshold used for postprocessing the detections. + detections_per_img (int): Number of best detections to keep after NMS. + fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be + considered as positive during training. + bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be + considered as negative during training. + topk_candidates (int): Number of best detections to keep before NMS. + + Example: + + >>> import torch + >>> import torchvision + >>> from torchvision.models.detection import RetinaNet + >>> from torchvision.models.detection.anchor_utils import AnchorGenerator + >>> # load a pre-trained model for classification and return + >>> # only the features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features + >>> # RetinaNet needs to know the number of + >>> # output channels in a backbone. For mobilenet_v2, it's 1280, + >>> # so we need to add it here + >>> backbone.out_channels = 1280 + >>> + >>> # let's make the network generate 5 x 3 anchors per spatial + >>> # location, with 5 different sizes and 3 different aspect + >>> # ratios. We have a Tuple[Tuple[int]] because each feature + >>> # map could potentially have different sizes and + >>> # aspect ratios + >>> anchor_generator = AnchorGenerator( + >>> sizes=((32, 64, 128, 256, 512),), + >>> aspect_ratios=((0.5, 1.0, 2.0),) + >>> ) + >>> + >>> # put the pieces together inside a RetinaNet model + >>> model = RetinaNet(backbone, + >>> num_classes=2, + >>> anchor_generator=anchor_generator) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + """ + + __annotations__ = { + "box_coder": det_utils.BoxCoder, + "proposal_matcher": det_utils.Matcher, + } + + def __init__( + self, + backbone, + num_classes, + # transform parameters + min_size=800, + max_size=1333, + image_mean=None, + image_std=None, + # Anchor parameters + anchor_generator=None, + head=None, + proposal_matcher=None, + score_thresh=0.05, + nms_thresh=0.5, + detections_per_img=300, + fg_iou_thresh=0.5, + bg_iou_thresh=0.4, + topk_candidates=1000, + **kwargs, + ): + super().__init__() + _log_api_usage_once(self) + + if not hasattr(backbone, "out_channels"): + raise ValueError( + "backbone should contain an attribute out_channels " + "specifying the number of output channels (assumed to be the " + "same for all the levels)" + ) + self.backbone = backbone + + if not isinstance(anchor_generator, (AnchorGenerator, type(None))): + raise TypeError( + f"anchor_generator should be of type AnchorGenerator or None instead of {type(anchor_generator)}" + ) + + if anchor_generator is None: + anchor_generator = _default_anchorgen() + self.anchor_generator = anchor_generator + + if head is None: + head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes) + self.head = head + + if proposal_matcher is None: + proposal_matcher = det_utils.Matcher( + fg_iou_thresh, + bg_iou_thresh, + allow_low_quality_matches=True, + ) + self.proposal_matcher = proposal_matcher + + self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + + if image_mean is None: + image_mean = [0.485, 0.456, 0.406] + if image_std is None: + image_std = [0.229, 0.224, 0.225] + self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs) + + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.detections_per_img = detections_per_img + self.topk_candidates = topk_candidates + + # used only on torchscript mode + self._has_warned = False + + @torch.jit.unused + def eager_outputs(self, losses, detections): + # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + if self.training: + return losses + + return detections + + def compute_loss(self, targets, head_outputs, anchors): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor] + matched_idxs = [] + for anchors_per_image, targets_per_image in zip(anchors, targets): + if targets_per_image["boxes"].numel() == 0: + matched_idxs.append( + torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device) + ) + continue + + match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image) + matched_idxs.append(self.proposal_matcher(match_quality_matrix)) + + return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) + + def postprocess_detections(self, head_outputs, anchors, image_shapes): + # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] + class_logits = head_outputs["cls_logits"] + box_regression = head_outputs["bbox_regression"] + + num_images = len(image_shapes) + + detections: List[Dict[str, Tensor]] = [] + + for index in range(num_images): + box_regression_per_image = [br[index] for br in box_regression] + logits_per_image = [cl[index] for cl in class_logits] + anchors_per_image, image_shape = anchors[index], image_shapes[index] + + image_boxes = [] + image_scores = [] + image_labels = [] + + for box_regression_per_level, logits_per_level, anchors_per_level in zip( + box_regression_per_image, logits_per_image, anchors_per_image + ): + num_classes = logits_per_level.shape[-1] + + # remove low scoring boxes + scores_per_level = torch.sigmoid(logits_per_level).flatten() + keep_idxs = scores_per_level > self.score_thresh + scores_per_level = scores_per_level[keep_idxs] + topk_idxs = torch.where(keep_idxs)[0] + + # keep only topk scoring predictions + num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0) + scores_per_level, idxs = scores_per_level.topk(num_topk) + topk_idxs = topk_idxs[idxs] + + anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor") + labels_per_level = topk_idxs % num_classes + + boxes_per_level = self.box_coder.decode_single( + box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs] + ) + boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape) + + image_boxes.append(boxes_per_level) + image_scores.append(scores_per_level) + image_labels.append(labels_per_level) + + image_boxes = torch.cat(image_boxes, dim=0) + image_scores = torch.cat(image_scores, dim=0) + image_labels = torch.cat(image_labels, dim=0) + + # non-maximum suppression + keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) + keep = keep[: self.detections_per_img] + + detections.append( + { + "boxes": image_boxes[keep], + "scores": image_scores[keep], + "labels": image_labels[keep], + } + ) + + return detections + + def forward(self, images, targets=None): + # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + """ + Args: + images (list[Tensor]): images to be processed + targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) + + Returns: + result (list[BoxList] or dict[Tensor]): the output from the model. + During training, it returns a dict[Tensor] which contains the losses. + During testing, it returns list[BoxList] contains additional fields + like `scores`, `labels` and `mask` (for Mask R-CNN models). + + """ + if self.training: + if targets is None: + torch._assert(False, "targets should not be none when in training mode") + else: + for target in targets: + boxes = target["boxes"] + torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.") + torch._assert( + len(boxes.shape) == 2 and boxes.shape[-1] == 4, + "Expected target boxes to be a tensor of shape [N, 4].", + ) + + # get the original image sizes + original_image_sizes: List[Tuple[int, int]] = [] + for img in images: + val = img.shape[-2:] + torch._assert( + len(val) == 2, + f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", + ) + original_image_sizes.append((val[0], val[1])) + + # transform the input + images, targets = self.transform(images, targets) + + # Check for degenerate boxes + # TODO: Move this to a function + if targets is not None: + for target_idx, target in enumerate(targets): + boxes = target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + # print the first degenerate box + bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] + degen_bb: List[float] = boxes[bb_idx].tolist() + torch._assert( + False, + "All bounding boxes should have positive height and width." + f" Found invalid box {degen_bb} for target at index {target_idx}.", + ) + + # get the features from the backbone + features = self.backbone(images.tensors) + if isinstance(features, torch.Tensor): + features = OrderedDict([("0", features)]) + + # TODO: Do we want a list or a dict? + features = list(features.values()) + + # compute the retinanet heads outputs using the features + head_outputs = self.head(features) + + # create the set of anchors + anchors = self.anchor_generator(images, features) + + losses = {} + detections: List[Dict[str, Tensor]] = [] + if self.training: + if targets is None: + torch._assert(False, "targets should not be none when in training mode") + else: + # compute the losses + losses = self.compute_loss(targets, head_outputs, anchors) + else: + # recover level sizes + num_anchors_per_level = [x.size(2) * x.size(3) for x in features] + HW = 0 + for v in num_anchors_per_level: + HW += v + HWA = head_outputs["cls_logits"].size(1) + A = HWA // HW + num_anchors_per_level = [hw * A for hw in num_anchors_per_level] + + # split outputs per level + split_head_outputs: Dict[str, List[Tensor]] = {} + for k in head_outputs: + split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1)) + split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors] + + # compute the detections + detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes) + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) + + if torch.jit.is_scripting(): + if not self._has_warned: + warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting") + self._has_warned = True + return losses, detections + return self.eager_outputs(losses, detections) + + +_COMMON_META = { + "categories": _COCO_CATEGORIES, + "min_size": (1, 1), +} + + +class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 34014999, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", + "_metrics": { + "COCO-val2017": { + "box_map": 36.4, + } + }, + "_ops": 151.54, + "_file_size": 130.267, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", + }, + ) + DEFAULT = COCO_V1 + + +class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "num_params": 38198935, + "recipe": "https://github.com/pytorch/vision/pull/5756", + "_metrics": { + "COCO-val2017": { + "box_map": 41.5, + } + }, + "_ops": 152.238, + "_file_size": 146.037, + "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""", + }, + ) + DEFAULT = COCO_V1 + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) +def retinanet_resnet50_fpn( + *, + weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> RetinaNet: + """ + Constructs a RetinaNet model with a ResNet-50-FPN backbone. + + .. betastatus:: detection module + + Reference: `Focal Loss for Dense Object Detection `_. + + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each + image, and should be in ``0-1`` range. Different images can have different sizes. + + The behavior of the model changes depending on if it is in training or evaluation mode. + + During training, the model expects both the input tensors and targets (list of dictionary), + containing: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as + follows, where ``N`` is the number of detections: + + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the predicted labels for each detection + - scores (``Tensor[N]``): the scores of each detection + + For more details on the output, you may refer to :ref:`instance_seg_output`. + + Example:: + + >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for + the backbone. + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. + Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is + passed (the default) this value is set to 3. + **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights + :members: + """ + weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d + + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) + # skip P2 because it generates too many anchors (according to their paper) + backbone = _resnet_fpn_extractor( + backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) + ) + model = RetinaNet(backbone, num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + + return model + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) +def retinanet_resnet50_fpn_v2( + *, + weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> RetinaNet: + """ + Constructs an improved RetinaNet model with a ResNet-50-FPN backbone. + + .. betastatus:: detection module + + Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection + `_. + + :func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details. + + Args: + weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for + the backbone. + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. + Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is + passed (the default) this value is set to 3. + **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights + :members: + """ + weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None + trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) + + backbone = resnet50(weights=weights_backbone, progress=progress) + backbone = _resnet_fpn_extractor( + backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(2048, 256) + ) + anchor_generator = _default_anchorgen() + head = RetinaNetHead( + backbone.out_channels, + anchor_generator.num_anchors_per_location()[0], + num_classes, + norm_layer=partial(nn.GroupNorm, 32), + ) + head.regression_head._loss_type = "giou" + model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/detection/roi_heads.py b/lib/python3.10/site-packages/torchvision/models/detection/roi_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..51b210cb6f368c1f4914ffe99287efef6057cba4 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/roi_heads.py @@ -0,0 +1,876 @@ +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn, Tensor +from torchvision.ops import boxes as box_ops, roi_align + +from . import _utils as det_utils + + +def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): + # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + """ + Computes the loss for Faster R-CNN. + + Args: + class_logits (Tensor) + box_regression (Tensor) + labels (list[BoxList]) + regression_targets (Tensor) + + Returns: + classification_loss (Tensor) + box_loss (Tensor) + """ + + labels = torch.cat(labels, dim=0) + regression_targets = torch.cat(regression_targets, dim=0) + + classification_loss = F.cross_entropy(class_logits, labels) + + # get indices that correspond to the regression targets for + # the corresponding ground truth labels, to be used with + # advanced indexing + sampled_pos_inds_subset = torch.where(labels > 0)[0] + labels_pos = labels[sampled_pos_inds_subset] + N, num_classes = class_logits.shape + box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4) + + box_loss = F.smooth_l1_loss( + box_regression[sampled_pos_inds_subset, labels_pos], + regression_targets[sampled_pos_inds_subset], + beta=1 / 9, + reduction="sum", + ) + box_loss = box_loss / labels.numel() + + return classification_loss, box_loss + + +def maskrcnn_inference(x, labels): + # type: (Tensor, List[Tensor]) -> List[Tensor] + """ + From the results of the CNN, post process the masks + by taking the mask corresponding to the class with max + probability (which are of fixed size and directly output + by the CNN) and return the masks in the mask field of the BoxList. + + Args: + x (Tensor): the mask logits + labels (list[BoxList]): bounding boxes that are used as + reference, one for ech image + + Returns: + results (list[BoxList]): one BoxList for each image, containing + the extra field mask + """ + mask_prob = x.sigmoid() + + # select masks corresponding to the predicted classes + num_masks = x.shape[0] + boxes_per_image = [label.shape[0] for label in labels] + labels = torch.cat(labels) + index = torch.arange(num_masks, device=labels.device) + mask_prob = mask_prob[index, labels][:, None] + mask_prob = mask_prob.split(boxes_per_image, dim=0) + + return mask_prob + + +def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): + # type: (Tensor, Tensor, Tensor, int) -> Tensor + """ + Given segmentation masks and the bounding boxes corresponding + to the location of the masks in the image, this function + crops and resizes the masks in the position defined by the + boxes. This prepares the masks for them to be fed to the + loss computation as the targets. + """ + matched_idxs = matched_idxs.to(boxes) + rois = torch.cat([matched_idxs[:, None], boxes], dim=1) + gt_masks = gt_masks[:, None].to(rois) + return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0] + + +def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs): + # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor + """ + Args: + proposals (list[BoxList]) + mask_logits (Tensor) + targets (list[BoxList]) + + Return: + mask_loss (Tensor): scalar tensor containing the loss + """ + + discretization_size = mask_logits.shape[-1] + labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)] + mask_targets = [ + project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs) + ] + + labels = torch.cat(labels, dim=0) + mask_targets = torch.cat(mask_targets, dim=0) + + # torch.mean (in binary_cross_entropy_with_logits) doesn't + # accept empty tensors, so handle it separately + if mask_targets.numel() == 0: + return mask_logits.sum() * 0 + + mask_loss = F.binary_cross_entropy_with_logits( + mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets + ) + return mask_loss + + +def keypoints_to_heatmap(keypoints, rois, heatmap_size): + # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor] + offset_x = rois[:, 0] + offset_y = rois[:, 1] + scale_x = heatmap_size / (rois[:, 2] - rois[:, 0]) + scale_y = heatmap_size / (rois[:, 3] - rois[:, 1]) + + offset_x = offset_x[:, None] + offset_y = offset_y[:, None] + scale_x = scale_x[:, None] + scale_y = scale_y[:, None] + + x = keypoints[..., 0] + y = keypoints[..., 1] + + x_boundary_inds = x == rois[:, 2][:, None] + y_boundary_inds = y == rois[:, 3][:, None] + + x = (x - offset_x) * scale_x + x = x.floor().long() + y = (y - offset_y) * scale_y + y = y.floor().long() + + x[x_boundary_inds] = heatmap_size - 1 + y[y_boundary_inds] = heatmap_size - 1 + + valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size) + vis = keypoints[..., 2] > 0 + valid = (valid_loc & vis).long() + + lin_ind = y * heatmap_size + x + heatmaps = lin_ind * valid + + return heatmaps, valid + + +def _onnx_heatmaps_to_keypoints( + maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i +): + num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64) + + width_correction = widths_i / roi_map_width + height_correction = heights_i / roi_map_height + + roi_map = F.interpolate( + maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False + )[:, 0] + + w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64) + pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) + + x_int = pos % w + y_int = (pos - x_int) // w + + x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to( + dtype=torch.float32 + ) + y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to( + dtype=torch.float32 + ) + + xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32) + xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32) + xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32) + xy_preds_i = torch.stack( + [ + xy_preds_i_0.to(dtype=torch.float32), + xy_preds_i_1.to(dtype=torch.float32), + xy_preds_i_2.to(dtype=torch.float32), + ], + 0, + ) + + # TODO: simplify when indexing without rank will be supported by ONNX + base = num_keypoints * num_keypoints + num_keypoints + 1 + ind = torch.arange(num_keypoints) + ind = ind.to(dtype=torch.int64) * base + end_scores_i = ( + roi_map.index_select(1, y_int.to(dtype=torch.int64)) + .index_select(2, x_int.to(dtype=torch.int64)) + .view(-1) + .index_select(0, ind.to(dtype=torch.int64)) + ) + + return xy_preds_i, end_scores_i + + +@torch.jit._script_if_tracing +def _onnx_heatmaps_to_keypoints_loop( + maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints +): + xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device) + end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device) + + for i in range(int(rois.size(0))): + xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints( + maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i] + ) + xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0) + end_scores = torch.cat( + (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0 + ) + return xy_preds, end_scores + + +def heatmaps_to_keypoints(maps, rois): + """Extract predicted keypoint locations from heatmaps. Output has shape + (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob) + for each keypoint. + """ + # This function converts a discrete image coordinate in a HEATMAP_SIZE x + # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain + # consistency with keypoints_to_heatmap_labels by using the conversion from + # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a + # continuous coordinate. + offset_x = rois[:, 0] + offset_y = rois[:, 1] + + widths = rois[:, 2] - rois[:, 0] + heights = rois[:, 3] - rois[:, 1] + widths = widths.clamp(min=1) + heights = heights.clamp(min=1) + widths_ceil = widths.ceil() + heights_ceil = heights.ceil() + + num_keypoints = maps.shape[1] + + if torchvision._is_tracing(): + xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop( + maps, + rois, + widths_ceil, + heights_ceil, + widths, + heights, + offset_x, + offset_y, + torch.scalar_tensor(num_keypoints, dtype=torch.int64), + ) + return xy_preds.permute(0, 2, 1), end_scores + + xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device) + end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device) + for i in range(len(rois)): + roi_map_width = int(widths_ceil[i].item()) + roi_map_height = int(heights_ceil[i].item()) + width_correction = widths[i] / roi_map_width + height_correction = heights[i] / roi_map_height + roi_map = F.interpolate( + maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False + )[:, 0] + # roi_map_probs = scores_to_probs(roi_map.copy()) + w = roi_map.shape[2] + pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) + + x_int = pos % w + y_int = torch.div(pos - x_int, w, rounding_mode="floor") + # assert (roi_map_probs[k, y_int, x_int] == + # roi_map_probs[k, :, :].max()) + x = (x_int.float() + 0.5) * width_correction + y = (y_int.float() + 0.5) * height_correction + xy_preds[i, 0, :] = x + offset_x[i] + xy_preds[i, 1, :] = y + offset_y[i] + xy_preds[i, 2, :] = 1 + end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int] + + return xy_preds.permute(0, 2, 1), end_scores + + +def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs): + # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor + N, K, H, W = keypoint_logits.shape + if H != W: + raise ValueError( + f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}" + ) + discretization_size = H + heatmaps = [] + valid = [] + for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs): + kp = gt_kp_in_image[midx] + heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size) + heatmaps.append(heatmaps_per_image.view(-1)) + valid.append(valid_per_image.view(-1)) + + keypoint_targets = torch.cat(heatmaps, dim=0) + valid = torch.cat(valid, dim=0).to(dtype=torch.uint8) + valid = torch.where(valid)[0] + + # torch.mean (in binary_cross_entropy_with_logits) doesn't + # accept empty tensors, so handle it sepaartely + if keypoint_targets.numel() == 0 or len(valid) == 0: + return keypoint_logits.sum() * 0 + + keypoint_logits = keypoint_logits.view(N * K, H * W) + + keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid]) + return keypoint_loss + + +def keypointrcnn_inference(x, boxes): + # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] + kp_probs = [] + kp_scores = [] + + boxes_per_image = [box.size(0) for box in boxes] + x2 = x.split(boxes_per_image, dim=0) + + for xx, bb in zip(x2, boxes): + kp_prob, scores = heatmaps_to_keypoints(xx, bb) + kp_probs.append(kp_prob) + kp_scores.append(scores) + + return kp_probs, kp_scores + + +def _onnx_expand_boxes(boxes, scale): + # type: (Tensor, float) -> Tensor + w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 + h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 + x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 + y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5 + + w_half = w_half.to(dtype=torch.float32) * scale + h_half = h_half.to(dtype=torch.float32) * scale + + boxes_exp0 = x_c - w_half + boxes_exp1 = y_c - h_half + boxes_exp2 = x_c + w_half + boxes_exp3 = y_c + h_half + boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1) + return boxes_exp + + +# the next two functions should be merged inside Masker +# but are kept here for the moment while we need them +# temporarily for paste_mask_in_image +def expand_boxes(boxes, scale): + # type: (Tensor, float) -> Tensor + if torchvision._is_tracing(): + return _onnx_expand_boxes(boxes, scale) + w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 + h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 + x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 + y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5 + + w_half *= scale + h_half *= scale + + boxes_exp = torch.zeros_like(boxes) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half + return boxes_exp + + +@torch.jit.unused +def expand_masks_tracing_scale(M, padding): + # type: (int, int) -> float + return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32) + + +def expand_masks(mask, padding): + # type: (Tensor, int) -> Tuple[Tensor, float] + M = mask.shape[-1] + if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why + scale = expand_masks_tracing_scale(M, padding) + else: + scale = float(M + 2 * padding) / M + padded_mask = F.pad(mask, (padding,) * 4) + return padded_mask, scale + + +def paste_mask_in_image(mask, box, im_h, im_w): + # type: (Tensor, Tensor, int, int) -> Tensor + TO_REMOVE = 1 + w = int(box[2] - box[0] + TO_REMOVE) + h = int(box[3] - box[1] + TO_REMOVE) + w = max(w, 1) + h = max(h, 1) + + # Set shape to [batchxCxHxW] + mask = mask.expand((1, 1, -1, -1)) + + # Resize mask + mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False) + mask = mask[0][0] + + im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device) + x_0 = max(box[0], 0) + x_1 = min(box[2] + 1, im_w) + y_0 = max(box[1], 0) + y_1 = min(box[3] + 1, im_h) + + im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])] + return im_mask + + +def _onnx_paste_mask_in_image(mask, box, im_h, im_w): + one = torch.ones(1, dtype=torch.int64) + zero = torch.zeros(1, dtype=torch.int64) + + w = box[2] - box[0] + one + h = box[3] - box[1] + one + w = torch.max(torch.cat((w, one))) + h = torch.max(torch.cat((h, one))) + + # Set shape to [batchxCxHxW] + mask = mask.expand((1, 1, mask.size(0), mask.size(1))) + + # Resize mask + mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False) + mask = mask[0][0] + + x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero))) + x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0)))) + y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero))) + y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0)))) + + unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])] + + # TODO : replace below with a dynamic padding when support is added in ONNX + + # pad y + zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1)) + zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1)) + concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :] + # pad x + zeros_x0 = torch.zeros(concat_0.size(0), x_0) + zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1) + im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w] + return im_mask + + +@torch.jit._script_if_tracing +def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w): + res_append = torch.zeros(0, im_h, im_w) + for i in range(masks.size(0)): + mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w) + mask_res = mask_res.unsqueeze(0) + res_append = torch.cat((res_append, mask_res)) + return res_append + + +def paste_masks_in_image(masks, boxes, img_shape, padding=1): + # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor + masks, scale = expand_masks(masks, padding=padding) + boxes = expand_boxes(boxes, scale).to(dtype=torch.int64) + im_h, im_w = img_shape + + if torchvision._is_tracing(): + return _onnx_paste_masks_in_image_loop( + masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64) + )[:, None] + res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)] + if len(res) > 0: + ret = torch.stack(res, dim=0)[:, None] + else: + ret = masks.new_empty((0, 1, im_h, im_w)) + return ret + + +class RoIHeads(nn.Module): + __annotations__ = { + "box_coder": det_utils.BoxCoder, + "proposal_matcher": det_utils.Matcher, + "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler, + } + + def __init__( + self, + box_roi_pool, + box_head, + box_predictor, + # Faster R-CNN training + fg_iou_thresh, + bg_iou_thresh, + batch_size_per_image, + positive_fraction, + bbox_reg_weights, + # Faster R-CNN inference + score_thresh, + nms_thresh, + detections_per_img, + # Mask + mask_roi_pool=None, + mask_head=None, + mask_predictor=None, + keypoint_roi_pool=None, + keypoint_head=None, + keypoint_predictor=None, + ): + super().__init__() + + self.box_similarity = box_ops.box_iou + # assign ground-truth boxes for each proposal + self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False) + + self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction) + + if bbox_reg_weights is None: + bbox_reg_weights = (10.0, 10.0, 5.0, 5.0) + self.box_coder = det_utils.BoxCoder(bbox_reg_weights) + + self.box_roi_pool = box_roi_pool + self.box_head = box_head + self.box_predictor = box_predictor + + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.detections_per_img = detections_per_img + + self.mask_roi_pool = mask_roi_pool + self.mask_head = mask_head + self.mask_predictor = mask_predictor + + self.keypoint_roi_pool = keypoint_roi_pool + self.keypoint_head = keypoint_head + self.keypoint_predictor = keypoint_predictor + + def has_mask(self): + if self.mask_roi_pool is None: + return False + if self.mask_head is None: + return False + if self.mask_predictor is None: + return False + return True + + def has_keypoint(self): + if self.keypoint_roi_pool is None: + return False + if self.keypoint_head is None: + return False + if self.keypoint_predictor is None: + return False + return True + + def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): + # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] + matched_idxs = [] + labels = [] + for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels): + + if gt_boxes_in_image.numel() == 0: + # Background image + device = proposals_in_image.device + clamped_matched_idxs_in_image = torch.zeros( + (proposals_in_image.shape[0],), dtype=torch.int64, device=device + ) + labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device) + else: + # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands + match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image) + matched_idxs_in_image = self.proposal_matcher(match_quality_matrix) + + clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0) + + labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image] + labels_in_image = labels_in_image.to(dtype=torch.int64) + + # Label background (below the low threshold) + bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD + labels_in_image[bg_inds] = 0 + + # Label ignore proposals (between low and high thresholds) + ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS + labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler + + matched_idxs.append(clamped_matched_idxs_in_image) + labels.append(labels_in_image) + return matched_idxs, labels + + def subsample(self, labels): + # type: (List[Tensor]) -> List[Tensor] + sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) + sampled_inds = [] + for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)): + img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0] + sampled_inds.append(img_sampled_inds) + return sampled_inds + + def add_gt_proposals(self, proposals, gt_boxes): + # type: (List[Tensor], List[Tensor]) -> List[Tensor] + proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)] + + return proposals + + def check_targets(self, targets): + # type: (Optional[List[Dict[str, Tensor]]]) -> None + if targets is None: + raise ValueError("targets should not be None") + if not all(["boxes" in t for t in targets]): + raise ValueError("Every element of targets should have a boxes key") + if not all(["labels" in t for t in targets]): + raise ValueError("Every element of targets should have a labels key") + if self.has_mask(): + if not all(["masks" in t for t in targets]): + raise ValueError("Every element of targets should have a masks key") + + def select_training_samples( + self, + proposals, # type: List[Tensor] + targets, # type: Optional[List[Dict[str, Tensor]]] + ): + # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]] + self.check_targets(targets) + if targets is None: + raise ValueError("targets should not be None") + dtype = proposals[0].dtype + device = proposals[0].device + + gt_boxes = [t["boxes"].to(dtype) for t in targets] + gt_labels = [t["labels"] for t in targets] + + # append ground-truth bboxes to propos + proposals = self.add_gt_proposals(proposals, gt_boxes) + + # get matching gt indices for each proposal + matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels) + # sample a fixed proportion of positive-negative proposals + sampled_inds = self.subsample(labels) + matched_gt_boxes = [] + num_images = len(proposals) + for img_id in range(num_images): + img_sampled_inds = sampled_inds[img_id] + proposals[img_id] = proposals[img_id][img_sampled_inds] + labels[img_id] = labels[img_id][img_sampled_inds] + matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds] + + gt_boxes_in_image = gt_boxes[img_id] + if gt_boxes_in_image.numel() == 0: + gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device) + matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]]) + + regression_targets = self.box_coder.encode(matched_gt_boxes, proposals) + return proposals, matched_idxs, labels, regression_targets + + def postprocess_detections( + self, + class_logits, # type: Tensor + box_regression, # type: Tensor + proposals, # type: List[Tensor] + image_shapes, # type: List[Tuple[int, int]] + ): + # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]] + device = class_logits.device + num_classes = class_logits.shape[-1] + + boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals] + pred_boxes = self.box_coder.decode(box_regression, proposals) + + pred_scores = F.softmax(class_logits, -1) + + pred_boxes_list = pred_boxes.split(boxes_per_image, 0) + pred_scores_list = pred_scores.split(boxes_per_image, 0) + + all_boxes = [] + all_scores = [] + all_labels = [] + for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes): + boxes = box_ops.clip_boxes_to_image(boxes, image_shape) + + # create labels for each prediction + labels = torch.arange(num_classes, device=device) + labels = labels.view(1, -1).expand_as(scores) + + # remove predictions with the background label + boxes = boxes[:, 1:] + scores = scores[:, 1:] + labels = labels[:, 1:] + + # batch everything, by making every class prediction be a separate instance + boxes = boxes.reshape(-1, 4) + scores = scores.reshape(-1) + labels = labels.reshape(-1) + + # remove low scoring boxes + inds = torch.where(scores > self.score_thresh)[0] + boxes, scores, labels = boxes[inds], scores[inds], labels[inds] + + # remove empty boxes + keep = box_ops.remove_small_boxes(boxes, min_size=1e-2) + boxes, scores, labels = boxes[keep], scores[keep], labels[keep] + + # non-maximum suppression, independently done per class + keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) + # keep only topk scoring predictions + keep = keep[: self.detections_per_img] + boxes, scores, labels = boxes[keep], scores[keep], labels[keep] + + all_boxes.append(boxes) + all_scores.append(scores) + all_labels.append(labels) + + return all_boxes, all_scores, all_labels + + def forward( + self, + features, # type: Dict[str, Tensor] + proposals, # type: List[Tensor] + image_shapes, # type: List[Tuple[int, int]] + targets=None, # type: Optional[List[Dict[str, Tensor]]] + ): + # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]] + """ + Args: + features (List[Tensor]) + proposals (List[Tensor[N, 4]]) + image_shapes (List[Tuple[H, W]]) + targets (List[Dict]) + """ + if targets is not None: + for t in targets: + # TODO: https://github.com/pytorch/pytorch/issues/26731 + floating_point_types = (torch.float, torch.double, torch.half) + if not t["boxes"].dtype in floating_point_types: + raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}") + if not t["labels"].dtype == torch.int64: + raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}") + if self.has_keypoint(): + if not t["keypoints"].dtype == torch.float32: + raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}") + + if self.training: + proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) + else: + labels = None + regression_targets = None + matched_idxs = None + + box_features = self.box_roi_pool(features, proposals, image_shapes) + box_features = self.box_head(box_features) + class_logits, box_regression = self.box_predictor(box_features) + + result: List[Dict[str, torch.Tensor]] = [] + losses = {} + if self.training: + if labels is None: + raise ValueError("labels cannot be None") + if regression_targets is None: + raise ValueError("regression_targets cannot be None") + loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets) + losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg} + else: + boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) + num_images = len(boxes) + for i in range(num_images): + result.append( + { + "boxes": boxes[i], + "labels": labels[i], + "scores": scores[i], + } + ) + + if self.has_mask(): + mask_proposals = [p["boxes"] for p in result] + if self.training: + if matched_idxs is None: + raise ValueError("if in training, matched_idxs should not be None") + + # during training, only focus on positive boxes + num_images = len(proposals) + mask_proposals = [] + pos_matched_idxs = [] + for img_id in range(num_images): + pos = torch.where(labels[img_id] > 0)[0] + mask_proposals.append(proposals[img_id][pos]) + pos_matched_idxs.append(matched_idxs[img_id][pos]) + else: + pos_matched_idxs = None + + if self.mask_roi_pool is not None: + mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes) + mask_features = self.mask_head(mask_features) + mask_logits = self.mask_predictor(mask_features) + else: + raise Exception("Expected mask_roi_pool to be not None") + + loss_mask = {} + if self.training: + if targets is None or pos_matched_idxs is None or mask_logits is None: + raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training") + + gt_masks = [t["masks"] for t in targets] + gt_labels = [t["labels"] for t in targets] + rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs) + loss_mask = {"loss_mask": rcnn_loss_mask} + else: + labels = [r["labels"] for r in result] + masks_probs = maskrcnn_inference(mask_logits, labels) + for mask_prob, r in zip(masks_probs, result): + r["masks"] = mask_prob + + losses.update(loss_mask) + + # keep none checks in if conditional so torchscript will conditionally + # compile each branch + if ( + self.keypoint_roi_pool is not None + and self.keypoint_head is not None + and self.keypoint_predictor is not None + ): + keypoint_proposals = [p["boxes"] for p in result] + if self.training: + # during training, only focus on positive boxes + num_images = len(proposals) + keypoint_proposals = [] + pos_matched_idxs = [] + if matched_idxs is None: + raise ValueError("if in trainning, matched_idxs should not be None") + + for img_id in range(num_images): + pos = torch.where(labels[img_id] > 0)[0] + keypoint_proposals.append(proposals[img_id][pos]) + pos_matched_idxs.append(matched_idxs[img_id][pos]) + else: + pos_matched_idxs = None + + keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes) + keypoint_features = self.keypoint_head(keypoint_features) + keypoint_logits = self.keypoint_predictor(keypoint_features) + + loss_keypoint = {} + if self.training: + if targets is None or pos_matched_idxs is None: + raise ValueError("both targets and pos_matched_idxs should not be None when in training mode") + + gt_keypoints = [t["keypoints"] for t in targets] + rcnn_loss_keypoint = keypointrcnn_loss( + keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs + ) + loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint} + else: + if keypoint_logits is None or keypoint_proposals is None: + raise ValueError( + "both keypoint_logits and keypoint_proposals should not be None when not in training mode" + ) + + keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals) + for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result): + r["keypoints"] = keypoint_prob + r["keypoints_scores"] = kps + losses.update(loss_keypoint) + + return result, losses diff --git a/lib/python3.10/site-packages/torchvision/models/detection/rpn.py b/lib/python3.10/site-packages/torchvision/models/detection/rpn.py new file mode 100644 index 0000000000000000000000000000000000000000..f103181e4c6cba48c1a3b4c97583c5fb6785a8c4 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/rpn.py @@ -0,0 +1,388 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn, Tensor +from torch.nn import functional as F +from torchvision.ops import boxes as box_ops, Conv2dNormActivation + +from . import _utils as det_utils + +# Import AnchorGenerator to keep compatibility. +from .anchor_utils import AnchorGenerator # noqa: 401 +from .image_list import ImageList + + +class RPNHead(nn.Module): + """ + Adds a simple RPN Head with classification and regression heads + + Args: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + conv_depth (int, optional): number of convolutions + """ + + _version = 2 + + def __init__(self, in_channels: int, num_anchors: int, conv_depth=1) -> None: + super().__init__() + convs = [] + for _ in range(conv_depth): + convs.append(Conv2dNormActivation(in_channels, in_channels, kernel_size=3, norm_layer=None)) + self.conv = nn.Sequential(*convs) + self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) + self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1) + + for layer in self.modules(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type] + if layer.bias is not None: + torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type] + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + for type in ["weight", "bias"]: + old_key = f"{prefix}conv.{type}" + new_key = f"{prefix}conv.0.0.{type}" + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: + logits = [] + bbox_reg = [] + for feature in x: + t = self.conv(feature) + logits.append(self.cls_logits(t)) + bbox_reg.append(self.bbox_pred(t)) + return logits, bbox_reg + + +def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor: + layer = layer.view(N, -1, C, H, W) + layer = layer.permute(0, 3, 4, 1, 2) + layer = layer.reshape(N, -1, C) + return layer + + +def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]: + box_cls_flattened = [] + box_regression_flattened = [] + # for each feature level, permute the outputs to make them be in the + # same format as the labels. Note that the labels are computed for + # all feature levels concatenated, so we keep the same representation + # for the objectness and the box_regression + for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression): + N, AxC, H, W = box_cls_per_level.shape + Ax4 = box_regression_per_level.shape[1] + A = Ax4 // 4 + C = AxC // A + box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W) + box_cls_flattened.append(box_cls_per_level) + + box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W) + box_regression_flattened.append(box_regression_per_level) + # concatenate on the first dimension (representing the feature levels), to + # take into account the way the labels were generated (with all feature maps + # being concatenated as well) + box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2) + box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4) + return box_cls, box_regression + + +class RegionProposalNetwork(torch.nn.Module): + """ + Implements Region Proposal Network (RPN). + + Args: + anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature + maps. + head (nn.Module): module that computes the objectness and regression deltas + fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be + considered as positive during training of the RPN. + bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be + considered as negative during training of the RPN. + batch_size_per_image (int): number of anchors that are sampled during training of the RPN + for computing the loss + positive_fraction (float): proportion of positive anchors in a mini-batch during training + of the RPN + pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should + contain two fields: training and testing, to allow for different values depending + on training or evaluation + post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should + contain two fields: training and testing, to allow for different values depending + on training or evaluation + nms_thresh (float): NMS threshold used for postprocessing the RPN proposals + score_thresh (float): only return proposals with an objectness score greater than score_thresh + + """ + + __annotations__ = { + "box_coder": det_utils.BoxCoder, + "proposal_matcher": det_utils.Matcher, + "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler, + } + + def __init__( + self, + anchor_generator: AnchorGenerator, + head: nn.Module, + # Faster-RCNN Training + fg_iou_thresh: float, + bg_iou_thresh: float, + batch_size_per_image: int, + positive_fraction: float, + # Faster-RCNN Inference + pre_nms_top_n: Dict[str, int], + post_nms_top_n: Dict[str, int], + nms_thresh: float, + score_thresh: float = 0.0, + ) -> None: + super().__init__() + self.anchor_generator = anchor_generator + self.head = head + self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + + # used during training + self.box_similarity = box_ops.box_iou + + self.proposal_matcher = det_utils.Matcher( + fg_iou_thresh, + bg_iou_thresh, + allow_low_quality_matches=True, + ) + + self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction) + # used during testing + self._pre_nms_top_n = pre_nms_top_n + self._post_nms_top_n = post_nms_top_n + self.nms_thresh = nms_thresh + self.score_thresh = score_thresh + self.min_size = 1e-3 + + def pre_nms_top_n(self) -> int: + if self.training: + return self._pre_nms_top_n["training"] + return self._pre_nms_top_n["testing"] + + def post_nms_top_n(self) -> int: + if self.training: + return self._post_nms_top_n["training"] + return self._post_nms_top_n["testing"] + + def assign_targets_to_anchors( + self, anchors: List[Tensor], targets: List[Dict[str, Tensor]] + ) -> Tuple[List[Tensor], List[Tensor]]: + + labels = [] + matched_gt_boxes = [] + for anchors_per_image, targets_per_image in zip(anchors, targets): + gt_boxes = targets_per_image["boxes"] + + if gt_boxes.numel() == 0: + # Background image (negative example) + device = anchors_per_image.device + matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device) + labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device) + else: + match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image) + matched_idxs = self.proposal_matcher(match_quality_matrix) + # get the targets corresponding GT for each proposal + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)] + + labels_per_image = matched_idxs >= 0 + labels_per_image = labels_per_image.to(dtype=torch.float32) + + # Background (negative examples) + bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD + labels_per_image[bg_indices] = 0.0 + + # discard indices that are between thresholds + inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS + labels_per_image[inds_to_discard] = -1.0 + + labels.append(labels_per_image) + matched_gt_boxes.append(matched_gt_boxes_per_image) + return labels, matched_gt_boxes + + def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor: + r = [] + offset = 0 + for ob in objectness.split(num_anchors_per_level, 1): + num_anchors = ob.shape[1] + pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1) + _, top_n_idx = ob.topk(pre_nms_top_n, dim=1) + r.append(top_n_idx + offset) + offset += num_anchors + return torch.cat(r, dim=1) + + def filter_proposals( + self, + proposals: Tensor, + objectness: Tensor, + image_shapes: List[Tuple[int, int]], + num_anchors_per_level: List[int], + ) -> Tuple[List[Tensor], List[Tensor]]: + + num_images = proposals.shape[0] + device = proposals.device + # do not backprop through objectness + objectness = objectness.detach() + objectness = objectness.reshape(num_images, -1) + + levels = [ + torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level) + ] + levels = torch.cat(levels, 0) + levels = levels.reshape(1, -1).expand_as(objectness) + + # select top_n boxes independently per level before applying nms + top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level) + + image_range = torch.arange(num_images, device=device) + batch_idx = image_range[:, None] + + objectness = objectness[batch_idx, top_n_idx] + levels = levels[batch_idx, top_n_idx] + proposals = proposals[batch_idx, top_n_idx] + + objectness_prob = torch.sigmoid(objectness) + + final_boxes = [] + final_scores = [] + for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes): + boxes = box_ops.clip_boxes_to_image(boxes, img_shape) + + # remove small boxes + keep = box_ops.remove_small_boxes(boxes, self.min_size) + boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep] + + # remove low scoring boxes + # use >= for Backwards compatibility + keep = torch.where(scores >= self.score_thresh)[0] + boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep] + + # non-maximum suppression, independently done per level + keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh) + + # keep only topk scoring predictions + keep = keep[: self.post_nms_top_n()] + boxes, scores = boxes[keep], scores[keep] + + final_boxes.append(boxes) + final_scores.append(scores) + return final_boxes, final_scores + + def compute_loss( + self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor] + ) -> Tuple[Tensor, Tensor]: + """ + Args: + objectness (Tensor) + pred_bbox_deltas (Tensor) + labels (List[Tensor]) + regression_targets (List[Tensor]) + + Returns: + objectness_loss (Tensor) + box_loss (Tensor) + """ + + sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) + sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0] + sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0] + + sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) + + objectness = objectness.flatten() + + labels = torch.cat(labels, dim=0) + regression_targets = torch.cat(regression_targets, dim=0) + + box_loss = F.smooth_l1_loss( + pred_bbox_deltas[sampled_pos_inds], + regression_targets[sampled_pos_inds], + beta=1 / 9, + reduction="sum", + ) / (sampled_inds.numel()) + + objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds]) + + return objectness_loss, box_loss + + def forward( + self, + images: ImageList, + features: Dict[str, Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Tuple[List[Tensor], Dict[str, Tensor]]: + + """ + Args: + images (ImageList): images for which we want to compute the predictions + features (Dict[str, Tensor]): features computed from the images that are + used for computing the predictions. Each tensor in the list + correspond to different feature levels + targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional). + If provided, each element in the dict should contain a field `boxes`, + with the locations of the ground-truth boxes. + + Returns: + boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per + image. + losses (Dict[str, Tensor]): the losses for the model during training. During + testing, it is an empty dict. + """ + # RPN uses all feature maps that are available + features = list(features.values()) + objectness, pred_bbox_deltas = self.head(features) + anchors = self.anchor_generator(images, features) + + num_images = len(anchors) + num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness] + num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors] + objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas) + # apply pred_bbox_deltas to anchors to obtain the decoded proposals + # note that we detach the deltas because Faster R-CNN do not backprop through + # the proposals + proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors) + proposals = proposals.view(num_images, -1, 4) + boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level) + + losses = {} + if self.training: + if targets is None: + raise ValueError("targets should not be None") + labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) + regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) + loss_objectness, loss_rpn_box_reg = self.compute_loss( + objectness, pred_bbox_deltas, labels, regression_targets + ) + losses = { + "loss_objectness": loss_objectness, + "loss_rpn_box_reg": loss_rpn_box_reg, + } + return boxes, losses diff --git a/lib/python3.10/site-packages/torchvision/models/detection/ssd.py b/lib/python3.10/site-packages/torchvision/models/detection/ssd.py new file mode 100644 index 0000000000000000000000000000000000000000..87062d2bc88a5bf17625e0530116aba22941c538 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/ssd.py @@ -0,0 +1,682 @@ +import warnings +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from ...ops import boxes as box_ops +from ...transforms._presets import ObjectDetection +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._meta import _COCO_CATEGORIES +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..vgg import VGG, vgg16, VGG16_Weights +from . import _utils as det_utils +from .anchor_utils import DefaultBoxGenerator +from .backbone_utils import _validate_trainable_layers +from .transform import GeneralizedRCNNTransform + + +__all__ = [ + "SSD300_VGG16_Weights", + "ssd300_vgg16", +] + + +class SSD300_VGG16_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", + transforms=ObjectDetection, + meta={ + "num_params": 35641826, + "categories": _COCO_CATEGORIES, + "min_size": (1, 1), + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", + "_metrics": { + "COCO-val2017": { + "box_map": 25.1, + } + }, + "_ops": 34.858, + "_file_size": 135.988, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", + }, + ) + DEFAULT = COCO_V1 + + +def _xavier_init(conv: nn.Module): + for layer in conv.modules(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.xavier_uniform_(layer.weight) + if layer.bias is not None: + torch.nn.init.constant_(layer.bias, 0.0) + + +class SSDHead(nn.Module): + def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int): + super().__init__() + self.classification_head = SSDClassificationHead(in_channels, num_anchors, num_classes) + self.regression_head = SSDRegressionHead(in_channels, num_anchors) + + def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: + return { + "bbox_regression": self.regression_head(x), + "cls_logits": self.classification_head(x), + } + + +class SSDScoringHead(nn.Module): + def __init__(self, module_list: nn.ModuleList, num_columns: int): + super().__init__() + self.module_list = module_list + self.num_columns = num_columns + + def _get_result_from_module_list(self, x: Tensor, idx: int) -> Tensor: + """ + This is equivalent to self.module_list[idx](x), + but torchscript doesn't support this yet + """ + num_blocks = len(self.module_list) + if idx < 0: + idx += num_blocks + out = x + for i, module in enumerate(self.module_list): + if i == idx: + out = module(x) + return out + + def forward(self, x: List[Tensor]) -> Tensor: + all_results = [] + + for i, features in enumerate(x): + results = self._get_result_from_module_list(features, i) + + # Permute output from (N, A * K, H, W) to (N, HWA, K). + N, _, H, W = results.shape + results = results.view(N, -1, self.num_columns, H, W) + results = results.permute(0, 3, 4, 1, 2) + results = results.reshape(N, -1, self.num_columns) # Size=(N, HWA, K) + + all_results.append(results) + + return torch.cat(all_results, dim=1) + + +class SSDClassificationHead(SSDScoringHead): + def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int): + cls_logits = nn.ModuleList() + for channels, anchors in zip(in_channels, num_anchors): + cls_logits.append(nn.Conv2d(channels, num_classes * anchors, kernel_size=3, padding=1)) + _xavier_init(cls_logits) + super().__init__(cls_logits, num_classes) + + +class SSDRegressionHead(SSDScoringHead): + def __init__(self, in_channels: List[int], num_anchors: List[int]): + bbox_reg = nn.ModuleList() + for channels, anchors in zip(in_channels, num_anchors): + bbox_reg.append(nn.Conv2d(channels, 4 * anchors, kernel_size=3, padding=1)) + _xavier_init(bbox_reg) + super().__init__(bbox_reg, 4) + + +class SSD(nn.Module): + """ + Implements SSD architecture from `"SSD: Single Shot MultiBox Detector" `_. + + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes, but they will be resized + to a fixed size before passing it to the backbone. + + The behavior of the model changes depending on if it is in training or evaluation mode. + + During training, the model expects both the input tensors and targets (list of dictionary), + containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the class label for each ground-truth box + + The model returns a Dict[Tensor] during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows, where ``N`` is the number of detections: + + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the predicted labels for each detection + - scores (Tensor[N]): the scores for each detection + + Args: + backbone (nn.Module): the network used to compute the features for the model. + It should contain an out_channels attribute with the list of the output channels of + each feature map. The backbone should return a single Tensor or an OrderedDict[Tensor]. + anchor_generator (DefaultBoxGenerator): module that generates the default boxes for a + set of feature maps. + size (Tuple[int, int]): the width and height to which images will be rescaled before feeding them + to the backbone. + num_classes (int): number of output classes of the model (including the background). + image_mean (Tuple[float, float, float]): mean values used for input normalization. + They are generally the mean values of the dataset on which the backbone has been trained + on + image_std (Tuple[float, float, float]): std values used for input normalization. + They are generally the std values of the dataset on which the backbone has been trained on + head (nn.Module, optional): Module run on top of the backbone features. Defaults to a module containing + a classification and regression module. + score_thresh (float): Score threshold used for postprocessing the detections. + nms_thresh (float): NMS threshold used for postprocessing the detections. + detections_per_img (int): Number of best detections to keep after NMS. + iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be + considered as positive during training. + topk_candidates (int): Number of best detections to keep before NMS. + positive_fraction (float): a number between 0 and 1 which indicates the proportion of positive + proposals used during the training of the classification head. It is used to estimate the negative to + positive ratio. + """ + + __annotations__ = { + "box_coder": det_utils.BoxCoder, + "proposal_matcher": det_utils.Matcher, + } + + def __init__( + self, + backbone: nn.Module, + anchor_generator: DefaultBoxGenerator, + size: Tuple[int, int], + num_classes: int, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + head: Optional[nn.Module] = None, + score_thresh: float = 0.01, + nms_thresh: float = 0.45, + detections_per_img: int = 200, + iou_thresh: float = 0.5, + topk_candidates: int = 400, + positive_fraction: float = 0.25, + **kwargs: Any, + ): + super().__init__() + _log_api_usage_once(self) + + self.backbone = backbone + + self.anchor_generator = anchor_generator + + self.box_coder = det_utils.BoxCoder(weights=(10.0, 10.0, 5.0, 5.0)) + + if head is None: + if hasattr(backbone, "out_channels"): + out_channels = backbone.out_channels + else: + out_channels = det_utils.retrieve_out_channels(backbone, size) + + if len(out_channels) != len(anchor_generator.aspect_ratios): + raise ValueError( + f"The length of the output channels from the backbone ({len(out_channels)}) do not match the length of the anchor generator aspect ratios ({len(anchor_generator.aspect_ratios)})" + ) + + num_anchors = self.anchor_generator.num_anchors_per_location() + head = SSDHead(out_channels, num_anchors, num_classes) + self.head = head + + self.proposal_matcher = det_utils.SSDMatcher(iou_thresh) + + if image_mean is None: + image_mean = [0.485, 0.456, 0.406] + if image_std is None: + image_std = [0.229, 0.224, 0.225] + self.transform = GeneralizedRCNNTransform( + min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size, **kwargs + ) + + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.detections_per_img = detections_per_img + self.topk_candidates = topk_candidates + self.neg_to_pos_ratio = (1.0 - positive_fraction) / positive_fraction + + # used only on torchscript mode + self._has_warned = False + + @torch.jit.unused + def eager_outputs( + self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]] + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + if self.training: + return losses + + return detections + + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + anchors: List[Tensor], + matched_idxs: List[Tensor], + ) -> Dict[str, Tensor]: + bbox_regression = head_outputs["bbox_regression"] + cls_logits = head_outputs["cls_logits"] + + # Match original targets with default boxes + num_foreground = 0 + bbox_loss = [] + cls_targets = [] + for ( + targets_per_image, + bbox_regression_per_image, + cls_logits_per_image, + anchors_per_image, + matched_idxs_per_image, + ) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs): + # produce the matching between boxes and targets + foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] + foreground_matched_idxs_per_image = matched_idxs_per_image[foreground_idxs_per_image] + num_foreground += foreground_matched_idxs_per_image.numel() + + # Calculate regression loss + matched_gt_boxes_per_image = targets_per_image["boxes"][foreground_matched_idxs_per_image] + bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] + anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] + target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) + bbox_loss.append( + torch.nn.functional.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum") + ) + + # Estimate ground truth for class targets + gt_classes_target = torch.zeros( + (cls_logits_per_image.size(0),), + dtype=targets_per_image["labels"].dtype, + device=targets_per_image["labels"].device, + ) + gt_classes_target[foreground_idxs_per_image] = targets_per_image["labels"][ + foreground_matched_idxs_per_image + ] + cls_targets.append(gt_classes_target) + + bbox_loss = torch.stack(bbox_loss) + cls_targets = torch.stack(cls_targets) + + # Calculate classification loss + num_classes = cls_logits.size(-1) + cls_loss = F.cross_entropy(cls_logits.view(-1, num_classes), cls_targets.view(-1), reduction="none").view( + cls_targets.size() + ) + + # Hard Negative Sampling + foreground_idxs = cls_targets > 0 + num_negative = self.neg_to_pos_ratio * foreground_idxs.sum(1, keepdim=True) + # num_negative[num_negative < self.neg_to_pos_ratio] = self.neg_to_pos_ratio + negative_loss = cls_loss.clone() + negative_loss[foreground_idxs] = -float("inf") # use -inf to detect positive values that creeped in the sample + values, idx = negative_loss.sort(1, descending=True) + # background_idxs = torch.logical_and(idx.sort(1)[1] < num_negative, torch.isfinite(values)) + background_idxs = idx.sort(1)[1] < num_negative + + N = max(1, num_foreground) + return { + "bbox_regression": bbox_loss.sum() / N, + "classification": (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N, + } + + def forward( + self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + if self.training: + if targets is None: + torch._assert(False, "targets should not be none when in training mode") + else: + for target in targets: + boxes = target["boxes"] + if isinstance(boxes, torch.Tensor): + torch._assert( + len(boxes.shape) == 2 and boxes.shape[-1] == 4, + f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.", + ) + else: + torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.") + + # get the original image sizes + original_image_sizes: List[Tuple[int, int]] = [] + for img in images: + val = img.shape[-2:] + torch._assert( + len(val) == 2, + f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", + ) + original_image_sizes.append((val[0], val[1])) + + # transform the input + images, targets = self.transform(images, targets) + + # Check for degenerate boxes + if targets is not None: + for target_idx, target in enumerate(targets): + boxes = target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] + degen_bb: List[float] = boxes[bb_idx].tolist() + torch._assert( + False, + "All bounding boxes should have positive height and width." + f" Found invalid box {degen_bb} for target at index {target_idx}.", + ) + + # get the features from the backbone + features = self.backbone(images.tensors) + if isinstance(features, torch.Tensor): + features = OrderedDict([("0", features)]) + + features = list(features.values()) + + # compute the ssd heads outputs using the features + head_outputs = self.head(features) + + # create the set of anchors + anchors = self.anchor_generator(images, features) + + losses = {} + detections: List[Dict[str, Tensor]] = [] + if self.training: + matched_idxs = [] + if targets is None: + torch._assert(False, "targets should not be none when in training mode") + else: + for anchors_per_image, targets_per_image in zip(anchors, targets): + if targets_per_image["boxes"].numel() == 0: + matched_idxs.append( + torch.full( + (anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device + ) + ) + continue + + match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image) + matched_idxs.append(self.proposal_matcher(match_quality_matrix)) + + losses = self.compute_loss(targets, head_outputs, anchors, matched_idxs) + else: + detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes) + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) + + if torch.jit.is_scripting(): + if not self._has_warned: + warnings.warn("SSD always returns a (Losses, Detections) tuple in scripting") + self._has_warned = True + return losses, detections + return self.eager_outputs(losses, detections) + + def postprocess_detections( + self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor], image_shapes: List[Tuple[int, int]] + ) -> List[Dict[str, Tensor]]: + bbox_regression = head_outputs["bbox_regression"] + pred_scores = F.softmax(head_outputs["cls_logits"], dim=-1) + + num_classes = pred_scores.size(-1) + device = pred_scores.device + + detections: List[Dict[str, Tensor]] = [] + + for boxes, scores, anchors, image_shape in zip(bbox_regression, pred_scores, image_anchors, image_shapes): + boxes = self.box_coder.decode_single(boxes, anchors) + boxes = box_ops.clip_boxes_to_image(boxes, image_shape) + + image_boxes = [] + image_scores = [] + image_labels = [] + for label in range(1, num_classes): + score = scores[:, label] + + keep_idxs = score > self.score_thresh + score = score[keep_idxs] + box = boxes[keep_idxs] + + # keep only topk scoring predictions + num_topk = det_utils._topk_min(score, self.topk_candidates, 0) + score, idxs = score.topk(num_topk) + box = box[idxs] + + image_boxes.append(box) + image_scores.append(score) + image_labels.append(torch.full_like(score, fill_value=label, dtype=torch.int64, device=device)) + + image_boxes = torch.cat(image_boxes, dim=0) + image_scores = torch.cat(image_scores, dim=0) + image_labels = torch.cat(image_labels, dim=0) + + # non-maximum suppression + keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) + keep = keep[: self.detections_per_img] + + detections.append( + { + "boxes": image_boxes[keep], + "scores": image_scores[keep], + "labels": image_labels[keep], + } + ) + return detections + + +class SSDFeatureExtractorVGG(nn.Module): + def __init__(self, backbone: nn.Module, highres: bool): + super().__init__() + + _, _, maxpool3_pos, maxpool4_pos, _ = (i for i, layer in enumerate(backbone) if isinstance(layer, nn.MaxPool2d)) + + # Patch ceil_mode for maxpool3 to get the same WxH output sizes as the paper + backbone[maxpool3_pos].ceil_mode = True + + # parameters used for L2 regularization + rescaling + self.scale_weight = nn.Parameter(torch.ones(512) * 20) + + # Multiple Feature maps - page 4, Fig 2 of SSD paper + self.features = nn.Sequential(*backbone[:maxpool4_pos]) # until conv4_3 + + # SSD300 case - page 4, Fig 2 of SSD paper + extra = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(1024, 256, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2 + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(512, 128, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2 + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3), # conv10_2 + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3), # conv11_2 + nn.ReLU(inplace=True), + ), + ] + ) + if highres: + # Additional layers for the SSD512 case. See page 11, footernote 5. + extra.append( + nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=4), # conv12_2 + nn.ReLU(inplace=True), + ) + ) + _xavier_init(extra) + + fc = nn.Sequential( + nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=False), # add modified maxpool5 + nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), # FC7 + nn.ReLU(inplace=True), + ) + _xavier_init(fc) + extra.insert( + 0, + nn.Sequential( + *backbone[maxpool4_pos:-1], # until conv5_3, skip maxpool5 + fc, + ), + ) + self.extra = extra + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + # L2 regularization + Rescaling of 1st block's feature map + x = self.features(x) + rescaled = self.scale_weight.view(1, -1, 1, 1) * F.normalize(x) + output = [rescaled] + + # Calculating Feature maps for the rest blocks + for block in self.extra: + x = block(x) + output.append(x) + + return OrderedDict([(str(i), v) for i, v in enumerate(output)]) + + +def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int): + backbone = backbone.features + # Gather the indices of maxpools. These are the locations of output blocks. + stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1] + num_stages = len(stage_indices) + + # find the index of the layer from which we won't freeze + torch._assert( + 0 <= trainable_layers <= num_stages, + f"trainable_layers should be in the range [0, {num_stages}]. Instead got {trainable_layers}", + ) + freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] + + for b in backbone[:freeze_before]: + for parameter in b.parameters(): + parameter.requires_grad_(False) + + return SSDFeatureExtractorVGG(backbone, highres) + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", SSD300_VGG16_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES), +) +def ssd300_vgg16( + *, + weights: Optional[SSD300_VGG16_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[VGG16_Weights] = VGG16_Weights.IMAGENET1K_FEATURES, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> SSD: + """The SSD300 model is based on the `SSD: Single Shot MultiBox Detector + `_ paper. + + .. betastatus:: detection module + + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes, but they will be resized + to a fixed size before passing it to the backbone. + + The behavior of the model changes depending on if it is in training or evaluation mode. + + During training, the model expects both the input tensors and targets (list of dictionary), + containing: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the class label for each ground-truth box + + The model returns a Dict[Tensor] during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows, where ``N`` is the number of detections: + + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the predicted labels for each detection + - scores (Tensor[N]): the scores for each detection + + Example: + + >>> model = torchvision.models.detection.ssd300_vgg16(weights=SSD300_VGG16_Weights.DEFAULT) + >>> model.eval() + >>> x = [torch.rand(3, 300, 300), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + weights (:class:`~torchvision.models.detection.SSD300_VGG16_Weights`, optional): The pretrained + weights to use. See + :class:`~torchvision.models.detection.SSD300_VGG16_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr + Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (:class:`~torchvision.models.VGG16_Weights`, optional): The pretrained weights for the + backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. + Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is + passed (the default) this value is set to 4. + **kwargs: parameters passed to the ``torchvision.models.detection.SSD`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.SSD300_VGG16_Weights + :members: + """ + weights = SSD300_VGG16_Weights.verify(weights) + weights_backbone = VGG16_Weights.verify(weights_backbone) + + if "size" in kwargs: + warnings.warn("The size of the model is already fixed; ignoring the parameter.") + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + trainable_backbone_layers = _validate_trainable_layers( + weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4 + ) + + # Use custom backbones more appropriate for SSD + backbone = vgg16(weights=weights_backbone, progress=progress) + backbone = _vgg_extractor(backbone, False, trainable_backbone_layers) + anchor_generator = DefaultBoxGenerator( + [[2], [2, 3], [2, 3], [2, 3], [2], [2]], + scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], + steps=[8, 16, 32, 64, 100, 300], + ) + + defaults = { + # Rescale the input in a way compatible to the backbone + "image_mean": [0.48235, 0.45882, 0.40784], + "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor + } + kwargs: Any = {**defaults, **kwargs} + model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/detection/ssdlite.py b/lib/python3.10/site-packages/torchvision/models/detection/ssdlite.py new file mode 100644 index 0000000000000000000000000000000000000000..eda21bf941ef0d4a9051312ebdba6911c6760e8d --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/ssdlite.py @@ -0,0 +1,331 @@ +import warnings +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from torch import nn, Tensor + +from ...ops.misc import Conv2dNormActivation +from ...transforms._presets import ObjectDetection +from ...utils import _log_api_usage_once +from .. import mobilenet +from .._api import register_model, Weights, WeightsEnum +from .._meta import _COCO_CATEGORIES +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights +from . import _utils as det_utils +from .anchor_utils import DefaultBoxGenerator +from .backbone_utils import _validate_trainable_layers +from .ssd import SSD, SSDScoringHead + + +__all__ = [ + "SSDLite320_MobileNet_V3_Large_Weights", + "ssdlite320_mobilenet_v3_large", +] + + +# Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper +def _prediction_block( + in_channels: int, out_channels: int, kernel_size: int, norm_layer: Callable[..., nn.Module] +) -> nn.Sequential: + return nn.Sequential( + # 3x3 depthwise with stride 1 and padding 1 + Conv2dNormActivation( + in_channels, + in_channels, + kernel_size=kernel_size, + groups=in_channels, + norm_layer=norm_layer, + activation_layer=nn.ReLU6, + ), + # 1x1 projetion to output channels + nn.Conv2d(in_channels, out_channels, 1), + ) + + +def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., nn.Module]) -> nn.Sequential: + activation = nn.ReLU6 + intermediate_channels = out_channels // 2 + return nn.Sequential( + # 1x1 projection to half output channels + Conv2dNormActivation( + in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation + ), + # 3x3 depthwise with stride 2 and padding 1 + Conv2dNormActivation( + intermediate_channels, + intermediate_channels, + kernel_size=3, + stride=2, + groups=intermediate_channels, + norm_layer=norm_layer, + activation_layer=activation, + ), + # 1x1 projetion to output channels + Conv2dNormActivation( + intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation + ), + ) + + +def _normal_init(conv: nn.Module): + for layer in conv.modules(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, mean=0.0, std=0.03) + if layer.bias is not None: + torch.nn.init.constant_(layer.bias, 0.0) + + +class SSDLiteHead(nn.Module): + def __init__( + self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module] + ): + super().__init__() + self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer) + self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer) + + def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: + return { + "bbox_regression": self.regression_head(x), + "cls_logits": self.classification_head(x), + } + + +class SSDLiteClassificationHead(SSDScoringHead): + def __init__( + self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module] + ): + cls_logits = nn.ModuleList() + for channels, anchors in zip(in_channels, num_anchors): + cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer)) + _normal_init(cls_logits) + super().__init__(cls_logits, num_classes) + + +class SSDLiteRegressionHead(SSDScoringHead): + def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: Callable[..., nn.Module]): + bbox_reg = nn.ModuleList() + for channels, anchors in zip(in_channels, num_anchors): + bbox_reg.append(_prediction_block(channels, 4 * anchors, 3, norm_layer)) + _normal_init(bbox_reg) + super().__init__(bbox_reg, 4) + + +class SSDLiteFeatureExtractorMobileNet(nn.Module): + def __init__( + self, + backbone: nn.Module, + c4_pos: int, + norm_layer: Callable[..., nn.Module], + width_mult: float = 1.0, + min_depth: int = 16, + ): + super().__init__() + _log_api_usage_once(self) + + if backbone[c4_pos].use_res_connect: + raise ValueError("backbone[c4_pos].use_res_connect should be False") + + self.features = nn.Sequential( + # As described in section 6.3 of MobileNetV3 paper + nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer + nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1 :]), # from C4 depthwise until end + ) + + get_depth = lambda d: max(min_depth, int(d * width_mult)) # noqa: E731 + extra = nn.ModuleList( + [ + _extra_block(backbone[-1].out_channels, get_depth(512), norm_layer), + _extra_block(get_depth(512), get_depth(256), norm_layer), + _extra_block(get_depth(256), get_depth(256), norm_layer), + _extra_block(get_depth(256), get_depth(128), norm_layer), + ] + ) + _normal_init(extra) + + self.extra = extra + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + # Get feature maps from backbone and extra. Can't be refactored due to JIT limitations. + output = [] + for block in self.features: + x = block(x) + output.append(x) + + for block in self.extra: + x = block(x) + output.append(x) + + return OrderedDict([(str(i), v) for i, v in enumerate(output)]) + + +def _mobilenet_extractor( + backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3], + trainable_layers: int, + norm_layer: Callable[..., nn.Module], +): + backbone = backbone.features + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] + num_stages = len(stage_indices) + + # find the index of the layer from which we won't freeze + if not 0 <= trainable_layers <= num_stages: + raise ValueError("trainable_layers should be in the range [0, {num_stages}], instead got {trainable_layers}") + freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] + + for b in backbone[:freeze_before]: + for parameter in b.parameters(): + parameter.requires_grad_(False) + + return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer) + + +class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", + transforms=ObjectDetection, + meta={ + "num_params": 3440060, + "categories": _COCO_CATEGORIES, + "min_size": (1, 1), + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large", + "_metrics": { + "COCO-val2017": { + "box_map": 21.3, + } + }, + "_ops": 0.583, + "_file_size": 13.418, + "_docs": """These weights were produced by following a similar training recipe as on the paper.""", + }, + ) + DEFAULT = COCO_V1 + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) +def ssdlite320_mobilenet_v3_large( + *, + weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, + trainable_backbone_layers: Optional[int] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, +) -> SSD: + """SSDlite model architecture with input size 320x320 and a MobileNetV3 Large backbone, as + described at `Searching for MobileNetV3 `__ and + `MobileNetV2: Inverted Residuals and Linear Bottlenecks `__. + + .. betastatus:: detection module + + See :func:`~torchvision.models.detection.ssd300_vgg16` for more details. + + Example: + + >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(weights=SSDLite320_MobileNet_V3_Large_Weights.DEFAULT) + >>> model.eval() + >>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + weights (:class:`~torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model + (including the background). + weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained + weights for the backbone. + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers + starting from final block. Valid values are between 0 and 6, with 6 meaning all + backbone layers are trainable. If ``None`` is passed (the default) this value is + set to 6. + norm_layer (callable, optional): Module specifying the normalization layer to use. + **kwargs: parameters passed to the ``torchvision.models.detection.ssd.SSD`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights + :members: + """ + + weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + + if "size" in kwargs: + warnings.warn("The size of the model is already fixed; ignoring the parameter.") + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + trainable_backbone_layers = _validate_trainable_layers( + weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6 + ) + + # Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper. + reduce_tail = weights_backbone is None + + if norm_layer is None: + norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) + + backbone = mobilenet_v3_large( + weights=weights_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs + ) + if weights_backbone is None: + # Change the default initialization scheme if not pretrained + _normal_init(backbone) + backbone = _mobilenet_extractor( + backbone, + trainable_backbone_layers, + norm_layer, + ) + + size = (320, 320) + anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) + out_channels = det_utils.retrieve_out_channels(backbone, size) + num_anchors = anchor_generator.num_anchors_per_location() + if len(out_channels) != len(anchor_generator.aspect_ratios): + raise ValueError( + f"The length of the output channels from the backbone {len(out_channels)} do not match the length of the anchor generator aspect ratios {len(anchor_generator.aspect_ratios)}" + ) + + defaults = { + "score_thresh": 0.001, + "nms_thresh": 0.55, + "detections_per_img": 300, + "topk_candidates": 300, + # Rescale the input in a way compatible to the backbone: + # The following mean/std rescale the data from [0, 1] to [-1, 1] + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5], + } + kwargs: Any = {**defaults, **kwargs} + model = SSD( + backbone, + anchor_generator, + size, + num_classes, + head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/detection/transform.py b/lib/python3.10/site-packages/torchvision/models/detection/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..9c569b0aafb0c5464815654c0f343d7fb927dc6c --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/detection/transform.py @@ -0,0 +1,319 @@ +import math +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torchvision +from torch import nn, Tensor + +from .image_list import ImageList +from .roi_heads import paste_masks_in_image + + +@torch.jit.unused +def _get_shape_onnx(image: Tensor) -> Tensor: + from torch.onnx import operators + + return operators.shape_as_tensor(image)[-2:] + + +@torch.jit.unused +def _fake_cast_onnx(v: Tensor) -> float: + # ONNX requires a tensor but here we fake its type for JIT. + return v + + +def _resize_image_and_masks( + image: Tensor, + self_min_size: int, + self_max_size: int, + target: Optional[Dict[str, Tensor]] = None, + fixed_size: Optional[Tuple[int, int]] = None, +) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + if torchvision._is_tracing(): + im_shape = _get_shape_onnx(image) + elif torch.jit.is_scripting(): + im_shape = torch.tensor(image.shape[-2:]) + else: + im_shape = image.shape[-2:] + + size: Optional[List[int]] = None + scale_factor: Optional[float] = None + recompute_scale_factor: Optional[bool] = None + if fixed_size is not None: + size = [fixed_size[1], fixed_size[0]] + else: + if torch.jit.is_scripting() or torchvision._is_tracing(): + min_size = torch.min(im_shape).to(dtype=torch.float32) + max_size = torch.max(im_shape).to(dtype=torch.float32) + self_min_size_f = float(self_min_size) + self_max_size_f = float(self_max_size) + scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size) + + if torchvision._is_tracing(): + scale_factor = _fake_cast_onnx(scale) + else: + scale_factor = scale.item() + + else: + # Do it the normal way + min_size = min(im_shape) + max_size = max(im_shape) + scale_factor = min(self_min_size / min_size, self_max_size / max_size) + + recompute_scale_factor = True + + image = torch.nn.functional.interpolate( + image[None], + size=size, + scale_factor=scale_factor, + mode="bilinear", + recompute_scale_factor=recompute_scale_factor, + align_corners=False, + )[0] + + if target is None: + return image, target + + if "masks" in target: + mask = target["masks"] + mask = torch.nn.functional.interpolate( + mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor + )[:, 0].byte() + target["masks"] = mask + return image, target + + +class GeneralizedRCNNTransform(nn.Module): + """ + Performs input / target transformation before feeding the data to a GeneralizedRCNN + model. + + The transformations it performs are: + - input normalization (mean subtraction and std division) + - input / target resizing to match min_size / max_size + + It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets + """ + + def __init__( + self, + min_size: int, + max_size: int, + image_mean: List[float], + image_std: List[float], + size_divisible: int = 32, + fixed_size: Optional[Tuple[int, int]] = None, + **kwargs: Any, + ): + super().__init__() + if not isinstance(min_size, (list, tuple)): + min_size = (min_size,) + self.min_size = min_size + self.max_size = max_size + self.image_mean = image_mean + self.image_std = image_std + self.size_divisible = size_divisible + self.fixed_size = fixed_size + self._skip_resize = kwargs.pop("_skip_resize", False) + + def forward( + self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None + ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]: + images = [img for img in images] + if targets is not None: + # make a copy of targets to avoid modifying it in-place + # once torchscript supports dict comprehension + # this can be simplified as follows + # targets = [{k: v for k,v in t.items()} for t in targets] + targets_copy: List[Dict[str, Tensor]] = [] + for t in targets: + data: Dict[str, Tensor] = {} + for k, v in t.items(): + data[k] = v + targets_copy.append(data) + targets = targets_copy + for i in range(len(images)): + image = images[i] + target_index = targets[i] if targets is not None else None + + if image.dim() != 3: + raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}") + image = self.normalize(image) + image, target_index = self.resize(image, target_index) + images[i] = image + if targets is not None and target_index is not None: + targets[i] = target_index + + image_sizes = [img.shape[-2:] for img in images] + images = self.batch_images(images, size_divisible=self.size_divisible) + image_sizes_list: List[Tuple[int, int]] = [] + for image_size in image_sizes: + torch._assert( + len(image_size) == 2, + f"Input tensors expected to have in the last two elements H and W, instead got {image_size}", + ) + image_sizes_list.append((image_size[0], image_size[1])) + + image_list = ImageList(images, image_sizes_list) + return image_list, targets + + def normalize(self, image: Tensor) -> Tensor: + if not image.is_floating_point(): + raise TypeError( + f"Expected input images to be of floating type (in range [0, 1]), " + f"but found type {image.dtype} instead" + ) + dtype, device = image.dtype, image.device + mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device) + std = torch.as_tensor(self.image_std, dtype=dtype, device=device) + return (image - mean[:, None, None]) / std[:, None, None] + + def torch_choice(self, k: List[int]) -> int: + """ + Implements `random.choice` via torch ops, so it can be compiled with + TorchScript and we use PyTorch's RNG (not native RNG) + """ + index = int(torch.empty(1).uniform_(0.0, float(len(k))).item()) + return k[index] + + def resize( + self, + image: Tensor, + target: Optional[Dict[str, Tensor]] = None, + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + h, w = image.shape[-2:] + if self.training: + if self._skip_resize: + return image, target + size = self.torch_choice(self.min_size) + else: + size = self.min_size[-1] + image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size) + + if target is None: + return image, target + + bbox = target["boxes"] + bbox = resize_boxes(bbox, (h, w), image.shape[-2:]) + target["boxes"] = bbox + + if "keypoints" in target: + keypoints = target["keypoints"] + keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:]) + target["keypoints"] = keypoints + return image, target + + # _onnx_batch_images() is an implementation of + # batch_images() that is supported by ONNX tracing. + @torch.jit.unused + def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor: + max_size = [] + for i in range(images[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + stride = size_divisible + max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64) + max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # which is not yet supported in onnx + padded_imgs = [] + for img in images: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + return torch.stack(padded_imgs) + + def max_by_axis(self, the_list: List[List[int]]) -> List[int]: + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor: + if torchvision._is_tracing(): + # batch_images() does not export well to ONNX + # call _onnx_batch_images() instead + return self._onnx_batch_images(images, size_divisible) + + max_size = self.max_by_axis([list(img.shape) for img in images]) + stride = float(size_divisible) + max_size = list(max_size) + max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride) + max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride) + + batch_shape = [len(images)] + max_size + batched_imgs = images[0].new_full(batch_shape, 0) + for i in range(batched_imgs.shape[0]): + img = images[i] + batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + + return batched_imgs + + def postprocess( + self, + result: List[Dict[str, Tensor]], + image_shapes: List[Tuple[int, int]], + original_image_sizes: List[Tuple[int, int]], + ) -> List[Dict[str, Tensor]]: + if self.training: + return result + for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): + boxes = pred["boxes"] + boxes = resize_boxes(boxes, im_s, o_im_s) + result[i]["boxes"] = boxes + if "masks" in pred: + masks = pred["masks"] + masks = paste_masks_in_image(masks, boxes, o_im_s) + result[i]["masks"] = masks + if "keypoints" in pred: + keypoints = pred["keypoints"] + keypoints = resize_keypoints(keypoints, im_s, o_im_s) + result[i]["keypoints"] = keypoints + return result + + def __repr__(self) -> str: + format_string = f"{self.__class__.__name__}(" + _indent = "\n " + format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})" + format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')" + format_string += "\n)" + return format_string + + +def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor: + ratios = [ + torch.tensor(s, dtype=torch.float32, device=keypoints.device) + / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) + for s, s_orig in zip(new_size, original_size) + ] + ratio_h, ratio_w = ratios + resized_data = keypoints.clone() + if torch._C._get_tracing_state(): + resized_data_0 = resized_data[:, :, 0] * ratio_w + resized_data_1 = resized_data[:, :, 1] * ratio_h + resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2) + else: + resized_data[..., 0] *= ratio_w + resized_data[..., 1] *= ratio_h + return resized_data + + +def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor: + ratios = [ + torch.tensor(s, dtype=torch.float32, device=boxes.device) + / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) + for s, s_orig in zip(new_size, original_size) + ] + ratio_height, ratio_width = ratios + xmin, ymin, xmax, ymax = boxes.unbind(1) + + xmin = xmin * ratio_width + xmax = xmax * ratio_width + ymin = ymin * ratio_height + ymax = ymax * ratio_height + return torch.stack((xmin, ymin, xmax, ymax), dim=1) diff --git a/lib/python3.10/site-packages/torchvision/models/optical_flow/__init__.py b/lib/python3.10/site-packages/torchvision/models/optical_flow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89d2302f825ff0dbe25d02f6dc7c84d3c0790ad0 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/optical_flow/__init__.py @@ -0,0 +1 @@ +from .raft import * diff --git a/lib/python3.10/site-packages/torchvision/models/optical_flow/_utils.py b/lib/python3.10/site-packages/torchvision/models/optical_flow/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fa2454a27315d6e560dccb6ea2ce6083da03e256 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/optical_flow/_utils.py @@ -0,0 +1,48 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor + + +def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None): + """Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates.""" + h, w = img.shape[-2:] + + xgrid, ygrid = absolute_grid.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (w - 1) - 1 + # Adding condition if h > 1 to enable this function be reused in raft-stereo + if h > 1: + ygrid = 2 * ygrid / (h - 1) - 1 + normalized_grid = torch.cat([xgrid, ygrid], dim=-1) + + return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners) + + +def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"): + device = torch.device(device) + coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij") + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch_size, 1, 1, 1) + + +def upsample_flow(flow, up_mask: Optional[Tensor] = None, factor: int = 8): + """Upsample flow by the input factor (default 8). + + If up_mask is None we just interpolate. + If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B. + Note that in appendix B the picture assumes a downsample factor of 4 instead of 8. + """ + batch_size, num_channels, h, w = flow.shape + new_h, new_w = h * factor, w * factor + + if up_mask is None: + return factor * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True) + + up_mask = up_mask.view(batch_size, 1, 9, factor, factor, h, w) + up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1 + + upsampled_flow = F.unfold(factor * flow, kernel_size=3, padding=1).view(batch_size, num_channels, 9, 1, 1, h, w) + upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2) + + return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, num_channels, new_h, new_w) diff --git a/lib/python3.10/site-packages/torchvision/models/optical_flow/raft.py b/lib/python3.10/site-packages/torchvision/models/optical_flow/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..c294777ee6ffc0a9151f76f13bf2bde018580f9e --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/optical_flow/raft.py @@ -0,0 +1,947 @@ +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.instancenorm import InstanceNorm2d +from torchvision.ops import Conv2dNormActivation + +from ...transforms._presets import OpticalFlow +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._utils import handle_legacy_interface +from ._utils import grid_sample, make_coords_grid, upsample_flow + + +__all__ = ( + "RAFT", + "raft_large", + "raft_small", + "Raft_Large_Weights", + "Raft_Small_Weights", +) + + +class ResidualBlock(nn.Module): + """Slightly modified Residual block with extra relu and biases.""" + + def __init__(self, in_channels, out_channels, *, norm_layer, stride=1, always_project: bool = False): + super().__init__() + + # Note regarding bias=True: + # Usually we can pass bias=False in conv layers followed by a norm layer. + # But in the RAFT training reference, the BatchNorm2d layers are only activated for the first dataset, + # and frozen for the rest of the training process (i.e. set as eval()). The bias term is thus still useful + # for the rest of the datasets. Technically, we could remove the bias for other norm layers like Instance norm + # because these aren't frozen, but we don't bother (also, we wouldn't be able to load the original weights). + self.convnormrelu1 = Conv2dNormActivation( + in_channels, out_channels, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True + ) + self.convnormrelu2 = Conv2dNormActivation( + out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True + ) + + # make mypy happy + self.downsample: nn.Module + + if stride == 1 and not always_project: + self.downsample = nn.Identity() + else: + self.downsample = Conv2dNormActivation( + in_channels, + out_channels, + norm_layer=norm_layer, + kernel_size=1, + stride=stride, + bias=True, + activation_layer=None, + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + y = x + y = self.convnormrelu1(y) + y = self.convnormrelu2(y) + + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + """Slightly modified BottleNeck block (extra relu and biases)""" + + def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): + super().__init__() + + # See note in ResidualBlock for the reason behind bias=True + self.convnormrelu1 = Conv2dNormActivation( + in_channels, out_channels // 4, norm_layer=norm_layer, kernel_size=1, bias=True + ) + self.convnormrelu2 = Conv2dNormActivation( + out_channels // 4, out_channels // 4, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True + ) + self.convnormrelu3 = Conv2dNormActivation( + out_channels // 4, out_channels, norm_layer=norm_layer, kernel_size=1, bias=True + ) + self.relu = nn.ReLU(inplace=True) + + if stride == 1: + self.downsample = nn.Identity() + else: + self.downsample = Conv2dNormActivation( + in_channels, + out_channels, + norm_layer=norm_layer, + kernel_size=1, + stride=stride, + bias=True, + activation_layer=None, + ) + + def forward(self, x): + y = x + y = self.convnormrelu1(y) + y = self.convnormrelu2(y) + y = self.convnormrelu3(y) + + x = self.downsample(x) + + return self.relu(x + y) + + +class FeatureEncoder(nn.Module): + """The feature encoder, used both as the actual feature encoder, and as the context encoder. + + It must downsample its input by 8. + """ + + def __init__( + self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), strides=(2, 1, 2, 2), norm_layer=nn.BatchNorm2d + ): + super().__init__() + + if len(layers) != 5: + raise ValueError(f"The expected number of layers is 5, instead got {len(layers)}") + + # See note in ResidualBlock for the reason behind bias=True + self.convnormrelu = Conv2dNormActivation( + 3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=strides[0], bias=True + ) + + self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=strides[1]) + self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=strides[2]) + self.layer3 = self._make_2_blocks(block, layers[2], layers[3], norm_layer=norm_layer, first_stride=strides[3]) + + self.conv = nn.Conv2d(layers[3], layers[4], kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + num_downsamples = len(list(filter(lambda s: s == 2, strides))) + self.output_dim = layers[-1] + self.downsample_factor = 2**num_downsamples + + def _make_2_blocks(self, block, in_channels, out_channels, norm_layer, first_stride): + block1 = block(in_channels, out_channels, norm_layer=norm_layer, stride=first_stride) + block2 = block(out_channels, out_channels, norm_layer=norm_layer, stride=1) + return nn.Sequential(block1, block2) + + def forward(self, x): + x = self.convnormrelu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv(x) + + return x + + +class MotionEncoder(nn.Module): + """The motion encoder, part of the update block. + + Takes the current predicted flow and the correlation features as input and returns an encoded version of these. + """ + + def __init__(self, *, in_channels_corr, corr_layers=(256, 192), flow_layers=(128, 64), out_channels=128): + super().__init__() + + if len(flow_layers) != 2: + raise ValueError(f"The expected number of flow_layers is 2, instead got {len(flow_layers)}") + if len(corr_layers) not in (1, 2): + raise ValueError(f"The number of corr_layers should be 1 or 2, instead got {len(corr_layers)}") + + self.convcorr1 = Conv2dNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1) + if len(corr_layers) == 2: + self.convcorr2 = Conv2dNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3) + else: + self.convcorr2 = nn.Identity() + + self.convflow1 = Conv2dNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7) + self.convflow2 = Conv2dNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3) + + # out_channels - 2 because we cat the flow (2 channels) at the end + self.conv = Conv2dNormActivation( + corr_layers[-1] + flow_layers[-1], out_channels - 2, norm_layer=None, kernel_size=3 + ) + + self.out_channels = out_channels + + def forward(self, flow, corr_features): + corr = self.convcorr1(corr_features) + corr = self.convcorr2(corr) + + flow_orig = flow + flow = self.convflow1(flow) + flow = self.convflow2(flow) + + corr_flow = torch.cat([corr, flow], dim=1) + corr_flow = self.conv(corr_flow) + return torch.cat([corr_flow, flow_orig], dim=1) + + +class ConvGRU(nn.Module): + """Convolutional Gru unit.""" + + def __init__(self, *, input_size, hidden_size, kernel_size, padding): + super().__init__() + self.convz = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding) + self.convr = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding) + self.convq = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + return h + + +def _pass_through_h(h, _): + # Declared here for torchscript + return h + + +class RecurrentBlock(nn.Module): + """Recurrent block, part of the update block. + + Takes the current hidden state and the concatenation of (motion encoder output, context) as input. + Returns an updated hidden state. + """ + + def __init__(self, *, input_size, hidden_size, kernel_size=((1, 5), (5, 1)), padding=((0, 2), (2, 0))): + super().__init__() + + if len(kernel_size) != len(padding): + raise ValueError( + f"kernel_size should have the same length as padding, instead got len(kernel_size) = {len(kernel_size)} and len(padding) = {len(padding)}" + ) + if len(kernel_size) not in (1, 2): + raise ValueError(f"kernel_size should either 1 or 2, instead got {len(kernel_size)}") + + self.convgru1 = ConvGRU( + input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[0], padding=padding[0] + ) + if len(kernel_size) == 2: + self.convgru2 = ConvGRU( + input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[1], padding=padding[1] + ) + else: + self.convgru2 = _pass_through_h + + self.hidden_size = hidden_size + + def forward(self, h, x): + h = self.convgru1(h, x) + h = self.convgru2(h, x) + return h + + +class FlowHead(nn.Module): + """Flow head, part of the update block. + + Takes the hidden state of the recurrent unit as input, and outputs the predicted "delta flow". + """ + + def __init__(self, *, in_channels, hidden_size): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, hidden_size, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_size, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class UpdateBlock(nn.Module): + """The update block which contains the motion encoder, the recurrent block, and the flow head. + + It must expose a ``hidden_state_size`` attribute which is the hidden state size of its recurrent block. + """ + + def __init__(self, *, motion_encoder, recurrent_block, flow_head): + super().__init__() + self.motion_encoder = motion_encoder + self.recurrent_block = recurrent_block + self.flow_head = flow_head + + self.hidden_state_size = recurrent_block.hidden_size + + def forward(self, hidden_state, context, corr_features, flow): + motion_features = self.motion_encoder(flow, corr_features) + x = torch.cat([context, motion_features], dim=1) + + hidden_state = self.recurrent_block(hidden_state, x) + delta_flow = self.flow_head(hidden_state) + return hidden_state, delta_flow + + +class MaskPredictor(nn.Module): + """Mask predictor to be used when upsampling the predicted flow. + + It takes the hidden state of the recurrent unit as input and outputs the mask. + This is not used in the raft-small model. + """ + + def __init__(self, *, in_channels, hidden_size, multiplier=0.25): + super().__init__() + self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3) + # 8 * 8 * 9 because the predicted flow is downsampled by 8, from the downsampling of the initial FeatureEncoder, + # and we interpolate with all 9 surrounding neighbors. See paper and appendix B. + self.conv = nn.Conv2d(hidden_size, 8 * 8 * 9, 1, padding=0) + + # In the original code, they use a factor of 0.25 to "downweight the gradients" of that branch. + # See e.g. https://github.com/princeton-vl/RAFT/issues/119#issuecomment-953950419 + # or https://github.com/princeton-vl/RAFT/issues/24. + # It doesn't seem to affect epe significantly and can likely be set to 1. + self.multiplier = multiplier + + def forward(self, x): + x = self.convrelu(x) + x = self.conv(x) + return self.multiplier * x + + +class CorrBlock(nn.Module): + """The correlation block. + + Creates a correlation pyramid with ``num_levels`` levels from the outputs of the feature encoder, + and then indexes from this pyramid to create correlation features. + The "indexing" of a given centroid pixel x' is done by concatenating its surrounding neighbors that + are within a ``radius``, according to the infinity norm (see paper section 3.2). + Note: typo in the paper, it should be infinity norm, not 1-norm. + """ + + def __init__(self, *, num_levels: int = 4, radius: int = 4): + super().__init__() + self.num_levels = num_levels + self.radius = radius + + self.corr_pyramid: List[Tensor] = [torch.tensor(0)] # useless, but torchscript is otherwise confused :') + + # The neighborhood of a centroid pixel x' is {x' + delta, ||delta||_inf <= radius} + # so it's a square surrounding x', and its sides have a length of 2 * radius + 1 + # The paper claims that it's ||.||_1 instead of ||.||_inf but it's a typo: + # https://github.com/princeton-vl/RAFT/issues/122 + self.out_channels = num_levels * (2 * radius + 1) ** 2 + + def build_pyramid(self, fmap1, fmap2): + """Build the correlation pyramid from two feature maps. + + The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2) + The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions + to build the correlation pyramid. + """ + + if fmap1.shape != fmap2.shape: + raise ValueError( + f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)" + ) + + # Explaining min_fmap_size below: the fmaps are down-sampled (num_levels - 1) times by a factor of 2. + # The last corr_volume most have at least 2 values (hence the 2* factor), otherwise grid_sample() would + # produce nans in its output. + min_fmap_size = 2 * (2 ** (self.num_levels - 1)) + if any(fmap_size < min_fmap_size for fmap_size in fmap1.shape[-2:]): + raise ValueError( + "Feature maps are too small to be down-sampled by the correlation pyramid. " + f"H and W of feature maps should be at least {min_fmap_size}; got: {fmap1.shape[-2:]}. " + "Remember that input images to the model are downsampled by 8, so that means their " + f"dimensions should be at least 8 * {min_fmap_size} = {8 * min_fmap_size}." + ) + + corr_volume = self._compute_corr_volume(fmap1, fmap2) + + batch_size, h, w, num_channels, _, _ = corr_volume.shape # _, _ = h, w + corr_volume = corr_volume.reshape(batch_size * h * w, num_channels, h, w) + self.corr_pyramid = [corr_volume] + for _ in range(self.num_levels - 1): + corr_volume = F.avg_pool2d(corr_volume, kernel_size=2, stride=2) + self.corr_pyramid.append(corr_volume) + + def index_pyramid(self, centroids_coords): + """Return correlation features by indexing from the pyramid.""" + neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels + di = torch.linspace(-self.radius, self.radius, neighborhood_side_len) + dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len) + delta = torch.stack(torch.meshgrid(di, dj, indexing="ij"), dim=-1).to(centroids_coords.device) + delta = delta.view(1, neighborhood_side_len, neighborhood_side_len, 2) + + batch_size, _, h, w = centroids_coords.shape # _ = 2 + centroids_coords = centroids_coords.permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 2) + + indexed_pyramid = [] + for corr_volume in self.corr_pyramid: + sampling_coords = centroids_coords + delta # end shape is (batch_size * h * w, side_len, side_len, 2) + indexed_corr_volume = grid_sample(corr_volume, sampling_coords, align_corners=True, mode="bilinear").view( + batch_size, h, w, -1 + ) + indexed_pyramid.append(indexed_corr_volume) + centroids_coords = centroids_coords / 2 + + corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous() + + expected_output_shape = (batch_size, self.out_channels, h, w) + if corr_features.shape != expected_output_shape: + raise ValueError( + f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}" + ) + + return corr_features + + def _compute_corr_volume(self, fmap1, fmap2): + batch_size, num_channels, h, w = fmap1.shape + fmap1 = fmap1.view(batch_size, num_channels, h * w) + fmap2 = fmap2.view(batch_size, num_channels, h * w) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch_size, h, w, 1, h, w) + return corr / torch.sqrt(torch.tensor(num_channels)) + + +class RAFT(nn.Module): + def __init__(self, *, feature_encoder, context_encoder, corr_block, update_block, mask_predictor=None): + """RAFT model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. + + args: + feature_encoder (nn.Module): The feature encoder. It must downsample the input by 8. + Its input is the concatenation of ``image1`` and ``image2``. + context_encoder (nn.Module): The context encoder. It must downsample the input by 8. + Its input is ``image1``. As in the original implementation, its output will be split into 2 parts: + + - one part will be used as the actual "context", passed to the recurrent unit of the ``update_block`` + - one part will be used to initialize the hidden state of the recurrent unit of + the ``update_block`` + + These 2 parts are split according to the ``hidden_state_size`` of the ``update_block``, so the output + of the ``context_encoder`` must be strictly greater than ``hidden_state_size``. + + corr_block (nn.Module): The correlation block, which creates a correlation pyramid from the output of the + ``feature_encoder``, and then indexes from this pyramid to create correlation features. It must expose + 2 methods: + + - a ``build_pyramid`` method that takes ``feature_map_1`` and ``feature_map_2`` as input (these are the + output of the ``feature_encoder``). + - a ``index_pyramid`` method that takes the coordinates of the centroid pixels as input, and returns + the correlation features. See paper section 3.2. + + It must expose an ``out_channels`` attribute. + + update_block (nn.Module): The update block, which contains the motion encoder, the recurrent unit, and the + flow head. It takes as input the hidden state of its recurrent unit, the context, the correlation + features, and the current predicted flow. It outputs an updated hidden state, and the ``delta_flow`` + prediction (see paper appendix A). It must expose a ``hidden_state_size`` attribute. + mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow. + The output channel must be 8 * 8 * 9 - see paper section 3.3, and Appendix B. + If ``None`` (default), the flow is upsampled using interpolation. + """ + super().__init__() + _log_api_usage_once(self) + + self.feature_encoder = feature_encoder + self.context_encoder = context_encoder + self.corr_block = corr_block + self.update_block = update_block + + self.mask_predictor = mask_predictor + + if not hasattr(self.update_block, "hidden_state_size"): + raise ValueError("The update_block parameter should expose a 'hidden_state_size' attribute.") + + def forward(self, image1, image2, num_flow_updates: int = 12): + + batch_size, _, h, w = image1.shape + if (h, w) != image2.shape[-2:]: + raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}") + if not (h % 8 == 0) and (w % 8 == 0): + raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)") + + fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0)) + fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) + if fmap1.shape[-2:] != (h // 8, w // 8): + raise ValueError("The feature encoder should downsample H and W by 8") + + self.corr_block.build_pyramid(fmap1, fmap2) + + context_out = self.context_encoder(image1) + if context_out.shape[-2:] != (h // 8, w // 8): + raise ValueError("The context encoder should downsample H and W by 8") + + # As in the original paper, the actual output of the context encoder is split in 2 parts: + # - one part is used to initialize the hidden state of the recurent units of the update block + # - the rest is the "actual" context. + hidden_state_size = self.update_block.hidden_state_size + out_channels_context = context_out.shape[1] - hidden_state_size + if out_channels_context <= 0: + raise ValueError( + f"The context encoder outputs {context_out.shape[1]} channels, but it should have at strictly more than hidden_state={hidden_state_size} channels" + ) + hidden_state, context = torch.split(context_out, [hidden_state_size, out_channels_context], dim=1) + hidden_state = torch.tanh(hidden_state) + context = F.relu(context) + + coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) + coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) + + flow_predictions = [] + for _ in range(num_flow_updates): + coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper + corr_features = self.corr_block.index_pyramid(centroids_coords=coords1) + + flow = coords1 - coords0 + hidden_state, delta_flow = self.update_block(hidden_state, context, corr_features, flow) + + coords1 = coords1 + delta_flow + + up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state) + upsampled_flow = upsample_flow(flow=(coords1 - coords0), up_mask=up_mask) + flow_predictions.append(upsampled_flow) + + return flow_predictions + + +_COMMON_META = { + "min_size": (128, 128), +} + + +class Raft_Large_Weights(WeightsEnum): + """The metrics reported here are as follows. + + ``epe`` is the "end-point-error" and indicates how far (in pixels) the + predicted flow is from its true value. This is averaged over all pixels + of all images. ``per_image_epe`` is similar, but the average is different: + the epe is first computed on each image independently, and then averaged + over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe") + in the original paper, and it's only used on Kitti. ``fl-all`` is also a + Kitti-specific metric, defined by the author of the dataset and used for the + Kitti leaderboard. It corresponds to the average of pixels whose epe is + either <3px, or <5% of flow's 2-norm. + """ + + C_T_V1 = Weights( + # Weights ported from https://github.com/princeton-vl/RAFT + url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "_metrics": { + "Sintel-Train-Cleanpass": {"epe": 1.4411}, + "Sintel-Train-Finalpass": {"epe": 2.7894}, + "Kitti-Train": {"per_image_epe": 5.0172, "fl_all": 17.4506}, + }, + "_ops": 211.007, + "_file_size": 20.129, + "_docs": """These weights were ported from the original paper. They + are trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`.""", + }, + ) + + C_T_V2 = Weights( + url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "_metrics": { + "Sintel-Train-Cleanpass": {"epe": 1.3822}, + "Sintel-Train-Finalpass": {"epe": 2.7161}, + "Kitti-Train": {"per_image_epe": 4.5118, "fl_all": 16.0679}, + }, + "_ops": 211.007, + "_file_size": 20.129, + "_docs": """These weights were trained from scratch on + :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`.""", + }, + ) + + C_T_SKHT_V1 = Weights( + # Weights ported from https://github.com/princeton-vl/RAFT + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "_metrics": { + "Sintel-Test-Cleanpass": {"epe": 1.94}, + "Sintel-Test-Finalpass": {"epe": 3.18}, + }, + "_ops": 211.007, + "_file_size": 20.129, + "_docs": """ + These weights were ported from the original paper. They are + trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D` and fine-tuned on + Sintel. The Sintel fine-tuning step is a combination of + :class:`~torchvision.datasets.Sintel`, + :class:`~torchvision.datasets.KittiFlow`, + :class:`~torchvision.datasets.HD1K`, and + :class:`~torchvision.datasets.FlyingThings3D` (clean pass). + """, + }, + ) + + C_T_SKHT_V2 = Weights( + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "_metrics": { + "Sintel-Test-Cleanpass": {"epe": 1.819}, + "Sintel-Test-Finalpass": {"epe": 3.067}, + }, + "_ops": 211.007, + "_file_size": 20.129, + "_docs": """ + These weights were trained from scratch. They are + pre-trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D` and then + fine-tuned on Sintel. The Sintel fine-tuning step is a + combination of :class:`~torchvision.datasets.Sintel`, + :class:`~torchvision.datasets.KittiFlow`, + :class:`~torchvision.datasets.HD1K`, and + :class:`~torchvision.datasets.FlyingThings3D` (clean pass). + """, + }, + ) + + C_T_SKHT_K_V1 = Weights( + # Weights ported from https://github.com/princeton-vl/RAFT + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "_metrics": { + "Kitti-Test": {"fl_all": 5.10}, + }, + "_ops": 211.007, + "_file_size": 20.129, + "_docs": """ + These weights were ported from the original paper. They are + pre-trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`, + fine-tuned on Sintel, and then fine-tuned on + :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning + step was described above. + """, + }, + ) + + C_T_SKHT_K_V2 = Weights( + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "_metrics": { + "Kitti-Test": {"fl_all": 5.19}, + }, + "_ops": 211.007, + "_file_size": 20.129, + "_docs": """ + These weights were trained from scratch. They are + pre-trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`, + fine-tuned on Sintel, and then fine-tuned on + :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning + step was described above. + """, + }, + ) + + DEFAULT = C_T_SKHT_V2 + + +class Raft_Small_Weights(WeightsEnum): + """The metrics reported here are as follows. + + ``epe`` is the "end-point-error" and indicates how far (in pixels) the + predicted flow is from its true value. This is averaged over all pixels + of all images. ``per_image_epe`` is similar, but the average is different: + the epe is first computed on each image independently, and then averaged + over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe") + in the original paper, and it's only used on Kitti. ``fl-all`` is also a + Kitti-specific metric, defined by the author of the dataset and used for the + Kitti leaderboard. It corresponds to the average of pixels whose epe is + either <3px, or <5% of flow's 2-norm. + """ + + C_T_V1 = Weights( + # Weights ported from https://github.com/princeton-vl/RAFT + url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 990162, + "recipe": "https://github.com/princeton-vl/RAFT", + "_metrics": { + "Sintel-Train-Cleanpass": {"epe": 2.1231}, + "Sintel-Train-Finalpass": {"epe": 3.2790}, + "Kitti-Train": {"per_image_epe": 7.6557, "fl_all": 25.2801}, + }, + "_ops": 47.655, + "_file_size": 3.821, + "_docs": """These weights were ported from the original paper. They + are trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`.""", + }, + ) + C_T_V2 = Weights( + url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 990162, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "_metrics": { + "Sintel-Train-Cleanpass": {"epe": 1.9901}, + "Sintel-Train-Finalpass": {"epe": 3.2831}, + "Kitti-Train": {"per_image_epe": 7.5978, "fl_all": 25.2369}, + }, + "_ops": 47.655, + "_file_size": 3.821, + "_docs": """These weights were trained from scratch on + :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`.""", + }, + ) + + DEFAULT = C_T_V2 + + +def _raft( + *, + weights=None, + progress=False, + # Feature encoder + feature_encoder_layers, + feature_encoder_block, + feature_encoder_norm_layer, + # Context encoder + context_encoder_layers, + context_encoder_block, + context_encoder_norm_layer, + # Correlation block + corr_block_num_levels, + corr_block_radius, + # Motion encoder + motion_encoder_corr_layers, + motion_encoder_flow_layers, + motion_encoder_out_channels, + # Recurrent block + recurrent_block_hidden_state_size, + recurrent_block_kernel_size, + recurrent_block_padding, + # Flow Head + flow_head_hidden_size, + # Mask predictor + use_mask_predictor, + **kwargs, +): + feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder( + block=feature_encoder_block, layers=feature_encoder_layers, norm_layer=feature_encoder_norm_layer + ) + context_encoder = kwargs.pop("context_encoder", None) or FeatureEncoder( + block=context_encoder_block, layers=context_encoder_layers, norm_layer=context_encoder_norm_layer + ) + + corr_block = kwargs.pop("corr_block", None) or CorrBlock(num_levels=corr_block_num_levels, radius=corr_block_radius) + + update_block = kwargs.pop("update_block", None) + if update_block is None: + motion_encoder = MotionEncoder( + in_channels_corr=corr_block.out_channels, + corr_layers=motion_encoder_corr_layers, + flow_layers=motion_encoder_flow_layers, + out_channels=motion_encoder_out_channels, + ) + + # See comments in forward pass of RAFT class about why we split the output of the context encoder + out_channels_context = context_encoder_layers[-1] - recurrent_block_hidden_state_size + recurrent_block = RecurrentBlock( + input_size=motion_encoder.out_channels + out_channels_context, + hidden_size=recurrent_block_hidden_state_size, + kernel_size=recurrent_block_kernel_size, + padding=recurrent_block_padding, + ) + + flow_head = FlowHead(in_channels=recurrent_block_hidden_state_size, hidden_size=flow_head_hidden_size) + + update_block = UpdateBlock(motion_encoder=motion_encoder, recurrent_block=recurrent_block, flow_head=flow_head) + + mask_predictor = kwargs.pop("mask_predictor", None) + if mask_predictor is None and use_mask_predictor: + mask_predictor = MaskPredictor( + in_channels=recurrent_block_hidden_state_size, + hidden_size=256, + multiplier=0.25, # See comment in MaskPredictor about this + ) + + model = RAFT( + feature_encoder=feature_encoder, + context_encoder=context_encoder, + corr_block=corr_block, + update_block=update_block, + mask_predictor=mask_predictor, + **kwargs, # not really needed, all params should be consumed by now + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2)) +def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs) -> RAFT: + """RAFT model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. + + Please see the example below for a tutorial on how to use this model. + + Args: + weights(:class:`~torchvision.models.optical_flow.Raft_Large_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.optical_flow.Raft_Large_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.optical_flow.Raft_Large_Weights + :members: + """ + + weights = Raft_Large_Weights.verify(weights) + + return _raft( + weights=weights, + progress=progress, + # Feature encoder + feature_encoder_layers=(64, 64, 96, 128, 256), + feature_encoder_block=ResidualBlock, + feature_encoder_norm_layer=InstanceNorm2d, + # Context encoder + context_encoder_layers=(64, 64, 96, 128, 256), + context_encoder_block=ResidualBlock, + context_encoder_norm_layer=BatchNorm2d, + # Correlation block + corr_block_num_levels=4, + corr_block_radius=4, + # Motion encoder + motion_encoder_corr_layers=(256, 192), + motion_encoder_flow_layers=(128, 64), + motion_encoder_out_channels=128, + # Recurrent block + recurrent_block_hidden_state_size=128, + recurrent_block_kernel_size=((1, 5), (5, 1)), + recurrent_block_padding=((0, 2), (2, 0)), + # Flow head + flow_head_hidden_size=256, + # Mask predictor + use_mask_predictor=True, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) +def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT: + """RAFT "small" model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `__. + + Please see the example below for a tutorial on how to use this model. + + Args: + weights(:class:`~torchvision.models.optical_flow.Raft_Small_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.optical_flow.Raft_Small_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.optical_flow.Raft_Small_Weights + :members: + """ + weights = Raft_Small_Weights.verify(weights) + + return _raft( + weights=weights, + progress=progress, + # Feature encoder + feature_encoder_layers=(32, 32, 64, 96, 128), + feature_encoder_block=BottleneckBlock, + feature_encoder_norm_layer=InstanceNorm2d, + # Context encoder + context_encoder_layers=(32, 32, 64, 96, 160), + context_encoder_block=BottleneckBlock, + context_encoder_norm_layer=None, + # Correlation block + corr_block_num_levels=4, + corr_block_radius=3, + # Motion encoder + motion_encoder_corr_layers=(96,), + motion_encoder_flow_layers=(64, 32), + motion_encoder_out_channels=82, + # Recurrent block + recurrent_block_hidden_state_size=96, + recurrent_block_kernel_size=(3,), + recurrent_block_padding=(1,), + # Flow head + flow_head_hidden_size=128, + # Mask predictor + use_mask_predictor=False, + **kwargs, + ) diff --git a/lib/python3.10/site-packages/torchvision/models/quantization/__init__.py b/lib/python3.10/site-packages/torchvision/models/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da8bbba3567b0b9110429354d89b65ec679a2fd5 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/quantization/__init__.py @@ -0,0 +1,5 @@ +from .googlenet import * +from .inception import * +from .mobilenet import * +from .resnet import * +from .shufflenetv2 import * diff --git a/lib/python3.10/site-packages/torchvision/models/quantization/googlenet.py b/lib/python3.10/site-packages/torchvision/models/quantization/googlenet.py new file mode 100644 index 0000000000000000000000000000000000000000..30ef3356ba13108b9bdc4c90a9ab4cb7f92e445a --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/quantization/googlenet.py @@ -0,0 +1,210 @@ +import warnings +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + +from ...transforms._presets import ImageClassification +from .._api import register_model, Weights, WeightsEnum +from .._meta import _IMAGENET_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from ..googlenet import BasicConv2d, GoogLeNet, GoogLeNet_Weights, GoogLeNetOutputs, Inception, InceptionAux +from .utils import _fuse_modules, _replace_relu, quantize_model + + +__all__ = [ + "QuantizableGoogLeNet", + "GoogLeNet_QuantizedWeights", + "googlenet", +] + + +class QuantizableBasicConv2d(BasicConv2d): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.relu = nn.ReLU() + + def forward(self, x: Tensor) -> Tensor: + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + _fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True) + + +class QuantizableInception(Inception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.cat = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return self.cat.cat(outputs, 1) + + +class QuantizableInceptionAux(InceptionAux): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.relu = nn.ReLU() + + def forward(self, x: Tensor) -> Tensor: + # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 + x = F.adaptive_avg_pool2d(x, (4, 4)) + # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 + x = self.conv(x) + # N x 128 x 4 x 4 + x = torch.flatten(x, 1) + # N x 2048 + x = self.relu(self.fc1(x)) + # N x 1024 + x = self.dropout(x) + # N x 1024 + x = self.fc2(x) + # N x 1000 (num_classes) + + return x + + +class QuantizableGoogLeNet(GoogLeNet): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__( # type: ignore[misc] + *args, blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], **kwargs + ) + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x: Tensor) -> GoogLeNetOutputs: + x = self._transform_input(x) + x = self.quant(x) + x, aux1, aux2 = self._forward(x) + x = self.dequant(x) + aux_defined = self.training and self.aux_logits + if torch.jit.is_scripting(): + if not aux_defined: + warnings.warn("Scripted QuantizableGoogleNet always returns GoogleNetOutputs Tuple") + return GoogLeNetOutputs(x, aux2, aux1) + else: + return self.eager_outputs(x, aux2, aux1) + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + r"""Fuse conv/bn/relu modules in googlenet model + + Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. + Model is modified in place. Note that this operation does not change numerics + and the model after modification is in floating point + """ + + for m in self.modules(): + if type(m) is QuantizableBasicConv2d: + m.fuse_model(is_qat) + + +class GoogLeNet_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c81f6644.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + "num_params": 6624904, + "min_size": (15, 15), + "categories": _IMAGENET_CATEGORIES, + "backend": "fbgemm", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "unquantized": GoogLeNet_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 69.826, + "acc@5": 89.404, + } + }, + "_ops": 1.498, + "_file_size": 12.618, + "_docs": """ + These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized + weights listed below. + """, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@register_model(name="quantized_googlenet") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else GoogLeNet_Weights.IMAGENET1K_V1, + ) +) +def googlenet( + *, + weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableGoogLeNet: + """GoogLeNet (Inception v1) model architecture from `Going Deeper with Convolutions `__. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.GoogLeNet_QuantizedWeights` or :class:`~torchvision.models.GoogLeNet_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.GoogLeNet_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + quantize (bool, optional): If True, return a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableGoogLeNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.GoogLeNet_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.GoogLeNet_Weights + :members: + :noindex: + """ + weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: + if "transform_input" not in kwargs: + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") + + model = QuantizableGoogLeNet(**kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + if not original_aux_logits: + model.aux_logits = False + model.aux1 = None # type: ignore[assignment] + model.aux2 = None # type: ignore[assignment] + else: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" + ) + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/quantization/inception.py b/lib/python3.10/site-packages/torchvision/models/quantization/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..75c126697e99befd6ae7d3c1ee88fb8542e06d31 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/quantization/inception.py @@ -0,0 +1,273 @@ +import warnings +from functools import partial +from typing import Any, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torchvision.models import inception as inception_module +from torchvision.models.inception import Inception_V3_Weights, InceptionOutputs + +from ...transforms._presets import ImageClassification +from .._api import register_model, Weights, WeightsEnum +from .._meta import _IMAGENET_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from .utils import _fuse_modules, _replace_relu, quantize_model + + +__all__ = [ + "QuantizableInception3", + "Inception_V3_QuantizedWeights", + "inception_v3", +] + + +class QuantizableBasicConv2d(inception_module.BasicConv2d): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.relu = nn.ReLU() + + def forward(self, x: Tensor) -> Tensor: + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + _fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True) + + +class QuantizableInceptionA(inception_module.InceptionA): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.myop = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return self.myop.cat(outputs, 1) + + +class QuantizableInceptionB(inception_module.InceptionB): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.myop = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return self.myop.cat(outputs, 1) + + +class QuantizableInceptionC(inception_module.InceptionC): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.myop = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return self.myop.cat(outputs, 1) + + +class QuantizableInceptionD(inception_module.InceptionD): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.myop = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return self.myop.cat(outputs, 1) + + +class QuantizableInceptionE(inception_module.InceptionE): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.myop1 = nn.quantized.FloatFunctional() + self.myop2 = nn.quantized.FloatFunctional() + self.myop3 = nn.quantized.FloatFunctional() + + def _forward(self, x: Tensor) -> List[Tensor]: + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)] + branch3x3 = self.myop1.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = self.myop2.cat(branch3x3dbl, 1) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return self.myop3.cat(outputs, 1) + + +class QuantizableInceptionAux(inception_module.InceptionAux): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + + +class QuantizableInception3(inception_module.Inception3): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__( # type: ignore[misc] + *args, + inception_blocks=[ + QuantizableBasicConv2d, + QuantizableInceptionA, + QuantizableInceptionB, + QuantizableInceptionC, + QuantizableInceptionD, + QuantizableInceptionE, + QuantizableInceptionAux, + ], + **kwargs, + ) + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x: Tensor) -> InceptionOutputs: + x = self._transform_input(x) + x = self.quant(x) + x, aux = self._forward(x) + x = self.dequant(x) + aux_defined = self.training and self.aux_logits + if torch.jit.is_scripting(): + if not aux_defined: + warnings.warn("Scripted QuantizableInception3 always returns QuantizableInception3 Tuple") + return InceptionOutputs(x, aux) + else: + return self.eager_outputs(x, aux) + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + r"""Fuse conv/bn/relu modules in inception model + + Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. + Model is modified in place. Note that this operation does not change numerics + and the model after modification is in floating point + """ + + for m in self.modules(): + if type(m) is QuantizableBasicConv2d: + m.fuse_model(is_qat) + + +class Inception_V3_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-a2837893.pth", + transforms=partial(ImageClassification, crop_size=299, resize_size=342), + meta={ + "num_params": 27161264, + "min_size": (75, 75), + "categories": _IMAGENET_CATEGORIES, + "backend": "fbgemm", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "unquantized": Inception_V3_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 77.176, + "acc@5": 93.354, + } + }, + "_ops": 5.713, + "_file_size": 23.146, + "_docs": """ + These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized + weights listed below. + """, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@register_model(name="quantized_inception_v3") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else Inception_V3_Weights.IMAGENET1K_V1, + ) +) +def inception_v3( + *, + weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableInception3: + r"""Inception v3 model architecture from + `Rethinking the Inception Architecture for Computer Vision `__. + + .. note:: + **Important**: In contrast to the other models the inception_v3 expects tensors with a size of + N x 3 x 299 x 299, so ensure your images are sized accordingly. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.Inception_V3_QuantizedWeights` or :class:`~torchvision.models.Inception_V3_Weights`, optional): The pretrained + weights for the model. See + :class:`~torchvision.models.quantization.Inception_V3_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. + Default is True. + quantize (bool, optional): If True, return a quantized version of the model. + Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableInception3`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.Inception_V3_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.Inception_V3_Weights + :members: + :noindex: + """ + weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: + if "transform_input" not in kwargs: + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") + + model = QuantizableInception3(**kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + if quantize and not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + if not quantize and not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/quantization/mobilenet.py b/lib/python3.10/site-packages/torchvision/models/quantization/mobilenet.py new file mode 100644 index 0000000000000000000000000000000000000000..0a270d14d3a4ad9eda62b68c2c01e9fdb710ef38 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/quantization/mobilenet.py @@ -0,0 +1,6 @@ +from .mobilenetv2 import * # noqa: F401, F403 +from .mobilenetv3 import * # noqa: F401, F403 +from .mobilenetv2 import __all__ as mv2_all +from .mobilenetv3 import __all__ as mv3_all + +__all__ = mv2_all + mv3_all diff --git a/lib/python3.10/site-packages/torchvision/models/quantization/mobilenetv2.py b/lib/python3.10/site-packages/torchvision/models/quantization/mobilenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..4700bb4af931072f1aee3403c1e8c461ec33c76d --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/quantization/mobilenetv2.py @@ -0,0 +1,154 @@ +from functools import partial +from typing import Any, Optional, Union + +from torch import nn, Tensor +from torch.ao.quantization import DeQuantStub, QuantStub +from torchvision.models.mobilenetv2 import InvertedResidual, MobileNet_V2_Weights, MobileNetV2 + +from ...ops.misc import Conv2dNormActivation +from ...transforms._presets import ImageClassification +from .._api import register_model, Weights, WeightsEnum +from .._meta import _IMAGENET_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from .utils import _fuse_modules, _replace_relu, quantize_model + + +__all__ = [ + "QuantizableMobileNetV2", + "MobileNet_V2_QuantizedWeights", + "mobilenet_v2", +] + + +class QuantizableInvertedResidual(InvertedResidual): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + if self.use_res_connect: + return self.skip_add.add(x, self.conv(x)) + else: + return self.conv(x) + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + for idx in range(len(self.conv)): + if type(self.conv[idx]) is nn.Conv2d: + _fuse_modules(self.conv, [str(idx), str(idx + 1)], is_qat, inplace=True) + + +class QuantizableMobileNetV2(MobileNetV2): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """ + MobileNet V2 main class + + Args: + Inherits args from floating point MobileNetV2 + """ + super().__init__(*args, **kwargs) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x: Tensor) -> Tensor: + x = self.quant(x) + x = self._forward_impl(x) + x = self.dequant(x) + return x + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + for m in self.modules(): + if type(m) is Conv2dNormActivation: + _fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True) + if type(m) is QuantizableInvertedResidual: + m.fuse_model(is_qat) + + +class MobileNet_V2_QuantizedWeights(WeightsEnum): + IMAGENET1K_QNNPACK_V1 = Weights( + url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + "num_params": 3504872, + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "backend": "qnnpack", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", + "unquantized": MobileNet_V2_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 71.658, + "acc@5": 90.150, + } + }, + "_ops": 0.301, + "_file_size": 3.423, + "_docs": """ + These weights were produced by doing Quantization Aware Training (eager mode) on top of the unquantized + weights listed below. + """, + }, + ) + DEFAULT = IMAGENET1K_QNNPACK_V1 + + +@register_model(name="quantized_mobilenet_v2") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1 + if kwargs.get("quantize", False) + else MobileNet_V2_Weights.IMAGENET1K_V1, + ) +) +def mobilenet_v2( + *, + weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableMobileNetV2: + """ + Constructs a MobileNetV2 architecture from + `MobileNetV2: Inverted Residuals and Linear Bottlenecks + `_. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.MobileNet_V2_QuantizedWeights` or :class:`~torchvision.models.MobileNet_V2_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.MobileNet_V2_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + quantize (bool, optional): If True, returns a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableMobileNetV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.quantization.MobileNet_V2_QuantizedWeights + :members: + .. autoclass:: torchvision.models.MobileNet_V2_Weights + :members: + :noindex: + """ + weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "qnnpack") + + model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/quantization/mobilenetv3.py b/lib/python3.10/site-packages/torchvision/models/quantization/mobilenetv3.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fdcfec9570d35683efb10344e667d3f4487fce --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/quantization/mobilenetv3.py @@ -0,0 +1,237 @@ +from functools import partial +from typing import Any, List, Optional, Union + +import torch +from torch import nn, Tensor +from torch.ao.quantization import DeQuantStub, QuantStub + +from ...ops.misc import Conv2dNormActivation, SqueezeExcitation +from ...transforms._presets import ImageClassification +from .._api import register_model, Weights, WeightsEnum +from .._meta import _IMAGENET_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from ..mobilenetv3 import ( + _mobilenet_v3_conf, + InvertedResidual, + InvertedResidualConfig, + MobileNet_V3_Large_Weights, + MobileNetV3, +) +from .utils import _fuse_modules, _replace_relu + + +__all__ = [ + "QuantizableMobileNetV3", + "MobileNet_V3_Large_QuantizedWeights", + "mobilenet_v3_large", +] + + +class QuantizableSqueezeExcitation(SqueezeExcitation): + _version = 2 + + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs["scale_activation"] = nn.Hardsigmoid + super().__init__(*args, **kwargs) + self.skip_mul = nn.quantized.FloatFunctional() + + def forward(self, input: Tensor) -> Tensor: + return self.skip_mul.mul(self._scale(input), input) + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + _fuse_modules(self, ["fc1", "activation"], is_qat, inplace=True) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if hasattr(self, "qconfig") and (version is None or version < 2): + default_state_dict = { + "scale_activation.activation_post_process.scale": torch.tensor([1.0]), + "scale_activation.activation_post_process.activation_post_process.scale": torch.tensor([1.0]), + "scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32), + "scale_activation.activation_post_process.activation_post_process.zero_point": torch.tensor( + [0], dtype=torch.int32 + ), + "scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]), + "scale_activation.activation_post_process.observer_enabled": torch.tensor([1]), + } + for k, v in default_state_dict.items(): + full_key = prefix + k + if full_key not in state_dict: + state_dict[full_key] = v + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class QuantizableInvertedResidual(InvertedResidual): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, se_layer=QuantizableSqueezeExcitation, **kwargs) # type: ignore[misc] + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + if self.use_res_connect: + return self.skip_add.add(x, self.block(x)) + else: + return self.block(x) + + +class QuantizableMobileNetV3(MobileNetV3): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """ + MobileNet V3 main class + + Args: + Inherits args from floating point MobileNetV3 + """ + super().__init__(*args, **kwargs) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x: Tensor) -> Tensor: + x = self.quant(x) + x = self._forward_impl(x) + x = self.dequant(x) + return x + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + for m in self.modules(): + if type(m) is Conv2dNormActivation: + modules_to_fuse = ["0", "1"] + if len(m) == 3 and type(m[2]) is nn.ReLU: + modules_to_fuse.append("2") + _fuse_modules(m, modules_to_fuse, is_qat, inplace=True) + elif type(m) is QuantizableSqueezeExcitation: + m.fuse_model(is_qat) + + +def _mobilenet_v3_model( + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + weights: Optional[WeightsEnum], + progress: bool, + quantize: bool, + **kwargs: Any, +) -> QuantizableMobileNetV3: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "qnnpack") + + model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) + _replace_relu(model) + + if quantize: + # Instead of quantizing the model and then loading the quantized weights we take a different approach. + # We prepare the QAT model, load the QAT weights from training and then convert it. + # This is done to avoid extremely low accuracies observed on the specific model. This is rather a workaround + # for an unresolved bug on the eager quantization API detailed at: https://github.com/pytorch/vision/issues/5890 + model.fuse_model(is_qat=True) + model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) + torch.ao.quantization.prepare_qat(model, inplace=True) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + if quantize: + torch.ao.quantization.convert(model, inplace=True) + model.eval() + + return model + + +class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): + IMAGENET1K_QNNPACK_V1 = Weights( + url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + "num_params": 5483032, + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "backend": "qnnpack", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", + "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 73.004, + "acc@5": 90.858, + } + }, + "_ops": 0.217, + "_file_size": 21.554, + "_docs": """ + These weights were produced by doing Quantization Aware Training (eager mode) on top of the unquantized + weights listed below. + """, + }, + ) + DEFAULT = IMAGENET1K_QNNPACK_V1 + + +@register_model(name="quantized_mobilenet_v3_large") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1 + if kwargs.get("quantize", False) + else MobileNet_V3_Large_Weights.IMAGENET1K_V1, + ) +) +def mobilenet_v3_large( + *, + weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableMobileNetV3: + """ + MobileNetV3 (Large) model from + `Searching for MobileNetV3 `_. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights` or :class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool): If True, displays a progress bar of the + download to stderr. Default is True. + quantize (bool): If True, return a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights + :members: + .. autoclass:: torchvision.models.MobileNet_V3_Large_Weights + :members: + :noindex: + """ + weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) + return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) diff --git a/lib/python3.10/site-packages/torchvision/models/quantization/resnet.py b/lib/python3.10/site-packages/torchvision/models/quantization/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..39958a010fbd335709bc77a1aaf26c996584a398 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/quantization/resnet.py @@ -0,0 +1,484 @@ +from functools import partial +from typing import Any, List, Optional, Type, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torchvision.models.resnet import ( + BasicBlock, + Bottleneck, + ResNet, + ResNet18_Weights, + ResNet50_Weights, + ResNeXt101_32X8D_Weights, + ResNeXt101_64X4D_Weights, +) + +from ...transforms._presets import ImageClassification +from .._api import register_model, Weights, WeightsEnum +from .._meta import _IMAGENET_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from .utils import _fuse_modules, _replace_relu, quantize_model + + +__all__ = [ + "QuantizableResNet", + "ResNet18_QuantizedWeights", + "ResNet50_QuantizedWeights", + "ResNeXt101_32X8D_QuantizedWeights", + "ResNeXt101_64X4D_QuantizedWeights", + "resnet18", + "resnet50", + "resnext101_32x8d", + "resnext101_64x4d", +] + + +class QuantizableBasicBlock(BasicBlock): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.add_relu = torch.nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = self.add_relu.add_relu(out, identity) + + return out + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + _fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True) + if self.downsample: + _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True) + + +class QuantizableBottleneck(Bottleneck): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.skip_add_relu = nn.quantized.FloatFunctional() + self.relu1 = nn.ReLU(inplace=False) + self.relu2 = nn.ReLU(inplace=False) + + def forward(self, x: Tensor) -> Tensor: + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu2(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + out = self.skip_add_relu.add_relu(out, identity) + + return out + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + _fuse_modules( + self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True + ) + if self.downsample: + _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True) + + +class QuantizableResNet(ResNet): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x: Tensor) -> Tensor: + x = self.quant(x) + # Ensure scriptability + # super(QuantizableResNet,self).forward(x) + # is not scriptable + x = self._forward_impl(x) + x = self.dequant(x) + return x + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + r"""Fuse conv/bn/relu modules in resnet models + + Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization. + Model is modified in place. Note that this operation does not change numerics + and the model after modification is in floating point + """ + _fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True) + for m in self.modules(): + if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock: + m.fuse_model(is_qat) + + +def _resnet( + block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], + layers: List[int], + weights: Optional[WeightsEnum], + progress: bool, + quantize: bool, + **kwargs: Any, +) -> QuantizableResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") + + model = QuantizableResNet(block, layers, **kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META = { + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "backend": "fbgemm", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "_docs": """ + These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized + weights listed below. + """, +} + + +class ResNet18_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 11689512, + "unquantized": ResNet18_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 69.494, + "acc@5": 88.882, + } + }, + "_ops": 1.814, + "_file_size": 11.238, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ResNet50_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 25557032, + "unquantized": ResNet50_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 75.920, + "acc@5": 92.814, + } + }, + "_ops": 4.089, + "_file_size": 24.759, + }, + ) + IMAGENET1K_FBGEMM_V2 = Weights( + url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 25557032, + "unquantized": ResNet50_Weights.IMAGENET1K_V2, + "_metrics": { + "ImageNet-1K": { + "acc@1": 80.282, + "acc@5": 94.976, + } + }, + "_ops": 4.089, + "_file_size": 24.953, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V2 + + +class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 88791336, + "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 78.986, + "acc@5": 94.480, + } + }, + "_ops": 16.414, + "_file_size": 86.034, + }, + ) + IMAGENET1K_FBGEMM_V2 = Weights( + url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 88791336, + "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2, + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.574, + "acc@5": 96.132, + } + }, + "_ops": 16.414, + "_file_size": 86.645, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V2 + + +class ResNeXt101_64X4D_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnext101_64x4d_fbgemm-605a1cb3.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 83455272, + "recipe": "https://github.com/pytorch/vision/pull/5935", + "unquantized": ResNeXt101_64X4D_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.898, + "acc@5": 96.326, + } + }, + "_ops": 15.46, + "_file_size": 81.556, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@register_model(name="quantized_resnet18") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNet18_Weights.IMAGENET1K_V1, + ) +) +def resnet18( + *, + weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableResNet: + """ResNet-18 model from + `Deep Residual Learning for Image Recognition `_ + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` or :class:`~torchvision.models.ResNet18_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + quantize (bool, optional): If True, return a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ResNet18_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ResNet18_Weights + :members: + :noindex: + """ + weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights) + + return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) + + +@register_model(name="quantized_resnet50") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNet50_Weights.IMAGENET1K_V1, + ) +) +def resnet50( + *, + weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableResNet: + """ResNet-50 model from + `Deep Residual Learning for Image Recognition `_ + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` or :class:`~torchvision.models.ResNet50_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + quantize (bool, optional): If True, return a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ResNet50_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ResNet50_Weights + :members: + :noindex: + """ + weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights) + + return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) + + +@register_model(name="quantized_resnext101_32x8d") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNeXt101_32X8D_Weights.IMAGENET1K_V1, + ) +) +def resnext101_32x8d( + *, + weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableResNet: + """ResNeXt-101 32x8d model from + `Aggregated Residual Transformation for Deep Neural Networks `_ + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ResNet101_32X8D_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + quantize (bool, optional): If True, return a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights + :members: + :noindex: + """ + weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights) + + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 8) + return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) + + +@register_model(name="quantized_resnext101_64x4d") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNeXt101_64X4D_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNeXt101_64X4D_Weights.IMAGENET1K_V1, + ) +) +def resnext101_64x4d( + *, + weights: Optional[Union[ResNeXt101_64X4D_QuantizedWeights, ResNeXt101_64X4D_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableResNet: + """ResNeXt-101 64x4d model from + `Aggregated Residual Transformation for Deep Neural Networks `_ + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ResNet101_64X4D_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + quantize (bool, optional): If True, return a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights + :members: + :noindex: + """ + weights = (ResNeXt101_64X4D_QuantizedWeights if quantize else ResNeXt101_64X4D_Weights).verify(weights) + + _ovewrite_named_param(kwargs, "groups", 64) + _ovewrite_named_param(kwargs, "width_per_group", 4) + return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) diff --git a/lib/python3.10/site-packages/torchvision/models/quantization/shufflenetv2.py b/lib/python3.10/site-packages/torchvision/models/quantization/shufflenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..3e1b01356a74b8e4f16d66811060be698cfed199 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/quantization/shufflenetv2.py @@ -0,0 +1,427 @@ +from functools import partial +from typing import Any, List, Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torchvision.models import shufflenetv2 + +from ...transforms._presets import ImageClassification +from .._api import register_model, Weights, WeightsEnum +from .._meta import _IMAGENET_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from ..shufflenetv2 import ( + ShuffleNet_V2_X0_5_Weights, + ShuffleNet_V2_X1_0_Weights, + ShuffleNet_V2_X1_5_Weights, + ShuffleNet_V2_X2_0_Weights, +) +from .utils import _fuse_modules, _replace_relu, quantize_model + + +__all__ = [ + "QuantizableShuffleNetV2", + "ShuffleNet_V2_X0_5_QuantizedWeights", + "ShuffleNet_V2_X1_0_QuantizedWeights", + "ShuffleNet_V2_X1_5_QuantizedWeights", + "ShuffleNet_V2_X2_0_QuantizedWeights", + "shufflenet_v2_x0_5", + "shufflenet_v2_x1_0", + "shufflenet_v2_x1_5", + "shufflenet_v2_x2_0", +] + + +class QuantizableInvertedResidual(shufflenetv2.InvertedResidual): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.cat = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = self.cat.cat([x1, self.branch2(x2)], dim=1) + else: + out = self.cat.cat([self.branch1(x), self.branch2(x)], dim=1) + + out = shufflenetv2.channel_shuffle(out, 2) + + return out + + +class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs) # type: ignore[misc] + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x: Tensor) -> Tensor: + x = self.quant(x) + x = self._forward_impl(x) + x = self.dequant(x) + return x + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + r"""Fuse conv/bn/relu modules in shufflenetv2 model + + Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. + Model is modified in place. + + .. note:: + Note that this operation does not change numerics + and the model after modification is in floating point + """ + for name, m in self._modules.items(): + if name in ["conv1", "conv5"] and m is not None: + _fuse_modules(m, [["0", "1", "2"]], is_qat, inplace=True) + for m in self.modules(): + if type(m) is QuantizableInvertedResidual: + if len(m.branch1._modules.items()) > 0: + _fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], is_qat, inplace=True) + _fuse_modules( + m.branch2, + [["0", "1", "2"], ["3", "4"], ["5", "6", "7"]], + is_qat, + inplace=True, + ) + + +def _shufflenetv2( + stages_repeats: List[int], + stages_out_channels: List[int], + *, + weights: Optional[WeightsEnum], + progress: bool, + quantize: bool, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") + + model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META = { + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "backend": "fbgemm", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "_docs": """ + These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized + weights listed below. + """, +} + + +class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 1366792, + "unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 57.972, + "acc@5": 79.780, + } + }, + "_ops": 0.04, + "_file_size": 1.501, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-1e62bb32.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2278604, + "unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 68.360, + "acc@5": 87.582, + } + }, + "_ops": 0.145, + "_file_size": 2.334, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ShuffleNet_V2_X1_5_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_5_fbgemm-d7401f05.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/5906", + "num_params": 3503624, + "unquantized": ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 72.052, + "acc@5": 90.700, + } + }, + "_ops": 0.296, + "_file_size": 3.672, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x2_0_fbgemm-5cac526c.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/5906", + "num_params": 7393996, + "unquantized": ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 75.354, + "acc@5": 92.488, + } + }, + "_ops": 0.583, + "_file_size": 7.467, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@register_model(name="quantized_shufflenet_v2_x0_5") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, + ) +) +def shufflenet_v2_x0_5( + *, + weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 0.5x output channels, as described in + `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + `__. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ShuffleNet_V2_X0_5_QuantizedWeights` or :class:`~torchvision.models.ShuffleNet_V2_X0_5_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ShuffleNet_V2_X0_5_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. + Default is True. + quantize (bool, optional): If True, return a quantized version of the model. + Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.ShuffleNet_V2_X0_5_QuantizedWeights`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ShuffleNet_V2_X0_5_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ShuffleNet_V2_X0_5_Weights + :members: + :noindex: + """ + weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights) + return _shufflenetv2( + [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs + ) + + +@register_model(name="quantized_shufflenet_v2_x1_0") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, + ) +) +def shufflenet_v2_x1_0( + *, + weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 1.0x output channels, as described in + `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + `__. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ShuffleNet_V2_X1_0_QuantizedWeights` or :class:`~torchvision.models.ShuffleNet_V2_X1_0_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ShuffleNet_V2_X1_0_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. + Default is True. + quantize (bool, optional): If True, return a quantized version of the model. + Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.ShuffleNet_V2_X1_0_QuantizedWeights`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ShuffleNet_V2_X1_0_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ShuffleNet_V2_X1_0_Weights + :members: + :noindex: + """ + weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights) + return _shufflenetv2( + [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs + ) + + +@register_model(name="quantized_shufflenet_v2_x1_5") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X1_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1, + ) +) +def shufflenet_v2_x1_5( + *, + weights: Optional[Union[ShuffleNet_V2_X1_5_QuantizedWeights, ShuffleNet_V2_X1_5_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 1.5x output channels, as described in + `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + `__. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ShuffleNet_V2_X1_5_QuantizedWeights` or :class:`~torchvision.models.ShuffleNet_V2_X1_5_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ShuffleNet_V2_X1_5_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. + Default is True. + quantize (bool, optional): If True, return a quantized version of the model. + Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.ShuffleNet_V2_X1_5_QuantizedWeights`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ShuffleNet_V2_X1_5_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ShuffleNet_V2_X1_5_Weights + :members: + :noindex: + """ + weights = (ShuffleNet_V2_X1_5_QuantizedWeights if quantize else ShuffleNet_V2_X1_5_Weights).verify(weights) + return _shufflenetv2( + [4, 8, 4], [24, 176, 352, 704, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs + ) + + +@register_model(name="quantized_shufflenet_v2_x2_0") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X2_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1, + ) +) +def shufflenet_v2_x2_0( + *, + weights: Optional[Union[ShuffleNet_V2_X2_0_QuantizedWeights, ShuffleNet_V2_X2_0_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 2.0x output channels, as described in + `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + `__. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ShuffleNet_V2_X2_0_QuantizedWeights` or :class:`~torchvision.models.ShuffleNet_V2_X2_0_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ShuffleNet_V2_X2_0_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. + Default is True. + quantize (bool, optional): If True, return a quantized version of the model. + Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.ShuffleNet_V2_X2_0_QuantizedWeights`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ShuffleNet_V2_X2_0_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ShuffleNet_V2_X2_0_Weights + :members: + :noindex: + """ + weights = (ShuffleNet_V2_X2_0_QuantizedWeights if quantize else ShuffleNet_V2_X2_0_Weights).verify(weights) + return _shufflenetv2( + [4, 8, 4], [24, 244, 488, 976, 2048], weights=weights, progress=progress, quantize=quantize, **kwargs + ) diff --git a/lib/python3.10/site-packages/torchvision/models/quantization/utils.py b/lib/python3.10/site-packages/torchvision/models/quantization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a21e2af8e016568e79b25e35ec774f39f0595c3a --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/quantization/utils.py @@ -0,0 +1,51 @@ +from typing import Any, List, Optional, Union + +import torch +from torch import nn + + +def _replace_relu(module: nn.Module) -> None: + reassign = {} + for name, mod in module.named_children(): + _replace_relu(mod) + # Checking for explicit type instead of instance + # as we only want to replace modules of the exact type + # not inherited classes + if type(mod) is nn.ReLU or type(mod) is nn.ReLU6: + reassign[name] = nn.ReLU(inplace=False) + + for key, value in reassign.items(): + module._modules[key] = value + + +def quantize_model(model: nn.Module, backend: str) -> None: + _dummy_input_data = torch.rand(1, 3, 299, 299) + if backend not in torch.backends.quantized.supported_engines: + raise RuntimeError("Quantized backend not supported ") + torch.backends.quantized.engine = backend + model.eval() + # Make sure that weight qconfig matches that of the serialized models + if backend == "fbgemm": + model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] + activation=torch.ao.quantization.default_observer, + weight=torch.ao.quantization.default_per_channel_weight_observer, + ) + elif backend == "qnnpack": + model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] + activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer + ) + + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + model.fuse_model() # type: ignore[operator] + torch.ao.quantization.prepare(model, inplace=True) + model(_dummy_input_data) + torch.ao.quantization.convert(model, inplace=True) + + +def _fuse_modules( + model: nn.Module, modules_to_fuse: Union[List[str], List[List[str]]], is_qat: Optional[bool], **kwargs: Any +): + if is_qat is None: + is_qat = model.training + method = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules + return method(model, modules_to_fuse, **kwargs) diff --git a/lib/python3.10/site-packages/torchvision/models/segmentation/__init__.py b/lib/python3.10/site-packages/torchvision/models/segmentation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6f37f958a131b76ce80306718b77d78bc3f045 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/segmentation/__init__.py @@ -0,0 +1,3 @@ +from .deeplabv3 import * +from .fcn import * +from .lraspp import * diff --git a/lib/python3.10/site-packages/torchvision/models/segmentation/__pycache__/_utils.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/segmentation/__pycache__/_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f5fdd4a7d6c704486e924a638d983b6c5c3f5d5 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/segmentation/__pycache__/_utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/segmentation/__pycache__/deeplabv3.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/segmentation/__pycache__/deeplabv3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8902a3f9d177e05dd1ae87e0de952916a7de78bb Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/segmentation/__pycache__/deeplabv3.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/segmentation/__pycache__/fcn.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/segmentation/__pycache__/fcn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8d8ccb9894f88a73d18c7d52c2f21df5b92a6c0 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/segmentation/__pycache__/fcn.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/segmentation/__pycache__/lraspp.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/segmentation/__pycache__/lraspp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e18c00111807d1711398551372f9e84cc126e2da Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/segmentation/__pycache__/lraspp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/segmentation/_utils.py b/lib/python3.10/site-packages/torchvision/models/segmentation/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56560e9dab5c143699c918fa28236a902e530daf --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/segmentation/_utils.py @@ -0,0 +1,37 @@ +from collections import OrderedDict +from typing import Dict, Optional + +from torch import nn, Tensor +from torch.nn import functional as F + +from ...utils import _log_api_usage_once + + +class _SimpleSegmentationModel(nn.Module): + __constants__ = ["aux_classifier"] + + def __init__(self, backbone: nn.Module, classifier: nn.Module, aux_classifier: Optional[nn.Module] = None) -> None: + super().__init__() + _log_api_usage_once(self) + self.backbone = backbone + self.classifier = classifier + self.aux_classifier = aux_classifier + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + input_shape = x.shape[-2:] + # contract: features is a dict of tensors + features = self.backbone(x) + + result = OrderedDict() + x = features["out"] + x = self.classifier(x) + x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) + result["out"] = x + + if self.aux_classifier is not None: + x = features["aux"] + x = self.aux_classifier(x) + x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) + result["aux"] = x + + return result diff --git a/lib/python3.10/site-packages/torchvision/models/segmentation/deeplabv3.py b/lib/python3.10/site-packages/torchvision/models/segmentation/deeplabv3.py new file mode 100644 index 0000000000000000000000000000000000000000..a92ddfe3b7af41a9ffd371d34cd459ba57965c53 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/segmentation/deeplabv3.py @@ -0,0 +1,390 @@ +from functools import partial +from typing import Any, Optional, Sequence + +import torch +from torch import nn +from torch.nn import functional as F + +from ...transforms._presets import SemanticSegmentation +from .._api import register_model, Weights, WeightsEnum +from .._meta import _VOC_CATEGORIES +from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3 +from ..resnet import ResNet, resnet101, ResNet101_Weights, resnet50, ResNet50_Weights +from ._utils import _SimpleSegmentationModel +from .fcn import FCNHead + + +__all__ = [ + "DeepLabV3", + "DeepLabV3_ResNet50_Weights", + "DeepLabV3_ResNet101_Weights", + "DeepLabV3_MobileNet_V3_Large_Weights", + "deeplabv3_mobilenet_v3_large", + "deeplabv3_resnet50", + "deeplabv3_resnet101", +] + + +class DeepLabV3(_SimpleSegmentationModel): + """ + Implements DeepLabV3 model from + `"Rethinking Atrous Convolution for Semantic Image Segmentation" + `_. + + Args: + backbone (nn.Module): the network used to compute the features for the model. + The backbone should return an OrderedDict[Tensor], with the key being + "out" for the last feature map used, and "aux" if an auxiliary classifier + is used. + classifier (nn.Module): module that takes the "out" element returned from + the backbone and returns a dense prediction. + aux_classifier (nn.Module, optional): auxiliary classifier used during training + """ + + pass + + +class DeepLabHead(nn.Sequential): + def __init__(self, in_channels: int, num_classes: int, atrous_rates: Sequence[int] = (12, 24, 36)) -> None: + super().__init__( + ASPP(in_channels, atrous_rates), + nn.Conv2d(256, 256, 3, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256, num_classes, 1), + ) + + +class ASPPConv(nn.Sequential): + def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None: + modules = [ + nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ] + super().__init__(*modules) + + +class ASPPPooling(nn.Sequential): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + size = x.shape[-2:] + for mod in self: + x = mod(x) + return F.interpolate(x, size=size, mode="bilinear", align_corners=False) + + +class ASPP(nn.Module): + def __init__(self, in_channels: int, atrous_rates: Sequence[int], out_channels: int = 256) -> None: + super().__init__() + modules = [] + modules.append( + nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU()) + ) + + rates = tuple(atrous_rates) + for rate in rates: + modules.append(ASPPConv(in_channels, out_channels, rate)) + + modules.append(ASPPPooling(in_channels, out_channels)) + + self.convs = nn.ModuleList(modules) + + self.project = nn.Sequential( + nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.Dropout(0.5), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _res = [] + for conv in self.convs: + _res.append(conv(x)) + res = torch.cat(_res, dim=1) + return self.project(res) + + +def _deeplabv3_resnet( + backbone: ResNet, + num_classes: int, + aux: Optional[bool], +) -> DeepLabV3: + return_layers = {"layer4": "out"} + if aux: + return_layers["layer3"] = "aux" + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + aux_classifier = FCNHead(1024, num_classes) if aux else None + classifier = DeepLabHead(2048, num_classes) + return DeepLabV3(backbone, classifier, aux_classifier) + + +_COMMON_META = { + "categories": _VOC_CATEGORIES, + "min_size": (1, 1), + "_docs": """ + These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC + dataset. + """, +} + + +class DeepLabV3_ResNet50_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 42004074, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50", + "_metrics": { + "COCO-val2017-VOC-labels": { + "miou": 66.4, + "pixel_acc": 92.4, + } + }, + "_ops": 178.722, + "_file_size": 160.515, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class DeepLabV3_ResNet101_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 60996202, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101", + "_metrics": { + "COCO-val2017-VOC-labels": { + "miou": 67.4, + "pixel_acc": 92.4, + } + }, + "_ops": 258.743, + "_file_size": 233.217, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 11029328, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large", + "_metrics": { + "COCO-val2017-VOC-labels": { + "miou": 60.3, + "pixel_acc": 91.2, + } + }, + "_ops": 10.452, + "_file_size": 42.301, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +def _deeplabv3_mobilenetv3( + backbone: MobileNetV3, + num_classes: int, + aux: Optional[bool], +) -> DeepLabV3: + backbone = backbone.features + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] + out_pos = stage_indices[-1] # use C5 which has output_stride = 16 + out_inplanes = backbone[out_pos].out_channels + aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8 + aux_inplanes = backbone[aux_pos].out_channels + return_layers = {str(out_pos): "out"} + if aux: + return_layers[str(aux_pos)] = "aux" + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None + classifier = DeepLabHead(out_inplanes, num_classes) + return DeepLabV3(backbone, classifier, aux_classifier) + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) +def deeplabv3_resnet50( + *, + weights: Optional[DeepLabV3_ResNet50_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + aux_loss: Optional[bool] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> DeepLabV3: + """Constructs a DeepLabV3 model with a ResNet-50 backbone. + + .. betastatus:: segmentation module + + Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation `__. + + Args: + weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + aux_loss (bool, optional): If True, it uses an auxiliary loss + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for the + backbone + **kwargs: unused + + .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet50_Weights + :members: + """ + weights = DeepLabV3_ResNet50_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) + model = _deeplabv3_resnet(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), +) +def deeplabv3_resnet101( + *, + weights: Optional[DeepLabV3_ResNet101_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + aux_loss: Optional[bool] = None, + weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> DeepLabV3: + """Constructs a DeepLabV3 model with a ResNet-101 backbone. + + .. betastatus:: segmentation module + + Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation `__. + + Args: + weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + aux_loss (bool, optional): If True, it uses an auxiliary loss + weights_backbone (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained weights for the + backbone + **kwargs: unused + + .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet101_Weights + :members: + """ + weights = DeepLabV3_ResNet101_Weights.verify(weights) + weights_backbone = ResNet101_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) + model = _deeplabv3_resnet(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) +def deeplabv3_mobilenet_v3_large( + *, + weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + aux_loss: Optional[bool] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> DeepLabV3: + """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. + + Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation `__. + + Args: + weights (:class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + aux_loss (bool, optional): If True, it uses an auxiliary loss + weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights + for the backbone + **kwargs: unused + + .. autoclass:: torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights + :members: + """ + weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) + model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/segmentation/fcn.py b/lib/python3.10/site-packages/torchvision/models/segmentation/fcn.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2e242adac0e7430bab6155ae0347770e29fee9 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/segmentation/fcn.py @@ -0,0 +1,232 @@ +from functools import partial +from typing import Any, Optional + +from torch import nn + +from ...transforms._presets import SemanticSegmentation +from .._api import register_model, Weights, WeightsEnum +from .._meta import _VOC_CATEGORIES +from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter +from ..resnet import ResNet, resnet101, ResNet101_Weights, resnet50, ResNet50_Weights +from ._utils import _SimpleSegmentationModel + + +__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"] + + +class FCN(_SimpleSegmentationModel): + """ + Implements FCN model from + `"Fully Convolutional Networks for Semantic Segmentation" + `_. + + Args: + backbone (nn.Module): the network used to compute the features for the model. + The backbone should return an OrderedDict[Tensor], with the key being + "out" for the last feature map used, and "aux" if an auxiliary classifier + is used. + classifier (nn.Module): module that takes the "out" element returned from + the backbone and returns a dense prediction. + aux_classifier (nn.Module, optional): auxiliary classifier used during training + """ + + pass + + +class FCNHead(nn.Sequential): + def __init__(self, in_channels: int, channels: int) -> None: + inter_channels = in_channels // 4 + layers = [ + nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), + nn.BatchNorm2d(inter_channels), + nn.ReLU(), + nn.Dropout(0.1), + nn.Conv2d(inter_channels, channels, 1), + ] + + super().__init__(*layers) + + +_COMMON_META = { + "categories": _VOC_CATEGORIES, + "min_size": (1, 1), + "_docs": """ + These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC + dataset. + """, +} + + +class FCN_ResNet50_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 35322218, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50", + "_metrics": { + "COCO-val2017-VOC-labels": { + "miou": 60.5, + "pixel_acc": 91.4, + } + }, + "_ops": 152.717, + "_file_size": 135.009, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class FCN_ResNet101_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 54314346, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101", + "_metrics": { + "COCO-val2017-VOC-labels": { + "miou": 63.7, + "pixel_acc": 91.9, + } + }, + "_ops": 232.738, + "_file_size": 207.711, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +def _fcn_resnet( + backbone: ResNet, + num_classes: int, + aux: Optional[bool], +) -> FCN: + return_layers = {"layer4": "out"} + if aux: + return_layers["layer3"] = "aux" + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + aux_classifier = FCNHead(1024, num_classes) if aux else None + classifier = FCNHead(2048, num_classes) + return FCN(backbone, classifier, aux_classifier) + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) +def fcn_resnet50( + *, + weights: Optional[FCN_ResNet50_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + aux_loss: Optional[bool] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> FCN: + """Fully-Convolutional Network model with a ResNet-50 backbone from the `Fully Convolutional + Networks for Semantic Segmentation `_ paper. + + .. betastatus:: segmentation module + + Args: + weights (:class:`~torchvision.models.segmentation.FCN_ResNet50_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.segmentation.FCN_ResNet50_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background). + aux_loss (bool, optional): If True, it uses an auxiliary loss. + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained + weights for the backbone. + **kwargs: parameters passed to the ``torchvision.models.segmentation.fcn.FCN`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.segmentation.FCN_ResNet50_Weights + :members: + """ + + weights = FCN_ResNet50_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) + model = _fcn_resnet(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), +) +def fcn_resnet101( + *, + weights: Optional[FCN_ResNet101_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + aux_loss: Optional[bool] = None, + weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> FCN: + """Fully-Convolutional Network model with a ResNet-101 backbone from the `Fully Convolutional + Networks for Semantic Segmentation `_ paper. + + .. betastatus:: segmentation module + + Args: + weights (:class:`~torchvision.models.segmentation.FCN_ResNet101_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.segmentation.FCN_ResNet101_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background). + aux_loss (bool, optional): If True, it uses an auxiliary loss. + weights_backbone (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained + weights for the backbone. + **kwargs: parameters passed to the ``torchvision.models.segmentation.fcn.FCN`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.segmentation.FCN_ResNet101_Weights + :members: + """ + + weights = FCN_ResNet101_Weights.verify(weights) + weights_backbone = ResNet101_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) + model = _fcn_resnet(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/segmentation/lraspp.py b/lib/python3.10/site-packages/torchvision/models/segmentation/lraspp.py new file mode 100644 index 0000000000000000000000000000000000000000..70bced70fd37c3c681915492cea0c68c87cf0a7e --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/segmentation/lraspp.py @@ -0,0 +1,178 @@ +from collections import OrderedDict +from functools import partial +from typing import Any, Dict, Optional + +from torch import nn, Tensor +from torch.nn import functional as F + +from ...transforms._presets import SemanticSegmentation +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._meta import _VOC_CATEGORIES +from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3 + + +__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] + + +class LRASPP(nn.Module): + """ + Implements a Lite R-ASPP Network for semantic segmentation from + `"Searching for MobileNetV3" + `_. + + Args: + backbone (nn.Module): the network used to compute the features for the model. + The backbone should return an OrderedDict[Tensor], with the key being + "high" for the high level feature map and "low" for the low level feature map. + low_channels (int): the number of channels of the low level features. + high_channels (int): the number of channels of the high level features. + num_classes (int, optional): number of output classes of the model (including the background). + inter_channels (int, optional): the number of channels for intermediate computations. + """ + + def __init__( + self, backbone: nn.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128 + ) -> None: + super().__init__() + _log_api_usage_once(self) + self.backbone = backbone + self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels) + + def forward(self, input: Tensor) -> Dict[str, Tensor]: + features = self.backbone(input) + out = self.classifier(features) + out = F.interpolate(out, size=input.shape[-2:], mode="bilinear", align_corners=False) + + result = OrderedDict() + result["out"] = out + + return result + + +class LRASPPHead(nn.Module): + def __init__(self, low_channels: int, high_channels: int, num_classes: int, inter_channels: int) -> None: + super().__init__() + self.cbr = nn.Sequential( + nn.Conv2d(high_channels, inter_channels, 1, bias=False), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + ) + self.scale = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(high_channels, inter_channels, 1, bias=False), + nn.Sigmoid(), + ) + self.low_classifier = nn.Conv2d(low_channels, num_classes, 1) + self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1) + + def forward(self, input: Dict[str, Tensor]) -> Tensor: + low = input["low"] + high = input["high"] + + x = self.cbr(high) + s = self.scale(high) + x = x * s + x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False) + + return self.low_classifier(low) + self.high_classifier(x) + + +def _lraspp_mobilenetv3(backbone: MobileNetV3, num_classes: int) -> LRASPP: + backbone = backbone.features + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] + low_pos = stage_indices[-4] # use C2 here which has output_stride = 8 + high_pos = stage_indices[-1] # use C5 which has output_stride = 16 + low_channels = backbone[low_pos].out_channels + high_channels = backbone[high_pos].out_channels + backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"}) + + return LRASPP(backbone, low_channels, high_channels, num_classes) + + +class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + "num_params": 3221538, + "categories": _VOC_CATEGORIES, + "min_size": (1, 1), + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large", + "_metrics": { + "COCO-val2017-VOC-labels": { + "miou": 57.9, + "pixel_acc": 91.2, + } + }, + "_ops": 2.086, + "_file_size": 12.49, + "_docs": """ + These weights were trained on a subset of COCO, using only the 20 categories that are present in the + Pascal VOC dataset. + """, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) +def lraspp_mobilenet_v3_large( + *, + weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> LRASPP: + """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone from + `Searching for MobileNetV3 `_ paper. + + .. betastatus:: segmentation module + + Args: + weights (:class:`~torchvision.models.segmentation.LRASPP_MobileNet_V3_Large_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.segmentation.LRASPP_MobileNet_V3_Large_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background). + aux_loss (bool, optional): If True, it uses an auxiliary loss. + weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained + weights for the backbone. + **kwargs: parameters passed to the ``torchvision.models.segmentation.LRASPP`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.segmentation.LRASPP_MobileNet_V3_Large_Weights + :members: + """ + if kwargs.pop("aux_loss", False): + raise NotImplementedError("This model does not use auxiliary loss") + + weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 21 + + backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) + model = _lraspp_mobilenetv3(backbone, num_classes) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/video/__init__.py b/lib/python3.10/site-packages/torchvision/models/video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1eedd3116001af22ec202d2ccec6eefad8090ae --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/video/__init__.py @@ -0,0 +1,4 @@ +from .mvit import * +from .resnet import * +from .s3d import * +from .swin_transformer import * diff --git a/lib/python3.10/site-packages/torchvision/models/video/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/video/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..237633989ca13e30fb0345171d8a0931c533b485 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/video/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/video/__pycache__/mvit.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/video/__pycache__/mvit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32fa2b539c5a2f58c6f17a02fb54856fb35cd7a2 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/video/__pycache__/mvit.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/video/__pycache__/resnet.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/video/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf19b3627348ffe2fb56b7672de80aae61c32c09 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/video/__pycache__/resnet.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/video/__pycache__/s3d.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/video/__pycache__/s3d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a20509612fc036eae4471e03d031727d0164d2d7 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/video/__pycache__/s3d.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/video/__pycache__/swin_transformer.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/models/video/__pycache__/swin_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d5c5c35e5095c1a90b0e4aa4ff10332d8cb46a7 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/models/video/__pycache__/swin_transformer.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/models/video/mvit.py b/lib/python3.10/site-packages/torchvision/models/video/mvit.py new file mode 100644 index 0000000000000000000000000000000000000000..159c12a4f3eac579f4e122741f57cc60f5cd0a23 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/video/mvit.py @@ -0,0 +1,897 @@ +import math +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.fx +import torch.nn as nn + +from ...ops import MLP, StochasticDepth +from ...transforms._presets import VideoClassification +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._meta import _KINETICS400_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface + + +__all__ = [ + "MViT", + "MViT_V1_B_Weights", + "mvit_v1_b", + "MViT_V2_S_Weights", + "mvit_v2_s", +] + + +@dataclass +class MSBlockConfig: + num_heads: int + input_channels: int + output_channels: int + kernel_q: List[int] + kernel_kv: List[int] + stride_q: List[int] + stride_kv: List[int] + + +def _prod(s: Sequence[int]) -> int: + product = 1 + for v in s: + product *= v + return product + + +def _unsqueeze(x: torch.Tensor, target_dim: int, expand_dim: int) -> Tuple[torch.Tensor, int]: + tensor_dim = x.dim() + if tensor_dim == target_dim - 1: + x = x.unsqueeze(expand_dim) + elif tensor_dim != target_dim: + raise ValueError(f"Unsupported input dimension {x.shape}") + return x, tensor_dim + + +def _squeeze(x: torch.Tensor, target_dim: int, expand_dim: int, tensor_dim: int) -> torch.Tensor: + if tensor_dim == target_dim - 1: + x = x.squeeze(expand_dim) + return x + + +torch.fx.wrap("_unsqueeze") +torch.fx.wrap("_squeeze") + + +class Pool(nn.Module): + def __init__( + self, + pool: nn.Module, + norm: Optional[nn.Module], + activation: Optional[nn.Module] = None, + norm_before_pool: bool = False, + ) -> None: + super().__init__() + self.pool = pool + layers = [] + if norm is not None: + layers.append(norm) + if activation is not None: + layers.append(activation) + self.norm_act = nn.Sequential(*layers) if layers else None + self.norm_before_pool = norm_before_pool + + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + x, tensor_dim = _unsqueeze(x, 4, 1) + + # Separate the class token and reshape the input + class_token, x = torch.tensor_split(x, indices=(1,), dim=2) + x = x.transpose(2, 3) + B, N, C = x.shape[:3] + x = x.reshape((B * N, C) + thw).contiguous() + + # normalizing prior pooling is useful when we use BN which can be absorbed to speed up inference + if self.norm_before_pool and self.norm_act is not None: + x = self.norm_act(x) + + # apply the pool on the input and add back the token + x = self.pool(x) + T, H, W = x.shape[2:] + x = x.reshape(B, N, C, -1).transpose(2, 3) + x = torch.cat((class_token, x), dim=2) + + if not self.norm_before_pool and self.norm_act is not None: + x = self.norm_act(x) + + x = _squeeze(x, 4, 1, tensor_dim) + return x, (T, H, W) + + +def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor: + if embedding.shape[0] == d: + return embedding + + return ( + nn.functional.interpolate( + embedding.permute(1, 0).unsqueeze(0), + size=d, + mode="linear", + ) + .squeeze(0) + .permute(1, 0) + ) + + +def _add_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + q_thw: Tuple[int, int, int], + k_thw: Tuple[int, int, int], + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + rel_pos_t: torch.Tensor, +) -> torch.Tensor: + # Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932 + q_t, q_h, q_w = q_thw + k_t, k_h, k_w = k_thw + dh = int(2 * max(q_h, k_h) - 1) + dw = int(2 * max(q_w, k_w) - 1) + dt = int(2 * max(q_t, k_t) - 1) + + # Scale up rel pos if shapes for q and k are different. + q_h_ratio = max(k_h / q_h, 1.0) + k_h_ratio = max(q_h / k_h, 1.0) + dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio + q_w_ratio = max(k_w / q_w, 1.0) + k_w_ratio = max(q_w / k_w, 1.0) + dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio + q_t_ratio = max(k_t / q_t, 1.0) + k_t_ratio = max(q_t / k_t, 1.0) + dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio + + # Interpolate rel pos if needed. + rel_pos_h = _interpolate(rel_pos_h, dh) + rel_pos_w = _interpolate(rel_pos_w, dw) + rel_pos_t = _interpolate(rel_pos_t, dt) + Rh = rel_pos_h[dist_h.long()] + Rw = rel_pos_w[dist_w.long()] + Rt = rel_pos_t[dist_t.long()] + + B, n_head, _, dim = q.shape + + r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim) + rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h] + rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w] + # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim] + r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim) + # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t] + rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) + # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t] + rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) + + # Combine rel pos. + rel_pos = ( + rel_h_q[:, :, :, :, :, None, :, None] + + rel_w_q[:, :, :, :, :, None, None, :] + + rel_q_t[:, :, :, :, :, :, None, None] + ).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w) + + # Add it to attention + attn[:, :, 1:, 1:] += rel_pos + + return attn + + +def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_with_cls_embed: bool): + if residual_with_cls_embed: + x.add_(shortcut) + else: + x[:, :, 1:, :] += shortcut[:, :, 1:, :] + return x + + +torch.fx.wrap("_add_rel_pos") +torch.fx.wrap("_add_shortcut") + + +class MultiscaleAttention(nn.Module): + def __init__( + self, + input_size: List[int], + embed_dim: int, + output_dim: int, + num_heads: int, + kernel_q: List[int], + kernel_kv: List[int], + stride_q: List[int], + stride_kv: List[int], + residual_pool: bool, + residual_with_cls_embed: bool, + rel_pos_embed: bool, + dropout: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.output_dim = output_dim + self.num_heads = num_heads + self.head_dim = output_dim // num_heads + self.scaler = 1.0 / math.sqrt(self.head_dim) + self.residual_pool = residual_pool + self.residual_with_cls_embed = residual_with_cls_embed + + self.qkv = nn.Linear(embed_dim, 3 * output_dim) + layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)] + if dropout > 0.0: + layers.append(nn.Dropout(dropout, inplace=True)) + self.project = nn.Sequential(*layers) + + self.pool_q: Optional[nn.Module] = None + if _prod(kernel_q) > 1 or _prod(stride_q) > 1: + padding_q = [int(q // 2) for q in kernel_q] + self.pool_q = Pool( + nn.Conv3d( + self.head_dim, + self.head_dim, + kernel_q, # type: ignore[arg-type] + stride=stride_q, # type: ignore[arg-type] + padding=padding_q, # type: ignore[arg-type] + groups=self.head_dim, + bias=False, + ), + norm_layer(self.head_dim), + ) + + self.pool_k: Optional[nn.Module] = None + self.pool_v: Optional[nn.Module] = None + if _prod(kernel_kv) > 1 or _prod(stride_kv) > 1: + padding_kv = [int(kv // 2) for kv in kernel_kv] + self.pool_k = Pool( + nn.Conv3d( + self.head_dim, + self.head_dim, + kernel_kv, # type: ignore[arg-type] + stride=stride_kv, # type: ignore[arg-type] + padding=padding_kv, # type: ignore[arg-type] + groups=self.head_dim, + bias=False, + ), + norm_layer(self.head_dim), + ) + self.pool_v = Pool( + nn.Conv3d( + self.head_dim, + self.head_dim, + kernel_kv, # type: ignore[arg-type] + stride=stride_kv, # type: ignore[arg-type] + padding=padding_kv, # type: ignore[arg-type] + groups=self.head_dim, + bias=False, + ), + norm_layer(self.head_dim), + ) + + self.rel_pos_h: Optional[nn.Parameter] = None + self.rel_pos_w: Optional[nn.Parameter] = None + self.rel_pos_t: Optional[nn.Parameter] = None + if rel_pos_embed: + size = max(input_size[1:]) + q_size = size // stride_q[1] if len(stride_q) > 0 else size + kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size + spatial_dim = 2 * max(q_size, kv_size) - 1 + temporal_dim = 2 * input_size[0] - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(spatial_dim, self.head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(spatial_dim, self.head_dim)) + self.rel_pos_t = nn.Parameter(torch.zeros(temporal_dim, self.head_dim)) + nn.init.trunc_normal_(self.rel_pos_h, std=0.02) + nn.init.trunc_normal_(self.rel_pos_w, std=0.02) + nn.init.trunc_normal_(self.rel_pos_t, std=0.02) + + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + B, N, C = x.shape + q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2) + + if self.pool_k is not None: + k, k_thw = self.pool_k(k, thw) + else: + k_thw = thw + if self.pool_v is not None: + v = self.pool_v(v, thw)[0] + if self.pool_q is not None: + q, thw = self.pool_q(q, thw) + + attn = torch.matmul(self.scaler * q, k.transpose(2, 3)) + if self.rel_pos_h is not None and self.rel_pos_w is not None and self.rel_pos_t is not None: + attn = _add_rel_pos( + attn, + q, + thw, + k_thw, + self.rel_pos_h, + self.rel_pos_w, + self.rel_pos_t, + ) + attn = attn.softmax(dim=-1) + + x = torch.matmul(attn, v) + if self.residual_pool: + _add_shortcut(x, q, self.residual_with_cls_embed) + x = x.transpose(1, 2).reshape(B, -1, self.output_dim) + x = self.project(x) + + return x, thw + + +class MultiscaleBlock(nn.Module): + def __init__( + self, + input_size: List[int], + cnf: MSBlockConfig, + residual_pool: bool, + residual_with_cls_embed: bool, + rel_pos_embed: bool, + proj_after_attn: bool, + dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + self.proj_after_attn = proj_after_attn + + self.pool_skip: Optional[nn.Module] = None + if _prod(cnf.stride_q) > 1: + kernel_skip = [s + 1 if s > 1 else s for s in cnf.stride_q] + padding_skip = [int(k // 2) for k in kernel_skip] + self.pool_skip = Pool( + nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type] + ) + + attn_dim = cnf.output_channels if proj_after_attn else cnf.input_channels + + self.norm1 = norm_layer(cnf.input_channels) + self.norm2 = norm_layer(attn_dim) + self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d) + + self.attn = MultiscaleAttention( + input_size, + cnf.input_channels, + attn_dim, + cnf.num_heads, + kernel_q=cnf.kernel_q, + kernel_kv=cnf.kernel_kv, + stride_q=cnf.stride_q, + stride_kv=cnf.stride_kv, + rel_pos_embed=rel_pos_embed, + residual_pool=residual_pool, + residual_with_cls_embed=residual_with_cls_embed, + dropout=dropout, + norm_layer=norm_layer, + ) + self.mlp = MLP( + attn_dim, + [4 * attn_dim, cnf.output_channels], + activation_layer=nn.GELU, + dropout=dropout, + inplace=None, + ) + + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + + self.project: Optional[nn.Module] = None + if cnf.input_channels != cnf.output_channels: + self.project = nn.Linear(cnf.input_channels, cnf.output_channels) + + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x) + x_attn, thw_new = self.attn(x_norm1, thw) + x = x if self.project is None or not self.proj_after_attn else self.project(x_norm1) + x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0] + x = x_skip + self.stochastic_depth(x_attn) + + x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x) + x_proj = x if self.project is None or self.proj_after_attn else self.project(x_norm2) + + return x_proj + self.stochastic_depth(self.mlp(x_norm2)), thw_new + + +class PositionalEncoding(nn.Module): + def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int, rel_pos_embed: bool) -> None: + super().__init__() + self.spatial_size = spatial_size + self.temporal_size = temporal_size + + self.class_token = nn.Parameter(torch.zeros(embed_size)) + self.spatial_pos: Optional[nn.Parameter] = None + self.temporal_pos: Optional[nn.Parameter] = None + self.class_pos: Optional[nn.Parameter] = None + if not rel_pos_embed: + self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size)) + self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size)) + self.class_pos = nn.Parameter(torch.zeros(embed_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1) + x = torch.cat((class_token, x), dim=1) + + if self.spatial_pos is not None and self.temporal_pos is not None and self.class_pos is not None: + hw_size, embed_size = self.spatial_pos.shape + pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0) + pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size)) + pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0) + x.add_(pos_embedding) + + return x + + +class MViT(nn.Module): + def __init__( + self, + spatial_size: Tuple[int, int], + temporal_size: int, + block_setting: Sequence[MSBlockConfig], + residual_pool: bool, + residual_with_cls_embed: bool, + rel_pos_embed: bool, + proj_after_attn: bool, + dropout: float = 0.5, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + num_classes: int = 400, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + patch_embed_kernel: Tuple[int, int, int] = (3, 7, 7), + patch_embed_stride: Tuple[int, int, int] = (2, 4, 4), + patch_embed_padding: Tuple[int, int, int] = (1, 3, 3), + ) -> None: + """ + MViT main class. + + Args: + spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``. + temporal_size (int): The temporal size ``T`` of the input. + block_setting (sequence of MSBlockConfig): The Network structure. + residual_pool (bool): If True, use MViTv2 pooling residual connection. + residual_with_cls_embed (bool): If True, the addition on the residual connection will include + the class embedding. + rel_pos_embed (bool): If True, use MViTv2's relative positional embeddings. + proj_after_attn (bool): If True, apply the projection after the attention. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + num_classes (int): The number of classes. + block (callable, optional): Module specifying the layer which consists of the attention and mlp. + norm_layer (callable, optional): Module specifying the normalization layer to use. + patch_embed_kernel (tuple of ints): The kernel of the convolution that patchifies the input. + patch_embed_stride (tuple of ints): The stride of the convolution that patchifies the input. + patch_embed_padding (tuple of ints): The padding of the convolution that patchifies the input. + """ + super().__init__() + # This implementation employs a different parameterization scheme than the one used at PyTorch Video: + # https://github.com/facebookresearch/pytorchvideo/blob/718d0a4/pytorchvideo/models/vision_transformers.py + # We remove any experimental configuration that didn't make it to the final variants of the models. To represent + # the configuration of the architecture we use the simplified form suggested at Table 1 of the paper. + _log_api_usage_once(self) + total_stage_blocks = len(block_setting) + if total_stage_blocks == 0: + raise ValueError("The configuration parameter can't be empty.") + + if block is None: + block = MultiscaleBlock + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + # Patch Embedding module + self.conv_proj = nn.Conv3d( + in_channels=3, + out_channels=block_setting[0].input_channels, + kernel_size=patch_embed_kernel, + stride=patch_embed_stride, + padding=patch_embed_padding, + ) + + input_size = [size // stride for size, stride in zip((temporal_size,) + spatial_size, self.conv_proj.stride)] + + # Spatio-Temporal Class Positional Encoding + self.pos_encoding = PositionalEncoding( + embed_size=block_setting[0].input_channels, + spatial_size=(input_size[1], input_size[2]), + temporal_size=input_size[0], + rel_pos_embed=rel_pos_embed, + ) + + # Encoder module + self.blocks = nn.ModuleList() + for stage_block_id, cnf in enumerate(block_setting): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) + + self.blocks.append( + block( + input_size=input_size, + cnf=cnf, + residual_pool=residual_pool, + residual_with_cls_embed=residual_with_cls_embed, + rel_pos_embed=rel_pos_embed, + proj_after_attn=proj_after_attn, + dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, + ) + ) + + if len(cnf.stride_q) > 0: + input_size = [size // stride for size, stride in zip(input_size, cnf.stride_q)] + self.norm = norm_layer(block_setting[-1].output_channels) + + # Classifier module + self.head = nn.Sequential( + nn.Dropout(dropout, inplace=True), + nn.Linear(block_setting[-1].output_channels, num_classes), + ) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.LayerNorm): + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, PositionalEncoding): + for weights in m.parameters(): + nn.init.trunc_normal_(weights, std=0.02) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Convert if necessary (B, C, H, W) -> (B, C, 1, H, W) + x = _unsqueeze(x, 5, 2)[0] + # patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0]) + x = self.conv_proj(x) + x = x.flatten(2).transpose(1, 2) + + # add positional encoding + x = self.pos_encoding(x) + + # pass patches through the encoder + thw = (self.pos_encoding.temporal_size,) + self.pos_encoding.spatial_size + for block in self.blocks: + x, thw = block(x, thw) + x = self.norm(x) + + # classifier "token" as used by standard language architectures + x = x[:, 0] + x = self.head(x) + + return x + + +def _mvit( + block_setting: List[MSBlockConfig], + stochastic_depth_prob: float, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> MViT: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + assert weights.meta["min_size"][0] == weights.meta["min_size"][1] + _ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"]) + _ovewrite_named_param(kwargs, "temporal_size", weights.meta["min_temporal_size"]) + spatial_size = kwargs.pop("spatial_size", (224, 224)) + temporal_size = kwargs.pop("temporal_size", 16) + + model = MViT( + spatial_size=spatial_size, + temporal_size=temporal_size, + block_setting=block_setting, + residual_pool=kwargs.pop("residual_pool", False), + residual_with_cls_embed=kwargs.pop("residual_with_cls_embed", True), + rel_pos_embed=kwargs.pop("rel_pos_embed", False), + proj_after_attn=kwargs.pop("proj_after_attn", False), + stochastic_depth_prob=stochastic_depth_prob, + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +class MViT_V1_B_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mvit_v1_b-dbeb1030.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.45, 0.45, 0.45), + std=(0.225, 0.225, 0.225), + ), + meta={ + "min_size": (224, 224), + "min_temporal_size": 16, + "categories": _KINETICS400_CATEGORIES, + "recipe": "https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`" + ), + "num_params": 36610672, + "_metrics": { + "Kinetics-400": { + "acc@1": 78.477, + "acc@5": 93.582, + } + }, + "_ops": 70.599, + "_file_size": 139.764, + }, + ) + DEFAULT = KINETICS400_V1 + + +class MViT_V2_S_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.45, 0.45, 0.45), + std=(0.225, 0.225, 0.225), + ), + meta={ + "min_size": (224, 224), + "min_temporal_size": 16, + "categories": _KINETICS400_CATEGORIES, + "recipe": "https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`" + ), + "num_params": 34537744, + "_metrics": { + "Kinetics-400": { + "acc@1": 80.757, + "acc@5": 94.665, + } + }, + "_ops": 64.224, + "_file_size": 131.884, + }, + ) + DEFAULT = KINETICS400_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", MViT_V1_B_Weights.KINETICS400_V1)) +def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: + """ + Constructs a base MViTV1 architecture from + `Multiscale Vision Transformers `__. + + .. betastatus:: video module + + Args: + weights (:class:`~torchvision.models.video.MViT_V1_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MViT_V1_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.MViT`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MViT_V1_B_Weights + :members: + """ + weights = MViT_V1_B_Weights.verify(weights) + + config: Dict[str, List] = { + "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8], + "input_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768], + "output_channels": [192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768, 768], + "kernel_q": [[], [3, 3, 3], [], [3, 3, 3], [], [], [], [], [], [], [], [], [], [], [3, 3, 3], []], + "kernel_kv": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + "stride_q": [[], [1, 2, 2], [], [1, 2, 2], [], [], [], [], [], [], [], [], [], [], [1, 2, 2], []], + "stride_kv": [ + [1, 8, 8], + [1, 4, 4], + [1, 4, 4], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 1, 1], + [1, 1, 1], + ], + } + + block_setting = [] + for i in range(len(config["num_heads"])): + block_setting.append( + MSBlockConfig( + num_heads=config["num_heads"][i], + input_channels=config["input_channels"][i], + output_channels=config["output_channels"][i], + kernel_q=config["kernel_q"][i], + kernel_kv=config["kernel_kv"][i], + stride_q=config["stride_q"][i], + stride_kv=config["stride_kv"][i], + ) + ) + + return _mvit( + spatial_size=(224, 224), + temporal_size=16, + block_setting=block_setting, + residual_pool=False, + residual_with_cls_embed=False, + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", MViT_V2_S_Weights.KINETICS400_V1)) +def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: + """Constructs a small MViTV2 architecture from + `Multiscale Vision Transformers `__ and + `MViTv2: Improved Multiscale Vision Transformers for Classification + and Detection `__. + + .. betastatus:: video module + + Args: + weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MViT_V2_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.MViT`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MViT_V2_S_Weights + :members: + """ + weights = MViT_V2_S_Weights.verify(weights) + + config: Dict[str, List] = { + "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8], + "input_channels": [96, 96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768], + "output_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768], + "kernel_q": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + "kernel_kv": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + "stride_q": [ + [1, 1, 1], + [1, 2, 2], + [1, 1, 1], + [1, 2, 2], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 2, 2], + [1, 1, 1], + ], + "stride_kv": [ + [1, 8, 8], + [1, 4, 4], + [1, 4, 4], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 1, 1], + [1, 1, 1], + ], + } + + block_setting = [] + for i in range(len(config["num_heads"])): + block_setting.append( + MSBlockConfig( + num_heads=config["num_heads"][i], + input_channels=config["input_channels"][i], + output_channels=config["output_channels"][i], + kernel_q=config["kernel_q"][i], + kernel_kv=config["kernel_kv"][i], + stride_q=config["stride_q"][i], + stride_kv=config["stride_kv"][i], + ) + ) + + return _mvit( + spatial_size=(224, 224), + temporal_size=16, + block_setting=block_setting, + residual_pool=True, + residual_with_cls_embed=False, + rel_pos_embed=True, + proj_after_attn=True, + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), + weights=weights, + progress=progress, + **kwargs, + ) diff --git a/lib/python3.10/site-packages/torchvision/models/video/resnet.py b/lib/python3.10/site-packages/torchvision/models/video/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a1cb2884013c053118555344617e4b1efb8ddaab --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/video/resnet.py @@ -0,0 +1,503 @@ +from functools import partial +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union + +import torch.nn as nn +from torch import Tensor + +from ...transforms._presets import VideoClassification +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._meta import _KINETICS400_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface + + +__all__ = [ + "VideoResNet", + "R3D_18_Weights", + "MC3_18_Weights", + "R2Plus1D_18_Weights", + "r3d_18", + "mc3_18", + "r2plus1d_18", +] + + +class Conv3DSimple(nn.Conv3d): + def __init__( + self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1 + ) -> None: + + super().__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(3, 3, 3), + stride=stride, + padding=padding, + bias=False, + ) + + @staticmethod + def get_downsample_stride(stride: int) -> Tuple[int, int, int]: + return stride, stride, stride + + +class Conv2Plus1D(nn.Sequential): + def __init__(self, in_planes: int, out_planes: int, midplanes: int, stride: int = 1, padding: int = 1) -> None: + super().__init__( + nn.Conv3d( + in_planes, + midplanes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False, + ), + nn.BatchNorm3d(midplanes), + nn.ReLU(inplace=True), + nn.Conv3d( + midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False + ), + ) + + @staticmethod + def get_downsample_stride(stride: int) -> Tuple[int, int, int]: + return stride, stride, stride + + +class Conv3DNoTemporal(nn.Conv3d): + def __init__( + self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1 + ) -> None: + + super().__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False, + ) + + @staticmethod + def get_downsample_stride(stride: int) -> Tuple[int, int, int]: + return 1, stride, stride + + +class BasicBlock(nn.Module): + + expansion = 1 + + def __init__( + self, + inplanes: int, + planes: int, + conv_builder: Callable[..., nn.Module], + stride: int = 1, + downsample: Optional[nn.Module] = None, + ) -> None: + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + super().__init__() + self.conv1 = nn.Sequential( + conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes)) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + residual = x + + out = self.conv1(x) + out = self.conv2(out) + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__( + self, + inplanes: int, + planes: int, + conv_builder: Callable[..., nn.Module], + stride: int = 1, + downsample: Optional[nn.Module] = None, + ) -> None: + + super().__init__() + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + # 1x1x1 + self.conv1 = nn.Sequential( + nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) + ) + # Second kernel + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) + ) + + # 1x1x1 + self.conv3 = nn.Sequential( + nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), + nn.BatchNorm3d(planes * self.expansion), + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + residual = x + + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BasicStem(nn.Sequential): + """The default conv-batchnorm-relu stem""" + + def __init__(self) -> None: + super().__init__( + nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + ) + + +class R2Plus1dStem(nn.Sequential): + """R(2+1)D stem is different than the default one as it uses separated 3D convolution""" + + def __init__(self) -> None: + super().__init__( + nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False), + nn.BatchNorm3d(45), + nn.ReLU(inplace=True), + nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + ) + + +class VideoResNet(nn.Module): + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], + layers: List[int], + stem: Callable[..., nn.Module], + num_classes: int = 400, + zero_init_residual: bool = False, + ) -> None: + """Generic resnet video generator. + + Args: + block (Type[Union[BasicBlock, Bottleneck]]): resnet building block + conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator + function for each layer + layers (List[int]): number of blocks per layer + stem (Callable[..., nn.Module]): module specifying the ResNet stem. + num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. + zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. + """ + super().__init__() + _log_api_usage_once(self) + self.inplanes = 64 + + self.stem = stem() + + self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) + self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) + + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + # init weights + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[union-attr, arg-type] + + def forward(self, x: Tensor) -> Tensor: + x = self.stem(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + # Flatten the layer to fc + x = x.flatten(1) + x = self.fc(x) + + return x + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]], + planes: int, + blocks: int, + stride: int = 1, + ) -> nn.Sequential: + downsample = None + + if stride != 1 or self.inplanes != planes * block.expansion: + ds_stride = conv_builder.get_downsample_stride(stride) + downsample = nn.Sequential( + nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False), + nn.BatchNorm3d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, conv_builder)) + + return nn.Sequential(*layers) + + +def _video_resnet( + block: Type[Union[BasicBlock, Bottleneck]], + conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], + layers: List[int], + stem: Callable[..., nn.Module], + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> VideoResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = VideoResNet(block, conv_makers, layers, stem, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META = { + "min_size": (1, 1), + "categories": _KINETICS400_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification", + "_docs": ( + "The weights reproduce closely the accuracy of the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=15`, `clips_per_video=5`, and `clip_len=16`." + ), +} + + +class R3D_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "num_params": 33371472, + "_metrics": { + "Kinetics-400": { + "acc@1": 63.200, + "acc@5": 83.479, + } + }, + "_ops": 40.697, + "_file_size": 127.359, + }, + ) + DEFAULT = KINETICS400_V1 + + +class MC3_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "num_params": 11695440, + "_metrics": { + "Kinetics-400": { + "acc@1": 63.960, + "acc@5": 84.130, + } + }, + "_ops": 43.343, + "_file_size": 44.672, + }, + ) + DEFAULT = KINETICS400_V1 + + +class R2Plus1D_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "num_params": 31505325, + "_metrics": { + "Kinetics-400": { + "acc@1": 67.463, + "acc@5": 86.175, + } + }, + "_ops": 40.519, + "_file_size": 120.318, + }, + ) + DEFAULT = KINETICS400_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1)) +def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: + """Construct 18 layer Resnet3D model. + + .. betastatus:: video module + + Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition `__. + + Args: + weights (:class:`~torchvision.models.video.R3D_18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.R3D_18_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class. + Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.R3D_18_Weights + :members: + """ + weights = R3D_18_Weights.verify(weights) + + return _video_resnet( + BasicBlock, + [Conv3DSimple] * 4, + [2, 2, 2, 2], + BasicStem, + weights, + progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1)) +def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: + """Construct 18 layer Mixed Convolution network as in + + .. betastatus:: video module + + Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition `__. + + Args: + weights (:class:`~torchvision.models.video.MC3_18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MC3_18_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class. + Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MC3_18_Weights + :members: + """ + weights = MC3_18_Weights.verify(weights) + + return _video_resnet( + BasicBlock, + [Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] + [2, 2, 2, 2], + BasicStem, + weights, + progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1)) +def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: + """Construct 18 layer deep R(2+1)D network as in + + .. betastatus:: video module + + Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition `__. + + Args: + weights (:class:`~torchvision.models.video.R2Plus1D_18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.R2Plus1D_18_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class. + Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.R2Plus1D_18_Weights + :members: + """ + weights = R2Plus1D_18_Weights.verify(weights) + + return _video_resnet( + BasicBlock, + [Conv2Plus1D] * 4, + [2, 2, 2, 2], + R2Plus1dStem, + weights, + progress, + **kwargs, + ) + + +# The dictionary below is internal implementation detail and will be removed in v0.15 +from .._utils import _ModelURLs + + +model_urls = _ModelURLs( + { + "r3d_18": R3D_18_Weights.KINETICS400_V1.url, + "mc3_18": MC3_18_Weights.KINETICS400_V1.url, + "r2plus1d_18": R2Plus1D_18_Weights.KINETICS400_V1.url, + } +) diff --git a/lib/python3.10/site-packages/torchvision/models/video/s3d.py b/lib/python3.10/site-packages/torchvision/models/video/s3d.py new file mode 100644 index 0000000000000000000000000000000000000000..4b202829b24fb1dc314452d38a521dfe6c8e446f --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/video/s3d.py @@ -0,0 +1,219 @@ +from functools import partial +from typing import Any, Callable, Optional + +import torch +from torch import nn +from torchvision.ops.misc import Conv3dNormActivation + +from ...transforms._presets import VideoClassification +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._meta import _KINETICS400_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface + + +__all__ = [ + "S3D", + "S3D_Weights", + "s3d", +] + + +class TemporalSeparableConv(nn.Sequential): + def __init__( + self, + in_planes: int, + out_planes: int, + kernel_size: int, + stride: int, + padding: int, + norm_layer: Callable[..., nn.Module], + ): + super().__init__( + Conv3dNormActivation( + in_planes, + out_planes, + kernel_size=(1, kernel_size, kernel_size), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False, + norm_layer=norm_layer, + ), + Conv3dNormActivation( + out_planes, + out_planes, + kernel_size=(kernel_size, 1, 1), + stride=(stride, 1, 1), + padding=(padding, 0, 0), + bias=False, + norm_layer=norm_layer, + ), + ) + + +class SepInceptionBlock3D(nn.Module): + def __init__( + self, + in_planes: int, + b0_out: int, + b1_mid: int, + b1_out: int, + b2_mid: int, + b2_out: int, + b3_out: int, + norm_layer: Callable[..., nn.Module], + ): + super().__init__() + + self.branch0 = Conv3dNormActivation(in_planes, b0_out, kernel_size=1, stride=1, norm_layer=norm_layer) + self.branch1 = nn.Sequential( + Conv3dNormActivation(in_planes, b1_mid, kernel_size=1, stride=1, norm_layer=norm_layer), + TemporalSeparableConv(b1_mid, b1_out, kernel_size=3, stride=1, padding=1, norm_layer=norm_layer), + ) + self.branch2 = nn.Sequential( + Conv3dNormActivation(in_planes, b2_mid, kernel_size=1, stride=1, norm_layer=norm_layer), + TemporalSeparableConv(b2_mid, b2_out, kernel_size=3, stride=1, padding=1, norm_layer=norm_layer), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1), + Conv3dNormActivation(in_planes, b3_out, kernel_size=1, stride=1, norm_layer=norm_layer), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + + return out + + +class S3D(nn.Module): + """S3D main class. + + Args: + num_class (int): number of classes for the classification task. + dropout (float): dropout probability. + norm_layer (Optional[Callable]): Module specifying the normalization layer to use. + + Inputs: + x (Tensor): batch of videos with dimensions (batch, channel, time, height, width) + """ + + def __init__( + self, + num_classes: int = 400, + dropout: float = 0.2, + norm_layer: Optional[Callable[..., torch.nn.Module]] = None, + ) -> None: + super().__init__() + _log_api_usage_once(self) + + if norm_layer is None: + norm_layer = partial(nn.BatchNorm3d, eps=0.001, momentum=0.001) + + self.features = nn.Sequential( + TemporalSeparableConv(3, 64, 7, 2, 3, norm_layer), + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), + Conv3dNormActivation( + 64, + 64, + kernel_size=1, + stride=1, + norm_layer=norm_layer, + ), + TemporalSeparableConv(64, 192, 3, 1, 1, norm_layer), + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), + SepInceptionBlock3D(192, 64, 96, 128, 16, 32, 32, norm_layer), + SepInceptionBlock3D(256, 128, 128, 192, 32, 96, 64, norm_layer), + nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), + SepInceptionBlock3D(480, 192, 96, 208, 16, 48, 64, norm_layer), + SepInceptionBlock3D(512, 160, 112, 224, 24, 64, 64, norm_layer), + SepInceptionBlock3D(512, 128, 128, 256, 24, 64, 64, norm_layer), + SepInceptionBlock3D(512, 112, 144, 288, 32, 64, 64, norm_layer), + SepInceptionBlock3D(528, 256, 160, 320, 32, 128, 128, norm_layer), + nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)), + SepInceptionBlock3D(832, 256, 160, 320, 32, 128, 128, norm_layer), + SepInceptionBlock3D(832, 384, 192, 384, 48, 128, 128, norm_layer), + ) + self.avgpool = nn.AvgPool3d(kernel_size=(2, 7, 7), stride=1) + self.classifier = nn.Sequential( + nn.Dropout(p=dropout), + nn.Conv3d(1024, num_classes, kernel_size=1, stride=1, bias=True), + ) + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = self.classifier(x) + x = torch.mean(x, dim=(2, 3, 4)) + return x + + +class S3D_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/s3d-d76dad2f.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256, 256), + ), + meta={ + "min_size": (224, 224), + "min_temporal_size": 14, + "categories": _KINETICS400_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification#s3d", + "_docs": ( + "The weights aim to approximate the accuracy of the paper. The accuracies are estimated on clip-level " + "with parameters `frame_rate=15`, `clips_per_video=1`, and `clip_len=128`." + ), + "num_params": 8320048, + "_metrics": { + "Kinetics-400": { + "acc@1": 68.368, + "acc@5": 88.050, + } + }, + "_ops": 17.979, + "_file_size": 31.972, + }, + ) + DEFAULT = KINETICS400_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", S3D_Weights.KINETICS400_V1)) +def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwargs: Any) -> S3D: + """Construct Separable 3D CNN model. + + Reference: `Rethinking Spatiotemporal Feature Learning `__. + + .. betastatus:: video module + + Args: + weights (:class:`~torchvision.models.video.S3D_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.S3D_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.S3D`` base class. + Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.S3D_Weights + :members: + """ + weights = S3D_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = S3D(**kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/lib/python3.10/site-packages/torchvision/models/video/swin_transformer.py b/lib/python3.10/site-packages/torchvision/models/video/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d87ffbe5af6caa1de0a3760fa5c506fdf8e231 --- /dev/null +++ b/lib/python3.10/site-packages/torchvision/models/video/swin_transformer.py @@ -0,0 +1,743 @@ +# Modified from 2d Swin Transformers in torchvision: +# https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py + +from functools import partial +from typing import Any, Callable, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from ...transforms._presets import VideoClassification + +from ...utils import _log_api_usage_once + +from .._api import register_model, Weights, WeightsEnum + +from .._meta import _KINETICS400_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from ..swin_transformer import PatchMerging, SwinTransformerBlock + +__all__ = [ + "SwinTransformer3d", + "Swin3D_T_Weights", + "Swin3D_S_Weights", + "Swin3D_B_Weights", + "swin3d_t", + "swin3d_s", + "swin3d_b", +] + + +def _get_window_and_shift_size( + shift_size: List[int], size_dhw: List[int], window_size: List[int] +) -> Tuple[List[int], List[int]]: + for i in range(3): + if size_dhw[i] <= window_size[i]: + # In this case, window_size will adapt to the input size, and no need to shift + window_size[i] = size_dhw[i] + shift_size[i] = 0 + + return window_size, shift_size + + +torch.fx.wrap("_get_window_and_shift_size") + + +def _get_relative_position_bias( + relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int] +) -> Tensor: + window_vol = window_size[0] * window_size[1] * window_size[2] + # In 3d case we flatten the relative_position_bias + relative_position_bias = relative_position_bias_table[ + relative_position_index[:window_vol, :window_vol].flatten() # type: ignore[index] + ] + relative_position_bias = relative_position_bias.view(window_vol, window_vol, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + return relative_position_bias + + +torch.fx.wrap("_get_relative_position_bias") + + +def _compute_pad_size_3d(size_dhw: Tuple[int, int, int], patch_size: Tuple[int, int, int]) -> Tuple[int, int, int]: + pad_size = [(patch_size[i] - size_dhw[i] % patch_size[i]) % patch_size[i] for i in range(3)] + return pad_size[0], pad_size[1], pad_size[2] + + +torch.fx.wrap("_compute_pad_size_3d") + + +def _compute_attention_mask_3d( + x: Tensor, + size_dhw: Tuple[int, int, int], + window_size: Tuple[int, int, int], + shift_size: Tuple[int, int, int], +) -> Tensor: + # generate attention mask + attn_mask = x.new_zeros(*size_dhw) + num_windows = (size_dhw[0] // window_size[0]) * (size_dhw[1] // window_size[1]) * (size_dhw[2] // window_size[2]) + slices = [ + ( + (0, -window_size[i]), + (-window_size[i], -shift_size[i]), + (-shift_size[i], None), + ) + for i in range(3) + ] + count = 0 + for d in slices[0]: + for h in slices[1]: + for w in slices[2]: + attn_mask[d[0] : d[1], h[0] : h[1], w[0] : w[1]] = count + count += 1 + + # Partition window on attn_mask + attn_mask = attn_mask.view( + size_dhw[0] // window_size[0], + window_size[0], + size_dhw[1] // window_size[1], + window_size[1], + size_dhw[2] // window_size[2], + window_size[2], + ) + attn_mask = attn_mask.permute(0, 2, 4, 1, 3, 5).reshape( + num_windows, window_size[0] * window_size[1] * window_size[2] + ) + attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + return attn_mask + + +torch.fx.wrap("_compute_attention_mask_3d") + + +def shifted_window_attention_3d( + input: Tensor, + qkv_weight: Tensor, + proj_weight: Tensor, + relative_position_bias: Tensor, + window_size: List[int], + num_heads: int, + shift_size: List[int], + attention_dropout: float = 0.0, + dropout: float = 0.0, + qkv_bias: Optional[Tensor] = None, + proj_bias: Optional[Tensor] = None, + training: bool = True, +) -> Tensor: + """ + Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + input (Tensor[B, T, H, W, C]): The input tensor, 5-dimensions. + qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. + proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. + relative_position_bias (Tensor): The learned relative position bias added to attention. + window_size (List[int]): 3-dimensions window size, T, H, W . + num_heads (int): Number of attention heads. + shift_size (List[int]): Shift size for shifted window attention (T, H, W). + attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. + dropout (float): Dropout ratio of output. Default: 0.0. + qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. + proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + training (bool, optional): Training flag used by the dropout parameters. Default: True. + Returns: + Tensor[B, T, H, W, C]: The output tensor after shifted window attention. + """ + b, t, h, w, c = input.shape + # pad feature maps to multiples of window size + pad_size = _compute_pad_size_3d((t, h, w), (window_size[0], window_size[1], window_size[2])) + x = F.pad(input, (0, 0, 0, pad_size[2], 0, pad_size[1], 0, pad_size[0])) + _, tp, hp, wp, _ = x.shape + padded_size = (tp, hp, wp) + + # cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + + # partition windows + num_windows = ( + (padded_size[0] // window_size[0]) * (padded_size[1] // window_size[1]) * (padded_size[2] // window_size[2]) + ) + x = x.view( + b, + padded_size[0] // window_size[0], + window_size[0], + padded_size[1] // window_size[1], + window_size[1], + padded_size[2] // window_size[2], + window_size[2], + c, + ) + x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).reshape( + b * num_windows, window_size[0] * window_size[1] * window_size[2], c + ) # B*nW, Wd*Wh*Ww, C + + # multi-head attention + qkv = F.linear(x, qkv_weight, qkv_bias) + qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, c // num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * (c // num_heads) ** -0.5 + attn = q.matmul(k.transpose(-2, -1)) + # add relative position bias + attn = attn + relative_position_bias + + if sum(shift_size) > 0: + # generate attention mask to handle shifted windows with varying size + attn_mask = _compute_attention_mask_3d( + x, + (padded_size[0], padded_size[1], padded_size[2]), + (window_size[0], window_size[1], window_size[2]), + (shift_size[0], shift_size[1], shift_size[2]), + ) + attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, num_heads, x.size(1), x.size(1)) + + attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, p=attention_dropout, training=training) + + x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), c) + x = F.linear(x, proj_weight, proj_bias) + x = F.dropout(x, p=dropout, training=training) + + # reverse windows + x = x.view( + b, + padded_size[0] // window_size[0], + padded_size[1] // window_size[1], + padded_size[2] // window_size[2], + window_size[0], + window_size[1], + window_size[2], + c, + ) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).reshape(b, tp, hp, wp, c) + + # reverse cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + + # unpad features + x = x[:, :t, :h, :w, :].contiguous() + return x + + +torch.fx.wrap("shifted_window_attention_3d") + + +class ShiftedWindowAttention3d(nn.Module): + """ + See :func:`shifted_window_attention_3d`. + """ + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ) -> None: + super().__init__() + if len(window_size) != 3 or len(shift_size) != 3: + raise ValueError("window_size and shift_size must be of length 2") + + self.window_size = window_size # Wd, Wh, Ww + self.shift_size = shift_size + self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.dropout = dropout + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + self.define_relative_position_bias_table() + self.define_relative_position_index() + + def define_relative_position_bias_table(self) -> None: + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1), + self.num_heads, + ) + ) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + def define_relative_position_index(self) -> None: + # get pair-wise relative position index for each token inside the window + coords_dhw = [torch.arange(self.window_size[i]) for i in range(3)] + coords = torch.stack( + torch.meshgrid(coords_dhw[0], coords_dhw[1], coords_dhw[2], indexing="ij") + ) # 3, Wd, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + # We don't flatten the relative_position_index here in 3d case. + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + def get_relative_position_bias(self, window_size: List[int]) -> torch.Tensor: + return _get_relative_position_bias(self.relative_position_bias_table, self.relative_position_index, window_size) # type: ignore + + def forward(self, x: Tensor) -> Tensor: + _, t, h, w, _ = x.shape + size_dhw = [t, h, w] + window_size, shift_size = self.window_size.copy(), self.shift_size.copy() + # Handle case where window_size is larger than the input tensor + window_size, shift_size = _get_window_and_shift_size(shift_size, size_dhw, window_size) + + relative_position_bias = self.get_relative_position_bias(window_size) + + return shifted_window_attention_3d( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + window_size, + self.num_heads, + shift_size=shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + training=self.training, + ) + + +# Modified from: +# https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py +class PatchEmbed3d(nn.Module): + """Video to Patch Embedding. + + Args: + patch_size (List[int]): Patch token size. + in_channels (int): Number of input channels. Default: 3 + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + patch_size: List[int], + in_channels: int = 3, + embed_dim: int = 96, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + _log_api_usage_once(self) + self.tuple_patch_size = (patch_size[0], patch_size[1], patch_size[2]) + + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=self.tuple_patch_size, + stride=self.tuple_patch_size, + ) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """Forward function.""" + # padding + _, _, t, h, w = x.size() + pad_size = _compute_pad_size_3d((t, h, w), self.tuple_patch_size) + x = F.pad(x, (0, pad_size[2], 0, pad_size[1], 0, pad_size[0])) + x = self.proj(x) # B C T Wh Ww + x = x.permute(0, 2, 3, 4, 1) # B T Wh Ww C + if self.norm is not None: + x = self.norm(x) + return x + + +class SwinTransformer3d(nn.Module): + """ + Implements 3D Swin Transformer from the `"Video Swin Transformer" `_ paper. + Args: + patch_size (List[int]): Patch size. + embed_dim (int): Patch embedding dimension. + depths (List(int)): Depth of each Swin Transformer layer. + num_heads (List(int)): Number of attention heads in different layers. + window_size (List[int]): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. + num_classes (int): Number of classes for classification head. Default: 400. + norm_layer (nn.Module, optional): Normalization layer. Default: None. + block (nn.Module, optional): SwinTransformer Block. Default: None. + downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging. + patch_embed (nn.Module, optional): Patch Embedding layer. Default: None. + """ + + def __init__( + self, + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.1, + num_classes: int = 400, + norm_layer: Optional[Callable[..., nn.Module]] = None, + block: Optional[Callable[..., nn.Module]] = None, + downsample_layer: Callable[..., nn.Module] = PatchMerging, + patch_embed: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + _log_api_usage_once(self) + self.num_classes = num_classes + + if block is None: + block = partial(SwinTransformerBlock, attn_layer=ShiftedWindowAttention3d) + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-5) + + if patch_embed is None: + patch_embed = PatchEmbed3d + + # split image into non-overlapping patches + self.patch_embed = patch_embed(patch_size=patch_size, embed_dim=embed_dim, norm_layer=norm_layer) + self.pos_drop = nn.Dropout(p=dropout) + + layers: List[nn.Module] = [] + total_stage_blocks = sum(depths) + stage_block_id = 0 + # build SwinTransformer blocks + for i_stage in range(len(depths)): + stage: List[nn.Module] = [] + dim = embed_dim * 2**i_stage + for i_layer in range(depths[i_stage]): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) + stage.append( + block( + dim, + num_heads[i_stage], + window_size=window_size, + shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, + attn_layer=ShiftedWindowAttention3d, + ) + ) + stage_block_id += 1 + layers.append(nn.Sequential(*stage)) + # add patch merging layer + if i_stage < (len(depths) - 1): + layers.append(downsample_layer(dim, norm_layer)) + self.features = nn.Sequential(*layers) + + self.num_features = embed_dim * 2 ** (len(depths) - 1) + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool3d(1) + self.head = nn.Linear(self.num_features, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x: Tensor) -> Tensor: + # x: B C T H W + x = self.patch_embed(x) # B _T _H _W C + x = self.pos_drop(x) + x = self.features(x) # B _T _H _W C + x = self.norm(x) + x = x.permute(0, 4, 1, 2, 3) # B, C, _T, _H, _W + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.head(x) + return x + + +def _swin_transformer3d( + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + stochastic_depth_prob: float, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> SwinTransformer3d: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = SwinTransformer3d( + patch_size=patch_size, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + stochastic_depth_prob=stochastic_depth_prob, + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META = { + "categories": _KINETICS400_CATEGORIES, + "min_size": (1, 1), + "min_temporal_size": 1, +} + + +class Swin3D_T_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/swin3d_t-7615ae03.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.4850, 0.4560, 0.4060), + std=(0.2290, 0.2240, 0.2250), + ), + meta={ + **_COMMON_META, + "recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`" + ), + "num_params": 28158070, + "_metrics": { + "Kinetics-400": { + "acc@1": 77.715, + "acc@5": 93.519, + } + }, + "_ops": 43.882, + "_file_size": 121.543, + }, + ) + DEFAULT = KINETICS400_V1 + + +class Swin3D_S_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/swin3d_s-da41c237.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.4850, 0.4560, 0.4060), + std=(0.2290, 0.2240, 0.2250), + ), + meta={ + **_COMMON_META, + "recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`" + ), + "num_params": 49816678, + "_metrics": { + "Kinetics-400": { + "acc@1": 79.521, + "acc@5": 94.158, + } + }, + "_ops": 82.841, + "_file_size": 218.288, + }, + ) + DEFAULT = KINETICS400_V1 + + +class Swin3D_B_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/swin3d_b_1k-24f7c7c6.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.4850, 0.4560, 0.4060), + std=(0.2290, 0.2240, 0.2250), + ), + meta={ + **_COMMON_META, + "recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`" + ), + "num_params": 88048984, + "_metrics": { + "Kinetics-400": { + "acc@1": 79.427, + "acc@5": 94.386, + } + }, + "_ops": 140.667, + "_file_size": 364.134, + }, + ) + KINETICS400_IMAGENET22K_V1 = Weights( + url="https://download.pytorch.org/models/swin3d_b_22k-7c6ae6fa.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.4850, 0.4560, 0.4060), + std=(0.2290, 0.2240, 0.2250), + ), + meta={ + **_COMMON_META, + "recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`" + ), + "num_params": 88048984, + "_metrics": { + "Kinetics-400": { + "acc@1": 81.643, + "acc@5": 95.574, + } + }, + "_ops": 140.667, + "_file_size": 364.134, + }, + ) + DEFAULT = KINETICS400_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Swin3D_T_Weights.KINETICS400_V1)) +def swin3d_t(*, weights: Optional[Swin3D_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer3d: + """ + Constructs a swin_tiny architecture from + `Video Swin Transformer `_. + + Args: + weights (:class:`~torchvision.models.video.Swin3D_T_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.Swin3D_T_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.Swin3D_T_Weights + :members: + """ + weights = Swin3D_T_Weights.verify(weights) + + return _swin_transformer3d( + patch_size=[2, 4, 4], + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=[8, 7, 7], + stochastic_depth_prob=0.1, + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Swin3D_S_Weights.KINETICS400_V1)) +def swin3d_s(*, weights: Optional[Swin3D_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer3d: + """ + Constructs a swin_small architecture from + `Video Swin Transformer `_. + + Args: + weights (:class:`~torchvision.models.video.Swin3D_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.Swin3D_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.Swin3D_S_Weights + :members: + """ + weights = Swin3D_S_Weights.verify(weights) + + return _swin_transformer3d( + patch_size=[2, 4, 4], + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=[8, 7, 7], + stochastic_depth_prob=0.1, + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Swin3D_B_Weights.KINETICS400_V1)) +def swin3d_b(*, weights: Optional[Swin3D_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer3d: + """ + Constructs a swin_base architecture from + `Video Swin Transformer `_. + + Args: + weights (:class:`~torchvision.models.video.Swin3D_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.Swin3D_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.Swin3D_B_Weights + :members: + """ + weights = Swin3D_B_Weights.verify(weights) + + return _swin_transformer3d( + patch_size=[2, 4, 4], + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=[8, 7, 7], + stochastic_depth_prob=0.1, + weights=weights, + progress=progress, + **kwargs, + ) diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a98a0f90ce278a391e47ea01db4dad78555c0c73 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/_box_convert.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/_box_convert.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a25252862037305b97ce0b5f4bdacaba68789aad Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/_box_convert.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/_register_onnx_ops.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/_register_onnx_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8799be4c96b6dda63714c512be97bd615f495b6f Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/_register_onnx_ops.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/_utils.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f15e9d85f721bdfc6c0427b929130a3d2e389fc Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/_utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/boxes.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/boxes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9b2dbd9bc226256ec57e52964f038fe9dc51f82 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/boxes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/ciou_loss.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/ciou_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9781b3c68da85476d2b40fc6155ff680e1f92891 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/ciou_loss.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/deform_conv.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/deform_conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..336e6c8046162580a355a219c052b25c07d4bd27 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/deform_conv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/diou_loss.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/diou_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..403660e8bf65ec21ce069a51892ebe34270ff011 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/diou_loss.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/drop_block.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/drop_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df3660334afe8dbb34bd784e340f95d24f642a32 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/drop_block.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/feature_pyramid_network.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/feature_pyramid_network.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7fb9b3aca4b868df5fa5cd8fecf6e319c9f912d Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/feature_pyramid_network.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/focal_loss.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/focal_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e51a0ebd35c2caa1ec20db45dc6310d0beb6f16a Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/focal_loss.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/giou_loss.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/giou_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8b1b53c8435905f3e26bdafddf7c0c2e8db2af5 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/giou_loss.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/misc.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec78cbab3f5e28bfab39cc9e9b5ade4e69ff00ee Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/misc.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/poolers.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/poolers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af0416ccfec09860b63e2f4d18afa29a5fb02d8d Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/poolers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/ps_roi_align.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/ps_roi_align.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b10383bafde0f44887a4b9a61a329cf0dedebf99 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/ps_roi_align.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/ps_roi_pool.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/ps_roi_pool.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74812381bcf2725d7d3f1de6e287ea81744d7120 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/ps_roi_pool.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/roi_align.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/roi_align.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a18411fbbdcb50d132732916632dacde4d6085db Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/roi_align.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/roi_pool.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/roi_pool.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0339c18c2521b12468a65f719ef72097190c9a5d Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/roi_pool.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/ops/__pycache__/stochastic_depth.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/ops/__pycache__/stochastic_depth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b7d59685b229d122cb04ae6467c838118d3b07a Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/ops/__pycache__/stochastic_depth.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/transforms/v2/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/transforms/v2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1590dcfaeb3180d851d780d2da7e7ee3406d4ca Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/transforms/v2/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc61a0a17d906464a61fd76fc0fedcb1fbed0b9a Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_augment.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_augment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..829c4b4db61f89422f45c001aeabafbae3d152f3 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_augment.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_color.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_color.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26b2b6a4518bdc63a4e86c426ae1eea70ca58aca Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_color.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_deprecated.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_deprecated.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da19bc59089e32f36be0370b57e6c28ab155623b Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_deprecated.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_geometry.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_geometry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f1431c85d80186789cf347f88a61801f8148c63 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_geometry.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_meta.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_meta.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..147518229f8fb6771a4ba272bd3ecbf6742e7595 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_meta.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_misc.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ae5629416881dca65b626cbdcb33b760647a27e Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_misc.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_temporal.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_temporal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3d30d3cf44d1999a47a0c9adc0d799d7a0db7b4 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_temporal.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_type_conversion.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_type_conversion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80f70f99c1c566fa0e2282f1ed87cbbf34641da8 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_type_conversion.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_utils.cpython-310.pyc b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..145116b6d7b42f3288dfa789fa8e7f6f9ac55b07 Binary files /dev/null and b/lib/python3.10/site-packages/torchvision/transforms/v2/functional/__pycache__/_utils.cpython-310.pyc differ